src/frank_wolfe.rs

changeset 52
f0e8704d3f0e
parent 39
6316d68b58af
equal deleted inserted replaced
31:6105b5cd8d89 52:f0e8704d3f0e
12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, 12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_,
13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). 13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205).
14 */ 14 */
15 15
16 use numeric_literals::replace_float_literals; 16 use numeric_literals::replace_float_literals;
17 use nalgebra::{DMatrix, DVector};
17 use serde::{Serialize, Deserialize}; 18 use serde::{Serialize, Deserialize};
18 //use colored::Colorize; 19 //use colored::Colorize;
19 20
20 use alg_tools::iterate::{ 21 use alg_tools::iterate::{
21 AlgIteratorFactory, 22 AlgIteratorFactory,
22 AlgIteratorState,
23 AlgIteratorOptions, 23 AlgIteratorOptions,
24 ValueIteratorFactory, 24 ValueIteratorFactory,
25 }; 25 };
26 use alg_tools::euclidean::Euclidean; 26 use alg_tools::euclidean::Euclidean;
27 use alg_tools::norms::Norm; 27 use alg_tools::norms::Norm;
28 use alg_tools::linops::Apply; 28 use alg_tools::linops::Mapping;
29 use alg_tools::sets::Cube; 29 use alg_tools::sets::Cube;
30 use alg_tools::loc::Loc; 30 use alg_tools::loc::Loc;
31 use alg_tools::bisection_tree::{ 31 use alg_tools::bisection_tree::{
32 BTFN, 32 BTFN,
33 Bounds, 33 Bounds,
38 SupportGenerator, 38 SupportGenerator,
39 LocalAnalysis, 39 LocalAnalysis,
40 }; 40 };
41 use alg_tools::mapping::RealMapping; 41 use alg_tools::mapping::RealMapping;
42 use alg_tools::nalgebra_support::ToNalgebraRealField; 42 use alg_tools::nalgebra_support::ToNalgebraRealField;
43 use alg_tools::norms::L2;
43 44
44 use crate::types::*; 45 use crate::types::*;
45 use crate::measures::{ 46 use crate::measures::{
47 RNDM,
46 DiscreteMeasure, 48 DiscreteMeasure,
47 DeltaMeasure, 49 DeltaMeasure,
48 Radon, 50 Radon,
49 }; 51 };
50 use crate::measures::merging::{ 52 use crate::measures::merging::{
66 PlotLookup 68 PlotLookup
67 }; 69 };
68 use crate::regularisation::{ 70 use crate::regularisation::{
69 NonnegRadonRegTerm, 71 NonnegRadonRegTerm,
70 RadonRegTerm, 72 RadonRegTerm,
71 }; 73 RegTerm
72 use crate::fb::RegTerm; 74 };
73 75
74 /// Settings for [`pointsource_fw`]. 76 /// Settings for [`pointsource_fw_reg`].
75 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 77 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
76 #[serde(default)] 78 #[serde(default)]
77 pub struct FWConfig<F : Float> { 79 pub struct FWConfig<F : Float> {
78 /// Tolerance for branch-and-bound new spike location discovery 80 /// Tolerance for branch-and-bound new spike location discovery
79 pub tolerance : Tolerance<F>, 81 pub tolerance : Tolerance<F>,
104 FWConfig { 106 FWConfig {
105 tolerance : Default::default(), 107 tolerance : Default::default(),
106 refinement : Default::default(), 108 refinement : Default::default(),
107 inner : Default::default(), 109 inner : Default::default(),
108 variant : FWVariant::FullyCorrective, 110 variant : FWVariant::FullyCorrective,
109 merging : Default::default(), 111 merging : SpikeMergingMethod { enabled : true, ..Default::default() },
110 } 112 }
111 } 113 }
112 } 114 }
113 115
114 /// Helper struct for pre-initialising the finite-dimensional subproblems solver 116 pub trait FindimQuadraticModel<Domain, F> : ForwardModel<DiscreteMeasure<Domain, F>, F>
115 /// [`prepare_optimise_weights`]. 117 where
116 /// 118 F : Float + ToNalgebraRealField,
117 /// The pre-initialisation is done by [`prepare_optimise_weights`]. 119 Domain : Clone + PartialEq,
120 {
121 /// Return A_*A and A_* b
122 fn findim_quadratic_model(
123 &self,
124 μ : &DiscreteMeasure<Domain, F>,
125 b : &Self::Observable
126 ) -> (DMatrix<F::MixedType>, DVector<F::MixedType>);
127 }
128
129 /// Helper struct for pre-initialising the finite-dimensional subproblem solver.
118 pub struct FindimData<F : Float> { 130 pub struct FindimData<F : Float> {
119 /// ‖A‖^2 131 /// ‖A‖^2
120 opAnorm_squared : F, 132 opAnorm_squared : F,
121 /// Bound $M_0$ from the Bredies–Pikkarainen article. 133 /// Bound $M_0$ from the Bredies–Pikkarainen article.
122 m0 : F 134 m0 : F
123 } 135 }
124 136
125 /// Trait for finite dimensional weight optimisation. 137 /// Trait for finite dimensional weight optimisation.
126 pub trait WeightOptim< 138 pub trait WeightOptim<
127 F : Float + ToNalgebraRealField, 139 F : Float + ToNalgebraRealField,
128 A : ForwardModel<Loc<F, N>, F>, 140 A : ForwardModel<RNDM<F, N>, F>,
129 I : AlgIteratorFactory<F>, 141 I : AlgIteratorFactory<F>,
130 const N : usize 142 const N : usize
131 > { 143 > {
132 144
133 /// Return a pre-initialisation struct for [`Self::optimise_weights`]. 145 /// Return a pre-initialisation struct for [`Self::optimise_weights`].
152 /// prepared using [`Self::prepare_optimise_weights`]: 164 /// prepared using [`Self::prepare_optimise_weights`]:
153 /// 165 ///
154 /// Returns the number of iterations taken by the method configured in `inner`. 166 /// Returns the number of iterations taken by the method configured in `inner`.
155 fn optimise_weights<'a>( 167 fn optimise_weights<'a>(
156 &self, 168 &self,
157 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 169 μ : &mut RNDM<F, N>,
158 opA : &'a A, 170 opA : &'a A,
159 b : &A::Observable, 171 b : &A::Observable,
160 findim_data : &FindimData<F>, 172 findim_data : &FindimData<F>,
161 inner : &InnerSettings<F>, 173 inner : &InnerSettings<F>,
162 iterator : I 174 iterator : I
164 } 176 }
165 177
166 /// Trait for regularisation terms supported by [`pointsource_fw_reg`]. 178 /// Trait for regularisation terms supported by [`pointsource_fw_reg`].
167 pub trait RegTermFW< 179 pub trait RegTermFW<
168 F : Float + ToNalgebraRealField, 180 F : Float + ToNalgebraRealField,
169 A : ForwardModel<Loc<F, N>, F>, 181 A : ForwardModel<RNDM<F, N>, F>,
170 I : AlgIteratorFactory<F>, 182 I : AlgIteratorFactory<F>,
171 const N : usize 183 const N : usize
172 > : RegTerm<F, N> 184 > : RegTerm<F, N>
173 + WeightOptim<F, A, I, N> 185 + WeightOptim<F, A, I, N>
174 + for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> { 186 + Mapping<RNDM<F, N>, Codomain = F> {
175 187
176 /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted 188 /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted
177 /// into $μ$, as determined by the regulariser. 189 /// into $μ$, as determined by the regulariser.
178 /// 190 ///
179 /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant 191 /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant
186 ) -> (Loc<F, N>, F); 198 ) -> (Loc<F, N>, F);
187 199
188 /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen. 200 /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen.
189 fn relaxed_insert<'a>( 201 fn relaxed_insert<'a>(
190 &self, 202 &self,
191 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 203 μ : &mut RNDM<F, N>,
192 g : &A::PreadjointCodomain, 204 g : &A::PreadjointCodomain,
193 opA : &'a A, 205 opA : &'a A,
194 ξ : Loc<F, N>, 206 ξ : Loc<F, N>,
195 v_ξ : F, 207 v_ξ : F,
196 findim_data : &FindimData<F> 208 findim_data : &FindimData<F>
199 211
200 #[replace_float_literals(F::cast_from(literal))] 212 #[replace_float_literals(F::cast_from(literal))]
201 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> 213 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
202 for RadonRegTerm<F> 214 for RadonRegTerm<F>
203 where I : AlgIteratorFactory<F>, 215 where I : AlgIteratorFactory<F>,
204 A : ForwardModel<Loc<F, N>, F> { 216 A : FindimQuadraticModel<Loc<F, N>, F> {
205 217
206 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { 218 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
207 FindimData{ 219 FindimData{
208 opAnorm_squared : opA.opnorm_bound().powi(2), 220 opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
209 m0 : b.norm2_squared() / (2.0 * self.α()), 221 m0 : b.norm2_squared() / (2.0 * self.α()),
210 } 222 }
211 } 223 }
212 224
213 fn optimise_weights<'a>( 225 fn optimise_weights<'a>(
214 &self, 226 &self,
215 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 227 μ : &mut RNDM<F, N>,
216 opA : &'a A, 228 opA : &'a A,
217 b : &A::Observable, 229 b : &A::Observable,
218 findim_data : &FindimData<F>, 230 findim_data : &FindimData<F>,
219 inner : &InnerSettings<F>, 231 inner : &InnerSettings<F>,
220 iterator : I 232 iterator : I
230 // 2-norm, we estimate 242 // 2-norm, we estimate
231 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2 243 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2
232 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2}, 244 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2},
233 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no 245 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no
234 // square root is needed when we scale: 246 // square root is needed when we scale:
235 let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len())); 247 let normest = findim_data.opAnorm_squared * F::cast_from(μ.len());
236 let iters = quadratic_unconstrained(inner.method, &Ã, &g̃, self.α(), 248 let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x,
237 &mut x, inner_τ, iterator); 249 normest, inner, iterator);
238 // Update masses of μ based on solution of finite-dimensional subproblem. 250 // Update masses of μ based on solution of finite-dimensional subproblem.
239 μ.set_masses_dvector(&x); 251 μ.set_masses_dvector(&x);
240 252
241 iters 253 iters
242 } 254 }
243 } 255 }
244 256
245 #[replace_float_literals(F::cast_from(literal))] 257 #[replace_float_literals(F::cast_from(literal))]
246 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> 258 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
247 for RadonRegTerm<F> 259 for RadonRegTerm<F>
248 where Cube<F, N> : P2Minimise<Loc<F, N>, F>, 260 where
249 I : AlgIteratorFactory<F>, 261 Cube<F, N> : P2Minimise<Loc<F, N>, F>,
250 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 262 I : AlgIteratorFactory<F>,
251 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 263 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
252 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, 264 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
253 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> { 265 A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
266 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
267 // FIXME: the following *should not* be needed, they are already implied
268 RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
269 DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
270 //A : Mapping<RNDM<F, N>, Codomain = A::Observable>,
271 //A : Mapping<DeltaMeasure<Loc<F, N>, F>, Codomain = A::Observable>,
272 {
254 273
255 fn find_insertion( 274 fn find_insertion(
256 &self, 275 &self,
257 g : &mut A::PreadjointCodomain, 276 g : &mut A::PreadjointCodomain,
258 refinement_tolerance : F, 277 refinement_tolerance : F,
267 } 286 }
268 } 287 }
269 288
270 fn relaxed_insert<'a>( 289 fn relaxed_insert<'a>(
271 &self, 290 &self,
272 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 291 μ : &mut RNDM<F, N>,
273 g : &A::PreadjointCodomain, 292 g : &A::PreadjointCodomain,
274 opA : &'a A, 293 opA : &'a A,
275 ξ : Loc<F, N>, 294 ξ : Loc<F, N>,
276 v_ξ : F, 295 v_ξ : F,
277 findim_data : &FindimData<F> 296 findim_data : &FindimData<F>
280 let m0 = findim_data.m0; 299 let m0 = findim_data.m0;
281 let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; 300 let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) };
282 let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; 301 let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ };
283 let δ = DeltaMeasure { x : ξ, α : v }; 302 let δ = DeltaMeasure { x : ξ, α : v };
284 let dp = μ.apply(g) - δ.apply(g); 303 let dp = μ.apply(g) - δ.apply(g);
285 let d = opA.apply(&*μ) - opA.apply(&δ); 304 let d = opA.apply(&*μ) - opA.apply(δ);
286 let r = d.norm2_squared(); 305 let r = d.norm2_squared();
287 let s = if r == 0.0 { 306 let s = if r == 0.0 {
288 1.0 307 1.0
289 } else { 308 } else {
290 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) 309 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r)
296 315
297 #[replace_float_literals(F::cast_from(literal))] 316 #[replace_float_literals(F::cast_from(literal))]
298 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> 317 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
299 for NonnegRadonRegTerm<F> 318 for NonnegRadonRegTerm<F>
300 where I : AlgIteratorFactory<F>, 319 where I : AlgIteratorFactory<F>,
301 A : ForwardModel<Loc<F, N>, F> { 320 A : FindimQuadraticModel<Loc<F, N>, F> {
302 321
303 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { 322 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
304 FindimData{ 323 FindimData{
305 opAnorm_squared : opA.opnorm_bound().powi(2), 324 opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
306 m0 : b.norm2_squared() / (2.0 * self.α()), 325 m0 : b.norm2_squared() / (2.0 * self.α()),
307 } 326 }
308 } 327 }
309 328
310 fn optimise_weights<'a>( 329 fn optimise_weights<'a>(
311 &self, 330 &self,
312 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 331 μ : &mut RNDM<F, N>,
313 opA : &'a A, 332 opA : &'a A,
314 b : &A::Observable, 333 b : &A::Observable,
315 findim_data : &FindimData<F>, 334 findim_data : &FindimData<F>,
316 inner : &InnerSettings<F>, 335 inner : &InnerSettings<F>,
317 iterator : I 336 iterator : I
327 // 2-norm, we estimate 346 // 2-norm, we estimate
328 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2 347 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2
329 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2}, 348 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2},
330 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no 349 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no
331 // square root is needed when we scale: 350 // square root is needed when we scale:
332 let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len())); 351 let normest = findim_data.opAnorm_squared * F::cast_from(μ.len());
333 let iters = quadratic_nonneg(inner.method, &Ã, &g̃, self.α(), 352 let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x,
334 &mut x, inner_τ, iterator); 353 normest, inner, iterator);
335 // Update masses of μ based on solution of finite-dimensional subproblem. 354 // Update masses of μ based on solution of finite-dimensional subproblem.
336 μ.set_masses_dvector(&x); 355 μ.set_masses_dvector(&x);
337 356
338 iters 357 iters
339 } 358 }
340 } 359 }
341 360
342 #[replace_float_literals(F::cast_from(literal))] 361 #[replace_float_literals(F::cast_from(literal))]
343 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> 362 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
344 for NonnegRadonRegTerm<F> 363 for NonnegRadonRegTerm<F>
345 where Cube<F, N> : P2Minimise<Loc<F, N>, F>, 364 where
346 I : AlgIteratorFactory<F>, 365 Cube<F, N> : P2Minimise<Loc<F, N>, F>,
347 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 366 I : AlgIteratorFactory<F>,
348 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 367 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
349 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, 368 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
350 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> { 369 A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
370 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
371 // FIXME: the following *should not* be needed, they are already implied
372 RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
373 DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
374 {
351 375
352 fn find_insertion( 376 fn find_insertion(
353 &self, 377 &self,
354 g : &mut A::PreadjointCodomain, 378 g : &mut A::PreadjointCodomain,
355 refinement_tolerance : F, 379 refinement_tolerance : F,
359 } 383 }
360 384
361 385
362 fn relaxed_insert<'a>( 386 fn relaxed_insert<'a>(
363 &self, 387 &self,
364 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 388 μ : &mut RNDM<F, N>,
365 g : &A::PreadjointCodomain, 389 g : &A::PreadjointCodomain,
366 opA : &'a A, 390 opA : &'a A,
367 ξ : Loc<F, N>, 391 ξ : Loc<F, N>,
368 v_ξ : F, 392 v_ξ : F,
369 findim_data : &FindimData<F> 393 findim_data : &FindimData<F>
399 /// The `opA` parameter is the forward operator $A$, while `b`$ is as in the 423 /// The `opA` parameter is the forward operator $A$, while `b`$ is as in the
400 /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while 424 /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while
401 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to 425 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to
402 /// save intermediate iteration states as images. 426 /// save intermediate iteration states as images.
403 #[replace_float_literals(F::cast_from(literal))] 427 #[replace_float_literals(F::cast_from(literal))]
404 pub fn pointsource_fw_reg<'a, F, I, A, GA, BTA, S, Reg, const N : usize>( 428 pub fn pointsource_fw_reg<F, I, A, GA, BTA, S, Reg, const N : usize>(
405 opA : &'a A, 429 opA : &A,
406 b : &A::Observable, 430 b : &A::Observable,
407 reg : Reg, 431 reg : Reg,
408 //domain : Cube<F, N>, 432 //domain : Cube<F, N>,
409 config : &FWConfig<F>, 433 config : &FWConfig<F>,
410 iterator : I, 434 iterator : I,
411 mut plotter : SeqPlotter<F, N>, 435 mut plotter : SeqPlotter<F, N>,
412 ) -> DiscreteMeasure<Loc<F, N>, F> 436 ) -> RNDM<F, N>
413 where F : Float + ToNalgebraRealField, 437 where F : Float + ToNalgebraRealField,
414 I : AlgIteratorFactory<IterInfo<F, N>>, 438 I : AlgIteratorFactory<IterInfo<F, N>>,
415 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 439 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
416 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
417 A::Observable : std::ops::MulAssign<F>,
418 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 440 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
419 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, 441 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
420 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 442 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
421 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 443 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
422 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 444 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
423 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 445 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
424 PlotLookup : Plotting<N>, 446 PlotLookup : Plotting<N>,
425 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 447 RNDM<F, N> : SpikeMerging<F>,
426 Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> { 448 Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> {
427 449
428 // Set up parameters 450 // Set up parameters
429 // We multiply tolerance by α for all algoritms. 451 // We multiply tolerance by α for all algoritms.
430 let tolerance = config.tolerance * reg.tolerance_scaling(); 452 let tolerance = config.tolerance * reg.tolerance_scaling();
436 458
437 // Initialise iterates 459 // Initialise iterates
438 let mut μ = DiscreteMeasure::new(); 460 let mut μ = DiscreteMeasure::new();
439 let mut residual = -b; 461 let mut residual = -b;
440 462
441 let mut inner_iters = 0; 463 // Statistics
442 let mut this_iters = 0; 464 let full_stats = |residual : &A::Observable,
443 let mut pruned = 0; 465 ν : &RNDM<F, N>,
444 let mut merged = 0; 466 ε, stats| IterInfo {
467 value : residual.norm2_squared_div2() + reg.apply(ν),
468 n_spikes : ν.len(),
469 ε,
470 .. stats
471 };
472 let mut stats = IterInfo::new();
445 473
446 // Run the algorithm 474 // Run the algorithm
447 iterator.iterate(|state| { 475 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
448 // Update tolerance
449 let inner_tolerance = ε * config.inner.tolerance_mult; 476 let inner_tolerance = ε * config.inner.tolerance_mult;
450 let refinement_tolerance = ε * config.refinement.tolerance_mult; 477 let refinement_tolerance = ε * config.refinement.tolerance_mult;
451 let ε_prev = ε;
452 ε = tolerance.update(ε, state.iteration());
453 478
454 // Calculate smooth part of surrogate model. 479 // Calculate smooth part of surrogate model.
455 // 480 let mut g = preadjA.apply(residual * (-1.0));
456 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
457 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
458 // the residual and replacing it below before the end of this closure.
459 let r = std::mem::replace(&mut residual, opA.empty_observable());
460 let mut g = -preadjA.apply(r);
461 481
462 // Find absolute value maximising point 482 // Find absolute value maximising point
463 let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance, 483 let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance,
464 config.refinement.max_steps); 484 config.refinement.max_steps);
465 485
466 let inner_it = match config.variant { 486 let inner_it = match config.variant {
467 FWVariant::FullyCorrective => { 487 FWVariant::FullyCorrective => {
468 // No point in optimising the weight here: the finite-dimensional algorithm is fast. 488 // No point in optimising the weight here: the finite-dimensional algorithm is fast.
469 μ += DeltaMeasure { x : ξ, α : 0.0 }; 489 μ += DeltaMeasure { x : ξ, α : 0.0 };
490 stats.inserted += 1;
470 config.inner.iterator_options.stop_target(inner_tolerance) 491 config.inner.iterator_options.stop_target(inner_tolerance)
471 }, 492 },
472 FWVariant::Relaxed => { 493 FWVariant::Relaxed => {
473 // Perform a relaxed initialisation of μ 494 // Perform a relaxed initialisation of μ
474 reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data); 495 reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data);
496 stats.inserted += 1;
475 // The stop_target is only needed for the type system. 497 // The stop_target is only needed for the type system.
476 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) 498 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0)
477 } 499 }
478 }; 500 };
479 501
480 inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, &config.inner, inner_it); 502 stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data,
503 &config.inner, inner_it);
481 504
482 // Merge spikes and update residual for next step and `if_verbose` below. 505 // Merge spikes and update residual for next step and `if_verbose` below.
483 let n_before_merge = μ.len(); 506 let (r, count) = μ.merge_spikes_fitness(config.merging,
484 residual = μ.merge_spikes_fitness(config.merging, 507 |μ̃| opA.apply(μ̃) - b,
485 |μ̃| opA.apply(μ̃) - b, 508 A::Observable::norm2_squared);
486 A::Observable::norm2_squared); 509 residual = r;
487 assert!(μ.len() >= n_before_merge); 510 stats.merged += count;
488 merged += μ.len() - n_before_merge;
489
490 511
491 // Prune points with zero mass 512 // Prune points with zero mass
492 let n_before_prune = μ.len(); 513 let n_before_prune = μ.len();
493 μ.prune(); 514 μ.prune();
494 debug_assert!(μ.len() <= n_before_prune); 515 debug_assert!(μ.len() <= n_before_prune);
495 pruned += n_before_prune - μ.len(); 516 stats.pruned += n_before_prune - μ.len();
496 517
497 this_iters +=1; 518 stats.this_iters += 1;
498 519 let iter = state.iteration();
499 // Give function value if needed 520
521 // Give statistics if needed
500 state.if_verbose(|| { 522 state.if_verbose(|| {
501 plotter.plot_spikes( 523 plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ);
502 format!("iter {} start", state.iteration()), &g, 524 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
503 "".to_string(), None::<&A::PreadjointCodomain>, 525 });
504 None, &μ 526
505 ); 527 // Update tolerance
506 let res = IterInfo { 528 ε = tolerance.update(ε, iter);
507 value : residual.norm2_squared_div2() + reg.apply(&μ), 529 }
508 n_spikes : μ.len(),
509 inner_iters,
510 this_iters,
511 merged,
512 pruned,
513 ε : ε_prev,
514 postprocessing : None,
515 };
516 inner_iters = 0;
517 this_iters = 0;
518 merged = 0;
519 pruned = 0;
520 res
521 })
522 });
523 530
524 // Return final iterate 531 // Return final iterate
525 μ 532 μ
526 } 533 }
527
528 //
529 // Deprecated interface
530 //
531
532 #[deprecated(note = "Use `pointsource_fw_reg`")]
533 pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>(
534 opA : &'a A,
535 b : &A::Observable,
536 α : F,
537 //domain : Cube<F, N>,
538 config : &FWConfig<F>,
539 iterator : I,
540 plotter : SeqPlotter<F, N>,
541 ) -> DiscreteMeasure<Loc<F, N>, F>
542 where F : Float + ToNalgebraRealField,
543 I : AlgIteratorFactory<IterInfo<F, N>>,
544 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
545 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
546 A::Observable : std::ops::MulAssign<F>,
547 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
548 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
549 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
550 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
551 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
552 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
553 PlotLookup : Plotting<N>,
554 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
555
556 pointsource_fw_reg(opA, b, NonnegRadonRegTerm(α), config, iterator, plotter)
557 }
558
559 #[deprecated(note = "Use `WeightOptim::optimise_weights`")]
560 pub fn optimise_weights<'a, F, A, I, const N : usize>(
561 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
562 opA : &'a A,
563 b : &A::Observable,
564 α : F,
565 findim_data : &FindimData<F>,
566 inner : &InnerSettings<F>,
567 iterator : I
568 ) -> usize
569 where F : Float + ToNalgebraRealField,
570 I : AlgIteratorFactory<F>,
571 A : ForwardModel<Loc<F, N>, F>
572 {
573 NonnegRadonRegTerm(α).optimise_weights(μ, opA, b, findim_data, inner, iterator)
574 }

mercurial