Mon, 06 Jan 2025 21:37:03 -0500
Attempt to do more Serialize / Deserialize but run into csv problems
| 32 | 1 | /*! |
| 2 | Solver for the point source localisation problem using a sliding | |
| 3 | forward-backward splitting method. | |
| 4 | */ | |
| 5 | ||
| 6 | use numeric_literals::replace_float_literals; | |
| 7 | use serde::{Serialize, Deserialize}; | |
| 8 | //use colored::Colorize; | |
| 9 | //use nalgebra::{DVector, DMatrix}; | |
| 10 | use itertools::izip; | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
11 | use std::iter::Iterator; |
| 32 | 12 | |
| 35 | 13 | use alg_tools::iterate::AlgIteratorFactory; |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
14 | use alg_tools::euclidean::Euclidean; |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
15 | use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
| 35 | 16 | use alg_tools::norms::Norm; |
| 32 | 17 | use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
18 | use alg_tools::norms::L2; |
| 32 | 19 | |
| 20 | use crate::types::*; | |
| 35 | 21 | use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
22 | use crate::measures::merging::SpikeMerging; |
| 35 | 23 | use crate::forward_model::{ |
| 24 | ForwardModel, | |
| 25 | AdjointProductBoundedBy, | |
| 26 | LipschitzValues, | |
| 27 | }; | |
| 32 | 28 | //use crate::tolerance::Tolerance; |
| 29 | use crate::plot::{ | |
| 30 | SeqPlotter, | |
| 31 | Plotting, | |
| 32 | PlotLookup | |
| 33 | }; | |
| 34 | use crate::fb::*; | |
| 35 | use crate::regularisation::SlidingRegTerm; | |
| 36 | use crate::dataterm::{ | |
| 37 | L2Squared, | |
| 38 | //DataTerm, | |
| 39 | calculate_residual, | |
| 40 | calculate_residual2, | |
| 41 | }; | |
| 35 | 42 | //use crate::transport::TransportLipschitz; |
| 43 | ||
| 44 | /// Transport settings for [`pointsource_sliding_fb_reg`]. | |
| 45 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
| 46 | #[serde(default)] | |
| 47 | pub struct TransportConfig<F : Float> { | |
| 48 | /// Transport step length $θ$ normalised to $(0, 1)$. | |
| 49 | pub θ0 : F, | |
| 50 | /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. | |
| 51 | pub adaptation : F, | |
| 52 | /// Transport tolerance wrt. ω | |
| 53 | pub tolerance_ω : F, | |
| 54 | /// Transport tolerance wrt. ∇v | |
| 55 | pub tolerance_dv : F, | |
| 56 | } | |
| 57 | ||
| 58 | #[replace_float_literals(F::cast_from(literal))] | |
| 59 | impl <F : Float> TransportConfig<F> { | |
| 60 | /// Check that the parameters are ok. Panics if not. | |
| 61 | pub fn check(&self) { | |
| 62 | assert!(self.θ0 > 0.0); | |
| 63 | assert!(0.0 < self.adaptation && self.adaptation < 1.0); | |
| 64 | assert!(self.tolerance_dv > 0.0); | |
| 65 | assert!(self.tolerance_ω > 0.0); | |
| 66 | } | |
| 67 | } | |
| 68 | ||
| 69 | #[replace_float_literals(F::cast_from(literal))] | |
| 70 | impl<F : Float> Default for TransportConfig<F> { | |
| 71 | fn default() -> Self { | |
| 72 | TransportConfig { | |
| 73 | θ0 : 0.01, | |
| 74 | adaptation : 0.9, | |
| 75 | tolerance_ω : 1000.0, // TODO: no idea what this should be | |
| 76 | tolerance_dv : 1000.0, // TODO: no idea what this should be | |
| 77 | } | |
| 78 | } | |
| 79 | } | |
| 32 | 80 | |
| 81 | /// Settings for [`pointsource_sliding_fb_reg`]. | |
| 82 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
| 83 | #[serde(default)] | |
| 84 | pub struct SlidingFBConfig<F : Float> { | |
| 85 | /// Step length scaling | |
| 86 | pub τ0 : F, | |
| 35 | 87 | /// Transport parameters |
| 88 | pub transport : TransportConfig<F>, | |
| 32 | 89 | /// Generic parameters |
| 90 | pub insertion : FBGenericConfig<F>, | |
| 91 | } | |
| 92 | ||
| 93 | #[replace_float_literals(F::cast_from(literal))] | |
| 94 | impl<F : Float> Default for SlidingFBConfig<F> { | |
| 95 | fn default() -> Self { | |
| 96 | SlidingFBConfig { | |
| 97 | τ0 : 0.99, | |
| 35 | 98 | transport : Default::default(), |
| 32 | 99 | insertion : Default::default() |
| 100 | } | |
| 101 | } | |
| 102 | } | |
| 103 | ||
| 35 | 104 | /// Internal type of adaptive transport step length calculation |
| 105 | pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> { | |
| 106 | /// Fixed, known step length | |
| 107 | Fixed(F), | |
| 108 | /// Adaptive step length, only wrt. maximum transport. | |
| 109 | /// Content of `l` depends on use case, while `g` calculates the step length from `l`. | |
| 110 | AdaptiveMax{ l : F, max_transport : F, g : G }, | |
| 111 | /// Adaptive step length. | |
| 112 | /// Content of `l` depends on use case, while `g` calculates the step length from `l`. | |
| 113 | FullyAdaptive{ l : F, max_transport : F, g : G }, | |
| 114 | } | |
| 115 | ||
| 116 | /// Constrution and a priori transport adaptation. | |
| 32 | 117 | #[replace_float_literals(F::cast_from(literal))] |
| 35 | 118 | pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>( |
| 119 | γ1 : &mut RNDM<F, N>, | |
| 120 | μ : &mut RNDM<F, N>, | |
| 121 | opAapply : impl Fn(&RNDM<F, N>) -> Observable, | |
| 122 | ε : F, | |
| 123 | τ : F, | |
| 124 | θ_or_adaptive : &mut TransportStepLength<F, G>, | |
| 125 | opAnorm : F, | |
| 126 | v : D, | |
| 127 | tconfig : &TransportConfig<F> | |
| 128 | ) -> (Vec<F>, RNDM<F, N>) | |
| 129 | where | |
| 130 | F : Float + ToNalgebraRealField, | |
| 131 | G : Fn(F, F) -> F, | |
| 132 | Observable : Euclidean<F, Output=Observable>, | |
| 133 | for<'a> &'a Observable : Instance<Observable>, | |
| 134 | //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
135 | D : DifferentiableRealMapping<F, N>, |
| 35 | 136 | { |
| 137 | ||
| 138 | use TransportStepLength::*; | |
| 139 | ||
| 140 | // Save current base point and shift μ to new positions. Idea is that | |
| 141 | // μ_base(_masses) = μ^k (vector of masses) | |
| 142 | // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} | |
| 143 | // γ1 = π_♯^1γ^{k+1} | |
| 144 | // μ = μ^{k+1} | |
| 145 | let μ_base_masses : Vec<F> = μ.iter_masses().collect(); | |
| 146 | let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. | |
| 147 | // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates | |
| 148 | //let mut sum_norm_dv = 0.0; | |
| 149 | let γ_prev_len = γ1.len(); | |
| 150 | assert!(μ.len() >= γ_prev_len); | |
| 151 | γ1.extend(μ[γ_prev_len..].iter().cloned()); | |
| 152 | ||
| 153 | // Calculate initial transport and step length. | |
| 154 | // First calculate initial transported weights | |
| 155 | for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { | |
| 156 | // If old transport has opposing sign, the new transport will be none. | |
| 157 | ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { | |
| 158 | 0.0 | |
| 159 | } else { | |
| 160 | δ.α | |
| 161 | }; | |
| 162 | }; | |
| 163 | ||
| 164 | // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ. | |
| 165 | // 1. Calculate transport rays. | |
| 166 | // If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it. | |
| 167 | match *θ_or_adaptive { | |
| 168 | Fixed(θ) => { | |
| 169 | let θτ = τ * θ; | |
| 170 | for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { | |
| 171 | ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); | |
| 172 | } | |
| 173 | }, | |
| 174 | AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => { | |
| 175 | *max_transport = max_transport.max(γ1.norm(Radon)); | |
| 176 | let θτ = τ * calculate_θ(ℓ_v, *max_transport); | |
| 177 | for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { | |
| 178 | ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); | |
| 179 | } | |
| 180 | }, | |
| 181 | FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { | |
| 182 | *max_transport = max_transport.max(γ1.norm(Radon)); | |
| 183 | let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); | |
| 184 | loop { | |
| 185 | let θτ = τ * θ; | |
| 186 | for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { | |
| 187 | let dv_x = v.differential(&δ.x); | |
| 188 | ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ); | |
| 189 | // Estimate Lipschitz factor of ∇v | |
| 190 | let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2(); | |
| 191 | *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); | |
| 192 | } | |
| 193 | let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport); | |
| 194 | if new_θ <= θ { | |
| 195 | break | |
| 196 | } | |
| 197 | θ = new_θ; | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
198 | } |
| 32 | 199 | } |
| 35 | 200 | } |
| 201 | ||
| 202 | // 2. Adjust transport mass, if needed. | |
| 203 | // This tries to remove the smallest transport masses first. | |
| 204 | if true { | |
| 205 | // Alternative 1 : subtract same amount from all transport rays until reaching zero | |
| 206 | loop { | |
| 207 | let nr =γ1.norm(Radon); | |
| 208 | let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2(); | |
| 209 | if n <= 0.0 || nr <= 0.0 { | |
| 210 | break | |
| 211 | } | |
| 212 | let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); | |
| 213 | if reduction_needed <= 0.0 { | |
| 214 | break | |
| 215 | } | |
| 216 | let (min_nonzero, n_nonzero) = γ1.iter_masses() | |
| 217 | .map(|α| α.abs()) | |
| 218 | .filter(|α| *α > F::EPSILON) | |
| 219 | .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1)); | |
| 220 | assert!(n_nonzero > 0); | |
| 221 | // Reduction that can be done in all nonzero spikes simultaneously | |
| 222 | let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero); | |
| 223 | for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) { | |
| 224 | ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0); | |
| 225 | δ.α = ρ.α; | |
| 226 | } | |
| 227 | if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed { | |
| 228 | break | |
| 229 | } | |
| 230 | } | |
| 231 | } else { | |
| 232 | // Alternative 2: first reduce transport rays with greater effect based on differential. | |
| 233 | // This is a an inefficient quick-and-dirty implementation. | |
| 234 | loop { | |
| 235 | let nr = γ1.norm(Radon); | |
| 236 | let a = opAapply(&*γ1)-opAapply(&*μ); | |
| 237 | let na = a.norm2(); | |
| 238 | let n = τ * 2.0 * opAnorm * na; | |
| 239 | if n <= 0.0 || nr <= 0.0 { | |
| 240 | break | |
| 241 | } | |
| 242 | let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); | |
| 243 | if reduction_needed <= 0.0 { | |
| 244 | break | |
| 245 | } | |
| 246 | let mut max_d = 0.0; | |
| 247 | let mut max_d_ind = 0; | |
| 248 | for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) { | |
| 249 | // Calculate differential of ‖A(γ₁-γ₀)‖‖γ‖ wrt. each spike | |
| 250 | let s = δ.α.signum(); | |
| 251 | // TODO: this is very inefficient implementation due to the limitations | |
| 252 | // of the closure parameters. | |
| 253 | let δ1 = DiscreteMeasure::from([(ρ.x, s)]); | |
| 254 | let δ2 = DiscreteMeasure::from([(δ.x, s)]); | |
| 255 | let a_part = opAapply(&δ1)-opAapply(&δ2); | |
| 256 | let d = a.dot(&a_part)/na * nr + 2.0 * na; | |
| 257 | if d > max_d { | |
| 258 | max_d = d; | |
| 259 | max_d_ind = i; | |
| 260 | } | |
| 261 | } | |
| 262 | // Just set mass to zero for transport ray with greater differential | |
| 263 | assert!(max_d > 0.0); | |
| 264 | γ1[max_d_ind].α = 0.0; | |
| 265 | μ[max_d_ind].α = 0.0; | |
| 266 | } | |
| 267 | } | |
| 268 | ||
| 269 | // Set initial guess for μ=μ^{k+1}. | |
| 270 | for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) { | |
| 271 | if ρ.α.abs() > F::EPSILON { | |
| 272 | δ.x = ρ.x; | |
| 273 | //δ.α = ρ.α; // already set above | |
| 274 | } else { | |
| 275 | δ.α = β; | |
| 276 | } | |
| 277 | } | |
| 278 | // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) | |
| 279 | μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) | |
| 280 | .map(|(&a,b)| a - b)); | |
| 281 | (μ_base_masses, μ_base_minus_γ0) | |
| 282 | } | |
| 283 | ||
| 284 | /// A posteriori transport adaptation. | |
| 285 | #[replace_float_literals(F::cast_from(literal))] | |
| 286 | pub(crate) fn aposteriori_transport<F, const N : usize>( | |
| 287 | γ1 : &mut RNDM<F, N>, | |
| 288 | μ : &mut RNDM<F, N>, | |
| 289 | μ_base_minus_γ0 : &mut RNDM<F, N>, | |
| 290 | μ_base_masses : &Vec<F>, | |
| 291 | ε : F, | |
| 292 | tconfig : &TransportConfig<F> | |
| 293 | ) -> bool | |
| 294 | where F : Float + ToNalgebraRealField { | |
| 295 | ||
| 296 | // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, | |
| 297 | // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 | |
| 298 | // at that point to zero, and retry. | |
| 299 | let mut all_ok = true; | |
| 300 | for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { | |
| 301 | if α_μ == 0.0 && *α_γ1 != 0.0 { | |
| 302 | all_ok = false; | |
| 303 | *α_γ1 = 0.0; | |
| 304 | } | |
| 305 | } | |
| 306 | ||
| 307 | // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). | |
| 308 | // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, | |
| 309 | // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. | |
| 310 | let nγ = γ1.norm(Radon); | |
| 311 | let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1); | |
| 312 | let t = ε * tconfig.tolerance_ω; | |
| 313 | if nγ*nΔ > t { | |
| 314 | // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, | |
| 315 | // this will guarantee that eventually ‖γ‖ decreases sufficiently that we | |
| 316 | // will not enter here. | |
| 317 | *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); | |
| 318 | all_ok = false | |
| 319 | } | |
| 320 | ||
| 321 | if !all_ok { | |
| 322 | // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} | |
| 323 | μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) | |
| 324 | .map(|(&a,b)| a - b)); | |
| 325 | ||
| 326 | } | |
| 327 | ||
| 328 | all_ok | |
| 32 | 329 | } |
| 330 | ||
| 331 | /// Iteratively solve the pointsource localisation problem using sliding forward-backward | |
| 332 | /// splitting | |
| 333 | /// | |
| 35 | 334 | /// The parametrisation is as for [`pointsource_fb_reg`]. |
| 32 | 335 | /// Inertia is currently not supported. |
| 336 | #[replace_float_literals(F::cast_from(literal))] | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
337 | pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
338 | opA : &A, |
| 32 | 339 | b : &A::Observable, |
| 340 | reg : Reg, | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
341 | prox_penalty : &P, |
| 35 | 342 | config : &SlidingFBConfig<F>, |
| 32 | 343 | iterator : I, |
| 344 | mut plotter : SeqPlotter<F, N>, | |
| 35 | 345 | ) -> RNDM<F, N> |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
346 | where |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
347 | F : Float + ToNalgebraRealField, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
348 | I : AlgIteratorFactory<IterInfo<F, N>>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
349 | A : ForwardModel<RNDM<F, N>, F> |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
350 | + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
351 | //+ TransportLipschitz<L2Squared, FloatType=F>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
352 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
353 | for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
354 | A::PreadjointCodomain : DifferentiableRealMapping<F, N>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
355 | RNDM<F, N> : SpikeMerging<F>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
356 | Reg : SlidingRegTerm<F, N>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
357 | P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
358 | PlotLookup : Plotting<N>, |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
359 | { |
| 32 | 360 | |
| 35 | 361 | // Check parameters |
| 362 | assert!(config.τ0 > 0.0, "Invalid step length parameter"); | |
| 363 | config.transport.check(); | |
| 32 | 364 | |
| 365 | // Initialise iterates | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
366 | let mut μ = DiscreteMeasure::new(); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
367 | let mut γ1 = DiscreteMeasure::new(); |
| 35 | 368 | let mut residual = -b; // Has to equal $Aμ-b$. |
| 369 | ||
| 370 | // Set up parameters | |
| 371 | let opAnorm = opA.opnorm_bound(Radon, L2); | |
| 372 | //let max_transport = config.max_transport.scale | |
| 373 | // * reg.radon_norm_bound(b.norm2_squared() / 2.0); | |
| 374 | //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; | |
| 375 | let ℓ = 0.0; | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
376 | let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| 35 | 377 | let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
| 378 | let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { | |
| 379 | // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v | |
| 380 | // (the uniform Lipschitz factor of ∇v). | |
| 381 | // We assume that the residual is decreasing. | |
| 382 | Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)), | |
| 383 | None => TransportStepLength::FullyAdaptive { | |
| 384 | l : 0.0, | |
| 385 | max_transport : 0.0, | |
| 386 | g : calculate_θ | |
| 387 | }, | |
| 388 | }; | |
| 389 | // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled | |
| 390 | // by τ compared to the conditional gradient approach. | |
| 391 | let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); | |
| 392 | let mut ε = tolerance.initial(); | |
| 393 | ||
| 394 | // Statistics | |
| 395 | let full_stats = |residual : &A::Observable, | |
| 396 | μ : &RNDM<F, N>, | |
| 397 | ε, stats| IterInfo { | |
| 398 | value : residual.norm2_squared_div2() + reg.apply(μ), | |
| 399 | n_spikes : μ.len(), | |
| 400 | ε, | |
| 401 | // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), | |
| 402 | .. stats | |
| 403 | }; | |
| 32 | 404 | let mut stats = IterInfo::new(); |
| 405 | ||
| 406 | // Run the algorithm | |
| 35 | 407 | for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
| 408 | // Calculate initial transport | |
| 409 | let v = opA.preadjoint().apply(residual); | |
| 410 | let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( | |
| 411 | &mut γ1, &mut μ, |ν| opA.apply(ν), | |
| 412 | ε, τ, &mut θ_or_adaptive, opAnorm, | |
| 413 | v, &config.transport, | |
| 414 | ); | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
415 | |
| 32 | 416 | // Solve finite-dimensional subproblem several times until the dual variable for the |
| 417 | // regularisation term conforms to the assumptions made for the transport above. | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
418 | let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
| 35 | 419 | // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
420 | let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
421 | let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
| 32 | 422 | |
| 423 | // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
424 | let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
425 | &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), |
| 35 | 426 | τ, ε, &config.insertion, |
| 427 | ®, &state, &mut stats, | |
| 32 | 428 | ); |
| 429 | ||
| 35 | 430 | // A posteriori transport adaptation. |
| 431 | if aposteriori_transport( | |
| 432 | &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, | |
| 433 | ε, &config.transport | |
| 434 | ) { | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
435 | break 'adapt_transport (maybe_d, within_tolerances, τv̆) |
| 32 | 436 | } |
| 437 | }; | |
| 438 | ||
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
439 | stats.untransported_fraction = Some({ |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
440 | assert_eq!(μ_base_masses.len(), γ1.len()); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
441 | let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
442 | let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
443 | (a + μ_base_minus_γ0.norm(Radon), b + source) |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
444 | }); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
445 | stats.transport_error = Some({ |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
446 | assert_eq!(μ_base_masses.len(), γ1.len()); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
447 | let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
| 35 | 448 | (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
449 | }); |
| 32 | 450 | |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
451 | // // Merge spikes. |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
452 | // // This expects the prune below to prune γ. |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
453 | // // TODO: This may not work correctly in all cases. |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
454 | // let ins = &config.insertion; |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
455 | // if ins.merge_now(&state) { |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
456 | // if let SpikeMergingMethod::None = ins.merging { |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
457 | // } else { |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
458 | // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
459 | // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
460 | // let mut d = &τv̆ + op𝒟.preapply(ν); |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
461 | // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
462 | // }); |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
463 | // } |
|
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
464 | // } |
| 35 | 465 | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
466 | // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
467 | // latter needs to be pruned when μ is. |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
468 | // TODO: This could do with a two-vector Vec::retain to avoid copies. |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
469 | let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
470 | if μ_new.len() != μ.len() { |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
471 | let mut μ_iter = μ.iter_spikes(); |
|
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
472 | γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); |
| 35 | 473 | stats.pruned += μ.len() - μ_new.len(); |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
474 | μ = μ_new; |
| 32 | 475 | } |
| 476 | ||
| 477 | // Update residual | |
| 478 | residual = calculate_residual(&μ, opA, b); | |
| 479 | ||
| 35 | 480 | let iter = state.iteration(); |
| 32 | 481 | stats.this_iters += 1; |
| 482 | ||
| 35 | 483 | // Give statistics if requested |
| 32 | 484 | state.if_verbose(|| { |
|
37
c5d8bd1a7728
Generic proximal penalty support
Tuomo Valkonen <tuomov@iki.fi>
parents:
35
diff
changeset
|
485 | plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
| 35 | 486 | full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 487 | }); | |
| 32 | 488 | |
| 35 | 489 | // Update main tolerance for next iteration |
| 490 | ε = tolerance.update(ε, iter); | |
| 491 | } | |
| 492 | ||
| 493 | postprocess(μ, &config.insertion, L2Squared, opA, b) | |
| 32 | 494 | } |