src/sliding_pdps.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
equal deleted inserted replaced
37:c5d8bd1a7728 39:6316d68b58af
10 use std::iter::Iterator; 10 use std::iter::Iterator;
11 11
12 use alg_tools::iterate::AlgIteratorFactory; 12 use alg_tools::iterate::AlgIteratorFactory;
13 use alg_tools::euclidean::Euclidean; 13 use alg_tools::euclidean::Euclidean;
14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; 14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
15 use alg_tools::norms::Norm; 15 use alg_tools::norms::{Norm, Dist};
16 use alg_tools::direct_product::Pair; 16 use alg_tools::direct_product::Pair;
17 use alg_tools::nalgebra_support::ToNalgebraRealField; 17 use alg_tools::nalgebra_support::ToNalgebraRealField;
18 use alg_tools::linops::{ 18 use alg_tools::linops::{
19 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, 19 BoundedLinear, AXPY, GEMV, Adjointable, IdOp,
20 }; 20 };
43 TransportConfig, 43 TransportConfig,
44 TransportStepLength, 44 TransportStepLength,
45 initial_transport, 45 initial_transport,
46 aposteriori_transport, 46 aposteriori_transport,
47 }; 47 };
48 use crate::dataterm::{calculate_residual, calculate_residual2}; 48 use crate::dataterm::{
49 calculate_residual2,
50 calculate_residual,
51 };
52
49 53
50 /// Settings for [`pointsource_sliding_pdps_pair`]. 54 /// Settings for [`pointsource_sliding_pdps_pair`].
51 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 55 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
52 #[serde(default)] 56 #[serde(default)]
53 pub struct SlidingPDPSConfig<F : Float> { 57 pub struct SlidingPDPSConfig<F : Float> {
64 } 68 }
65 69
66 #[replace_float_literals(F::cast_from(literal))] 70 #[replace_float_literals(F::cast_from(literal))]
67 impl<F : Float> Default for SlidingPDPSConfig<F> { 71 impl<F : Float> Default for SlidingPDPSConfig<F> {
68 fn default() -> Self { 72 fn default() -> Self {
69 let τ0 = 0.99;
70 SlidingPDPSConfig { 73 SlidingPDPSConfig {
71 τ0, 74 τ0 : 0.99,
72 σd0 : 0.1, 75 σd0 : 0.05,
73 σp0 : 0.99, 76 σp0 : 0.99,
74 transport : Default::default(), 77 transport : TransportConfig { θ0 : 0.1, ..Default::default()},
75 insertion : Default::default() 78 insertion : Default::default()
76 } 79 }
77 } 80 }
78 } 81 }
79 82
132 + GEMV<F, Z> 135 + GEMV<F, Z>
133 + Adjointable<Z, Y, AdjointCodomain = Z>, 136 + Adjointable<Z, Y, AdjointCodomain = Z>,
134 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, 137 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>,
135 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, 138 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd,
136 for<'b> &'b Y : Instance<Y>, 139 for<'b> &'b Y : Instance<Y>,
137 Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2>, 140 Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2> + Dist<F, L2>,
138 for<'b> &'b Z : Instance<Z>, 141 for<'b> &'b Z : Instance<Z>,
139 R : Prox<Z, Codomain=F>, 142 R : Prox<Z, Codomain=F>,
140 H : Conjugable<Y, F, Codomain=F>, 143 H : Conjugable<Y, F, Codomain=F>,
141 for<'b> H::Conjugate<'b> : Prox<Y>, 144 for<'b> H::Conjugate<'b> : Prox<Y>,
142 { 145 {
200 l: ℓ_v0 * b.norm2(), 203 l: ℓ_v0 * b.norm2(),
201 max_transport : 0.0, 204 max_transport : 0.0,
202 g : calculate_θ 205 g : calculate_θ
203 }, 206 },
204 None => TransportStepLength::FullyAdaptive{ 207 None => TransportStepLength::FullyAdaptive{
205 l : 0.0, 208 l : F::EPSILON,
206 max_transport : 0.0, 209 max_transport : 0.0,
207 g : calculate_θ 210 g : calculate_θ
208 }, 211 },
209 }; 212 };
210 // Acceleration is not currently supported 213 // Acceleration is not currently supported
232 // Run the algorithm 235 // Run the algorithm
233 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { 236 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
234 // Calculate initial transport 237 // Calculate initial transport
235 let Pair(v, _) = opA.preadjoint().apply(&residual); 238 let Pair(v, _) = opA.preadjoint().apply(&residual);
236 //opKμ.preadjoint().apply_add(&mut v, y); 239 //opKμ.preadjoint().apply_add(&mut v, y);
237 let z_base = z.clone();
238 // We want to proceed as in Example 4.12 but with v and v̆ as in §5. 240 // We want to proceed as in Example 4.12 but with v and v̆ as in §5.
239 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have 241 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have
240 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, 242 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν,
241 // where A_ν^* becomes a multiplier. 243 // where A_ν^* becomes a multiplier.
242 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. 244 // This is much easier with K_μ = 0, which is the only reason why are enforcing it.
248 v, &config.transport, 250 v, &config.transport,
249 ); 251 );
250 252
251 // Solve finite-dimensional subproblem several times until the dual variable for the 253 // Solve finite-dimensional subproblem several times until the dual variable for the
252 // regularisation term conforms to the assumptions made for the transport above. 254 // regularisation term conforms to the assumptions made for the transport above.
253 let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { 255 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop {
254 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 256 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
255 let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), 257 let residual_μ̆ = calculate_residual2(Pair(&γ1, &z),
256 Pair(&μ_base_minus_γ0, &zero_z), 258 Pair(&μ_base_minus_γ0, &zero_z),
257 opA, b); 259 opA, b);
258 let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); 260 let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
259 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); 261 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
260 262
261 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 263 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
262 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( 264 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
263 &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), 265 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0),
264 τ, ε, &config.insertion, 266 τ, ε, &config.insertion,
265 &reg, &state, &mut stats, 267 &reg, &state, &mut stats,
266 ); 268 );
267 269
270 // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}}
271 let mut z_new = τz̆;
272 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ);
273 z_new = fnR.prox(σ_p, z_new + &z);
274
268 // A posteriori transport adaptation. 275 // A posteriori transport adaptation.
269 // TODO: this does not properly treat v^{k+1} - v̆^k that depends on z^{k+1}!
270 if aposteriori_transport( 276 if aposteriori_transport(
271 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, 277 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
278 Some(z_new.dist(&z, L2)),
272 ε, &config.transport 279 ε, &config.transport
273 ) { 280 ) {
274 break 'adapt_transport (maybe_d, within_tolerances, τv̆z) 281 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new)
275 } 282 }
276 }; 283 };
277 284
278 stats.untransported_fraction = Some({ 285 stats.untransported_fraction = Some({
279 assert_eq!(μ_base_masses.len(), γ1.len()); 286 assert_eq!(μ_base_masses.len(), γ1.len());
285 assert_eq!(μ_base_masses.len(), γ1.len()); 292 assert_eq!(μ_base_masses.len(), γ1.len());
286 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); 293 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
287 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) 294 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
288 }); 295 });
289 296
290 // // Merge spikes. 297 // Merge spikes.
291 // // This expects the prune below to prune γ. 298 // This crucially expects the merge routine to be stable with respect to spike locations,
292 // // TODO: This may not work correctly in all cases. 299 // and not to performing any pruning. That is be to done below simultaneously for γ.
293 // let ins = &config.insertion; 300 let ins = &config.insertion;
294 // if ins.merge_now(&state) { 301 if ins.merge_now(&state) {
295 // if let SpikeMergingMethod::None = ins.merging { 302 stats.merged += prox_penalty.merge_spikes_no_fitness(
296 // } else { 303 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, &reg,
297 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { 304 //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
298 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; 305 );
299 // let mut d = &τv̆ + op𝒟.preapply(ν); 306 }
300 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
301 // });
302 // }
303 // }
304 307
305 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 308 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
306 // latter needs to be pruned when μ is. 309 // latter needs to be pruned when μ is.
307 // TODO: This could do with a two-vector Vec::retain to avoid copies. 310 // TODO: This could do with a two-vector Vec::retain to avoid copies.
308 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); 311 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
311 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); 314 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
312 stats.pruned += μ.len() - μ_new.len(); 315 stats.pruned += μ.len() - μ_new.len();
313 μ = μ_new; 316 μ = μ_new;
314 } 317 }
315 318
316 // Do z variable primal update
317 z.axpy(-σ_p/τ, τz̆, 1.0); // TODO: simplify nasty factors
318 opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0);
319 z = fnR.prox(σ_p, z);
320 // Do dual update 319 // Do dual update
321 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] 320 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
322 opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0); 321 opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0);
323 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b 322 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
324 opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b 323 opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
325 y = starH.prox(σ_d, y); 324 y = starH.prox(σ_d, y);
325 z = z_new;
326 326
327 // Update residual 327 // Update residual
328 residual = calculate_residual(Pair(&μ, &z), opA, b); 328 residual = calculate_residual(Pair(&μ, &z), opA, b);
329 329
330 // Update step length parameters 330 // Update step length parameters
347 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() 347 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2()
348 //+ fnR.apply(z) + reg.apply(μ) 348 //+ fnR.apply(z) + reg.apply(μ)
349 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) 349 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
350 }; 350 };
351 351
352 μ.merge_spikes_fitness(config.insertion.merging, fit, |&v| v); 352 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
353 μ.prune(); 353 μ.prune();
354 Pair(μ, z) 354 Pair(μ, z)
355 } 355 }

mercurial