| 72 \end{aligned} |
72 \end{aligned} |
| 73 $$ |
73 $$ |
| 74 </p> |
74 </p> |
| 75 |
75 |
| 76 We solve this with either SSN or FB as determined by |
76 We solve this with either SSN or FB as determined by |
| 77 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. |
| 78 */ |
78 */ |
| 79 |
79 |
| |
80 use colored::Colorize; |
| 80 use numeric_literals::replace_float_literals; |
81 use numeric_literals::replace_float_literals; |
| 81 use serde::{Serialize, Deserialize}; |
82 use serde::{Deserialize, Serialize}; |
| 82 use colored::Colorize; |
83 |
| 83 |
84 use alg_tools::euclidean::Euclidean; |
| |
85 use alg_tools::instance::Instance; |
| 84 use alg_tools::iterate::AlgIteratorFactory; |
86 use alg_tools::iterate::AlgIteratorFactory; |
| 85 use alg_tools::euclidean::Euclidean; |
|
| 86 use alg_tools::linops::{Mapping, GEMV}; |
87 use alg_tools::linops::{Mapping, GEMV}; |
| 87 use alg_tools::mapping::RealMapping; |
88 use alg_tools::mapping::RealMapping; |
| 88 use alg_tools::nalgebra_support::ToNalgebraRealField; |
89 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 89 use alg_tools::instance::Instance; |
90 |
| 90 |
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; |
| |
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; |
| |
93 use crate::measures::merging::SpikeMerging; |
| |
94 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
| |
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; |
| |
97 use crate::regularisation::RegTerm; |
| 91 use crate::types::*; |
98 use crate::types::*; |
| 92 use crate::measures::{ |
|
| 93 DiscreteMeasure, |
|
| 94 RNDM, |
|
| 95 }; |
|
| 96 use crate::measures::merging::SpikeMerging; |
|
| 97 use crate::forward_model::{ |
|
| 98 ForwardModel, |
|
| 99 AdjointProductBoundedBy, |
|
| 100 }; |
|
| 101 use crate::plot::{ |
|
| 102 SeqPlotter, |
|
| 103 Plotting, |
|
| 104 PlotLookup |
|
| 105 }; |
|
| 106 use crate::regularisation::RegTerm; |
|
| 107 use crate::dataterm::{ |
|
| 108 calculate_residual, |
|
| 109 L2Squared, |
|
| 110 DataTerm, |
|
| 111 }; |
|
| 112 pub use crate::prox_penalty::{ |
|
| 113 FBGenericConfig, |
|
| 114 ProxPenalty |
|
| 115 }; |
|
| 116 |
99 |
| 117 /// Settings for [`pointsource_fb_reg`]. |
100 /// Settings for [`pointsource_fb_reg`]. |
| 118 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 119 #[serde(default)] |
102 #[serde(default)] |
| 120 pub struct FBConfig<F : Float> { |
103 pub struct FBConfig<F: Float> { |
| 121 /// Step length scaling |
104 /// Step length scaling |
| 122 pub τ0 : F, |
105 pub τ0: F, |
| 123 /// Generic parameters |
106 /// Generic parameters |
| 124 pub generic : FBGenericConfig<F>, |
107 pub generic: FBGenericConfig<F>, |
| 125 } |
108 } |
| 126 |
109 |
| 127 #[replace_float_literals(F::cast_from(literal))] |
110 #[replace_float_literals(F::cast_from(literal))] |
| 128 impl<F : Float> Default for FBConfig<F> { |
111 impl<F: Float> Default for FBConfig<F> { |
| 129 fn default() -> Self { |
112 fn default() -> Self { |
| 130 FBConfig { |
113 FBConfig { |
| 131 τ0 : 0.99, |
114 τ0: 0.99, |
| 132 generic : Default::default(), |
115 generic: Default::default(), |
| 133 } |
116 } |
| 134 } |
117 } |
| 135 } |
118 } |
| 136 |
119 |
| 137 pub(crate) fn prune_with_stats<F : Float, const N : usize>( |
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { |
| 138 μ : &mut RNDM<F, N>, |
|
| 139 ) -> usize { |
|
| 140 let n_before_prune = μ.len(); |
121 let n_before_prune = μ.len(); |
| 141 μ.prune(); |
122 μ.prune(); |
| 142 debug_assert!(μ.len() <= n_before_prune); |
123 debug_assert!(μ.len() <= n_before_prune); |
| 143 n_before_prune - μ.len() |
124 n_before_prune - μ.len() |
| 144 } |
125 } |
| 145 |
126 |
| 146 #[replace_float_literals(F::cast_from(literal))] |
127 #[replace_float_literals(F::cast_from(literal))] |
| 147 pub(crate) fn postprocess< |
128 pub(crate) fn postprocess< |
| 148 F : Float, |
129 F: Float, |
| 149 V : Euclidean<F> + Clone, |
130 V: Euclidean<F> + Clone, |
| 150 A : GEMV<F, RNDM<F, N>, Codomain = V>, |
131 A: GEMV<F, RNDM<F, N>, Codomain = V>, |
| 151 D : DataTerm<F, V, N>, |
132 D: DataTerm<F, V, N>, |
| 152 const N : usize |
133 const N: usize, |
| 153 > ( |
134 >( |
| 154 mut μ : RNDM<F, N>, |
135 mut μ: RNDM<F, N>, |
| 155 config : &FBGenericConfig<F>, |
136 config: &FBGenericConfig<F>, |
| 156 dataterm : D, |
137 dataterm: D, |
| 157 opA : &A, |
138 opA: &A, |
| 158 b : &V, |
139 b: &V, |
| 159 ) -> RNDM<F, N> |
140 ) -> RNDM<F, N> |
| 160 where |
141 where |
| 161 RNDM<F, N> : SpikeMerging<F>, |
142 RNDM<F, N>: SpikeMerging<F>, |
| 162 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, |
| 163 { |
144 { |
| 164 μ.merge_spikes_fitness(config.final_merging_method(), |
145 μ.merge_spikes_fitness( |
| 165 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
146 config.final_merging_method(), |
| 166 |&v| v); |
147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
| |
148 |&v| v, |
| |
149 ); |
| 167 μ.prune(); |
150 μ.prune(); |
| 168 μ |
151 μ |
| 169 } |
152 } |
| 170 |
153 |
| 171 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
| 185 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
| 186 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
| 187 /// |
170 /// |
| 188 /// Returns the final iterate. |
171 /// Returns the final iterate. |
| 189 #[replace_float_literals(F::cast_from(literal))] |
172 #[replace_float_literals(F::cast_from(literal))] |
| 190 pub fn pointsource_fb_reg< |
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( |
| 191 F, I, A, Reg, P, const N : usize |
174 opA: &A, |
| 192 >( |
175 b: &A::Observable, |
| 193 opA : &A, |
176 reg: Reg, |
| 194 b : &A::Observable, |
177 prox_penalty: &P, |
| 195 reg : Reg, |
178 fbconfig: &FBConfig<F>, |
| 196 prox_penalty : &P, |
179 iterator: I, |
| 197 fbconfig : &FBConfig<F>, |
180 mut plotter: SeqPlotter<F, N>, |
| 198 iterator : I, |
|
| 199 mut plotter : SeqPlotter<F, N>, |
|
| 200 ) -> RNDM<F, N> |
181 ) -> RNDM<F, N> |
| 201 where |
182 where |
| 202 F : Float + ToNalgebraRealField, |
183 F: Float + ToNalgebraRealField, |
| 203 I : AlgIteratorFactory<IterInfo<F, N>>, |
184 I: AlgIteratorFactory<IterInfo<F, N>>, |
| 204 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
| 205 A : ForwardModel<RNDM<F, N>, F> |
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
| 206 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
187 A::PreadjointCodomain: RealMapping<F, N>, |
| 207 A::PreadjointCodomain : RealMapping<F, N>, |
188 PlotLookup: Plotting<N>, |
| 208 PlotLookup : Plotting<N>, |
189 RNDM<F, N>: SpikeMerging<F>, |
| 209 RNDM<F, N> : SpikeMerging<F>, |
190 Reg: RegTerm<F, N>, |
| 210 Reg : RegTerm<F, N>, |
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 211 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 212 { |
192 { |
| 213 |
|
| 214 // Set up parameters |
193 // Set up parameters |
| 215 let config = &fbconfig.generic; |
194 let config = &fbconfig.generic; |
| 216 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); |
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| 217 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 218 // by τ compared to the conditional gradient approach. |
197 // by τ compared to the conditional gradient approach. |
| 219 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 220 let mut ε = tolerance.initial(); |
199 let mut ε = tolerance.initial(); |
| 221 |
200 |
| 222 // Initialise iterates |
201 // Initialise iterates |
| 223 let mut μ = DiscreteMeasure::new(); |
202 let mut μ = DiscreteMeasure::new(); |
| 224 let mut residual = -b; |
203 let mut residual = -b; |
| 225 |
204 |
| 226 // Statistics |
205 // Statistics |
| 227 let full_stats = |residual : &A::Observable, |
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
| 228 μ : &RNDM<F, N>, |
207 value: residual.norm2_squared_div2() + reg.apply(μ), |
| 229 ε, stats| IterInfo { |
208 n_spikes: μ.len(), |
| 230 value : residual.norm2_squared_div2() + reg.apply(μ), |
|
| 231 n_spikes : μ.len(), |
|
| 232 ε, |
209 ε, |
| 233 //postprocessing: config.postprocessing.then(|| μ.clone()), |
210 //postprocessing: config.postprocessing.then(|| μ.clone()), |
| 234 .. stats |
211 ..stats |
| 235 }; |
212 }; |
| 236 let mut stats = IterInfo::new(); |
213 let mut stats = IterInfo::new(); |
| 237 |
214 |
| 238 // Run the algorithm |
215 // Run the algorithm |
| 239 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
| 240 // Calculate smooth part of surrogate model. |
217 // Calculate smooth part of surrogate model. |
| 241 let mut τv = opA.preadjoint().apply(residual * τ); |
218 let mut τv = opA.preadjoint().apply(residual * τ); |
| 242 |
219 |
| 243 // Save current base point |
220 // Save current base point |
| 244 let μ_base = μ.clone(); |
221 let μ_base = μ.clone(); |
| 245 |
222 |
| 246 // Insert and reweigh |
223 // Insert and reweigh |
| 247 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 248 &mut μ, &mut τv, &μ_base, None, |
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 249 τ, ε, |
|
| 250 config, ®, &state, &mut stats |
|
| 251 ); |
226 ); |
| 252 |
227 |
| 253 // Prune and possibly merge spikes |
228 // Prune and possibly merge spikes |
| 254 if config.merge_now(&state) { |
229 if config.merge_now(&state) { |
| 255 stats.merged += prox_penalty.merge_spikes( |
230 stats.merged += prox_penalty.merge_spikes( |
| 256 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
231 &mut μ, |
| 257 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
232 &mut τv, |
| |
233 &μ_base, |
| |
234 None, |
| |
235 τ, |
| |
236 ε, |
| |
237 config, |
| |
238 ®, |
| |
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
| 258 ); |
240 ); |
| 259 } |
241 } |
| 260 |
242 |
| 261 stats.pruned += prune_with_stats(&mut μ); |
243 stats.pruned += prune_with_stats(&mut μ); |
| 262 |
244 |
| 296 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
| 297 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
| 298 /// |
285 /// |
| 299 /// Returns the final iterate. |
286 /// Returns the final iterate. |
| 300 #[replace_float_literals(F::cast_from(literal))] |
287 #[replace_float_literals(F::cast_from(literal))] |
| 301 pub fn pointsource_fista_reg< |
288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( |
| 302 F, I, A, Reg, P, const N : usize |
289 opA: &A, |
| 303 >( |
290 b: &A::Observable, |
| 304 opA : &A, |
291 reg: Reg, |
| 305 b : &A::Observable, |
292 prox_penalty: &P, |
| 306 reg : Reg, |
293 fbconfig: &FBConfig<F>, |
| 307 prox_penalty : &P, |
294 iterator: I, |
| 308 fbconfig : &FBConfig<F>, |
295 mut plotter: SeqPlotter<F, N>, |
| 309 iterator : I, |
|
| 310 mut plotter : SeqPlotter<F, N>, |
|
| 311 ) -> RNDM<F, N> |
296 ) -> RNDM<F, N> |
| 312 where |
297 where |
| 313 F : Float + ToNalgebraRealField, |
298 F: Float + ToNalgebraRealField, |
| 314 I : AlgIteratorFactory<IterInfo<F, N>>, |
299 I: AlgIteratorFactory<IterInfo<F, N>>, |
| 315 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
| 316 A : ForwardModel<RNDM<F, N>, F> |
301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
| 317 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
302 A::PreadjointCodomain: RealMapping<F, N>, |
| 318 A::PreadjointCodomain : RealMapping<F, N>, |
303 PlotLookup: Plotting<N>, |
| 319 PlotLookup : Plotting<N>, |
304 RNDM<F, N>: SpikeMerging<F>, |
| 320 RNDM<F, N> : SpikeMerging<F>, |
305 Reg: RegTerm<F, N>, |
| 321 Reg : RegTerm<F, N>, |
306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 322 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 323 { |
307 { |
| 324 |
|
| 325 // Set up parameters |
308 // Set up parameters |
| 326 let config = &fbconfig.generic; |
309 let config = &fbconfig.generic; |
| 327 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); |
310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| 328 let mut λ = 1.0; |
311 let mut λ = 1.0; |
| 329 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 330 // by τ compared to the conditional gradient approach. |
313 // by τ compared to the conditional gradient approach. |
| 331 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 332 let mut ε = tolerance.initial(); |
315 let mut ε = tolerance.initial(); |
| 336 let mut μ_prev = DiscreteMeasure::new(); |
319 let mut μ_prev = DiscreteMeasure::new(); |
| 337 let mut residual = -b; |
320 let mut residual = -b; |
| 338 let mut warned_merging = false; |
321 let mut warned_merging = false; |
| 339 |
322 |
| 340 // Statistics |
323 // Statistics |
| 341 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { |
| 342 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
| 343 n_spikes : ν.len(), |
326 n_spikes: ν.len(), |
| 344 ε, |
327 ε, |
| 345 // postprocessing: config.postprocessing.then(|| ν.clone()), |
328 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| 346 .. stats |
329 ..stats |
| 347 }; |
330 }; |
| 348 let mut stats = IterInfo::new(); |
331 let mut stats = IterInfo::new(); |
| 349 |
332 |
| 350 // Run the algorithm |
333 // Run the algorithm |
| 351 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 352 // Calculate smooth part of surrogate model. |
335 // Calculate smooth part of surrogate model. |
| 353 let mut τv = opA.preadjoint().apply(residual * τ); |
336 let mut τv = opA.preadjoint().apply(residual * τ); |
| 354 |
337 |
| 355 // Save current base point |
338 // Save current base point |
| 356 let μ_base = μ.clone(); |
339 let μ_base = μ.clone(); |
| 357 |
340 |
| 358 // Insert new spikes and reweigh |
341 // Insert new spikes and reweigh |
| 359 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 360 &mut μ, &mut τv, &μ_base, None, |
343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 361 τ, ε, |
|
| 362 config, ®, &state, &mut stats |
|
| 363 ); |
344 ); |
| 364 |
345 |
| 365 // (Do not) merge spikes. |
346 // (Do not) merge spikes. |
| 366 if config.merge_now(&state) && !warned_merging { |
347 if config.merge_now(&state) && !warned_merging { |
| 367 let err = format!("Merging not supported for μFISTA"); |
348 let err = format!("Merging not supported for μFISTA"); |