src/regularisation.rs

branch
dev
changeset 32
56c8adc32b09
parent 24
d29d1fcf5423
child 34
efa60bc4f743
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
1 /*! 1 /*!
2 Regularisation terms 2 Regularisation terms
3 */ 3 */
4 4
5 use numeric_literals::replace_float_literals;
5 use serde::{Serialize, Deserialize}; 6 use serde::{Serialize, Deserialize};
6 use alg_tools::norms::Norm; 7 use alg_tools::norms::Norm;
7 use alg_tools::linops::Apply; 8 use alg_tools::linops::Apply;
8 use alg_tools::loc::Loc; 9 use alg_tools::loc::Loc;
9 use crate::types::*; 10 use crate::types::*;
10 use crate::measures::{ 11 use crate::measures::{
11 DiscreteMeasure, 12 DiscreteMeasure,
13 DeltaMeasure,
12 Radon 14 Radon
13 }; 15 };
16 use crate::fb::FBGenericConfig;
14 #[allow(unused_imports)] // Used by documentation. 17 #[allow(unused_imports)] // Used by documentation.
15 use crate::fb::generic_pointsource_fb_reg; 18 use crate::fb::pointsource_fb_reg;
16 19 #[allow(unused_imports)] // Used by documentation.
17 /// The regularisation term $α\\|μ\\|\_{ℳ(Ω)} + δ_{≥ 0}(μ)$ for [`generic_pointsource_fb_reg`]. 20 use crate::sliding_fb::pointsource_sliding_fb_reg;
21
22 use nalgebra::{DVector, DMatrix};
23 use alg_tools::nalgebra_support::ToNalgebraRealField;
24 use alg_tools::mapping::Mapping;
25 use alg_tools::bisection_tree::{
26 BTFN,
27 Bounds,
28 BTSearch,
29 P2Minimise,
30 SupportGenerator,
31 LocalAnalysis,
32 Bounded,
33 };
34 use crate::subproblem::{
35 nonneg::quadratic_nonneg,
36 unconstrained::quadratic_unconstrained,
37 };
38 use alg_tools::iterate::AlgIteratorFactory;
39
40 /// The regularisation term $α\\|μ\\|\_{ℳ(Ω)} + δ_{≥ 0}(μ)$ for [`pointsource_fb_reg`] and other
41 /// algorithms.
18 /// 42 ///
19 /// The only member of the struct is the regularisation parameter α. 43 /// The only member of the struct is the regularisation parameter α.
20 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 44 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
21 pub struct NonnegRadonRegTerm<F : Float>(pub F /* α */); 45 pub struct NonnegRadonRegTerm<F : Float>(pub F /* α */);
22 46
36 self.α() * μ.norm(Radon) 60 self.α() * μ.norm(Radon)
37 } 61 }
38 } 62 }
39 63
40 64
41 /// The regularisation term $α\|μ\|_{ℳ(Ω)}$ for [`generic_pointsource_fb_reg`]. 65 /// The regularisation term $α\|μ\|_{ℳ(Ω)}$ for [`pointsource_fb_reg`].
42 /// 66 ///
43 /// The only member of the struct is the regularisation parameter α. 67 /// The only member of the struct is the regularisation parameter α.
44 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 68 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
45 pub struct RadonRegTerm<F : Float>(pub F /* α */); 69 pub struct RadonRegTerm<F : Float>(pub F /* α */);
46 70
80 Self::Radon(α) => RadonRegTerm(α).apply(μ), 104 Self::Radon(α) => RadonRegTerm(α).apply(μ),
81 Self::NonnegRadon(α) => NonnegRadonRegTerm(α).apply(μ), 105 Self::NonnegRadon(α) => NonnegRadonRegTerm(α).apply(μ),
82 } 106 }
83 } 107 }
84 } 108 }
109
110 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
111 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
112 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
113 /// Approximately solve the problem
114 /// <div>$$
115 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
116 /// $$</div>
117 /// for $G$ depending on the trait implementation.
118 ///
119 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
120 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
121 ///
122 /// Returns the number of iterations taken.
123 fn solve_findim(
124 &self,
125 mA : &DMatrix<F::MixedType>,
126 g : &DVector<F::MixedType>,
127 τ : F,
128 x : &mut DVector<F::MixedType>,
129 mA_normest : F,
130 ε : F,
131 config : &FBGenericConfig<F>
132 ) -> usize;
133
134 /// Find a point where `d` may violate the tolerance `ε`.
135 ///
136 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
137 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
138 /// regulariser.
139 ///
140 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
141 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
142 /// and a boolean indicating whether the found point is in bounds.
143 fn find_tolerance_violation<G, BT>(
144 &self,
145 d : &mut BTFN<F, G, BT, N>,
146 τ : F,
147 ε : F,
148 skip_by_rough_check : bool,
149 config : &FBGenericConfig<F>,
150 ) -> Option<(Loc<F, N>, F, bool)>
151 where BT : BTSearch<F, N, Agg=Bounds<F>>,
152 G : SupportGenerator<F, N, Id=BT::Data>,
153 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
154 + LocalAnalysis<F, Bounds<F>, N>;
155
156 /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
157 ///
158 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
159 fn verify_merge_candidate<G, BT>(
160 &self,
161 d : &mut BTFN<F, G, BT, N>,
162 μ : &DiscreteMeasure<Loc<F, N>, F>,
163 τ : F,
164 ε : F,
165 config : &FBGenericConfig<F>,
166 ) -> bool
167 where BT : BTSearch<F, N, Agg=Bounds<F>>,
168 G : SupportGenerator<F, N, Id=BT::Data>,
169 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
170 + LocalAnalysis<F, Bounds<F>, N>;
171
172 /// TODO: document this
173 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;
174
175 /// Returns a scaling factor for the tolerance sequence.
176 ///
177 /// Typically this is the regularisation parameter.
178 fn tolerance_scaling(&self) -> F;
179 }
180
181 /// Abstraction of regularisation terms for [`pointsource_sliding_fb_reg`].
182 pub trait SlidingRegTerm<F : Float + ToNalgebraRealField, const N : usize>
183 : RegTerm<F, N> {
184 /// Calculate $τ[w(z) - w(y)]$ for some w in the subdifferential of the regularisation
185 /// term, such that $-ε ≤ τw - d ≤ ε$.
186 fn goodness<G, BT>(
187 &self,
188 d : &mut BTFN<F, G, BT, N>,
189 μ : &DiscreteMeasure<Loc<F, N>, F>,
190 y : &Loc<F, N>,
191 z : &Loc<F, N>,
192 τ : F,
193 ε : F,
194 config : &FBGenericConfig<F>,
195 ) -> F
196 where BT : BTSearch<F, N, Agg=Bounds<F>>,
197 G : SupportGenerator<F, N, Id=BT::Data>,
198 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
199 + LocalAnalysis<F, Bounds<F>, N>;
200 }
201
202 #[replace_float_literals(F::cast_from(literal))]
203 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N>
204 for NonnegRadonRegTerm<F>
205 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
206 fn solve_findim(
207 &self,
208 mA : &DMatrix<F::MixedType>,
209 g : &DVector<F::MixedType>,
210 τ : F,
211 x : &mut DVector<F::MixedType>,
212 mA_normest : F,
213 ε : F,
214 config : &FBGenericConfig<F>
215 ) -> usize {
216 let inner_tolerance = ε * config.inner.tolerance_mult;
217 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
218 let inner_τ = config.inner.τ0 / mA_normest;
219 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
220 inner_τ, inner_it)
221 }
222
223 #[inline]
224 fn find_tolerance_violation<G, BT>(
225 &self,
226 d : &mut BTFN<F, G, BT, N>,
227 τ : F,
228 ε : F,
229 skip_by_rough_check : bool,
230 config : &FBGenericConfig<F>,
231 ) -> Option<(Loc<F, N>, F, bool)>
232 where BT : BTSearch<F, N, Agg=Bounds<F>>,
233 G : SupportGenerator<F, N, Id=BT::Data>,
234 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
235 + LocalAnalysis<F, Bounds<F>, N> {
236 let τα = τ * self.α();
237 let keep_below = τα + ε;
238 let maximise_above = τα + ε * config.insertion_cutoff_factor;
239 let refinement_tolerance = ε * config.refinement.tolerance_mult;
240
241 // If preliminary check indicates that we are in bonds, and if it otherwise matches
242 // the insertion strategy, skip insertion.
243 if skip_by_rough_check && d.bounds().upper() <= keep_below {
244 None
245 } else {
246 // If the rough check didn't indicate no insertion needed, find maximising point.
247 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
248 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
249 }
250 }
251
252 fn verify_merge_candidate<G, BT>(
253 &self,
254 d : &mut BTFN<F, G, BT, N>,
255 μ : &DiscreteMeasure<Loc<F, N>, F>,
256 τ : F,
257 ε : F,
258 config : &FBGenericConfig<F>,
259 ) -> bool
260 where BT : BTSearch<F, N, Agg=Bounds<F>>,
261 G : SupportGenerator<F, N, Id=BT::Data>,
262 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
263 + LocalAnalysis<F, Bounds<F>, N> {
264 let τα = τ * self.α();
265 let refinement_tolerance = ε * config.refinement.tolerance_mult;
266 let merge_tolerance = config.merge_tolerance_mult * ε;
267 let keep_below = τα + merge_tolerance;
268 let keep_supp_above = τα - merge_tolerance;
269 let bnd = d.bounds();
270
271 return (
272 bnd.lower() >= keep_supp_above
273 ||
274 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
275 (β == 0.0) || d.apply(x) >= keep_supp_above
276 }).all(std::convert::identity)
277 ) && (
278 bnd.upper() <= keep_below
279 ||
280 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
281 )
282 }
283
284 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
285 let τα = τ * self.α();
286 Some(Bounds(τα - ε, τα + ε))
287 }
288
289 fn tolerance_scaling(&self) -> F {
290 self.α()
291 }
292 }
293
294 #[replace_float_literals(F::cast_from(literal))]
295 impl<F : Float + ToNalgebraRealField, const N : usize> SlidingRegTerm<F, N>
296 for NonnegRadonRegTerm<F>
297 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
298
299 fn goodness<G, BT>(
300 &self,
301 d : &mut BTFN<F, G, BT, N>,
302 _μ : &DiscreteMeasure<Loc<F, N>, F>,
303 y : &Loc<F, N>,
304 z : &Loc<F, N>,
305 τ : F,
306 ε : F,
307 _config : &FBGenericConfig<F>,
308 ) -> F
309 where BT : BTSearch<F, N, Agg=Bounds<F>>,
310 G : SupportGenerator<F, N, Id=BT::Data>,
311 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
312 + LocalAnalysis<F, Bounds<F>, N> {
313 //let w = |x| 1.0.min((ε + d.apply(x))/(τ * self.α()));
314 let τw = |x| τ.min((ε + d.apply(x))/self.α());
315 τw(z) - τw(y)
316 }
317 }
318
319 #[replace_float_literals(F::cast_from(literal))]
320 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
321 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
322 fn solve_findim(
323 &self,
324 mA : &DMatrix<F::MixedType>,
325 g : &DVector<F::MixedType>,
326 τ : F,
327 x : &mut DVector<F::MixedType>,
328 mA_normest: F,
329 ε : F,
330 config : &FBGenericConfig<F>
331 ) -> usize {
332 let inner_tolerance = ε * config.inner.tolerance_mult;
333 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
334 let inner_τ = config.inner.τ0 / mA_normest;
335 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
336 inner_τ, inner_it)
337 }
338
339 fn find_tolerance_violation<G, BT>(
340 &self,
341 d : &mut BTFN<F, G, BT, N>,
342 τ : F,
343 ε : F,
344 skip_by_rough_check : bool,
345 config : &FBGenericConfig<F>,
346 ) -> Option<(Loc<F, N>, F, bool)>
347 where BT : BTSearch<F, N, Agg=Bounds<F>>,
348 G : SupportGenerator<F, N, Id=BT::Data>,
349 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
350 + LocalAnalysis<F, Bounds<F>, N> {
351 let τα = τ * self.α();
352 let keep_below = τα + ε;
353 let keep_above = -τα - ε;
354 let maximise_above = τα + ε * config.insertion_cutoff_factor;
355 let minimise_below = -τα - ε * config.insertion_cutoff_factor;
356 let refinement_tolerance = ε * config.refinement.tolerance_mult;
357
358 // If preliminary check indicates that we are in bonds, and if it otherwise matches
359 // the insertion strategy, skip insertion.
360 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
361 None
362 } else {
363 // If the rough check didn't indicate no insertion needed, find maximising point.
364 let mx = d.maximise_above(maximise_above, refinement_tolerance,
365 config.refinement.max_steps);
366 let mi = d.minimise_below(minimise_below, refinement_tolerance,
367 config.refinement.max_steps);
368
369 match (mx, mi) {
370 (None, None) => None,
371 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
372 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
373 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
374 if v_ξ - τα > τα - v_ζ {
375 Some((ξ, v_ξ, keep_below >= v_ξ))
376 } else {
377 Some((ζ, v_ζ, keep_above <= v_ζ))
378 }
379 }
380 }
381 }
382 }
383
384 fn verify_merge_candidate<G, BT>(
385 &self,
386 d : &mut BTFN<F, G, BT, N>,
387 μ : &DiscreteMeasure<Loc<F, N>, F>,
388 τ : F,
389 ε : F,
390 config : &FBGenericConfig<F>,
391 ) -> bool
392 where BT : BTSearch<F, N, Agg=Bounds<F>>,
393 G : SupportGenerator<F, N, Id=BT::Data>,
394 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
395 + LocalAnalysis<F, Bounds<F>, N> {
396 let τα = τ * self.α();
397 let refinement_tolerance = ε * config.refinement.tolerance_mult;
398 let merge_tolerance = config.merge_tolerance_mult * ε;
399 let keep_below = τα + merge_tolerance;
400 let keep_above = -τα - merge_tolerance;
401 let keep_supp_pos_above = τα - merge_tolerance;
402 let keep_supp_neg_below = -τα + merge_tolerance;
403 let bnd = d.bounds();
404
405 return (
406 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
407 ||
408 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
409 use std::cmp::Ordering::*;
410 match β.partial_cmp(&0.0) {
411 Some(Greater) => d.apply(x) >= keep_supp_pos_above,
412 Some(Less) => d.apply(x) <= keep_supp_neg_below,
413 _ => true,
414 }
415 }).all(std::convert::identity)
416 ) && (
417 bnd.upper() <= keep_below
418 ||
419 d.has_upper_bound(keep_below, refinement_tolerance,
420 config.refinement.max_steps)
421 ) && (
422 bnd.lower() >= keep_above
423 ||
424 d.has_lower_bound(keep_above, refinement_tolerance,
425 config.refinement.max_steps)
426 )
427 }
428
429 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
430 let τα = τ * self.α();
431 Some(Bounds(-τα - ε, τα + ε))
432 }
433
434 fn tolerance_scaling(&self) -> F {
435 self.α()
436 }
437 }
438
439 #[replace_float_literals(F::cast_from(literal))]
440 impl<F : Float + ToNalgebraRealField, const N : usize> SlidingRegTerm<F, N>
441 for RadonRegTerm<F>
442 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
443
444 fn goodness<G, BT>(
445 &self,
446 d : &mut BTFN<F, G, BT, N>,
447 _μ : &DiscreteMeasure<Loc<F, N>, F>,
448 y : &Loc<F, N>,
449 z : &Loc<F, N>,
450 τ : F,
451 ε : F,
452 _config : &FBGenericConfig<F>,
453 ) -> F
454 where BT : BTSearch<F, N, Agg=Bounds<F>>,
455 G : SupportGenerator<F, N, Id=BT::Data>,
456 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
457 + LocalAnalysis<F, Bounds<F>, N> {
458
459 let α = self.α();
460 // let w = |x| {
461 // let dx = d.apply(x);
462 // ((-ε + dx)/(τ * α)).max(1.0.min(ε + dx)/(τ * α))
463 // };
464 let τw = |x| {
465 let dx = d.apply(x);
466 ((-ε + dx)/α).max(τ.min(ε + dx)/α)
467 };
468 τw(z) - τw(y)
469 }
470 }

mercurial