72 \end{aligned} |
72 \end{aligned} |
73 $$ |
73 $$ |
74 </p> |
74 </p> |
75 |
75 |
76 We solve this with either SSN or FB as determined by |
76 We solve this with either SSN or FB as determined by |
77 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. |
78 */ |
78 */ |
79 |
79 |
|
80 use colored::Colorize; |
80 use numeric_literals::replace_float_literals; |
81 use numeric_literals::replace_float_literals; |
81 use serde::{Serialize, Deserialize}; |
82 use serde::{Deserialize, Serialize}; |
82 use colored::Colorize; |
83 |
83 |
84 use alg_tools::euclidean::Euclidean; |
|
85 use alg_tools::instance::Instance; |
84 use alg_tools::iterate::AlgIteratorFactory; |
86 use alg_tools::iterate::AlgIteratorFactory; |
85 use alg_tools::euclidean::Euclidean; |
|
86 use alg_tools::linops::{Mapping, GEMV}; |
87 use alg_tools::linops::{Mapping, GEMV}; |
87 use alg_tools::mapping::RealMapping; |
88 use alg_tools::mapping::RealMapping; |
88 use alg_tools::nalgebra_support::ToNalgebraRealField; |
89 use alg_tools::nalgebra_support::ToNalgebraRealField; |
89 use alg_tools::instance::Instance; |
90 |
90 |
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; |
|
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; |
|
93 use crate::measures::merging::SpikeMerging; |
|
94 use crate::measures::{DiscreteMeasure, RNDM}; |
|
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
|
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; |
|
97 use crate::regularisation::RegTerm; |
91 use crate::types::*; |
98 use crate::types::*; |
92 use crate::measures::{ |
|
93 DiscreteMeasure, |
|
94 RNDM, |
|
95 }; |
|
96 use crate::measures::merging::SpikeMerging; |
|
97 use crate::forward_model::{ |
|
98 ForwardModel, |
|
99 AdjointProductBoundedBy, |
|
100 }; |
|
101 use crate::plot::{ |
|
102 SeqPlotter, |
|
103 Plotting, |
|
104 PlotLookup |
|
105 }; |
|
106 use crate::regularisation::RegTerm; |
|
107 use crate::dataterm::{ |
|
108 calculate_residual, |
|
109 L2Squared, |
|
110 DataTerm, |
|
111 }; |
|
112 pub use crate::prox_penalty::{ |
|
113 FBGenericConfig, |
|
114 ProxPenalty |
|
115 }; |
|
116 |
99 |
117 /// Settings for [`pointsource_fb_reg`]. |
100 /// Settings for [`pointsource_fb_reg`]. |
118 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
119 #[serde(default)] |
102 #[serde(default)] |
120 pub struct FBConfig<F : Float> { |
103 pub struct FBConfig<F: Float> { |
121 /// Step length scaling |
104 /// Step length scaling |
122 pub τ0 : F, |
105 pub τ0: F, |
123 /// Generic parameters |
106 /// Generic parameters |
124 pub generic : FBGenericConfig<F>, |
107 pub generic: FBGenericConfig<F>, |
125 } |
108 } |
126 |
109 |
127 #[replace_float_literals(F::cast_from(literal))] |
110 #[replace_float_literals(F::cast_from(literal))] |
128 impl<F : Float> Default for FBConfig<F> { |
111 impl<F: Float> Default for FBConfig<F> { |
129 fn default() -> Self { |
112 fn default() -> Self { |
130 FBConfig { |
113 FBConfig { |
131 τ0 : 0.99, |
114 τ0: 0.99, |
132 generic : Default::default(), |
115 generic: Default::default(), |
133 } |
116 } |
134 } |
117 } |
135 } |
118 } |
136 |
119 |
137 pub(crate) fn prune_with_stats<F : Float, const N : usize>( |
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { |
138 μ : &mut RNDM<F, N>, |
|
139 ) -> usize { |
|
140 let n_before_prune = μ.len(); |
121 let n_before_prune = μ.len(); |
141 μ.prune(); |
122 μ.prune(); |
142 debug_assert!(μ.len() <= n_before_prune); |
123 debug_assert!(μ.len() <= n_before_prune); |
143 n_before_prune - μ.len() |
124 n_before_prune - μ.len() |
144 } |
125 } |
145 |
126 |
146 #[replace_float_literals(F::cast_from(literal))] |
127 #[replace_float_literals(F::cast_from(literal))] |
147 pub(crate) fn postprocess< |
128 pub(crate) fn postprocess< |
148 F : Float, |
129 F: Float, |
149 V : Euclidean<F> + Clone, |
130 V: Euclidean<F> + Clone, |
150 A : GEMV<F, RNDM<F, N>, Codomain = V>, |
131 A: GEMV<F, RNDM<F, N>, Codomain = V>, |
151 D : DataTerm<F, V, N>, |
132 D: DataTerm<F, V, N>, |
152 const N : usize |
133 const N: usize, |
153 > ( |
134 >( |
154 mut μ : RNDM<F, N>, |
135 mut μ: RNDM<F, N>, |
155 config : &FBGenericConfig<F>, |
136 config: &FBGenericConfig<F>, |
156 dataterm : D, |
137 dataterm: D, |
157 opA : &A, |
138 opA: &A, |
158 b : &V, |
139 b: &V, |
159 ) -> RNDM<F, N> |
140 ) -> RNDM<F, N> |
160 where |
141 where |
161 RNDM<F, N> : SpikeMerging<F>, |
142 RNDM<F, N>: SpikeMerging<F>, |
162 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, |
163 { |
144 { |
164 μ.merge_spikes_fitness(config.final_merging_method(), |
145 μ.merge_spikes_fitness( |
165 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
146 config.final_merging_method(), |
166 |&v| v); |
147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
|
148 |&v| v, |
|
149 ); |
167 μ.prune(); |
150 μ.prune(); |
168 μ |
151 μ |
169 } |
152 } |
170 |
153 |
171 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
185 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
186 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
187 /// |
170 /// |
188 /// Returns the final iterate. |
171 /// Returns the final iterate. |
189 #[replace_float_literals(F::cast_from(literal))] |
172 #[replace_float_literals(F::cast_from(literal))] |
190 pub fn pointsource_fb_reg< |
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( |
191 F, I, A, Reg, P, const N : usize |
174 opA: &A, |
192 >( |
175 b: &A::Observable, |
193 opA : &A, |
176 reg: Reg, |
194 b : &A::Observable, |
177 prox_penalty: &P, |
195 reg : Reg, |
178 fbconfig: &FBConfig<F>, |
196 prox_penalty : &P, |
179 iterator: I, |
197 fbconfig : &FBConfig<F>, |
180 mut plotter: SeqPlotter<F, N>, |
198 iterator : I, |
|
199 mut plotter : SeqPlotter<F, N>, |
|
200 ) -> RNDM<F, N> |
181 ) -> RNDM<F, N> |
201 where |
182 where |
202 F : Float + ToNalgebraRealField, |
183 F: Float + ToNalgebraRealField, |
203 I : AlgIteratorFactory<IterInfo<F, N>>, |
184 I: AlgIteratorFactory<IterInfo<F, N>>, |
204 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
205 A : ForwardModel<RNDM<F, N>, F> |
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
206 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
187 A::PreadjointCodomain: RealMapping<F, N>, |
207 A::PreadjointCodomain : RealMapping<F, N>, |
188 PlotLookup: Plotting<N>, |
208 PlotLookup : Plotting<N>, |
189 RNDM<F, N>: SpikeMerging<F>, |
209 RNDM<F, N> : SpikeMerging<F>, |
190 Reg: RegTerm<F, N>, |
210 Reg : RegTerm<F, N>, |
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
211 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
212 { |
192 { |
213 |
|
214 // Set up parameters |
193 // Set up parameters |
215 let config = &fbconfig.generic; |
194 let config = &fbconfig.generic; |
216 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); |
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
217 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
218 // by τ compared to the conditional gradient approach. |
197 // by τ compared to the conditional gradient approach. |
219 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
220 let mut ε = tolerance.initial(); |
199 let mut ε = tolerance.initial(); |
221 |
200 |
222 // Initialise iterates |
201 // Initialise iterates |
223 let mut μ = DiscreteMeasure::new(); |
202 let mut μ = DiscreteMeasure::new(); |
224 let mut residual = -b; |
203 let mut residual = -b; |
225 |
204 |
226 // Statistics |
205 // Statistics |
227 let full_stats = |residual : &A::Observable, |
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
228 μ : &RNDM<F, N>, |
207 value: residual.norm2_squared_div2() + reg.apply(μ), |
229 ε, stats| IterInfo { |
208 n_spikes: μ.len(), |
230 value : residual.norm2_squared_div2() + reg.apply(μ), |
|
231 n_spikes : μ.len(), |
|
232 ε, |
209 ε, |
233 //postprocessing: config.postprocessing.then(|| μ.clone()), |
210 //postprocessing: config.postprocessing.then(|| μ.clone()), |
234 .. stats |
211 ..stats |
235 }; |
212 }; |
236 let mut stats = IterInfo::new(); |
213 let mut stats = IterInfo::new(); |
237 |
214 |
238 // Run the algorithm |
215 // Run the algorithm |
239 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
240 // Calculate smooth part of surrogate model. |
217 // Calculate smooth part of surrogate model. |
241 let mut τv = opA.preadjoint().apply(residual * τ); |
218 let mut τv = opA.preadjoint().apply(residual * τ); |
242 |
219 |
243 // Save current base point |
220 // Save current base point |
244 let μ_base = μ.clone(); |
221 let μ_base = μ.clone(); |
245 |
222 |
246 // Insert and reweigh |
223 // Insert and reweigh |
247 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
248 &mut μ, &mut τv, &μ_base, None, |
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
249 τ, ε, |
|
250 config, ®, &state, &mut stats |
|
251 ); |
226 ); |
252 |
227 |
253 // Prune and possibly merge spikes |
228 // Prune and possibly merge spikes |
254 if config.merge_now(&state) { |
229 if config.merge_now(&state) { |
255 stats.merged += prox_penalty.merge_spikes( |
230 stats.merged += prox_penalty.merge_spikes( |
256 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
231 &mut μ, |
257 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
232 &mut τv, |
|
233 &μ_base, |
|
234 None, |
|
235 τ, |
|
236 ε, |
|
237 config, |
|
238 ®, |
|
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
258 ); |
240 ); |
259 } |
241 } |
260 |
242 |
261 stats.pruned += prune_with_stats(&mut μ); |
243 stats.pruned += prune_with_stats(&mut μ); |
262 |
244 |
296 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
297 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
298 /// |
285 /// |
299 /// Returns the final iterate. |
286 /// Returns the final iterate. |
300 #[replace_float_literals(F::cast_from(literal))] |
287 #[replace_float_literals(F::cast_from(literal))] |
301 pub fn pointsource_fista_reg< |
288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( |
302 F, I, A, Reg, P, const N : usize |
289 opA: &A, |
303 >( |
290 b: &A::Observable, |
304 opA : &A, |
291 reg: Reg, |
305 b : &A::Observable, |
292 prox_penalty: &P, |
306 reg : Reg, |
293 fbconfig: &FBConfig<F>, |
307 prox_penalty : &P, |
294 iterator: I, |
308 fbconfig : &FBConfig<F>, |
295 mut plotter: SeqPlotter<F, N>, |
309 iterator : I, |
|
310 mut plotter : SeqPlotter<F, N>, |
|
311 ) -> RNDM<F, N> |
296 ) -> RNDM<F, N> |
312 where |
297 where |
313 F : Float + ToNalgebraRealField, |
298 F: Float + ToNalgebraRealField, |
314 I : AlgIteratorFactory<IterInfo<F, N>>, |
299 I: AlgIteratorFactory<IterInfo<F, N>>, |
315 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
316 A : ForwardModel<RNDM<F, N>, F> |
301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
317 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
302 A::PreadjointCodomain: RealMapping<F, N>, |
318 A::PreadjointCodomain : RealMapping<F, N>, |
303 PlotLookup: Plotting<N>, |
319 PlotLookup : Plotting<N>, |
304 RNDM<F, N>: SpikeMerging<F>, |
320 RNDM<F, N> : SpikeMerging<F>, |
305 Reg: RegTerm<F, N>, |
321 Reg : RegTerm<F, N>, |
306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
322 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
323 { |
307 { |
324 |
|
325 // Set up parameters |
308 // Set up parameters |
326 let config = &fbconfig.generic; |
309 let config = &fbconfig.generic; |
327 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); |
310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
328 let mut λ = 1.0; |
311 let mut λ = 1.0; |
329 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
330 // by τ compared to the conditional gradient approach. |
313 // by τ compared to the conditional gradient approach. |
331 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
332 let mut ε = tolerance.initial(); |
315 let mut ε = tolerance.initial(); |
336 let mut μ_prev = DiscreteMeasure::new(); |
319 let mut μ_prev = DiscreteMeasure::new(); |
337 let mut residual = -b; |
320 let mut residual = -b; |
338 let mut warned_merging = false; |
321 let mut warned_merging = false; |
339 |
322 |
340 // Statistics |
323 // Statistics |
341 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { |
342 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
343 n_spikes : ν.len(), |
326 n_spikes: ν.len(), |
344 ε, |
327 ε, |
345 // postprocessing: config.postprocessing.then(|| ν.clone()), |
328 // postprocessing: config.postprocessing.then(|| ν.clone()), |
346 .. stats |
329 ..stats |
347 }; |
330 }; |
348 let mut stats = IterInfo::new(); |
331 let mut stats = IterInfo::new(); |
349 |
332 |
350 // Run the algorithm |
333 // Run the algorithm |
351 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
352 // Calculate smooth part of surrogate model. |
335 // Calculate smooth part of surrogate model. |
353 let mut τv = opA.preadjoint().apply(residual * τ); |
336 let mut τv = opA.preadjoint().apply(residual * τ); |
354 |
337 |
355 // Save current base point |
338 // Save current base point |
356 let μ_base = μ.clone(); |
339 let μ_base = μ.clone(); |
357 |
340 |
358 // Insert new spikes and reweigh |
341 // Insert new spikes and reweigh |
359 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
360 &mut μ, &mut τv, &μ_base, None, |
343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
361 τ, ε, |
|
362 config, ®, &state, &mut stats |
|
363 ); |
344 ); |
364 |
345 |
365 // (Do not) merge spikes. |
346 // (Do not) merge spikes. |
366 if config.merge_now(&state) && !warned_merging { |
347 if config.merge_now(&state) && !warned_merging { |
367 let err = format!("Merging not supported for μFISTA"); |
348 let err = format!("Merging not supported for μFISTA"); |