src/fb.rs

changeset 70
ed16d0f10d08
parent 63
7a8a55fd41c0
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
72 \end{aligned} 72 \end{aligned}
73 $$ 73 $$
74 </p> 74 </p>
75 75
76 We solve this with either SSN or FB as determined by 76 We solve this with either SSN or FB as determined by
77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. 77 [`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`].
78 */ 78 */
79 79
80 use crate::measures::merging::SpikeMerging;
81 use crate::measures::{DiscreteMeasure, RNDM};
82 use crate::plot::Plotter;
83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
84 use crate::regularisation::RegTerm;
85 use crate::types::*;
86 use alg_tools::error::DynResult;
87 use alg_tools::instance::Instance;
88 use alg_tools::iterate::AlgIteratorFactory;
89 use alg_tools::mapping::DifferentiableMapping;
90 use alg_tools::nalgebra_support::ToNalgebraRealField;
80 use colored::Colorize; 91 use colored::Colorize;
81 use numeric_literals::replace_float_literals; 92 use numeric_literals::replace_float_literals;
82 use serde::{Deserialize, Serialize}; 93 use serde::{Deserialize, Serialize};
83
84 use alg_tools::euclidean::Euclidean;
85 use alg_tools::instance::Instance;
86 use alg_tools::iterate::AlgIteratorFactory;
87 use alg_tools::linops::{Mapping, GEMV};
88 use alg_tools::mapping::RealMapping;
89 use alg_tools::nalgebra_support::ToNalgebraRealField;
90
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
93 use crate::measures::merging::SpikeMerging;
94 use crate::measures::{DiscreteMeasure, RNDM};
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
97 use crate::regularisation::RegTerm;
98 use crate::types::*;
99 94
100 /// Settings for [`pointsource_fb_reg`]. 95 /// Settings for [`pointsource_fb_reg`].
101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 96 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
102 #[serde(default)] 97 #[serde(default)]
103 pub struct FBConfig<F: Float> { 98 pub struct FBConfig<F: Float> {
104 /// Step length scaling 99 /// Step length scaling
105 pub τ0: F, 100 pub τ0: F,
101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`]
102 pub σp0: F,
106 /// Generic parameters 103 /// Generic parameters
107 pub generic: FBGenericConfig<F>, 104 pub insertion: InsertionConfig<F>,
108 } 105 }
109 106
110 #[replace_float_literals(F::cast_from(literal))] 107 #[replace_float_literals(F::cast_from(literal))]
111 impl<F: Float> Default for FBConfig<F> { 108 impl<F: Float> Default for FBConfig<F> {
112 fn default() -> Self { 109 fn default() -> Self {
113 FBConfig { 110 FBConfig { τ0: 0.99, σp0: 0.99, insertion: Default::default() }
114 τ0: 0.99,
115 generic: Default::default(),
116 }
117 } 111 }
118 } 112 }
119 113
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { 114 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
121 let n_before_prune = μ.len(); 115 let n_before_prune = μ.len();
122 μ.prune(); 116 μ.prune();
123 debug_assert!(μ.len() <= n_before_prune); 117 debug_assert!(μ.len() <= n_before_prune);
124 n_before_prune - μ.len() 118 n_before_prune - μ.len()
125 } 119 }
126 120
127 #[replace_float_literals(F::cast_from(literal))] 121 #[replace_float_literals(F::cast_from(literal))]
128 pub(crate) fn postprocess< 122 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
129 F: Float, 123 mut μ: RNDM<N, F>,
130 V: Euclidean<F> + Clone, 124 config: &InsertionConfig<F>,
131 A: GEMV<F, RNDM<F, N>, Codomain = V>, 125 f: Dat,
132 D: DataTerm<F, V, N>, 126 ) -> DynResult<RNDM<N, F>>
133 const N: usize,
134 >(
135 mut μ: RNDM<F, N>,
136 config: &FBGenericConfig<F>,
137 dataterm: D,
138 opA: &A,
139 b: &V,
140 ) -> RNDM<F, N>
141 where 127 where
142 RNDM<F, N>: SpikeMerging<F>, 128 RNDM<N, F>: SpikeMerging<F>,
143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, 129 for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>,
144 { 130 {
145 μ.merge_spikes_fitness( 131 //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v);
146 config.final_merging_method(), 132 μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v);
147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
148 |&v| v,
149 );
150 μ.prune(); 133 μ.prune();
151 μ 134 Ok(μ)
152 } 135 }
153 136
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. 137 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
155 /// 138 ///
156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 139 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 142 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
160 /// as documented in [`alg_tools::iterate`]. 143 /// as documented in [`alg_tools::iterate`].
161 /// 144 ///
162 /// For details on the mathematical formulation, see the [module level](self) documentation. 145 /// For details on the mathematical formulation, see the [module level](self) documentation.
163 /// 146 ///
164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
165 /// sums of simple functions usign bisection trees, and the related
166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
167 /// active at a specific points, and to maximise their sums. Through the implementation of the
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
170 ///
171 /// Returns the final iterate. 147 /// Returns the final iterate.
172 #[replace_float_literals(F::cast_from(literal))] 148 #[replace_float_literals(F::cast_from(literal))]
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( 149 pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
174 opA: &A, 150 f: &Dat,
175 b: &A::Observable, 151 reg: &Reg,
176 reg: Reg,
177 prox_penalty: &P, 152 prox_penalty: &P,
178 fbconfig: &FBConfig<F>, 153 fbconfig: &FBConfig<F>,
179 iterator: I, 154 iterator: I,
180 mut plotter: SeqPlotter<F, N>, 155 mut plotter: Plot,
181 ) -> RNDM<F, N> 156 μ0: Option<RNDM<N, F>>,
157 ) -> DynResult<RNDM<N, F>>
182 where 158 where
183 F: Float + ToNalgebraRealField, 159 F: Float + ToNalgebraRealField,
184 I: AlgIteratorFactory<IterInfo<F, N>>, 160 I: AlgIteratorFactory<IterInfo<F>>,
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, 161 RNDM<N, F>: SpikeMerging<F>,
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, 162 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
187 A::PreadjointCodomain: RealMapping<F, N>, 163 Dat::DerivativeDomain: ClosedMul<F>,
188 PlotLookup: Plotting<N>, 164 Reg: RegTerm<Loc<N, F>, F>,
189 RNDM<F, N>: SpikeMerging<F>, 165 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
190 Reg: RegTerm<F, N>, 166 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
192 { 167 {
193 // Set up parameters 168 // Set up parameters
194 let config = &fbconfig.generic; 169 let config = &fbconfig.insertion;
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 170 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 171 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
197 // by τ compared to the conditional gradient approach. 172 // by τ compared to the conditional gradient approach.
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 173 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
199 let mut ε = tolerance.initial(); 174 let mut ε = tolerance.initial();
200 175
201 // Initialise iterates 176 // Initialise iterates
202 let mut μ = DiscreteMeasure::new(); 177 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
203 let mut residual = -b;
204 178
205 // Statistics 179 // Statistics
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { 180 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
207 value: residual.norm2_squared_div2() + reg.apply(μ), 181 value: f.apply(μ) + reg.apply(μ),
208 n_spikes: μ.len(), 182 n_spikes: μ.len(),
209 ε, 183 ε,
210 //postprocessing: config.postprocessing.then(|| μ.clone()), 184 //postprocessing: config.postprocessing.then(|| μ.clone()),
211 ..stats 185 ..stats
212 }; 186 };
213 let mut stats = IterInfo::new(); 187 let mut stats = IterInfo::new();
214 188
215 // Run the algorithm 189 // Run the algorithm
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 190 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
217 // Calculate smooth part of surrogate model. 191 // Calculate smooth part of surrogate model.
218 let mut τv = opA.preadjoint().apply(residual * τ); 192 // TODO: optimise τ to be applied to residual.
219 193 let mut τv = f.differential(&μ) * τ;
220 // Save current base point 194
221 let μ_base = μ.clone(); 195 // Save current base point for merge
196 let μ_base_len = μ.len();
197 let maybe_μ_base = config.merge_now(&state).then(|| μ.clone());
222 198
223 // Insert and reweigh 199 // Insert and reweigh
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 200 let (maybe_d, _within_tolerances) = prox_penalty
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 201 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, &reg, &state, &mut stats)?;
226 ); 202
203 stats.inserted += μ.len() - μ_base_len;
227 204
228 // Prune and possibly merge spikes 205 // Prune and possibly merge spikes
229 if config.merge_now(&state) { 206 if let Some(μ_base) = maybe_μ_base {
230 stats.merged += prox_penalty.merge_spikes( 207 stats.merged += prox_penalty.merge_spikes(
231 &mut μ, 208 &mut μ,
232 &mut τv, 209 &mut τv,
233 &μ_base, 210 &μ_base,
234 None,
235 τ, 211 τ,
236 ε, 212 ε,
237 config, 213 config,
238 &reg, 214 &reg,
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 215 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
240 ); 216 );
241 } 217 }
242 218
243 stats.pruned += prune_with_stats(&mut μ); 219 stats.pruned += prune_with_stats(&mut μ);
244
245 // Update residual
246 residual = calculate_residual(&μ, opA, b);
247 220
248 let iter = state.iteration(); 221 let iter = state.iteration();
249 stats.this_iters += 1; 222 stats.this_iters += 1;
250 223
251 // Give statistics if needed 224 // Give statistics if needed
252 state.if_verbose(|| { 225 state.if_verbose(|| {
253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); 226 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
254 full_stats( 227 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
255 &residual,
256 &μ,
257 ε,
258 std::mem::replace(&mut stats, IterInfo::new()),
259 )
260 }); 228 });
261 229
262 // Update main tolerance for next iteration 230 // Update main tolerance for next iteration
263 ε = tolerance.update(ε, iter); 231 ε = tolerance.update(ε, iter);
264 } 232 }
265 233
266 postprocess(μ, config, L2Squared, opA, b) 234 //postprocess(μ_prev, config, f)
235 postprocess(μ, config, |μ̃| f.apply(μ̃))
267 } 236 }
268 237
269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. 238 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
270 /// 239 ///
271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 240 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 243 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
275 /// as documented in [`alg_tools::iterate`]. 244 /// as documented in [`alg_tools::iterate`].
276 /// 245 ///
277 /// For details on the mathematical formulation, see the [module level](self) documentation. 246 /// For details on the mathematical formulation, see the [module level](self) documentation.
278 /// 247 ///
279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
280 /// sums of simple functions usign bisection trees, and the related
281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
282 /// active at a specific points, and to maximise their sums. Through the implementation of the
283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
285 ///
286 /// Returns the final iterate. 248 /// Returns the final iterate.
287 #[replace_float_literals(F::cast_from(literal))] 249 #[replace_float_literals(F::cast_from(literal))]
288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( 250 pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
289 opA: &A, 251 f: &Dat,
290 b: &A::Observable, 252 reg: &Reg,
291 reg: Reg,
292 prox_penalty: &P, 253 prox_penalty: &P,
293 fbconfig: &FBConfig<F>, 254 fbconfig: &FBConfig<F>,
294 iterator: I, 255 iterator: I,
295 mut plotter: SeqPlotter<F, N>, 256 mut plotter: Plot,
296 ) -> RNDM<F, N> 257 μ0: Option<RNDM<N, F>>,
258 ) -> DynResult<RNDM<N, F>>
297 where 259 where
298 F: Float + ToNalgebraRealField, 260 F: Float + ToNalgebraRealField,
299 I: AlgIteratorFactory<IterInfo<F, N>>, 261 I: AlgIteratorFactory<IterInfo<F>>,
300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, 262 RNDM<N, F>: SpikeMerging<F>,
301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, 263 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
302 A::PreadjointCodomain: RealMapping<F, N>, 264 Dat::DerivativeDomain: ClosedMul<F>,
303 PlotLookup: Plotting<N>, 265 Reg: RegTerm<Loc<N, F>, F>,
304 RNDM<F, N>: SpikeMerging<F>, 266 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
305 Reg: RegTerm<F, N>, 267 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
307 { 268 {
308 // Set up parameters 269 // Set up parameters
309 let config = &fbconfig.generic; 270 let config = &fbconfig.insertion;
310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 271 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
311 let mut λ = 1.0; 272 let mut λ = 1.0;
312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 273 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
313 // by τ compared to the conditional gradient approach. 274 // by τ compared to the conditional gradient approach.
314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 275 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
315 let mut ε = tolerance.initial(); 276 let mut ε = tolerance.initial();
316 277
317 // Initialise iterates 278 // Initialise iterates
318 let mut μ = DiscreteMeasure::new(); 279 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
319 let mut μ_prev = DiscreteMeasure::new(); 280 let mut μ_prev = μ.clone();
320 let mut residual = -b;
321 let mut warned_merging = false; 281 let mut warned_merging = false;
322 282
323 // Statistics 283 // Statistics
324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { 284 let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo {
325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), 285 value: f.apply(ν) + reg.apply(ν),
326 n_spikes: ν.len(), 286 n_spikes: ν.len(),
327 ε, 287 ε,
328 // postprocessing: config.postprocessing.then(|| ν.clone()), 288 // postprocessing: config.postprocessing.then(|| ν.clone()),
329 ..stats 289 ..stats
330 }; 290 };
331 let mut stats = IterInfo::new(); 291 let mut stats = IterInfo::new();
332 292
333 // Run the algorithm 293 // Run the algorithm
334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 294 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
335 // Calculate smooth part of surrogate model. 295 // Calculate smooth part of surrogate model.
336 let mut τv = opA.preadjoint().apply(residual * τ); 296 let mut τv = f.differential(&μ) * τ;
337 297
338 // Save current base point 298 let μ_base_len = μ.len();
339 let μ_base = μ.clone();
340 299
341 // Insert new spikes and reweigh 300 // Insert new spikes and reweigh
342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 301 let (maybe_d, _within_tolerances) = prox_penalty
343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 302 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, &reg, &state, &mut stats)?;
344 ); 303
304 stats.inserted += μ.len() - μ_base_len;
345 305
346 // (Do not) merge spikes. 306 // (Do not) merge spikes.
347 if config.merge_now(&state) && !warned_merging { 307 if config.merge_now(&state) && !warned_merging {
348 let err = format!("Merging not supported for μFISTA"); 308 let err = format!("Merging not supported for μFISTA");
349 println!("{}", err.red()); 309 println!("{}", err.red());
367 // μ_prev = μ; 327 // μ_prev = μ;
368 // μ = μ_new; 328 // μ = μ_new;
369 debug_assert!(μ.len() <= n_before_prune); 329 debug_assert!(μ.len() <= n_before_prune);
370 stats.pruned += n_before_prune - μ.len(); 330 stats.pruned += n_before_prune - μ.len();
371 331
372 // Update residual
373 residual = calculate_residual(&μ, opA, b);
374
375 let iter = state.iteration(); 332 let iter = state.iteration();
376 stats.this_iters += 1; 333 stats.this_iters += 1;
377 334
378 // Give statistics if needed 335 // Give statistics if needed
379 state.if_verbose(|| { 336 state.if_verbose(|| {
383 340
384 // Update main tolerance for next iteration 341 // Update main tolerance for next iteration
385 ε = tolerance.update(ε, iter); 342 ε = tolerance.update(ε, iter);
386 } 343 }
387 344
388 postprocess(μ_prev, config, L2Squared, opA, b) 345 //postprocess(μ_prev, config, f)
389 } 346 postprocess(μ_prev, config, |μ̃| f.apply(μ̃))
347 }

mercurial