src/sliding_fb.rs

branch
dev
changeset 41
b6bdb6cb4d44
parent 39
6316d68b58af
child 44
03251c546744
equal deleted inserted replaced
40:896b42b5ac1a 41:b6bdb6cb4d44
13 use alg_tools::iterate::AlgIteratorFactory; 13 use alg_tools::iterate::AlgIteratorFactory;
14 use alg_tools::euclidean::Euclidean; 14 use alg_tools::euclidean::Euclidean;
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; 15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
16 use alg_tools::norms::Norm; 16 use alg_tools::norms::Norm;
17 use alg_tools::nalgebra_support::ToNalgebraRealField; 17 use alg_tools::nalgebra_support::ToNalgebraRealField;
18 use alg_tools::norms::L2;
19 18
20 use crate::types::*; 19 use crate::types::*;
21 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 20 use crate::measures::{DiscreteMeasure, Radon, RNDM};
22 use crate::measures::merging::SpikeMerging; 21 use crate::measures::merging::SpikeMerging;
23 use crate::forward_model::{ 22 use crate::forward_model::{
47 pub struct TransportConfig<F : Float> { 46 pub struct TransportConfig<F : Float> {
48 /// Transport step length $θ$ normalised to $(0, 1)$. 47 /// Transport step length $θ$ normalised to $(0, 1)$.
49 pub θ0 : F, 48 pub θ0 : F,
50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. 49 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
51 pub adaptation : F, 50 pub adaptation : F,
52 /// A priori transport tolerance multiplier (C_pri)
53 pub tolerance_mult_pri : F,
54 /// A posteriori transport tolerance multiplier (C_pos) 51 /// A posteriori transport tolerance multiplier (C_pos)
55 pub tolerance_mult_pos : F, 52 pub tolerance_mult_pos : F,
56 } 53 }
57 54
58 #[replace_float_literals(F::cast_from(literal))] 55 #[replace_float_literals(F::cast_from(literal))]
59 impl <F : Float> TransportConfig<F> { 56 impl <F : Float> TransportConfig<F> {
60 /// Check that the parameters are ok. Panics if not. 57 /// Check that the parameters are ok. Panics if not.
61 pub fn check(&self) { 58 pub fn check(&self) {
62 assert!(self.θ0 > 0.0); 59 assert!(self.θ0 > 0.0);
63 assert!(0.0 < self.adaptation && self.adaptation < 1.0); 60 assert!(0.0 < self.adaptation && self.adaptation < 1.0);
64 assert!(self.tolerance_mult_pri > 0.0);
65 assert!(self.tolerance_mult_pos > 0.0); 61 assert!(self.tolerance_mult_pos > 0.0);
66 } 62 }
67 } 63 }
68 64
69 #[replace_float_literals(F::cast_from(literal))] 65 #[replace_float_literals(F::cast_from(literal))]
71 fn default() -> Self { 67 fn default() -> Self {
72 TransportConfig { 68 TransportConfig {
73 θ0 : 0.4, 69 θ0 : 0.4,
74 adaptation : 0.9, 70 adaptation : 0.9,
75 tolerance_mult_pos : 100.0, 71 tolerance_mult_pos : 100.0,
76 tolerance_mult_pri : 1000.0,
77 } 72 }
78 } 73 }
79 } 74 }
80 75
81 /// Settings for [`pointsource_sliding_fb_reg`]. 76 /// Settings for [`pointsource_sliding_fb_reg`].
111 /// Adaptive step length. 106 /// Adaptive step length.
112 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 107 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
113 FullyAdaptive{ l : F, max_transport : F, g : G }, 108 FullyAdaptive{ l : F, max_transport : F, g : G },
114 } 109 }
115 110
116 /// Constrution and a priori transport adaptation. 111 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
117 #[replace_float_literals(F::cast_from(literal))] 112 /// with step lengh τ and transport step length `θ_or_adaptive`.
118 pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>( 113 #[replace_float_literals(F::cast_from(literal))]
114 pub(crate) fn initial_transport<F, G, D, const N : usize>(
119 γ1 : &mut RNDM<F, N>, 115 γ1 : &mut RNDM<F, N>,
120 μ : &mut RNDM<F, N>, 116 μ : &mut RNDM<F, N>,
121 opAapply : impl Fn(&RNDM<F, N>) -> Observable,
122 ε : F,
123 τ : F, 117 τ : F,
124 θ_or_adaptive : &mut TransportStepLength<F, G>, 118 θ_or_adaptive : &mut TransportStepLength<F, G>,
125 opAnorm : F,
126 v : D, 119 v : D,
127 tconfig : &TransportConfig<F>
128 ) -> (Vec<F>, RNDM<F, N>) 120 ) -> (Vec<F>, RNDM<F, N>)
129 where 121 where
130 F : Float + ToNalgebraRealField, 122 F : Float + ToNalgebraRealField,
131 G : Fn(F, F) -> F, 123 G : Fn(F, F) -> F,
132 Observable : Euclidean<F, Output=Observable>,
133 for<'a> &'a Observable : Instance<Observable>,
134 //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
135 D : DifferentiableRealMapping<F, N>, 124 D : DifferentiableRealMapping<F, N>,
136 { 125 {
137 126
138 use TransportStepLength::*; 127 use TransportStepLength::*;
139 128
159 } else { 148 } else {
160 δ.α 149 δ.α
161 }; 150 };
162 }; 151 };
163 152
164 // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ. 153 // Calculate transport rays.
165 // 1. Calculate transport rays.
166 // If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it.
167 match *θ_or_adaptive { 154 match *θ_or_adaptive {
168 Fixed(θ) => { 155 Fixed(θ) => {
169 let θτ = τ * θ; 156 let θτ = τ * θ;
170 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 157 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
171 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 158 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
203 } 190 }
204 } 191 }
205 } 192 }
206 } 193 }
207 194
208 // 2. Adjust transport mass, if needed.
209 // This tries to remove the smallest transport masses first.
210 if true {
211 // Alternative 1 : subtract same amount from all transport rays until reaching zero
212 loop {
213 let nr =γ1.norm(Radon);
214 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2();
215 if n <= 0.0 || nr <= 0.0 {
216 break
217 }
218 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n);
219 if reduction_needed <= 0.0 {
220 break
221 }
222 let (min_nonzero, n_nonzero) = γ1.iter_masses()
223 .map(|α| α.abs())
224 .filter(|α| *α > F::EPSILON)
225 .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1));
226 assert!(n_nonzero > 0);
227 // Reduction that can be done in all nonzero spikes simultaneously
228 let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero);
229 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
230 ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0);
231 δ.α = ρ.α;
232 }
233 if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed {
234 break
235 }
236 }
237 } else {
238 // Alternative 2: first reduce transport rays with greater effect based on differential.
239 // This is a an inefficient quick-and-dirty implementation.
240 loop {
241 let nr = γ1.norm(Radon);
242 let a = opAapply(&*γ1)-opAapply(&*μ);
243 let na = a.norm2();
244 let n = τ * 2.0 * opAnorm * na;
245 if n <= 0.0 || nr <= 0.0 {
246 break
247 }
248 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n);
249 if reduction_needed <= 0.0 {
250 break
251 }
252 let mut max_d = 0.0;
253 let mut max_d_ind = 0;
254 for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) {
255 // Calculate differential of ‖A(γ₁-γ₀)‖‖γ‖ wrt. each spike
256 let s = δ.α.signum();
257 // TODO: this is very inefficient implementation due to the limitations
258 // of the closure parameters.
259 let δ1 = DiscreteMeasure::from([(ρ.x, s)]);
260 let δ2 = DiscreteMeasure::from([(δ.x, s)]);
261 let a_part = opAapply(&δ1)-opAapply(&δ2);
262 let d = a.dot(&a_part)/na * nr + 2.0 * na;
263 if d > max_d {
264 max_d = d;
265 max_d_ind = i;
266 }
267 }
268 // Just set mass to zero for transport ray with greater differential
269 assert!(max_d > 0.0);
270 γ1[max_d_ind].α = 0.0;
271 μ[max_d_ind].α = 0.0;
272 }
273 }
274
275 // Set initial guess for μ=μ^{k+1}. 195 // Set initial guess for μ=μ^{k+1}.
276 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { 196 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
277 if ρ.α.abs() > F::EPSILON { 197 if ρ.α.abs() > F::EPSILON {
278 δ.x = ρ.x; 198 δ.x = ρ.x;
279 //δ.α = ρ.α; // already set above 199 //δ.α = ρ.α; // already set above
373 let mut μ = DiscreteMeasure::new(); 293 let mut μ = DiscreteMeasure::new();
374 let mut γ1 = DiscreteMeasure::new(); 294 let mut γ1 = DiscreteMeasure::new();
375 let mut residual = -b; // Has to equal $Aμ-b$. 295 let mut residual = -b; // Has to equal $Aμ-b$.
376 296
377 // Set up parameters 297 // Set up parameters
378 let opAnorm = opA.opnorm_bound(Radon, L2); 298 // let opAnorm = opA.opnorm_bound(Radon, L2);
379 //let max_transport = config.max_transport.scale 299 //let max_transport = config.max_transport.scale
380 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
381 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
382 let ℓ = 0.0; 302 let ℓ = 0.0;
383 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
413 // Run the algorithm 333 // Run the algorithm
414 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 334 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
415 // Calculate initial transport 335 // Calculate initial transport
416 let v = opA.preadjoint().apply(residual); 336 let v = opA.preadjoint().apply(residual);
417 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( 337 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(
418 &mut γ1, &mut μ, |ν| opA.apply(ν), 338 &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v
419 ε, τ, &mut θ_or_adaptive, opAnorm,
420 v, &config.transport,
421 ); 339 );
422 340
423 // Solve finite-dimensional subproblem several times until the dual variable for the 341 // Solve finite-dimensional subproblem several times until the dual variable for the
424 // regularisation term conforms to the assumptions made for the transport above. 342 // regularisation term conforms to the assumptions made for the transport above.
425 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { 343 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {

mercurial