| |
1 /*! |
| |
2 Basic proximal penalty based on convolution operators $𝒟$. |
| |
3 */ |
| |
4 |
| |
5 use numeric_literals::replace_float_literals; |
| |
6 use nalgebra::DVector; |
| |
7 use colored::Colorize; |
| |
8 |
| |
9 use alg_tools::types::*; |
| |
10 use alg_tools::loc::Loc; |
| |
11 use alg_tools::mapping::{Mapping, RealMapping}; |
| |
12 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
13 use alg_tools::norms::Linfinity; |
| |
14 use alg_tools::iterate::{ |
| |
15 AlgIteratorIteration, |
| |
16 AlgIterator, |
| |
17 }; |
| |
18 use alg_tools::bisection_tree::{ |
| |
19 BTFN, |
| |
20 PreBTFN, |
| |
21 Bounds, |
| |
22 BTSearch, |
| |
23 SupportGenerator, |
| |
24 LocalAnalysis, |
| |
25 BothGenerators, |
| |
26 }; |
| |
27 use crate::measures::{ |
| |
28 RNDM, |
| |
29 DeltaMeasure, |
| |
30 Radon, |
| |
31 }; |
| |
32 use crate::measures::merging::{ |
| |
33 SpikeMerging, |
| |
34 }; |
| |
35 use crate::seminorms::DiscreteMeasureOp; |
| |
36 use crate::types::{ |
| |
37 IterInfo, |
| |
38 }; |
| |
39 use crate::regularisation::RegTerm; |
| |
40 use super::{ProxPenalty, FBGenericConfig}; |
| |
41 |
| |
42 #[replace_float_literals(F::cast_from(literal))] |
| |
43 impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize> |
| |
44 ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟 |
| |
45 where |
| |
46 F : Float + ToNalgebraRealField, |
| |
47 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| |
48 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| |
49 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| |
50 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
| |
51 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
| |
52 𝒟::Codomain : RealMapping<F, N>, |
| |
53 K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| |
54 Reg : RegTerm<F, N>, |
| |
55 RNDM<F, N> : SpikeMerging<F>, |
| |
56 { |
| |
57 type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>; |
| |
58 |
| |
59 fn insert_and_reweigh<I>( |
| |
60 &self, |
| |
61 μ : &mut RNDM<F, N>, |
| |
62 τv : &mut BTFN<F, GA, BTA, N>, |
| |
63 μ_base : &RNDM<F, N>, |
| |
64 ν_delta: Option<&RNDM<F, N>>, |
| |
65 τ : F, |
| |
66 ε : F, |
| |
67 config : &FBGenericConfig<F>, |
| |
68 reg : &Reg, |
| |
69 state : &AlgIteratorIteration<I>, |
| |
70 stats : &mut IterInfo<F, N>, |
| |
71 ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool) |
| |
72 where |
| |
73 I : AlgIterator |
| |
74 { |
| |
75 |
| |
76 let op𝒟norm = self.opnorm_bound(Radon, Linfinity); |
| |
77 |
| |
78 // Maximum insertion count and measure difference calculation depend on insertion style. |
| |
79 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
| |
80 (i, Some((l, k))) if i <= l => (k, false), |
| |
81 _ => (config.max_insertions, !state.is_quiet()), |
| |
82 }; |
| |
83 |
| |
84 let ω0 = match ν_delta { |
| |
85 None => self.apply(μ_base), |
| |
86 Some(ν) => self.apply(μ_base + ν), |
| |
87 }; |
| |
88 |
| |
89 // Add points to support until within error tolerance or maximum insertion count reached. |
| |
90 let mut count = 0; |
| |
91 let (within_tolerances, d) = 'insertion: loop { |
| |
92 if μ.len() > 0 { |
| |
93 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
| |
94 // from the beginning of the iteration are all contained in the immutable c and g. |
| |
95 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
| |
96 // problems have not yet been updated to sign change. |
| |
97 let à = self.findim_matrix(μ.iter_locations()); |
| |
98 let g̃ = DVector::from_iterator(μ.len(), |
| |
99 μ.iter_locations() |
| |
100 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) |
| |
101 .map(F::to_nalgebra_mixed)); |
| |
102 let mut x = μ.masses_dvector(); |
| |
103 |
| |
104 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
| |
105 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
| |
106 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
| |
107 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 |
| |
108 // = n |𝒟| |x|_2, where n is the number of points. Therefore |
| |
109 let Ã_normest = op𝒟norm * F::cast_from(μ.len()); |
| |
110 |
| |
111 // Solve finite-dimensional subproblem. |
| |
112 stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); |
| |
113 |
| |
114 // Update masses of μ based on solution of finite-dimensional subproblem. |
| |
115 μ.set_masses_dvector(&x); |
| |
116 } |
| |
117 |
| |
118 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality |
| |
119 // conditions in the predual space, and finding new points for insertion, if necessary. |
| |
120 let mut d = &*τv + match ν_delta { |
| |
121 None => self.preapply(μ.sub_matching(μ_base)), |
| |
122 Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) |
| |
123 }; |
| |
124 |
| |
125 // If no merging heuristic is used, let's be more conservative about spike insertion, |
| |
126 // and skip it after first round. If merging is done, being more greedy about spike |
| |
127 // insertion also seems to improve performance. |
| |
128 let skip_by_rough_check = if config.merging.enabled { |
| |
129 false |
| |
130 } else { |
| |
131 count > 0 |
| |
132 }; |
| |
133 |
| |
134 // Find a spike to insert, if needed |
| |
135 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( |
| |
136 &mut d, τ, ε, skip_by_rough_check, config |
| |
137 ) { |
| |
138 None => break 'insertion (true, d), |
| |
139 Some(res) => res, |
| |
140 }; |
| |
141 |
| |
142 // Break if maximum insertion count reached |
| |
143 if count >= max_insertions { |
| |
144 break 'insertion (in_bounds, d) |
| |
145 } |
| |
146 |
| |
147 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
| |
148 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
| |
149 count += 1; |
| |
150 stats.inserted += 1; |
| |
151 }; |
| |
152 |
| |
153 if !within_tolerances && warn_insertions { |
| |
154 // Complain (but continue) if we failed to get within tolerances |
| |
155 // by inserting more points. |
| |
156 let err = format!("Maximum insertions reached without achieving \ |
| |
157 subproblem solution tolerance"); |
| |
158 println!("{}", err.red()); |
| |
159 } |
| |
160 |
| |
161 (Some(d), within_tolerances) |
| |
162 } |
| |
163 |
| |
164 fn merge_spikes( |
| |
165 &self, |
| |
166 μ : &mut RNDM<F, N>, |
| |
167 τv : &mut BTFN<F, GA, BTA, N>, |
| |
168 μ_base : &RNDM<F, N>, |
| |
169 ν_delta: Option<&RNDM<F, N>>, |
| |
170 τ : F, |
| |
171 ε : F, |
| |
172 config : &FBGenericConfig<F>, |
| |
173 reg : &Reg, |
| |
174 fitness : Option<impl Fn(&RNDM<F, N>) -> F>, |
| |
175 ) -> usize |
| |
176 { |
| |
177 if config.fitness_merging { |
| |
178 if let Some(f) = fitness { |
| |
179 return μ.merge_spikes_fitness(config.merging, f, |&v| v) |
| |
180 .1 |
| |
181 } |
| |
182 } |
| |
183 μ.merge_spikes(config.merging, |μ_candidate| { |
| |
184 let mut d = &*τv + self.preapply(match ν_delta { |
| |
185 None => μ_candidate.sub_matching(μ_base), |
| |
186 Some(ν) => μ_candidate.sub_matching(μ_base) - ν, |
| |
187 }); |
| |
188 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) |
| |
189 }) |
| |
190 } |
| |
191 } |