src/radon_fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
equal deleted inserted replaced
34:efa60bc4f743 35:b087e3eab191
9 use colored::Colorize; 9 use colored::Colorize;
10 use nalgebra::DVector; 10 use nalgebra::DVector;
11 11
12 use alg_tools::iterate::{ 12 use alg_tools::iterate::{
13 AlgIteratorFactory, 13 AlgIteratorFactory,
14 AlgIteratorState, 14 AlgIteratorIteration,
15 AlgIterator
15 }; 16 };
16 use alg_tools::euclidean::Euclidean; 17 use alg_tools::euclidean::Euclidean;
17 use alg_tools::linops::Apply; 18 use alg_tools::linops::Mapping;
18 use alg_tools::sets::Cube; 19 use alg_tools::sets::Cube;
19 use alg_tools::loc::Loc; 20 use alg_tools::loc::Loc;
20 use alg_tools::bisection_tree::{ 21 use alg_tools::bisection_tree::{
21 BTFN, 22 BTFN,
22 Bounds, 23 Bounds,
27 SupportGenerator, 28 SupportGenerator,
28 LocalAnalysis, 29 LocalAnalysis,
29 }; 30 };
30 use alg_tools::mapping::RealMapping; 31 use alg_tools::mapping::RealMapping;
31 use alg_tools::nalgebra_support::ToNalgebraRealField; 32 use alg_tools::nalgebra_support::ToNalgebraRealField;
33 use alg_tools::norms::L2;
32 34
33 use crate::types::*; 35 use crate::types::*;
34 use crate::measures::{ 36 use crate::measures::{
37 RNDM,
35 DiscreteMeasure, 38 DiscreteMeasure,
36 DeltaMeasure, 39 DeltaMeasure,
40 Radon,
37 }; 41 };
38 use crate::measures::merging::{ 42 use crate::measures::merging::{
39 SpikeMergingMethod, 43 SpikeMergingMethod,
40 SpikeMerging, 44 SpikeMerging,
41 }; 45 };
52 DataTerm, 56 DataTerm,
53 }; 57 };
54 58
55 use crate::fb::{ 59 use crate::fb::{
56 FBGenericConfig, 60 FBGenericConfig,
57 postprocess 61 postprocess,
58 }; 62 prune_with_stats
59 63 };
60 /// Settings for [`pointsource_fb_reg`]. 64
65 /// Settings for [`pointsource_radon_fb_reg`].
61 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 66 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
62 #[serde(default)] 67 #[serde(default)]
63 pub struct RadonFBConfig<F : Float> { 68 pub struct RadonFBConfig<F : Float> {
64 /// Step length scaling 69 /// Step length scaling
65 pub τ0 : F, 70 pub τ0 : F,
77 } 82 }
78 } 83 }
79 84
80 #[replace_float_literals(F::cast_from(literal))] 85 #[replace_float_literals(F::cast_from(literal))]
81 pub(crate) fn insert_and_reweigh< 86 pub(crate) fn insert_and_reweigh<
82 'a, F, GA, BTA, S, Reg, State, const N : usize 87 'a, F, GA, BTA, S, Reg, I, const N : usize
83 >( 88 >(
84 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 89 μ : &mut RNDM<F, N>,
85 minus_τv : &mut BTFN<F, GA, BTA, N>, 90 τv : &mut BTFN<F, GA, BTA, N>,
86 μ_base : &mut DiscreteMeasure<Loc<F, N>, F>, 91 μ_base : &mut RNDM<F, N>,
87 _ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, 92 //_ν_delta: Option<&RNDM<F, N>>,
88 τ : F, 93 τ : F,
89 ε : F, 94 ε : F,
90 config : &FBGenericConfig<F>, 95 config : &FBGenericConfig<F>,
91 reg : &Reg, 96 reg : &Reg,
92 _state : &State, 97 _state : &AlgIteratorIteration<I>,
93 stats : &mut IterInfo<F, N>, 98 stats : &mut IterInfo<F, N>,
94 ) 99 )
95 where F : Float + ToNalgebraRealField, 100 where F : Float + ToNalgebraRealField,
96 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 101 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
97 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 102 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
98 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 103 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
99 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 104 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
100 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 105 RNDM<F, N> : SpikeMerging<F>,
101 Reg : RegTerm<F, N>, 106 Reg : RegTerm<F, N>,
102 State : AlgIteratorState { 107 I : AlgIterator {
103 108
104 'i_and_w: for i in 0..=1 { 109 'i_and_w: for i in 0..=1 {
105 // Optimise weights 110 // Optimise weights
106 if μ.len() > 0 { 111 if μ.len() > 0 {
107 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 112 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
108 // from the beginning of the iteration are all contained in the immutable c and g. 113 // from the beginning of the iteration are all contained in the immutable c and g.
114 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
115 // problems have not yet been updated to sign change.
109 let g̃ = DVector::from_iterator(μ.len(), 116 let g̃ = DVector::from_iterator(μ.len(),
110 μ.iter_locations() 117 μ.iter_locations()
111 .map(|ζ| F::to_nalgebra_mixed(minus_τv.apply(ζ)))); 118 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
112 let mut x = μ.masses_dvector(); 119 let mut x = μ.masses_dvector();
113 let y = μ_base.masses_dvector(); 120 let y = μ_base.masses_dvector();
114 121
115 // Solve finite-dimensional subproblem. 122 // Solve finite-dimensional subproblem.
116 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); 123 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
120 } 127 }
121 128
122 if i>0 { 129 if i>0 {
123 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. 130 // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
124 //let n = μ.dist_matching(μ_base); 131 //let n = μ.dist_matching(μ_base);
125 //println!("{:?}", reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n)); 132 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
126 break 'i_and_w 133 break 'i_and_w
127 } 134 }
128 135
129 // Calculate ‖μ - μ_base‖_ℳ 136 // Calculate ‖μ - μ_base‖_ℳ
130 let n = μ.dist_matching(μ_base); 137 let n = μ.dist_matching(μ_base);
131 138
132 // Find a spike to insert, if needed. 139 // Find a spike to insert, if needed.
133 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, 140 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
134 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. 141 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
135 match reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n) { 142 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
136 None => { break 'i_and_w }, 143 None => { break 'i_and_w },
137 Some((ξ, _v_ξ, _in_bounds)) => { 144 Some((ξ, _v_ξ, _in_bounds)) => {
138 // Weight is found out by running the finite-dimensional optimisation algorithm 145 // Weight is found out by running the finite-dimensional optimisation algorithm
139 // above 146 // above
140 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 147 *μ += DeltaMeasure { x : ξ, α : 0.0 };
141 *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; 148 *μ_base += DeltaMeasure { x : ξ, α : 0.0 };
149 stats.inserted += 1;
142 } 150 }
143 }; 151 };
144 } 152 }
145 } 153 }
146 154
147 #[replace_float_literals(F::cast_from(literal))]
148 pub(crate) fn prune_and_maybe_simple_merge<
149 'a, F, GA, BTA, S, Reg, State, const N : usize
150 >(
151 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
152 minus_τv : &mut BTFN<F, GA, BTA, N>,
153 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
154 τ : F,
155 ε : F,
156 config : &FBGenericConfig<F>,
157 reg : &Reg,
158 state : &State,
159 stats : &mut IterInfo<F, N>,
160 )
161 where F : Float + ToNalgebraRealField,
162 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
163 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
164 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
165 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
166 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
167 Reg : RegTerm<F, N>,
168 State : AlgIteratorState {
169
170 assert!(μ_base.len() <= μ.len());
171
172 if state.iteration() % config.merge_every == 0 {
173 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
174 // Important: μ_candidate's new points are afterwards,
175 // and do not conflict with μ_base.
176 // TODO: could simplify to requiring μ_base instead of μ_radon.
177 // but may complicate with sliding base's exgtra points that need to be
178 // after μ_candidate's extra points.
179 // TODO: doesn't seem to work, maybe need to merge μ_base as well?
180 // Although that doesn't seem to make sense.
181 let μ_radon = μ_candidate.sub_matching(μ_base);
182 reg.verify_merge_candidate_radonsq(minus_τv, μ_candidate, τ, ε, &config, &μ_radon)
183 //let n = μ_candidate.dist_matching(μ_base);
184 //reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n).is_none()
185 });
186 }
187
188 let n_before_prune = μ.len();
189 μ.prune();
190 debug_assert!(μ.len() <= n_before_prune);
191 stats.pruned += n_before_prune - μ.len();
192 }
193
194 155
195 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. 156 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting.
196 /// 157 ///
197 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 158 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
198 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 159 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
199 /// Finally, the `iterator` is an outer loop verbosity and iteration count control 160 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
200 /// as documented in [`alg_tools::iterate`]. 161 /// as documented in [`alg_tools::iterate`].
201 /// 162 ///
202 /// For details on the mathematical formulation, see the [module level](self) documentation. 163 /// For details on the mathematical formulation, see the [module level](self) documentation.
217 b : &A::Observable, 178 b : &A::Observable,
218 reg : Reg, 179 reg : Reg,
219 fbconfig : &RadonFBConfig<F>, 180 fbconfig : &RadonFBConfig<F>,
220 iterator : I, 181 iterator : I,
221 mut _plotter : SeqPlotter<F, N>, 182 mut _plotter : SeqPlotter<F, N>,
222 ) -> DiscreteMeasure<Loc<F, N>, F> 183 ) -> RNDM<F, N>
223 where F : Float + ToNalgebraRealField, 184 where F : Float + ToNalgebraRealField,
224 I : AlgIteratorFactory<IterInfo<F, N>>, 185 I : AlgIteratorFactory<IterInfo<F, N>>,
225 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 186 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
226 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
227 A::Observable : std::ops::MulAssign<F>,
228 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 187 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
229 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, 188 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
230 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 189 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
231 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 190 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
232 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 191 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
233 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 192 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
234 PlotLookup : Plotting<N>, 193 RNDM<F, N> : SpikeMerging<F>,
235 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
236 Reg : RegTerm<F, N> { 194 Reg : RegTerm<F, N> {
237 195
238 // Set up parameters 196 // Set up parameters
239 let config = &fbconfig.insertion; 197 let config = &fbconfig.insertion;
240 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ 198 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
241 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such 199 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
242 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. 200 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
243 let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); 201 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
244 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 202 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
245 // by τ compared to the conditional gradient approach. 203 // by τ compared to the conditional gradient approach.
246 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 204 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
247 let mut ε = tolerance.initial(); 205 let mut ε = tolerance.initial();
248 206
249 // Initialise iterates 207 // Initialise iterates
250 let mut μ = DiscreteMeasure::new(); 208 let mut μ = DiscreteMeasure::new();
251 let mut residual = -b; 209 let mut residual = -b;
210
211 // Statistics
212 let full_stats = |residual : &A::Observable,
213 μ : &RNDM<F, N>,
214 ε, stats| IterInfo {
215 value : residual.norm2_squared_div2() + reg.apply(μ),
216 n_spikes : μ.len(),
217 ε,
218 // postprocessing: config.postprocessing.then(|| μ.clone()),
219 .. stats
220 };
252 let mut stats = IterInfo::new(); 221 let mut stats = IterInfo::new();
253 222
254 // Run the algorithm 223 // Run the algorithm
255 iterator.iterate(|state| { 224 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
256 // Calculate smooth part of surrogate model. 225 // Calculate smooth part of surrogate model.
257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 226 let mut τv = opA.preadjoint().apply(residual * τ);
258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
259 // the residual and replacing it below before the end of this closure.
260 residual *= -τ;
261 let r = std::mem::replace(&mut residual, opA.empty_observable());
262 let mut minus_τv = opA.preadjoint().apply(r);
263 227
264 // Save current base point 228 // Save current base point
265 let mut μ_base = μ.clone(); 229 let mut μ_base = μ.clone();
266 230
267 // Insert and reweigh 231 // Insert and reweigh
268 insert_and_reweigh( 232 insert_and_reweigh(
269 &mut μ, &mut minus_τv, &mut μ_base, None, 233 &mut μ, &mut τv, &mut μ_base, //None,
270 τ, ε, 234 τ, ε,
271 config, &reg, state, &mut stats 235 config, &reg, &state, &mut stats
272 ); 236 );
273 237
274 // Prune and possibly merge spikes 238 // Prune and possibly merge spikes
275 prune_and_maybe_simple_merge( 239 assert!(μ_base.len() <= μ.len());
276 &mut μ, &mut minus_τv, &μ_base, 240 if config.merge_now(&state) {
277 τ, ε, 241 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
278 config, &reg, state, &mut stats 242 // Important: μ_candidate's new points are afterwards,
279 ); 243 // and do not conflict with μ_base.
244 // TODO: could simplify to requiring μ_base instead of μ_radon.
245 // but may complicate with sliding base's exgtra points that need to be
246 // after μ_candidate's extra points.
247 // TODO: doesn't seem to work, maybe need to merge μ_base as well?
248 // Although that doesn't seem to make sense.
249 let μ_radon = μ_candidate.sub_matching(&μ_base);
250 reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon)
251 //let n = μ_candidate.dist_matching(μ_base);
252 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
253 });
254 }
255 stats.pruned += prune_with_stats(&mut μ);
280 256
281 // Update residual 257 // Update residual
282 residual = calculate_residual(&μ, opA, b); 258 residual = calculate_residual(&μ, opA, b);
283 259
260 let iter = state.iteration();
261 stats.this_iters += 1;
262
263 // Give statistics if needed
264 state.if_verbose(|| {
265 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
266 });
267
284 // Update main tolerance for next iteration 268 // Update main tolerance for next iteration
285 let ε_prev = ε; 269 ε = tolerance.update(ε, iter);
286 ε = tolerance.update(ε, state.iteration()); 270 }
287 stats.this_iters += 1;
288
289 // Give function value if needed
290 state.if_verbose(|| {
291 // Plot if so requested
292 // plotter.plot_spikes(
293 // format!("iter {} end;", state.iteration()), &d,
294 // "start".to_string(), Some(&minus_τv),
295 // reg.target_bounds(τ, ε_prev), &μ,
296 // );
297 // Calculate mean inner iterations and reset relevant counters.
298 // Return the statistics
299 let res = IterInfo {
300 value : residual.norm2_squared_div2() + reg.apply(&μ),
301 n_spikes : μ.len(),
302 ε : ε_prev,
303 postprocessing: config.postprocessing.then(|| μ.clone()),
304 .. stats
305 };
306 stats = IterInfo::new();
307 res
308 })
309 });
310 271
311 postprocess(μ, config, L2Squared, opA, b) 272 postprocess(μ, config, L2Squared, opA, b)
312 } 273 }
313 274
314 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. 275 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting.
315 /// 276 ///
316 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 277 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
317 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 278 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
318 /// Finally, the `iterator` is an outer loop verbosity and iteration count control 279 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
319 /// as documented in [`alg_tools::iterate`]. 280 /// as documented in [`alg_tools::iterate`].
320 /// 281 ///
321 /// For details on the mathematical formulation, see the [module level](self) documentation. 282 /// For details on the mathematical formulation, see the [module level](self) documentation.
335 opA : &'a A, 296 opA : &'a A,
336 b : &A::Observable, 297 b : &A::Observable,
337 reg : Reg, 298 reg : Reg,
338 fbconfig : &RadonFBConfig<F>, 299 fbconfig : &RadonFBConfig<F>,
339 iterator : I, 300 iterator : I,
340 mut _plotter : SeqPlotter<F, N>, 301 mut plotter : SeqPlotter<F, N>,
341 ) -> DiscreteMeasure<Loc<F, N>, F> 302 ) -> RNDM<F, N>
342 where F : Float + ToNalgebraRealField, 303 where F : Float + ToNalgebraRealField,
343 I : AlgIteratorFactory<IterInfo<F, N>>, 304 I : AlgIteratorFactory<IterInfo<F, N>>,
344 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 305 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
345 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
346 A::Observable : std::ops::MulAssign<F>,
347 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 306 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
348 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, 307 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
349 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 308 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
350 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 309 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
351 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 310 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
352 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 311 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
353 PlotLookup : Plotting<N>, 312 PlotLookup : Plotting<N>,
354 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 313 RNDM<F, N> : SpikeMerging<F>,
355 Reg : RegTerm<F, N> { 314 Reg : RegTerm<F, N> {
356 315
357 // Set up parameters 316 // Set up parameters
358 let config = &fbconfig.insertion; 317 let config = &fbconfig.insertion;
359 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ 318 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
360 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such 319 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
361 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. 320 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
362 let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); 321 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
363 let mut λ = 1.0; 322 let mut λ = 1.0;
364 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 323 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
365 // by τ compared to the conditional gradient approach. 324 // by τ compared to the conditional gradient approach.
366 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 325 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
367 let mut ε = tolerance.initial(); 326 let mut ε = tolerance.initial();
368 327
369 // Initialise iterates 328 // Initialise iterates
370 let mut μ = DiscreteMeasure::new(); 329 let mut μ = DiscreteMeasure::new();
371 let mut μ_prev = DiscreteMeasure::new(); 330 let mut μ_prev = DiscreteMeasure::new();
372 let mut residual = -b; 331 let mut residual = -b;
332 let mut warned_merging = false;
333
334 // Statistics
335 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
336 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
337 n_spikes : ν.len(),
338 ε,
339 // postprocessing: config.postprocessing.then(|| ν.clone()),
340 .. stats
341 };
373 let mut stats = IterInfo::new(); 342 let mut stats = IterInfo::new();
374 let mut warned_merging = false;
375 343
376 // Run the algorithm 344 // Run the algorithm
377 iterator.iterate(|state| { 345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
378 // Calculate smooth part of surrogate model. 346 // Calculate smooth part of surrogate model.
379 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 347 let mut τv = opA.preadjoint().apply(residual * τ);
380 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
381 // the residual and replacing it below before the end of this closure.
382 residual *= -τ;
383 let r = std::mem::replace(&mut residual, opA.empty_observable());
384 let mut minus_τv = opA.preadjoint().apply(r);
385 348
386 // Save current base point 349 // Save current base point
387 let mut μ_base = μ.clone(); 350 let mut μ_base = μ.clone();
388 351
389 // Insert new spikes and reweigh 352 // Insert new spikes and reweigh
390 insert_and_reweigh( 353 insert_and_reweigh(
391 &mut μ, &mut minus_τv, &mut μ_base, None, 354 &mut μ, &mut τv, &mut μ_base, //None,
392 τ, ε, 355 τ, ε,
393 config, &reg, state, &mut stats 356 config, &reg, &state, &mut stats
394 ); 357 );
395 358
396 // (Do not) merge spikes. 359 // (Do not) merge spikes.
397 if state.iteration() % config.merge_every == 0 { 360 if config.merge_now(&state) {
398 match config.merging { 361 match config.merging {
399 SpikeMergingMethod::None => { }, 362 SpikeMergingMethod::None => { },
400 _ => if !warned_merging { 363 _ => if !warned_merging {
401 let err = format!("Merging not supported for μFISTA"); 364 let err = format!("Merging not supported for μFISTA");
402 println!("{}", err.red()); 365 println!("{}", err.red());
421 debug_assert!(μ.len() <= n_before_prune); 384 debug_assert!(μ.len() <= n_before_prune);
422 stats.pruned += n_before_prune - μ.len(); 385 stats.pruned += n_before_prune - μ.len();
423 386
424 // Update residual 387 // Update residual
425 residual = calculate_residual(&μ, opA, b); 388 residual = calculate_residual(&μ, opA, b);
389
390 let iter = state.iteration();
391 stats.this_iters += 1;
392
393 // Give statistics if needed
394 state.if_verbose(|| {
395 plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev);
396 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
397 });
426 398
427 // Update main tolerance for next iteration 399 // Update main tolerance for next iteration
428 let ε_prev = ε; 400 ε = tolerance.update(ε, iter);
429 ε = tolerance.update(ε, state.iteration()); 401 }
430 stats.this_iters += 1;
431
432 // Give function value if needed
433 state.if_verbose(|| {
434 // Plot if so requested
435 // plotter.plot_spikes(
436 // format!("iter {} end;", state.iteration()), &d,
437 // "start".to_string(), Some(&minus_τv),
438 // reg.target_bounds(τ, ε_prev), &μ_prev,
439 // );
440 // Calculate mean inner iterations and reset relevant counters.
441 // Return the statistics
442 let res = IterInfo {
443 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
444 n_spikes : μ_prev.len(),
445 ε : ε_prev,
446 postprocessing: config.postprocessing.then(|| μ_prev.clone()),
447 .. stats
448 };
449 stats = IterInfo::new();
450 res
451 })
452 });
453 402
454 postprocess(μ_prev, config, L2Squared, opA, b) 403 postprocess(μ_prev, config, L2Squared, opA, b)
455 } 404 }

mercurial