src/forward_pdps.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
6 use numeric_literals::replace_float_literals; 6 use numeric_literals::replace_float_literals;
7 use serde::{Serialize, Deserialize}; 7 use serde::{Serialize, Deserialize};
8 8
9 use alg_tools::iterate::AlgIteratorFactory; 9 use alg_tools::iterate::AlgIteratorFactory;
10 use alg_tools::euclidean::Euclidean; 10 use alg_tools::euclidean::Euclidean;
11 use alg_tools::sets::Cube; 11 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
12 use alg_tools::loc::Loc;
13 use alg_tools::mapping::{Mapping, Instance};
14 use alg_tools::norms::Norm; 12 use alg_tools::norms::Norm;
15 use alg_tools::direct_product::Pair; 13 use alg_tools::direct_product::Pair;
16 use alg_tools::bisection_tree::{
17 BTFN,
18 PreBTFN,
19 Bounds,
20 BTNodeLookup,
21 BTNode,
22 BTSearch,
23 P2Minimise,
24 SupportGenerator,
25 LocalAnalysis,
26 //Bounded,
27 };
28 use alg_tools::mapping::RealMapping;
29 use alg_tools::nalgebra_support::ToNalgebraRealField; 14 use alg_tools::nalgebra_support::ToNalgebraRealField;
30 use alg_tools::linops::{ 15 use alg_tools::linops::{
31 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, 16 BoundedLinear, AXPY, GEMV, Adjointable, IdOp,
32 }; 17 };
33 use alg_tools::convex::{Conjugable, Prox}; 18 use alg_tools::convex::{Conjugable, Prox};
34 use alg_tools::norms::{L2, Linfinity, PairNorm}; 19 use alg_tools::norms::{L2, PairNorm};
35 20
36 use crate::types::*; 21 use crate::types::*;
37 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 22 use crate::measures::{DiscreteMeasure, Radon, RNDM};
38 use crate::measures::merging::SpikeMerging; 23 use crate::measures::merging::SpikeMerging;
39 use crate::forward_model::{ 24 use crate::forward_model::{
40 ForwardModel, 25 ForwardModel,
41 AdjointProductPairBoundedBy, 26 AdjointProductPairBoundedBy,
42 }; 27 };
43 use crate::seminorms::DiscreteMeasureOp;
44 use crate::plot::{ 28 use crate::plot::{
45 SeqPlotter, 29 SeqPlotter,
46 Plotting, 30 Plotting,
47 PlotLookup 31 PlotLookup
48 }; 32 };
81 65
82 /// Iteratively solve the pointsource localisation with an additional variable 66 /// Iteratively solve the pointsource localisation with an additional variable
83 /// using primal-dual proximal splitting with a forward step. 67 /// using primal-dual proximal splitting with a forward step.
84 #[replace_float_literals(F::cast_from(literal))] 68 #[replace_float_literals(F::cast_from(literal))]
85 pub fn pointsource_forward_pdps_pair< 69 pub fn pointsource_forward_pdps_pair<
86 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize 70 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize
87 >( 71 >(
88 opA : &'a A, 72 opA : &A,
89 b : &A::Observable, 73 b : &A::Observable,
90 reg : Reg, 74 reg : Reg,
91 op𝒟 : &'a 𝒟, 75 prox_penalty : &P,
92 config : &ForwardPDPSConfig<F>, 76 config : &ForwardPDPSConfig<F>,
93 iterator : I, 77 iterator : I,
94 mut plotter : SeqPlotter<F, N>, 78 mut plotter : SeqPlotter<F, N>,
95 //opKμ : KOpM, 79 //opKμ : KOpM,
96 opKz : &KOpZ, 80 opKz : &KOpZ,
100 mut y : Y, 84 mut y : Y,
101 ) -> MeasureZ<F, Z, N> 85 ) -> MeasureZ<F, Z, N>
102 where 86 where
103 F : Float + ToNalgebraRealField, 87 F : Float + ToNalgebraRealField,
104 I : AlgIteratorFactory<IterInfo<F, N>>, 88 I : AlgIteratorFactory<IterInfo<F, N>>,
105 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
106 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
107 A : ForwardModel< 89 A : ForwardModel<
108 MeasureZ<F, Z, N>, 90 MeasureZ<F, Z, N>,
109 F, 91 F,
110 PairNorm<Radon, L2, L2>, 92 PairNorm<Radon, L2, L2>,
111 PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, 93 PreadjointCodomain = Pair<S, Z>,
112 > 94 >
113 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, 95 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>,
114 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 96 S: DifferentiableRealMapping<F, N>,
115 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 97 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
116 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
117 Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
118 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
119 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
120 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
121 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
122 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
123 PlotLookup : Plotting<N>, 98 PlotLookup : Plotting<N>,
124 RNDM<F, N> : SpikeMerging<F>, 99 RNDM<F, N> : SpikeMerging<F>,
125 Reg : RegTerm<F, N>, 100 Reg : RegTerm<F, N>,
101 P : ProxPenalty<F, S, Reg, N>,
126 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> 102 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y>
127 + GEMV<F, Z> 103 + GEMV<F, Z>
128 + Adjointable<Z, Y, AdjointCodomain = Z>, 104 + Adjointable<Z, Y, AdjointCodomain = Z>,
129 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, 105 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>,
130 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, 106 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd,
148 // Initialise iterates 124 // Initialise iterates
149 let mut μ = DiscreteMeasure::new(); 125 let mut μ = DiscreteMeasure::new();
150 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); 126 let mut residual = calculate_residual(Pair(&μ, &z), opA, b);
151 127
152 // Set up parameters 128 // Set up parameters
153 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); 129 let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt();
154 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt();
155 let nKz = opKz.opnorm_bound(L2, L2); 130 let nKz = opKz.opnorm_bound(L2, L2);
156 let opIdZ = IdOp::new(); 131 let opIdZ = IdOp::new();
157 let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); 132 let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap();
158 // We need to satisfy 133 // We need to satisfy
159 // 134 //
160 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 135 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
161 // ^^^^^^^^^^^^^^^^^^^^^^^^^ 136 // ^^^^^^^^^^^^^^^^^^^^^^^^^
162 // with 1 > σ_p L_z and 1 > τ L. 137 // with 1 > σ_p L_z and 1 > τ L.
194 let mut stats = IterInfo::new(); 169 let mut stats = IterInfo::new();
195 170
196 // Run the algorithm 171 // Run the algorithm
197 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { 172 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
198 // Calculate initial transport 173 // Calculate initial transport
199 let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); 174 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ);
200 let z_base = z.clone(); 175 let z_base = z.clone();
201 let μ_base = μ.clone(); 176 let μ_base = μ.clone();
202 177
203 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 178 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
204 let (d, _within_tolerances) = insert_and_reweigh( 179 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
205 &mut μ, &τv, &μ_base, None, 180 &mut μ, &mut τv, &μ_base, None,
206 op𝒟, op𝒟norm,
207 τ, ε, &config.insertion, 181 τ, ε, &config.insertion,
208 &reg, &state, &mut stats, 182 &reg, &state, &mut stats,
209 ); 183 );
210 184
211 // // Merge spikes. 185 // // Merge spikes.
246 // Give statistics if requested 220 // Give statistics if requested
247 let iter = state.iteration(); 221 let iter = state.iteration();
248 stats.this_iters += 1; 222 stats.this_iters += 1;
249 223
250 state.if_verbose(|| { 224 state.if_verbose(|| {
251 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); 225 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
252 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) 226 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
253 }); 227 });
254 228
255 // Update main tolerance for next iteration 229 // Update main tolerance for next iteration
256 ε = tolerance.update(ε, iter); 230 ε = tolerance.update(ε, iter);

mercurial