| 4 */ |
4 */ |
| 5 |
5 |
| 6 use crate::fb::*; |
6 use crate::fb::*; |
| 7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; |
7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; |
| 8 use crate::measures::merging::SpikeMerging; |
8 use crate::measures::merging::SpikeMerging; |
| 9 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
9 use crate::measures::{DiscreteMeasure, RNDM}; |
| 10 use crate::plot::Plotter; |
10 use crate::plot::Plotter; |
| 11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; |
11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; |
| 12 use crate::regularisation::SlidingRegTerm; |
12 use crate::regularisation::SlidingRegTerm; |
| 13 use crate::sliding_fb::{ |
13 use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength}; |
| 14 aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength, |
|
| 15 }; |
|
| 16 use crate::types::*; |
14 use crate::types::*; |
| 17 use alg_tools::convex::{Conjugable, Prox, Zero}; |
15 use alg_tools::convex::{Conjugable, Prox, Zero}; |
| 18 use alg_tools::direct_product::Pair; |
16 use alg_tools::direct_product::Pair; |
| 19 use alg_tools::error::DynResult; |
17 use alg_tools::error::DynResult; |
| 20 use alg_tools::euclidean::ClosedEuclidean; |
18 use alg_tools::euclidean::ClosedEuclidean; |
| 22 use alg_tools::linops::{ |
20 use alg_tools::linops::{ |
| 23 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, |
21 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, |
| 24 }; |
22 }; |
| 25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; |
23 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; |
| 26 use alg_tools::nalgebra_support::ToNalgebraRealField; |
24 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 27 use alg_tools::norms::{Norm, L2}; |
25 use alg_tools::norms::L2; |
| 28 use anyhow::ensure; |
26 use anyhow::ensure; |
| 29 use numeric_literals::replace_float_literals; |
27 use numeric_literals::replace_float_literals; |
| 30 use serde::{Deserialize, Serialize}; |
28 use serde::{Deserialize, Serialize}; |
| 31 //use colored::Colorize; |
|
| 32 //use nalgebra::{DVector, DMatrix}; |
|
| 33 use std::iter::Iterator; |
|
| 34 |
29 |
| 35 /// Settings for [`pointsource_sliding_pdps_pair`]. |
30 /// Settings for [`pointsource_sliding_pdps_pair`]. |
| 36 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 37 #[serde(default)] |
32 #[serde(default)] |
| 38 pub struct SlidingPDPSConfig<F: Float> { |
33 pub struct SlidingPDPSConfig<F: Float> { |
| 146 );*/ |
141 );*/ |
| 147 config.transport.check()?; |
142 config.transport.check()?; |
| 148 |
143 |
| 149 // Initialise iterates |
144 // Initialise iterates |
| 150 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
145 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 151 let mut γ1 = DiscreteMeasure::new(); |
146 let mut γ = Transport::new(); |
| 152 //let zero_z = z.similar_origin(); |
147 //let zero_z = z.similar_origin(); |
| 153 |
148 |
| 154 // Set up parameters |
149 // Set up parameters |
| 155 // TODO: maybe this PairNorm doesn't make sense here? |
150 // TODO: maybe this PairNorm doesn't make sense here? |
| 156 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
151 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
| 184 ensure!(β < 1.0); |
179 ensure!(β < 1.0); |
| 185 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
180 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
| 186 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
181 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
| 187 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
182 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
| 188 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
183 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
| 189 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); |
184 |
| 190 let transport_lip = maybe_transport_lip?; |
185 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) { |
| 191 let calculate_θ = |ℓ_F, max_transport| { |
186 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0), |
| 192 let ℓ_r = transport_lip * max_transport; |
187 (maybe_ℓ_F, Ok(transport_lip)) => { |
| 193 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) |
188 let calculate_θτ = move |ℓ_F, max_transport| { |
| 194 }; |
189 let ℓ_r = transport_lip * max_transport; |
| 195 let mut θ_or_adaptive = match maybe_ℓ_F { |
190 config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ) |
| 196 // We assume that the residual is decreasing. |
191 }; |
| 197 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { |
192 match maybe_ℓ_F { |
| 198 l: ℓ_F, // TODO: could estimate computing the real reesidual |
193 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { |
| 199 max_transport: 0.0, |
194 l: ℓ_F, // TODO: could estimate computing the real reesidual |
| 200 g: calculate_θ, |
195 max_transport: 0.0, |
| 201 }, |
196 g: calculate_θτ, |
| 202 Err(_) => { |
197 }, |
| 203 TransportStepLength::FullyAdaptive { |
198 Err(_) => TransportStepLength::FullyAdaptive { |
| 204 l: F::EPSILON, max_transport: 0.0, g: calculate_θ |
199 l: F::EPSILON, // Start with something very small to estimate differentials |
| |
200 max_transport: 0.0, |
| |
201 g: calculate_θτ, |
| |
202 }, |
| 205 } |
203 } |
| 206 } |
204 } |
| 207 }; |
205 }; |
| 208 // Acceleration is not currently supported |
206 // Acceleration is not currently supported |
| 209 // let γ = dataterm.factor_of_strong_convexity(); |
207 // let γ = dataterm.factor_of_strong_convexity(); |
| 241 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
239 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
| 242 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
240 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
| 243 |
241 |
| 244 //dbg!(&μ); |
242 //dbg!(&μ); |
| 245 |
243 |
| 246 let (μ_base_masses, mut μ_base_minus_γ0) = |
244 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); |
| 247 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
245 |
| |
246 let mut attempts = 0; |
| 248 |
247 |
| 249 // Solve finite-dimensional subproblem several times until the dual variable for the |
248 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 250 // regularisation term conforms to the assumptions made for the transport above. |
249 // regularisation term conforms to the assumptions made for the transport above. |
| 251 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
250 let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop { |
| |
251 // Set initial guess for μ=μ^{k+1}. |
| |
252 γ.μ̆_into(&mut μ); |
| |
253 let μ̆ = μ.clone(); |
| |
254 |
| 252 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
255 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
| 253 // let residual_μ̆ = |
256 let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ; |
| 254 // calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); |
|
| 255 // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
|
| 256 // TODO: might be able to optimise the measure sum working as calculate_residual2 above. |
|
| 257 let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ; |
|
| 258 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
257 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
| 259 |
258 |
| 260 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
259 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 261 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
260 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
| 262 &mut μ, |
261 &mut μ, |
| 263 &mut τv̆, |
262 &mut τv̆, |
| 264 &γ1, |
|
| 265 Some(&μ_base_minus_γ0), |
|
| 266 τ, |
263 τ, |
| 267 ε, |
264 ε, |
| 268 &config.insertion, |
265 &config.insertion, |
| 269 ®, |
266 ®, |
| 270 &state, |
267 &state, |
| 275 let mut z_new = τz̆; |
272 let mut z_new = τz̆; |
| 276 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
273 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
| 277 z_new = fnR.prox(σ_p, z_new + &z); |
274 z_new = fnR.prox(σ_p, z_new + &z); |
| 278 |
275 |
| 279 // A posteriori transport adaptation. |
276 // A posteriori transport adaptation. |
| 280 if aposteriori_transport( |
277 if γ.aposteriori_transport( |
| 281 &mut γ1, |
278 &μ, |
| 282 &mut μ, |
279 &μ̆, |
| 283 &mut μ_base_minus_γ0, |
280 &mut τv̆, |
| 284 &μ_base_masses, |
|
| 285 Some(z_new.dist2(&z)), |
281 Some(z_new.dist2(&z)), |
| 286 ε, |
282 ε, |
| 287 &config.transport, |
283 &config.transport, |
| |
284 &mut attempts, |
| 288 ) { |
285 ) { |
| 289 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); |
286 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆); |
| 290 } |
287 } |
| 291 }; |
288 }; |
| 292 |
289 |
| 293 stats.untransported_fraction = Some({ |
290 γ.get_transport_stats(&mut stats, &μ); |
| 294 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 295 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
|
| 296 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
|
| 297 (a + μ_base_minus_γ0.norm(Radon), b + source) |
|
| 298 }); |
|
| 299 stats.transport_error = Some({ |
|
| 300 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 301 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
|
| 302 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
|
| 303 }); |
|
| 304 |
291 |
| 305 // Merge spikes. |
292 // Merge spikes. |
| 306 // This crucially expects the merge routine to be stable with respect to spike locations, |
293 // This crucially expects the merge routine to be stable with respect to spike locations, |
| 307 // and not to performing any pruning. That is be to done below simultaneously for γ. |
294 // and not to performing any pruning. That is be to done below simultaneously for γ. |
| 308 let ins = &config.insertion; |
295 if config.insertion.merge_now(&state) { |
| 309 if ins.merge_now(&state) { |
|
| 310 stats.merged += prox_penalty.merge_spikes_no_fitness( |
296 stats.merged += prox_penalty.merge_spikes_no_fitness( |
| 311 &mut μ, |
297 &mut μ, |
| 312 &mut τv̆, |
298 &mut τv̆, |
| 313 &γ1, |
299 &μ̆, |
| 314 Some(&μ_base_minus_γ0), |
|
| 315 τ, |
300 τ, |
| 316 ε, |
301 ε, |
| 317 ins, |
302 &config.insertion, |
| 318 ®, |
303 ®, |
| 319 //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), |
|
| 320 ); |
304 ); |
| 321 } |
305 } |
| 322 |
306 |
| 323 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
307 γ.prune_compat(&mut μ, &mut stats); |
| 324 // latter needs to be pruned when μ is. |
|
| 325 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
|
| 326 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
|
| 327 if μ_new.len() != μ.len() { |
|
| 328 let mut μ_iter = μ.iter_spikes(); |
|
| 329 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); |
|
| 330 stats.pruned += μ.len() - μ_new.len(); |
|
| 331 μ = μ_new; |
|
| 332 } |
|
| 333 |
308 |
| 334 // Do dual update |
309 // Do dual update |
| 335 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
310 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
| 336 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); |
311 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); |
| 337 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
312 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |