src/sliding_fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
equal deleted inserted replaced
34:efa60bc4f743 35:b087e3eab191
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 12
13 use alg_tools::iterate::{ 13 use alg_tools::iterate::AlgIteratorFactory;
14 AlgIteratorFactory,
15 AlgIteratorState
16 };
17 use alg_tools::euclidean::Euclidean; 14 use alg_tools::euclidean::Euclidean;
18 use alg_tools::sets::Cube; 15 use alg_tools::sets::Cube;
19 use alg_tools::loc::Loc; 16 use alg_tools::loc::Loc;
20 use alg_tools::mapping::{Apply, Differentiable}; 17 use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance};
21 use alg_tools::norms::{Norm, L2}; 18 use alg_tools::norms::Norm;
22 use alg_tools::bisection_tree::{ 19 use alg_tools::bisection_tree::{
23 BTFN, 20 BTFN,
24 PreBTFN, 21 PreBTFN,
25 Bounds, 22 Bounds,
26 BTNodeLookup, 23 BTNodeLookup,
31 LocalAnalysis, 28 LocalAnalysis,
32 //Bounded, 29 //Bounded,
33 }; 30 };
34 use alg_tools::mapping::RealMapping; 31 use alg_tools::mapping::RealMapping;
35 use alg_tools::nalgebra_support::ToNalgebraRealField; 32 use alg_tools::nalgebra_support::ToNalgebraRealField;
33 use alg_tools::norms::{L2, Linfinity};
36 34
37 use crate::types::*; 35 use crate::types::*;
38 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; 36 use crate::measures::{DiscreteMeasure, Radon, RNDM};
39 use crate::measures::merging::{ 37 use crate::measures::merging::{
40 //SpikeMergingMethod, 38 SpikeMergingMethod,
41 SpikeMerging, 39 SpikeMerging,
42 }; 40 };
43 use crate::forward_model::ForwardModel; 41 use crate::forward_model::{
42 ForwardModel,
43 AdjointProductBoundedBy,
44 LipschitzValues,
45 };
44 use crate::seminorms::DiscreteMeasureOp; 46 use crate::seminorms::DiscreteMeasureOp;
45 //use crate::tolerance::Tolerance; 47 //use crate::tolerance::Tolerance;
46 use crate::plot::{ 48 use crate::plot::{
47 SeqPlotter, 49 SeqPlotter,
48 Plotting, 50 Plotting,
54 L2Squared, 56 L2Squared,
55 //DataTerm, 57 //DataTerm,
56 calculate_residual, 58 calculate_residual,
57 calculate_residual2, 59 calculate_residual2,
58 }; 60 };
59 use crate::transport::TransportLipschitz; 61 //use crate::transport::TransportLipschitz;
62
63 /// Transport settings for [`pointsource_sliding_fb_reg`].
64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
65 #[serde(default)]
66 pub struct TransportConfig<F : Float> {
67 /// Transport step length $θ$ normalised to $(0, 1)$.
68 pub θ0 : F,
69 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
70 pub adaptation : F,
71 /// Transport tolerance wrt. ω
72 pub tolerance_ω : F,
73 /// Transport tolerance wrt. ∇v
74 pub tolerance_dv : F,
75 }
76
77 #[replace_float_literals(F::cast_from(literal))]
78 impl <F : Float> TransportConfig<F> {
79 /// Check that the parameters are ok. Panics if not.
80 pub fn check(&self) {
81 assert!(self.θ0 > 0.0);
82 assert!(0.0 < self.adaptation && self.adaptation < 1.0);
83 assert!(self.tolerance_dv > 0.0);
84 assert!(self.tolerance_ω > 0.0);
85 }
86 }
87
88 #[replace_float_literals(F::cast_from(literal))]
89 impl<F : Float> Default for TransportConfig<F> {
90 fn default() -> Self {
91 TransportConfig {
92 θ0 : 0.01,
93 adaptation : 0.9,
94 tolerance_ω : 1000.0, // TODO: no idea what this should be
95 tolerance_dv : 1000.0, // TODO: no idea what this should be
96 }
97 }
98 }
60 99
61 /// Settings for [`pointsource_sliding_fb_reg`]. 100 /// Settings for [`pointsource_sliding_fb_reg`].
62 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
63 #[serde(default)] 102 #[serde(default)]
64 pub struct SlidingFBConfig<F : Float> { 103 pub struct SlidingFBConfig<F : Float> {
65 /// Step length scaling 104 /// Step length scaling
66 pub τ0 : F, 105 pub τ0 : F,
67 /// Transport step length $θ$ normalised to $(0, 1)$. 106 /// Transport parameters
68 pub θ0 : F, 107 pub transport : TransportConfig<F>,
69 /// Maximum transport mass scaling.
70 // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$.
71 // pub max_transport_scale : F,
72 /// Transport tolerance wrt. ω
73 pub transport_tolerance_ω : F,
74 /// Transport tolerance wrt. ∇v
75 pub transport_tolerance_dv : F,
76 /// Generic parameters 108 /// Generic parameters
77 pub insertion : FBGenericConfig<F>, 109 pub insertion : FBGenericConfig<F>,
78 } 110 }
79 111
80 #[replace_float_literals(F::cast_from(literal))] 112 #[replace_float_literals(F::cast_from(literal))]
81 impl<F : Float> Default for SlidingFBConfig<F> { 113 impl<F : Float> Default for SlidingFBConfig<F> {
82 fn default() -> Self { 114 fn default() -> Self {
83 SlidingFBConfig { 115 SlidingFBConfig {
84 τ0 : 0.99, 116 τ0 : 0.99,
85 θ0 : 0.99, 117 transport : Default::default(),
86 //max_transport_scale : 10.0,
87 transport_tolerance_ω : 1.0, // TODO: no idea what this should be
88 transport_tolerance_dv : 1.0, // TODO: no idea what this should be
89 insertion : Default::default() 118 insertion : Default::default()
90 } 119 }
91 } 120 }
92 } 121 }
93 122
94 /// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i) 123 /// Internal type of adaptive transport step length calculation
124 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> {
125 /// Fixed, known step length
126 Fixed(F),
127 /// Adaptive step length, only wrt. maximum transport.
128 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
129 AdaptiveMax{ l : F, max_transport : F, g : G },
130 /// Adaptive step length.
131 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
132 FullyAdaptive{ l : F, max_transport : F, g : G },
133 }
134
135 /// Constrution and a priori transport adaptation.
95 #[replace_float_literals(F::cast_from(literal))] 136 #[replace_float_literals(F::cast_from(literal))]
96 fn scale_down<'a, I, F, G, const N : usize>( 137 pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>(
97 iter : I, 138 γ1 : &mut RNDM<F, N>,
98 q̄ : F, 139 μ : &mut RNDM<F, N>,
99 mut g : G 140 opAapply : impl Fn(&RNDM<F, N>) -> Observable,
100 ) where F : Float, 141 ε : F,
101 I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>, 142 τ : F,
102 G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F { 143 θ_or_adaptive : &mut TransportStepLength<F, G>,
103 iter.for_each(|δ| { 144 opAnorm : F,
104 if δ.α != 0.0 { 145 v : D,
105 let b = g(δ); 146 tconfig : &TransportConfig<F>
106 if b * δ.α > 0.0 { 147 ) -> (Vec<F>, RNDM<F, N>)
107 δ.α *= q̄/b; 148 where
108 } 149 F : Float + ToNalgebraRealField,
109 } 150 G : Fn(F, F) -> F,
110 }); 151 Observable : Euclidean<F, Output=Observable>,
152 for<'a> &'a Observable : Instance<Observable>,
153 //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
154 D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>,
155 {
156
157 use TransportStepLength::*;
158
159 // Save current base point and shift μ to new positions. Idea is that
160 // μ_base(_masses) = μ^k (vector of masses)
161 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
162 // γ1 = π_♯^1γ^{k+1}
163 // μ = μ^{k+1}
164 let μ_base_masses : Vec<F> = μ.iter_masses().collect();
165 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
166 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
167 //let mut sum_norm_dv = 0.0;
168 let γ_prev_len = γ1.len();
169 assert!(μ.len() >= γ_prev_len);
170 γ1.extend(μ[γ_prev_len..].iter().cloned());
171
172 // Calculate initial transport and step length.
173 // First calculate initial transported weights
174 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
175 // If old transport has opposing sign, the new transport will be none.
176 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
177 0.0
178 } else {
179 δ.α
180 };
181 };
182
183 // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ.
184 // 1. Calculate transport rays.
185 // If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it.
186 match *θ_or_adaptive {
187 Fixed(θ) => {
188 let θτ = τ * θ;
189 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
190 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
191 }
192 },
193 AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
194 *max_transport = max_transport.max(γ1.norm(Radon));
195 let θτ = τ * calculate_θ(ℓ_v, *max_transport);
196 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
197 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
198 }
199 },
200 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
201 *max_transport = max_transport.max(γ1.norm(Radon));
202 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport);
203 loop {
204 let θτ = τ * θ;
205 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
206 let dv_x = v.differential(&δ.x);
207 ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ);
208 // Estimate Lipschitz factor of ∇v
209 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2();
210 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v);
211 }
212 let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport);
213 if new_θ <= θ {
214 break
215 }
216 θ = new_θ;
217 }
218 }
219 }
220
221 // 2. Adjust transport mass, if needed.
222 // This tries to remove the smallest transport masses first.
223 if true {
224 // Alternative 1 : subtract same amount from all transport rays until reaching zero
225 loop {
226 let nr =γ1.norm(Radon);
227 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2();
228 if n <= 0.0 || nr <= 0.0 {
229 break
230 }
231 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
232 if reduction_needed <= 0.0 {
233 break
234 }
235 let (min_nonzero, n_nonzero) = γ1.iter_masses()
236 .map(|α| α.abs())
237 .filter(|α| *α > F::EPSILON)
238 .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1));
239 assert!(n_nonzero > 0);
240 // Reduction that can be done in all nonzero spikes simultaneously
241 let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero);
242 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
243 ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0);
244 δ.α = ρ.α;
245 }
246 if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed {
247 break
248 }
249 }
250 } else {
251 // Alternative 2: first reduce transport rays with greater effect based on differential.
252 // This is a an inefficient quick-and-dirty implementation.
253 loop {
254 let nr = γ1.norm(Radon);
255 let a = opAapply(&*γ1)-opAapply(&*μ);
256 let na = a.norm2();
257 let n = τ * 2.0 * opAnorm * na;
258 if n <= 0.0 || nr <= 0.0 {
259 break
260 }
261 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
262 if reduction_needed <= 0.0 {
263 break
264 }
265 let mut max_d = 0.0;
266 let mut max_d_ind = 0;
267 for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) {
268 // Calculate differential of ‖A(γ₁-γ₀)‖‖γ‖ wrt. each spike
269 let s = δ.α.signum();
270 // TODO: this is very inefficient implementation due to the limitations
271 // of the closure parameters.
272 let δ1 = DiscreteMeasure::from([(ρ.x, s)]);
273 let δ2 = DiscreteMeasure::from([(δ.x, s)]);
274 let a_part = opAapply(&δ1)-opAapply(&δ2);
275 let d = a.dot(&a_part)/na * nr + 2.0 * na;
276 if d > max_d {
277 max_d = d;
278 max_d_ind = i;
279 }
280 }
281 // Just set mass to zero for transport ray with greater differential
282 assert!(max_d > 0.0);
283 γ1[max_d_ind].α = 0.0;
284 μ[max_d_ind].α = 0.0;
285 }
286 }
287
288 // Set initial guess for μ=μ^{k+1}.
289 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
290 if ρ.α.abs() > F::EPSILON {
291 δ.x = ρ.x;
292 //δ.α = ρ.α; // already set above
293 } else {
294 δ.α = β;
295 }
296 }
297 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
298 μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
299 .map(|(&a,b)| a - b));
300 (μ_base_masses, μ_base_minus_γ0)
301 }
302
303 /// A posteriori transport adaptation.
304 #[replace_float_literals(F::cast_from(literal))]
305 pub(crate) fn aposteriori_transport<F, const N : usize>(
306 γ1 : &mut RNDM<F, N>,
307 μ : &mut RNDM<F, N>,
308 μ_base_minus_γ0 : &mut RNDM<F, N>,
309 μ_base_masses : &Vec<F>,
310 ε : F,
311 tconfig : &TransportConfig<F>
312 ) -> bool
313 where F : Float + ToNalgebraRealField {
314
315 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
316 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
317 // at that point to zero, and retry.
318 let mut all_ok = true;
319 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
320 if α_μ == 0.0 && *α_γ1 != 0.0 {
321 all_ok = false;
322 *α_γ1 = 0.0;
323 }
324 }
325
326 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
327 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
328 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
329 let nγ = γ1.norm(Radon);
330 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1);
331 let t = ε * tconfig.tolerance_ω;
332 if nγ*nΔ > t {
333 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
334 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
335 // will not enter here.
336 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ );
337 all_ok = false
338 }
339
340 if !all_ok {
341 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
342 μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
343 .map(|(&a,b)| a - b));
344
345 }
346
347 all_ok
111 } 348 }
112 349
113 /// Iteratively solve the pointsource localisation problem using sliding forward-backward 350 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
114 /// splitting 351 /// splitting
115 /// 352 ///
116 /// The parametrisatio is as for [`pointsource_fb_reg`]. 353 /// The parametrisation is as for [`pointsource_fb_reg`].
117 /// Inertia is currently not supported. 354 /// Inertia is currently not supported.
118 #[replace_float_literals(F::cast_from(literal))] 355 #[replace_float_literals(F::cast_from(literal))]
119 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( 356 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
120 opA : &'a A, 357 opA : &'a A,
121 b : &A::Observable, 358 b : &A::Observable,
122 reg : Reg, 359 reg : Reg,
123 op𝒟 : &'a 𝒟, 360 op𝒟 : &'a 𝒟,
124 sfbconfig : &SlidingFBConfig<F>, 361 config : &SlidingFBConfig<F>,
125 iterator : I, 362 iterator : I,
126 mut plotter : SeqPlotter<F, N>, 363 mut plotter : SeqPlotter<F, N>,
127 ) -> DiscreteMeasure<Loc<F, N>, F> 364 ) -> RNDM<F, N>
128 where F : Float + ToNalgebraRealField, 365 where F : Float + ToNalgebraRealField,
129 I : AlgIteratorFactory<IterInfo<F, N>>, 366 I : AlgIteratorFactory<IterInfo<F, N>>,
130 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 367 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
131 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow 368 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
132 A::Observable : std::ops::MulAssign<F>, 369 A::PreadjointCodomain : DifferentiableMapping<
133 A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>, 370 Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F
371 >,
134 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 372 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
135 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 373 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
136 + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, 374 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
375 //+ TransportLipschitz<L2Squared, FloatType=F>,
137 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 376 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
138 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 377 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
139 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, 378 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
140 Codomain = BTFN<F, G𝒟, BT𝒟, N>>, 379 Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
141 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 380 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
142 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> 381 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
143 + Differentiable<Loc<F, N>, Output=Loc<F,N>>, 382 + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>,
144 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 383 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
145 //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, 384 //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
146 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 385 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
147 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 386 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
148 PlotLookup : Plotting<N>, 387 PlotLookup : Plotting<N>,
149 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 388 RNDM<F, N> : SpikeMerging<F>,
150 Reg : SlidingRegTerm<F, N> { 389 Reg : SlidingRegTerm<F, N> {
151 390
152 assert!(sfbconfig.τ0 > 0.0 && 391 // Check parameters
153 sfbconfig.θ0 > 0.0); 392 assert!(config.τ0 > 0.0, "Invalid step length parameter");
154 393 config.transport.check();
155 // Set up parameters
156 let config = &sfbconfig.insertion;
157 let op𝒟norm = op𝒟.opnorm_bound();
158 //let max_transport = sfbconfig.max_transport_scale
159 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
160 //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport;
161 //let ℓ = 0.0;
162 let θ = sfbconfig.θ0; // (ℓ + tlip);
163 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
164 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
165 // by τ compared to the conditional gradient approach.
166 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
167 let mut ε = tolerance.initial();
168 394
169 // Initialise iterates 395 // Initialise iterates
170 let mut μ = DiscreteMeasure::new(); 396 let mut μ = DiscreteMeasure::new();
171 let mut γ1 = DiscreteMeasure::new(); 397 let mut γ1 = DiscreteMeasure::new();
172 let mut residual = -b; 398 let mut residual = -b; // Has to equal $Aμ-b$.
399
400 // Set up parameters
401 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
402 let opAnorm = opA.opnorm_bound(Radon, L2);
403 //let max_transport = config.max_transport.scale
404 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
405 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
406 let ℓ = 0.0;
407 let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap();
408 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
409 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
410 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
411 // (the uniform Lipschitz factor of ∇v).
412 // We assume that the residual is decreasing.
413 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)),
414 None => TransportStepLength::FullyAdaptive {
415 l : 0.0,
416 max_transport : 0.0,
417 g : calculate_θ
418 },
419 };
420 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
421 // by τ compared to the conditional gradient approach.
422 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
423 let mut ε = tolerance.initial();
424
425 // Statistics
426 let full_stats = |residual : &A::Observable,
427 μ : &RNDM<F, N>,
428 ε, stats| IterInfo {
429 value : residual.norm2_squared_div2() + reg.apply(μ),
430 n_spikes : μ.len(),
431 ε,
432 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
433 .. stats
434 };
173 let mut stats = IterInfo::new(); 435 let mut stats = IterInfo::new();
174 436
175 // Run the algorithm 437 // Run the algorithm
176 iterator.iterate(|state| { 438 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
177 // Calculate smooth part of surrogate model. 439 // Calculate initial transport
178 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 440 let v = opA.preadjoint().apply(residual);
179 // has no significant overhead. For some reosn Rust doesn't allow us simply moving 441 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(
180 // the residual and replacing it below before the end of this closure. 442 &mut γ1, &mut μ, |ν| opA.apply(ν),
181 let r = std::mem::replace(&mut residual, opA.empty_observable()); 443 ε, τ, &mut θ_or_adaptive, opAnorm,
182 let v = opA.preadjoint().apply(r); 444 v, &config.transport,
183 445 );
184 // Save current base point and shift μ to new positions. Idea is that
185 // μ_base(_masses) = μ^k (vector of masses)
186 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
187 // γ1 = π_♯^1γ^{k+1}
188 // μ = μ^{k+1}
189 let μ_base_masses : Vec<F> = μ.iter_masses().collect();
190 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
191 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
192 let mut sum_norm_dv_times_γinit = 0.0;
193 let mut sum_abs_γinit = 0.0;
194 //let mut sum_norm_dv = 0.0;
195 let γ_prev_len = γ1.len();
196 assert!(μ.len() >= γ_prev_len);
197 γ1.extend(μ[γ_prev_len..].iter().cloned());
198 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
199 let d_v_x = v.differential(&δ.x);
200 // If old transport has opposing sign, the new transport will be none.
201 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
202 0.0
203 } else {
204 δ.α
205 };
206 δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0.
207 ρ.x = δ.x;
208 let nrm = d_v_x.norm(L2);
209 let a = ρ.α.abs();
210 let v = nrm * a;
211 if v > 0.0 {
212 sum_norm_dv_times_γinit += v;
213 sum_abs_γinit += a;
214 }
215 }
216
217 // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z).
218 // This is just one option, there are many.
219 let t = ε * sfbconfig.transport_tolerance_dv;
220 if sum_norm_dv_times_γinit > t {
221 // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t
222 // TODO: store the closure values above?
223 scale_down(γ1.iter_spikes_mut(),
224 t / sum_abs_γinit,
225 |δ| v.differential(&δ.x).norm(L2));
226 }
227 //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon));
228 446
229 // Solve finite-dimensional subproblem several times until the dual variable for the 447 // Solve finite-dimensional subproblem several times until the dual variable for the
230 // regularisation term conforms to the assumptions made for the transport above. 448 // regularisation term conforms to the assumptions made for the transport above.
231 let (d, within_tolerances) = 'adapt_transport: loop { 449 let (d, _within_tolerances, τv̆) = 'adapt_transport: loop {
232 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 450 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
233 for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(),
234 μ_base_minus_γ0.iter_spikes_mut(),
235 μ_base_masses.iter()) {
236 δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass());
237 }
238
239 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b)
240 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 451 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
241 let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ)); 452 let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
242 453
243 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 454 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
244 let (d, within_tolerances) = insert_and_reweigh( 455 let (d, within_tolerances) = insert_and_reweigh(
245 &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0), 456 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0),
246 op𝒟, op𝒟norm, 457 op𝒟, op𝒟norm,
247 τ, ε, 458 τ, ε, &config.insertion,
248 config, 459 &reg, &state, &mut stats,
249 &reg, state, &mut stats,
250 ); 460 );
251 461
252 // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z). 462 // A posteriori transport adaptation.
253 let all_ok = if false { // Basic check 463 if aposteriori_transport(
254 // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, 464 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
255 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 465 ε, &config.transport
256 // at that point to zero, and retry. 466 ) {
257 let mut all_ok = true; 467 break 'adapt_transport (d, within_tolerances, τv̆)
258 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
259 if α_μ == 0.0 && *α_γ1 != 0.0 {
260 all_ok = false;
261 *α_γ1 = 0.0;
262 }
263 }
264 all_ok
265 } else {
266 // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above.
267 let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0);
268
269 // let vpos = γ1.iter_spikes()
270 // .filter(|δ| δ.α > 0.0)
271 // .map(|δ| minus_ω.apply(&δ.x))
272 // .reduce(F::max)
273 // .and_then(|threshold| {
274 // minus_ω.minimise_below(threshold,
275 // ε * config.refinement.tolerance_mult,
276 // config.refinement.max_steps)
277 // .map(|(_z, minus_ω_z)| minus_ω_z)
278 // });
279
280 // let vneg = γ1.iter_spikes()
281 // .filter(|δ| δ.α < 0.0)
282 // .map(|δ| minus_ω.apply(&δ.x))
283 // .reduce(F::min)
284 // .and_then(|threshold| {
285 // minus_ω.maximise_above(threshold,
286 // ε * config.refinement.tolerance_mult,
287 // config.refinement.max_steps)
288 // .map(|(_z, minus_ω_z)| minus_ω_z)
289 // });
290 let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult,
291 config.refinement.max_steps);
292 let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult,
293 config.refinement.max_steps);
294
295 let t = τ * ε * sfbconfig.transport_tolerance_ω;
296 let val = |δ : &DeltaMeasure<Loc<F, N>, F>| {
297 δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg })
298 // match if δ.α >= 0.0 { vpos } else { vneg } {
299 // None => 0.0,
300 // Some(v) => δ.α * (minus_ω.apply(&δ.x) - v)
301 // }
302 };
303 // Calculate positive/bad (rp) values under the integral.
304 // Also store sum of masses for the positive entries.
305 let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| {
306 let v = val(δ);
307 if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) }
308 });
309
310 if rp > t {
311 // TODO: store v above?
312 scale_down(γ1.iter_spikes_mut(), t / w, val);
313 false
314 } else {
315 true
316 }
317 };
318
319 if all_ok {
320 break 'adapt_transport (d, within_tolerances)
321 } 468 }
322 }; 469 };
323 470
324 stats.untransported_fraction = Some({ 471 stats.untransported_fraction = Some({
325 assert_eq!(μ_base_masses.len(), γ1.len()); 472 assert_eq!(μ_base_masses.len(), γ1.len());
328 (a + μ_base_minus_γ0.norm(Radon), b + source) 475 (a + μ_base_minus_γ0.norm(Radon), b + source)
329 }); 476 });
330 stats.transport_error = Some({ 477 stats.transport_error = Some({
331 assert_eq!(μ_base_masses.len(), γ1.len()); 478 assert_eq!(μ_base_masses.len(), γ1.len());
332 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); 479 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
333 let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum(); 480 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
334 (a + err, b + γ1.norm(Radon))
335 }); 481 });
482
483 // Merge spikes.
484 // This expects the prune below to prune γ.
485 // TODO: This may not work correctly in all cases.
486 let ins = &config.insertion;
487 if ins.merge_now(&state) {
488 if let SpikeMergingMethod::None = ins.merging {
489 } else {
490 stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
491 let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
492 let mut d = &τv̆ + op𝒟.preapply(ν);
493 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
494 });
495 }
496 }
336 497
337 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 498 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
338 // latter needs to be pruned when μ is. 499 // latter needs to be pruned when μ is.
339 // TODO: This could do with a two-vector Vec::retain to avoid copies. 500 // TODO: This could do with a two-vector Vec::retain to avoid copies.
340 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); 501 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
341 if μ_new.len() != μ.len() { 502 if μ_new.len() != μ.len() {
342 let mut μ_iter = μ.iter_spikes(); 503 let mut μ_iter = μ.iter_spikes();
343 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); 504 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
505 stats.pruned += μ.len() - μ_new.len();
344 μ = μ_new; 506 μ = μ_new;
345 } 507 }
346
347 // TODO: how to merge?
348 508
349 // Update residual 509 // Update residual
350 residual = calculate_residual(&μ, opA, b); 510 residual = calculate_residual(&μ, opA, b);
351 511
512 let iter = state.iteration();
513 stats.this_iters += 1;
514
515 // Give statistics if requested
516 state.if_verbose(|| {
517 plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ);
518 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
519 });
520
352 // Update main tolerance for next iteration 521 // Update main tolerance for next iteration
353 let ε_prev = ε; 522 ε = tolerance.update(ε, iter);
354 ε = tolerance.update(ε, state.iteration()); 523 }
355 stats.this_iters += 1; 524
356 525 postprocess(μ, &config.insertion, L2Squared, opA, b)
357 // Give function value if needed 526 }
358 state.if_verbose(|| {
359 // Plot if so requested
360 plotter.plot_spikes(
361 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
362 "start".to_string(), None::<&A::PreadjointCodomain>, // TODO: Should be Some(&((-τ) * v)), but not implemented
363 reg.target_bounds(τ, ε_prev), &μ,
364 );
365 // Calculate mean inner iterations and reset relevant counters.
366 // Return the statistics
367 let res = IterInfo {
368 value : residual.norm2_squared_div2() + reg.apply(&μ),
369 n_spikes : μ.len(),
370 ε : ε_prev,
371 postprocessing: config.postprocessing.then(|| μ.clone()),
372 .. stats
373 };
374 stats = IterInfo::new();
375 res
376 })
377 });
378
379 postprocess(μ, config, L2Squared, opA, b)
380 }

mercurial