| 1 /*! |
1 /*! |
| 2 Solver for the point source localisation problem using a sliding |
2 Solver for the point source localisation problem using a sliding |
| 3 primal-dual proximal splitting method. |
3 primal-dual proximal splitting method. |
| 4 */ |
4 */ |
| 5 |
5 |
| |
6 use crate::fb::*; |
| |
7 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; |
| |
8 use crate::measures::merging::SpikeMerging; |
| |
9 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
| |
10 use crate::plot::Plotter; |
| |
11 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; |
| |
12 use crate::regularisation::SlidingRegTerm; |
| |
13 use crate::sliding_fb::{ |
| |
14 aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength, |
| |
15 }; |
| |
16 use crate::types::*; |
| |
17 use alg_tools::convex::{Conjugable, Prox, Zero}; |
| |
18 use alg_tools::direct_product::Pair; |
| |
19 use alg_tools::error::DynResult; |
| |
20 use alg_tools::euclidean::ClosedEuclidean; |
| |
21 use alg_tools::iterate::AlgIteratorFactory; |
| |
22 use alg_tools::linops::{ |
| |
23 BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, |
| |
24 }; |
| |
25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; |
| |
26 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
27 use alg_tools::norms::{Norm, L2}; |
| |
28 use anyhow::ensure; |
| 6 use numeric_literals::replace_float_literals; |
29 use numeric_literals::replace_float_literals; |
| 7 use serde::{Deserialize, Serialize}; |
30 use serde::{Deserialize, Serialize}; |
| 8 //use colored::Colorize; |
31 //use colored::Colorize; |
| 9 //use nalgebra::{DVector, DMatrix}; |
32 //use nalgebra::{DVector, DMatrix}; |
| 10 use std::iter::Iterator; |
33 use std::iter::Iterator; |
| 11 |
34 |
| 12 use alg_tools::convex::{Conjugable, Prox}; |
|
| 13 use alg_tools::direct_product::Pair; |
|
| 14 use alg_tools::euclidean::Euclidean; |
|
| 15 use alg_tools::iterate::AlgIteratorFactory; |
|
| 16 use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; |
|
| 17 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; |
|
| 18 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 19 use alg_tools::norms::{Dist, Norm}; |
|
| 20 use alg_tools::norms::{PairNorm, L2}; |
|
| 21 |
|
| 22 use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; |
|
| 23 use crate::measures::merging::SpikeMerging; |
|
| 24 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
|
| 25 use crate::types::*; |
|
| 26 // use crate::transport::TransportLipschitz; |
|
| 27 //use crate::tolerance::Tolerance; |
|
| 28 use crate::fb::*; |
|
| 29 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
|
| 30 use crate::regularisation::SlidingRegTerm; |
|
| 31 // use crate::dataterm::L2Squared; |
|
| 32 use crate::dataterm::{calculate_residual, calculate_residual2}; |
|
| 33 use crate::sliding_fb::{ |
|
| 34 aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, |
|
| 35 }; |
|
| 36 |
|
| 37 /// Settings for [`pointsource_sliding_pdps_pair`]. |
35 /// Settings for [`pointsource_sliding_pdps_pair`]. |
| 38 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
36 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 39 #[serde(default)] |
37 #[serde(default)] |
| 40 pub struct SlidingPDPSConfig<F: Float> { |
38 pub struct SlidingPDPSConfig<F: Float> { |
| 41 /// Primal step length scaling. |
39 /// Overall primal step length scaling. |
| 42 pub τ0: F, |
40 pub τ0: F, |
| 43 /// Primal step length scaling. |
41 /// Primal step length scaling for additional variable. |
| 44 pub σp0: F, |
42 pub σp0: F, |
| 45 /// Dual step length scaling. |
43 /// Dual step length scaling for additional variable. |
| |
44 /// |
| |
45 /// Taken zero for [`pointsource_sliding_fb_pair`]. |
| 46 pub σd0: F, |
46 pub σd0: F, |
| 47 /// Transport parameters |
47 /// Transport parameters |
| 48 pub transport: TransportConfig<F>, |
48 pub transport: TransportConfig<F>, |
| 49 /// Generic parameters |
49 /// Generic parameters |
| 50 pub insertion: FBGenericConfig<F>, |
50 pub insertion: InsertionConfig<F>, |
| |
51 /// Guess for curvature bound calculations. |
| |
52 pub guess: BoundedCurvatureGuess, |
| 51 } |
53 } |
| 52 |
54 |
| 53 #[replace_float_literals(F::cast_from(literal))] |
55 #[replace_float_literals(F::cast_from(literal))] |
| 54 impl<F: Float> Default for SlidingPDPSConfig<F> { |
56 impl<F: Float> Default for SlidingPDPSConfig<F> { |
| 55 fn default() -> Self { |
57 fn default() -> Self { |
| 56 SlidingPDPSConfig { |
58 SlidingPDPSConfig { |
| 57 τ0: 0.99, |
59 τ0: 0.99, |
| 58 σd0: 0.05, |
60 σd0: 0.05, |
| 59 σp0: 0.99, |
61 σp0: 0.99, |
| 60 transport: TransportConfig { |
62 transport: TransportConfig { θ0: 0.9, ..Default::default() }, |
| 61 θ0: 0.9, |
|
| 62 ..Default::default() |
|
| 63 }, |
|
| 64 insertion: Default::default(), |
63 insertion: Default::default(), |
| |
64 guess: BoundedCurvatureGuess::BetterThanZero, |
| 65 } |
65 } |
| 66 } |
66 } |
| 67 } |
67 } |
| 68 |
68 |
| 69 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<F, N>, Z>; |
69 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>; |
| 70 |
70 |
| 71 /// Iteratively solve the pointsource localisation with an additional variable |
71 /// Iteratively solve the pointsource localisation with an additional variable |
| 72 /// using sliding primal-dual proximal splitting |
72 /// using sliding primal-dual proximal splitting |
| 73 /// |
73 /// |
| 74 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
74 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
| 75 #[replace_float_literals(F::cast_from(literal))] |
75 #[replace_float_literals(F::cast_from(literal))] |
| 76 pub fn pointsource_sliding_pdps_pair< |
76 pub fn pointsource_sliding_pdps_pair< |
| 77 F, |
77 F, |
| 78 I, |
78 I, |
| 79 A, |
|
| 80 S, |
79 S, |
| |
80 Dat, |
| 81 Reg, |
81 Reg, |
| 82 P, |
82 P, |
| 83 Z, |
83 Z, |
| 84 R, |
84 R, |
| 85 Y, |
85 Y, |
| |
86 Plot, |
| 86 /*KOpM, */ KOpZ, |
87 /*KOpM, */ KOpZ, |
| 87 H, |
88 H, |
| 88 const N: usize, |
89 const N: usize, |
| 89 >( |
90 >( |
| 90 opA: &A, |
91 f: &Dat, |
| 91 b: &A::Observable, |
92 reg: &Reg, |
| 92 reg: Reg, |
|
| 93 prox_penalty: &P, |
93 prox_penalty: &P, |
| 94 config: &SlidingPDPSConfig<F>, |
94 config: &SlidingPDPSConfig<F>, |
| 95 iterator: I, |
95 iterator: I, |
| 96 mut plotter: SeqPlotter<F, N>, |
96 mut plotter: Plot, |
| |
97 (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y), |
| 97 //opKμ : KOpM, |
98 //opKμ : KOpM, |
| 98 opKz: &KOpZ, |
99 opKz: &KOpZ, |
| 99 fnR: &R, |
100 fnR: &R, |
| 100 fnH: &H, |
101 fnH: &H, |
| 101 mut z: Z, |
102 ) -> DynResult<MeasureZ<F, Z, N>> |
| 102 mut y: Y, |
|
| 103 ) -> MeasureZ<F, Z, N> |
|
| 104 where |
103 where |
| 105 F: Float + ToNalgebraRealField, |
104 F: Float + ToNalgebraRealField, |
| 106 I: AlgIteratorFactory<IterInfo<F, N>>, |
105 I: AlgIteratorFactory<IterInfo<F>>, |
| 107 A: ForwardModel<MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, PreadjointCodomain = Pair<S, Z>> |
106 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| 108 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType = F> |
107 + BoundedCurvature<F>, |
| 109 + BoundedCurvature<FloatType = F>, |
108 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 110 S: DifferentiableRealMapping<F, N>, |
109 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| 111 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, |
110 //Pair<S, Z>: ClosedMul<F>, |
| 112 PlotLookup: Plotting<N>, |
111 RNDM<N, F>: SpikeMerging<F>, |
| 113 RNDM<F, N>: SpikeMerging<F>, |
112 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 114 Reg: SlidingRegTerm<F, N>, |
113 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| 115 P: ProxPenalty<F, S, Reg, N>, |
114 // KOpM : Linear<RNDM<N, F>, Codomain=Y> |
| 116 // KOpM : Linear<RNDM<F, N>, Codomain=Y> |
115 // + GEMV<F, RNDM<N, F>> |
| 117 // + GEMV<F, RNDM<F, N>> |
|
| 118 // + Preadjointable< |
116 // + Preadjointable< |
| 119 // RNDM<F, N>, Y, |
117 // RNDM<N, F>, Y, |
| 120 // PreadjointCodomain = S, |
118 // PreadjointCodomain = S, |
| 121 // > |
119 // > |
| 122 // + TransportLipschitz<L2Squared, FloatType=F> |
120 // + TransportLipschitz<L2Squared, FloatType=F> |
| 123 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
121 // + AdjointProductBoundedBy<RNDM<N, F>, 𝒟, FloatType=F>, |
| 124 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
122 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
| 125 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
123 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
| 126 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
124 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
| 127 + GEMV<F, Z> |
125 + GEMV<F, Z> |
| 128 + Adjointable<Z, Y, AdjointCodomain = Z>, |
126 + SimplyAdjointable<Z, Y, AdjointCodomain = Z>, |
| 129 for<'b> KOpZ::Adjoint<'b>: GEMV<F, Y>, |
127 KOpZ::SimpleAdjoint: GEMV<F, Y>, |
| 130 Y: AXPY<F> + Euclidean<F, Output = Y> + Clone + ClosedAdd, |
128 Y: ClosedEuclidean<F>, |
| 131 for<'b> &'b Y: Instance<Y>, |
129 for<'b> &'b Y: Instance<Y>, |
| 132 Z: AXPY<F, Owned = Z> + Euclidean<F, Output = Z> + Clone + Norm<F, L2> + Dist<F, L2>, |
130 Z: ClosedEuclidean<F>, |
| 133 for<'b> &'b Z: Instance<Z>, |
131 for<'b> &'b Z: Instance<Z>, |
| 134 R: Prox<Z, Codomain = F>, |
132 R: Prox<Z, Codomain = F>, |
| 135 H: Conjugable<Y, F, Codomain = F>, |
133 H: Conjugable<Y, F, Codomain = F>, |
| 136 for<'b> H::Conjugate<'b>: Prox<Y>, |
134 for<'b> H::Conjugate<'b>: Prox<Y>, |
| |
135 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| 137 { |
136 { |
| 138 // Check parameters |
137 // Check parameters |
| 139 assert!( |
138 /*ensure!( |
| 140 config.τ0 > 0.0 |
139 config.τ0 > 0.0 |
| 141 && config.τ0 < 1.0 |
140 && config.τ0 < 1.0 |
| 142 && config.σp0 > 0.0 |
141 && config.σp0 > 0.0 |
| 143 && config.σp0 < 1.0 |
142 && config.σp0 < 1.0 |
| 144 && config.σd0 > 0.0 |
143 && config.σd0 > 0.0 |
| 145 && config.σp0 * config.σd0 <= 1.0, |
144 && config.σp0 * config.σd0 <= 1.0, |
| 146 "Invalid step length parameters" |
145 "Invalid step length parameters" |
| 147 ); |
146 );*/ |
| 148 config.transport.check(); |
147 config.transport.check()?; |
| 149 |
148 |
| 150 // Initialise iterates |
149 // Initialise iterates |
| 151 let mut μ = DiscreteMeasure::new(); |
150 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 152 let mut γ1 = DiscreteMeasure::new(); |
151 let mut γ1 = DiscreteMeasure::new(); |
| 153 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
152 //let zero_z = z.similar_origin(); |
| 154 let zero_z = z.similar_origin(); |
|
| 155 |
153 |
| 156 // Set up parameters |
154 // Set up parameters |
| 157 // TODO: maybe this PairNorm doesn't make sense here? |
155 // TODO: maybe this PairNorm doesn't make sense here? |
| 158 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
156 // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); |
| 159 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); |
157 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); |
| 160 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
158 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
| 161 let nKz = opKz.opnorm_bound(L2, L2); |
159 let nKz = opKz.opnorm_bound(L2, L2)?; |
| 162 let ℓ = 0.0; |
160 let ℓ = 0.0; |
| 163 let opIdZ = IdOp::new(); |
161 let idOpZ = IdOp::new(); |
| 164 let (l, l_z) = opA |
162 let opKz_adj = opKz.adjoint(); |
| 165 .adjoint_product_pair_bound(prox_penalty, &opIdZ) |
163 let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; |
| 166 .unwrap(); |
164 |
| 167 // We need to satisfy |
165 // We need to satisfy |
| 168 // |
166 // |
| 169 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
167 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
| 170 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
168 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 171 // with 1 > σ_p L_z and 1 > τ L. |
169 // with 1 > σ_p L_z and 1 > τ L. |
| 172 // |
170 // |
| 173 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
171 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
| 174 // ^^^^^ < 1. then we solve τ from the rest. |
172 // ^^^^^ < 1. then we solve τ from the rest. |
| 175 let σ_d = config.σd0 / nKz; |
173 // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. |
| |
174 let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz }; |
| 176 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
175 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
| 177 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
176 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
| 178 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
177 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
| 179 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
178 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
| 180 let φ = 1.0 - config.σp0; |
179 let φ = 1.0 - config.σp0; |
| 181 let a = 1.0 - σ_p * l_z; |
180 let a = 1.0 - σ_p * l_z; |
| 182 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
181 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
| 183 let ψ = 1.0 - τ * l; |
182 let ψ = 1.0 - τ * l; |
| 184 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
183 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
| 185 assert!(β < 1.0); |
184 ensure!(β < 1.0); |
| 186 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
185 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
| 187 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
186 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
| 188 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
187 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
| 189 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
188 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
| 190 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); |
189 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); |
| 191 let transport_lip = maybe_transport_lip.unwrap(); |
190 let transport_lip = maybe_transport_lip?; |
| 192 let calculate_θ = |ℓ_F, max_transport| { |
191 let calculate_θ = |ℓ_F, max_transport| { |
| 193 let ℓ_r = transport_lip * max_transport; |
192 let ℓ_r = transport_lip * max_transport; |
| 194 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) |
193 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) |
| 195 }; |
194 }; |
| 196 let mut θ_or_adaptive = match maybe_ℓ_F0 { |
195 let mut θ_or_adaptive = match maybe_ℓ_F { |
| 197 // We assume that the residual is decreasing. |
196 // We assume that the residual is decreasing. |
| 198 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { |
197 Ok(ℓ_F) => TransportStepLength::AdaptiveMax { |
| 199 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual |
198 l: ℓ_F, // TODO: could estimate computing the real reesidual |
| 200 max_transport: 0.0, |
199 max_transport: 0.0, |
| 201 g: calculate_θ, |
200 g: calculate_θ, |
| 202 }, |
201 }, |
| 203 None => TransportStepLength::FullyAdaptive { |
202 Err(_) => { |
| 204 l: F::EPSILON, |
203 TransportStepLength::FullyAdaptive { |
| 205 max_transport: 0.0, |
204 l: F::EPSILON, max_transport: 0.0, g: calculate_θ |
| 206 g: calculate_θ, |
205 } |
| 207 }, |
206 } |
| 208 }; |
207 }; |
| 209 // Acceleration is not currently supported |
208 // Acceleration is not currently supported |
| 210 // let γ = dataterm.factor_of_strong_convexity(); |
209 // let γ = dataterm.factor_of_strong_convexity(); |
| 211 let ω = 1.0; |
210 let ω = 1.0; |
| 212 |
211 |
| 229 ..stats |
228 ..stats |
| 230 }; |
229 }; |
| 231 let mut stats = IterInfo::new(); |
230 let mut stats = IterInfo::new(); |
| 232 |
231 |
| 233 // Run the algorithm |
232 // Run the algorithm |
| 234 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
233 for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { |
| 235 // Calculate initial transport |
234 // Calculate initial transport |
| 236 let Pair(v, _) = opA.preadjoint().apply(&residual); |
235 let Pair(v, _) = f.differential(Pair(&μ, &z)); |
| 237 //opKμ.preadjoint().apply_add(&mut v, y); |
236 //opKμ.preadjoint().apply_add(&mut v, y); |
| 238 // We want to proceed as in Example 4.12 but with v and v̆ as in §5. |
237 // We want to proceed as in Example 4.12 but with v and v̆ as in §5. |
| 239 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
238 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
| 240 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
239 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
| 241 // where A_ν^* becomes a multiplier. |
240 // where A_ν^* becomes a multiplier. |
| 242 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
241 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
| 243 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
242 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
| 244 |
243 |
| |
244 //dbg!(&μ); |
| |
245 |
| 245 let (μ_base_masses, mut μ_base_minus_γ0) = |
246 let (μ_base_masses, mut μ_base_minus_γ0) = |
| 246 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
247 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
| 247 |
248 |
| 248 // Solve finite-dimensional subproblem several times until the dual variable for the |
249 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 249 // regularisation term conforms to the assumptions made for the transport above. |
250 // regularisation term conforms to the assumptions made for the transport above. |
| 250 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
251 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
| 251 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
252 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
| 252 let residual_μ̆ = |
253 // let residual_μ̆ = |
| 253 calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); |
254 // calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); |
| 254 let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
255 // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
| |
256 // TODO: might be able to optimise the measure sum working as calculate_residual2 above. |
| |
257 let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ; |
| 255 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
258 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
| 256 |
259 |
| 257 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
260 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 258 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
261 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
| 259 &mut μ, |
262 &mut μ, |
| 334 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
337 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 335 opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
338 opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 336 y = starH.prox(σ_d, y); |
339 y = starH.prox(σ_d, y); |
| 337 z = z_new; |
340 z = z_new; |
| 338 |
341 |
| 339 // Update residual |
|
| 340 residual = calculate_residual(Pair(&μ, &z), opA, b); |
|
| 341 |
|
| 342 // Update step length parameters |
342 // Update step length parameters |
| 343 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
343 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 344 |
344 |
| 345 // Give statistics if requested |
345 // Give statistics if requested |
| 346 let iter = state.iteration(); |
346 let iter = state.iteration(); |
| 347 stats.this_iters += 1; |
347 stats.this_iters += 1; |
| 348 |
348 |
| 349 state.if_verbose(|| { |
349 state.if_verbose(|| { |
| 350 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
350 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
| 351 full_stats( |
351 full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 352 &residual, |
|
| 353 &μ, |
|
| 354 &z, |
|
| 355 ε, |
|
| 356 std::mem::replace(&mut stats, IterInfo::new()), |
|
| 357 ) |
|
| 358 }); |
352 }); |
| 359 |
353 |
| 360 // Update main tolerance for next iteration |
354 // Update main tolerance for next iteration |
| 361 ε = tolerance.update(ε, iter); |
355 ε = tolerance.update(ε, iter); |
| 362 } |
356 } |
| 363 |
357 |
| 364 let fit = |μ̃: &RNDM<F, N>| { |
358 let fit = |μ̃: &RNDM<N, F>| { |
| 365 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() |
359 f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ |
| 366 //+ fnR.apply(z) + reg.apply(μ) |
|
| 367 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
360 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
| 368 }; |
361 }; |
| 369 |
362 |
| 370 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
363 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
| 371 μ.prune(); |
364 μ.prune(); |
| 372 Pair(μ, z) |
365 Ok(Pair(μ, z)) |
| 373 } |
366 } |
| |
367 |
| |
368 /// Iteratively solve the pointsource localisation with an additional variable |
| |
369 /// using sliding forward-backward splitting. |
| |
370 /// |
| |
371 /// The implementation uses [`pointsource_sliding_pdps_pair`] with appropriate dummy |
| |
372 /// variables, operators, and functions. |
| |
373 #[replace_float_literals(F::cast_from(literal))] |
| |
374 pub fn pointsource_sliding_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>( |
| |
375 f: &Dat, |
| |
376 reg: &Reg, |
| |
377 prox_penalty: &P, |
| |
378 config: &SlidingFBConfig<F>, |
| |
379 iterator: I, |
| |
380 plotter: Plot, |
| |
381 (μ0, z): (Option<RNDM<N, F>>, Z), |
| |
382 //opKμ : KOpM, |
| |
383 fnR: &R, |
| |
384 ) -> DynResult<MeasureZ<F, Z, N>> |
| |
385 where |
| |
386 F: Float + ToNalgebraRealField, |
| |
387 I: AlgIteratorFactory<IterInfo<F>>, |
| |
388 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| |
389 + BoundedCurvature<F>, |
| |
390 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| |
391 RNDM<N, F>: SpikeMerging<F>, |
| |
392 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| |
393 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| |
394 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| |
395 Z: ClosedEuclidean<F> + AXPY + Clone, |
| |
396 for<'b> &'b Z: Instance<Z>, |
| |
397 R: Prox<Z, Codomain = F>, |
| |
398 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| |
399 // We should not need to explicitly require this: |
| |
400 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>, |
| |
401 // Loc<0, F>: StaticEuclidean<Field = F, PrincipalE = Loc<0, F>> |
| |
402 // + Instance<Loc<0, F>> |
| |
403 // + VectorSpace<Field = F>, |
| |
404 { |
| |
405 let opKz: ZeroOp<Z, Loc<0, F>, _, _, F> = |
| |
406 ZeroOp::new_dualisable(StaticEuclideanOriginGenerator, z.dual_origin()); |
| |
407 let fnH = Zero::new(); |
| |
408 // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 |
| |
409 // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. |
| |
410 let &SlidingFBConfig { τ0, σp0, insertion, transport, guess } = config; |
| |
411 let pdps_config = SlidingPDPSConfig { τ0, σp0, insertion, transport, guess, σd0: 0.0 }; |
| |
412 |
| |
413 pointsource_sliding_pdps_pair( |
| |
414 f, |
| |
415 reg, |
| |
416 prox_penalty, |
| |
417 &pdps_config, |
| |
418 iterator, |
| |
419 plotter, |
| |
420 (μ0, z, Loc([])), |
| |
421 &opKz, |
| |
422 fnR, |
| |
423 &fnH, |
| |
424 ) |
| |
425 } |