src/prox_penalty/wave.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 /*! 1 /*!
2 Basic proximal penalty based on convolution operators $𝒟$. 2 Basic proximal penalty based on convolution operators $𝒟$.
3 */ 3 */
4 4
5 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
6 use crate::dataterm::QuadraticDataTerm;
7 use crate::forward_model::ForwardModel;
8 use crate::measures::merging::SpikeMerging;
9 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
10 use crate::regularisation::RegTerm;
11 use crate::seminorms::DiscreteMeasureOp;
12 use crate::types::IterInfo;
13 use alg_tools::bounds::MinMaxMapping;
14 use alg_tools::error::DynResult;
15 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
16 use alg_tools::linops::BoundedLinear;
17 use alg_tools::mapping::{Instance, Mapping, Space};
18 use alg_tools::nalgebra_support::ToNalgebraRealField;
19 use alg_tools::norms::{Linfinity, Norm, NormExponent, L2};
20 use alg_tools::types::*;
21 use colored::Colorize;
22 use nalgebra::DVector;
5 use numeric_literals::replace_float_literals; 23 use numeric_literals::replace_float_literals;
6 use nalgebra::DVector;
7 use colored::Colorize;
8
9 use alg_tools::types::*;
10 use alg_tools::loc::Loc;
11 use alg_tools::mapping::{Mapping, RealMapping};
12 use alg_tools::nalgebra_support::ToNalgebraRealField;
13 use alg_tools::norms::Linfinity;
14 use alg_tools::iterate::{
15 AlgIteratorIteration,
16 AlgIterator,
17 };
18 use alg_tools::bisection_tree::{
19 BTFN,
20 PreBTFN,
21 Bounds,
22 BTSearch,
23 SupportGenerator,
24 LocalAnalysis,
25 BothGenerators,
26 };
27 use crate::measures::{
28 RNDM,
29 DeltaMeasure,
30 Radon,
31 };
32 use crate::measures::merging::{
33 SpikeMerging,
34 };
35 use crate::seminorms::DiscreteMeasureOp;
36 use crate::types::{
37 IterInfo,
38 };
39 use crate::regularisation::RegTerm;
40 use super::{ProxPenalty, FBGenericConfig};
41 24
42 #[replace_float_literals(F::cast_from(literal))] 25 #[replace_float_literals(F::cast_from(literal))]
43 impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize> 26 impl<F, M, Reg, 𝒟, O, Domain> ProxPenalty<Domain, M, Reg, F> for 𝒟
44 ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟
45 where 27 where
46 F : Float + ToNalgebraRealField, 28 Domain: Space + Clone + PartialEq + 'static,
47 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 29 for<'a> &'a Domain: Instance<Domain>,
48 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 30 F: Float + ToNalgebraRealField,
49 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 31 𝒟: DiscreteMeasureOp<Domain, F>,
50 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 32 𝒟::Codomain: Mapping<Domain, Codomain = F>,
51 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 33 M: Mapping<Domain, Codomain = F>,
52 𝒟::Codomain : RealMapping<F, N>, 34 for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>,
53 K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 35 O: MinMaxMapping<Domain, F>,
54 Reg : RegTerm<F, N>, 36 Reg: RegTerm<Domain, F>,
55 RNDM<F, N> : SpikeMerging<F>, 37 DiscreteMeasure<Domain, F>: SpikeMerging<F>,
56 { 38 {
57 type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>; 39 type ReturnMapping = O;
40
41 fn prox_type() -> ProxTerm {
42 ProxTerm::Wave
43 }
58 44
59 fn insert_and_reweigh<I>( 45 fn insert_and_reweigh<I>(
60 &self, 46 &self,
61 μ : &mut RNDM<F, N>, 47 μ: &mut DiscreteMeasure<Domain, F>,
62 τv : &mut BTFN<F, GA, BTA, N>, 48 τv: &mut M,
63 μ_base : &RNDM<F, N>, 49 μ_base: &DiscreteMeasure<Domain, F>,
64 ν_delta: Option<&RNDM<F, N>>, 50 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
65 τ : F, 51 τ: F,
66 ε : F, 52 ε: F,
67 config : &FBGenericConfig<F>, 53 config: &InsertionConfig<F>,
68 reg : &Reg, 54 reg: &Reg,
69 state : &AlgIteratorIteration<I>, 55 state: &AlgIteratorIteration<I>,
70 stats : &mut IterInfo<F, N>, 56 stats: &mut IterInfo<F>,
71 ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool) 57 ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
72 where 58 where
73 I : AlgIterator 59 I: AlgIterator,
74 { 60 {
75 61 let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?;
76 let op𝒟norm = self.opnorm_bound(Radon, Linfinity);
77 62
78 // Maximum insertion count and measure difference calculation depend on insertion style. 63 // Maximum insertion count and measure difference calculation depend on insertion style.
79 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { 64 let (max_insertions, warn_insertions) =
80 (i, Some((l, k))) if i <= l => (k, false), 65 match (state.iteration(), config.bootstrap_insertions) {
81 _ => (config.max_insertions, !state.is_quiet()), 66 (i, Some((l, k))) if i <= l => (k, false),
82 }; 67 _ => (config.max_insertions, !state.is_quiet()),
68 };
83 69
84 let ω0 = match ν_delta { 70 let ω0 = match ν_delta {
85 None => self.apply(μ_base), 71 None => self.apply(μ_base),
86 Some(ν) => self.apply(μ_base + ν), 72 Some(ν) => self.apply(μ_base + ν),
87 }; 73 };
93 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 79 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
94 // from the beginning of the iteration are all contained in the immutable c and g. 80 // from the beginning of the iteration are all contained in the immutable c and g.
95 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional 81 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
96 // problems have not yet been updated to sign change. 82 // problems have not yet been updated to sign change.
97 let à = self.findim_matrix(μ.iter_locations()); 83 let à = self.findim_matrix(μ.iter_locations());
98 let g̃ = DVector::from_iterator(μ.len(), 84 let g̃ = DVector::from_iterator(
99 μ.iter_locations() 85 μ.len(),
100 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) 86 μ.iter_locations()
101 .map(F::to_nalgebra_mixed)); 87 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
88 .map(F::to_nalgebra_mixed),
89 );
102 let mut x = μ.masses_dvector(); 90 let mut x = μ.masses_dvector();
103 91
104 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. 92 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
105 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ 93 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
106 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ 94 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
115 μ.set_masses_dvector(&x); 103 μ.set_masses_dvector(&x);
116 } 104 }
117 105
118 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality 106 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
119 // conditions in the predual space, and finding new points for insertion, if necessary. 107 // conditions in the predual space, and finding new points for insertion, if necessary.
120 let mut d = &*τv + match ν_delta { 108 let mut d = &*τv
121 None => self.preapply(μ.sub_matching(μ_base)), 109 + match ν_delta {
122 Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) 110 None => self.preapply(μ.sub_matching(μ_base)),
123 }; 111 Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν),
112 };
124 113
125 // If no merging heuristic is used, let's be more conservative about spike insertion, 114 // If no merging heuristic is used, let's be more conservative about spike insertion,
126 // and skip it after first round. If merging is done, being more greedy about spike 115 // and skip it after first round. If merging is done, being more greedy about spike
127 // insertion also seems to improve performance. 116 // insertion also seems to improve performance.
128 let skip_by_rough_check = if config.merging.enabled { 117 let skip_by_rough_check = if config.merging.enabled {
130 } else { 119 } else {
131 count > 0 120 count > 0
132 }; 121 };
133 122
134 // Find a spike to insert, if needed 123 // Find a spike to insert, if needed
135 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( 124 let (ξ, _v_ξ, in_bounds) =
136 &mut d, τ, ε, skip_by_rough_check, config 125 match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) {
137 ) { 126 None => break 'insertion (true, d),
138 None => break 'insertion (true, d), 127 Some(res) => res,
139 Some(res) => res, 128 };
140 };
141 129
142 // Break if maximum insertion count reached 130 // Break if maximum insertion count reached
143 if count >= max_insertions { 131 if count >= max_insertions {
144 break 'insertion (in_bounds, d) 132 break 'insertion (in_bounds, d);
145 } 133 }
146 134
147 // No point in optimising the weight here; the finite-dimensional algorithm is fast. 135 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
148 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 136 *μ += DeltaMeasure { x: ξ, α: 0.0 };
149 count += 1; 137 count += 1;
150 stats.inserted += 1; 138 stats.inserted += 1;
151 }; 139 };
152 140
153 if !within_tolerances && warn_insertions { 141 if !within_tolerances && warn_insertions {
154 // Complain (but continue) if we failed to get within tolerances 142 // Complain (but continue) if we failed to get within tolerances
155 // by inserting more points. 143 // by inserting more points.
156 let err = format!("Maximum insertions reached without achieving \ 144 let err = format!(
157 subproblem solution tolerance"); 145 "Maximum insertions reached without achieving \
146 subproblem solution tolerance"
147 );
158 println!("{}", err.red()); 148 println!("{}", err.red());
159 } 149 }
160 150
161 (Some(d), within_tolerances) 151 Ok((Some(d), within_tolerances))
162 } 152 }
163 153
164 fn merge_spikes( 154 fn merge_spikes(
165 &self, 155 &self,
166 μ : &mut RNDM<F, N>, 156 μ: &mut DiscreteMeasure<Domain, F>,
167 τv : &mut BTFN<F, GA, BTA, N>, 157 τv: &mut M,
168 μ_base : &RNDM<F, N>, 158 μ_base: &DiscreteMeasure<Domain, F>,
169 ν_delta: Option<&RNDM<F, N>>, 159 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
170 τ : F, 160 τ: F,
171 ε : F, 161 ε: F,
172 config : &FBGenericConfig<F>, 162 config: &InsertionConfig<F>,
173 reg : &Reg, 163 reg: &Reg,
174 fitness : Option<impl Fn(&RNDM<F, N>) -> F>, 164 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
175 ) -> usize 165 ) -> usize {
176 {
177 if config.fitness_merging { 166 if config.fitness_merging {
178 if let Some(f) = fitness { 167 if let Some(f) = fitness {
179 return μ.merge_spikes_fitness(config.merging, f, |&v| v) 168 return μ.merge_spikes_fitness(config.merging, f, |&v| v).1;
180 .1
181 } 169 }
182 } 170 }
183 μ.merge_spikes(config.merging, |μ_candidate| { 171 μ.merge_spikes(config.merging, |μ_candidate| {
184 let mut d = &*τv + self.preapply(match ν_delta { 172 let mut d = &*τv
185 None => μ_candidate.sub_matching(μ_base), 173 + self.preapply(match ν_delta {
186 Some(ν) => μ_candidate.sub_matching(μ_base) - ν, 174 None => μ_candidate.sub_matching(μ_base),
187 }); 175 Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
176 });
188 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) 177 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config)
189 }) 178 })
190 } 179 }
191 } 180 }
181
182 #[replace_float_literals(F::cast_from(literal))]
183 impl<'a, F, A, 𝒟, Domain> StepLengthBound<F, QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>>
184 for 𝒟
185 where
186 Domain: Space + Clone + PartialEq + 'static,
187 F: Float + ToNalgebraRealField,
188 𝒟: DiscreteMeasureOp<Domain, F>,
189 A: ForwardModel<DiscreteMeasure<Domain, F>, F>
190 + for<'b> BoundedLinear<DiscreteMeasure<Domain, F>, &'b 𝒟, L2, F>,
191 DiscreteMeasure<Domain, F>: for<'b> Norm<&'b 𝒟, F>,
192 for<'b> &'b 𝒟: NormExponent,
193 {
194 fn step_length_bound(
195 &self,
196 f: &QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>,
197 ) -> DynResult<F> {
198 // TODO: direct squared calculation
199 Ok(f.operator().opnorm_bound(self, L2)?.powi(2))
200 }
201 }
202
203 #[replace_float_literals(F::cast_from(literal))]
204 impl<F, A, 𝒟, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for 𝒟
205 where
206 Domain: Space + Clone + PartialEq + 'static,
207 F: Float + ToNalgebraRealField,
208 𝒟: DiscreteMeasureOp<Domain, F>,
209 A: for<'a> BoundedLinear<DiscreteMeasure<Domain, F>, &'a 𝒟, L2, F>,
210 DiscreteMeasure<Domain, F>: for<'a> Norm<&'a 𝒟, F>,
211 for<'b> &'b 𝒟: NormExponent,
212 {
213 fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> {
214 opA.opnorm_bound(self, L2)
215 }
216 }

mercurial