src/fb.rs

branch
dev
changeset 61
4f468d35fa29
parent 51
0693cc9ba9f0
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
72 \end{aligned} 72 \end{aligned}
73 $$ 73 $$
74 </p> 74 </p>
75 75
76 We solve this with either SSN or FB as determined by 76 We solve this with either SSN or FB as determined by
77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. 77 [`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`].
78 */ 78 */
79 79
80 use crate::measures::merging::SpikeMerging;
81 use crate::measures::{DiscreteMeasure, RNDM};
82 use crate::plot::Plotter;
83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
84 use crate::regularisation::RegTerm;
85 use crate::types::*;
86 use alg_tools::error::DynResult;
87 use alg_tools::instance::Instance;
88 use alg_tools::iterate::AlgIteratorFactory;
89 use alg_tools::mapping::DifferentiableMapping;
90 use alg_tools::nalgebra_support::ToNalgebraRealField;
80 use colored::Colorize; 91 use colored::Colorize;
81 use numeric_literals::replace_float_literals; 92 use numeric_literals::replace_float_literals;
82 use serde::{Deserialize, Serialize}; 93 use serde::{Deserialize, Serialize};
83
84 use alg_tools::euclidean::Euclidean;
85 use alg_tools::instance::Instance;
86 use alg_tools::iterate::AlgIteratorFactory;
87 use alg_tools::linops::{Mapping, GEMV};
88 use alg_tools::mapping::RealMapping;
89 use alg_tools::nalgebra_support::ToNalgebraRealField;
90
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
93 use crate::measures::merging::SpikeMerging;
94 use crate::measures::{DiscreteMeasure, RNDM};
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
97 use crate::regularisation::RegTerm;
98 use crate::types::*;
99 94
100 /// Settings for [`pointsource_fb_reg`]. 95 /// Settings for [`pointsource_fb_reg`].
101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 96 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
102 #[serde(default)] 97 #[serde(default)]
103 pub struct FBConfig<F: Float> { 98 pub struct FBConfig<F: Float> {
104 /// Step length scaling 99 /// Step length scaling
105 pub τ0: F, 100 pub τ0: F,
101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`]
102 pub σp0: F,
106 /// Generic parameters 103 /// Generic parameters
107 pub generic: FBGenericConfig<F>, 104 pub insertion: InsertionConfig<F>,
108 } 105 }
109 106
110 #[replace_float_literals(F::cast_from(literal))] 107 #[replace_float_literals(F::cast_from(literal))]
111 impl<F: Float> Default for FBConfig<F> { 108 impl<F: Float> Default for FBConfig<F> {
112 fn default() -> Self { 109 fn default() -> Self {
113 FBConfig { 110 FBConfig {
114 τ0: 0.99, 111 τ0: 0.99,
115 generic: Default::default(), 112 σp0: 0.99,
113 insertion: Default::default(),
116 } 114 }
117 } 115 }
118 } 116 }
119 117
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { 118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
121 let n_before_prune = μ.len(); 119 let n_before_prune = μ.len();
122 μ.prune(); 120 μ.prune();
123 debug_assert!(μ.len() <= n_before_prune); 121 debug_assert!(μ.len() <= n_before_prune);
124 n_before_prune - μ.len() 122 n_before_prune - μ.len()
125 } 123 }
126 124
127 #[replace_float_literals(F::cast_from(literal))] 125 #[replace_float_literals(F::cast_from(literal))]
128 pub(crate) fn postprocess< 126 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
129 F: Float, 127 mut μ: RNDM<N, F>,
130 V: Euclidean<F> + Clone, 128 config: &InsertionConfig<F>,
131 A: GEMV<F, RNDM<F, N>, Codomain = V>, 129 f: Dat,
132 D: DataTerm<F, V, N>, 130 ) -> DynResult<RNDM<N, F>>
133 const N: usize,
134 >(
135 mut μ: RNDM<F, N>,
136 config: &FBGenericConfig<F>,
137 dataterm: D,
138 opA: &A,
139 b: &V,
140 ) -> RNDM<F, N>
141 where 131 where
142 RNDM<F, N>: SpikeMerging<F>, 132 RNDM<N, F>: SpikeMerging<F>,
143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, 133 for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>,
144 { 134 {
145 μ.merge_spikes_fitness( 135 //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v);
146 config.final_merging_method(), 136 μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v);
147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
148 |&v| v,
149 );
150 μ.prune(); 137 μ.prune();
151 μ 138 Ok(μ)
152 } 139 }
153 140
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. 141 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
155 /// 142 ///
156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 143 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 146 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
160 /// as documented in [`alg_tools::iterate`]. 147 /// as documented in [`alg_tools::iterate`].
161 /// 148 ///
162 /// For details on the mathematical formulation, see the [module level](self) documentation. 149 /// For details on the mathematical formulation, see the [module level](self) documentation.
163 /// 150 ///
164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
165 /// sums of simple functions usign bisection trees, and the related
166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
167 /// active at a specific points, and to maximise their sums. Through the implementation of the
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
170 ///
171 /// Returns the final iterate. 151 /// Returns the final iterate.
172 #[replace_float_literals(F::cast_from(literal))] 152 #[replace_float_literals(F::cast_from(literal))]
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( 153 pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
174 opA: &A, 154 f: &Dat,
175 b: &A::Observable, 155 reg: &Reg,
176 reg: Reg,
177 prox_penalty: &P, 156 prox_penalty: &P,
178 fbconfig: &FBConfig<F>, 157 fbconfig: &FBConfig<F>,
179 iterator: I, 158 iterator: I,
180 mut plotter: SeqPlotter<F, N>, 159 mut plotter: Plot,
181 ) -> RNDM<F, N> 160 μ0 : Option<RNDM<N, F>>,
161 ) -> DynResult<RNDM<N, F>>
182 where 162 where
183 F: Float + ToNalgebraRealField, 163 F: Float + ToNalgebraRealField,
184 I: AlgIteratorFactory<IterInfo<F, N>>, 164 I: AlgIteratorFactory<IterInfo<F>>,
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, 165 RNDM<N, F>: SpikeMerging<F>,
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, 166 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
187 A::PreadjointCodomain: RealMapping<F, N>, 167 Dat::DerivativeDomain: ClosedMul<F>,
188 PlotLookup: Plotting<N>, 168 Reg: RegTerm<Loc<N, F>, F>,
189 RNDM<F, N>: SpikeMerging<F>, 169 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
190 Reg: RegTerm<F, N>, 170 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
192 { 171 {
193 // Set up parameters 172 // Set up parameters
194 let config = &fbconfig.generic; 173 let config = &fbconfig.insertion;
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 174 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 175 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
197 // by τ compared to the conditional gradient approach. 176 // by τ compared to the conditional gradient approach.
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 177 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
199 let mut ε = tolerance.initial(); 178 let mut ε = tolerance.initial();
200 179
201 // Initialise iterates 180 // Initialise iterates
202 let mut μ = DiscreteMeasure::new(); 181 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
203 let mut residual = -b;
204 182
205 // Statistics 183 // Statistics
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { 184 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
207 value: residual.norm2_squared_div2() + reg.apply(μ), 185 value: f.apply(μ) + reg.apply(μ),
208 n_spikes: μ.len(), 186 n_spikes: μ.len(),
209 ε, 187 ε,
210 //postprocessing: config.postprocessing.then(|| μ.clone()), 188 //postprocessing: config.postprocessing.then(|| μ.clone()),
211 ..stats 189 ..stats
212 }; 190 };
213 let mut stats = IterInfo::new(); 191 let mut stats = IterInfo::new();
214 192
215 // Run the algorithm 193 // Run the algorithm
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
217 // Calculate smooth part of surrogate model. 195 // Calculate smooth part of surrogate model.
218 let mut τv = opA.preadjoint().apply(residual * τ); 196 // TODO: optimise τ to be applied to residual.
197 let mut τv = f.differential(&μ) * τ;
219 198
220 // Save current base point 199 // Save current base point
221 let μ_base = μ.clone(); 200 let μ_base = μ.clone();
222 201
223 // Insert and reweigh 202 // Insert and reweigh
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
226 ); 205 )?;
227 206
228 // Prune and possibly merge spikes 207 // Prune and possibly merge spikes
229 if config.merge_now(&state) { 208 if config.merge_now(&state) {
230 stats.merged += prox_penalty.merge_spikes( 209 stats.merged += prox_penalty.merge_spikes(
231 &mut μ, 210 &mut μ,
234 None, 213 None,
235 τ, 214 τ,
236 ε, 215 ε,
237 config, 216 config,
238 &reg, 217 &reg,
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 218 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
240 ); 219 );
241 } 220 }
242 221
243 stats.pruned += prune_with_stats(&mut μ); 222 stats.pruned += prune_with_stats(&mut μ);
244
245 // Update residual
246 residual = calculate_residual(&μ, opA, b);
247 223
248 let iter = state.iteration(); 224 let iter = state.iteration();
249 stats.this_iters += 1; 225 stats.this_iters += 1;
250 226
251 // Give statistics if needed 227 // Give statistics if needed
252 state.if_verbose(|| { 228 state.if_verbose(|| {
253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); 229 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
254 full_stats( 230 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
255 &residual,
256 &μ,
257 ε,
258 std::mem::replace(&mut stats, IterInfo::new()),
259 )
260 }); 231 });
261 232
262 // Update main tolerance for next iteration 233 // Update main tolerance for next iteration
263 ε = tolerance.update(ε, iter); 234 ε = tolerance.update(ε, iter);
264 } 235 }
265 236
266 postprocess(μ, config, L2Squared, opA, b) 237 //postprocess(μ_prev, config, f)
238 postprocess(μ, config, |μ̃| f.apply(μ̃))
267 } 239 }
268 240
269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. 241 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
270 /// 242 ///
271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 243 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 246 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
275 /// as documented in [`alg_tools::iterate`]. 247 /// as documented in [`alg_tools::iterate`].
276 /// 248 ///
277 /// For details on the mathematical formulation, see the [module level](self) documentation. 249 /// For details on the mathematical formulation, see the [module level](self) documentation.
278 /// 250 ///
279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
280 /// sums of simple functions usign bisection trees, and the related
281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
282 /// active at a specific points, and to maximise their sums. Through the implementation of the
283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
285 ///
286 /// Returns the final iterate. 251 /// Returns the final iterate.
287 #[replace_float_literals(F::cast_from(literal))] 252 #[replace_float_literals(F::cast_from(literal))]
288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( 253 pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
289 opA: &A, 254 f: &Dat,
290 b: &A::Observable, 255 reg: &Reg,
291 reg: Reg,
292 prox_penalty: &P, 256 prox_penalty: &P,
293 fbconfig: &FBConfig<F>, 257 fbconfig: &FBConfig<F>,
294 iterator: I, 258 iterator: I,
295 mut plotter: SeqPlotter<F, N>, 259 mut plotter: Plot,
296 ) -> RNDM<F, N> 260 μ0: Option<RNDM<N, F>>
261 ) -> DynResult<RNDM<N, F>>
297 where 262 where
298 F: Float + ToNalgebraRealField, 263 F: Float + ToNalgebraRealField,
299 I: AlgIteratorFactory<IterInfo<F, N>>, 264 I: AlgIteratorFactory<IterInfo<F>>,
300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, 265 RNDM<N, F>: SpikeMerging<F>,
301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, 266 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
302 A::PreadjointCodomain: RealMapping<F, N>, 267 Dat::DerivativeDomain: ClosedMul<F>,
303 PlotLookup: Plotting<N>, 268 Reg: RegTerm<Loc<N, F>, F>,
304 RNDM<F, N>: SpikeMerging<F>, 269 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
305 Reg: RegTerm<F, N>, 270 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
307 { 271 {
308 // Set up parameters 272 // Set up parameters
309 let config = &fbconfig.generic; 273 let config = &fbconfig.insertion;
310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 274 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
311 let mut λ = 1.0; 275 let mut λ = 1.0;
312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 276 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
313 // by τ compared to the conditional gradient approach. 277 // by τ compared to the conditional gradient approach.
314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 278 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
315 let mut ε = tolerance.initial(); 279 let mut ε = tolerance.initial();
316 280
317 // Initialise iterates 281 // Initialise iterates
318 let mut μ = DiscreteMeasure::new(); 282 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
319 let mut μ_prev = DiscreteMeasure::new(); 283 let mut μ_prev = μ.clone();
320 let mut residual = -b;
321 let mut warned_merging = false; 284 let mut warned_merging = false;
322 285
323 // Statistics 286 // Statistics
324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { 287 let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo {
325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), 288 value: f.apply(ν) + reg.apply(ν),
326 n_spikes: ν.len(), 289 n_spikes: ν.len(),
327 ε, 290 ε,
328 // postprocessing: config.postprocessing.then(|| ν.clone()), 291 // postprocessing: config.postprocessing.then(|| ν.clone()),
329 ..stats 292 ..stats
330 }; 293 };
331 let mut stats = IterInfo::new(); 294 let mut stats = IterInfo::new();
332 295
333 // Run the algorithm 296 // Run the algorithm
334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
335 // Calculate smooth part of surrogate model. 298 // Calculate smooth part of surrogate model.
336 let mut τv = opA.preadjoint().apply(residual * τ); 299 let mut τv = f.differential(&μ) * τ;
337 300
338 // Save current base point 301 // Save current base point
339 let μ_base = μ.clone(); 302 let μ_base = μ.clone();
340 303
341 // Insert new spikes and reweigh 304 // Insert new spikes and reweigh
342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
344 ); 307 )?;
345 308
346 // (Do not) merge spikes. 309 // (Do not) merge spikes.
347 if config.merge_now(&state) && !warned_merging { 310 if config.merge_now(&state) && !warned_merging {
348 let err = format!("Merging not supported for μFISTA"); 311 let err = format!("Merging not supported for μFISTA");
349 println!("{}", err.red()); 312 println!("{}", err.red());
367 // μ_prev = μ; 330 // μ_prev = μ;
368 // μ = μ_new; 331 // μ = μ_new;
369 debug_assert!(μ.len() <= n_before_prune); 332 debug_assert!(μ.len() <= n_before_prune);
370 stats.pruned += n_before_prune - μ.len(); 333 stats.pruned += n_before_prune - μ.len();
371 334
372 // Update residual
373 residual = calculate_residual(&μ, opA, b);
374
375 let iter = state.iteration(); 335 let iter = state.iteration();
376 stats.this_iters += 1; 336 stats.this_iters += 1;
377 337
378 // Give statistics if needed 338 // Give statistics if needed
379 state.if_verbose(|| { 339 state.if_verbose(|| {
383 343
384 // Update main tolerance for next iteration 344 // Update main tolerance for next iteration
385 ε = tolerance.update(ε, iter); 345 ε = tolerance.update(ε, iter);
386 } 346 }
387 347
388 postprocess(μ_prev, config, L2Squared, opA, b) 348 //postprocess(μ_prev, config, f)
389 } 349 postprocess(μ_prev, config, |μ̃| f.apply(μ̃))
350 }

mercurial