| 94 /// Adaptive step length. |
108 /// Adaptive step length. |
| 95 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
109 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
| 96 FullyAdaptive { l: F, max_transport: F, g: G }, |
110 FullyAdaptive { l: F, max_transport: F, g: G }, |
| 97 } |
111 } |
| 98 |
112 |
| 99 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` |
113 #[derive(Clone, Debug, Serialize)] |
| 100 /// with step lengh τ and transport step length `θ_or_adaptive`. |
114 pub struct SingleTransport<const N: usize, F: Float> { |
| 101 #[replace_float_literals(F::cast_from(literal))] |
115 /// Source point |
| 102 pub(crate) fn initial_transport<F, G, D, const N: usize>( |
116 x: Loc<N, F>, |
| 103 γ1: &mut RNDM<F, N>, |
117 /// Target point |
| 104 μ: &mut RNDM<F, N>, |
118 y: Loc<N, F>, |
| 105 τ: F, |
119 /// Original mass |
| 106 θ_or_adaptive: &mut TransportStepLength<F, G>, |
120 α_μ_orig: F, |
| 107 v: D, |
121 /// Transported mass |
| 108 ) -> (Vec<F>, RNDM<F, N>) |
122 α_γ: F, |
| 109 where |
123 /// Helper for pruning |
| 110 F: Float + ToNalgebraRealField, |
124 prune: bool, |
| 111 G: Fn(F, F) -> F, |
125 /// Fail count |
| 112 D: DifferentiableRealMapping<F, N>, |
126 fail_count: usize, |
| 113 { |
127 } |
| 114 use TransportStepLength::*; |
128 |
| 115 |
129 #[derive(Clone, Debug, Serialize)] |
| 116 // Save current base point and shift μ to new positions. Idea is that |
130 pub struct Transport<const N: usize, F: Float> { |
| 117 // μ_base(_masses) = μ^k (vector of masses) |
131 vec: Vec<SingleTransport<N, F>>, |
| 118 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} |
132 } |
| 119 // γ1 = π_♯^1γ^{k+1} |
133 |
| 120 // μ = μ^{k+1} |
134 /// Whether partiall transported points are allowed. |
| 121 let μ_base_masses: Vec<F> = μ.iter_masses().collect(); |
135 /// |
| 122 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. |
136 /// Partial transport can cause spike count explosion, so full or zero |
| 123 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates |
137 /// transport is generally preferred. If this is set to `true`, different |
| 124 //let mut sum_norm_dv = 0.0; |
138 /// transport adaptation heuristics will be used. |
| 125 let γ_prev_len = γ1.len(); |
139 const ALLOW_PARTIAL_TRANSPORT: bool = true; |
| 126 assert!(μ.len() >= γ_prev_len); |
140 const MINIMAL_PARTIAL_TRANSPORT: bool = true; |
| 127 γ1.extend(μ[γ_prev_len..].iter().cloned()); |
141 |
| 128 |
142 impl<const N: usize, F: Float> Transport<N, F> { |
| 129 // Calculate initial transport and step length. |
143 pub(crate) fn new() -> Self { |
| 130 // First calculate initial transported weights |
144 Transport { vec: Vec::new() } |
| 131 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
145 } |
| 132 // If old transport has opposing sign, the new transport will be none. |
146 |
| 133 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { |
147 pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ SingleTransport<N, F>> { |
| 134 0.0 |
148 self.vec.iter() |
| 135 } else { |
149 } |
| 136 δ.α |
150 |
| 137 }; |
151 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut SingleTransport<N, F>> { |
| 138 } |
152 self.vec.iter_mut() |
| 139 |
153 } |
| 140 // Calculate transport rays. |
154 |
| 141 match *θ_or_adaptive { |
155 pub(crate) fn extend<I>(&mut self, it: I) |
| 142 Fixed(θ) => { |
156 where |
| 143 let θτ = τ * θ; |
157 I: IntoIterator<Item = SingleTransport<N, F>>, |
| 144 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
158 { |
| 145 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
159 self.vec.extend(it) |
| 146 } |
160 } |
| 147 } |
161 |
| 148 AdaptiveMax { |
162 pub(crate) fn len(&self) -> usize { |
| 149 l: ℓ_F, |
163 self.vec.len() |
| 150 ref mut max_transport, |
164 } |
| 151 g: ref calculate_θ, |
165 |
| 152 } => { |
166 // pub(crate) fn dist_matching(&self, μ: &RNDM<N, F>) -> F { |
| 153 *max_transport = max_transport.max(γ1.norm(Radon)); |
167 // self.iter() |
| 154 let θτ = τ * calculate_θ(ℓ_F, *max_transport); |
168 // .zip(μ.iter_spikes()) |
| 155 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
169 // .map(|(ρ, δ)| (ρ.α_γ - δ.α).abs()) |
| 156 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
170 // .sum() |
| 157 } |
171 // } |
| 158 } |
172 |
| 159 FullyAdaptive { |
173 /// Construct `μ̆`, replacing the contents of `μ`. |
| 160 l: ref mut adaptive_ℓ_F, |
174 #[replace_float_literals(F::cast_from(literal))] |
| 161 ref mut max_transport, |
175 pub(crate) fn μ̆_into(&self, μ: &mut RNDM<N, F>) { |
| 162 g: ref calculate_θ, |
176 assert!(self.len() <= μ.len()); |
| 163 } => { |
177 |
| 164 *max_transport = max_transport.max(γ1.norm(Radon)); |
178 // First transported points |
| 165 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); |
179 for (δ, ρ) in izip!(μ.iter_spikes_mut(), self.iter()) { |
| 166 // Do two runs through the spikes to update θ, breaking if first run did not cause |
180 if ρ.α_γ.abs() > 0.0 { |
| 167 // a change. |
181 // Transport – transported point |
| 168 for _i in 0..=1 { |
182 δ.α = ρ.α_γ; |
| 169 let mut changes = false; |
183 δ.x = ρ.y; |
| 170 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
184 } else { |
| 171 let dv_x = v.differential(&δ.x); |
185 // No transport – original point |
| 172 let g = &dv_x * (ρ.α.signum() * θ * τ); |
186 δ.α = ρ.α_μ_orig; |
| 173 ρ.x = δ.x - g; |
187 δ.x = ρ.x; |
| 174 let n = g.norm2(); |
188 } |
| 175 if n >= F::EPSILON { |
189 } |
| 176 // Estimate Lipschitz factor of ∇v |
190 |
| 177 let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n; |
191 // Then source points with partial transport |
| 178 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); |
192 let mut i = self.len(); |
| 179 θ = calculate_θ(*adaptive_ℓ_F, *max_transport); |
193 if ALLOW_PARTIAL_TRANSPORT { |
| 180 changes = true |
194 // This can cause the number of points to explode, so cannot have partial transport. |
| |
195 for ρ in self.iter() { |
| |
196 let α = ρ.α_μ_orig - ρ.α_γ; |
| |
197 if ρ.α_γ.abs() > F::EPSILON && α != 0.0 { |
| |
198 let δ = DeltaMeasure { α, x: ρ.x }; |
| |
199 if i < μ.len() { |
| |
200 μ[i] = δ; |
| |
201 } else { |
| |
202 μ.push(δ) |
| |
203 } |
| |
204 i += 1; |
| |
205 } |
| |
206 } |
| |
207 } |
| |
208 μ.truncate(i); |
| |
209 } |
| |
210 |
| |
211 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` |
| |
212 /// with step lengh τ and transport step length `θ_or_adaptive`. |
| |
213 #[replace_float_literals(F::cast_from(literal))] |
| |
214 pub(crate) fn initial_transport<G, D>( |
| |
215 &mut self, |
| |
216 μ: &RNDM<N, F>, |
| |
217 _τ: F, |
| |
218 τθ_or_adaptive: &mut TransportStepLength<F, G>, |
| |
219 v: D, |
| |
220 tconfig: &TransportConfig<F>, |
| |
221 ) where |
| |
222 G: Fn(F, F) -> F, |
| |
223 D: DifferentiableRealMapping<N, F>, |
| |
224 { |
| |
225 use TransportStepLength::*; |
| |
226 |
| |
227 // Initialise transport structure weights |
| |
228 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { |
| |
229 ρ.α_μ_orig = δ.α; |
| |
230 ρ.x = δ.x; |
| |
231 if ρ.fail_count > tconfig.max_fail { |
| |
232 ρ.α_γ = 0.0 |
| |
233 } else { |
| |
234 // If old transport has opposing sign, the new transport will be none. |
| |
235 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { |
| |
236 0.0 |
| |
237 } else { |
| |
238 δ.α |
| |
239 } |
| |
240 } |
| |
241 } |
| |
242 |
| |
243 let γ_prev_len = self.len(); |
| |
244 assert!(μ.len() >= γ_prev_len); |
| |
245 self.extend(μ[γ_prev_len..].iter().map(|δ| SingleTransport { |
| |
246 x: δ.x, |
| |
247 y: δ.x, // Just something, will be filled properly in the next phase |
| |
248 α_μ_orig: δ.α, |
| |
249 α_γ: δ.α, |
| |
250 prune: false, |
| |
251 fail_count: 0, |
| |
252 })); |
| |
253 |
| |
254 // Calculate transport rays. |
| |
255 match *τθ_or_adaptive { |
| |
256 Fixed(θ) => { |
| |
257 for ρ in self.iter_mut() { |
| |
258 if ρ.fail_count <= tconfig.max_fail { |
| |
259 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); |
| 181 } |
260 } |
| 182 } |
261 } |
| 183 if !changes { |
262 } |
| 184 break; |
263 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { |
| |
264 *max_transport = max_transport.max(self.norm(Radon)); |
| |
265 let θτ = calculate_θτ(ℓ_F, *max_transport); |
| |
266 for ρ in self.iter_mut() { |
| |
267 if ρ.fail_count <= tconfig.max_fail { |
| |
268 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); |
| |
269 } |
| 185 } |
270 } |
| 186 } |
271 } |
| 187 } |
272 FullyAdaptive { |
| 188 } |
273 l: ref mut adaptive_ℓ_F, |
| 189 |
274 ref mut max_transport, |
| 190 // Set initial guess for μ=μ^{k+1}. |
275 g: ref calculate_θτ, |
| 191 for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { |
276 } => { |
| 192 if ρ.α.abs() > F::EPSILON { |
277 *max_transport = max_transport.max(self.norm(Radon)); |
| 193 δ.x = ρ.x; |
278 let mut θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); |
| 194 //δ.α = ρ.α; // already set above |
279 // Do two runs through the spikes to update θ, breaking if first run did not cause |
| 195 } else { |
280 // a change. |
| 196 δ.α = β; |
281 for _i in 0..=1 { |
| 197 } |
282 let mut changes = false; |
| 198 } |
283 for ρ in self.iter_mut() { |
| 199 // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) |
284 if ρ.fail_count < tconfig.max_fail { |
| 200 μ_base_minus_γ0.set_masses( |
285 let dv_x = v.differential(&ρ.x); |
| 201 μ_base_masses |
286 let g = &dv_x * (ρ.α_γ.signum() * θτ); |
| |
287 ρ.y = ρ.x - g; |
| |
288 let n = g.norm2(); |
| |
289 if n >= F::EPSILON { |
| |
290 // Estimate Lipschitz factor of ∇v |
| |
291 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; |
| |
292 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); |
| |
293 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); |
| |
294 changes = true |
| |
295 } |
| |
296 } |
| |
297 } |
| |
298 if !changes { |
| |
299 break; |
| |
300 } |
| |
301 } |
| |
302 } |
| |
303 } |
| |
304 } |
| |
305 |
| |
306 /// A posteriori transport adaptation. |
| |
307 #[replace_float_literals(F::cast_from(literal))] |
| |
308 pub(crate) fn aposteriori_transport<D>( |
| |
309 &mut self, |
| |
310 μ: &RNDM<N, F>, |
| |
311 μ̆: &RNDM<N, F>, |
| |
312 _v: &mut D, |
| |
313 extra: Option<F>, |
| |
314 ε: F, |
| |
315 tconfig: &TransportConfig<F>, |
| |
316 attempts: &mut usize, |
| |
317 ) -> bool |
| |
318 where |
| |
319 D: DifferentiableRealMapping<N, F>, |
| |
320 { |
| |
321 *attempts += 1; |
| |
322 |
| |
323 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, |
| |
324 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 |
| |
325 // at that point to zero, and retry. |
| |
326 let mut all_ok = true; |
| |
327 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { |
| |
328 if δ.α == 0.0 && ρ.α_γ != 0.0 { |
| |
329 all_ok = false; |
| |
330 ρ.α_γ = 0.0; |
| |
331 } |
| |
332 } |
| |
333 |
| |
334 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). |
| |
335 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ̆^k |
| |
336 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. |
| |
337 let nγ = self.norm(Radon); |
| |
338 let nΔ = μ.dist_matching(&μ̆) + extra.unwrap_or(0.0); |
| |
339 let t = ε * tconfig.tolerance_mult_con; |
| |
340 if nγ * nΔ > t && *attempts >= tconfig.max_attempts { |
| |
341 all_ok = false; |
| |
342 } else if nγ * nΔ > t { |
| |
343 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, |
| |
344 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we |
| |
345 // will not enter here. |
| |
346 //*self *= tconfig.adaptation * t / (nγ * nΔ); |
| |
347 |
| |
348 // We want a consistent behaviour that has the potential to set many weights to zero. |
| |
349 // Therefore, we find the smallest uniform reduction `chg_one`, subtracted |
| |
350 // from all weights, that achieves total `adapt` adaptation. |
| |
351 let adapt_to = tconfig.adaptation * t / nΔ; |
| |
352 let reduction_target = nγ - adapt_to; |
| |
353 assert!(reduction_target > 0.0); |
| |
354 if ALLOW_PARTIAL_TRANSPORT { |
| |
355 if MINIMAL_PARTIAL_TRANSPORT { |
| |
356 // This reduces weights of transport, starting from … until `adapt` is |
| |
357 // exhausted. It will, therefore, only ever cause one extrap point insertion |
| |
358 // at the sources, unlike “full” partial transport. |
| |
359 //let refs = self.vec.iter_mut().collect::<Vec<_>>(); |
| |
360 //refs.sort_by(|ρ1, ρ2| ρ1.α_γ.abs().partial_cmp(&ρ2.α_γ.abs()).unwrap()); |
| |
361 // let mut it = refs.into_iter(); |
| |
362 // |
| |
363 // Maybe sort by differential norm |
| |
364 // let mut refs = self |
| |
365 // .vec |
| |
366 // .iter_mut() |
| |
367 // .map(|ρ| { |
| |
368 // let val = v.differential(&ρ.x).norm2_squared(); |
| |
369 // (ρ, val) |
| |
370 // }) |
| |
371 // .collect::<Vec<_>>(); |
| |
372 // refs.sort_by(|(_, v1), (_, v2)| v2.partial_cmp(&v1).unwrap()); |
| |
373 // let mut it = refs.into_iter().map(|(ρ, _)| ρ); |
| |
374 let mut it = self.vec.iter_mut().rev(); |
| |
375 let _unused = it.try_fold(reduction_target, |left, ρ| { |
| |
376 let w = ρ.α_γ.abs(); |
| |
377 if left <= w { |
| |
378 ρ.α_γ = ρ.α_γ.signum() * (w - left); |
| |
379 ControlFlow::Break(()) |
| |
380 } else { |
| |
381 ρ.α_γ = 0.0; |
| |
382 ControlFlow::Continue(left - w) |
| |
383 } |
| |
384 }); |
| |
385 } else { |
| |
386 // This version equally reduces all weights. It causes partial transport, which |
| |
387 // has the problem that that we need to then adapt weights in both start and |
| |
388 // end points, in insert_and_reweigh, somtimes causing the number of spikes μ |
| |
389 // to explode. |
| |
390 let mut abs_weights = self |
| |
391 .vec |
| |
392 .iter() |
| |
393 .map(|ρ| ρ.α_γ.abs()) |
| |
394 .filter(|t| *t > F::EPSILON) |
| |
395 .collect::<Vec<F>>(); |
| |
396 abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap()); |
| |
397 let n = abs_weights.len(); |
| |
398 // Cannot have partial transport; can cause spike count explosion |
| |
399 let chg = abs_weights.into_iter().zip((1..=n).rev()).try_fold( |
| |
400 0.0, |
| |
401 |smaller_total, (w, m)| { |
| |
402 let mf = F::cast_from(m); |
| |
403 let reduction = w * mf + smaller_total; |
| |
404 if reduction >= reduction_target { |
| |
405 ControlFlow::Break((reduction_target - smaller_total) / mf) |
| |
406 } else { |
| |
407 ControlFlow::Continue(smaller_total + w) |
| |
408 } |
| |
409 }, |
| |
410 ); |
| |
411 match chg { |
| |
412 ControlFlow::Continue(_) => self.vec.iter_mut().for_each(|δ| δ.α_γ = 0.0), |
| |
413 ControlFlow::Break(chg_one) => self.vec.iter_mut().for_each(|ρ| { |
| |
414 let t = ρ.α_γ.abs(); |
| |
415 if t > 0.0 { |
| |
416 if ALLOW_PARTIAL_TRANSPORT { |
| |
417 let new = (t - chg_one).max(0.0); |
| |
418 ρ.α_γ = ρ.α_γ.signum() * new; |
| |
419 } |
| |
420 } |
| |
421 }), |
| |
422 } |
| |
423 } |
| |
424 } else { |
| |
425 // This version zeroes smallest weights, avoiding partial transport. |
| |
426 let mut abs_weights_idx = self |
| |
427 .vec |
| |
428 .iter() |
| |
429 .map(|ρ| ρ.α_γ.abs()) |
| |
430 .zip(0..) |
| |
431 .filter(|(w, _)| *w >= 0.0) |
| |
432 .collect::<Vec<(F, usize)>>(); |
| |
433 abs_weights_idx.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()); |
| |
434 |
| |
435 let mut left = reduction_target; |
| |
436 |
| |
437 for (w, i) in abs_weights_idx { |
| |
438 left -= w; |
| |
439 let ρ = &mut self.vec[i]; |
| |
440 ρ.α_γ = 0.0; |
| |
441 if left < 0.0 { |
| |
442 break; |
| |
443 } |
| |
444 } |
| |
445 } |
| |
446 |
| |
447 all_ok = false |
| |
448 } |
| |
449 |
| |
450 if !all_ok && *attempts >= tconfig.max_attempts { |
| |
451 for ρ in self.iter_mut() { |
| |
452 ρ.α_γ = 0.0; |
| |
453 } |
| |
454 } |
| |
455 |
| |
456 for ρ in self.iter_mut() { |
| |
457 if ρ.α_γ == 0.0 { |
| |
458 ρ.fail_count += 1; |
| |
459 } else if all_ok { |
| |
460 ρ.fail_count = 0; |
| |
461 } |
| |
462 } |
| |
463 |
| |
464 all_ok |
| |
465 } |
| |
466 |
| |
467 /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$ |
| |
468 pub(crate) fn μ0_minus_γ0_radon(&self) -> F { |
| |
469 self.vec.iter().map(|ρ| (ρ.α_μ_orig - ρ.α_γ).abs()).sum() |
| |
470 } |
| |
471 |
| |
472 /// Returns $∫ c_2 d|γ|$ |
| |
473 #[replace_float_literals(F::cast_from(literal))] |
| |
474 pub(crate) fn c2integral(&self) -> F { |
| |
475 self.vec |
| 202 .iter() |
476 .iter() |
| 203 .zip(γ1.iter_masses()) |
477 .map(|ρ| ρ.y.dist2_squared(&ρ.x) / 2.0 * ρ.α_γ.abs()) |
| 204 .map(|(&a, b)| a - b), |
478 .sum() |
| 205 ); |
479 } |
| 206 (μ_base_masses, μ_base_minus_γ0) |
480 |
| 207 } |
481 #[replace_float_literals(F::cast_from(literal))] |
| 208 |
482 pub(crate) fn get_transport_stats(&self, stats: &mut IterInfo<F>, μ: &RNDM<N, F>) { |
| 209 /// A posteriori transport adaptation. |
483 // TODO: This doesn't take into account μ[i].α becoming zero in the latest tranport |
| 210 #[replace_float_literals(F::cast_from(literal))] |
484 // attempt, for i < self.len(), when a corresponding source term also exists with index |
| 211 pub(crate) fn aposteriori_transport<F, const N: usize>( |
485 // j ≥ self.len(). For now, we let that be reflected in the prune count. |
| 212 γ1: &mut RNDM<F, N>, |
486 stats.inserted += μ.len() - self.len(); |
| 213 μ: &mut RNDM<F, N>, |
487 |
| 214 μ_base_minus_γ0: &mut RNDM<F, N>, |
488 let transp = stats.get_transport_mut(); |
| 215 μ_base_masses: &Vec<F>, |
489 |
| 216 extra: Option<F>, |
490 transp.dist = { |
| 217 ε: F, |
491 let (a, b) = transp.dist; |
| 218 tconfig: &TransportConfig<F>, |
492 (a + self.c2integral(), b + self.norm(Radon)) |
| 219 ) -> bool |
493 }; |
| 220 where |
494 transp.untransported_fraction = { |
| 221 F: Float + ToNalgebraRealField, |
495 let (a, b) = transp.untransported_fraction; |
| 222 { |
496 let source = self.iter().map(|ρ| ρ.α_μ_orig.abs()).sum(); |
| 223 // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, |
497 (a + self.μ0_minus_γ0_radon(), b + source) |
| 224 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 |
498 }; |
| 225 // at that point to zero, and retry. |
499 transp.transport_error = { |
| 226 let mut all_ok = true; |
500 let (a, b) = transp.transport_error; |
| 227 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { |
501 //(a + self.dist_matching(&μ), b + self.norm(Radon)) |
| 228 if α_μ == 0.0 && *α_γ1 != 0.0 { |
502 |
| 229 all_ok = false; |
503 // This ignores points that have been not transported at all, to only calculate |
| 230 *α_γ1 = 0.0; |
504 // destnation error; untransported_fraction accounts for not being able to transport |
| 231 } |
505 // at all. |
| 232 } |
506 self.iter() |
| 233 |
507 .zip(μ.iter_spikes()) |
| 234 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). |
508 .fold((a, b), |(a, b), (ρ, δ)| { |
| 235 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, |
509 let transported = ρ.α_γ.abs(); |
| 236 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. |
510 if transported > F::EPSILON { |
| 237 let nγ = γ1.norm(Radon); |
511 (a + (ρ.α_γ - δ.α).abs(), b + transported) |
| 238 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); |
512 } else { |
| 239 let t = ε * tconfig.tolerance_mult_con; |
513 (a, b) |
| 240 if nγ * nΔ > t { |
514 } |
| 241 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, |
515 }) |
| 242 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we |
516 }; |
| 243 // will not enter here. |
517 } |
| 244 *γ1 *= tconfig.adaptation * t / (nγ * nΔ); |
518 |
| 245 all_ok = false |
519 /// Prune spikes with zero weight. To maintain correct ordering between μ and γ, also the |
| 246 } |
520 /// latter needs to be pruned when μ is. |
| 247 |
521 pub(crate) fn prune_compat(&mut self, μ: &mut RNDM<N, F>, stats: &mut IterInfo<F>) { |
| 248 if !all_ok { |
522 assert!(self.vec.len() <= μ.len()); |
| 249 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} |
523 let old_len = μ.len(); |
| 250 μ_base_minus_γ0.set_masses( |
524 for (ρ, δ) in self.vec.iter_mut().zip(μ.iter_spikes()) { |
| 251 μ_base_masses |
525 ρ.prune = !(δ.α.abs() > F::EPSILON); |
| 252 .iter() |
526 } |
| 253 .zip(γ1.iter_masses()) |
527 μ.prune_by(|δ| δ.α.abs() > F::EPSILON); |
| 254 .map(|(&a, b)| a - b), |
528 stats.pruned += old_len - μ.len(); |
| 255 ); |
529 self.vec.retain(|ρ| !ρ.prune); |
| 256 } |
530 assert!(self.vec.len() <= μ.len()); |
| 257 |
531 } |
| 258 all_ok |
532 } |
| |
533 |
| |
534 impl<const N: usize, F: Float> Norm<Radon, F> for Transport<N, F> { |
| |
535 fn norm(&self, _: Radon) -> F { |
| |
536 self.iter().map(|ρ| ρ.α_γ.abs()).sum() |
| |
537 } |
| |
538 } |
| |
539 |
| |
540 impl<const N: usize, F: Float> MulAssign<F> for Transport<N, F> { |
| |
541 fn mul_assign(&mut self, factor: F) { |
| |
542 for ρ in self.iter_mut() { |
| |
543 ρ.α_γ *= factor; |
| |
544 } |
| |
545 } |
| 259 } |
546 } |
| 260 |
547 |
| 261 /// Iteratively solve the pointsource localisation problem using sliding forward-backward |
548 /// Iteratively solve the pointsource localisation problem using sliding forward-backward |
| 262 /// splitting |
549 /// splitting |
| 263 /// |
550 /// |
| 264 /// The parametrisation is as for [`pointsource_fb_reg`]. |
551 /// The parametrisation is as for [`pointsource_fb_reg`]. |
| 265 /// Inertia is currently not supported. |
552 /// Inertia is currently not supported. |
| 266 #[replace_float_literals(F::cast_from(literal))] |
553 #[replace_float_literals(F::cast_from(literal))] |
| 267 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>( |
554 pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>( |
| 268 opA: &A, |
555 f: &Dat, |
| 269 b: &A::Observable, |
556 reg: &Reg, |
| 270 reg: Reg, |
|
| 271 prox_penalty: &P, |
557 prox_penalty: &P, |
| 272 config: &SlidingFBConfig<F>, |
558 config: &SlidingFBConfig<F>, |
| 273 iterator: I, |
559 iterator: I, |
| 274 mut plotter: SeqPlotter<F, N>, |
560 mut plotter: Plot, |
| 275 ) -> RNDM<F, N> |
561 μ0: Option<RNDM<N, F>>, |
| |
562 ) -> DynResult<RNDM<N, F>> |
| 276 where |
563 where |
| 277 F: Float + ToNalgebraRealField, |
564 F: Float + ToNalgebraRealField, |
| 278 I: AlgIteratorFactory<IterInfo<F, N>>, |
565 I: AlgIteratorFactory<IterInfo<F>>, |
| 279 A: ForwardModel<RNDM<F, N>, F> |
566 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, |
| 280 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F> |
567 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 281 + BoundedCurvature<FloatType = F>, |
568 //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, |
| 282 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, |
569 RNDM<N, F>: SpikeMerging<F>, |
| 283 A::PreadjointCodomain: DifferentiableRealMapping<F, N>, |
570 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 284 RNDM<F, N>: SpikeMerging<F>, |
571 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 285 Reg: SlidingRegTerm<F, N>, |
572 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 286 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 287 PlotLookup: Plotting<N>, |
|
| 288 { |
573 { |
| 289 // Check parameters |
574 // Check parameters |
| 290 assert!(config.τ0 > 0.0, "Invalid step length parameter"); |
575 ensure!(config.τ0 > 0.0, "Invalid step length parameter"); |
| 291 config.transport.check(); |
576 config.transport.check()?; |
| 292 |
577 |
| 293 // Initialise iterates |
578 // Initialise iterates |
| 294 let mut μ = DiscreteMeasure::new(); |
579 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 295 let mut γ1 = DiscreteMeasure::new(); |
580 let mut γ = Transport::new(); |
| 296 let mut residual = -b; // Has to equal $Aμ-b$. |
|
| 297 |
581 |
| 298 // Set up parameters |
582 // Set up parameters |
| 299 // let opAnorm = opA.opnorm_bound(Radon, L2); |
583 // let opAnorm = opA.opnorm_bound(Radon, L2); |
| 300 //let max_transport = config.max_transport.scale |
584 //let max_transport = config.max_transport.scale |
| 301 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
585 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
| 302 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
586 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
| 303 let ℓ = 0.0; |
587 let ℓ = 0.0; |
| 304 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
588 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; |
| 305 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); |
589 |
| 306 let transport_lip = maybe_transport_lip.unwrap(); |
590 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) { |
| 307 let calculate_θ = |ℓ_F, max_transport| { |
591 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0), |
| 308 let ℓ_r = transport_lip * max_transport; |
592 (maybe_ℓ_F, Ok(transport_lip)) => { |
| 309 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) |
593 let calculate_θτ = move |ℓ_F, max_transport| { |
| 310 }; |
594 let ℓ_r = transport_lip * max_transport; |
| 311 let mut θ_or_adaptive = match maybe_ℓ_F0 { |
595 config.transport.θ0 / (ℓ + ℓ_F + ℓ_r) |
| 312 //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), |
596 }; |
| 313 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { |
597 match maybe_ℓ_F { |
| 314 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual |
598 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { |
| 315 max_transport: 0.0, |
599 l: ℓ_F, // TODO: could estimate computing the real reesidual |
| 316 g: calculate_θ, |
600 max_transport: 0.0, |
| 317 }, |
601 g: calculate_θτ, |
| 318 None => TransportStepLength::FullyAdaptive { |
602 }, |
| 319 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials |
603 Err(_) => TransportStepLength::FullyAdaptive { |
| 320 max_transport: 0.0, |
604 l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials |
| 321 g: calculate_θ, |
605 max_transport: 0.0, |
| 322 }, |
606 g: calculate_θτ, |
| |
607 }, |
| |
608 } |
| |
609 } |
| 323 }; |
610 }; |
| 324 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
611 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 325 // by τ compared to the conditional gradient approach. |
612 // by τ compared to the conditional gradient approach. |
| 326 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); |
613 let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); |
| 327 let mut ε = tolerance.initial(); |
614 let mut ε = tolerance.initial(); |
| 328 |
615 |
| 329 // Statistics |
616 // Statistics |
| 330 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
617 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { |
| 331 value: residual.norm2_squared_div2() + reg.apply(μ), |
618 value: f.apply(μ) + reg.apply(μ), |
| 332 n_spikes: μ.len(), |
619 n_spikes: μ.len(), |
| 333 ε, |
620 ε, |
| 334 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), |
621 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), |
| 335 ..stats |
622 ..stats |
| 336 }; |
623 }; |
| 337 let mut stats = IterInfo::new(); |
624 let mut stats = IterInfo::new(); |
| 338 |
625 |
| 339 // Run the algorithm |
626 // Run the algorithm |
| 340 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
627 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 341 // Calculate initial transport |
628 // Calculate initial transport |
| 342 let v = opA.preadjoint().apply(residual); |
629 let v = f.differential(&μ); |
| 343 let (μ_base_masses, mut μ_base_minus_γ0) = |
630 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport); |
| 344 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
631 |
| |
632 let mut attempts = 0; |
| 345 |
633 |
| 346 // Solve finite-dimensional subproblem several times until the dual variable for the |
634 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 347 // regularisation term conforms to the assumptions made for the transport above. |
635 // regularisation term conforms to the assumptions made for the transport above. |
| 348 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { |
636 let (maybe_d, _within_tolerances, mut τv̆, μ̆) = 'adapt_transport: loop { |
| |
637 // Set initial guess for μ=μ^{k+1}. |
| |
638 γ.μ̆_into(&mut μ); |
| |
639 let μ̆ = μ.clone(); |
| |
640 |
| 349 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
641 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
| 350 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
642 //let residual_μ̆ = calculate_residual2(&γ1, &μ0_minus_γ0, opA, b); |
| 351 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
643 // TODO: this could be optimised by doing the differential like the |
| |
644 // old residual2. |
| |
645 // NOTE: This assumes that μ = γ1 |
| |
646 let mut τv̆ = f.differential(&μ̆) * τ; |
| 352 |
647 |
| 353 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
648 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 354 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
649 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
| 355 &mut μ, |
650 &mut μ, |
| 356 &mut τv̆, |
651 &mut τv̆, |
| 357 &γ1, |
|
| 358 Some(&μ_base_minus_γ0), |
|
| 359 τ, |
652 τ, |
| 360 ε, |
653 ε, |
| 361 &config.insertion, |
654 &config.insertion, |
| 362 ®, |
655 ®, |
| 363 &state, |
656 &state, |
| 364 &mut stats, |
657 &mut stats, |
| 365 ); |
658 )?; |
| 366 |
659 |
| 367 // A posteriori transport adaptation. |
660 // A posteriori transport adaptation. |
| 368 if aposteriori_transport( |
661 if γ.aposteriori_transport(&μ, &μ̆, &mut τv̆, None, ε, &config.transport, &mut attempts) |
| 369 &mut γ1, |
662 { |
| 370 &mut μ, |
663 break 'adapt_transport (maybe_d, within_tolerances, τv̆, μ̆); |
| 371 &mut μ_base_minus_γ0, |
664 } |
| 372 &μ_base_masses, |
665 |
| 373 None, |
666 stats.get_transport_mut().readjustment_iters += 1; |
| 374 ε, |
|
| 375 &config.transport, |
|
| 376 ) { |
|
| 377 break 'adapt_transport (maybe_d, within_tolerances, τv̆); |
|
| 378 } |
|
| 379 }; |
667 }; |
| 380 |
668 |
| 381 stats.untransported_fraction = Some({ |
669 γ.get_transport_stats(&mut stats, &μ); |
| 382 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 383 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
|
| 384 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
|
| 385 (a + μ_base_minus_γ0.norm(Radon), b + source) |
|
| 386 }); |
|
| 387 stats.transport_error = Some({ |
|
| 388 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 389 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
|
| 390 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
|
| 391 }); |
|
| 392 |
670 |
| 393 // Merge spikes. |
671 // Merge spikes. |
| 394 // This crucially expects the merge routine to be stable with respect to spike locations, |
672 // This crucially expects the merge routine to be stable with respect to spike locations, |
| 395 // and not to performing any pruning. That is be to done below simultaneously for γ. |
673 // and not to performing any pruning. That is be to done below simultaneously for γ. |
| 396 let ins = &config.insertion; |
674 if config.insertion.merge_now(&state) { |
| 397 if ins.merge_now(&state) { |
|
| 398 stats.merged += prox_penalty.merge_spikes( |
675 stats.merged += prox_penalty.merge_spikes( |
| 399 &mut μ, |
676 &mut μ, |
| 400 &mut τv̆, |
677 &mut τv̆, |
| 401 &γ1, |
678 &μ̆, |
| 402 Some(&μ_base_minus_γ0), |
|
| 403 τ, |
679 τ, |
| 404 ε, |
680 ε, |
| 405 ins, |
681 &config.insertion, |
| 406 ®, |
682 ®, |
| 407 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
683 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), |
| 408 ); |
684 ); |
| 409 } |
685 } |
| 410 |
686 |
| 411 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
687 γ.prune_compat(&mut μ, &mut stats); |
| 412 // latter needs to be pruned when μ is. |
|
| 413 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
|
| 414 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
|
| 415 if μ_new.len() != μ.len() { |
|
| 416 let mut μ_iter = μ.iter_spikes(); |
|
| 417 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); |
|
| 418 stats.pruned += μ.len() - μ_new.len(); |
|
| 419 μ = μ_new; |
|
| 420 } |
|
| 421 |
|
| 422 // Update residual |
|
| 423 residual = calculate_residual(&μ, opA, b); |
|
| 424 |
688 |
| 425 let iter = state.iteration(); |
689 let iter = state.iteration(); |
| 426 stats.this_iters += 1; |
690 stats.this_iters += 1; |
| 427 |
691 |
| 428 // Give statistics if requested |
692 // Give statistics if requested |
| 429 state.if_verbose(|| { |
693 state.if_verbose(|| { |
| 430 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
694 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
| 431 full_stats( |
695 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 432 &residual, |
|
| 433 &μ, |
|
| 434 ε, |
|
| 435 std::mem::replace(&mut stats, IterInfo::new()), |
|
| 436 ) |
|
| 437 }); |
696 }); |
| 438 |
697 |
| 439 // Update main tolerance for next iteration |
698 // Update main tolerance for next iteration |
| 440 ε = tolerance.update(ε, iter); |
699 ε = tolerance.update(ε, iter); |
| 441 } |
700 } |
| 442 |
701 |
| 443 postprocess(μ, &config.insertion, L2Squared, opA, b) |
702 //postprocess(μ, &config.insertion, f) |
| 444 } |
703 postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃)) |
| |
704 } |