src/sliding_fb.rs

branch
dev
changeset 61
4f468d35fa29
parent 49
6b0db7251ebe
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 12
13 use crate::fb::*;
14 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
15 use crate::measures::merging::SpikeMerging;
16 use crate::measures::{DiscreteMeasure, Radon, RNDM};
17 use crate::plot::Plotter;
18 use crate::prox_penalty::{ProxPenalty, StepLengthBound};
19 use crate::regularisation::SlidingRegTerm;
20 use crate::types::*;
21 use alg_tools::error::DynResult;
13 use alg_tools::euclidean::Euclidean; 22 use alg_tools::euclidean::Euclidean;
14 use alg_tools::iterate::AlgIteratorFactory; 23 use alg_tools::iterate::AlgIteratorFactory;
15 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; 24 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
16 use alg_tools::nalgebra_support::ToNalgebraRealField; 25 use alg_tools::nalgebra_support::ToNalgebraRealField;
17 use alg_tools::norms::Norm; 26 use alg_tools::norms::Norm;
18 27 use anyhow::ensure;
19 use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
20 use crate::measures::merging::SpikeMerging;
21 use crate::measures::{DiscreteMeasure, Radon, RNDM};
22 use crate::types::*;
23 //use crate::tolerance::Tolerance;
24 use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
25 use crate::fb::*;
26 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
27 use crate::regularisation::SlidingRegTerm;
28 //use crate::transport::TransportLipschitz;
29 28
30 /// Transport settings for [`pointsource_sliding_fb_reg`]. 29 /// Transport settings for [`pointsource_sliding_fb_reg`].
31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 30 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
32 #[serde(default)] 31 #[serde(default)]
33 pub struct TransportConfig<F: Float> { 32 pub struct TransportConfig<F: Float> {
40 } 39 }
41 40
42 #[replace_float_literals(F::cast_from(literal))] 41 #[replace_float_literals(F::cast_from(literal))]
43 impl<F: Float> TransportConfig<F> { 42 impl<F: Float> TransportConfig<F> {
44 /// Check that the parameters are ok. Panics if not. 43 /// Check that the parameters are ok. Panics if not.
45 pub fn check(&self) { 44 pub fn check(&self) -> DynResult<()> {
46 assert!(self.θ0 > 0.0); 45 ensure!(self.θ0 > 0.0);
47 assert!(0.0 < self.adaptation && self.adaptation < 1.0); 46 ensure!(0.0 < self.adaptation && self.adaptation < 1.0);
48 assert!(self.tolerance_mult_con > 0.0); 47 ensure!(self.tolerance_mult_con > 0.0);
48 Ok(())
49 } 49 }
50 } 50 }
51 51
52 #[replace_float_literals(F::cast_from(literal))] 52 #[replace_float_literals(F::cast_from(literal))]
53 impl<F: Float> Default for TransportConfig<F> { 53 impl<F: Float> Default for TransportConfig<F> {
54 fn default() -> Self { 54 fn default() -> Self {
55 TransportConfig { 55 TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 }
56 θ0: 0.9,
57 adaptation: 0.9,
58 tolerance_mult_con: 100.0,
59 }
60 } 56 }
61 } 57 }
62 58
63 /// Settings for [`pointsource_sliding_fb_reg`]. 59 /// Settings for [`pointsource_sliding_fb_reg`].
64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 60 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
65 #[serde(default)] 61 #[serde(default)]
66 pub struct SlidingFBConfig<F: Float> { 62 pub struct SlidingFBConfig<F: Float> {
67 /// Step length scaling 63 /// Step length scaling
68 pub τ0: F, 64 pub τ0: F,
65 // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`]
66 pub σp0: F,
69 /// Transport parameters 67 /// Transport parameters
70 pub transport: TransportConfig<F>, 68 pub transport: TransportConfig<F>,
71 /// Generic parameters 69 /// Generic parameters
72 pub insertion: FBGenericConfig<F>, 70 pub insertion: InsertionConfig<F>,
71 /// Guess for curvature bound calculations.
72 pub guess: BoundedCurvatureGuess,
73 } 73 }
74 74
75 #[replace_float_literals(F::cast_from(literal))] 75 #[replace_float_literals(F::cast_from(literal))]
76 impl<F: Float> Default for SlidingFBConfig<F> { 76 impl<F: Float> Default for SlidingFBConfig<F> {
77 fn default() -> Self { 77 fn default() -> Self {
78 SlidingFBConfig { 78 SlidingFBConfig {
79 τ0: 0.99, 79 τ0: 0.99,
80 σp0: 0.99,
80 transport: Default::default(), 81 transport: Default::default(),
81 insertion: Default::default(), 82 insertion: Default::default(),
83 guess: BoundedCurvatureGuess::BetterThanZero,
82 } 84 }
83 } 85 }
84 } 86 }
85 87
86 /// Internal type of adaptive transport step length calculation 88 /// Internal type of adaptive transport step length calculation
98 100
99 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` 101 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
100 /// with step lengh τ and transport step length `θ_or_adaptive`. 102 /// with step lengh τ and transport step length `θ_or_adaptive`.
101 #[replace_float_literals(F::cast_from(literal))] 103 #[replace_float_literals(F::cast_from(literal))]
102 pub(crate) fn initial_transport<F, G, D, const N: usize>( 104 pub(crate) fn initial_transport<F, G, D, const N: usize>(
103 γ1: &mut RNDM<F, N>, 105 γ1: &mut RNDM<N, F>,
104 μ: &mut RNDM<F, N>, 106 μ: &mut RNDM<N, F>,
105 τ: F, 107 τ: F,
106 θ_or_adaptive: &mut TransportStepLength<F, G>, 108 θ_or_adaptive: &mut TransportStepLength<F, G>,
107 v: D, 109 v: D,
108 ) -> (Vec<F>, RNDM<F, N>) 110 ) -> (Vec<F>, RNDM<N, F>)
109 where 111 where
110 F: Float + ToNalgebraRealField, 112 F: Float + ToNalgebraRealField,
111 G: Fn(F, F) -> F, 113 G: Fn(F, F) -> F,
112 D: DifferentiableRealMapping<F, N>, 114 D: DifferentiableRealMapping<N, F>,
113 { 115 {
114 use TransportStepLength::*; 116 use TransportStepLength::*;
115 117
116 // Save current base point and shift μ to new positions. Idea is that 118 // Save current base point and shift μ to new positions. Idea is that
117 // μ_base(_masses) = μ^k (vector of masses) 119 // μ_base(_masses) = μ^k (vector of masses)
143 let θτ = τ * θ; 145 let θτ = τ * θ;
144 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 146 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
145 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 147 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
146 } 148 }
147 } 149 }
148 AdaptiveMax { 150 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
149 l: ℓ_F,
150 ref mut max_transport,
151 g: ref calculate_θ,
152 } => {
153 *max_transport = max_transport.max(γ1.norm(Radon)); 151 *max_transport = max_transport.max(γ1.norm(Radon));
154 let θτ = τ * calculate_θ(ℓ_F, *max_transport); 152 let θτ = τ * calculate_θ(ℓ_F, *max_transport);
155 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 153 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
156 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 154 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
157 } 155 }
158 } 156 }
159 FullyAdaptive { 157 FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
160 l: ref mut adaptive_ℓ_F,
161 ref mut max_transport,
162 g: ref calculate_θ,
163 } => {
164 *max_transport = max_transport.max(γ1.norm(Radon)); 158 *max_transport = max_transport.max(γ1.norm(Radon));
165 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 159 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
166 // Do two runs through the spikes to update θ, breaking if first run did not cause 160 // Do two runs through the spikes to update θ, breaking if first run did not cause
167 // a change. 161 // a change.
168 for _i in 0..=1 { 162 for _i in 0..=1 {
207 } 201 }
208 202
209 /// A posteriori transport adaptation. 203 /// A posteriori transport adaptation.
210 #[replace_float_literals(F::cast_from(literal))] 204 #[replace_float_literals(F::cast_from(literal))]
211 pub(crate) fn aposteriori_transport<F, const N: usize>( 205 pub(crate) fn aposteriori_transport<F, const N: usize>(
212 γ1: &mut RNDM<F, N>, 206 γ1: &mut RNDM<N, F>,
213 μ: &mut RNDM<F, N>, 207 μ: &mut RNDM<N, F>,
214 μ_base_minus_γ0: &mut RNDM<F, N>, 208 μ_base_minus_γ0: &mut RNDM<N, F>,
215 μ_base_masses: &Vec<F>, 209 μ_base_masses: &Vec<F>,
216 extra: Option<F>, 210 extra: Option<F>,
217 ε: F, 211 ε: F,
218 tconfig: &TransportConfig<F>, 212 tconfig: &TransportConfig<F>,
219 ) -> bool 213 ) -> bool
262 /// splitting 256 /// splitting
263 /// 257 ///
264 /// The parametrisation is as for [`pointsource_fb_reg`]. 258 /// The parametrisation is as for [`pointsource_fb_reg`].
265 /// Inertia is currently not supported. 259 /// Inertia is currently not supported.
266 #[replace_float_literals(F::cast_from(literal))] 260 #[replace_float_literals(F::cast_from(literal))]
267 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>( 261 pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>(
268 opA: &A, 262 f: &Dat,
269 b: &A::Observable, 263 reg: &Reg,
270 reg: Reg,
271 prox_penalty: &P, 264 prox_penalty: &P,
272 config: &SlidingFBConfig<F>, 265 config: &SlidingFBConfig<F>,
273 iterator: I, 266 iterator: I,
274 mut plotter: SeqPlotter<F, N>, 267 mut plotter: Plot,
275 ) -> RNDM<F, N> 268 μ0: Option<RNDM<N, F>>,
269 ) -> DynResult<RNDM<N, F>>
276 where 270 where
277 F: Float + ToNalgebraRealField, 271 F: Float + ToNalgebraRealField,
278 I: AlgIteratorFactory<IterInfo<F, N>>, 272 I: AlgIteratorFactory<IterInfo<F>>,
279 A: ForwardModel<RNDM<F, N>, F> 273 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
280 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F> 274 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
281 + BoundedCurvature<FloatType = F>, 275 //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>,
282 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, 276 RNDM<N, F>: SpikeMerging<F>,
283 A::PreadjointCodomain: DifferentiableRealMapping<F, N>, 277 Reg: SlidingRegTerm<Loc<N, F>, F>,
284 RNDM<F, N>: SpikeMerging<F>, 278 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
285 Reg: SlidingRegTerm<F, N>, 279 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
286 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
287 PlotLookup: Plotting<N>,
288 { 280 {
289 // Check parameters 281 // Check parameters
290 assert!(config.τ0 > 0.0, "Invalid step length parameter"); 282 ensure!(config.τ0 > 0.0, "Invalid step length parameter");
291 config.transport.check(); 283 config.transport.check()?;
292 284
293 // Initialise iterates 285 // Initialise iterates
294 let mut μ = DiscreteMeasure::new(); 286 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
295 let mut γ1 = DiscreteMeasure::new(); 287 let mut γ1 = DiscreteMeasure::new();
296 let mut residual = -b; // Has to equal $Aμ-b$.
297 288
298 // Set up parameters 289 // Set up parameters
299 // let opAnorm = opA.opnorm_bound(Radon, L2); 290 // let opAnorm = opA.opnorm_bound(Radon, L2);
300 //let max_transport = config.max_transport.scale 291 //let max_transport = config.max_transport.scale
301 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 292 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
302 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 293 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
303 let ℓ = 0.0; 294 let ℓ = 0.0;
304 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 295 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
305 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); 296 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
306 let transport_lip = maybe_transport_lip.unwrap(); 297 let transport_lip = maybe_transport_lip?;
307 let calculate_θ = |ℓ_F, max_transport| { 298 let calculate_θ = |ℓ_F, max_transport| {
308 let ℓ_r = transport_lip * max_transport; 299 let ℓ_r = transport_lip * max_transport;
309 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) 300 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
310 }; 301 };
311 let mut θ_or_adaptive = match maybe_ℓ_F0 { 302 let mut θ_or_adaptive = match maybe_ℓ_F {
312 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), 303 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
313 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { 304 Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
314 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual 305 l: ℓ_F, // TODO: could estimate computing the real reesidual
315 max_transport: 0.0, 306 max_transport: 0.0,
316 g: calculate_θ, 307 g: calculate_θ,
317 }, 308 },
318 None => TransportStepLength::FullyAdaptive { 309 Err(_) => TransportStepLength::FullyAdaptive {
319 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials 310 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
320 max_transport: 0.0, 311 max_transport: 0.0,
321 g: calculate_θ, 312 g: calculate_θ,
322 }, 313 },
323 }; 314 };
325 // by τ compared to the conditional gradient approach. 316 // by τ compared to the conditional gradient approach.
326 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); 317 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
327 let mut ε = tolerance.initial(); 318 let mut ε = tolerance.initial();
328 319
329 // Statistics 320 // Statistics
330 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { 321 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
331 value: residual.norm2_squared_div2() + reg.apply(μ), 322 value: f.apply(μ) + reg.apply(μ),
332 n_spikes: μ.len(), 323 n_spikes: μ.len(),
333 ε, 324 ε,
334 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), 325 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
335 ..stats 326 ..stats
336 }; 327 };
337 let mut stats = IterInfo::new(); 328 let mut stats = IterInfo::new();
338 329
339 // Run the algorithm 330 // Run the algorithm
340 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 331 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
341 // Calculate initial transport 332 // Calculate initial transport
342 let v = opA.preadjoint().apply(residual); 333 let v = f.differential(&μ);
343 let (μ_base_masses, mut μ_base_minus_γ0) = 334 let (μ_base_masses, mut μ_base_minus_γ0) =
344 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); 335 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
345 336
346 // Solve finite-dimensional subproblem several times until the dual variable for the 337 // Solve finite-dimensional subproblem several times until the dual variable for the
347 // regularisation term conforms to the assumptions made for the transport above. 338 // regularisation term conforms to the assumptions made for the transport above.
348 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { 339 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
349 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 340 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
350 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 341 //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
351 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); 342 // TODO: this could be optimised by doing the differential like the
343 // old residual2.
344 let μ̆ = &γ1 + &μ_base_minus_γ0;
345 let mut τv̆ = f.differential(μ̆) * τ;
352 346
353 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 347 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
354 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 348 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
355 &mut μ, 349 &mut μ,
356 &mut τv̆, 350 &mut τv̆,
360 ε, 354 ε,
361 &config.insertion, 355 &config.insertion,
362 &reg, 356 &reg,
363 &state, 357 &state,
364 &mut stats, 358 &mut stats,
365 ); 359 )?;
366 360
367 // A posteriori transport adaptation. 361 // A posteriori transport adaptation.
368 if aposteriori_transport( 362 if aposteriori_transport(
369 &mut γ1, 363 &mut γ1,
370 &mut μ, 364 &mut μ,
402 Some(&μ_base_minus_γ0), 396 Some(&μ_base_minus_γ0),
403 τ, 397 τ,
404 ε, 398 ε,
405 ins, 399 ins,
406 &reg, 400 &reg,
407 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 401 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
408 ); 402 );
409 } 403 }
410 404
411 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 405 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
412 // latter needs to be pruned when μ is. 406 // latter needs to be pruned when μ is.
417 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); 411 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
418 stats.pruned += μ.len() - μ_new.len(); 412 stats.pruned += μ.len() - μ_new.len();
419 μ = μ_new; 413 μ = μ_new;
420 } 414 }
421 415
422 // Update residual
423 residual = calculate_residual(&μ, opA, b);
424
425 let iter = state.iteration(); 416 let iter = state.iteration();
426 stats.this_iters += 1; 417 stats.this_iters += 1;
427 418
428 // Give statistics if requested 419 // Give statistics if requested
429 state.if_verbose(|| { 420 state.if_verbose(|| {
430 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); 421 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
431 full_stats( 422 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
432 &residual,
433 &μ,
434 ε,
435 std::mem::replace(&mut stats, IterInfo::new()),
436 )
437 }); 423 });
438 424
439 // Update main tolerance for next iteration 425 // Update main tolerance for next iteration
440 ε = tolerance.update(ε, iter); 426 ε = tolerance.update(ε, iter);
441 } 427 }
442 428
443 postprocess(μ, &config.insertion, L2Squared, opA, b) 429 //postprocess(μ, &config.insertion, f)
444 } 430 postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃))
431 }

mercurial