src/sliding_pdps.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 36
fb911f72e698
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
9 //use nalgebra::{DVector, DMatrix}; 9 //use nalgebra::{DVector, DMatrix};
10 use std::iter::Iterator; 10 use std::iter::Iterator;
11 11
12 use alg_tools::iterate::AlgIteratorFactory; 12 use alg_tools::iterate::AlgIteratorFactory;
13 use alg_tools::euclidean::Euclidean; 13 use alg_tools::euclidean::Euclidean;
14 use alg_tools::sets::Cube;
15 use alg_tools::loc::Loc;
16 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; 14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
17 use alg_tools::norms::Norm; 15 use alg_tools::norms::Norm;
18 use alg_tools::direct_product::Pair; 16 use alg_tools::direct_product::Pair;
19 use alg_tools::bisection_tree::{
20 BTFN,
21 PreBTFN,
22 Bounds,
23 BTNodeLookup,
24 BTNode,
25 BTSearch,
26 P2Minimise,
27 SupportGenerator,
28 LocalAnalysis,
29 //Bounded,
30 };
31 use alg_tools::mapping::RealMapping;
32 use alg_tools::nalgebra_support::ToNalgebraRealField; 17 use alg_tools::nalgebra_support::ToNalgebraRealField;
33 use alg_tools::linops::{ 18 use alg_tools::linops::{
34 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, 19 BoundedLinear, AXPY, GEMV, Adjointable, IdOp,
35 }; 20 };
36 use alg_tools::convex::{Conjugable, Prox}; 21 use alg_tools::convex::{Conjugable, Prox};
37 use alg_tools::norms::{L2, Linfinity, PairNorm}; 22 use alg_tools::norms::{L2, PairNorm};
38 23
39 use crate::types::*; 24 use crate::types::*;
40 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 25 use crate::measures::{DiscreteMeasure, Radon, RNDM};
41 use crate::measures::merging::SpikeMerging; 26 use crate::measures::merging::SpikeMerging;
42 use crate::forward_model::{ 27 use crate::forward_model::{
43 ForwardModel, 28 ForwardModel,
44 AdjointProductPairBoundedBy, 29 AdjointProductPairBoundedBy,
45 LipschitzValues, 30 LipschitzValues,
46 }; 31 };
47 // use crate::transport::TransportLipschitz; 32 // use crate::transport::TransportLipschitz;
48 use crate::seminorms::DiscreteMeasureOp;
49 //use crate::tolerance::Tolerance; 33 //use crate::tolerance::Tolerance;
50 use crate::plot::{ 34 use crate::plot::{
51 SeqPlotter, 35 SeqPlotter,
52 Plotting, 36 Plotting,
53 PlotLookup 37 PlotLookup
99 /// using sliding primal-dual proximal splitting 83 /// using sliding primal-dual proximal splitting
100 /// 84 ///
101 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. 85 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`].
102 #[replace_float_literals(F::cast_from(literal))] 86 #[replace_float_literals(F::cast_from(literal))]
103 pub fn pointsource_sliding_pdps_pair< 87 pub fn pointsource_sliding_pdps_pair<
104 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize 88 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize
105 >( 89 >(
106 opA : &'a A, 90 opA : &A,
107 b : &A::Observable, 91 b : &A::Observable,
108 reg : Reg, 92 reg : Reg,
109 op𝒟 : &'a 𝒟, 93 prox_penalty : &P,
110 config : &SlidingPDPSConfig<F>, 94 config : &SlidingPDPSConfig<F>,
111 iterator : I, 95 iterator : I,
112 mut plotter : SeqPlotter<F, N>, 96 mut plotter : SeqPlotter<F, N>,
113 //opKμ : KOpM, 97 //opKμ : KOpM,
114 opKz : &KOpZ, 98 opKz : &KOpZ,
118 mut y : Y, 102 mut y : Y,
119 ) -> MeasureZ<F, Z, N> 103 ) -> MeasureZ<F, Z, N>
120 where 104 where
121 F : Float + ToNalgebraRealField, 105 F : Float + ToNalgebraRealField,
122 I : AlgIteratorFactory<IterInfo<F, N>>, 106 I : AlgIteratorFactory<IterInfo<F, N>>,
123 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
124 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
125 BTFN<F, GA, BTA, N> : DifferentiableRealMapping<F, N>,
126 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
127 A : ForwardModel< 107 A : ForwardModel<
128 MeasureZ<F, Z, N>, 108 MeasureZ<F, Z, N>,
129 F, 109 F,
130 PairNorm<Radon, L2, L2>, 110 PairNorm<Radon, L2, L2>,
131 PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, 111 PreadjointCodomain = Pair<S, Z>,
132 > 112 >
133 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, 113 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>,
134 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 114 S : DifferentiableRealMapping<F, N>,
135 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 115 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
136 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, 116 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
137 Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
138 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
139 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
140 + DifferentiableRealMapping<F, N>,
141 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
142 //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
143 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
144 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
145 PlotLookup : Plotting<N>, 117 PlotLookup : Plotting<N>,
146 RNDM<F, N> : SpikeMerging<F>, 118 RNDM<F, N> : SpikeMerging<F>,
147 Reg : SlidingRegTerm<F, N>, 119 Reg : SlidingRegTerm<F, N>,
120 P : ProxPenalty<F, S, Reg, N>,
148 // KOpM : Linear<RNDM<F, N>, Codomain=Y> 121 // KOpM : Linear<RNDM<F, N>, Codomain=Y>
149 // + GEMV<F, RNDM<F, N>> 122 // + GEMV<F, RNDM<F, N>>
150 // + Preadjointable< 123 // + Preadjointable<
151 // RNDM<F, N>, Y, 124 // RNDM<F, N>, Y,
152 // PreadjointCodomain = BTFN<F, GA, BTA, N>, 125 // PreadjointCodomain = S,
153 // > 126 // >
154 // + TransportLipschitz<L2Squared, FloatType=F> 127 // + TransportLipschitz<L2Squared, FloatType=F>
155 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, 128 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
156 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, 129 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>,
157 // Since Z is Hilbert, we may just as well use adjoints for K_z. 130 // Since Z is Hilbert, we may just as well use adjoints for K_z.
183 let mut γ1 = DiscreteMeasure::new(); 156 let mut γ1 = DiscreteMeasure::new();
184 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); 157 let mut residual = calculate_residual(Pair(&μ, &z), opA, b);
185 let zero_z = z.similar_origin(); 158 let zero_z = z.similar_origin();
186 159
187 // Set up parameters 160 // Set up parameters
188 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
189 // TODO: maybe this PairNorm doesn't make sense here? 161 // TODO: maybe this PairNorm doesn't make sense here?
190 let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); 162 let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2);
191 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); 163 let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared);
192 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); 164 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt();
193 let nKz = opKz.opnorm_bound(L2, L2); 165 let nKz = opKz.opnorm_bound(L2, L2);
194 let ℓ = 0.0; 166 let ℓ = 0.0;
195 let opIdZ = IdOp::new(); 167 let opIdZ = IdOp::new();
196 let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); 168 let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap();
197 // We need to satisfy 169 // We need to satisfy
198 // 170 //
199 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 171 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
200 // ^^^^^^^^^^^^^^^^^^^^^^^^^ 172 // ^^^^^^^^^^^^^^^^^^^^^^^^^
201 // with 1 > σ_p L_z and 1 > τ L. 173 // with 1 > σ_p L_z and 1 > τ L.
276 v, &config.transport, 248 v, &config.transport,
277 ); 249 );
278 250
279 // Solve finite-dimensional subproblem several times until the dual variable for the 251 // Solve finite-dimensional subproblem several times until the dual variable for the
280 // regularisation term conforms to the assumptions made for the transport above. 252 // regularisation term conforms to the assumptions made for the transport above.
281 let (d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { 253 let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop {
282 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 254 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
283 let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), 255 let residual_μ̆ = calculate_residual2(Pair(&γ1, &z),
284 Pair(&μ_base_minus_γ0, &zero_z), 256 Pair(&μ_base_minus_γ0, &zero_z),
285 opA, b); 257 opA, b);
286 let Pair(τv̆, τz) = opA.preadjoint().apply(residual_μ̆ * τ); 258 let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ);
287 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); 259 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
288 260
289 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 261 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
290 let (d, within_tolerances) = insert_and_reweigh( 262 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
291 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), 263 &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0),
292 op𝒟, op𝒟norm,
293 τ, ε, &config.insertion, 264 τ, ε, &config.insertion,
294 &reg, &state, &mut stats, 265 &reg, &state, &mut stats,
295 ); 266 );
296 267
297 // A posteriori transport adaptation. 268 // A posteriori transport adaptation.
298 // TODO: this does not properly treat v^{k+1} - v̆^k that depends on z^{k+1}! 269 // TODO: this does not properly treat v^{k+1} - v̆^k that depends on z^{k+1}!
299 if aposteriori_transport( 270 if aposteriori_transport(
300 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, 271 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
301 ε, &config.transport 272 ε, &config.transport
302 ) { 273 ) {
303 break 'adapt_transport (d, within_tolerances, Pair(τv̆, τz)) 274 break 'adapt_transport (maybe_d, within_tolerances, τv̆z)
304 } 275 }
305 }; 276 };
306 277
307 stats.untransported_fraction = Some({ 278 stats.untransported_fraction = Some({
308 assert_eq!(μ_base_masses.len(), γ1.len()); 279 assert_eq!(μ_base_masses.len(), γ1.len());
362 // Give statistics if requested 333 // Give statistics if requested
363 let iter = state.iteration(); 334 let iter = state.iteration();
364 stats.this_iters += 1; 335 stats.this_iters += 1;
365 336
366 state.if_verbose(|| { 337 state.if_verbose(|| {
367 plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); 338 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
368 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) 339 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
369 }); 340 });
370 341
371 // Update main tolerance for next iteration 342 // Update main tolerance for next iteration
372 ε = tolerance.update(ε, iter); 343 ε = tolerance.update(ε, iter);

mercurial