src/fb.rs

branch
dev
changeset 51
0693cc9ba9f0
parent 39
6316d68b58af
equal deleted inserted replaced
50:39c5e6c7759d 51:0693cc9ba9f0
72 \end{aligned} 72 \end{aligned}
73 $$ 73 $$
74 </p> 74 </p>
75 75
76 We solve this with either SSN or FB as determined by 76 We solve this with either SSN or FB as determined by
77 [`InnerSettings`] in [`FBGenericConfig::inner`]. 77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`].
78 */ 78 */
79 79
80 use colored::Colorize;
80 use numeric_literals::replace_float_literals; 81 use numeric_literals::replace_float_literals;
81 use serde::{Serialize, Deserialize}; 82 use serde::{Deserialize, Serialize};
82 use colored::Colorize; 83
83 84 use alg_tools::euclidean::Euclidean;
85 use alg_tools::instance::Instance;
84 use alg_tools::iterate::AlgIteratorFactory; 86 use alg_tools::iterate::AlgIteratorFactory;
85 use alg_tools::euclidean::Euclidean;
86 use alg_tools::linops::{Mapping, GEMV}; 87 use alg_tools::linops::{Mapping, GEMV};
87 use alg_tools::mapping::RealMapping; 88 use alg_tools::mapping::RealMapping;
88 use alg_tools::nalgebra_support::ToNalgebraRealField; 89 use alg_tools::nalgebra_support::ToNalgebraRealField;
89 use alg_tools::instance::Instance; 90
90 91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
93 use crate::measures::merging::SpikeMerging;
94 use crate::measures::{DiscreteMeasure, RNDM};
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
97 use crate::regularisation::RegTerm;
91 use crate::types::*; 98 use crate::types::*;
92 use crate::measures::{
93 DiscreteMeasure,
94 RNDM,
95 };
96 use crate::measures::merging::SpikeMerging;
97 use crate::forward_model::{
98 ForwardModel,
99 AdjointProductBoundedBy,
100 };
101 use crate::plot::{
102 SeqPlotter,
103 Plotting,
104 PlotLookup
105 };
106 use crate::regularisation::RegTerm;
107 use crate::dataterm::{
108 calculate_residual,
109 L2Squared,
110 DataTerm,
111 };
112 pub use crate::prox_penalty::{
113 FBGenericConfig,
114 ProxPenalty
115 };
116 99
117 /// Settings for [`pointsource_fb_reg`]. 100 /// Settings for [`pointsource_fb_reg`].
118 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
119 #[serde(default)] 102 #[serde(default)]
120 pub struct FBConfig<F : Float> { 103 pub struct FBConfig<F: Float> {
121 /// Step length scaling 104 /// Step length scaling
122 pub τ0 : F, 105 pub τ0: F,
123 /// Generic parameters 106 /// Generic parameters
124 pub generic : FBGenericConfig<F>, 107 pub generic: FBGenericConfig<F>,
125 } 108 }
126 109
127 #[replace_float_literals(F::cast_from(literal))] 110 #[replace_float_literals(F::cast_from(literal))]
128 impl<F : Float> Default for FBConfig<F> { 111 impl<F: Float> Default for FBConfig<F> {
129 fn default() -> Self { 112 fn default() -> Self {
130 FBConfig { 113 FBConfig {
131 τ0 : 0.99, 114 τ0: 0.99,
132 generic : Default::default(), 115 generic: Default::default(),
133 } 116 }
134 } 117 }
135 } 118 }
136 119
137 pub(crate) fn prune_with_stats<F : Float, const N : usize>( 120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize {
138 μ : &mut RNDM<F, N>,
139 ) -> usize {
140 let n_before_prune = μ.len(); 121 let n_before_prune = μ.len();
141 μ.prune(); 122 μ.prune();
142 debug_assert!(μ.len() <= n_before_prune); 123 debug_assert!(μ.len() <= n_before_prune);
143 n_before_prune - μ.len() 124 n_before_prune - μ.len()
144 } 125 }
145 126
146 #[replace_float_literals(F::cast_from(literal))] 127 #[replace_float_literals(F::cast_from(literal))]
147 pub(crate) fn postprocess< 128 pub(crate) fn postprocess<
148 F : Float, 129 F: Float,
149 V : Euclidean<F> + Clone, 130 V: Euclidean<F> + Clone,
150 A : GEMV<F, RNDM<F, N>, Codomain = V>, 131 A: GEMV<F, RNDM<F, N>, Codomain = V>,
151 D : DataTerm<F, V, N>, 132 D: DataTerm<F, V, N>,
152 const N : usize 133 const N: usize,
153 > ( 134 >(
154 mut μ : RNDM<F, N>, 135 mut μ: RNDM<F, N>,
155 config : &FBGenericConfig<F>, 136 config: &FBGenericConfig<F>,
156 dataterm : D, 137 dataterm: D,
157 opA : &A, 138 opA: &A,
158 b : &V, 139 b: &V,
159 ) -> RNDM<F, N> 140 ) -> RNDM<F, N>
160 where 141 where
161 RNDM<F, N> : SpikeMerging<F>, 142 RNDM<F, N>: SpikeMerging<F>,
162 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, 143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>,
163 { 144 {
164 μ.merge_spikes_fitness(config.final_merging_method(), 145 μ.merge_spikes_fitness(
165 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), 146 config.final_merging_method(),
166 |&v| v); 147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
148 |&v| v,
149 );
167 μ.prune(); 150 μ.prune();
168 μ 151 μ
169 } 152 }
170 153
171 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. 154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
185 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features 168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
186 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. 169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
187 /// 170 ///
188 /// Returns the final iterate. 171 /// Returns the final iterate.
189 #[replace_float_literals(F::cast_from(literal))] 172 #[replace_float_literals(F::cast_from(literal))]
190 pub fn pointsource_fb_reg< 173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>(
191 F, I, A, Reg, P, const N : usize 174 opA: &A,
192 >( 175 b: &A::Observable,
193 opA : &A, 176 reg: Reg,
194 b : &A::Observable, 177 prox_penalty: &P,
195 reg : Reg, 178 fbconfig: &FBConfig<F>,
196 prox_penalty : &P, 179 iterator: I,
197 fbconfig : &FBConfig<F>, 180 mut plotter: SeqPlotter<F, N>,
198 iterator : I,
199 mut plotter : SeqPlotter<F, N>,
200 ) -> RNDM<F, N> 181 ) -> RNDM<F, N>
201 where 182 where
202 F : Float + ToNalgebraRealField, 183 F: Float + ToNalgebraRealField,
203 I : AlgIteratorFactory<IterInfo<F, N>>, 184 I: AlgIteratorFactory<IterInfo<F, N>>,
204 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
205 A : ForwardModel<RNDM<F, N>, F> 186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
206 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, 187 A::PreadjointCodomain: RealMapping<F, N>,
207 A::PreadjointCodomain : RealMapping<F, N>, 188 PlotLookup: Plotting<N>,
208 PlotLookup : Plotting<N>, 189 RNDM<F, N>: SpikeMerging<F>,
209 RNDM<F, N> : SpikeMerging<F>, 190 Reg: RegTerm<F, N>,
210 Reg : RegTerm<F, N>, 191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
211 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
212 { 192 {
213
214 // Set up parameters 193 // Set up parameters
215 let config = &fbconfig.generic; 194 let config = &fbconfig.generic;
216 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); 195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
217 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
218 // by τ compared to the conditional gradient approach. 197 // by τ compared to the conditional gradient approach.
219 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 198 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
220 let mut ε = tolerance.initial(); 199 let mut ε = tolerance.initial();
221 200
222 // Initialise iterates 201 // Initialise iterates
223 let mut μ = DiscreteMeasure::new(); 202 let mut μ = DiscreteMeasure::new();
224 let mut residual = -b; 203 let mut residual = -b;
225 204
226 // Statistics 205 // Statistics
227 let full_stats = |residual : &A::Observable, 206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
228 μ : &RNDM<F, N>, 207 value: residual.norm2_squared_div2() + reg.apply(μ),
229 ε, stats| IterInfo { 208 n_spikes: μ.len(),
230 value : residual.norm2_squared_div2() + reg.apply(μ),
231 n_spikes : μ.len(),
232 ε, 209 ε,
233 //postprocessing: config.postprocessing.then(|| μ.clone()), 210 //postprocessing: config.postprocessing.then(|| μ.clone()),
234 .. stats 211 ..stats
235 }; 212 };
236 let mut stats = IterInfo::new(); 213 let mut stats = IterInfo::new();
237 214
238 // Run the algorithm 215 // Run the algorithm
239 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
240 // Calculate smooth part of surrogate model. 217 // Calculate smooth part of surrogate model.
241 let mut τv = opA.preadjoint().apply(residual * τ); 218 let mut τv = opA.preadjoint().apply(residual * τ);
242 219
243 // Save current base point 220 // Save current base point
244 let μ_base = μ.clone(); 221 let μ_base = μ.clone();
245 222
246 // Insert and reweigh 223 // Insert and reweigh
247 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
248 &mut μ, &mut τv, &μ_base, None, 225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
249 τ, ε,
250 config, &reg, &state, &mut stats
251 ); 226 );
252 227
253 // Prune and possibly merge spikes 228 // Prune and possibly merge spikes
254 if config.merge_now(&state) { 229 if config.merge_now(&state) {
255 stats.merged += prox_penalty.merge_spikes( 230 stats.merged += prox_penalty.merge_spikes(
256 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, 231 &mut μ,
257 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 232 &mut τv,
233 &μ_base,
234 None,
235 τ,
236 ε,
237 config,
238 &reg,
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
258 ); 240 );
259 } 241 }
260 242
261 stats.pruned += prune_with_stats(&mut μ); 243 stats.pruned += prune_with_stats(&mut μ);
262 244
267 stats.this_iters += 1; 249 stats.this_iters += 1;
268 250
269 // Give statistics if needed 251 // Give statistics if needed
270 state.if_verbose(|| { 252 state.if_verbose(|| {
271 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); 253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
272 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) 254 full_stats(
255 &residual,
256 &μ,
257 ε,
258 std::mem::replace(&mut stats, IterInfo::new()),
259 )
273 }); 260 });
274 261
275 // Update main tolerance for next iteration 262 // Update main tolerance for next iteration
276 ε = tolerance.update(ε, iter); 263 ε = tolerance.update(ε, iter);
277 } 264 }
278 265
279 postprocess(μ, config, L2Squared, opA, b) 266 postprocess(μ, config, L2Squared, opA, b)
296 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features 283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
297 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. 284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
298 /// 285 ///
299 /// Returns the final iterate. 286 /// Returns the final iterate.
300 #[replace_float_literals(F::cast_from(literal))] 287 #[replace_float_literals(F::cast_from(literal))]
301 pub fn pointsource_fista_reg< 288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>(
302 F, I, A, Reg, P, const N : usize 289 opA: &A,
303 >( 290 b: &A::Observable,
304 opA : &A, 291 reg: Reg,
305 b : &A::Observable, 292 prox_penalty: &P,
306 reg : Reg, 293 fbconfig: &FBConfig<F>,
307 prox_penalty : &P, 294 iterator: I,
308 fbconfig : &FBConfig<F>, 295 mut plotter: SeqPlotter<F, N>,
309 iterator : I,
310 mut plotter : SeqPlotter<F, N>,
311 ) -> RNDM<F, N> 296 ) -> RNDM<F, N>
312 where 297 where
313 F : Float + ToNalgebraRealField, 298 F: Float + ToNalgebraRealField,
314 I : AlgIteratorFactory<IterInfo<F, N>>, 299 I: AlgIteratorFactory<IterInfo<F, N>>,
315 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
316 A : ForwardModel<RNDM<F, N>, F> 301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
317 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, 302 A::PreadjointCodomain: RealMapping<F, N>,
318 A::PreadjointCodomain : RealMapping<F, N>, 303 PlotLookup: Plotting<N>,
319 PlotLookup : Plotting<N>, 304 RNDM<F, N>: SpikeMerging<F>,
320 RNDM<F, N> : SpikeMerging<F>, 305 Reg: RegTerm<F, N>,
321 Reg : RegTerm<F, N>, 306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
322 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
323 { 307 {
324
325 // Set up parameters 308 // Set up parameters
326 let config = &fbconfig.generic; 309 let config = &fbconfig.generic;
327 let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); 310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
328 let mut λ = 1.0; 311 let mut λ = 1.0;
329 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
330 // by τ compared to the conditional gradient approach. 313 // by τ compared to the conditional gradient approach.
331 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 314 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
332 let mut ε = tolerance.initial(); 315 let mut ε = tolerance.initial();
336 let mut μ_prev = DiscreteMeasure::new(); 319 let mut μ_prev = DiscreteMeasure::new();
337 let mut residual = -b; 320 let mut residual = -b;
338 let mut warned_merging = false; 321 let mut warned_merging = false;
339 322
340 // Statistics 323 // Statistics
341 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { 324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo {
342 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), 325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
343 n_spikes : ν.len(), 326 n_spikes: ν.len(),
344 ε, 327 ε,
345 // postprocessing: config.postprocessing.then(|| ν.clone()), 328 // postprocessing: config.postprocessing.then(|| ν.clone()),
346 .. stats 329 ..stats
347 }; 330 };
348 let mut stats = IterInfo::new(); 331 let mut stats = IterInfo::new();
349 332
350 // Run the algorithm 333 // Run the algorithm
351 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
352 // Calculate smooth part of surrogate model. 335 // Calculate smooth part of surrogate model.
353 let mut τv = opA.preadjoint().apply(residual * τ); 336 let mut τv = opA.preadjoint().apply(residual * τ);
354 337
355 // Save current base point 338 // Save current base point
356 let μ_base = μ.clone(); 339 let μ_base = μ.clone();
357 340
358 // Insert new spikes and reweigh 341 // Insert new spikes and reweigh
359 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
360 &mut μ, &mut τv, &μ_base, None, 343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
361 τ, ε,
362 config, &reg, &state, &mut stats
363 ); 344 );
364 345
365 // (Do not) merge spikes. 346 // (Do not) merge spikes.
366 if config.merge_now(&state) && !warned_merging { 347 if config.merge_now(&state) && !warned_merging {
367 let err = format!("Merging not supported for μFISTA"); 348 let err = format!("Merging not supported for μFISTA");
369 warned_merging = true; 350 warned_merging = true;
370 } 351 }
371 352
372 // Update inertial prameters 353 // Update inertial prameters
373 let λ_prev = λ; 354 let λ_prev = λ;
374 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); 355 λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt());
375 let θ = λ / λ_prev - λ; 356 let θ = λ / λ_prev - λ;
376 357
377 // Perform inertial update on μ. 358 // Perform inertial update on μ.
378 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ 359 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
379 // and μ_prev have zero weight. Since both have weights from the finite-dimensional 360 // and μ_prev have zero weight. Since both have weights from the finite-dimensional

mercurial