src/sliding_fb.rs

branch
dev
changeset 49
6b0db7251ebe
parent 46
f358958cc1a6
equal deleted inserted replaced
48:53136eba9abf 49:6b0db7251ebe
2 Solver for the point source localisation problem using a sliding 2 Solver for the point source localisation problem using a sliding
3 forward-backward splitting method. 3 forward-backward splitting method.
4 */ 4 */
5 5
6 use numeric_literals::replace_float_literals; 6 use numeric_literals::replace_float_literals;
7 use serde::{Serialize, Deserialize}; 7 use serde::{Deserialize, Serialize};
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 12
13 use alg_tools::euclidean::Euclidean;
13 use alg_tools::iterate::AlgIteratorFactory; 14 use alg_tools::iterate::AlgIteratorFactory;
14 use alg_tools::euclidean::Euclidean; 15 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; 16 use alg_tools::nalgebra_support::ToNalgebraRealField;
16 use alg_tools::norms::Norm; 17 use alg_tools::norms::Norm;
17 use alg_tools::nalgebra_support::ToNalgebraRealField; 18
18 19 use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
20 use crate::measures::merging::SpikeMerging;
21 use crate::measures::{DiscreteMeasure, Radon, RNDM};
19 use crate::types::*; 22 use crate::types::*;
20 use crate::measures::{DiscreteMeasure, Radon, RNDM};
21 use crate::measures::merging::SpikeMerging;
22 use crate::forward_model::{
23 ForwardModel,
24 AdjointProductBoundedBy,
25 BoundedCurvature,
26 };
27 //use crate::tolerance::Tolerance; 23 //use crate::tolerance::Tolerance;
28 use crate::plot::{ 24 use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
29 SeqPlotter,
30 Plotting,
31 PlotLookup
32 };
33 use crate::fb::*; 25 use crate::fb::*;
26 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
34 use crate::regularisation::SlidingRegTerm; 27 use crate::regularisation::SlidingRegTerm;
35 use crate::dataterm::{
36 L2Squared,
37 DataTerm,
38 calculate_residual,
39 calculate_residual2,
40 };
41 //use crate::transport::TransportLipschitz; 28 //use crate::transport::TransportLipschitz;
42 29
43 /// Transport settings for [`pointsource_sliding_fb_reg`]. 30 /// Transport settings for [`pointsource_sliding_fb_reg`].
44 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
45 #[serde(default)] 32 #[serde(default)]
46 pub struct TransportConfig<F : Float> { 33 pub struct TransportConfig<F: Float> {
47 /// Transport step length $θ$ normalised to $(0, 1)$. 34 /// Transport step length $θ$ normalised to $(0, 1)$.
48 pub θ0 : F, 35 pub θ0: F,
49 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. 36 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
50 pub adaptation : F, 37 pub adaptation: F,
51 /// A posteriori transport tolerance multiplier (C_pos) 38 /// A posteriori transport tolerance multiplier (C_pos)
52 pub tolerance_mult_con : F, 39 pub tolerance_mult_con: F,
53 } 40 }
54 41
55 #[replace_float_literals(F::cast_from(literal))] 42 #[replace_float_literals(F::cast_from(literal))]
56 impl <F : Float> TransportConfig<F> { 43 impl<F: Float> TransportConfig<F> {
57 /// Check that the parameters are ok. Panics if not. 44 /// Check that the parameters are ok. Panics if not.
58 pub fn check(&self) { 45 pub fn check(&self) {
59 assert!(self.θ0 > 0.0); 46 assert!(self.θ0 > 0.0);
60 assert!(0.0 < self.adaptation && self.adaptation < 1.0); 47 assert!(0.0 < self.adaptation && self.adaptation < 1.0);
61 assert!(self.tolerance_mult_con > 0.0); 48 assert!(self.tolerance_mult_con > 0.0);
62 } 49 }
63 } 50 }
64 51
65 #[replace_float_literals(F::cast_from(literal))] 52 #[replace_float_literals(F::cast_from(literal))]
66 impl<F : Float> Default for TransportConfig<F> { 53 impl<F: Float> Default for TransportConfig<F> {
67 fn default() -> Self { 54 fn default() -> Self {
68 TransportConfig { 55 TransportConfig {
69 θ0 : 0.9, 56 θ0: 0.9,
70 adaptation : 0.9, 57 adaptation: 0.9,
71 tolerance_mult_con : 100.0, 58 tolerance_mult_con: 100.0,
72 } 59 }
73 } 60 }
74 } 61 }
75 62
76 /// Settings for [`pointsource_sliding_fb_reg`]. 63 /// Settings for [`pointsource_sliding_fb_reg`].
77 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
78 #[serde(default)] 65 #[serde(default)]
79 pub struct SlidingFBConfig<F : Float> { 66 pub struct SlidingFBConfig<F: Float> {
80 /// Step length scaling 67 /// Step length scaling
81 pub τ0 : F, 68 pub τ0: F,
82 /// Transport parameters 69 /// Transport parameters
83 pub transport : TransportConfig<F>, 70 pub transport: TransportConfig<F>,
84 /// Generic parameters 71 /// Generic parameters
85 pub insertion : FBGenericConfig<F>, 72 pub insertion: FBGenericConfig<F>,
86 } 73 }
87 74
88 #[replace_float_literals(F::cast_from(literal))] 75 #[replace_float_literals(F::cast_from(literal))]
89 impl<F : Float> Default for SlidingFBConfig<F> { 76 impl<F: Float> Default for SlidingFBConfig<F> {
90 fn default() -> Self { 77 fn default() -> Self {
91 SlidingFBConfig { 78 SlidingFBConfig {
92 τ0 : 0.99, 79 τ0: 0.99,
93 transport : Default::default(), 80 transport: Default::default(),
94 insertion : Default::default() 81 insertion: Default::default(),
95 } 82 }
96 } 83 }
97 } 84 }
98 85
99 /// Internal type of adaptive transport step length calculation 86 /// Internal type of adaptive transport step length calculation
100 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> { 87 pub(crate) enum TransportStepLength<F: Float, G: Fn(F, F) -> F> {
101 /// Fixed, known step length 88 /// Fixed, known step length
102 #[allow(dead_code)] 89 #[allow(dead_code)]
103 Fixed(F), 90 Fixed(F),
104 /// Adaptive step length, only wrt. maximum transport. 91 /// Adaptive step length, only wrt. maximum transport.
105 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 92 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
106 AdaptiveMax{ l : F, max_transport : F, g : G }, 93 AdaptiveMax { l: F, max_transport: F, g: G },
107 /// Adaptive step length. 94 /// Adaptive step length.
108 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 95 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
109 FullyAdaptive{ l : F, max_transport : F, g : G }, 96 FullyAdaptive { l: F, max_transport: F, g: G },
110 } 97 }
111 98
112 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` 99 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
113 /// with step lengh τ and transport step length `θ_or_adaptive`. 100 /// with step lengh τ and transport step length `θ_or_adaptive`.
114 #[replace_float_literals(F::cast_from(literal))] 101 #[replace_float_literals(F::cast_from(literal))]
115 pub(crate) fn initial_transport<F, G, D, const N : usize>( 102 pub(crate) fn initial_transport<F, G, D, const N: usize>(
116 γ1 : &mut RNDM<F, N>, 103 γ1: &mut RNDM<F, N>,
117 μ : &mut RNDM<F, N>, 104 μ: &mut RNDM<F, N>,
118 τ : F, 105 τ: F,
119 θ_or_adaptive : &mut TransportStepLength<F, G>, 106 θ_or_adaptive: &mut TransportStepLength<F, G>,
120 v : D, 107 v: D,
121 ) -> (Vec<F>, RNDM<F, N>) 108 ) -> (Vec<F>, RNDM<F, N>)
122 where 109 where
123 F : Float + ToNalgebraRealField, 110 F: Float + ToNalgebraRealField,
124 G : Fn(F, F) -> F, 111 G: Fn(F, F) -> F,
125 D : DifferentiableRealMapping<F, N>, 112 D: DifferentiableRealMapping<F, N>,
126 { 113 {
127
128 use TransportStepLength::*; 114 use TransportStepLength::*;
129 115
130 // Save current base point and shift μ to new positions. Idea is that 116 // Save current base point and shift μ to new positions. Idea is that
131 // μ_base(_masses) = μ^k (vector of masses) 117 // μ_base(_masses) = μ^k (vector of masses)
132 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 118 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
133 // γ1 = π_♯^1γ^{k+1} 119 // γ1 = π_♯^1γ^{k+1}
134 // μ = μ^{k+1} 120 // μ = μ^{k+1}
135 let μ_base_masses : Vec<F> = μ.iter_masses().collect(); 121 let μ_base_masses: Vec<F> = μ.iter_masses().collect();
136 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. 122 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
137 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates 123 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
138 //let mut sum_norm_dv = 0.0; 124 //let mut sum_norm_dv = 0.0;
139 let γ_prev_len = γ1.len(); 125 let γ_prev_len = γ1.len();
140 assert!(μ.len() >= γ_prev_len); 126 assert!(μ.len() >= γ_prev_len);
141 γ1.extend(μ[γ_prev_len..].iter().cloned()); 127 γ1.extend(μ[γ_prev_len..].iter().cloned());
142 128
143 // Calculate initial transport and step length. 129 // Calculate initial transport and step length.
147 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { 133 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
148 0.0 134 0.0
149 } else { 135 } else {
150 δ.α 136 δ.α
151 }; 137 };
152 }; 138 }
153 139
154 // Calculate transport rays. 140 // Calculate transport rays.
155 match *θ_or_adaptive { 141 match *θ_or_adaptive {
156 Fixed(θ) => { 142 Fixed(θ) => {
157 let θτ = τ * θ; 143 let θτ = τ * θ;
158 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 144 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
159 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 145 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
160 } 146 }
161 }, 147 }
162 AdaptiveMax{ l : ℓ_F, ref mut max_transport, g : ref calculate_θ } => { 148 AdaptiveMax {
149 l: ℓ_F,
150 ref mut max_transport,
151 g: ref calculate_θ,
152 } => {
163 *max_transport = max_transport.max(γ1.norm(Radon)); 153 *max_transport = max_transport.max(γ1.norm(Radon));
164 let θτ = τ * calculate_θ(ℓ_F, *max_transport); 154 let θτ = τ * calculate_θ(ℓ_F, *max_transport);
165 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 155 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
166 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 156 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
167 } 157 }
168 }, 158 }
169 FullyAdaptive{ l : ref mut adaptive_ℓ_F, ref mut max_transport, g : ref calculate_θ } => { 159 FullyAdaptive {
160 l: ref mut adaptive_ℓ_F,
161 ref mut max_transport,
162 g: ref calculate_θ,
163 } => {
170 *max_transport = max_transport.max(γ1.norm(Radon)); 164 *max_transport = max_transport.max(γ1.norm(Radon));
171 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 165 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
172 // Do two runs through the spikes to update θ, breaking if first run did not cause 166 // Do two runs through the spikes to update θ, breaking if first run did not cause
173 // a change. 167 // a change.
174 for _i in 0..=1 { 168 for _i in 0..=1 {
185 θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 179 θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
186 changes = true 180 changes = true
187 } 181 }
188 } 182 }
189 if !changes { 183 if !changes {
190 break 184 break;
191 } 185 }
192 } 186 }
193 } 187 }
194 } 188 }
195 189
201 } else { 195 } else {
202 δ.α = β; 196 δ.α = β;
203 } 197 }
204 } 198 }
205 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) 199 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
206 μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) 200 μ_base_minus_γ0.set_masses(
207 .map(|(&a,b)| a - b)); 201 μ_base_masses
202 .iter()
203 .zip(γ1.iter_masses())
204 .map(|(&a, b)| a - b),
205 );
208 (μ_base_masses, μ_base_minus_γ0) 206 (μ_base_masses, μ_base_minus_γ0)
209 } 207 }
210 208
211 /// A posteriori transport adaptation. 209 /// A posteriori transport adaptation.
212 #[replace_float_literals(F::cast_from(literal))] 210 #[replace_float_literals(F::cast_from(literal))]
213 pub(crate) fn aposteriori_transport<F, const N : usize>( 211 pub(crate) fn aposteriori_transport<F, const N: usize>(
214 γ1 : &mut RNDM<F, N>, 212 γ1: &mut RNDM<F, N>,
215 μ : &mut RNDM<F, N>, 213 μ: &mut RNDM<F, N>,
216 μ_base_minus_γ0 : &mut RNDM<F, N>, 214 μ_base_minus_γ0: &mut RNDM<F, N>,
217 μ_base_masses : &Vec<F>, 215 μ_base_masses: &Vec<F>,
218 extra : Option<F>, 216 extra: Option<F>,
219 ε : F, 217 ε: F,
220 tconfig : &TransportConfig<F> 218 tconfig: &TransportConfig<F>,
221 ) -> bool 219 ) -> bool
222 where F : Float + ToNalgebraRealField { 220 where
223 221 F: Float + ToNalgebraRealField,
222 {
224 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, 223 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
225 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 224 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
226 // at that point to zero, and retry. 225 // at that point to zero, and retry.
227 let mut all_ok = true; 226 let mut all_ok = true;
228 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { 227 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
236 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, 235 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
237 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. 236 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
238 let nγ = γ1.norm(Radon); 237 let nγ = γ1.norm(Radon);
239 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); 238 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
240 let t = ε * tconfig.tolerance_mult_con; 239 let t = ε * tconfig.tolerance_mult_con;
241 if nγ*nΔ > t { 240 if nγ * nΔ > t {
242 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, 241 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
243 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we 242 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
244 // will not enter here. 243 // will not enter here.
245 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); 244 *γ1 *= tconfig.adaptation * t / (nγ * nΔ);
246 all_ok = false 245 all_ok = false
247 } 246 }
248 247
249 if !all_ok { 248 if !all_ok {
250 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 249 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
251 μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) 250 μ_base_minus_γ0.set_masses(
252 .map(|(&a,b)| a - b)); 251 μ_base_masses
253 252 .iter()
253 .zip(γ1.iter_masses())
254 .map(|(&a, b)| a - b),
255 );
254 } 256 }
255 257
256 all_ok 258 all_ok
257 } 259 }
258 260
260 /// splitting 262 /// splitting
261 /// 263 ///
262 /// The parametrisation is as for [`pointsource_fb_reg`]. 264 /// The parametrisation is as for [`pointsource_fb_reg`].
263 /// Inertia is currently not supported. 265 /// Inertia is currently not supported.
264 #[replace_float_literals(F::cast_from(literal))] 266 #[replace_float_literals(F::cast_from(literal))]
265 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( 267 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>(
266 opA : &A, 268 opA: &A,
267 b : &A::Observable, 269 b: &A::Observable,
268 reg : Reg, 270 reg: Reg,
269 prox_penalty : &P, 271 prox_penalty: &P,
270 config : &SlidingFBConfig<F>, 272 config: &SlidingFBConfig<F>,
271 iterator : I, 273 iterator: I,
272 mut plotter : SeqPlotter<F, N>, 274 mut plotter: SeqPlotter<F, N>,
273 ) -> RNDM<F, N> 275 ) -> RNDM<F, N>
274 where 276 where
275 F : Float + ToNalgebraRealField, 277 F: Float + ToNalgebraRealField,
276 I : AlgIteratorFactory<IterInfo<F, N>>, 278 I: AlgIteratorFactory<IterInfo<F, N>>,
277 A : ForwardModel<RNDM<F, N>, F> 279 A: ForwardModel<RNDM<F, N>, F>
278 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F> 280 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>
279 + BoundedCurvature<FloatType=F>, 281 + BoundedCurvature<FloatType = F>,
280 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, 282 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
281 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, 283 A::PreadjointCodomain: DifferentiableRealMapping<F, N>,
282 RNDM<F, N> : SpikeMerging<F>, 284 RNDM<F, N>: SpikeMerging<F>,
283 Reg : SlidingRegTerm<F, N>, 285 Reg: SlidingRegTerm<F, N>,
284 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, 286 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
285 PlotLookup : Plotting<N>, 287 PlotLookup: Plotting<N>,
286 { 288 {
287
288 // Check parameters 289 // Check parameters
289 assert!(config.τ0 > 0.0, "Invalid step length parameter"); 290 assert!(config.τ0 > 0.0, "Invalid step length parameter");
290 config.transport.check(); 291 config.transport.check();
291 292
292 // Initialise iterates 293 // Initialise iterates
299 //let max_transport = config.max_transport.scale 300 //let max_transport = config.max_transport.scale
300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 301 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 302 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
302 let ℓ = 0.0; 303 let ℓ = 0.0;
303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 304 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
304 let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components(); 305 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
305 let transport_lip = maybe_transport_lip.unwrap(); 306 let transport_lip = maybe_transport_lip.unwrap();
306 let calculate_θ = |ℓ_v, max_transport| { 307 let calculate_θ = |ℓ_F, max_transport| {
307 let ℓ_F = ℓ_v + transport_lip * max_transport; 308 let ℓ_r = transport_lip * max_transport;
308 config.transport.θ0 / (τ*(ℓ + ℓ_F)) 309 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
309 }; 310 };
310 let mut θ_or_adaptive = match maybe_ℓ_v0 { 311 let mut θ_or_adaptive = match maybe_ℓ_F0 {
311 //Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), 312 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
312 Some(ℓ_v0) => TransportStepLength::AdaptiveMax { 313 Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
313 l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual 314 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
314 max_transport : 0.0, 315 max_transport: 0.0,
315 g : calculate_θ 316 g: calculate_θ,
316 }, 317 },
317 None => TransportStepLength::FullyAdaptive { 318 None => TransportStepLength::FullyAdaptive {
318 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials 319 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
319 max_transport : 0.0, 320 max_transport: 0.0,
320 g : calculate_θ 321 g: calculate_θ,
321 }, 322 },
322 }; 323 };
323 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 324 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
324 // by τ compared to the conditional gradient approach. 325 // by τ compared to the conditional gradient approach.
325 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); 326 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
326 let mut ε = tolerance.initial(); 327 let mut ε = tolerance.initial();
327 328
328 // Statistics 329 // Statistics
329 let full_stats = |residual : &A::Observable, 330 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
330 μ : &RNDM<F, N>, 331 value: residual.norm2_squared_div2() + reg.apply(μ),
331 ε, stats| IterInfo { 332 n_spikes: μ.len(),
332 value : residual.norm2_squared_div2() + reg.apply(μ),
333 n_spikes : μ.len(),
334 ε, 333 ε,
335 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), 334 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
336 .. stats 335 ..stats
337 }; 336 };
338 let mut stats = IterInfo::new(); 337 let mut stats = IterInfo::new();
339 338
340 // Run the algorithm 339 // Run the algorithm
341 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 340 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
342 // Calculate initial transport 341 // Calculate initial transport
343 let v = opA.preadjoint().apply(residual); 342 let v = opA.preadjoint().apply(residual);
344 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( 343 let (μ_base_masses, mut μ_base_minus_γ0) =
345 &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v 344 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
346 );
347 345
348 // Solve finite-dimensional subproblem several times until the dual variable for the 346 // Solve finite-dimensional subproblem several times until the dual variable for the
349 // regularisation term conforms to the assumptions made for the transport above. 347 // regularisation term conforms to the assumptions made for the transport above.
350 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { 348 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
351 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 349 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
352 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 350 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
353 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); 351 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
354 352
355 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 353 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
356 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 354 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
357 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), 355 &mut μ,
358 τ, ε, &config.insertion, 356 &mut τv̆,
359 &reg, &state, &mut stats, 357 &γ1,
358 Some(&μ_base_minus_γ0),
359 τ,
360 ε,
361 &config.insertion,
362 &reg,
363 &state,
364 &mut stats,
360 ); 365 );
361 366
362 // A posteriori transport adaptation. 367 // A posteriori transport adaptation.
363 if aposteriori_transport( 368 if aposteriori_transport(
364 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, 369 &mut γ1,
370 &mut μ,
371 &mut μ_base_minus_γ0,
372 &μ_base_masses,
365 None, 373 None,
366 ε, &config.transport 374 ε,
375 &config.transport,
367 ) { 376 ) {
368 break 'adapt_transport (maybe_d, within_tolerances, τv̆) 377 break 'adapt_transport (maybe_d, within_tolerances, τv̆);
369 } 378 }
370 }; 379 };
371 380
372 stats.untransported_fraction = Some({ 381 stats.untransported_fraction = Some({
373 assert_eq!(μ_base_masses.len(), γ1.len()); 382 assert_eq!(μ_base_masses.len(), γ1.len());
385 // This crucially expects the merge routine to be stable with respect to spike locations, 394 // This crucially expects the merge routine to be stable with respect to spike locations,
386 // and not to performing any pruning. That is be to done below simultaneously for γ. 395 // and not to performing any pruning. That is be to done below simultaneously for γ.
387 let ins = &config.insertion; 396 let ins = &config.insertion;
388 if ins.merge_now(&state) { 397 if ins.merge_now(&state) {
389 stats.merged += prox_penalty.merge_spikes( 398 stats.merged += prox_penalty.merge_spikes(
390 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, &reg, 399 &mut μ,
391 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 400 &mut τv̆,
401 &γ1,
402 Some(&μ_base_minus_γ0),
403 τ,
404 ε,
405 ins,
406 &reg,
407 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
392 ); 408 );
393 } 409 }
394 410
395 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 411 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
396 // latter needs to be pruned when μ is. 412 // latter needs to be pruned when μ is.
410 stats.this_iters += 1; 426 stats.this_iters += 1;
411 427
412 // Give statistics if requested 428 // Give statistics if requested
413 state.if_verbose(|| { 429 state.if_verbose(|| {
414 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); 430 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
415 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) 431 full_stats(
432 &residual,
433 &μ,
434 ε,
435 std::mem::replace(&mut stats, IterInfo::new()),
436 )
416 }); 437 });
417 438
418 // Update main tolerance for next iteration 439 // Update main tolerance for next iteration
419 ε = tolerance.update(ε, iter); 440 ε = tolerance.update(ε, iter);
420 } 441 }

mercurial