src/sliding_fb.rs

branch
dev
changeset 63
7a8a55fd41c0
parent 61
4f468d35fa29
child 68
00d0881f89a6
equal deleted inserted replaced
61:4f468d35fa29 63:7a8a55fd41c0
7 use serde::{Deserialize, Serialize}; 7 use serde::{Deserialize, Serialize};
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 use std::ops::MulAssign;
12 13
13 use crate::fb::*; 14 use crate::fb::*;
14 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; 15 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
15 use crate::measures::merging::SpikeMerging; 16 use crate::measures::merging::SpikeMerging;
16 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 17 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM};
17 use crate::plot::Plotter; 18 use crate::plot::Plotter;
18 use crate::prox_penalty::{ProxPenalty, StepLengthBound}; 19 use crate::prox_penalty::{ProxPenalty, StepLengthBound};
19 use crate::regularisation::SlidingRegTerm; 20 use crate::regularisation::SlidingRegTerm;
20 use crate::types::*; 21 use crate::types::*;
21 use alg_tools::error::DynResult; 22 use alg_tools::error::DynResult;
23 use alg_tools::iterate::AlgIteratorFactory; 24 use alg_tools::iterate::AlgIteratorFactory;
24 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; 25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
25 use alg_tools::nalgebra_support::ToNalgebraRealField; 26 use alg_tools::nalgebra_support::ToNalgebraRealField;
26 use alg_tools::norms::Norm; 27 use alg_tools::norms::Norm;
27 use anyhow::ensure; 28 use anyhow::ensure;
29 use std::ops::ControlFlow;
28 30
29 /// Transport settings for [`pointsource_sliding_fb_reg`]. 31 /// Transport settings for [`pointsource_sliding_fb_reg`].
30 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 32 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
31 #[serde(default)] 33 #[serde(default)]
32 pub struct TransportConfig<F: Float> { 34 pub struct TransportConfig<F: Float> {
34 pub θ0: F, 36 pub θ0: F,
35 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. 37 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
36 pub adaptation: F, 38 pub adaptation: F,
37 /// A posteriori transport tolerance multiplier (C_pos) 39 /// A posteriori transport tolerance multiplier (C_pos)
38 pub tolerance_mult_con: F, 40 pub tolerance_mult_con: F,
41 /// maximum number of adaptation iterations, until cancelling transport.
42 pub max_attempts: usize,
39 } 43 }
40 44
41 #[replace_float_literals(F::cast_from(literal))] 45 #[replace_float_literals(F::cast_from(literal))]
42 impl<F: Float> TransportConfig<F> { 46 impl<F: Float> TransportConfig<F> {
43 /// Check that the parameters are ok. Panics if not. 47 /// Check that the parameters are ok. Panics if not.
50 } 54 }
51 55
52 #[replace_float_literals(F::cast_from(literal))] 56 #[replace_float_literals(F::cast_from(literal))]
53 impl<F: Float> Default for TransportConfig<F> { 57 impl<F: Float> Default for TransportConfig<F> {
54 fn default() -> Self { 58 fn default() -> Self {
55 TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 } 59 TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 }
56 } 60 }
57 } 61 }
58 62
59 /// Settings for [`pointsource_sliding_fb_reg`]. 63 /// Settings for [`pointsource_sliding_fb_reg`].
60 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
96 /// Adaptive step length. 100 /// Adaptive step length.
97 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 101 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
98 FullyAdaptive { l: F, max_transport: F, g: G }, 102 FullyAdaptive { l: F, max_transport: F, g: G },
99 } 103 }
100 104
101 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` 105 #[derive(Clone, Debug, Serialize)]
102 /// with step lengh τ and transport step length `θ_or_adaptive`. 106 pub struct SingleTransport<const N: usize, F: Float> {
103 #[replace_float_literals(F::cast_from(literal))] 107 /// Source point
104 pub(crate) fn initial_transport<F, G, D, const N: usize>( 108 x: Loc<N, F>,
105 γ1: &mut RNDM<N, F>, 109 /// Target point
106 μ: &mut RNDM<N, F>, 110 y: Loc<N, F>,
107 τ: F, 111 /// Original mass
108 θ_or_adaptive: &mut TransportStepLength<F, G>, 112 α_μ_orig: F,
109 v: D, 113 /// Transported mass
110 ) -> (Vec<F>, RNDM<N, F>) 114 α_γ: F,
111 where 115 /// Helper for pruning
112 F: Float + ToNalgebraRealField, 116 prune: bool,
113 G: Fn(F, F) -> F, 117 }
114 D: DifferentiableRealMapping<N, F>, 118
115 { 119 #[derive(Clone, Debug, Serialize)]
116 use TransportStepLength::*; 120 pub struct Transport<const N: usize, F: Float> {
117 121 vec: Vec<SingleTransport<N, F>>,
118 // Save current base point and shift μ to new positions. Idea is that 122 }
119 // μ_base(_masses) = μ^k (vector of masses) 123
120 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 124 /// Whether partiall transported points are allowed.
121 // γ1 = π_♯^1γ^{k+1} 125 ///
122 // μ = μ^{k+1} 126 /// Partial transport can cause spike count explosion, so full or zero
123 let μ_base_masses: Vec<F> = μ.iter_masses().collect(); 127 /// transport is generally preferred. If this is set to `true`, different
124 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. 128 /// transport adaptation heuristics will be used.
125 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates 129 const ALLOW_PARTIAL_TRANSPORT: bool = true;
126 //let mut sum_norm_dv = 0.0; 130 const MINIMAL_PARTIAL_TRANSPORT: bool = true;
127 let γ_prev_len = γ1.len(); 131
128 assert!(μ.len() >= γ_prev_len); 132 impl<const N: usize, F: Float> Transport<N, F> {
129 γ1.extend(μ[γ_prev_len..].iter().cloned()); 133 pub(crate) fn new() -> Self {
130 134 Transport { vec: Vec::new() }
131 // Calculate initial transport and step length. 135 }
132 // First calculate initial transported weights 136
133 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 137 pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ SingleTransport<N, F>> {
134 // If old transport has opposing sign, the new transport will be none. 138 self.vec.iter()
135 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { 139 }
136 0.0 140
137 } else { 141 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut SingleTransport<N, F>> {
138 δ.α 142 self.vec.iter_mut()
139 }; 143 }
140 } 144
141 145 pub(crate) fn extend<I>(&mut self, it: I)
142 // Calculate transport rays. 146 where
143 match *θ_or_adaptive { 147 I: IntoIterator<Item = SingleTransport<N, F>>,
144 Fixed(θ) => { 148 {
145 let θτ = τ * θ; 149 self.vec.extend(it)
146 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 150 }
147 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 151
148 } 152 pub(crate) fn len(&self) -> usize {
149 } 153 self.vec.len()
150 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => { 154 }
151 *max_transport = max_transport.max(γ1.norm(Radon)); 155
152 let θτ = τ * calculate_θ(ℓ_F, *max_transport); 156 // pub(crate) fn dist_matching(&self, μ: &RNDM<N, F>) -> F {
153 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 157 // self.iter()
154 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 158 // .zip(μ.iter_spikes())
155 } 159 // .map(|(ρ, δ)| (ρ.α_γ - δ.α).abs())
156 } 160 // .sum()
157 FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => { 161 // }
158 *max_transport = max_transport.max(γ1.norm(Radon)); 162
159 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 163 /// Construct `μ̆`, replacing the contents of `μ`.
160 // Do two runs through the spikes to update θ, breaking if first run did not cause 164 #[replace_float_literals(F::cast_from(literal))]
161 // a change. 165 pub(crate) fn μ̆_into(&self, μ: &mut RNDM<N, F>) {
162 for _i in 0..=1 { 166 assert!(self.len() <= μ.len());
163 let mut changes = false; 167
164 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 168 // First transported points
165 let dv_x = v.differential(&δ.x); 169 for (δ, ρ) in izip!(μ.iter_spikes_mut(), self.iter()) {
166 let g = &dv_x * (ρ.α.signum() * θ * τ); 170 if ρ.α_γ.abs() > 0.0 {
167 ρ.x = δ.x - g; 171 // Transport – transported point
168 let n = g.norm2(); 172 δ.α = ρ.α_γ;
169 if n >= F::EPSILON { 173 δ.x = ρ.y;
170 // Estimate Lipschitz factor of ∇v 174 } else {
171 let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n; 175 // No transport – original point
172 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); 176 δ.α = ρ.α_μ_orig;
173 θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 177 δ.x = ρ.x;
174 changes = true 178 }
179 }
180
181 // Then source points with partial transport
182 let mut i = self.len();
183 if ALLOW_PARTIAL_TRANSPORT {
184 // This can cause the number of points to explode, so cannot have partial transport.
185 for ρ in self.iter() {
186 let α = ρ.α_μ_orig - ρ.α_γ;
187 if ρ.α_γ.abs() > F::EPSILON && α != 0.0 {
188 let δ = DeltaMeasure { α, x: ρ.x };
189 if i < μ.len() {
190 μ[i] = δ;
191 } else {
192 μ.push(δ)
193 }
194 i += 1;
195 }
196 }
197 }
198 μ.truncate(i);
199 }
200
201 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
202 /// with step lengh τ and transport step length `θ_or_adaptive`.
203 #[replace_float_literals(F::cast_from(literal))]
204 pub(crate) fn initial_transport<G, D>(
205 &mut self,
206 μ: &RNDM<N, F>,
207 _τ: F,
208 τθ_or_adaptive: &mut TransportStepLength<F, G>,
209 v: D,
210 ) where
211 G: Fn(F, F) -> F,
212 D: DifferentiableRealMapping<N, F>,
213 {
214 use TransportStepLength::*;
215
216 // Initialise transport structure weights
217 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
218 ρ.α_μ_orig = δ.α;
219 ρ.x = δ.x;
220 // If old transport has opposing sign, the new transport will be none.
221 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
222 0.0
223 } else {
224 δ.α
225 }
226 }
227
228 let γ_prev_len = self.len();
229 assert!(μ.len() >= γ_prev_len);
230 self.extend(μ[γ_prev_len..].iter().map(|δ| SingleTransport {
231 x: δ.x,
232 y: δ.x, // Just something, will be filled properly in the next phase
233 α_μ_orig: δ.α,
234 α_γ: δ.α,
235 prune: false,
236 }));
237
238 // Calculate transport rays.
239 match *τθ_or_adaptive {
240 Fixed(θ) => {
241 for ρ in self.iter_mut() {
242 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
243 }
244 }
245 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => {
246 *max_transport = max_transport.max(self.norm(Radon));
247 let θτ = calculate_θτ(ℓ_F, *max_transport);
248 for ρ in self.iter_mut() {
249 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
250 }
251 }
252 FullyAdaptive {
253 l: ref mut adaptive_ℓ_F,
254 ref mut max_transport,
255 g: ref calculate_θτ,
256 } => {
257 *max_transport = max_transport.max(self.norm(Radon));
258 let mut θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
259 // Do two runs through the spikes to update θ, breaking if first run did not cause
260 // a change.
261 for _i in 0..=1 {
262 let mut changes = false;
263 for ρ in self.iter_mut() {
264 let dv_x = v.differential(&ρ.x);
265 let g = &dv_x * (ρ.α_γ.signum() * θτ);
266 ρ.y = ρ.x - g;
267 let n = g.norm2();
268 if n >= F::EPSILON {
269 // Estimate Lipschitz factor of ∇v
270 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
271 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
272 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
273 changes = true
274 }
275 }
276 if !changes {
277 break;
175 } 278 }
176 } 279 }
177 if !changes { 280 }
178 break; 281 }
282 }
283
284 /// A posteriori transport adaptation.
285 #[replace_float_literals(F::cast_from(literal))]
286 pub(crate) fn aposteriori_transport<D>(
287 &mut self,
288 μ: &RNDM<N, F>,
289 μ̆: &RNDM<N, F>,
290 _v: &mut D,
291 extra: Option<F>,
292 ε: F,
293 tconfig: &TransportConfig<F>,
294 attempts: &mut usize,
295 ) -> bool
296 where
297 D: DifferentiableRealMapping<N, F>,
298 {
299 *attempts += 1;
300
301 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
302 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
303 // at that point to zero, and retry.
304 let mut all_ok = true;
305 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
306 if δ.α == 0.0 && ρ.α_γ != 0.0 {
307 all_ok = false;
308 ρ.α_γ = 0.0;
309 }
310 }
311
312 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
313 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ̆^k
314 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
315 let nγ = self.norm(Radon);
316 let nΔ = μ.dist_matching(&μ̆) + extra.unwrap_or(0.0);
317 let t = ε * tconfig.tolerance_mult_con;
318 if nγ * nΔ > t && *attempts >= tconfig.max_attempts {
319 all_ok = false;
320 } else if nγ * nΔ > t {
321 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
322 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
323 // will not enter here.
324 //*self *= tconfig.adaptation * t / (nγ * nΔ);
325
326 // We want a consistent behaviour that has the potential to set many weights to zero.
327 // Therefore, we find the smallest uniform reduction `chg_one`, subtracted
328 // from all weights, that achieves total `adapt` adaptation.
329 let adapt_to = tconfig.adaptation * t / nΔ;
330 let reduction_target = nγ - adapt_to;
331 assert!(reduction_target > 0.0);
332 if ALLOW_PARTIAL_TRANSPORT {
333 if MINIMAL_PARTIAL_TRANSPORT {
334 // This reduces weights of transport, starting from … until `adapt` is
335 // exhausted. It will, therefore, only ever cause one extrap point insertion
336 // at the sources, unlike “full” partial transport.
337 //let refs = self.vec.iter_mut().collect::<Vec<_>>();
338 //refs.sort_by(|ρ1, ρ2| ρ1.α_γ.abs().partial_cmp(&ρ2.α_γ.abs()).unwrap());
339 // let mut it = refs.into_iter();
340 //
341 // Maybe sort by differential norm
342 // let mut refs = self
343 // .vec
344 // .iter_mut()
345 // .map(|ρ| {
346 // let val = v.differential(&ρ.x).norm2_squared();
347 // (ρ, val)
348 // })
349 // .collect::<Vec<_>>();
350 // refs.sort_by(|(_, v1), (_, v2)| v2.partial_cmp(&v1).unwrap());
351 // let mut it = refs.into_iter().map(|(ρ, _)| ρ);
352 let mut it = self.vec.iter_mut().rev();
353 let _unused = it.try_fold(reduction_target, |left, ρ| {
354 let w = ρ.α_γ.abs();
355 if left <= w {
356 ρ.α_γ = ρ.α_γ.signum() * (w - left);
357 ControlFlow::Break(())
358 } else {
359 ρ.α_γ = 0.0;
360 ControlFlow::Continue(left - w)
361 }
362 });
363 } else {
364 // This version equally reduces all weights. It causes partial transport, which
365 // has the problem that that we need to then adapt weights in both start and
366 // end points, in insert_and_reweigh, somtimes causing the number of spikes μ
367 // to explode.
368 let mut abs_weights = self
369 .vec
370 .iter()
371 .map(|ρ| ρ.α_γ.abs())
372 .filter(|t| *t > F::EPSILON)
373 .collect::<Vec<F>>();
374 abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
375 let n = abs_weights.len();
376 // Cannot have partial transport; can cause spike count explosion
377 let chg = abs_weights.into_iter().zip((1..=n).rev()).try_fold(
378 0.0,
379 |smaller_total, (w, m)| {
380 let mf = F::cast_from(m);
381 let reduction = w * mf + smaller_total;
382 if reduction >= reduction_target {
383 ControlFlow::Break((reduction_target - smaller_total) / mf)
384 } else {
385 ControlFlow::Continue(smaller_total + w)
386 }
387 },
388 );
389 match chg {
390 ControlFlow::Continue(_) => self.vec.iter_mut().for_each(|δ| δ.α_γ = 0.0),
391 ControlFlow::Break(chg_one) => self.vec.iter_mut().for_each(|ρ| {
392 let t = ρ.α_γ.abs();
393 if t > 0.0 {
394 if ALLOW_PARTIAL_TRANSPORT {
395 let new = (t - chg_one).max(0.0);
396 ρ.α_γ = ρ.α_γ.signum() * new;
397 }
398 }
399 }),
400 }
179 } 401 }
180 } 402 } else {
181 } 403 // This version zeroes smallest weights, avoiding partial transport.
182 } 404 let mut abs_weights_idx = self
183 405 .vec
184 // Set initial guess for μ=μ^{k+1}. 406 .iter()
185 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { 407 .map(|ρ| ρ.α_γ.abs())
186 if ρ.α.abs() > F::EPSILON { 408 .zip(0..)
187 δ.x = ρ.x; 409 .filter(|(w, _)| *w >= 0.0)
188 //δ.α = ρ.α; // already set above 410 .collect::<Vec<(F, usize)>>();
189 } else { 411 abs_weights_idx.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap());
190 δ.α = β; 412
191 } 413 let mut left = reduction_target;
192 } 414
193 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) 415 for (w, i) in abs_weights_idx {
194 μ_base_minus_γ0.set_masses( 416 left -= w;
195 μ_base_masses 417 let ρ = &mut self.vec[i];
418 ρ.α_γ = 0.0;
419 if left < 0.0 {
420 break;
421 }
422 }
423 }
424
425 all_ok = false
426 }
427
428 if !all_ok && *attempts >= tconfig.max_attempts {
429 for ρ in self.iter_mut() {
430 ρ.α_γ = 0.0;
431 }
432 }
433
434 all_ok
435 }
436
437 /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$
438 pub(crate) fn μ0_minus_γ0_radon(&self) -> F {
439 self.vec.iter().map(|ρ| (ρ.α_μ_orig - ρ.α_γ).abs()).sum()
440 }
441
442 /// Returns $∫ c_2 d|γ|$
443 #[replace_float_literals(F::cast_from(literal))]
444 pub(crate) fn c2integral(&self) -> F {
445 self.vec
196 .iter() 446 .iter()
197 .zip(γ1.iter_masses()) 447 .map(|ρ| ρ.y.dist2_squared(&ρ.x) / 2.0 * ρ.α_γ.abs())
198 .map(|(&a, b)| a - b), 448 .sum()
199 ); 449 }
200 (μ_base_masses, μ_base_minus_γ0) 450
201 } 451 #[replace_float_literals(F::cast_from(literal))]
202 452 pub(crate) fn get_transport_stats(&self, stats: &mut IterInfo<F>, μ: &RNDM<N, F>) {
203 /// A posteriori transport adaptation. 453 // TODO: This doesn't take into account μ[i].α becoming zero in the latest tranport
204 #[replace_float_literals(F::cast_from(literal))] 454 // attempt, for i < self.len(), when a corresponding source term also exists with index
205 pub(crate) fn aposteriori_transport<F, const N: usize>( 455 // j ≥ self.len(). For now, we let that be reflected in the prune count.
206 γ1: &mut RNDM<N, F>, 456 stats.inserted += μ.len() - self.len();
207 μ: &mut RNDM<N, F>, 457
208 μ_base_minus_γ0: &mut RNDM<N, F>, 458 let transp = stats.get_transport_mut();
209 μ_base_masses: &Vec<F>, 459
210 extra: Option<F>, 460 transp.dist = {
211 ε: F, 461 let (a, b) = transp.dist;
212 tconfig: &TransportConfig<F>, 462 (a + self.c2integral(), b + self.norm(Radon))
213 ) -> bool 463 };
214 where 464 transp.untransported_fraction = {
215 F: Float + ToNalgebraRealField, 465 let (a, b) = transp.untransported_fraction;
216 { 466 let source = self.iter().map(|ρ| ρ.α_μ_orig.abs()).sum();
217 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, 467 (a + self.μ0_minus_γ0_radon(), b + source)
218 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 468 };
219 // at that point to zero, and retry. 469 transp.transport_error = {
220 let mut all_ok = true; 470 let (a, b) = transp.transport_error;
221 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { 471 //(a + self.dist_matching(&μ), b + self.norm(Radon))
222 if α_μ == 0.0 && *α_γ1 != 0.0 { 472
223 all_ok = false; 473 // This ignores points that have been not transported at all, to only calculate
224 *α_γ1 = 0.0; 474 // destnation error; untransported_fraction accounts for not being able to transport
225 } 475 // at all.
226 } 476 self.iter()
227 477 .zip(μ.iter_spikes())
228 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). 478 .fold((a, b), |(a, b), (ρ, δ)| {
229 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, 479 let transported = ρ.α_γ.abs();
230 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. 480 if transported > F::EPSILON {
231 let nγ = γ1.norm(Radon); 481 (a + (ρ.α_γ - δ.α).abs(), b + transported)
232 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); 482 } else {
233 let t = ε * tconfig.tolerance_mult_con; 483 (a, b)
234 if nγ * nΔ > t { 484 }
235 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, 485 })
236 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we 486 };
237 // will not enter here. 487 }
238 *γ1 *= tconfig.adaptation * t / (nγ * nΔ); 488
239 all_ok = false 489 /// Prune spikes with zero weight. To maintain correct ordering between μ and γ, also the
240 } 490 /// latter needs to be pruned when μ is.
241 491 pub(crate) fn prune_compat(&mut self, μ: &mut RNDM<N, F>, stats: &mut IterInfo<F>) {
242 if !all_ok { 492 assert!(self.vec.len() <= μ.len());
243 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 493 let old_len = μ.len();
244 μ_base_minus_γ0.set_masses( 494 for (ρ, δ) in self.vec.iter_mut().zip(μ.iter_spikes()) {
245 μ_base_masses 495 ρ.prune = !(δ.α.abs() > F::EPSILON);
246 .iter() 496 }
247 .zip(γ1.iter_masses()) 497 μ.prune_by(|δ| δ.α.abs() > F::EPSILON);
248 .map(|(&a, b)| a - b), 498 stats.pruned += old_len - μ.len();
249 ); 499 self.vec.retain(|ρ| !ρ.prune);
250 } 500 assert!(self.vec.len() <= μ.len());
251 501 }
252 all_ok 502 }
503
504 impl<const N: usize, F: Float> Norm<Radon, F> for Transport<N, F> {
505 fn norm(&self, _: Radon) -> F {
506 self.iter().map(|ρ| ρ.α_γ.abs()).sum()
507 }
508 }
509
510 impl<const N: usize, F: Float> MulAssign<F> for Transport<N, F> {
511 fn mul_assign(&mut self, factor: F) {
512 for ρ in self.iter_mut() {
513 ρ.α_γ *= factor;
514 }
515 }
253 } 516 }
254 517
255 /// Iteratively solve the pointsource localisation problem using sliding forward-backward 518 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
256 /// splitting 519 /// splitting
257 /// 520 ///
282 ensure!(config.τ0 > 0.0, "Invalid step length parameter"); 545 ensure!(config.τ0 > 0.0, "Invalid step length parameter");
283 config.transport.check()?; 546 config.transport.check()?;
284 547
285 // Initialise iterates 548 // Initialise iterates
286 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); 549 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
287 let mut γ1 = DiscreteMeasure::new(); 550 let mut γ = Transport::new();
288 551
289 // Set up parameters 552 // Set up parameters
290 // let opAnorm = opA.opnorm_bound(Radon, L2); 553 // let opAnorm = opA.opnorm_bound(Radon, L2);
291 //let max_transport = config.max_transport.scale 554 //let max_transport = config.max_transport.scale
292 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 555 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
293 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 556 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
294 let ℓ = 0.0; 557 let ℓ = 0.0;
295 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; 558 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
296 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); 559
297 let transport_lip = maybe_transport_lip?; 560 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
298 let calculate_θ = |ℓ_F, max_transport| { 561 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
299 let ℓ_r = transport_lip * max_transport; 562 (maybe_ℓ_F, Ok(transport_lip)) => {
300 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) 563 let calculate_θτ = move |ℓ_F, max_transport| {
301 }; 564 let ℓ_r = transport_lip * max_transport;
302 let mut θ_or_adaptive = match maybe_ℓ_F { 565 config.transport.θ0 / (ℓ + ℓ_F + ℓ_r)
303 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), 566 };
304 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { 567 match maybe_ℓ_F {
305 l: ℓ_F, // TODO: could estimate computing the real reesidual 568 Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
306 max_transport: 0.0, 569 l: ℓ_F, // TODO: could estimate computing the real reesidual
307 g: calculate_θ, 570 max_transport: 0.0,
308 }, 571 g: calculate_θτ,
309 Err(_) => TransportStepLength::FullyAdaptive { 572 },
310 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials 573 Err(_) => TransportStepLength::FullyAdaptive {
311 max_transport: 0.0, 574 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
312 g: calculate_θ, 575 max_transport: 0.0,
313 }, 576 g: calculate_θτ,
577 },
578 }
579 }
314 }; 580 };
315 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 581 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
316 // by τ compared to the conditional gradient approach. 582 // by τ compared to the conditional gradient approach.
317 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); 583 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
318 let mut ε = tolerance.initial(); 584 let mut ε = tolerance.initial();
329 595
330 // Run the algorithm 596 // Run the algorithm
331 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 597 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
332 // Calculate initial transport 598 // Calculate initial transport
333 let v = f.differential(&μ); 599 let v = f.differential(&μ);
334 let (μ_base_masses, mut μ_base_minus_γ0) = 600 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v);
335 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); 601
602 let mut attempts = 0;
336 603
337 // Solve finite-dimensional subproblem several times until the dual variable for the 604 // Solve finite-dimensional subproblem several times until the dual variable for the
338 // regularisation term conforms to the assumptions made for the transport above. 605 // regularisation term conforms to the assumptions made for the transport above.
339 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { 606 let (maybe_d, _within_tolerances, mut τv̆, μ̆) = 'adapt_transport: loop {
607 // Set initial guess for μ=μ^{k+1}.
608 γ.μ̆_into(&mut μ);
609 let μ̆ = μ.clone();
610
340 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 611 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
341 //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 612 //let residual_μ̆ = calculate_residual2(&γ1, &μ0_minus_γ0, opA, b);
342 // TODO: this could be optimised by doing the differential like the 613 // TODO: this could be optimised by doing the differential like the
343 // old residual2. 614 // old residual2.
344 let μ̆ = &γ1 + &μ_base_minus_γ0; 615 // NOTE: This assumes that μ = γ1
345 let mut τv̆ = f.differential(μ̆) * τ; 616 let mut τv̆ = f.differential(&μ̆) * τ;
346 617
347 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 618 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
348 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 619 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
349 &mut μ, 620 &mut μ,
350 &mut τv̆, 621 &mut τv̆,
351 &γ1,
352 Some(&μ_base_minus_γ0),
353 τ, 622 τ,
354 ε, 623 ε,
355 &config.insertion, 624 &config.insertion,
356 &reg, 625 &reg,
357 &state, 626 &state,
358 &mut stats, 627 &mut stats,
359 )?; 628 )?;
360 629
361 // A posteriori transport adaptation. 630 // A posteriori transport adaptation.
362 if aposteriori_transport( 631 if γ.aposteriori_transport(&μ, &μ̆, &mut τv̆, None, ε, &config.transport, &mut attempts)
363 &mut γ1, 632 {
364 &mut μ, 633 break 'adapt_transport (maybe_d, within_tolerances, τv̆, μ̆);
365 &mut μ_base_minus_γ0, 634 }
366 &μ_base_masses, 635
367 None, 636 stats.get_transport_mut().readjustment_iters += 1;
368 ε,
369 &config.transport,
370 ) {
371 break 'adapt_transport (maybe_d, within_tolerances, τv̆);
372 }
373 }; 637 };
374 638
375 stats.untransported_fraction = Some({ 639 γ.get_transport_stats(&mut stats, &μ);
376 assert_eq!(μ_base_masses.len(), γ1.len());
377 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
378 let source = μ_base_masses.iter().map(|v| v.abs()).sum();
379 (a + μ_base_minus_γ0.norm(Radon), b + source)
380 });
381 stats.transport_error = Some({
382 assert_eq!(μ_base_masses.len(), γ1.len());
383 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
384 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
385 });
386 640
387 // Merge spikes. 641 // Merge spikes.
388 // This crucially expects the merge routine to be stable with respect to spike locations, 642 // This crucially expects the merge routine to be stable with respect to spike locations,
389 // and not to performing any pruning. That is be to done below simultaneously for γ. 643 // and not to performing any pruning. That is be to done below simultaneously for γ.
390 let ins = &config.insertion; 644 if config.insertion.merge_now(&state) {
391 if ins.merge_now(&state) {
392 stats.merged += prox_penalty.merge_spikes( 645 stats.merged += prox_penalty.merge_spikes(
393 &mut μ, 646 &mut μ,
394 &mut τv̆, 647 &mut τv̆,
395 &γ1, 648 &μ̆,
396 Some(&μ_base_minus_γ0),
397 τ, 649 τ,
398 ε, 650 ε,
399 ins, 651 &config.insertion,
400 &reg, 652 &reg,
401 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), 653 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
402 ); 654 );
403 } 655 }
404 656
405 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 657 γ.prune_compat(&mut μ, &mut stats);
406 // latter needs to be pruned when μ is.
407 // TODO: This could do with a two-vector Vec::retain to avoid copies.
408 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
409 if μ_new.len() != μ.len() {
410 let mut μ_iter = μ.iter_spikes();
411 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
412 stats.pruned += μ.len() - μ_new.len();
413 μ = μ_new;
414 }
415 658
416 let iter = state.iteration(); 659 let iter = state.iteration();
417 stats.this_iters += 1; 660 stats.this_iters += 1;
418 661
419 // Give statistics if requested 662 // Give statistics if requested

mercurial