| 35 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
34 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
| 36 to replace the specific residual $Aμ-b$ by $y$. |
35 to replace the specific residual $Aμ-b$ by $y$. |
| 37 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
| 38 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
| 39 </p> |
38 </p> |
| 40 |
|
| 41 Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an |
|
| 42 initialisation corresponding to the second part of the optimality conditions. |
|
| 43 In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$. |
|
| 44 */ |
39 */ |
| 45 |
40 |
| 46 use numeric_literals::replace_float_literals; |
41 use numeric_literals::replace_float_literals; |
| 47 use serde::{Serialize, Deserialize}; |
42 use serde::{Serialize, Deserialize}; |
| 48 use nalgebra::DVector; |
43 use nalgebra::DVector; |
| 49 use clap::ValueEnum; |
44 use clap::ValueEnum; |
| 50 |
45 |
| 51 use alg_tools::iterate:: AlgIteratorFactory; |
46 use alg_tools::iterate::AlgIteratorFactory; |
| 52 use alg_tools::sets::Cube; |
|
| 53 use alg_tools::loc::Loc; |
|
| 54 use alg_tools::euclidean::Euclidean; |
47 use alg_tools::euclidean::Euclidean; |
| |
48 use alg_tools::linops::Mapping; |
| 55 use alg_tools::norms::{ |
49 use alg_tools::norms::{ |
| 56 L1, Linfinity, |
50 Linfinity, |
| 57 Projection, Norm, |
51 Projection, |
| 58 }; |
52 }; |
| 59 use alg_tools::bisection_tree::{ |
53 use alg_tools::mapping::{RealMapping, Instance}; |
| 60 BTFN, |
|
| 61 PreBTFN, |
|
| 62 Bounds, |
|
| 63 BTNodeLookup, |
|
| 64 BTNode, |
|
| 65 BTSearch, |
|
| 66 P2Minimise, |
|
| 67 SupportGenerator, |
|
| 68 LocalAnalysis, |
|
| 69 }; |
|
| 70 use alg_tools::mapping::RealMapping; |
|
| 71 use alg_tools::nalgebra_support::ToNalgebraRealField; |
54 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 72 use alg_tools::linops::AXPY; |
55 use alg_tools::linops::AXPY; |
| 73 |
56 |
| 74 use crate::types::*; |
57 use crate::types::*; |
| 75 use crate::measures::DiscreteMeasure; |
58 use crate::measures::{DiscreteMeasure, RNDM}; |
| 76 use crate::measures::merging::{ |
59 use crate::measures::merging::SpikeMerging; |
| 77 SpikeMerging, |
60 use crate::forward_model::{ |
| 78 }; |
61 ForwardModel, |
| 79 use crate::forward_model::ForwardModel; |
62 AdjointProductBoundedBy, |
| 80 use crate::seminorms::{ |
|
| 81 DiscreteMeasureOp, Lipschitz |
|
| 82 }; |
63 }; |
| 83 use crate::plot::{ |
64 use crate::plot::{ |
| 84 SeqPlotter, |
65 SeqPlotter, |
| 85 Plotting, |
66 Plotting, |
| 86 PlotLookup |
67 PlotLookup |
| 87 }; |
68 }; |
| 88 use crate::fb::{ |
69 use crate::fb::{ |
| |
70 postprocess, |
| |
71 prune_with_stats |
| |
72 }; |
| |
73 pub use crate::prox_penalty::{ |
| 89 FBGenericConfig, |
74 FBGenericConfig, |
| 90 FBSpecialisation, |
75 ProxPenalty |
| 91 generic_pointsource_fb_reg, |
76 }; |
| 92 RegTerm, |
77 use crate::regularisation::RegTerm; |
| 93 }; |
78 use crate::dataterm::{ |
| 94 use crate::regularisation::NonnegRadonRegTerm; |
79 DataTerm, |
| |
80 L2Squared, |
| |
81 L1 |
| |
82 }; |
| |
83 use crate::measures::merging::SpikeMergingMethod; |
| |
84 |
| 95 |
85 |
| 96 /// Acceleration |
86 /// Acceleration |
| 97 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] |
87 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] |
| 98 pub enum Acceleration { |
88 pub enum Acceleration { |
| 99 /// No acceleration |
89 /// No acceleration |
| 105 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
95 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
| 106 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
96 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
| 107 Full |
97 Full |
| 108 } |
98 } |
| 109 |
99 |
| 110 /// Settings for [`pointsource_pdps`]. |
100 #[replace_float_literals(F::cast_from(literal))] |
| |
101 impl Acceleration { |
| |
102 /// PDPS parameter acceleration. Updates τ and σ and returns ω. |
| |
103 /// This uses dual strong convexity, not primal. |
| |
104 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { |
| |
105 match self { |
| |
106 Acceleration::None => 1.0, |
| |
107 Acceleration::Partial => { |
| |
108 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); |
| |
109 *σ *= ω; |
| |
110 *τ /= ω; |
| |
111 ω |
| |
112 }, |
| |
113 Acceleration::Full => { |
| |
114 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); |
| |
115 *σ *= ω; |
| |
116 *τ /= ω; |
| |
117 ω |
| |
118 }, |
| |
119 } |
| |
120 } |
| |
121 } |
| |
122 |
| |
123 /// Settings for [`pointsource_pdps_reg`]. |
| 111 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
124 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 112 #[serde(default)] |
125 #[serde(default)] |
| 113 pub struct PDPSConfig<F : Float> { |
126 pub struct PDPSConfig<F : Float> { |
| 114 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
127 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
| 115 pub τ0 : F, |
128 pub τ0 : F, |
| 116 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
129 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
| 117 pub σ0 : F, |
130 pub σ0 : F, |
| 118 /// Accelerate if available |
131 /// Accelerate if available |
| 119 pub acceleration : Acceleration, |
132 pub acceleration : Acceleration, |
| 120 /// Generic parameters |
133 /// Generic parameters |
| 121 pub insertion : FBGenericConfig<F>, |
134 pub generic : FBGenericConfig<F>, |
| 122 } |
135 } |
| 123 |
136 |
| 124 #[replace_float_literals(F::cast_from(literal))] |
137 #[replace_float_literals(F::cast_from(literal))] |
| 125 impl<F : Float> Default for PDPSConfig<F> { |
138 impl<F : Float> Default for PDPSConfig<F> { |
| 126 fn default() -> Self { |
139 fn default() -> Self { |
| 127 let τ0 = 0.5; |
140 let τ0 = 5.0; |
| 128 PDPSConfig { |
141 PDPSConfig { |
| 129 τ0, |
142 τ0, |
| 130 σ0 : 0.99/τ0, |
143 σ0 : 0.99/τ0, |
| 131 acceleration : Acceleration::Partial, |
144 acceleration : Acceleration::Partial, |
| 132 insertion : Default::default() |
145 generic : FBGenericConfig { |
| |
146 merging : SpikeMergingMethod { enabled : true, ..Default::default() }, |
| |
147 .. Default::default() |
| |
148 }, |
| 133 } |
149 } |
| 134 } |
150 } |
| 135 } |
151 } |
| 136 |
152 |
| 137 /// Trait for subdifferentiable objects |
153 /// Trait for data terms for the PDPS |
| 138 pub trait Subdifferentiable<F : Float, V, U=V> { |
154 #[replace_float_literals(F::cast_from(literal))] |
| 139 /// Calculate some subdifferential at `x` |
155 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { |
| 140 fn some_subdifferential(&self, x : V) -> U; |
156 /// Calculate some subdifferential at `x` for the conjugate |
| 141 } |
157 fn some_subdifferential(&self, x : V) -> V; |
| 142 |
158 |
| 143 /// Type for indicating norm-2-squared data fidelity. |
159 /// Factor of strong convexity of the conjugate |
| 144 pub struct L2Squared; |
160 #[inline] |
| 145 |
161 fn factor_of_strong_convexity(&self) -> F { |
| 146 impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared { |
162 0.0 |
| |
163 } |
| |
164 |
| |
165 /// Perform dual update |
| |
166 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
| |
167 } |
| |
168 |
| |
169 |
| |
170 #[replace_float_literals(F::cast_from(literal))] |
| |
171 impl<F, V, const N : usize> PDPSDataTerm<F, V, N> |
| |
172 for L2Squared |
| |
173 where |
| |
174 F : Float, |
| |
175 V : Euclidean<F> + AXPY<F>, |
| |
176 for<'b> &'b V : Instance<V>, |
| |
177 { |
| 147 fn some_subdifferential(&self, x : V) -> V { x } |
178 fn some_subdifferential(&self, x : V) -> V { x } |
| 148 } |
179 |
| 149 |
180 fn factor_of_strong_convexity(&self) -> F { |
| 150 impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 { |
181 1.0 |
| |
182 } |
| |
183 |
| |
184 #[inline] |
| |
185 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
| |
186 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); |
| |
187 } |
| |
188 } |
| |
189 |
| |
190 #[replace_float_literals(F::cast_from(literal))] |
| |
191 impl<F : Float + nalgebra::RealField, const N : usize> |
| |
192 PDPSDataTerm<F, DVector<F>, N> |
| |
193 for L1 { |
| 151 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
194 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
| 152 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
195 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
| 153 x.iter_mut() |
196 x.iter_mut() |
| 154 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
197 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
| 155 x |
198 x |
| 156 } |
199 } |
| 157 } |
200 |
| 158 |
201 #[inline] |
| 159 /// Specialisation of [`generic_pointsource_fb_reg`] to PDPS. |
202 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { |
| 160 pub struct PDPS< |
203 y.axpy(1.0, y_prev, σ); |
| 161 'a, |
|
| 162 F : Float + ToNalgebraRealField, |
|
| 163 A : ForwardModel<Loc<F, N>, F>, |
|
| 164 D, |
|
| 165 const N : usize |
|
| 166 > { |
|
| 167 /// The data |
|
| 168 b : &'a A::Observable, |
|
| 169 /// The forward operator |
|
| 170 opA : &'a A, |
|
| 171 /// Primal step length |
|
| 172 τ : F, |
|
| 173 // Dual step length |
|
| 174 σ : F, |
|
| 175 /// Whether acceleration should be applied (if data term supports) |
|
| 176 acceleration : Acceleration, |
|
| 177 /// The dataterm. Only used by the type system. |
|
| 178 _dataterm : D, |
|
| 179 /// Previous dual iterate. |
|
| 180 y_prev : A::Observable, |
|
| 181 } |
|
| 182 |
|
| 183 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity. |
|
| 184 #[replace_float_literals(F::cast_from(literal))] |
|
| 185 impl< |
|
| 186 'a, |
|
| 187 F : Float + ToNalgebraRealField, |
|
| 188 A : ForwardModel<Loc<F, N>, F>, |
|
| 189 const N : usize |
|
| 190 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N> |
|
| 191 where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
| 192 |
|
| 193 fn update( |
|
| 194 &mut self, |
|
| 195 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
| 196 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
| 197 ) -> (A::Observable, Option<F>) { |
|
| 198 let σ = self.σ; |
|
| 199 let τ = self.τ; |
|
| 200 let ω = match self.acceleration { |
|
| 201 Acceleration::None => 1.0, |
|
| 202 Acceleration::Partial => { |
|
| 203 let ω = 1.0 / (1.0 + σ).sqrt(); |
|
| 204 self.σ = σ * ω; |
|
| 205 self.τ = τ / ω; |
|
| 206 ω |
|
| 207 }, |
|
| 208 Acceleration::Full => { |
|
| 209 let ω = 1.0 / (1.0 + 2.0 * σ).sqrt(); |
|
| 210 self.σ = σ * ω; |
|
| 211 self.τ = τ / ω; |
|
| 212 ω |
|
| 213 }, |
|
| 214 }; |
|
| 215 |
|
| 216 μ.prune(); |
|
| 217 |
|
| 218 let mut y = self.b.clone(); |
|
| 219 self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0); |
|
| 220 self.opA.gemv(&mut y, -ω, μ_base, 1.0); |
|
| 221 y.axpy(1.0 / (1.0 + σ), &self.y_prev, σ / (1.0 + σ)); |
|
| 222 self.y_prev.copy_from(&y); |
|
| 223 |
|
| 224 (y, Some(self.τ)) |
|
| 225 } |
|
| 226 |
|
| 227 fn calculate_fit( |
|
| 228 &self, |
|
| 229 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 230 _y : &A::Observable |
|
| 231 ) -> F { |
|
| 232 self.calculate_fit_simple(μ) |
|
| 233 } |
|
| 234 |
|
| 235 fn calculate_fit_simple( |
|
| 236 &self, |
|
| 237 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 238 ) -> F { |
|
| 239 let mut residual = self.b.clone(); |
|
| 240 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
| 241 residual.norm2_squared_div2() |
|
| 242 } |
|
| 243 } |
|
| 244 |
|
| 245 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity. |
|
| 246 #[replace_float_literals(F::cast_from(literal))] |
|
| 247 impl< |
|
| 248 'a, |
|
| 249 F : Float + ToNalgebraRealField, |
|
| 250 A : ForwardModel<Loc<F, N>, F>, |
|
| 251 const N : usize |
|
| 252 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N> |
|
| 253 where A::Observable : Projection<F, Linfinity> + Norm<F, L1>, |
|
| 254 for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
| 255 fn update( |
|
| 256 &mut self, |
|
| 257 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
| 258 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
| 259 ) -> (A::Observable, Option<F>) { |
|
| 260 let σ = self.σ; |
|
| 261 |
|
| 262 μ.prune(); |
|
| 263 |
|
| 264 //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base); |
|
| 265 //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ) |
|
| 266 let mut y = self.y_prev.clone(); |
|
| 267 self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0); |
|
| 268 self.opA.gemv(&mut y, -σ, μ_base, 1.0); |
|
| 269 y.axpy(-σ, self.b, 1.0); |
|
| 270 y.proj_ball_mut(1.0, Linfinity); |
204 y.proj_ball_mut(1.0, Linfinity); |
| 271 self.y_prev.copy_from(&y); |
|
| 272 |
|
| 273 (y, None) |
|
| 274 } |
|
| 275 |
|
| 276 fn calculate_fit( |
|
| 277 &self, |
|
| 278 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 279 _y : &A::Observable |
|
| 280 ) -> F { |
|
| 281 self.calculate_fit_simple(μ) |
|
| 282 } |
|
| 283 |
|
| 284 fn calculate_fit_simple( |
|
| 285 &self, |
|
| 286 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 287 ) -> F { |
|
| 288 let mut residual = self.b.clone(); |
|
| 289 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
| 290 residual.norm(L1) |
|
| 291 } |
205 } |
| 292 } |
206 } |
| 293 |
207 |
| 294 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
208 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
| 295 /// |
209 /// |
| 302 /// |
216 /// |
| 303 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
217 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
| 304 /// |
218 /// |
| 305 /// Returns the final iterate. |
219 /// Returns the final iterate. |
| 306 #[replace_float_literals(F::cast_from(literal))] |
220 #[replace_float_literals(F::cast_from(literal))] |
| 307 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
221 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( |
| 308 opA : &'a A, |
222 opA : &A, |
| 309 b : &'a A::Observable, |
223 b : &A::Observable, |
| 310 reg : Reg, |
224 reg : Reg, |
| 311 op𝒟 : &'a 𝒟, |
225 prox_penalty : &P, |
| 312 config : &PDPSConfig<F>, |
226 pdpsconfig : &PDPSConfig<F>, |
| 313 iterator : I, |
227 iterator : I, |
| 314 plotter : SeqPlotter<F, N>, |
228 mut plotter : SeqPlotter<F, N>, |
| 315 dataterm : D, |
229 dataterm : D, |
| 316 ) -> DiscreteMeasure<Loc<F, N>, F> |
230 ) -> RNDM<F, N> |
| 317 where F : Float + ToNalgebraRealField, |
231 where |
| 318 I : AlgIteratorFactory<IterInfo<F, N>>, |
232 F : Float + ToNalgebraRealField, |
| 319 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
233 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 320 + std::ops::Add<A::Observable, Output=A::Observable>, |
234 A : ForwardModel<RNDM<F, N>, F> |
| 321 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow |
235 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
| 322 A::Observable : std::ops::MulAssign<F>, |
236 A::PreadjointCodomain : RealMapping<F, N>, |
| 323 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
237 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
| 324 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
238 PlotLookup : Plotting<N>, |
| 325 + Lipschitz<𝒟, FloatType=F>, |
239 RNDM<F, N> : SpikeMerging<F>, |
| 326 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
240 D : PDPSDataTerm<F, A::Observable, N>, |
| 327 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
241 Reg : RegTerm<F, N>, |
| 328 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
242 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 329 𝒟::Codomain : RealMapping<F, N>, |
243 { |
| 330 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
244 |
| 331 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
245 // Check parameters |
| 332 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
246 assert!(pdpsconfig.τ0 > 0.0 && |
| 333 PlotLookup : Plotting<N>, |
247 pdpsconfig.σ0 > 0.0 && |
| 334 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
248 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
| 335 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, |
249 "Invalid step length parameters"); |
| 336 D : Subdifferentiable<F, A::Observable>, |
250 |
| 337 Reg : RegTerm<F, N> { |
251 // Set up parameters |
| 338 |
252 let config = &pdpsconfig.generic; |
| 339 let y = dataterm.some_subdifferential(-b); |
253 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
| 340 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); |
254 let mut τ = pdpsconfig.τ0 / l; |
| 341 let τ = config.τ0 / l; |
255 let mut σ = pdpsconfig.σ0 / l; |
| 342 let σ = config.σ0 / l; |
256 let γ = dataterm.factor_of_strong_convexity(); |
| 343 |
257 |
| 344 let pdps = PDPS { |
258 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 345 b, |
259 // by τ compared to the conditional gradient approach. |
| 346 opA, |
260 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 347 τ, |
261 let mut ε = tolerance.initial(); |
| 348 σ, |
262 |
| 349 acceleration : config.acceleration, |
263 // Initialise iterates |
| 350 _dataterm : dataterm, |
264 let mut μ = DiscreteMeasure::new(); |
| 351 y_prev : y.clone(), |
265 let mut y = dataterm.some_subdifferential(-b); |
| |
266 let mut y_prev = y.clone(); |
| |
267 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { |
| |
268 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), |
| |
269 n_spikes : μ.len(), |
| |
270 ε, |
| |
271 // postprocessing: config.postprocessing.then(|| μ.clone()), |
| |
272 .. stats |
| 352 }; |
273 }; |
| 353 |
274 let mut stats = IterInfo::new(); |
| 354 generic_pointsource_fb_reg( |
275 |
| 355 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps |
276 // Run the algorithm |
| 356 ) |
277 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 357 } |
278 // Calculate smooth part of surrogate model. |
| 358 |
279 let mut τv = opA.preadjoint().apply(y * τ); |
| 359 // |
280 |
| 360 // Deprecated interfaces |
281 // Save current base point |
| 361 // |
282 let μ_base = μ.clone(); |
| 362 |
283 |
| 363 #[deprecated(note = "Use `pointsource_pdps_reg`")] |
284 // Insert and reweigh |
| 364 pub fn pointsource_pdps<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, const N : usize>( |
285 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 365 opA : &'a A, |
286 &mut μ, &mut τv, &μ_base, None, |
| 366 b : &'a A::Observable, |
287 τ, ε, |
| 367 α : F, |
288 config, ®, &state, &mut stats |
| 368 op𝒟 : &'a 𝒟, |
289 ); |
| 369 config : &PDPSConfig<F>, |
290 |
| 370 iterator : I, |
291 // Prune and possibly merge spikes |
| 371 plotter : SeqPlotter<F, N>, |
292 if config.merge_now(&state) { |
| 372 dataterm : D, |
293 stats.merged += prox_penalty.merge_spikes_no_fitness( |
| 373 ) -> DiscreteMeasure<Loc<F, N>, F> |
294 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
| 374 where F : Float + ToNalgebraRealField, |
295 ); |
| 375 I : AlgIteratorFactory<IterInfo<F, N>>, |
296 } |
| 376 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
297 stats.pruned += prune_with_stats(&mut μ); |
| 377 + std::ops::Add<A::Observable, Output=A::Observable>, |
298 |
| 378 A::Observable : std::ops::MulAssign<F>, |
299 // Update step length parameters |
| 379 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
300 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 380 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
301 |
| 381 + Lipschitz<𝒟, FloatType=F>, |
302 // Do dual update |
| 382 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
303 y = b.clone(); // y = b |
| 383 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
304 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
| 384 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
305 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
| 385 𝒟::Codomain : RealMapping<F, N>, |
306 dataterm.dual_update(&mut y, &y_prev, σ); |
| 386 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
307 y_prev.copy_from(&y); |
| 387 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
308 |
| 388 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
309 // Give statistics if requested |
| 389 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
310 let iter = state.iteration(); |
| 390 PlotLookup : Plotting<N>, |
311 stats.this_iters += 1; |
| 391 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
312 |
| 392 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, |
313 state.if_verbose(|| { |
| 393 D : Subdifferentiable<F, A::Observable> { |
314 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| 394 |
315 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 395 pointsource_pdps_reg(opA, b, NonnegRadonRegTerm(α), op𝒟, config, iterator, plotter, dataterm) |
316 }); |
| 396 } |
317 |
| |
318 ε = tolerance.update(ε, iter); |
| |
319 } |
| |
320 |
| |
321 postprocess(μ, config, dataterm, opA, b) |
| |
322 } |
| |
323 |