src/sliding_fb.rs

changeset 70
ed16d0f10d08
parent 68
00d0881f89a6
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
7 use serde::{Deserialize, Serialize}; 7 use serde::{Deserialize, Serialize};
8 //use colored::Colorize; 8 //use colored::Colorize;
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 12 use std::ops::MulAssign;
13
14 use crate::fb::*;
15 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
16 use crate::measures::merging::SpikeMerging;
17 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM};
18 use crate::plot::Plotter;
19 use crate::prox_penalty::{ProxPenalty, StepLengthBound};
20 use crate::regularisation::SlidingRegTerm;
21 use crate::types::*;
22 use alg_tools::error::DynResult;
13 use alg_tools::euclidean::Euclidean; 23 use alg_tools::euclidean::Euclidean;
14 use alg_tools::iterate::AlgIteratorFactory; 24 use alg_tools::iterate::AlgIteratorFactory;
15 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; 25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
16 use alg_tools::nalgebra_support::ToNalgebraRealField; 26 use alg_tools::nalgebra_support::ToNalgebraRealField;
17 use alg_tools::norms::Norm; 27 use alg_tools::norms::Norm;
18 28 use anyhow::ensure;
19 use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel}; 29 use std::ops::ControlFlow;
20 use crate::measures::merging::SpikeMerging;
21 use crate::measures::{DiscreteMeasure, Radon, RNDM};
22 use crate::types::*;
23 //use crate::tolerance::Tolerance;
24 use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
25 use crate::fb::*;
26 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
27 use crate::regularisation::SlidingRegTerm;
28 //use crate::transport::TransportLipschitz;
29 30
30 /// Transport settings for [`pointsource_sliding_fb_reg`]. 31 /// Transport settings for [`pointsource_sliding_fb_reg`].
31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 32 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
32 #[serde(default)] 33 #[serde(default)]
33 pub struct TransportConfig<F: Float> { 34 pub struct TransportConfig<F: Float> {
35 pub θ0: F, 36 pub θ0: F,
36 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. 37 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
37 pub adaptation: F, 38 pub adaptation: F,
38 /// A posteriori transport tolerance multiplier (C_pos) 39 /// A posteriori transport tolerance multiplier (C_pos)
39 pub tolerance_mult_con: F, 40 pub tolerance_mult_con: F,
41 /// maximum number of adaptation iterations, until cancelling transport.
42 pub max_attempts: usize,
43 /// Maximum number of failed transportations for a single source point
44 pub max_fail: usize,
40 } 45 }
41 46
42 #[replace_float_literals(F::cast_from(literal))] 47 #[replace_float_literals(F::cast_from(literal))]
43 impl<F: Float> TransportConfig<F> { 48 impl<F: Float> TransportConfig<F> {
44 /// Check that the parameters are ok. Panics if not. 49 /// Check that the parameters are ok. Panics if not.
45 pub fn check(&self) { 50 pub fn check(&self) -> DynResult<()> {
46 assert!(self.θ0 > 0.0); 51 ensure!(self.θ0 > 0.0);
47 assert!(0.0 < self.adaptation && self.adaptation < 1.0); 52 ensure!(0.0 < self.adaptation && self.adaptation < 1.0);
48 assert!(self.tolerance_mult_con > 0.0); 53 ensure!(self.tolerance_mult_con > 0.0);
54 Ok(())
49 } 55 }
50 } 56 }
51 57
52 #[replace_float_literals(F::cast_from(literal))] 58 #[replace_float_literals(F::cast_from(literal))]
53 impl<F: Float> Default for TransportConfig<F> { 59 impl<F: Float> Default for TransportConfig<F> {
54 fn default() -> Self { 60 fn default() -> Self {
55 TransportConfig { 61 TransportConfig {
56 θ0: 0.9, 62 θ0: 0.9,
57 adaptation: 0.9, 63 adaptation: 0.9,
58 tolerance_mult_con: 100.0, 64 tolerance_mult_con: 100.0,
65 max_attempts: 2,
66 max_fail: usize::MAX,
59 } 67 }
60 } 68 }
61 } 69 }
62 70
63 /// Settings for [`pointsource_sliding_fb_reg`]. 71 /// Settings for [`pointsource_sliding_fb_reg`].
64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 72 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
65 #[serde(default)] 73 #[serde(default)]
66 pub struct SlidingFBConfig<F: Float> { 74 pub struct SlidingFBConfig<F: Float> {
67 /// Step length scaling 75 /// Step length scaling
68 pub τ0: F, 76 pub τ0: F,
77 // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`]
78 pub σp0: F,
69 /// Transport parameters 79 /// Transport parameters
70 pub transport: TransportConfig<F>, 80 pub transport: TransportConfig<F>,
71 /// Generic parameters 81 /// Generic parameters
72 pub insertion: FBGenericConfig<F>, 82 pub insertion: InsertionConfig<F>,
83 /// Guess for curvature bound calculations.
84 pub guess: BoundedCurvatureGuess,
73 } 85 }
74 86
75 #[replace_float_literals(F::cast_from(literal))] 87 #[replace_float_literals(F::cast_from(literal))]
76 impl<F: Float> Default for SlidingFBConfig<F> { 88 impl<F: Float> Default for SlidingFBConfig<F> {
77 fn default() -> Self { 89 fn default() -> Self {
78 SlidingFBConfig { 90 SlidingFBConfig {
79 τ0: 0.99, 91 τ0: 0.99,
92 σp0: 0.99,
80 transport: Default::default(), 93 transport: Default::default(),
81 insertion: Default::default(), 94 insertion: Default::default(),
95 guess: BoundedCurvatureGuess::BetterThanZero,
82 } 96 }
83 } 97 }
84 } 98 }
85 99
86 /// Internal type of adaptive transport step length calculation 100 /// Internal type of adaptive transport step length calculation
94 /// Adaptive step length. 108 /// Adaptive step length.
95 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 109 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
96 FullyAdaptive { l: F, max_transport: F, g: G }, 110 FullyAdaptive { l: F, max_transport: F, g: G },
97 } 111 }
98 112
99 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` 113 #[derive(Clone, Debug, Serialize)]
100 /// with step lengh τ and transport step length `θ_or_adaptive`. 114 pub struct SingleTransport<const N: usize, F: Float> {
101 #[replace_float_literals(F::cast_from(literal))] 115 /// Source point
102 pub(crate) fn initial_transport<F, G, D, const N: usize>( 116 x: Loc<N, F>,
103 γ1: &mut RNDM<F, N>, 117 /// Target point
104 μ: &mut RNDM<F, N>, 118 y: Loc<N, F>,
105 τ: F, 119 /// Original mass
106 θ_or_adaptive: &mut TransportStepLength<F, G>, 120 α_μ_orig: F,
107 v: D, 121 /// Transported mass
108 ) -> (Vec<F>, RNDM<F, N>) 122 α_γ: F,
109 where 123 /// Helper for pruning
110 F: Float + ToNalgebraRealField, 124 prune: bool,
111 G: Fn(F, F) -> F, 125 /// Fail count
112 D: DifferentiableRealMapping<F, N>, 126 fail_count: usize,
113 { 127 }
114 use TransportStepLength::*; 128
115 129 #[derive(Clone, Debug, Serialize)]
116 // Save current base point and shift μ to new positions. Idea is that 130 pub struct Transport<const N: usize, F: Float> {
117 // μ_base(_masses) = μ^k (vector of masses) 131 vec: Vec<SingleTransport<N, F>>,
118 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 132 }
119 // γ1 = π_♯^1γ^{k+1} 133
120 // μ = μ^{k+1} 134 /// Whether partiall transported points are allowed.
121 let μ_base_masses: Vec<F> = μ.iter_masses().collect(); 135 ///
122 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. 136 /// Partial transport can cause spike count explosion, so full or zero
123 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates 137 /// transport is generally preferred. If this is set to `true`, different
124 //let mut sum_norm_dv = 0.0; 138 /// transport adaptation heuristics will be used.
125 let γ_prev_len = γ1.len(); 139 const ALLOW_PARTIAL_TRANSPORT: bool = true;
126 assert!(μ.len() >= γ_prev_len); 140 const MINIMAL_PARTIAL_TRANSPORT: bool = true;
127 γ1.extend(μ[γ_prev_len..].iter().cloned()); 141
128 142 impl<const N: usize, F: Float> Transport<N, F> {
129 // Calculate initial transport and step length. 143 pub(crate) fn new() -> Self {
130 // First calculate initial transported weights 144 Transport { vec: Vec::new() }
131 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 145 }
132 // If old transport has opposing sign, the new transport will be none. 146
133 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { 147 pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ SingleTransport<N, F>> {
134 0.0 148 self.vec.iter()
135 } else { 149 }
136 δ.α 150
137 }; 151 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut SingleTransport<N, F>> {
138 } 152 self.vec.iter_mut()
139 153 }
140 // Calculate transport rays. 154
141 match *θ_or_adaptive { 155 pub(crate) fn extend<I>(&mut self, it: I)
142 Fixed(θ) => { 156 where
143 let θτ = τ * θ; 157 I: IntoIterator<Item = SingleTransport<N, F>>,
144 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 158 {
145 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 159 self.vec.extend(it)
146 } 160 }
147 } 161
148 AdaptiveMax { 162 pub(crate) fn len(&self) -> usize {
149 l: ℓ_F, 163 self.vec.len()
150 ref mut max_transport, 164 }
151 g: ref calculate_θ, 165
152 } => { 166 // pub(crate) fn dist_matching(&self, μ: &RNDM<N, F>) -> F {
153 *max_transport = max_transport.max(γ1.norm(Radon)); 167 // self.iter()
154 let θτ = τ * calculate_θ(ℓ_F, *max_transport); 168 // .zip(μ.iter_spikes())
155 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 169 // .map(|(ρ, δ)| (ρ.α_γ - δ.α).abs())
156 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 170 // .sum()
157 } 171 // }
158 } 172
159 FullyAdaptive { 173 /// Construct `μ̆`, replacing the contents of `μ`.
160 l: ref mut adaptive_ℓ_F, 174 #[replace_float_literals(F::cast_from(literal))]
161 ref mut max_transport, 175 pub(crate) fn μ̆_into(&self, μ: &mut RNDM<N, F>) {
162 g: ref calculate_θ, 176 assert!(self.len() <= μ.len());
163 } => { 177
164 *max_transport = max_transport.max(γ1.norm(Radon)); 178 // First transported points
165 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 179 for (δ, ρ) in izip!(μ.iter_spikes_mut(), self.iter()) {
166 // Do two runs through the spikes to update θ, breaking if first run did not cause 180 if ρ.α_γ.abs() > 0.0 {
167 // a change. 181 // Transport – transported point
168 for _i in 0..=1 { 182 δ.α = ρ.α_γ;
169 let mut changes = false; 183 δ.x = ρ.y;
170 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 184 } else {
171 let dv_x = v.differential(&δ.x); 185 // No transport – original point
172 let g = &dv_x * (ρ.α.signum() * θ * τ); 186 δ.α = ρ.α_μ_orig;
173 ρ.x = δ.x - g; 187 δ.x = ρ.x;
174 let n = g.norm2(); 188 }
175 if n >= F::EPSILON { 189 }
176 // Estimate Lipschitz factor of ∇v 190
177 let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n; 191 // Then source points with partial transport
178 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); 192 let mut i = self.len();
179 θ = calculate_θ(*adaptive_ℓ_F, *max_transport); 193 if ALLOW_PARTIAL_TRANSPORT {
180 changes = true 194 // This can cause the number of points to explode, so cannot have partial transport.
195 for ρ in self.iter() {
196 let α = ρ.α_μ_orig - ρ.α_γ;
197 if ρ.α_γ.abs() > F::EPSILON && α != 0.0 {
198 let δ = DeltaMeasure { α, x: ρ.x };
199 if i < μ.len() {
200 μ[i] = δ;
201 } else {
202 μ.push(δ)
203 }
204 i += 1;
205 }
206 }
207 }
208 μ.truncate(i);
209 }
210
211 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
212 /// with step lengh τ and transport step length `θ_or_adaptive`.
213 #[replace_float_literals(F::cast_from(literal))]
214 pub(crate) fn initial_transport<G, D>(
215 &mut self,
216 μ: &RNDM<N, F>,
217 _τ: F,
218 τθ_or_adaptive: &mut TransportStepLength<F, G>,
219 v: D,
220 tconfig: &TransportConfig<F>,
221 ) where
222 G: Fn(F, F) -> F,
223 D: DifferentiableRealMapping<N, F>,
224 {
225 use TransportStepLength::*;
226
227 // Initialise transport structure weights
228 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
229 ρ.α_μ_orig = δ.α;
230 ρ.x = δ.x;
231 if ρ.fail_count > tconfig.max_fail {
232 ρ.α_γ = 0.0
233 } else {
234 // If old transport has opposing sign, the new transport will be none.
235 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
236 0.0
237 } else {
238 δ.α
239 }
240 }
241 }
242
243 let γ_prev_len = self.len();
244 assert!(μ.len() >= γ_prev_len);
245 self.extend(μ[γ_prev_len..].iter().map(|δ| SingleTransport {
246 x: δ.x,
247 y: δ.x, // Just something, will be filled properly in the next phase
248 α_μ_orig: δ.α,
249 α_γ: δ.α,
250 prune: false,
251 fail_count: 0,
252 }));
253
254 // Calculate transport rays.
255 match *τθ_or_adaptive {
256 Fixed(θ) => {
257 for ρ in self.iter_mut() {
258 if ρ.fail_count <= tconfig.max_fail {
259 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
181 } 260 }
182 } 261 }
183 if !changes { 262 }
184 break; 263 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => {
264 *max_transport = max_transport.max(self.norm(Radon));
265 let θτ = calculate_θτ(ℓ_F, *max_transport);
266 for ρ in self.iter_mut() {
267 if ρ.fail_count <= tconfig.max_fail {
268 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
269 }
185 } 270 }
186 } 271 }
187 } 272 FullyAdaptive {
188 } 273 l: ref mut adaptive_ℓ_F,
189 274 ref mut max_transport,
190 // Set initial guess for μ=μ^{k+1}. 275 g: ref calculate_θτ,
191 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { 276 } => {
192 if ρ.α.abs() > F::EPSILON { 277 *max_transport = max_transport.max(self.norm(Radon));
193 δ.x = ρ.x; 278 let mut θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
194 //δ.α = ρ.α; // already set above 279 // Do two runs through the spikes to update θ, breaking if first run did not cause
195 } else { 280 // a change.
196 δ.α = β; 281 for _i in 0..=1 {
197 } 282 let mut changes = false;
198 } 283 for ρ in self.iter_mut() {
199 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) 284 if ρ.fail_count < tconfig.max_fail {
200 μ_base_minus_γ0.set_masses( 285 let dv_x = v.differential(&ρ.x);
201 μ_base_masses 286 let g = &dv_x * (ρ.α_γ.signum() * θτ);
287 ρ.y = ρ.x - g;
288 let n = g.norm2();
289 if n >= F::EPSILON {
290 // Estimate Lipschitz factor of ∇v
291 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
292 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
293 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
294 changes = true
295 }
296 }
297 }
298 if !changes {
299 break;
300 }
301 }
302 }
303 }
304 }
305
306 /// A posteriori transport adaptation.
307 #[replace_float_literals(F::cast_from(literal))]
308 pub(crate) fn aposteriori_transport<D>(
309 &mut self,
310 μ: &RNDM<N, F>,
311 μ̆: &RNDM<N, F>,
312 _v: &mut D,
313 extra: Option<F>,
314 ε: F,
315 tconfig: &TransportConfig<F>,
316 attempts: &mut usize,
317 ) -> bool
318 where
319 D: DifferentiableRealMapping<N, F>,
320 {
321 *attempts += 1;
322
323 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
324 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
325 // at that point to zero, and retry.
326 let mut all_ok = true;
327 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
328 if δ.α == 0.0 && ρ.α_γ != 0.0 {
329 all_ok = false;
330 ρ.α_γ = 0.0;
331 }
332 }
333
334 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
335 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ̆^k
336 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
337 let nγ = self.norm(Radon);
338 let nΔ = μ.dist_matching(&μ̆) + extra.unwrap_or(0.0);
339 let t = ε * tconfig.tolerance_mult_con;
340 if nγ * nΔ > t && *attempts >= tconfig.max_attempts {
341 all_ok = false;
342 } else if nγ * nΔ > t {
343 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
344 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
345 // will not enter here.
346 //*self *= tconfig.adaptation * t / (nγ * nΔ);
347
348 // We want a consistent behaviour that has the potential to set many weights to zero.
349 // Therefore, we find the smallest uniform reduction `chg_one`, subtracted
350 // from all weights, that achieves total `adapt` adaptation.
351 let adapt_to = tconfig.adaptation * t / nΔ;
352 let reduction_target = nγ - adapt_to;
353 assert!(reduction_target > 0.0);
354 if ALLOW_PARTIAL_TRANSPORT {
355 if MINIMAL_PARTIAL_TRANSPORT {
356 // This reduces weights of transport, starting from … until `adapt` is
357 // exhausted. It will, therefore, only ever cause one extrap point insertion
358 // at the sources, unlike “full” partial transport.
359 //let refs = self.vec.iter_mut().collect::<Vec<_>>();
360 //refs.sort_by(|ρ1, ρ2| ρ1.α_γ.abs().partial_cmp(&ρ2.α_γ.abs()).unwrap());
361 // let mut it = refs.into_iter();
362 //
363 // Maybe sort by differential norm
364 // let mut refs = self
365 // .vec
366 // .iter_mut()
367 // .map(|ρ| {
368 // let val = v.differential(&ρ.x).norm2_squared();
369 // (ρ, val)
370 // })
371 // .collect::<Vec<_>>();
372 // refs.sort_by(|(_, v1), (_, v2)| v2.partial_cmp(&v1).unwrap());
373 // let mut it = refs.into_iter().map(|(ρ, _)| ρ);
374 let mut it = self.vec.iter_mut().rev();
375 let _unused = it.try_fold(reduction_target, |left, ρ| {
376 let w = ρ.α_γ.abs();
377 if left <= w {
378 ρ.α_γ = ρ.α_γ.signum() * (w - left);
379 ControlFlow::Break(())
380 } else {
381 ρ.α_γ = 0.0;
382 ControlFlow::Continue(left - w)
383 }
384 });
385 } else {
386 // This version equally reduces all weights. It causes partial transport, which
387 // has the problem that that we need to then adapt weights in both start and
388 // end points, in insert_and_reweigh, somtimes causing the number of spikes μ
389 // to explode.
390 let mut abs_weights = self
391 .vec
392 .iter()
393 .map(|ρ| ρ.α_γ.abs())
394 .filter(|t| *t > F::EPSILON)
395 .collect::<Vec<F>>();
396 abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
397 let n = abs_weights.len();
398 // Cannot have partial transport; can cause spike count explosion
399 let chg = abs_weights.into_iter().zip((1..=n).rev()).try_fold(
400 0.0,
401 |smaller_total, (w, m)| {
402 let mf = F::cast_from(m);
403 let reduction = w * mf + smaller_total;
404 if reduction >= reduction_target {
405 ControlFlow::Break((reduction_target - smaller_total) / mf)
406 } else {
407 ControlFlow::Continue(smaller_total + w)
408 }
409 },
410 );
411 match chg {
412 ControlFlow::Continue(_) => self.vec.iter_mut().for_each(|δ| δ.α_γ = 0.0),
413 ControlFlow::Break(chg_one) => self.vec.iter_mut().for_each(|ρ| {
414 let t = ρ.α_γ.abs();
415 if t > 0.0 {
416 if ALLOW_PARTIAL_TRANSPORT {
417 let new = (t - chg_one).max(0.0);
418 ρ.α_γ = ρ.α_γ.signum() * new;
419 }
420 }
421 }),
422 }
423 }
424 } else {
425 // This version zeroes smallest weights, avoiding partial transport.
426 let mut abs_weights_idx = self
427 .vec
428 .iter()
429 .map(|ρ| ρ.α_γ.abs())
430 .zip(0..)
431 .filter(|(w, _)| *w >= 0.0)
432 .collect::<Vec<(F, usize)>>();
433 abs_weights_idx.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap());
434
435 let mut left = reduction_target;
436
437 for (w, i) in abs_weights_idx {
438 left -= w;
439 let ρ = &mut self.vec[i];
440 ρ.α_γ = 0.0;
441 if left < 0.0 {
442 break;
443 }
444 }
445 }
446
447 all_ok = false
448 }
449
450 if !all_ok && *attempts >= tconfig.max_attempts {
451 for ρ in self.iter_mut() {
452 ρ.α_γ = 0.0;
453 }
454 }
455
456 for ρ in self.iter_mut() {
457 if ρ.α_γ == 0.0 {
458 ρ.fail_count += 1;
459 } else if all_ok {
460 ρ.fail_count = 0;
461 }
462 }
463
464 all_ok
465 }
466
467 /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$
468 pub(crate) fn μ0_minus_γ0_radon(&self) -> F {
469 self.vec.iter().map(|ρ| (ρ.α_μ_orig - ρ.α_γ).abs()).sum()
470 }
471
472 /// Returns $∫ c_2 d|γ|$
473 #[replace_float_literals(F::cast_from(literal))]
474 pub(crate) fn c2integral(&self) -> F {
475 self.vec
202 .iter() 476 .iter()
203 .zip(γ1.iter_masses()) 477 .map(|ρ| ρ.y.dist2_squared(&ρ.x) / 2.0 * ρ.α_γ.abs())
204 .map(|(&a, b)| a - b), 478 .sum()
205 ); 479 }
206 (μ_base_masses, μ_base_minus_γ0) 480
207 } 481 #[replace_float_literals(F::cast_from(literal))]
208 482 pub(crate) fn get_transport_stats(&self, stats: &mut IterInfo<F>, μ: &RNDM<N, F>) {
209 /// A posteriori transport adaptation. 483 // TODO: This doesn't take into account μ[i].α becoming zero in the latest tranport
210 #[replace_float_literals(F::cast_from(literal))] 484 // attempt, for i < self.len(), when a corresponding source term also exists with index
211 pub(crate) fn aposteriori_transport<F, const N: usize>( 485 // j ≥ self.len(). For now, we let that be reflected in the prune count.
212 γ1: &mut RNDM<F, N>, 486 stats.inserted += μ.len() - self.len();
213 μ: &mut RNDM<F, N>, 487
214 μ_base_minus_γ0: &mut RNDM<F, N>, 488 let transp = stats.get_transport_mut();
215 μ_base_masses: &Vec<F>, 489
216 extra: Option<F>, 490 transp.dist = {
217 ε: F, 491 let (a, b) = transp.dist;
218 tconfig: &TransportConfig<F>, 492 (a + self.c2integral(), b + self.norm(Radon))
219 ) -> bool 493 };
220 where 494 transp.untransported_fraction = {
221 F: Float + ToNalgebraRealField, 495 let (a, b) = transp.untransported_fraction;
222 { 496 let source = self.iter().map(|ρ| ρ.α_μ_orig.abs()).sum();
223 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, 497 (a + self.μ0_minus_γ0_radon(), b + source)
224 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 498 };
225 // at that point to zero, and retry. 499 transp.transport_error = {
226 let mut all_ok = true; 500 let (a, b) = transp.transport_error;
227 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { 501 //(a + self.dist_matching(&μ), b + self.norm(Radon))
228 if α_μ == 0.0 && *α_γ1 != 0.0 { 502
229 all_ok = false; 503 // This ignores points that have been not transported at all, to only calculate
230 *α_γ1 = 0.0; 504 // destnation error; untransported_fraction accounts for not being able to transport
231 } 505 // at all.
232 } 506 self.iter()
233 507 .zip(μ.iter_spikes())
234 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). 508 .fold((a, b), |(a, b), (ρ, δ)| {
235 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, 509 let transported = ρ.α_γ.abs();
236 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. 510 if transported > F::EPSILON {
237 let nγ = γ1.norm(Radon); 511 (a + (ρ.α_γ - δ.α).abs(), b + transported)
238 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); 512 } else {
239 let t = ε * tconfig.tolerance_mult_con; 513 (a, b)
240 if nγ * nΔ > t { 514 }
241 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, 515 })
242 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we 516 };
243 // will not enter here. 517 }
244 *γ1 *= tconfig.adaptation * t / (nγ * nΔ); 518
245 all_ok = false 519 /// Prune spikes with zero weight. To maintain correct ordering between μ and γ, also the
246 } 520 /// latter needs to be pruned when μ is.
247 521 pub(crate) fn prune_compat(&mut self, μ: &mut RNDM<N, F>, stats: &mut IterInfo<F>) {
248 if !all_ok { 522 assert!(self.vec.len() <= μ.len());
249 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} 523 let old_len = μ.len();
250 μ_base_minus_γ0.set_masses( 524 for (ρ, δ) in self.vec.iter_mut().zip(μ.iter_spikes()) {
251 μ_base_masses 525 ρ.prune = !(δ.α.abs() > F::EPSILON);
252 .iter() 526 }
253 .zip(γ1.iter_masses()) 527 μ.prune_by(|δ| δ.α.abs() > F::EPSILON);
254 .map(|(&a, b)| a - b), 528 stats.pruned += old_len - μ.len();
255 ); 529 self.vec.retain(|ρ| !ρ.prune);
256 } 530 assert!(self.vec.len() <= μ.len());
257 531 }
258 all_ok 532 }
533
534 impl<const N: usize, F: Float> Norm<Radon, F> for Transport<N, F> {
535 fn norm(&self, _: Radon) -> F {
536 self.iter().map(|ρ| ρ.α_γ.abs()).sum()
537 }
538 }
539
540 impl<const N: usize, F: Float> MulAssign<F> for Transport<N, F> {
541 fn mul_assign(&mut self, factor: F) {
542 for ρ in self.iter_mut() {
543 ρ.α_γ *= factor;
544 }
545 }
259 } 546 }
260 547
261 /// Iteratively solve the pointsource localisation problem using sliding forward-backward 548 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
262 /// splitting 549 /// splitting
263 /// 550 ///
264 /// The parametrisation is as for [`pointsource_fb_reg`]. 551 /// The parametrisation is as for [`pointsource_fb_reg`].
265 /// Inertia is currently not supported. 552 /// Inertia is currently not supported.
266 #[replace_float_literals(F::cast_from(literal))] 553 #[replace_float_literals(F::cast_from(literal))]
267 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>( 554 pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>(
268 opA: &A, 555 f: &Dat,
269 b: &A::Observable, 556 reg: &Reg,
270 reg: Reg,
271 prox_penalty: &P, 557 prox_penalty: &P,
272 config: &SlidingFBConfig<F>, 558 config: &SlidingFBConfig<F>,
273 iterator: I, 559 iterator: I,
274 mut plotter: SeqPlotter<F, N>, 560 mut plotter: Plot,
275 ) -> RNDM<F, N> 561 μ0: Option<RNDM<N, F>>,
562 ) -> DynResult<RNDM<N, F>>
276 where 563 where
277 F: Float + ToNalgebraRealField, 564 F: Float + ToNalgebraRealField,
278 I: AlgIteratorFactory<IterInfo<F, N>>, 565 I: AlgIteratorFactory<IterInfo<F>>,
279 A: ForwardModel<RNDM<F, N>, F> 566 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
280 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F> 567 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
281 + BoundedCurvature<FloatType = F>, 568 //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>,
282 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, 569 RNDM<N, F>: SpikeMerging<F>,
283 A::PreadjointCodomain: DifferentiableRealMapping<F, N>, 570 Reg: SlidingRegTerm<Loc<N, F>, F>,
284 RNDM<F, N>: SpikeMerging<F>, 571 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
285 Reg: SlidingRegTerm<F, N>, 572 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
286 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
287 PlotLookup: Plotting<N>,
288 { 573 {
289 // Check parameters 574 // Check parameters
290 assert!(config.τ0 > 0.0, "Invalid step length parameter"); 575 ensure!(config.τ0 > 0.0, "Invalid step length parameter");
291 config.transport.check(); 576 config.transport.check()?;
292 577
293 // Initialise iterates 578 // Initialise iterates
294 let mut μ = DiscreteMeasure::new(); 579 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
295 let mut γ1 = DiscreteMeasure::new(); 580 let mut γ = Transport::new();
296 let mut residual = -b; // Has to equal $Aμ-b$.
297 581
298 // Set up parameters 582 // Set up parameters
299 // let opAnorm = opA.opnorm_bound(Radon, L2); 583 // let opAnorm = opA.opnorm_bound(Radon, L2);
300 //let max_transport = config.max_transport.scale 584 //let max_transport = config.max_transport.scale
301 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 585 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
302 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 586 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
303 let ℓ = 0.0; 587 let ℓ = 0.0;
304 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 588 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
305 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); 589
306 let transport_lip = maybe_transport_lip.unwrap(); 590 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
307 let calculate_θ = |ℓ_F, max_transport| { 591 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
308 let ℓ_r = transport_lip * max_transport; 592 (maybe_ℓ_F, Ok(transport_lip)) => {
309 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) 593 let calculate_θτ = move |ℓ_F, max_transport| {
310 }; 594 let ℓ_r = transport_lip * max_transport;
311 let mut θ_or_adaptive = match maybe_ℓ_F0 { 595 config.transport.θ0 / (ℓ + ℓ_F + ℓ_r)
312 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), 596 };
313 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { 597 match maybe_ℓ_F {
314 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual 598 Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
315 max_transport: 0.0, 599 l: ℓ_F, // TODO: could estimate computing the real reesidual
316 g: calculate_θ, 600 max_transport: 0.0,
317 }, 601 g: calculate_θτ,
318 None => TransportStepLength::FullyAdaptive { 602 },
319 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials 603 Err(_) => TransportStepLength::FullyAdaptive {
320 max_transport: 0.0, 604 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
321 g: calculate_θ, 605 max_transport: 0.0,
322 }, 606 g: calculate_θτ,
607 },
608 }
609 }
323 }; 610 };
324 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 611 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
325 // by τ compared to the conditional gradient approach. 612 // by τ compared to the conditional gradient approach.
326 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); 613 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
327 let mut ε = tolerance.initial(); 614 let mut ε = tolerance.initial();
328 615
329 // Statistics 616 // Statistics
330 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { 617 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
331 value: residual.norm2_squared_div2() + reg.apply(μ), 618 value: f.apply(μ) + reg.apply(μ),
332 n_spikes: μ.len(), 619 n_spikes: μ.len(),
333 ε, 620 ε,
334 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), 621 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
335 ..stats 622 ..stats
336 }; 623 };
337 let mut stats = IterInfo::new(); 624 let mut stats = IterInfo::new();
338 625
339 // Run the algorithm 626 // Run the algorithm
340 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { 627 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
341 // Calculate initial transport 628 // Calculate initial transport
342 let v = opA.preadjoint().apply(residual); 629 let v = f.differential(&μ);
343 let (μ_base_masses, mut μ_base_minus_γ0) = 630 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport);
344 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); 631
632 let mut attempts = 0;
345 633
346 // Solve finite-dimensional subproblem several times until the dual variable for the 634 // Solve finite-dimensional subproblem several times until the dual variable for the
347 // regularisation term conforms to the assumptions made for the transport above. 635 // regularisation term conforms to the assumptions made for the transport above.
348 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { 636 let (maybe_d, _within_tolerances, mut τv̆, μ̆) = 'adapt_transport: loop {
637 // Set initial guess for μ=μ^{k+1}.
638 γ.μ̆_into(&mut μ);
639 let μ̆ = μ.clone();
640
349 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 641 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
350 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 642 //let residual_μ̆ = calculate_residual2(&γ1, &μ0_minus_γ0, opA, b);
351 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); 643 // TODO: this could be optimised by doing the differential like the
644 // old residual2.
645 // NOTE: This assumes that μ = γ1
646 let mut τv̆ = f.differential(&μ̆) * τ;
352 647
353 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 648 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
354 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 649 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
355 &mut μ, 650 &mut μ,
356 &mut τv̆, 651 &mut τv̆,
357 &γ1,
358 Some(&μ_base_minus_γ0),
359 τ, 652 τ,
360 ε, 653 ε,
361 &config.insertion, 654 &config.insertion,
362 &reg, 655 &reg,
363 &state, 656 &state,
364 &mut stats, 657 &mut stats,
365 ); 658 )?;
366 659
367 // A posteriori transport adaptation. 660 // A posteriori transport adaptation.
368 if aposteriori_transport( 661 if γ.aposteriori_transport(&μ, &μ̆, &mut τv̆, None, ε, &config.transport, &mut attempts)
369 &mut γ1, 662 {
370 &mut μ, 663 break 'adapt_transport (maybe_d, within_tolerances, τv̆, μ̆);
371 &mut μ_base_minus_γ0, 664 }
372 &μ_base_masses, 665
373 None, 666 stats.get_transport_mut().readjustment_iters += 1;
374 ε,
375 &config.transport,
376 ) {
377 break 'adapt_transport (maybe_d, within_tolerances, τv̆);
378 }
379 }; 667 };
380 668
381 stats.untransported_fraction = Some({ 669 γ.get_transport_stats(&mut stats, &μ);
382 assert_eq!(μ_base_masses.len(), γ1.len());
383 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
384 let source = μ_base_masses.iter().map(|v| v.abs()).sum();
385 (a + μ_base_minus_γ0.norm(Radon), b + source)
386 });
387 stats.transport_error = Some({
388 assert_eq!(μ_base_masses.len(), γ1.len());
389 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
390 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
391 });
392 670
393 // Merge spikes. 671 // Merge spikes.
394 // This crucially expects the merge routine to be stable with respect to spike locations, 672 // This crucially expects the merge routine to be stable with respect to spike locations,
395 // and not to performing any pruning. That is be to done below simultaneously for γ. 673 // and not to performing any pruning. That is be to done below simultaneously for γ.
396 let ins = &config.insertion; 674 if config.insertion.merge_now(&state) {
397 if ins.merge_now(&state) {
398 stats.merged += prox_penalty.merge_spikes( 675 stats.merged += prox_penalty.merge_spikes(
399 &mut μ, 676 &mut μ,
400 &mut τv̆, 677 &mut τv̆,
401 &γ1, 678 &μ̆,
402 Some(&μ_base_minus_γ0),
403 τ, 679 τ,
404 ε, 680 ε,
405 ins, 681 &config.insertion,
406 &reg, 682 &reg,
407 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), 683 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
408 ); 684 );
409 } 685 }
410 686
411 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 687 γ.prune_compat(&mut μ, &mut stats);
412 // latter needs to be pruned when μ is.
413 // TODO: This could do with a two-vector Vec::retain to avoid copies.
414 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
415 if μ_new.len() != μ.len() {
416 let mut μ_iter = μ.iter_spikes();
417 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
418 stats.pruned += μ.len() - μ_new.len();
419 μ = μ_new;
420 }
421
422 // Update residual
423 residual = calculate_residual(&μ, opA, b);
424 688
425 let iter = state.iteration(); 689 let iter = state.iteration();
426 stats.this_iters += 1; 690 stats.this_iters += 1;
427 691
428 // Give statistics if requested 692 // Give statistics if requested
429 state.if_verbose(|| { 693 state.if_verbose(|| {
430 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); 694 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
431 full_stats( 695 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
432 &residual,
433 &μ,
434 ε,
435 std::mem::replace(&mut stats, IterInfo::new()),
436 )
437 }); 696 });
438 697
439 // Update main tolerance for next iteration 698 // Update main tolerance for next iteration
440 ε = tolerance.update(ε, iter); 699 ε = tolerance.update(ε, iter);
441 } 700 }
442 701
443 postprocess(μ, &config.insertion, L2Squared, opA, b) 702 //postprocess(μ, &config.insertion, f)
444 } 703 postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃))
704 }

mercurial