| 1 /*! |
1 /*! |
| 2 Solver for the point source localisation problem using a sliding |
2 Solver for the point source localisation problem using a sliding |
| 3 primal-dual proximal splitting method. |
3 primal-dual proximal splitting method. |
| 4 */ |
4 */ |
| 5 |
5 |
| |
6 use crate::fb::*; |
| |
7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; |
| |
8 use crate::measures::merging::SpikeMerging; |
| |
9 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
10 use crate::plot::Plotter; |
| |
11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; |
| |
12 use crate::regularisation::SlidingRegTerm; |
| |
13 use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength}; |
| |
14 use crate::types::*; |
| |
15 use alg_tools::convex::{Conjugable, Prox, Zero}; |
| |
16 use alg_tools::direct_product::Pair; |
| |
17 use alg_tools::error::DynResult; |
| |
18 use alg_tools::euclidean::ClosedEuclidean; |
| |
19 use alg_tools::iterate::AlgIteratorFactory; |
| |
20 use alg_tools::linops::{ |
| |
21 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, |
| |
22 }; |
| |
23 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; |
| |
24 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
25 use alg_tools::norms::L2; |
| |
26 use anyhow::ensure; |
| 6 use numeric_literals::replace_float_literals; |
27 use numeric_literals::replace_float_literals; |
| 7 use serde::{Deserialize, Serialize}; |
28 use serde::{Deserialize, Serialize}; |
| 8 //use colored::Colorize; |
|
| 9 //use nalgebra::{DVector, DMatrix}; |
|
| 10 use std::iter::Iterator; |
|
| 11 |
|
| 12 use alg_tools::convex::{Conjugable, Prox}; |
|
| 13 use alg_tools::direct_product::Pair; |
|
| 14 use alg_tools::euclidean::Euclidean; |
|
| 15 use alg_tools::iterate::AlgIteratorFactory; |
|
| 16 use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; |
|
| 17 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; |
|
| 18 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 19 use alg_tools::norms::{Dist, Norm}; |
|
| 20 use alg_tools::norms::{PairNorm, L2}; |
|
| 21 |
|
| 22 use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; |
|
| 23 use crate::measures::merging::SpikeMerging; |
|
| 24 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
|
| 25 use crate::types::*; |
|
| 26 // use crate::transport::TransportLipschitz; |
|
| 27 //use crate::tolerance::Tolerance; |
|
| 28 use crate::fb::*; |
|
| 29 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
|
| 30 use crate::regularisation::SlidingRegTerm; |
|
| 31 // use crate::dataterm::L2Squared; |
|
| 32 use crate::dataterm::{calculate_residual, calculate_residual2}; |
|
| 33 use crate::sliding_fb::{ |
|
| 34 aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, |
|
| 35 }; |
|
| 36 |
29 |
| 37 /// Settings for [`pointsource_sliding_pdps_pair`]. |
30 /// Settings for [`pointsource_sliding_pdps_pair`]. |
| 38 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
31 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 39 #[serde(default)] |
32 #[serde(default)] |
| 40 pub struct SlidingPDPSConfig<F: Float> { |
33 pub struct SlidingPDPSConfig<F: Float> { |
| 41 /// Primal step length scaling. |
34 /// Overall primal step length scaling. |
| 42 pub τ0: F, |
35 pub τ0: F, |
| 43 /// Primal step length scaling. |
36 /// Primal step length scaling for additional variable. |
| 44 pub σp0: F, |
37 pub σp0: F, |
| 45 /// Dual step length scaling. |
38 /// Dual step length scaling for additional variable. |
| |
39 /// |
| |
40 /// Taken zero for [`pointsource_sliding_fb_pair`]. |
| 46 pub σd0: F, |
41 pub σd0: F, |
| 47 /// Transport parameters |
42 /// Transport parameters |
| 48 pub transport: TransportConfig<F>, |
43 pub transport: TransportConfig<F>, |
| 49 /// Generic parameters |
44 /// Generic parameters |
| 50 pub insertion: FBGenericConfig<F>, |
45 pub insertion: InsertionConfig<F>, |
| |
46 /// Guess for curvature bound calculations. |
| |
47 pub guess: BoundedCurvatureGuess, |
| 51 } |
48 } |
| 52 |
49 |
| 53 #[replace_float_literals(F::cast_from(literal))] |
50 #[replace_float_literals(F::cast_from(literal))] |
| 54 impl<F: Float> Default for SlidingPDPSConfig<F> { |
51 impl<F: Float> Default for SlidingPDPSConfig<F> { |
| 55 fn default() -> Self { |
52 fn default() -> Self { |
| 56 SlidingPDPSConfig { |
53 SlidingPDPSConfig { |
| 57 τ0: 0.99, |
54 τ0: 0.99, |
| 58 σd0: 0.05, |
55 σd0: 0.05, |
| 59 σp0: 0.99, |
56 σp0: 0.99, |
| 60 transport: TransportConfig { |
57 transport: TransportConfig { θ0: 0.9, ..Default::default() }, |
| 61 θ0: 0.9, |
|
| 62 ..Default::default() |
|
| 63 }, |
|
| 64 insertion: Default::default(), |
58 insertion: Default::default(), |
| |
59 guess: BoundedCurvatureGuess::BetterThanZero, |
| 65 } |
60 } |
| 66 } |
61 } |
| 67 } |
62 } |
| 68 |
63 |
| 69 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<F, N>, Z>; |
64 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>; |
| 70 |
65 |
| 71 /// Iteratively solve the pointsource localisation with an additional variable |
66 /// Iteratively solve the pointsource localisation with an additional variable |
| 72 /// using sliding primal-dual proximal splitting |
67 /// using sliding primal-dual proximal splitting |
| 73 /// |
68 /// |
| 74 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
69 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
| 75 #[replace_float_literals(F::cast_from(literal))] |
70 #[replace_float_literals(F::cast_from(literal))] |
| 76 pub fn pointsource_sliding_pdps_pair< |
71 pub fn pointsource_sliding_pdps_pair< |
| 77 F, |
72 F, |
| 78 I, |
73 I, |
| 79 A, |
|
| 80 S, |
74 S, |
| |
75 Dat, |
| 81 Reg, |
76 Reg, |
| 82 P, |
77 P, |
| 83 Z, |
78 Z, |
| 84 R, |
79 R, |
| 85 Y, |
80 Y, |
| |
81 Plot, |
| 86 /*KOpM, */ KOpZ, |
82 /*KOpM, */ KOpZ, |
| 87 H, |
83 H, |
| 88 const N: usize, |
84 const N: usize, |
| 89 >( |
85 >( |
| 90 opA: &A, |
86 f: &Dat, |
| 91 b: &A::Observable, |
87 reg: &Reg, |
| 92 reg: Reg, |
|
| 93 prox_penalty: &P, |
88 prox_penalty: &P, |
| 94 config: &SlidingPDPSConfig<F>, |
89 config: &SlidingPDPSConfig<F>, |
| 95 iterator: I, |
90 iterator: I, |
| 96 mut plotter: SeqPlotter<F, N>, |
91 mut plotter: Plot, |
| |
92 (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y), |
| 97 //opKμ : KOpM, |
93 //opKμ : KOpM, |
| 98 opKz: &KOpZ, |
94 opKz: &KOpZ, |
| 99 fnR: &R, |
95 fnR: &R, |
| 100 fnH: &H, |
96 fnH: &H, |
| 101 mut z: Z, |
97 ) -> DynResult<MeasureZ<F, Z, N>> |
| 102 mut y: Y, |
|
| 103 ) -> MeasureZ<F, Z, N> |
|
| 104 where |
98 where |
| 105 F: Float + ToNalgebraRealField, |
99 F: Float + ToNalgebraRealField, |
| 106 I: AlgIteratorFactory<IterInfo<F, N>>, |
100 I: AlgIteratorFactory<IterInfo<F>>, |
| 107 A: ForwardModel<MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, PreadjointCodomain = Pair<S, Z>> |
101 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| 108 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType = F> |
102 + BoundedCurvature<F>, |
| 109 + BoundedCurvature<FloatType = F>, |
103 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 110 S: DifferentiableRealMapping<F, N>, |
104 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| 111 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, |
105 //Pair<S, Z>: ClosedMul<F>, |
| 112 PlotLookup: Plotting<N>, |
106 RNDM<N, F>: SpikeMerging<F>, |
| 113 RNDM<F, N>: SpikeMerging<F>, |
107 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 114 Reg: SlidingRegTerm<F, N>, |
108 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| 115 P: ProxPenalty<F, S, Reg, N>, |
109 // KOpM : Linear<RNDM<N, F>, Codomain=Y> |
| 116 // KOpM : Linear<RNDM<F, N>, Codomain=Y> |
110 // + GEMV<F, RNDM<N, F>> |
| 117 // + GEMV<F, RNDM<F, N>> |
|
| 118 // + Preadjointable< |
111 // + Preadjointable< |
| 119 // RNDM<F, N>, Y, |
112 // RNDM<N, F>, Y, |
| 120 // PreadjointCodomain = S, |
113 // PreadjointCodomain = S, |
| 121 // > |
114 // > |
| 122 // + TransportLipschitz<L2Squared, FloatType=F> |
115 // + TransportLipschitz<L2Squared, FloatType=F> |
| 123 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
116 // + AdjointProductBoundedBy<RNDM<N, F>, 𝒟, FloatType=F>, |
| 124 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
117 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
| 125 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
118 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
| 126 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
119 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
| 127 + GEMV<F, Z> |
120 + GEMV<F, Z> |
| 128 + Adjointable<Z, Y, AdjointCodomain = Z>, |
121 + SimplyAdjointable<Z, Y, AdjointCodomain = Z>, |
| 129 for<'b> KOpZ::Adjoint<'b>: GEMV<F, Y>, |
122 KOpZ::SimpleAdjoint: GEMV<F, Y>, |
| 130 Y: AXPY<F> + Euclidean<F, Output = Y> + Clone + ClosedAdd, |
123 Y: ClosedEuclidean<F>, |
| 131 for<'b> &'b Y: Instance<Y>, |
124 for<'b> &'b Y: Instance<Y>, |
| 132 Z: AXPY<F, Owned = Z> + Euclidean<F, Output = Z> + Clone + Norm<F, L2> + Dist<F, L2>, |
125 Z: ClosedEuclidean<F>, |
| 133 for<'b> &'b Z: Instance<Z>, |
126 for<'b> &'b Z: Instance<Z>, |
| 134 R: Prox<Z, Codomain = F>, |
127 R: Prox<Z, Codomain = F>, |
| 135 H: Conjugable<Y, F, Codomain = F>, |
128 H: Conjugable<Y, F, Codomain = F>, |
| 136 for<'b> H::Conjugate<'b>: Prox<Y>, |
129 for<'b> H::Conjugate<'b>: Prox<Y>, |
| |
130 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| 137 { |
131 { |
| 138 // Check parameters |
132 // Check parameters |
| 139 assert!( |
133 /*ensure!( |
| 140 config.τ0 > 0.0 |
134 config.τ0 > 0.0 |
| 141 && config.τ0 < 1.0 |
135 && config.τ0 < 1.0 |
| 142 && config.σp0 > 0.0 |
136 && config.σp0 > 0.0 |
| 143 && config.σp0 < 1.0 |
137 && config.σp0 < 1.0 |
| 144 && config.σd0 > 0.0 |
138 && config.σd0 > 0.0 |
| 145 && config.σp0 * config.σd0 <= 1.0, |
139 && config.σp0 * config.σd0 <= 1.0, |
| 146 "Invalid step length parameters" |
140 "Invalid step length parameters" |
| 147 ); |
141 );*/ |
| 148 config.transport.check(); |
142 config.transport.check()?; |
| 149 |
143 |
| 150 // Initialise iterates |
144 // Initialise iterates |
| 151 let mut μ = DiscreteMeasure::new(); |
145 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 152 let mut γ1 = DiscreteMeasure::new(); |
146 let mut γ = Transport::new(); |
| 153 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
147 //let zero_z = z.similar_origin(); |
| 154 let zero_z = z.similar_origin(); |
|
| 155 |
148 |
| 156 // Set up parameters |
149 // Set up parameters |
| 157 // TODO: maybe this PairNorm doesn't make sense here? |
150 // TODO: maybe this PairNorm doesn't make sense here? |
| 158 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
151 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
| 159 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); |
152 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); |
| 160 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
153 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
| 161 let nKz = opKz.opnorm_bound(L2, L2); |
154 let nKz = opKz.opnorm_bound(L2, L2)?; |
| |
155 let is_fb = nKz == 0.0; |
| 162 let ℓ = 0.0; |
156 let ℓ = 0.0; |
| 163 let opIdZ = IdOp::new(); |
157 let idOpZ = IdOp::new(); |
| 164 let (l, l_z) = opA |
158 let opKz_adj = opKz.adjoint(); |
| 165 .adjoint_product_pair_bound(prox_penalty, &opIdZ) |
159 let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; |
| 166 .unwrap(); |
160 |
| 167 // We need to satisfy |
161 // We need to satisfy |
| 168 // |
162 // |
| 169 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
163 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
| 170 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
164 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 171 // with 1 > σ_p L_z and 1 > τ L. |
165 // with 1 > σ_p L_z and 1 > τ L. |
| 172 // |
166 // |
| 173 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
167 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
| 174 // ^^^^^ < 1. then we solve τ from the rest. |
168 // ^^^^^ < 1. then we solve τ from the rest. |
| 175 let σ_d = config.σd0 / nKz; |
169 // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. |
| |
170 let σ_d = if is_fb { 0.0 } else { config.σd0 / nKz }; |
| 176 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
171 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
| 177 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
172 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
| 178 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
173 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
| 179 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
174 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
| 180 let φ = 1.0 - config.σp0; |
175 let φ = 1.0 - config.σp0; |
| 181 let a = 1.0 - σ_p * l_z; |
176 let a = 1.0 - σ_p * l_z; |
| 182 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
177 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
| 183 let ψ = 1.0 - τ * l; |
178 let ψ = 1.0 - τ * l; |
| 184 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
179 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
| 185 assert!(β < 1.0); |
180 ensure!(β < 1.0); |
| 186 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
181 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
| 187 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
182 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
| 188 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
183 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
| 189 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
184 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
| 190 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); |
185 |
| 191 let transport_lip = maybe_transport_lip.unwrap(); |
186 let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) { |
| 192 let calculate_θ = |ℓ_F, max_transport| { |
187 (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0), |
| 193 let ℓ_r = transport_lip * max_transport; |
188 (maybe_ℓ_F, Ok(transport_lip)) => { |
| 194 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) |
189 let calculate_θτ = move |ℓ_F, max_transport| { |
| 195 }; |
190 let ℓ_r = transport_lip * max_transport; |
| 196 let mut θ_or_adaptive = match maybe_ℓ_F0 { |
191 config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ) |
| 197 // We assume that the residual is decreasing. |
192 }; |
| 198 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { |
193 match maybe_ℓ_F { |
| 199 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual |
194 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { |
| 200 max_transport: 0.0, |
195 l: ℓ_F, // TODO: could estimate computing the real reesidual |
| 201 g: calculate_θ, |
196 max_transport: 0.0, |
| 202 }, |
197 g: calculate_θτ, |
| 203 None => TransportStepLength::FullyAdaptive { |
198 }, |
| 204 l: F::EPSILON, |
199 Err(_) => TransportStepLength::FullyAdaptive { |
| 205 max_transport: 0.0, |
200 l: F::EPSILON, // Start with something very small to estimate differentials |
| 206 g: calculate_θ, |
201 max_transport: 0.0, |
| 207 }, |
202 g: calculate_θτ, |
| |
203 }, |
| |
204 } |
| |
205 } |
| 208 }; |
206 }; |
| 209 // Acceleration is not currently supported |
207 // Acceleration is not currently supported |
| 210 // let γ = dataterm.factor_of_strong_convexity(); |
208 // let γ = dataterm.factor_of_strong_convexity(); |
| 211 let ω = 1.0; |
209 let ω = 1.0; |
| 212 |
210 |
| 229 ..stats |
227 ..stats |
| 230 }; |
228 }; |
| 231 let mut stats = IterInfo::new(); |
229 let mut stats = IterInfo::new(); |
| 232 |
230 |
| 233 // Run the algorithm |
231 // Run the algorithm |
| 234 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
232 for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { |
| 235 // Calculate initial transport |
233 // Calculate initial transport |
| 236 let Pair(v, _) = opA.preadjoint().apply(&residual); |
234 let Pair(v, _) = f.differential(Pair(&μ, &z)); |
| 237 //opKμ.preadjoint().apply_add(&mut v, y); |
235 //opKμ.preadjoint().apply_add(&mut v, y); |
| 238 // We want to proceed as in Example 4.12 but with v and v̆ as in §5. |
236 // We want to proceed as in Example 4.12 but with v and v̆ as in §5. |
| 239 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
237 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
| 240 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
238 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
| 241 // where A_ν^* becomes a multiplier. |
239 // where A_ν^* becomes a multiplier. |
| 242 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
240 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
| 243 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
241 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
| 244 |
242 |
| 245 let (μ_base_masses, mut μ_base_minus_γ0) = |
243 //dbg!(&μ); |
| 246 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
244 |
| |
245 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport); |
| |
246 |
| |
247 let mut attempts = 0; |
| 247 |
248 |
| 248 // Solve finite-dimensional subproblem several times until the dual variable for the |
249 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 249 // regularisation term conforms to the assumptions made for the transport above. |
250 // regularisation term conforms to the assumptions made for the transport above. |
| 250 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
251 let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop { |
| |
252 // Set initial guess for μ=μ^{k+1}. |
| |
253 γ.μ̆_into(&mut μ); |
| |
254 let μ̆ = μ.clone(); |
| |
255 |
| 251 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
256 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
| 252 let residual_μ̆ = |
257 let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ; |
| 253 calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); |
|
| 254 let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
|
| 255 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
258 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
| 256 |
259 |
| 257 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
260 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 258 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
261 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
| 259 &mut μ, |
262 &mut μ, |
| 260 &mut τv̆, |
263 &mut τv̆, |
| 261 &γ1, |
|
| 262 Some(&μ_base_minus_γ0), |
|
| 263 τ, |
264 τ, |
| 264 ε, |
265 ε, |
| 265 &config.insertion, |
266 &config.insertion, |
| 266 ®, |
267 ®, |
| 267 &state, |
268 &state, |
| 268 &mut stats, |
269 &mut stats, |
| 269 ); |
270 )?; |
| 270 |
271 |
| 271 // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} |
272 // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} |
| 272 let mut z_new = τz̆; |
273 let mut z_new = τz̆; |
| 273 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
274 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
| 274 z_new = fnR.prox(σ_p, z_new + &z); |
275 z_new = fnR.prox(σ_p, z_new + &z); |
| 275 |
276 |
| 276 // A posteriori transport adaptation. |
277 // A posteriori transport adaptation. |
| 277 if aposteriori_transport( |
278 if γ.aposteriori_transport( |
| 278 &mut γ1, |
279 &μ, |
| 279 &mut μ, |
280 &μ̆, |
| 280 &mut μ_base_minus_γ0, |
281 &mut τv̆, |
| 281 &μ_base_masses, |
282 Some(z_new.dist2(&z)), |
| 282 Some(z_new.dist(&z, L2)), |
|
| 283 ε, |
283 ε, |
| 284 &config.transport, |
284 &config.transport, |
| |
285 &mut attempts, |
| 285 ) { |
286 ) { |
| 286 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); |
287 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆); |
| 287 } |
288 } |
| 288 }; |
289 }; |
| 289 |
290 |
| 290 stats.untransported_fraction = Some({ |
291 γ.get_transport_stats(&mut stats, &μ); |
| 291 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 292 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
|
| 293 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
|
| 294 (a + μ_base_minus_γ0.norm(Radon), b + source) |
|
| 295 }); |
|
| 296 stats.transport_error = Some({ |
|
| 297 assert_eq!(μ_base_masses.len(), γ1.len()); |
|
| 298 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
|
| 299 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
|
| 300 }); |
|
| 301 |
292 |
| 302 // Merge spikes. |
293 // Merge spikes. |
| 303 // This crucially expects the merge routine to be stable with respect to spike locations, |
294 // This crucially expects the merge routine to be stable with respect to spike locations, |
| 304 // and not to performing any pruning. That is be to done below simultaneously for γ. |
295 // and not to performing any pruning. That is be to done below simultaneously for γ. |
| 305 let ins = &config.insertion; |
296 if config.insertion.merge_now(&state) { |
| 306 if ins.merge_now(&state) { |
297 stats.merged += prox_penalty.merge_spikes( |
| 307 stats.merged += prox_penalty.merge_spikes_no_fitness( |
|
| 308 &mut μ, |
298 &mut μ, |
| 309 &mut τv̆, |
299 &mut τv̆, |
| 310 &γ1, |
300 &μ̆, |
| 311 Some(&μ_base_minus_γ0), |
|
| 312 τ, |
301 τ, |
| 313 ε, |
302 ε, |
| 314 ins, |
303 &config.insertion, |
| 315 ®, |
304 ®, |
| 316 //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), |
305 is_fb.then_some(|μ̃: &RNDM<N, F>| f.apply(Pair(μ̃, &z))), |
| 317 ); |
306 ); |
| 318 } |
307 } |
| 319 |
308 |
| 320 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
309 γ.prune_compat(&mut μ, &mut stats); |
| 321 // latter needs to be pruned when μ is. |
|
| 322 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
|
| 323 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
|
| 324 if μ_new.len() != μ.len() { |
|
| 325 let mut μ_iter = μ.iter_spikes(); |
|
| 326 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); |
|
| 327 stats.pruned += μ.len() - μ_new.len(); |
|
| 328 μ = μ_new; |
|
| 329 } |
|
| 330 |
310 |
| 331 // Do dual update |
311 // Do dual update |
| 332 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
312 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
| 333 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); |
313 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); |
| 334 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
314 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 335 opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
315 opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 336 y = starH.prox(σ_d, y); |
316 y = starH.prox(σ_d, y); |
| 337 z = z_new; |
317 z = z_new; |
| 338 |
318 |
| 339 // Update residual |
|
| 340 residual = calculate_residual(Pair(&μ, &z), opA, b); |
|
| 341 |
|
| 342 // Update step length parameters |
319 // Update step length parameters |
| 343 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
320 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 344 |
321 |
| 345 // Give statistics if requested |
322 // Give statistics if requested |
| 346 let iter = state.iteration(); |
323 let iter = state.iteration(); |
| 347 stats.this_iters += 1; |
324 stats.this_iters += 1; |
| 348 |
325 |
| 349 state.if_verbose(|| { |
326 state.if_verbose(|| { |
| 350 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
327 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
| 351 full_stats( |
328 full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 352 &residual, |
|
| 353 &μ, |
|
| 354 &z, |
|
| 355 ε, |
|
| 356 std::mem::replace(&mut stats, IterInfo::new()), |
|
| 357 ) |
|
| 358 }); |
329 }); |
| 359 |
330 |
| 360 // Update main tolerance for next iteration |
331 // Update main tolerance for next iteration |
| 361 ε = tolerance.update(ε, iter); |
332 ε = tolerance.update(ε, iter); |
| 362 } |
333 } |
| 363 |
334 |
| 364 let fit = |μ̃: &RNDM<F, N>| { |
335 let fit = |μ̃: &RNDM<N, F>| { |
| 365 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() |
336 f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ |
| 366 //+ fnR.apply(z) + reg.apply(μ) |
|
| 367 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
337 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
| 368 }; |
338 }; |
| 369 |
339 |
| 370 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
340 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
| 371 μ.prune(); |
341 μ.prune(); |
| 372 Pair(μ, z) |
342 Ok(Pair(μ, z)) |
| 373 } |
343 } |
| |
344 |
| |
345 /// Iteratively solve the pointsource localisation with an additional variable |
| |
346 /// using sliding forward-backward splitting. |
| |
347 /// |
| |
348 /// The implementation uses [`pointsource_sliding_pdps_pair`] with appropriate dummy |
| |
349 /// variables, operators, and functions. |
| |
350 #[replace_float_literals(F::cast_from(literal))] |
| |
351 pub fn pointsource_sliding_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>( |
| |
352 f: &Dat, |
| |
353 reg: &Reg, |
| |
354 prox_penalty: &P, |
| |
355 config: &SlidingFBConfig<F>, |
| |
356 iterator: I, |
| |
357 plotter: Plot, |
| |
358 (μ0, z): (Option<RNDM<N, F>>, Z), |
| |
359 //opKμ : KOpM, |
| |
360 fnR: &R, |
| |
361 ) -> DynResult<MeasureZ<F, Z, N>> |
| |
362 where |
| |
363 F: Float + ToNalgebraRealField, |
| |
364 I: AlgIteratorFactory<IterInfo<F>>, |
| |
365 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| |
366 + BoundedCurvature<F>, |
| |
367 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| |
368 RNDM<N, F>: SpikeMerging<F>, |
| |
369 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| |
370 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| |
371 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| |
372 Z: ClosedEuclidean<F> + AXPY + Clone, |
| |
373 for<'b> &'b Z: Instance<Z>, |
| |
374 R: Prox<Z, Codomain = F>, |
| |
375 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| |
376 // We should not need to explicitly require this: |
| |
377 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>, |
| |
378 // Loc<0, F>: StaticEuclidean<Field = F, PrincipalE = Loc<0, F>> |
| |
379 // + Instance<Loc<0, F>> |
| |
380 // + VectorSpace<Field = F>, |
| |
381 { |
| |
382 let opKz: ZeroOp<Z, Loc<0, F>, _, _, F> = |
| |
383 ZeroOp::new_dualisable(StaticEuclideanOriginGenerator, z.dual_origin()); |
| |
384 let fnH = Zero::new(); |
| |
385 // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 |
| |
386 // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. |
| |
387 let &SlidingFBConfig { τ0, σp0, insertion, transport, guess } = config; |
| |
388 let pdps_config = SlidingPDPSConfig { τ0, σp0, insertion, transport, guess, σd0: 0.0 }; |
| |
389 |
| |
390 pointsource_sliding_pdps_pair( |
| |
391 f, |
| |
392 reg, |
| |
393 prox_penalty, |
| |
394 &pdps_config, |
| |
395 iterator, |
| |
396 plotter, |
| |
397 (μ0, z, Loc([])), |
| |
398 &opKz, |
| |
399 fnR, |
| |
400 &fnH, |
| |
401 ) |
| |
402 } |