src/sliding_fb.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
equal deleted inserted replaced
33:aec67cdd6b14 34:efa60bc4f743
6 use numeric_literals::replace_float_literals; 6 use numeric_literals::replace_float_literals;
7 use serde::{Serialize, Deserialize}; 7 use serde::{Serialize, Deserialize};
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::{Map, Flatten}; 11 use std::iter::Iterator;
12 12
13 use alg_tools::iterate::{ 13 use alg_tools::iterate::{
14 AlgIteratorFactory, 14 AlgIteratorFactory,
15 AlgIteratorState 15 AlgIteratorState
16 }; 16 };
17 use alg_tools::euclidean::{ 17 use alg_tools::euclidean::Euclidean;
18 Euclidean,
19 Dot
20 };
21 use alg_tools::sets::Cube; 18 use alg_tools::sets::Cube;
22 use alg_tools::loc::Loc; 19 use alg_tools::loc::Loc;
23 use alg_tools::mapping::{Apply, Differentiable}; 20 use alg_tools::mapping::{Apply, Differentiable};
21 use alg_tools::norms::{Norm, L2};
24 use alg_tools::bisection_tree::{ 22 use alg_tools::bisection_tree::{
25 BTFN, 23 BTFN,
26 PreBTFN, 24 PreBTFN,
27 Bounds, 25 Bounds,
28 BTNodeLookup, 26 BTNodeLookup,
35 }; 33 };
36 use alg_tools::mapping::RealMapping; 34 use alg_tools::mapping::RealMapping;
37 use alg_tools::nalgebra_support::ToNalgebraRealField; 35 use alg_tools::nalgebra_support::ToNalgebraRealField;
38 36
39 use crate::types::*; 37 use crate::types::*;
40 use crate::measures::{ 38 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
41 DiscreteMeasure,
42 DeltaMeasure,
43 };
44 use crate::measures::merging::{ 39 use crate::measures::merging::{
45 //SpikeMergingMethod, 40 //SpikeMergingMethod,
46 SpikeMerging, 41 SpikeMerging,
47 }; 42 };
48 use crate::forward_model::ForwardModel; 43 use crate::forward_model::ForwardModel;
67 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 62 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
68 #[serde(default)] 63 #[serde(default)]
69 pub struct SlidingFBConfig<F : Float> { 64 pub struct SlidingFBConfig<F : Float> {
70 /// Step length scaling 65 /// Step length scaling
71 pub τ0 : F, 66 pub τ0 : F,
72 /// Transport smoothness assumption 67 /// Transport step length $θ$ normalised to $(0, 1)$.
73 pub ℓ0 : F, 68 pub θ0 : F,
74 /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. 69 /// Maximum transport mass scaling.
75 /// This means that $τθ$ is the step length for the transport step. 70 // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$.
76 pub inverse_transport_scaling : F, 71 // pub max_transport_scale : F,
77 /// Factor for deciding transport reduction based on smoothness assumption violation 72 /// Transport tolerance wrt. ω
78 pub minimum_goodness_factor : F, 73 pub transport_tolerance_ω : F,
79 /// Maximum rays to retain in transports from each source. 74 /// Transport tolerance wrt. ∇v
80 pub maximum_rays : usize, 75 pub transport_tolerance_dv : F,
81 /// Generic parameters 76 /// Generic parameters
82 pub insertion : FBGenericConfig<F>, 77 pub insertion : FBGenericConfig<F>,
83 } 78 }
84 79
85 #[replace_float_literals(F::cast_from(literal))] 80 #[replace_float_literals(F::cast_from(literal))]
86 impl<F : Float> Default for SlidingFBConfig<F> { 81 impl<F : Float> Default for SlidingFBConfig<F> {
87 fn default() -> Self { 82 fn default() -> Self {
88 SlidingFBConfig { 83 SlidingFBConfig {
89 τ0 : 0.99, 84 τ0 : 0.99,
90 ℓ0 : 1.5, 85 θ0 : 0.99,
91 inverse_transport_scaling : 1.0, 86 //max_transport_scale : 10.0,
92 minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, 87 transport_tolerance_ω : 1.0, // TODO: no idea what this should be
93 // should be scaled by problem data? 88 transport_tolerance_dv : 1.0, // TODO: no idea what this should be
94 maximum_rays : 10,
95 insertion : Default::default() 89 insertion : Default::default()
96 } 90 }
97 } 91 }
98 } 92 }
99 93
100 /// A transport ray (including various additional computational information). 94 /// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i)
101 #[derive(Clone, Debug)]
102 pub struct Ray<Domain, F : Num> {
103 /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`].
104 δ : DeltaMeasure<Domain, F>,
105 /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$.
106 goodness : F,
107 /// Goodness of the regularisation term for the ray: $w(z)-w(y)$.
108 /// Initially zero until $w$ can be constructed.
109 reg_goodness : F,
110 /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`.
111 to_return : F,
112 }
113
114 /// A set of transport rays with the same source point.
115 #[derive(Clone, Debug)]
116 pub struct RaySet<Domain, F : Num> {
117 /// Source of every ray in thset
118 source : Domain,
119 /// Mass of the diagonal ray, with destination the same as the source.
120 diagonal: F,
121 /// Goodness of the data term for the diagonal ray with $z=x$:
122 /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$.
123 diagonal_goodness : F,
124 /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$.
125 diagonal_reg_goodness : F,
126 /// The non-diagonal rays.
127 rays : Vec<Ray<Domain, F>>,
128 }
129
130 #[replace_float_literals(F::cast_from(literal))] 95 #[replace_float_literals(F::cast_from(literal))]
131 impl<Domain, F : Float> RaySet<Domain, F> { 96 fn scale_down<'a, I, F, G, const N : usize>(
132 fn non_diagonal_mass(&self) -> F { 97 iter : I,
133 self.rays 98 q̄ : F,
134 .iter() 99 mut g : G
135 .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) 100 ) where F : Float,
136 .sum() 101 I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>,
137 } 102 G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F {
138 103 iter.for_each(|δ| {
139 fn total_mass(&self) -> F { 104 if δ.α != 0.0 {
140 self.non_diagonal_mass() + self.diagonal 105 let b = g(δ);
141 } 106 if b * δ.α > 0.0 {
142 107 δ.α *= q̄/b;
143 fn targets<'a>(&'a self) 108 }
144 -> Map< 109 }
145 std::slice::Iter<'a, Ray<Domain, F>>, 110 });
146 fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F>
147 > {
148 fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>)
149 -> &'b DeltaMeasure<Domain, F> {
150 δ
151 }
152 self.rays
153 .iter()
154 .map(get_δ)
155 }
156
157 // fn non_diagonal_goodness(&self) -> F {
158 // self.rays
159 // .iter()
160 // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| {
161 // α * (goodness + reg_goodness)
162 // })
163 // .sum()
164 // }
165
166 // fn total_goodness(&self) -> F {
167 // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness)
168 // }
169
170 fn non_diagonal_badness(&self) -> F {
171 self.rays
172 .iter()
173 .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| {
174 0.0.max(- α * (goodness + reg_goodness))
175 })
176 .sum()
177 }
178
179 fn total_badness(&self) -> F {
180 self.non_diagonal_badness()
181 + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness))
182 }
183
184 fn total_return(&self) -> F {
185 self.rays
186 .iter()
187 .map(|&Ray{ to_return, .. }| to_return)
188 .sum()
189 }
190 }
191
192 #[replace_float_literals(F::cast_from(literal))]
193 impl<Domain : Clone, F : Num> RaySet<Domain, F> {
194 fn return_targets<'a>(&'a self)
195 -> Flatten<Map<
196 std::slice::Iter<'a, Ray<Domain, F>>,
197 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>>
198 >> {
199 fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>)
200 -> Option<DeltaMeasure<Domain, F>> {
201 (ray.to_return != 0.0).then_some(
202 DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return}
203 )
204 }
205 let tmp : Map<
206 std::slice::Iter<'a, Ray<Domain, F>>,
207 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>>
208 > = self.rays
209 .iter()
210 .map(get_return);
211 tmp.flatten()
212 }
213 } 111 }
214 112
215 /// Iteratively solve the pointsource localisation problem using sliding forward-backward 113 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
216 /// splitting 114 /// splitting
217 /// 115 ///
218 /// The parametrisatio is as for [`pointsource_fb_reg`]. 116 /// The parametrisatio is as for [`pointsource_fb_reg`].
219 /// Inertia is currently not supported. 117 /// Inertia is currently not supported.
220 #[replace_float_literals(F::cast_from(literal))] 118 #[replace_float_literals(F::cast_from(literal))]
221 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( 119 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
222 opA : &'a A, 120 opA : &'a A,
223 b : &A::Observable, 121 b : &A::Observable,
224 reg : Reg, 122 reg : Reg,
225 op𝒟 : &'a 𝒟, 123 op𝒟 : &'a 𝒟,
226 sfbconfig : &SlidingFBConfig<F>, 124 sfbconfig : &SlidingFBConfig<F>,
236 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 134 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
237 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 135 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
238 + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, 136 + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>,
239 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 137 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
240 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 138 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
241 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 139 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
242 𝒟::Codomain : RealMapping<F, N>, 140 Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
141 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
243 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> 142 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
244 + Differentiable<Loc<F, N>, Output=Loc<F,N>>, 143 + Differentiable<Loc<F, N>, Output=Loc<F,N>>,
245 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 144 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
246 //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, 145 //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>,
247 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 146 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
249 PlotLookup : Plotting<N>, 148 PlotLookup : Plotting<N>,
250 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 149 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
251 Reg : SlidingRegTerm<F, N> { 150 Reg : SlidingRegTerm<F, N> {
252 151
253 assert!(sfbconfig.τ0 > 0.0 && 152 assert!(sfbconfig.τ0 > 0.0 &&
254 sfbconfig.inverse_transport_scaling > 0.0 && 153 sfbconfig.θ0 > 0.0);
255 sfbconfig.ℓ0 > 0.0);
256 154
257 // Set up parameters 155 // Set up parameters
258 let config = &sfbconfig.insertion; 156 let config = &sfbconfig.insertion;
259 let op𝒟norm = op𝒟.opnorm_bound(); 157 let op𝒟norm = op𝒟.opnorm_bound();
260 let θ = sfbconfig.inverse_transport_scaling; 158 //let max_transport = sfbconfig.max_transport_scale
261 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() 159 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
262 .max(opA.transport_lipschitz_factor(L2Squared) * θ); 160 //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport;
263 let ℓ = sfbconfig.ℓ0; // TODO: v scaling? 161 //let ℓ = 0.0;
162 let θ = sfbconfig.θ0; // (ℓ + tlip);
163 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
264 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 164 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
265 // by τ compared to the conditional gradient approach. 165 // by τ compared to the conditional gradient approach.
266 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 166 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
267 let mut ε = tolerance.initial(); 167 let mut ε = tolerance.initial();
268 168
269 // Initialise iterates 169 // Initialise iterates
270 let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); 170 let mut μ = DiscreteMeasure::new();
271 let mut μ_transported_base = DiscreteMeasure::new(); 171 let mut γ1 = DiscreteMeasure::new();
272 let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info
273 let mut residual = -b; 172 let mut residual = -b;
274 let mut stats = IterInfo::new(); 173 let mut stats = IterInfo::new();
275 174
276 // Run the algorithm 175 // Run the algorithm
277 iterator.iterate(|state| { 176 iterator.iterate(|state| {
278 // Calculate smooth part of surrogate model. 177 // Calculate smooth part of surrogate model.
279 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 178 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
280 // has no significant overhead. For some reosn Rust doesn't allow us simply moving 179 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
281 // the residual and replacing it below before the end of this closure. 180 // the residual and replacing it below before the end of this closure.
282 residual *= -τ;
283 let r = std::mem::replace(&mut residual, opA.empty_observable()); 181 let r = std::mem::replace(&mut residual, opA.empty_observable());
284 let minus_τv = opA.preadjoint().apply(r); 182 let v = opA.preadjoint().apply(r);
285 183
286 // Save current base point and shift μ to new positions. 184 // Save current base point and shift μ to new positions. Idea is that
287 let μ_base = μ.clone(); 185 // μ_base(_masses) = μ^k (vector of masses)
288 for δ in μ.iter_spikes_mut() { 186 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
289 δ.x += minus_τv.differential(&δ.x) * θ; 187 // γ1 = π_♯^1γ^{k+1}
290 } 188 // μ = μ^{k+1}
291 let mut μ_transported = μ.clone(); 189 let μ_base_masses : Vec<F> = μ.iter_masses().collect();
292 190 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
293 assert_eq!(μ.len(), γ_hat.len()); 191 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
294 192 let mut sum_norm_dv_times_γinit = 0.0;
295 // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter 193 let mut sum_abs_γinit = 0.0;
296 // transports points x from μ_base to points y in μ as shifted above, or “returns” 194 //let mut sum_norm_dv = 0.0;
297 // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays 195 let γ_prev_len = γ1.len();
298 // are not “good” for the smoothness assumptions, or if γ_hat has more mass than 196 assert!(μ.len() >= γ_prev_len);
299 // μ_base. 197 γ1.extend(μ[γ_prev_len..].iter().cloned());
300 let mut total_goodness = 0.0; // data term goodness 198 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
301 let mut total_reg_goodness = 0.0; // regulariser goodness 199 let d_v_x = v.differential(&δ.x);
302 let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; 200 // If old transport has opposing sign, the new transport will be none.
303 201 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
304 for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { 202 0.0
305 // Calculate data term goodness for all rays. 203 } else {
306 let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; 204 δ.α
307 let x = &r.source;
308 let mvy = minus_τv.apply(y);
309 let mdvx = minus_τv.differential(x);
310 let mut r_total_mass = 0.0; // Total mass of all rays with source r.source.
311 let mut bad_mass = 0.0;
312 let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| {
313 *reg_goodness = 0.0; // Initial guess
314 *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y))
315 + ℓ * z.dist2_squared(&y);
316 total_goodness += *goodness * α;
317 r_total_mass += α; // TODO: should this include to_return from staging? (Probably not)
318 if *goodness < 0.0 {
319 bad_mass += α;
320 }
321 }; 205 };
322 for ray in r.rays.iter_mut() { 206 δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0.
323 calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); 207 ρ.x = δ.x;
208 let nrm = d_v_x.norm(L2);
209 let a = ρ.α.abs();
210 let v = nrm * a;
211 if v > 0.0 {
212 sum_norm_dv_times_γinit += v;
213 sum_abs_γinit += a;
324 } 214 }
325 calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); 215 }
326 216
327 // If the total mass of the ray set is less than that of μ at the same source, 217 // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z).
328 // a diagonal component needs to be added to be able to (attempt to) transport 218 // This is just one option, there are many.
329 // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ 219 let t = ε * sfbconfig.transport_tolerance_dv;
330 // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” 220 if sum_norm_dv_times_γinit > t {
331 // rays in the set. 221 // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t
332 if δ_mass >= r_total_mass { 222 // TODO: store the closure values above?
333 r.diagonal += δ_mass - r_total_mass; 223 scale_down(γ1.iter_spikes_mut(),
334 } else { 224 t / sum_abs_γinit,
335 let mut reduce_transport = r_total_mass - δ_mass; 225 |δ| v.differential(&δ.x).norm(L2));
336 let mut good_needed = (bad_mass - reduce_transport).max(0.0); 226 }
337 // NOTE: reg_goodness is zero at this point, so it is not used in this code. 227 //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon));
338 let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| {
339 if reduce_transport > 0.0 {
340 let return_amount = if goodness < 0.0 {
341 α.min(reduce_transport)
342 } else {
343 let amount = α.min(good_needed);
344 good_needed -= amount;
345 amount
346 };
347
348 if return_amount > 0.0 {
349 reduce_transport -= return_amount;
350 // Adjust total goodness by returned amount
351 total_goodness -= goodness * return_amount;
352 to_return.map(|tr| *tr += return_amount);
353 *α -= return_amount;
354 *α > 0.0
355 } else {
356 true
357 }
358 } else {
359 true
360 }
361 };
362 r.rays.retain_mut(|ray| {
363 reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α)
364 });
365 // A bad diagonal is simply reduced without any 'return'.
366 // It was, after all, just added to match μ, but there is no need to match it.
367 // It's just a heuristic.
368 // TODO: Maybe a bad diagonal should be the first to go.
369 reduce_ray(r.diagonal_goodness, None, &mut r.diagonal);
370 }
371 }
372 228
373 // Solve finite-dimensional subproblem several times until the dual variable for the 229 // Solve finite-dimensional subproblem several times until the dual variable for the
374 // regularisation term conforms to the assumptions made for the transport above. 230 // regularisation term conforms to the assumptions made for the transport above.
375 let (d, within_tolerances) = 'adapt_transport: loop { 231 let (d, within_tolerances) = 'adapt_transport: loop {
376 // If transport violates goodness requirements, shift it to ‘return’ mass to z, 232 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
377 // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), 233 for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(),
378 // we proportionally distribute the reductions to each ray set, and within each ray 234 μ_base_minus_γ0.iter_spikes_mut(),
379 // set, prioritise reducing the oldest bad rays' weight. 235 μ_base_masses.iter()) {
380 let tg = total_goodness + total_reg_goodness; 236 δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass());
381 let adaptation_needed = minimum_goodness - tg; 237 }
382 if adaptation_needed > 0.0 { 238
383 let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); 239 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b)
384 240 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
385 let mut return_ray = |goodness : F, 241 let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ));
386 reg_goodness : F, 242
387 to_return : Option<&mut F>, 243 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
388 α : &mut F, 244 let (d, within_tolerances) = insert_and_reweigh(
389 left_to_return : &mut F| { 245 &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0),
390 let g = goodness + reg_goodness; 246 op𝒟, op𝒟norm,
391 assert!(*α >= 0.0 && *left_to_return >= 0.0); 247 τ, ε,
392 if *left_to_return > 0.0 && g < 0.0 { 248 config,
393 let return_amount = (*left_to_return / (-g)).min(*α); 249 &reg, state, &mut stats,
394 *left_to_return -= (-g) * return_amount; 250 );
395 total_goodness -= goodness * return_amount; 251
396 total_reg_goodness -= reg_goodness * return_amount; 252 // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z).
397 to_return.map(|tr| *tr += return_amount); 253 let all_ok = if false { // Basic check
398 *α -= return_amount; 254 // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
399 *α > 0.0 255 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
400 } else { 256 // at that point to zero, and retry.
401 true 257 let mut all_ok = true;
402 } 258 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
403 }; 259 if α_μ == 0.0 && *α_γ1 != 0.0 {
404 260 all_ok = false;
405 for r in γ_hat.iter_mut() { 261 *α_γ1 = 0.0;
406 let mut left_to_return = adaptation_needed * r.total_badness() / total_badness;
407 if left_to_return > 0.0 {
408 for ray in r.rays.iter_mut() {
409 return_ray(ray.goodness, ray.reg_goodness,
410 Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return);
411 }
412 return_ray(r.diagonal_goodness, r.diagonal_reg_goodness,
413 None, &mut r.diagonal, &mut left_to_return);
414 } 262 }
415 } 263 }
416 } 264 all_ok
417 265 } else {
418 // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. 266 // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above.
419 // This can be broken down into 267 let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0);
420 // 268
421 // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and 269 // let vpos = γ1.iter_spikes()
422 // μ_transported = π_#^1 γ_shift 270 // .filter(|δ| δ.α > 0.0)
423 // 271 // .map(|δ| minus_ω.apply(&δ.x))
424 // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. 272 // .reduce(F::max)
425 // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ 273 // .and_then(|threshold| {
426 // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure 274 // minus_ω.minimise_below(threshold,
427 // The latter can be constructed from γ_hat rays' source and destination with the 275 // ε * config.refinement.tolerance_mult,
428 // to_return mass. 276 // config.refinement.max_steps)
429 // 277 // .map(|(_z, minus_ω_z)| minus_ω_z)
430 // Note that μ_transported is constructed to have the same spike locations as μ, but 278 // });
431 // to have same length as μ_base. This loop does not iterate over the spikes of μ 279
432 // (and corresponding transports of γ_hat) that have been newly added in the current 280 // let vneg = γ1.iter_spikes()
433 // 'adapt_transport loop. 281 // .filter(|δ| δ.α < 0.0)
434 for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), 282 // .map(|δ| minus_ω.apply(&δ.x))
435 μ_transported.iter_spikes_mut(), 283 // .reduce(F::min)
436 γ_hat.iter()) { 284 // .and_then(|threshold| {
437 let &DeltaMeasure{ref x, α} = δ; 285 // minus_ω.maximise_above(threshold,
438 debug_assert_eq!(*x, r.source); 286 // ε * config.refinement.tolerance_mult,
439 let shifted_mass = r.total_mass(); 287 // config.refinement.max_steps)
440 let ret_mass = r.total_return(); 288 // .map(|(_z, minus_ω_z)| minus_ω_z)
441 // μ - π_#^0 (γ_shift + γ_return) 289 // });
442 μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; 290 let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult,
443 // π_#^1 γ_return 291 config.refinement.max_steps);
444 μ_transported_base.extend(r.return_targets()); 292 let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult,
445 // π_#^1 γ_shift 293 config.refinement.max_steps);
446 δ_transported.set_mass(shifted_mass); 294
447 } 295 let t = τ * ε * sfbconfig.transport_tolerance_ω;
448 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) 296 let val = |δ : &DeltaMeasure<Loc<F, N>, F>| {
449 let transported_residual = calculate_residual2(&μ_transported, 297 δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg })
450 &μ_transported_base, 298 // match if δ.α >= 0.0 { vpos } else { vneg } {
451 opA, b); 299 // None => 0.0,
452 let transported_minus_τv = opA.preadjoint() 300 // Some(v) => δ.α * (minus_ω.apply(&δ.x) - v)
453 .apply(transported_residual); 301 // }
454 302 };
455 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 303 // Calculate positive/bad (rp) values under the integral.
456 let (mut d, within_tolerances) = insert_and_reweigh( 304 // Also store sum of masses for the positive entries.
457 &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), 305 let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| {
458 op𝒟, op𝒟norm, 306 let v = val(δ);
459 τ, ε, 307 if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) }
460 config, &reg, state, &mut stats 308 });
461 ); 309
462 310 if rp > t {
463 // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely 311 // TODO: store v above?
464 // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); 312 scale_down(γ1.iter_spikes_mut(), t / w, val);
465 // We “essentially” assume that the subdifferential w of the regularisation term 313 false
466 // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] 314 } else {
467 // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for 315 true
468 // some w in the subdifferential of the regularisation term, such that
469 // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`].
470 for r in γ_hat.iter_mut() {
471 for ray in r.rays.iter_mut() {
472 ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config);
473 total_reg_goodness += ray.reg_goodness * ray.δ.α;
474 } 316 }
475 } 317 };
476 318
477 // If update of regularisation term goodness didn't invalidate minimum goodness 319 if all_ok {
478 // requirements, we have found our step. Otherwise we need to keep reducing
479 // transport by repeating the loop.
480 if total_goodness + total_reg_goodness >= minimum_goodness {
481 break 'adapt_transport (d, within_tolerances) 320 break 'adapt_transport (d, within_tolerances)
482 } 321 }
483 }; 322 };
484 323
485 // Update γ_hat to new location 324 stats.untransported_fraction = Some({
486 for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { 325 assert_eq!(μ_base_masses.len(), γ1.len());
487 // Prune rays that only had a return component, as the return component becomes 326 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
488 // a diagonal in γ̂^{k+1}. 327 let source = μ_base_masses.iter().map(|v| v.abs()).sum();
489 r.rays.retain(|ray| ray.δ.α != 0.0); 328 (a + μ_base_minus_γ0.norm(Radon), b + source)
490 // Otherwise zero out the return component, or stage rays for pruning 329 });
491 // to keep memory and computational demands reasonable. 330 stats.transport_error = Some({
492 let n_rays = r.rays.len(); 331 assert_eq!(μ_base_masses.len(), γ1.len());
493 for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { 332 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
494 if ir >= sfbconfig.maximum_rays { 333 let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum();
495 // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for 334 (a + err, b + γ1.norm(Radon))
496 // pruning in next step. 335 });
497 ray.to_return = ray.δ.α; 336
498 ray.δ.α = 0.0; 337 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
499 } else { 338 // latter needs to be pruned when μ is.
500 ray.to_return = 0.0; 339 // TODO: This could do with a two-vector Vec::retain to avoid copies.
501 } 340 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
502 ray.goodness = 0.0; // TODO: probably not needed 341 if μ_new.len() != μ.len() {
503 ray.reg_goodness = 0.0; 342 let mut μ_iter = μ.iter_spikes();
504 } 343 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
505 // Add a new ray for the currently diagonal component 344 μ = μ_new;
506 if r.diagonal > 0.0 {
507 r.rays.push(Ray{
508 δ : DeltaMeasure{x : r.source, α : r.diagonal},
509 goodness : 0.0,
510 reg_goodness : 0.0,
511 to_return : 0.0,
512 });
513 // TODO: Maybe this does not need to be done here, and is sufficent to to do where
514 // the goodness is calculated.
515 r.diagonal = 0.0;
516 }
517 r.diagonal_goodness = 0.0;
518
519 // Shift source
520 r.source = δ_new.x;
521 }
522 // Extend to new spikes
523 γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| {
524 RaySet{
525 source : δ_new.x,
526 rays : [].into(),
527 diagonal : 0.0,
528 diagonal_goodness : 0.0,
529 diagonal_reg_goodness : 0.0
530 }
531 }));
532
533 // Prune spikes with zero weight. This also moves the marginal differences of corresponding
534 // transports from γ_hat to γ_pruned_marginal_diff.
535 // TODO: optimise standard prune with swap_remove.
536 μ_transported_base.clear();
537 let mut i = 0;
538 assert_eq!(μ.len(), γ_hat.len());
539 while i < μ.len() {
540 if μ[i].α == F::ZERO {
541 μ.swap_remove(i);
542 let r = γ_hat.swap_remove(i);
543 μ_transported_base.extend(r.targets().cloned());
544 μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source };
545 } else {
546 i += 1;
547 }
548 } 345 }
549 346
550 // TODO: how to merge? 347 // TODO: how to merge?
551 348
552 // Update residual 349 // Update residual
560 // Give function value if needed 357 // Give function value if needed
561 state.if_verbose(|| { 358 state.if_verbose(|| {
562 // Plot if so requested 359 // Plot if so requested
563 plotter.plot_spikes( 360 plotter.plot_spikes(
564 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, 361 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
565 "start".to_string(), Some(&minus_τv), 362 "start".to_string(), None::<&A::PreadjointCodomain>, // TODO: Should be Some(&((-τ) * v)), but not implemented
566 reg.target_bounds(τ, ε_prev), &μ, 363 reg.target_bounds(τ, ε_prev), &μ,
567 ); 364 );
568 // Calculate mean inner iterations and reset relevant counters. 365 // Calculate mean inner iterations and reset relevant counters.
569 // Return the statistics 366 // Return the statistics
570 let res = IterInfo { 367 let res = IterInfo {

mercurial