src/sliding_pdps.rs

branch
dev
changeset 63
7a8a55fd41c0
parent 61
4f468d35fa29
child 66
fe47ad484deb
equal deleted inserted replaced
61:4f468d35fa29 63:7a8a55fd41c0
4 */ 4 */
5 5
6 use crate::fb::*; 6 use crate::fb::*;
7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; 7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
8 use crate::measures::merging::SpikeMerging; 8 use crate::measures::merging::SpikeMerging;
9 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 9 use crate::measures::{DiscreteMeasure, RNDM};
10 use crate::plot::Plotter; 10 use crate::plot::Plotter;
11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; 11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
12 use crate::regularisation::SlidingRegTerm; 12 use crate::regularisation::SlidingRegTerm;
13 use crate::sliding_fb::{ 13 use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength};
14 aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength,
15 };
16 use crate::types::*; 14 use crate::types::*;
17 use alg_tools::convex::{Conjugable, Prox, Zero}; 15 use alg_tools::convex::{Conjugable, Prox, Zero};
18 use alg_tools::direct_product::Pair; 16 use alg_tools::direct_product::Pair;
19 use alg_tools::error::DynResult; 17 use alg_tools::error::DynResult;
20 use alg_tools::euclidean::ClosedEuclidean; 18 use alg_tools::euclidean::ClosedEuclidean;
22 use alg_tools::linops::{ 20 use alg_tools::linops::{
23 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, 21 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV,
24 }; 22 };
25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; 23 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
26 use alg_tools::nalgebra_support::ToNalgebraRealField; 24 use alg_tools::nalgebra_support::ToNalgebraRealField;
27 use alg_tools::norms::{Norm, L2}; 25 use alg_tools::norms::L2;
28 use anyhow::ensure; 26 use anyhow::ensure;
29 use numeric_literals::replace_float_literals; 27 use numeric_literals::replace_float_literals;
30 use serde::{Deserialize, Serialize}; 28 use serde::{Deserialize, Serialize};
31 //use colored::Colorize;
32 //use nalgebra::{DVector, DMatrix};
33 use std::iter::Iterator;
34 29
35 /// Settings for [`pointsource_sliding_pdps_pair`]. 30 /// Settings for [`pointsource_sliding_pdps_pair`].
36 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
37 #[serde(default)] 32 #[serde(default)]
38 pub struct SlidingPDPSConfig<F: Float> { 33 pub struct SlidingPDPSConfig<F: Float> {
146 );*/ 141 );*/
147 config.transport.check()?; 142 config.transport.check()?;
148 143
149 // Initialise iterates 144 // Initialise iterates
150 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); 145 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
151 let mut γ1 = DiscreteMeasure::new(); 146 let mut γ = Transport::new();
152 //let zero_z = z.similar_origin(); 147 //let zero_z = z.similar_origin();
153 148
154 // Set up parameters 149 // Set up parameters
155 // TODO: maybe this PairNorm doesn't make sense here? 150 // TODO: maybe this PairNorm doesn't make sense here?
156 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); 151 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2);
184 ensure!(β < 1.0); 179 ensure!(β < 1.0);
185 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: 180 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as:
186 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); 181 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM);
187 // The factor two in the manuscript disappears due to the definition of 𝚹 being 182 // The factor two in the manuscript disappears due to the definition of 𝚹 being
188 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. 183 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2.
189 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); 184
190 let transport_lip = maybe_transport_lip?; 185 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
191 let calculate_θ = |ℓ_F, max_transport| { 186 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
192 let ℓ_r = transport_lip * max_transport; 187 (maybe_ℓ_F, Ok(transport_lip)) => {
193 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) 188 let calculate_θτ = move |ℓ_F, max_transport| {
194 }; 189 let ℓ_r = transport_lip * max_transport;
195 let mut θ_or_adaptive = match maybe_ℓ_F { 190 config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ)
196 // We assume that the residual is decreasing. 191 };
197 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { 192 match maybe_ℓ_F {
198 l: ℓ_F, // TODO: could estimate computing the real reesidual 193 Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
199 max_transport: 0.0, 194 l: ℓ_F, // TODO: could estimate computing the real reesidual
200 g: calculate_θ, 195 max_transport: 0.0,
201 }, 196 g: calculate_θτ,
202 Err(_) => { 197 },
203 TransportStepLength::FullyAdaptive { 198 Err(_) => TransportStepLength::FullyAdaptive {
204 l: F::EPSILON, max_transport: 0.0, g: calculate_θ 199 l: F::EPSILON, // Start with something very small to estimate differentials
200 max_transport: 0.0,
201 g: calculate_θτ,
202 },
205 } 203 }
206 } 204 }
207 }; 205 };
208 // Acceleration is not currently supported 206 // Acceleration is not currently supported
209 // let γ = dataterm.factor_of_strong_convexity(); 207 // let γ = dataterm.factor_of_strong_convexity();
241 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. 239 // This is much easier with K_μ = 0, which is the only reason why are enforcing it.
242 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. 240 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0.
243 241
244 //dbg!(&μ); 242 //dbg!(&μ);
245 243
246 let (μ_base_masses, mut μ_base_minus_γ0) = 244 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v);
247 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); 245
246 let mut attempts = 0;
248 247
249 // Solve finite-dimensional subproblem several times until the dual variable for the 248 // Solve finite-dimensional subproblem several times until the dual variable for the
250 // regularisation term conforms to the assumptions made for the transport above. 249 // regularisation term conforms to the assumptions made for the transport above.
251 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { 250 let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop {
251 // Set initial guess for μ=μ^{k+1}.
252 γ.μ̆_into(&mut μ);
253 let μ̆ = μ.clone();
254
252 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 255 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
253 // let residual_μ̆ = 256 let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ;
254 // calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b);
255 // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
256 // TODO: might be able to optimise the measure sum working as calculate_residual2 above.
257 let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ;
258 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); 257 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
259 258
260 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 259 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
261 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 260 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
262 &mut μ, 261 &mut μ,
263 &mut τv̆, 262 &mut τv̆,
264 &γ1,
265 Some(&μ_base_minus_γ0),
266 τ, 263 τ,
267 ε, 264 ε,
268 &config.insertion, 265 &config.insertion,
269 &reg, 266 &reg,
270 &state, 267 &state,
275 let mut z_new = τz̆; 272 let mut z_new = τz̆;
276 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); 273 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
277 z_new = fnR.prox(σ_p, z_new + &z); 274 z_new = fnR.prox(σ_p, z_new + &z);
278 275
279 // A posteriori transport adaptation. 276 // A posteriori transport adaptation.
280 if aposteriori_transport( 277 if γ.aposteriori_transport(
281 &mut γ1, 278 &μ,
282 &mut μ, 279 &μ̆,
283 &mut μ_base_minus_γ0, 280 &mut τv̆,
284 &μ_base_masses,
285 Some(z_new.dist2(&z)), 281 Some(z_new.dist2(&z)),
286 ε, 282 ε,
287 &config.transport, 283 &config.transport,
284 &mut attempts,
288 ) { 285 ) {
289 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); 286 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆);
290 } 287 }
291 }; 288 };
292 289
293 stats.untransported_fraction = Some({ 290 γ.get_transport_stats(&mut stats, &μ);
294 assert_eq!(μ_base_masses.len(), γ1.len());
295 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
296 let source = μ_base_masses.iter().map(|v| v.abs()).sum();
297 (a + μ_base_minus_γ0.norm(Radon), b + source)
298 });
299 stats.transport_error = Some({
300 assert_eq!(μ_base_masses.len(), γ1.len());
301 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
302 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
303 });
304 291
305 // Merge spikes. 292 // Merge spikes.
306 // This crucially expects the merge routine to be stable with respect to spike locations, 293 // This crucially expects the merge routine to be stable with respect to spike locations,
307 // and not to performing any pruning. That is be to done below simultaneously for γ. 294 // and not to performing any pruning. That is be to done below simultaneously for γ.
308 let ins = &config.insertion; 295 if config.insertion.merge_now(&state) {
309 if ins.merge_now(&state) {
310 stats.merged += prox_penalty.merge_spikes_no_fitness( 296 stats.merged += prox_penalty.merge_spikes_no_fitness(
311 &mut μ, 297 &mut μ,
312 &mut τv̆, 298 &mut τv̆,
313 &γ1, 299 &μ̆,
314 Some(&μ_base_minus_γ0),
315 τ, 300 τ,
316 ε, 301 ε,
317 ins, 302 &config.insertion,
318 &reg, 303 &reg,
319 //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
320 ); 304 );
321 } 305 }
322 306
323 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 307 γ.prune_compat(&mut μ, &mut stats);
324 // latter needs to be pruned when μ is.
325 // TODO: This could do with a two-vector Vec::retain to avoid copies.
326 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
327 if μ_new.len() != μ.len() {
328 let mut μ_iter = μ.iter_spikes();
329 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
330 stats.pruned += μ.len() - μ_new.len();
331 μ = μ_new;
332 }
333 308
334 // Do dual update 309 // Do dual update
335 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] 310 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
336 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); 311 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0);
337 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b 312 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b

mercurial