| 1 /*! |
1 /*! |
| 2 Basic proximal penalty based on convolution operators $𝒟$. |
2 Basic proximal penalty based on convolution operators $𝒟$. |
| 3 */ |
3 */ |
| 4 |
4 |
| |
5 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; |
| |
6 use crate::dataterm::QuadraticDataTerm; |
| |
7 use crate::forward_model::ForwardModel; |
| |
8 use crate::measures::merging::SpikeMerging; |
| |
9 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; |
| |
10 use crate::regularisation::RegTerm; |
| |
11 use crate::seminorms::DiscreteMeasureOp; |
| |
12 use crate::types::IterInfo; |
| |
13 use alg_tools::bounds::MinMaxMapping; |
| |
14 use alg_tools::error::DynResult; |
| |
15 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; |
| |
16 use alg_tools::linops::BoundedLinear; |
| |
17 use alg_tools::mapping::{Instance, Mapping, Space}; |
| |
18 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
19 use alg_tools::norms::{Linfinity, Norm, NormExponent, L2}; |
| |
20 use alg_tools::types::*; |
| |
21 use colored::Colorize; |
| |
22 use nalgebra::DVector; |
| 5 use numeric_literals::replace_float_literals; |
23 use numeric_literals::replace_float_literals; |
| 6 use nalgebra::DVector; |
|
| 7 use colored::Colorize; |
|
| 8 |
|
| 9 use alg_tools::types::*; |
|
| 10 use alg_tools::loc::Loc; |
|
| 11 use alg_tools::mapping::{Mapping, RealMapping}; |
|
| 12 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 13 use alg_tools::norms::Linfinity; |
|
| 14 use alg_tools::iterate::{ |
|
| 15 AlgIteratorIteration, |
|
| 16 AlgIterator, |
|
| 17 }; |
|
| 18 use alg_tools::bisection_tree::{ |
|
| 19 BTFN, |
|
| 20 PreBTFN, |
|
| 21 Bounds, |
|
| 22 BTSearch, |
|
| 23 SupportGenerator, |
|
| 24 LocalAnalysis, |
|
| 25 BothGenerators, |
|
| 26 }; |
|
| 27 use crate::measures::{ |
|
| 28 RNDM, |
|
| 29 DeltaMeasure, |
|
| 30 Radon, |
|
| 31 }; |
|
| 32 use crate::measures::merging::{ |
|
| 33 SpikeMerging, |
|
| 34 }; |
|
| 35 use crate::seminorms::DiscreteMeasureOp; |
|
| 36 use crate::types::{ |
|
| 37 IterInfo, |
|
| 38 }; |
|
| 39 use crate::regularisation::RegTerm; |
|
| 40 use super::{ProxPenalty, FBGenericConfig}; |
|
| 41 |
24 |
| 42 #[replace_float_literals(F::cast_from(literal))] |
25 #[replace_float_literals(F::cast_from(literal))] |
| 43 impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize> |
26 impl<F, M, Reg, 𝒟, O, Domain> ProxPenalty<Domain, M, Reg, F> for 𝒟 |
| 44 ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟 |
|
| 45 where |
27 where |
| 46 F : Float + ToNalgebraRealField, |
28 Domain: Space + Clone + PartialEq + 'static, |
| 47 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
29 for<'a> &'a Domain: Instance<Domain>, |
| 48 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
30 F: Float + ToNalgebraRealField, |
| 49 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
31 𝒟: DiscreteMeasureOp<Domain, F>, |
| 50 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
32 𝒟::Codomain: Mapping<Domain, Codomain = F>, |
| 51 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
33 M: Mapping<Domain, Codomain = F>, |
| 52 𝒟::Codomain : RealMapping<F, N>, |
34 for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>, |
| 53 K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
35 O: MinMaxMapping<Domain, F>, |
| 54 Reg : RegTerm<F, N>, |
36 Reg: RegTerm<Domain, F>, |
| 55 RNDM<F, N> : SpikeMerging<F>, |
37 DiscreteMeasure<Domain, F>: SpikeMerging<F>, |
| 56 { |
38 { |
| 57 type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>; |
39 type ReturnMapping = O; |
| |
40 |
| |
41 fn prox_type() -> ProxTerm { |
| |
42 ProxTerm::Wave |
| |
43 } |
| 58 |
44 |
| 59 fn insert_and_reweigh<I>( |
45 fn insert_and_reweigh<I>( |
| 60 &self, |
46 &self, |
| 61 μ : &mut RNDM<F, N>, |
47 μ: &mut DiscreteMeasure<Domain, F>, |
| 62 τv : &mut BTFN<F, GA, BTA, N>, |
48 τv: &mut M, |
| 63 μ_base : &RNDM<F, N>, |
49 τ: F, |
| 64 ν_delta: Option<&RNDM<F, N>>, |
50 ε: F, |
| 65 τ : F, |
51 config: &InsertionConfig<F>, |
| 66 ε : F, |
52 reg: &Reg, |
| 67 config : &FBGenericConfig<F>, |
53 state: &AlgIteratorIteration<I>, |
| 68 reg : &Reg, |
54 stats: &mut IterInfo<F>, |
| 69 state : &AlgIteratorIteration<I>, |
55 ) -> DynResult<(Option<Self::ReturnMapping>, bool)> |
| 70 stats : &mut IterInfo<F, N>, |
|
| 71 ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool) |
|
| 72 where |
56 where |
| 73 I : AlgIterator |
57 I: AlgIterator, |
| 74 { |
58 { |
| 75 |
59 let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?; |
| 76 let op𝒟norm = self.opnorm_bound(Radon, Linfinity); |
|
| 77 |
60 |
| 78 // Maximum insertion count and measure difference calculation depend on insertion style. |
61 // Maximum insertion count and measure difference calculation depend on insertion style. |
| 79 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
62 let (max_insertions, warn_insertions) = |
| 80 (i, Some((l, k))) if i <= l => (k, false), |
63 match (state.iteration(), config.bootstrap_insertions) { |
| 81 _ => (config.max_insertions, !state.is_quiet()), |
64 (i, Some((l, k))) if i <= l => (k, false), |
| 82 }; |
65 _ => (config.max_insertions, !state.is_quiet()), |
| 83 |
66 }; |
| 84 let ω0 = match ν_delta { |
67 |
| 85 None => self.apply(μ_base), |
68 let μ_base = μ.clone(); |
| 86 Some(ν) => self.apply(μ_base + ν), |
69 let ω0 = self.apply(&μ_base); |
| 87 }; |
|
| 88 |
70 |
| 89 // Add points to support until within error tolerance or maximum insertion count reached. |
71 // Add points to support until within error tolerance or maximum insertion count reached. |
| 90 let mut count = 0; |
72 let mut count = 0; |
| 91 let (within_tolerances, d) = 'insertion: loop { |
73 let (within_tolerances, d) = 'insertion: loop { |
| 92 if μ.len() > 0 { |
74 if μ.len() > 0 { |
| 93 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
75 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
| 94 // from the beginning of the iteration are all contained in the immutable c and g. |
76 // from the beginning of the iteration are all contained in the immutable c and g. |
| 95 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
77 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
| 96 // problems have not yet been updated to sign change. |
78 // problems have not yet been updated to sign change. |
| 97 let à = self.findim_matrix(μ.iter_locations()); |
79 let à = self.findim_matrix(μ.iter_locations()); |
| 98 let g̃ = DVector::from_iterator(μ.len(), |
80 let g̃ = DVector::from_iterator( |
| 99 μ.iter_locations() |
81 μ.len(), |
| 100 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) |
82 μ.iter_locations() |
| 101 .map(F::to_nalgebra_mixed)); |
83 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) |
| |
84 .map(F::to_nalgebra_mixed), |
| |
85 ); |
| 102 let mut x = μ.masses_dvector(); |
86 let mut x = μ.masses_dvector(); |
| 103 |
87 |
| 104 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
88 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
| 105 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
89 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
| 106 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
90 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
| 130 } else { |
111 } else { |
| 131 count > 0 |
112 count > 0 |
| 132 }; |
113 }; |
| 133 |
114 |
| 134 // Find a spike to insert, if needed |
115 // Find a spike to insert, if needed |
| 135 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( |
116 let (ξ, _v_ξ, in_bounds) = |
| 136 &mut d, τ, ε, skip_by_rough_check, config |
117 match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) { |
| 137 ) { |
118 None => break 'insertion (true, d), |
| 138 None => break 'insertion (true, d), |
119 Some(res) => res, |
| 139 Some(res) => res, |
120 }; |
| 140 }; |
|
| 141 |
121 |
| 142 // Break if maximum insertion count reached |
122 // Break if maximum insertion count reached |
| 143 if count >= max_insertions { |
123 if count >= max_insertions { |
| 144 break 'insertion (in_bounds, d) |
124 break 'insertion (in_bounds, d); |
| 145 } |
125 } |
| 146 |
126 |
| 147 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
127 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
| 148 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
128 *μ += DeltaMeasure { x: ξ, α: 0.0 }; |
| 149 count += 1; |
129 count += 1; |
| 150 stats.inserted += 1; |
|
| 151 }; |
130 }; |
| 152 |
131 |
| 153 if !within_tolerances && warn_insertions { |
132 if !within_tolerances && warn_insertions { |
| 154 // Complain (but continue) if we failed to get within tolerances |
133 // Complain (but continue) if we failed to get within tolerances |
| 155 // by inserting more points. |
134 // by inserting more points. |
| 156 let err = format!("Maximum insertions reached without achieving \ |
135 let err = format!( |
| 157 subproblem solution tolerance"); |
136 "Maximum insertions reached without achieving \ |
| |
137 subproblem solution tolerance" |
| |
138 ); |
| 158 println!("{}", err.red()); |
139 println!("{}", err.red()); |
| 159 } |
140 } |
| 160 |
141 |
| 161 (Some(d), within_tolerances) |
142 Ok((Some(d), within_tolerances)) |
| 162 } |
143 } |
| 163 |
144 |
| 164 fn merge_spikes( |
145 fn merge_spikes( |
| 165 &self, |
146 &self, |
| 166 μ : &mut RNDM<F, N>, |
147 μ: &mut DiscreteMeasure<Domain, F>, |
| 167 τv : &mut BTFN<F, GA, BTA, N>, |
148 τv: &mut M, |
| 168 μ_base : &RNDM<F, N>, |
149 μ_base: &DiscreteMeasure<Domain, F>, |
| 169 ν_delta: Option<&RNDM<F, N>>, |
150 τ: F, |
| 170 τ : F, |
151 ε: F, |
| 171 ε : F, |
152 config: &InsertionConfig<F>, |
| 172 config : &FBGenericConfig<F>, |
153 reg: &Reg, |
| 173 reg : &Reg, |
154 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>, |
| 174 fitness : Option<impl Fn(&RNDM<F, N>) -> F>, |
155 ) -> usize { |
| 175 ) -> usize |
|
| 176 { |
|
| 177 if config.fitness_merging { |
156 if config.fitness_merging { |
| 178 if let Some(f) = fitness { |
157 if let Some(f) = fitness { |
| 179 return μ.merge_spikes_fitness(config.merging, f, |&v| v) |
158 return μ.merge_spikes_fitness(config.merging, f, |&v| v).1; |
| 180 .1 |
|
| 181 } |
159 } |
| 182 } |
160 } |
| 183 μ.merge_spikes(config.merging, |μ_candidate| { |
161 μ.merge_spikes(config.merging, |μ_candidate| { |
| 184 let mut d = &*τv + self.preapply(match ν_delta { |
162 let mut d = &*τv + self.preapply(μ_candidate.sub_matching(μ_base)); |
| 185 None => μ_candidate.sub_matching(μ_base), |
|
| 186 Some(ν) => μ_candidate.sub_matching(μ_base) - ν, |
|
| 187 }); |
|
| 188 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) |
163 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) |
| 189 }) |
164 }) |
| 190 } |
165 } |
| 191 } |
166 } |
| |
167 |
| |
168 #[replace_float_literals(F::cast_from(literal))] |
| |
169 impl<'a, F, A, 𝒟, Domain> StepLengthBound<F, QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>> |
| |
170 for 𝒟 |
| |
171 where |
| |
172 Domain: Space + Clone + PartialEq + 'static, |
| |
173 F: Float + ToNalgebraRealField, |
| |
174 𝒟: DiscreteMeasureOp<Domain, F>, |
| |
175 A: ForwardModel<DiscreteMeasure<Domain, F>, F> |
| |
176 + for<'b> BoundedLinear<DiscreteMeasure<Domain, F>, &'b 𝒟, L2, F>, |
| |
177 DiscreteMeasure<Domain, F>: for<'b> Norm<&'b 𝒟, F>, |
| |
178 for<'b> &'b 𝒟: NormExponent, |
| |
179 { |
| |
180 fn step_length_bound( |
| |
181 &self, |
| |
182 f: &QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>, |
| |
183 ) -> DynResult<F> { |
| |
184 // TODO: direct squared calculation |
| |
185 Ok(f.operator().opnorm_bound(self, L2)?.powi(2)) |
| |
186 } |
| |
187 } |
| |
188 |
| |
189 #[replace_float_literals(F::cast_from(literal))] |
| |
190 impl<F, A, 𝒟, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for 𝒟 |
| |
191 where |
| |
192 Domain: Space + Clone + PartialEq + 'static, |
| |
193 F: Float + ToNalgebraRealField, |
| |
194 𝒟: DiscreteMeasureOp<Domain, F>, |
| |
195 A: for<'a> BoundedLinear<DiscreteMeasure<Domain, F>, &'a 𝒟, L2, F>, |
| |
196 DiscreteMeasure<Domain, F>: for<'a> Norm<&'a 𝒟, F>, |
| |
197 for<'b> &'b 𝒟: NormExponent, |
| |
198 { |
| |
199 fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> { |
| |
200 opA.opnorm_bound(self, L2) |
| |
201 } |
| |
202 } |