13 use alg_tools::iterate::AlgIteratorFactory; |
13 use alg_tools::iterate::AlgIteratorFactory; |
14 use alg_tools::euclidean::Euclidean; |
14 use alg_tools::euclidean::Euclidean; |
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
16 use alg_tools::norms::Norm; |
16 use alg_tools::norms::Norm; |
17 use alg_tools::nalgebra_support::ToNalgebraRealField; |
17 use alg_tools::nalgebra_support::ToNalgebraRealField; |
18 use alg_tools::norms::L2; |
|
19 |
18 |
20 use crate::types::*; |
19 use crate::types::*; |
21 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
20 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
22 use crate::measures::merging::SpikeMerging; |
21 use crate::measures::merging::SpikeMerging; |
23 use crate::forward_model::{ |
22 use crate::forward_model::{ |
47 pub struct TransportConfig<F : Float> { |
46 pub struct TransportConfig<F : Float> { |
48 /// Transport step length $θ$ normalised to $(0, 1)$. |
47 /// Transport step length $θ$ normalised to $(0, 1)$. |
49 pub θ0 : F, |
48 pub θ0 : F, |
50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. |
49 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. |
51 pub adaptation : F, |
50 pub adaptation : F, |
52 /// A priori transport tolerance multiplier (C_pri) |
|
53 pub tolerance_mult_pri : F, |
|
54 /// A posteriori transport tolerance multiplier (C_pos) |
51 /// A posteriori transport tolerance multiplier (C_pos) |
55 pub tolerance_mult_pos : F, |
52 pub tolerance_mult_pos : F, |
56 } |
53 } |
57 |
54 |
58 #[replace_float_literals(F::cast_from(literal))] |
55 #[replace_float_literals(F::cast_from(literal))] |
59 impl <F : Float> TransportConfig<F> { |
56 impl <F : Float> TransportConfig<F> { |
60 /// Check that the parameters are ok. Panics if not. |
57 /// Check that the parameters are ok. Panics if not. |
61 pub fn check(&self) { |
58 pub fn check(&self) { |
62 assert!(self.θ0 > 0.0); |
59 assert!(self.θ0 > 0.0); |
63 assert!(0.0 < self.adaptation && self.adaptation < 1.0); |
60 assert!(0.0 < self.adaptation && self.adaptation < 1.0); |
64 assert!(self.tolerance_mult_pri > 0.0); |
|
65 assert!(self.tolerance_mult_pos > 0.0); |
61 assert!(self.tolerance_mult_pos > 0.0); |
66 } |
62 } |
67 } |
63 } |
68 |
64 |
69 #[replace_float_literals(F::cast_from(literal))] |
65 #[replace_float_literals(F::cast_from(literal))] |
111 /// Adaptive step length. |
106 /// Adaptive step length. |
112 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
107 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
113 FullyAdaptive{ l : F, max_transport : F, g : G }, |
108 FullyAdaptive{ l : F, max_transport : F, g : G }, |
114 } |
109 } |
115 |
110 |
116 /// Constrution and a priori transport adaptation. |
111 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` |
117 #[replace_float_literals(F::cast_from(literal))] |
112 /// with step lengh τ and transport step length `θ_or_adaptive`. |
118 pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>( |
113 #[replace_float_literals(F::cast_from(literal))] |
|
114 pub(crate) fn initial_transport<F, G, D, const N : usize>( |
119 γ1 : &mut RNDM<F, N>, |
115 γ1 : &mut RNDM<F, N>, |
120 μ : &mut RNDM<F, N>, |
116 μ : &mut RNDM<F, N>, |
121 opAapply : impl Fn(&RNDM<F, N>) -> Observable, |
|
122 ε : F, |
|
123 τ : F, |
117 τ : F, |
124 θ_or_adaptive : &mut TransportStepLength<F, G>, |
118 θ_or_adaptive : &mut TransportStepLength<F, G>, |
125 opAnorm : F, |
|
126 v : D, |
119 v : D, |
127 tconfig : &TransportConfig<F> |
|
128 ) -> (Vec<F>, RNDM<F, N>) |
120 ) -> (Vec<F>, RNDM<F, N>) |
129 where |
121 where |
130 F : Float + ToNalgebraRealField, |
122 F : Float + ToNalgebraRealField, |
131 G : Fn(F, F) -> F, |
123 G : Fn(F, F) -> F, |
132 Observable : Euclidean<F, Output=Observable>, |
|
133 for<'a> &'a Observable : Instance<Observable>, |
|
134 //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
|
135 D : DifferentiableRealMapping<F, N>, |
124 D : DifferentiableRealMapping<F, N>, |
136 { |
125 { |
137 |
126 |
138 use TransportStepLength::*; |
127 use TransportStepLength::*; |
139 |
128 |
159 } else { |
148 } else { |
160 δ.α |
149 δ.α |
161 }; |
150 }; |
162 }; |
151 }; |
163 |
152 |
164 // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ. |
153 // Calculate transport rays. |
165 // 1. Calculate transport rays. |
|
166 // If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it. |
|
167 match *θ_or_adaptive { |
154 match *θ_or_adaptive { |
168 Fixed(θ) => { |
155 Fixed(θ) => { |
169 let θτ = τ * θ; |
156 let θτ = τ * θ; |
170 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
157 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
171 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
158 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
203 } |
190 } |
204 } |
191 } |
205 } |
192 } |
206 } |
193 } |
207 |
194 |
208 // 2. Adjust transport mass, if needed. |
|
209 // This tries to remove the smallest transport masses first. |
|
210 if true { |
|
211 // Alternative 1 : subtract same amount from all transport rays until reaching zero |
|
212 loop { |
|
213 let nr =γ1.norm(Radon); |
|
214 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2(); |
|
215 if n <= 0.0 || nr <= 0.0 { |
|
216 break |
|
217 } |
|
218 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); |
|
219 if reduction_needed <= 0.0 { |
|
220 break |
|
221 } |
|
222 let (min_nonzero, n_nonzero) = γ1.iter_masses() |
|
223 .map(|α| α.abs()) |
|
224 .filter(|α| *α > F::EPSILON) |
|
225 .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1)); |
|
226 assert!(n_nonzero > 0); |
|
227 // Reduction that can be done in all nonzero spikes simultaneously |
|
228 let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero); |
|
229 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) { |
|
230 ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0); |
|
231 δ.α = ρ.α; |
|
232 } |
|
233 if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed { |
|
234 break |
|
235 } |
|
236 } |
|
237 } else { |
|
238 // Alternative 2: first reduce transport rays with greater effect based on differential. |
|
239 // This is a an inefficient quick-and-dirty implementation. |
|
240 loop { |
|
241 let nr = γ1.norm(Radon); |
|
242 let a = opAapply(&*γ1)-opAapply(&*μ); |
|
243 let na = a.norm2(); |
|
244 let n = τ * 2.0 * opAnorm * na; |
|
245 if n <= 0.0 || nr <= 0.0 { |
|
246 break |
|
247 } |
|
248 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); |
|
249 if reduction_needed <= 0.0 { |
|
250 break |
|
251 } |
|
252 let mut max_d = 0.0; |
|
253 let mut max_d_ind = 0; |
|
254 for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) { |
|
255 // Calculate differential of ‖A(γ₁-γ₀)‖‖γ‖ wrt. each spike |
|
256 let s = δ.α.signum(); |
|
257 // TODO: this is very inefficient implementation due to the limitations |
|
258 // of the closure parameters. |
|
259 let δ1 = DiscreteMeasure::from([(ρ.x, s)]); |
|
260 let δ2 = DiscreteMeasure::from([(δ.x, s)]); |
|
261 let a_part = opAapply(&δ1)-opAapply(&δ2); |
|
262 let d = a.dot(&a_part)/na * nr + 2.0 * na; |
|
263 if d > max_d { |
|
264 max_d = d; |
|
265 max_d_ind = i; |
|
266 } |
|
267 } |
|
268 // Just set mass to zero for transport ray with greater differential |
|
269 assert!(max_d > 0.0); |
|
270 γ1[max_d_ind].α = 0.0; |
|
271 μ[max_d_ind].α = 0.0; |
|
272 } |
|
273 } |
|
274 |
|
275 // Set initial guess for μ=μ^{k+1}. |
195 // Set initial guess for μ=μ^{k+1}. |
276 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { |
196 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { |
277 if ρ.α.abs() > F::EPSILON { |
197 if ρ.α.abs() > F::EPSILON { |
278 δ.x = ρ.x; |
198 δ.x = ρ.x; |
279 //δ.α = ρ.α; // already set above |
199 //δ.α = ρ.α; // already set above |
373 let mut μ = DiscreteMeasure::new(); |
293 let mut μ = DiscreteMeasure::new(); |
374 let mut γ1 = DiscreteMeasure::new(); |
294 let mut γ1 = DiscreteMeasure::new(); |
375 let mut residual = -b; // Has to equal $Aμ-b$. |
295 let mut residual = -b; // Has to equal $Aμ-b$. |
376 |
296 |
377 // Set up parameters |
297 // Set up parameters |
378 let opAnorm = opA.opnorm_bound(Radon, L2); |
298 // let opAnorm = opA.opnorm_bound(Radon, L2); |
379 //let max_transport = config.max_transport.scale |
299 //let max_transport = config.max_transport.scale |
380 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
381 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
382 let ℓ = 0.0; |
302 let ℓ = 0.0; |
383 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
413 // Run the algorithm |
333 // Run the algorithm |
414 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
334 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
415 // Calculate initial transport |
335 // Calculate initial transport |
416 let v = opA.preadjoint().apply(residual); |
336 let v = opA.preadjoint().apply(residual); |
417 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( |
337 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( |
418 &mut γ1, &mut μ, |ν| opA.apply(ν), |
338 &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v |
419 ε, τ, &mut θ_or_adaptive, opAnorm, |
|
420 v, &config.transport, |
|
421 ); |
339 ); |
422 |
340 |
423 // Solve finite-dimensional subproblem several times until the dual variable for the |
341 // Solve finite-dimensional subproblem several times until the dual variable for the |
424 // regularisation term conforms to the assumptions made for the transport above. |
342 // regularisation term conforms to the assumptions made for the transport above. |
425 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { |
343 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { |