| |
1 /*! |
| |
2 Solver for the point source localisation problem using a sliding |
| |
3 forward-backward splitting method. |
| |
4 */ |
| |
5 |
| |
6 use numeric_literals::replace_float_literals; |
| |
7 use serde::{Serialize, Deserialize}; |
| |
8 //use colored::Colorize; |
| |
9 //use nalgebra::{DVector, DMatrix}; |
| |
10 use itertools::izip; |
| |
11 use std::iter::{Map, Flatten}; |
| |
12 |
| |
13 use alg_tools::iterate::{ |
| |
14 AlgIteratorFactory, |
| |
15 AlgIteratorState |
| |
16 }; |
| |
17 use alg_tools::euclidean::{ |
| |
18 Euclidean, |
| |
19 Dot |
| |
20 }; |
| |
21 use alg_tools::sets::Cube; |
| |
22 use alg_tools::loc::Loc; |
| |
23 use alg_tools::mapping::{Apply, Differentiable}; |
| |
24 use alg_tools::bisection_tree::{ |
| |
25 BTFN, |
| |
26 PreBTFN, |
| |
27 Bounds, |
| |
28 BTNodeLookup, |
| |
29 BTNode, |
| |
30 BTSearch, |
| |
31 P2Minimise, |
| |
32 SupportGenerator, |
| |
33 LocalAnalysis, |
| |
34 //Bounded, |
| |
35 }; |
| |
36 use alg_tools::mapping::RealMapping; |
| |
37 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
38 |
| |
39 use crate::types::*; |
| |
40 use crate::measures::{ |
| |
41 DiscreteMeasure, |
| |
42 DeltaMeasure, |
| |
43 }; |
| |
44 use crate::measures::merging::{ |
| |
45 //SpikeMergingMethod, |
| |
46 SpikeMerging, |
| |
47 }; |
| |
48 use crate::forward_model::ForwardModel; |
| |
49 use crate::seminorms::DiscreteMeasureOp; |
| |
50 //use crate::tolerance::Tolerance; |
| |
51 use crate::plot::{ |
| |
52 SeqPlotter, |
| |
53 Plotting, |
| |
54 PlotLookup |
| |
55 }; |
| |
56 use crate::fb::*; |
| |
57 use crate::regularisation::SlidingRegTerm; |
| |
58 use crate::dataterm::{ |
| |
59 L2Squared, |
| |
60 //DataTerm, |
| |
61 calculate_residual, |
| |
62 calculate_residual2, |
| |
63 }; |
| |
64 use crate::transport::TransportLipschitz; |
| |
65 |
| |
66 /// Settings for [`pointsource_sliding_fb_reg`]. |
| |
67 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| |
68 #[serde(default)] |
| |
69 pub struct SlidingFBConfig<F : Float> { |
| |
70 /// Step length scaling |
| |
71 pub τ0 : F, |
| |
72 /// Transport smoothness assumption |
| |
73 pub ℓ0 : F, |
| |
74 /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. |
| |
75 /// This means that $τθ$ is the step length for the transport step. |
| |
76 pub inverse_transport_scaling : F, |
| |
77 /// Factor for deciding transport reduction based on smoothness assumption violation |
| |
78 pub minimum_goodness_factor : F, |
| |
79 /// Maximum rays to retain in transports from each source. |
| |
80 pub maximum_rays : usize, |
| |
81 /// Generic parameters |
| |
82 pub insertion : FBGenericConfig<F>, |
| |
83 } |
| |
84 |
| |
85 #[replace_float_literals(F::cast_from(literal))] |
| |
86 impl<F : Float> Default for SlidingFBConfig<F> { |
| |
87 fn default() -> Self { |
| |
88 SlidingFBConfig { |
| |
89 τ0 : 0.99, |
| |
90 ℓ0 : 1.5, |
| |
91 inverse_transport_scaling : 1.0, |
| |
92 minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, |
| |
93 // should be scaled by problem data? |
| |
94 maximum_rays : 10, |
| |
95 insertion : Default::default() |
| |
96 } |
| |
97 } |
| |
98 } |
| |
99 |
| |
100 /// A transport ray (including various additional computational information). |
| |
101 #[derive(Clone, Debug)] |
| |
102 pub struct Ray<Domain, F : Num> { |
| |
103 /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`]. |
| |
104 δ : DeltaMeasure<Domain, F>, |
| |
105 /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$. |
| |
106 goodness : F, |
| |
107 /// Goodness of the regularisation term for the ray: $w(z)-w(y)$. |
| |
108 /// Initially zero until $w$ can be constructed. |
| |
109 reg_goodness : F, |
| |
110 /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`. |
| |
111 to_return : F, |
| |
112 } |
| |
113 |
| |
114 /// A set of transport rays with the same source point. |
| |
115 #[derive(Clone, Debug)] |
| |
116 pub struct RaySet<Domain, F : Num> { |
| |
117 /// Source of every ray in thset |
| |
118 source : Domain, |
| |
119 /// Mass of the diagonal ray, with destination the same as the source. |
| |
120 diagonal: F, |
| |
121 /// Goodness of the data term for the diagonal ray with $z=x$: |
| |
122 /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$. |
| |
123 diagonal_goodness : F, |
| |
124 /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$. |
| |
125 diagonal_reg_goodness : F, |
| |
126 /// The non-diagonal rays. |
| |
127 rays : Vec<Ray<Domain, F>>, |
| |
128 } |
| |
129 |
| |
130 #[replace_float_literals(F::cast_from(literal))] |
| |
131 impl<Domain, F : Float> RaySet<Domain, F> { |
| |
132 fn non_diagonal_mass(&self) -> F { |
| |
133 self.rays |
| |
134 .iter() |
| |
135 .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) |
| |
136 .sum() |
| |
137 } |
| |
138 |
| |
139 fn total_mass(&self) -> F { |
| |
140 self.non_diagonal_mass() + self.diagonal |
| |
141 } |
| |
142 |
| |
143 fn targets<'a>(&'a self) |
| |
144 -> Map< |
| |
145 std::slice::Iter<'a, Ray<Domain, F>>, |
| |
146 fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F> |
| |
147 > { |
| |
148 fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>) |
| |
149 -> &'b DeltaMeasure<Domain, F> { |
| |
150 δ |
| |
151 } |
| |
152 self.rays |
| |
153 .iter() |
| |
154 .map(get_δ) |
| |
155 } |
| |
156 |
| |
157 // fn non_diagonal_goodness(&self) -> F { |
| |
158 // self.rays |
| |
159 // .iter() |
| |
160 // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { |
| |
161 // α * (goodness + reg_goodness) |
| |
162 // }) |
| |
163 // .sum() |
| |
164 // } |
| |
165 |
| |
166 // fn total_goodness(&self) -> F { |
| |
167 // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness) |
| |
168 // } |
| |
169 |
| |
170 fn non_diagonal_badness(&self) -> F { |
| |
171 self.rays |
| |
172 .iter() |
| |
173 .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { |
| |
174 0.0.max(- α * (goodness + reg_goodness)) |
| |
175 }) |
| |
176 .sum() |
| |
177 } |
| |
178 |
| |
179 fn total_badness(&self) -> F { |
| |
180 self.non_diagonal_badness() |
| |
181 + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness)) |
| |
182 } |
| |
183 |
| |
184 fn total_return(&self) -> F { |
| |
185 self.rays |
| |
186 .iter() |
| |
187 .map(|&Ray{ to_return, .. }| to_return) |
| |
188 .sum() |
| |
189 } |
| |
190 } |
| |
191 |
| |
192 #[replace_float_literals(F::cast_from(literal))] |
| |
193 impl<Domain : Clone, F : Num> RaySet<Domain, F> { |
| |
194 fn return_targets<'a>(&'a self) |
| |
195 -> Flatten<Map< |
| |
196 std::slice::Iter<'a, Ray<Domain, F>>, |
| |
197 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> |
| |
198 >> { |
| |
199 fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>) |
| |
200 -> Option<DeltaMeasure<Domain, F>> { |
| |
201 (ray.to_return != 0.0).then_some( |
| |
202 DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return} |
| |
203 ) |
| |
204 } |
| |
205 let tmp : Map< |
| |
206 std::slice::Iter<'a, Ray<Domain, F>>, |
| |
207 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> |
| |
208 > = self.rays |
| |
209 .iter() |
| |
210 .map(get_return); |
| |
211 tmp.flatten() |
| |
212 } |
| |
213 } |
| |
214 |
| |
215 /// Iteratively solve the pointsource localisation problem using sliding forward-backward |
| |
216 /// splitting |
| |
217 /// |
| |
218 /// The parametrisatio is as for [`pointsource_fb_reg`]. |
| |
219 /// Inertia is currently not supported. |
| |
220 #[replace_float_literals(F::cast_from(literal))] |
| |
221 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( |
| |
222 opA : &'a A, |
| |
223 b : &A::Observable, |
| |
224 reg : Reg, |
| |
225 op𝒟 : &'a 𝒟, |
| |
226 sfbconfig : &SlidingFBConfig<F>, |
| |
227 iterator : I, |
| |
228 mut plotter : SeqPlotter<F, N>, |
| |
229 ) -> DiscreteMeasure<Loc<F, N>, F> |
| |
230 where F : Float + ToNalgebraRealField, |
| |
231 I : AlgIteratorFactory<IterInfo<F, N>>, |
| |
232 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
| |
233 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
| |
234 A::Observable : std::ops::MulAssign<F>, |
| |
235 A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>, |
| |
236 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| |
237 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
| |
238 + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, |
| |
239 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| |
240 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
| |
241 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
| |
242 𝒟::Codomain : RealMapping<F, N>, |
| |
243 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> |
| |
244 + Differentiable<Loc<F, N>, Output=Loc<F,N>>, |
| |
245 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| |
246 //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, |
| |
247 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| |
248 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
| |
249 PlotLookup : Plotting<N>, |
| |
250 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
| |
251 Reg : SlidingRegTerm<F, N> { |
| |
252 |
| |
253 assert!(sfbconfig.τ0 > 0.0 && |
| |
254 sfbconfig.inverse_transport_scaling > 0.0 && |
| |
255 sfbconfig.ℓ0 > 0.0); |
| |
256 |
| |
257 // Set up parameters |
| |
258 let config = &sfbconfig.insertion; |
| |
259 let op𝒟norm = op𝒟.opnorm_bound(); |
| |
260 let θ = sfbconfig.inverse_transport_scaling; |
| |
261 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() |
| |
262 .max(opA.transport_lipschitz_factor(L2Squared) * θ); |
| |
263 let ℓ = sfbconfig.ℓ0; // TODO: v scaling? |
| |
264 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| |
265 // by τ compared to the conditional gradient approach. |
| |
266 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| |
267 let mut ε = tolerance.initial(); |
| |
268 |
| |
269 // Initialise iterates |
| |
270 let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); |
| |
271 let mut μ_transported_base = DiscreteMeasure::new(); |
| |
272 let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info |
| |
273 let mut residual = -b; |
| |
274 let mut stats = IterInfo::new(); |
| |
275 |
| |
276 // Run the algorithm |
| |
277 iterator.iterate(|state| { |
| |
278 // Calculate smooth part of surrogate model. |
| |
279 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
| |
280 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
| |
281 // the residual and replacing it below before the end of this closure. |
| |
282 residual *= -τ; |
| |
283 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
| |
284 let minus_τv = opA.preadjoint().apply(r); |
| |
285 |
| |
286 // Save current base point and shift μ to new positions. |
| |
287 let μ_base = μ.clone(); |
| |
288 for δ in μ.iter_spikes_mut() { |
| |
289 δ.x += minus_τv.differential(&δ.x) * θ; |
| |
290 } |
| |
291 let mut μ_transported = μ.clone(); |
| |
292 |
| |
293 assert_eq!(μ.len(), γ_hat.len()); |
| |
294 |
| |
295 // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter |
| |
296 // transports points x from μ_base to points y in μ as shifted above, or “returns” |
| |
297 // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays |
| |
298 // are not “good” for the smoothness assumptions, or if γ_hat has more mass than |
| |
299 // μ_base. |
| |
300 let mut total_goodness = 0.0; // data term goodness |
| |
301 let mut total_reg_goodness = 0.0; // regulariser goodness |
| |
302 let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; |
| |
303 |
| |
304 for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { |
| |
305 // Calculate data term goodness for all rays. |
| |
306 let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; |
| |
307 let x = &r.source; |
| |
308 let mvy = minus_τv.apply(y); |
| |
309 let mdvx = minus_τv.differential(x); |
| |
310 let mut r_total_mass = 0.0; // Total mass of all rays with source r.source. |
| |
311 let mut bad_mass = 0.0; |
| |
312 let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| { |
| |
313 *reg_goodness = 0.0; // Initial guess |
| |
314 *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y)) |
| |
315 + ℓ * z.dist2_squared(&y); |
| |
316 total_goodness += *goodness * α; |
| |
317 r_total_mass += α; // TODO: should this include to_return from staging? (Probably not) |
| |
318 if *goodness < 0.0 { |
| |
319 bad_mass += α; |
| |
320 } |
| |
321 }; |
| |
322 for ray in r.rays.iter_mut() { |
| |
323 calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); |
| |
324 } |
| |
325 calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); |
| |
326 |
| |
327 // If the total mass of the ray set is less than that of μ at the same source, |
| |
328 // a diagonal component needs to be added to be able to (attempt to) transport |
| |
329 // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ |
| |
330 // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” |
| |
331 // rays in the set. |
| |
332 if δ_mass >= r_total_mass { |
| |
333 r.diagonal += δ_mass - r_total_mass; |
| |
334 } else { |
| |
335 let mut reduce_transport = r_total_mass - δ_mass; |
| |
336 let mut good_needed = (bad_mass - reduce_transport).max(0.0); |
| |
337 // NOTE: reg_goodness is zero at this point, so it is not used in this code. |
| |
338 let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| { |
| |
339 if reduce_transport > 0.0 { |
| |
340 let return_amount = if goodness < 0.0 { |
| |
341 α.min(reduce_transport) |
| |
342 } else { |
| |
343 let amount = α.min(good_needed); |
| |
344 good_needed -= amount; |
| |
345 amount |
| |
346 }; |
| |
347 |
| |
348 if return_amount > 0.0 { |
| |
349 reduce_transport -= return_amount; |
| |
350 // Adjust total goodness by returned amount |
| |
351 total_goodness -= goodness * return_amount; |
| |
352 to_return.map(|tr| *tr += return_amount); |
| |
353 *α -= return_amount; |
| |
354 *α > 0.0 |
| |
355 } else { |
| |
356 true |
| |
357 } |
| |
358 } else { |
| |
359 true |
| |
360 } |
| |
361 }; |
| |
362 r.rays.retain_mut(|ray| { |
| |
363 reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α) |
| |
364 }); |
| |
365 // A bad diagonal is simply reduced without any 'return'. |
| |
366 // It was, after all, just added to match μ, but there is no need to match it. |
| |
367 // It's just a heuristic. |
| |
368 // TODO: Maybe a bad diagonal should be the first to go. |
| |
369 reduce_ray(r.diagonal_goodness, None, &mut r.diagonal); |
| |
370 } |
| |
371 } |
| |
372 |
| |
373 // Solve finite-dimensional subproblem several times until the dual variable for the |
| |
374 // regularisation term conforms to the assumptions made for the transport above. |
| |
375 let (d, within_tolerances) = 'adapt_transport: loop { |
| |
376 // If transport violates goodness requirements, shift it to ‘return’ mass to z, |
| |
377 // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), |
| |
378 // we proportionally distribute the reductions to each ray set, and within each ray |
| |
379 // set, prioritise reducing the oldest bad rays' weight. |
| |
380 let tg = total_goodness + total_reg_goodness; |
| |
381 let adaptation_needed = minimum_goodness - tg; |
| |
382 if adaptation_needed > 0.0 { |
| |
383 let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); |
| |
384 |
| |
385 let mut return_ray = |goodness : F, |
| |
386 reg_goodness : F, |
| |
387 to_return : Option<&mut F>, |
| |
388 α : &mut F, |
| |
389 left_to_return : &mut F| { |
| |
390 let g = goodness + reg_goodness; |
| |
391 assert!(*α >= 0.0 && *left_to_return >= 0.0); |
| |
392 if *left_to_return > 0.0 && g < 0.0 { |
| |
393 let return_amount = (*left_to_return / (-g)).min(*α); |
| |
394 *left_to_return -= (-g) * return_amount; |
| |
395 total_goodness -= goodness * return_amount; |
| |
396 total_reg_goodness -= reg_goodness * return_amount; |
| |
397 to_return.map(|tr| *tr += return_amount); |
| |
398 *α -= return_amount; |
| |
399 *α > 0.0 |
| |
400 } else { |
| |
401 true |
| |
402 } |
| |
403 }; |
| |
404 |
| |
405 for r in γ_hat.iter_mut() { |
| |
406 let mut left_to_return = adaptation_needed * r.total_badness() / total_badness; |
| |
407 if left_to_return > 0.0 { |
| |
408 for ray in r.rays.iter_mut() { |
| |
409 return_ray(ray.goodness, ray.reg_goodness, |
| |
410 Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return); |
| |
411 } |
| |
412 return_ray(r.diagonal_goodness, r.diagonal_reg_goodness, |
| |
413 None, &mut r.diagonal, &mut left_to_return); |
| |
414 } |
| |
415 } |
| |
416 } |
| |
417 |
| |
418 // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. |
| |
419 // This can be broken down into |
| |
420 // |
| |
421 // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and |
| |
422 // μ_transported = π_#^1 γ_shift |
| |
423 // |
| |
424 // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. |
| |
425 // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ |
| |
426 // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure |
| |
427 // The latter can be constructed from γ_hat rays' source and destination with the |
| |
428 // to_return mass. |
| |
429 // |
| |
430 // Note that μ_transported is constructed to have the same spike locations as μ, but |
| |
431 // to have same length as μ_base. This loop does not iterate over the spikes of μ |
| |
432 // (and corresponding transports of γ_hat) that have been newly added in the current |
| |
433 // 'adapt_transport loop. |
| |
434 for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), |
| |
435 μ_transported.iter_spikes_mut(), |
| |
436 γ_hat.iter()) { |
| |
437 let &DeltaMeasure{ref x, α} = δ; |
| |
438 debug_assert_eq!(*x, r.source); |
| |
439 let shifted_mass = r.total_mass(); |
| |
440 let ret_mass = r.total_return(); |
| |
441 // μ - π_#^0 (γ_shift + γ_return) |
| |
442 μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; |
| |
443 // π_#^1 γ_return |
| |
444 μ_transported_base.extend(r.return_targets()); |
| |
445 // π_#^1 γ_shift |
| |
446 δ_transported.set_mass(shifted_mass); |
| |
447 } |
| |
448 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) |
| |
449 let transported_residual = calculate_residual2(&μ_transported, |
| |
450 &μ_transported_base, |
| |
451 opA, b); |
| |
452 let transported_minus_τv = opA.preadjoint() |
| |
453 .apply(transported_residual); |
| |
454 |
| |
455 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| |
456 let (mut d, within_tolerances) = insert_and_reweigh( |
| |
457 &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), |
| |
458 op𝒟, op𝒟norm, |
| |
459 τ, ε, |
| |
460 config, ®, state, &mut stats |
| |
461 ); |
| |
462 |
| |
463 // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely |
| |
464 // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); |
| |
465 // We “essentially” assume that the subdifferential w of the regularisation term |
| |
466 // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] |
| |
467 // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for |
| |
468 // some w in the subdifferential of the regularisation term, such that |
| |
469 // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`]. |
| |
470 for r in γ_hat.iter_mut() { |
| |
471 for ray in r.rays.iter_mut() { |
| |
472 ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config); |
| |
473 total_reg_goodness += ray.reg_goodness * ray.δ.α; |
| |
474 } |
| |
475 } |
| |
476 |
| |
477 // If update of regularisation term goodness didn't invalidate minimum goodness |
| |
478 // requirements, we have found our step. Otherwise we need to keep reducing |
| |
479 // transport by repeating the loop. |
| |
480 if total_goodness + total_reg_goodness >= minimum_goodness { |
| |
481 break 'adapt_transport (d, within_tolerances) |
| |
482 } |
| |
483 }; |
| |
484 |
| |
485 // Update γ_hat to new location |
| |
486 for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { |
| |
487 // Prune rays that only had a return component, as the return component becomes |
| |
488 // a diagonal in γ̂^{k+1}. |
| |
489 r.rays.retain(|ray| ray.δ.α != 0.0); |
| |
490 // Otherwise zero out the return component, or stage rays for pruning |
| |
491 // to keep memory and computational demands reasonable. |
| |
492 let n_rays = r.rays.len(); |
| |
493 for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { |
| |
494 if ir >= sfbconfig.maximum_rays { |
| |
495 // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for |
| |
496 // pruning in next step. |
| |
497 ray.to_return = ray.δ.α; |
| |
498 ray.δ.α = 0.0; |
| |
499 } else { |
| |
500 ray.to_return = 0.0; |
| |
501 } |
| |
502 ray.goodness = 0.0; // TODO: probably not needed |
| |
503 ray.reg_goodness = 0.0; |
| |
504 } |
| |
505 // Add a new ray for the currently diagonal component |
| |
506 if r.diagonal > 0.0 { |
| |
507 r.rays.push(Ray{ |
| |
508 δ : DeltaMeasure{x : r.source, α : r.diagonal}, |
| |
509 goodness : 0.0, |
| |
510 reg_goodness : 0.0, |
| |
511 to_return : 0.0, |
| |
512 }); |
| |
513 // TODO: Maybe this does not need to be done here, and is sufficent to to do where |
| |
514 // the goodness is calculated. |
| |
515 r.diagonal = 0.0; |
| |
516 } |
| |
517 r.diagonal_goodness = 0.0; |
| |
518 |
| |
519 // Shift source |
| |
520 r.source = δ_new.x; |
| |
521 } |
| |
522 // Extend to new spikes |
| |
523 γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| { |
| |
524 RaySet{ |
| |
525 source : δ_new.x, |
| |
526 rays : [].into(), |
| |
527 diagonal : 0.0, |
| |
528 diagonal_goodness : 0.0, |
| |
529 diagonal_reg_goodness : 0.0 |
| |
530 } |
| |
531 })); |
| |
532 |
| |
533 // Prune spikes with zero weight. This also moves the marginal differences of corresponding |
| |
534 // transports from γ_hat to γ_pruned_marginal_diff. |
| |
535 // TODO: optimise standard prune with swap_remove. |
| |
536 μ_transported_base.clear(); |
| |
537 let mut i = 0; |
| |
538 assert_eq!(μ.len(), γ_hat.len()); |
| |
539 while i < μ.len() { |
| |
540 if μ[i].α == F::ZERO { |
| |
541 μ.swap_remove(i); |
| |
542 let r = γ_hat.swap_remove(i); |
| |
543 μ_transported_base.extend(r.targets().cloned()); |
| |
544 μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source }; |
| |
545 } else { |
| |
546 i += 1; |
| |
547 } |
| |
548 } |
| |
549 |
| |
550 // TODO: how to merge? |
| |
551 |
| |
552 // Update residual |
| |
553 residual = calculate_residual(&μ, opA, b); |
| |
554 |
| |
555 // Update main tolerance for next iteration |
| |
556 let ε_prev = ε; |
| |
557 ε = tolerance.update(ε, state.iteration()); |
| |
558 stats.this_iters += 1; |
| |
559 |
| |
560 // Give function value if needed |
| |
561 state.if_verbose(|| { |
| |
562 // Plot if so requested |
| |
563 plotter.plot_spikes( |
| |
564 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
| |
565 "start".to_string(), Some(&minus_τv), |
| |
566 reg.target_bounds(τ, ε_prev), &μ, |
| |
567 ); |
| |
568 // Calculate mean inner iterations and reset relevant counters. |
| |
569 // Return the statistics |
| |
570 let res = IterInfo { |
| |
571 value : residual.norm2_squared_div2() + reg.apply(&μ), |
| |
572 n_spikes : μ.len(), |
| |
573 ε : ε_prev, |
| |
574 postprocessing: config.postprocessing.then(|| μ.clone()), |
| |
575 .. stats |
| |
576 }; |
| |
577 stats = IterInfo::new(); |
| |
578 res |
| |
579 }) |
| |
580 }); |
| |
581 |
| |
582 postprocess(μ, config, L2Squared, opA, b) |
| |
583 } |