| 10 use std::iter::Iterator; | 10 use std::iter::Iterator; | 
| 11 | 11 | 
| 12 use alg_tools::iterate::AlgIteratorFactory; | 12 use alg_tools::iterate::AlgIteratorFactory; | 
| 13 use alg_tools::euclidean::Euclidean; | 13 use alg_tools::euclidean::Euclidean; | 
| 14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; | 14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; | 
| 15 use alg_tools::norms::Norm; | 15 use alg_tools::norms::{Norm, Dist}; | 
| 16 use alg_tools::direct_product::Pair; | 16 use alg_tools::direct_product::Pair; | 
| 17 use alg_tools::nalgebra_support::ToNalgebraRealField; | 17 use alg_tools::nalgebra_support::ToNalgebraRealField; | 
| 18 use alg_tools::linops::{ | 18 use alg_tools::linops::{ | 
| 19     BoundedLinear, AXPY, GEMV, Adjointable, IdOp, | 19     BoundedLinear, AXPY, GEMV, Adjointable, IdOp, | 
| 20 }; | 20 }; | 
| 43     TransportConfig, | 43     TransportConfig, | 
| 44     TransportStepLength, | 44     TransportStepLength, | 
| 45     initial_transport, | 45     initial_transport, | 
| 46     aposteriori_transport, | 46     aposteriori_transport, | 
| 47 }; | 47 }; | 
| 48 use crate::dataterm::{calculate_residual, calculate_residual2}; | 48 use crate::dataterm::{ | 
|  | 49     calculate_residual2, | 
|  | 50     calculate_residual, | 
|  | 51 }; | 
|  | 52 | 
| 49 | 53 | 
| 50 /// Settings for [`pointsource_sliding_pdps_pair`]. | 54 /// Settings for [`pointsource_sliding_pdps_pair`]. | 
| 51 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | 55 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | 
| 52 #[serde(default)] | 56 #[serde(default)] | 
| 53 pub struct SlidingPDPSConfig<F : Float> { | 57 pub struct SlidingPDPSConfig<F : Float> { | 
| 132         + GEMV<F, Z> | 135         + GEMV<F, Z> | 
| 133         + Adjointable<Z, Y, AdjointCodomain = Z>, | 136         + Adjointable<Z, Y, AdjointCodomain = Z>, | 
| 134     for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, | 137     for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, | 
| 135     Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, | 138     Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, | 
| 136     for<'b> &'b Y : Instance<Y>, | 139     for<'b> &'b Y : Instance<Y>, | 
| 137     Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2>, | 140     Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2> + Dist<F, L2>, | 
| 138     for<'b> &'b Z : Instance<Z>, | 141     for<'b> &'b Z : Instance<Z>, | 
| 139     R : Prox<Z, Codomain=F>, | 142     R : Prox<Z, Codomain=F>, | 
| 140     H : Conjugable<Y, F, Codomain=F>, | 143     H : Conjugable<Y, F, Codomain=F>, | 
| 141     for<'b> H::Conjugate<'b> : Prox<Y>, | 144     for<'b> H::Conjugate<'b> : Prox<Y>, | 
| 142 { | 145 { | 
| 232     // Run the algorithm | 235     // Run the algorithm | 
| 233     for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { | 236     for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { | 
| 234         // Calculate initial transport | 237         // Calculate initial transport | 
| 235         let Pair(v, _) = opA.preadjoint().apply(&residual); | 238         let Pair(v, _) = opA.preadjoint().apply(&residual); | 
| 236         //opKμ.preadjoint().apply_add(&mut v, y); | 239         //opKμ.preadjoint().apply_add(&mut v, y); | 
| 237         let z_base = z.clone(); |  | 
| 238         // We want to proceed as in Example 4.12 but with v and v̆ as in §5. | 240         // We want to proceed as in Example 4.12 but with v and v̆ as in §5. | 
| 239         // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have | 241         // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have | 
| 240         // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, | 242         // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, | 
| 241         // where A_ν^* becomes a multiplier. | 243         // where A_ν^* becomes a multiplier. | 
| 242         // This is much easier with K_μ = 0, which is the only reason why are enforcing it. | 244         // This is much easier with K_μ = 0, which is the only reason why are enforcing it. | 
| 248             v, &config.transport, | 250             v, &config.transport, | 
| 249         ); | 251         ); | 
| 250 | 252 | 
| 251         // Solve finite-dimensional subproblem several times until the dual variable for the | 253         // Solve finite-dimensional subproblem several times until the dual variable for the | 
| 252         // regularisation term conforms to the assumptions made for the transport above. | 254         // regularisation term conforms to the assumptions made for the transport above. | 
| 253         let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { | 255         let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { | 
| 254             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) | 256             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) | 
| 255             let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), | 257             let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), | 
| 256                                                  Pair(&μ_base_minus_γ0, &zero_z), | 258                                                  Pair(&μ_base_minus_γ0, &zero_z), | 
| 257                                                  opA, b); | 259                                                  opA, b); | 
| 258             let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); | 260             let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); | 
| 259             // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); | 261             // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); | 
| 260 | 262 | 
| 261             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | 263             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | 
| 262             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( | 264             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( | 
| 263                 &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), | 265                 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), | 
| 264                 τ, ε, &config.insertion, | 266                 τ, ε, &config.insertion, | 
| 265                 ®, &state, &mut stats, | 267                 ®, &state, &mut stats, | 
| 266             ); | 268             ); | 
| 267 | 269 | 
|  | 270             // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} | 
|  | 271             let mut z_new = τz̆; | 
|  | 272             opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); | 
|  | 273             z_new = fnR.prox(σ_p, z_new + &z); | 
|  | 274 | 
| 268             // A posteriori transport adaptation. | 275             // A posteriori transport adaptation. | 
| 269             // TODO: this does not properly treat v^{k+1} - v̆^k that depends on z^{k+1}! |  | 
| 270             if aposteriori_transport( | 276             if aposteriori_transport( | 
| 271                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, | 277                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, | 
|  | 278                 Some(z_new.dist(&z, L2)), | 
| 272                 ε, &config.transport | 279                 ε, &config.transport | 
| 273             ) { | 280             ) { | 
| 274                 break 'adapt_transport (maybe_d, within_tolerances, τv̆z) | 281                 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new) | 
| 275             } | 282             } | 
| 276         }; | 283         }; | 
| 277 | 284 | 
| 278         stats.untransported_fraction = Some({ | 285         stats.untransported_fraction = Some({ | 
| 279             assert_eq!(μ_base_masses.len(), γ1.len()); | 286             assert_eq!(μ_base_masses.len(), γ1.len()); | 
| 285             assert_eq!(μ_base_masses.len(), γ1.len()); | 292             assert_eq!(μ_base_masses.len(), γ1.len()); | 
| 286             let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); | 293             let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); | 
| 287             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) | 294             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) | 
| 288         }); | 295         }); | 
| 289 | 296 | 
| 290         // // Merge spikes. | 297         // Merge spikes. | 
| 291         // // This expects the prune below to prune γ. | 298         // This crucially expects the merge routine to be stable with respect to spike locations, | 
| 292         // // TODO: This may not work correctly in all cases. | 299         // and not to performing any pruning. That is be to done below simultaneously for γ. | 
| 293         // let ins = &config.insertion; | 300         let ins = &config.insertion; | 
| 294         // if ins.merge_now(&state) { | 301         if ins.merge_now(&state) { | 
| 295         //     if let SpikeMergingMethod::None = ins.merging { | 302             stats.merged += prox_penalty.merge_spikes_no_fitness( | 
| 296         //     } else { | 303                 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, ®, | 
| 297         //         stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { | 304                 //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), | 
| 298         //             let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; | 305             ); | 
| 299         //             let mut d = &τv̆ + op𝒟.preapply(ν); | 306         } | 
| 300         //             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |  | 
| 301         //         }); |  | 
| 302         //     } |  | 
| 303         // } |  | 
| 304 | 307 | 
| 305         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the | 308         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the | 
| 306         // latter needs to be pruned when μ is. | 309         // latter needs to be pruned when μ is. | 
| 307         // TODO: This could do with a two-vector Vec::retain to avoid copies. | 310         // TODO: This could do with a two-vector Vec::retain to avoid copies. | 
| 308         let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); | 311         let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); | 
| 311             γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); | 314             γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); | 
| 312             stats.pruned += μ.len() - μ_new.len(); | 315             stats.pruned += μ.len() - μ_new.len(); | 
| 313             μ = μ_new; | 316             μ = μ_new; | 
| 314         } | 317         } | 
| 315 | 318 | 
| 316         // Do z variable primal update |  | 
| 317         z.axpy(-σ_p/τ, τz̆, 1.0); // TODO: simplify nasty factors |  | 
| 318         opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0); |  | 
| 319         z = fnR.prox(σ_p, z); |  | 
| 320         // Do dual update | 319         // Do dual update | 
| 321         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] | 320         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] | 
| 322         opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0); | 321         opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); | 
| 323         // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b | 322         // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b | 
| 324         opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b | 323         opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b | 
| 325         y = starH.prox(σ_d, y); | 324         y = starH.prox(σ_d, y); | 
|  | 325         z = z_new; | 
| 326 | 326 | 
| 327         // Update residual | 327         // Update residual | 
| 328         residual = calculate_residual(Pair(&μ, &z), opA, b); | 328         residual = calculate_residual(Pair(&μ, &z), opA, b); | 
| 329 | 329 | 
| 330         // Update step length parameters | 330         // Update step length parameters |