src/prox_penalty/wave.rs

changeset 70
ed16d0f10d08
parent 63
7a8a55fd41c0
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
1 /*! 1 /*!
2 Basic proximal penalty based on convolution operators $𝒟$. 2 Basic proximal penalty based on convolution operators $𝒟$.
3 */ 3 */
4 4
5 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
6 use crate::dataterm::QuadraticDataTerm;
7 use crate::forward_model::ForwardModel;
8 use crate::measures::merging::SpikeMerging;
9 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
10 use crate::regularisation::RegTerm;
11 use crate::seminorms::DiscreteMeasureOp;
12 use crate::types::IterInfo;
13 use alg_tools::bounds::MinMaxMapping;
14 use alg_tools::error::DynResult;
15 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
16 use alg_tools::linops::BoundedLinear;
17 use alg_tools::mapping::{Instance, Mapping, Space};
18 use alg_tools::nalgebra_support::ToNalgebraRealField;
19 use alg_tools::norms::{Linfinity, Norm, NormExponent, L2};
20 use alg_tools::types::*;
21 use colored::Colorize;
22 use nalgebra::DVector;
5 use numeric_literals::replace_float_literals; 23 use numeric_literals::replace_float_literals;
6 use nalgebra::DVector;
7 use colored::Colorize;
8
9 use alg_tools::types::*;
10 use alg_tools::loc::Loc;
11 use alg_tools::mapping::{Mapping, RealMapping};
12 use alg_tools::nalgebra_support::ToNalgebraRealField;
13 use alg_tools::norms::Linfinity;
14 use alg_tools::iterate::{
15 AlgIteratorIteration,
16 AlgIterator,
17 };
18 use alg_tools::bisection_tree::{
19 BTFN,
20 PreBTFN,
21 Bounds,
22 BTSearch,
23 SupportGenerator,
24 LocalAnalysis,
25 BothGenerators,
26 };
27 use crate::measures::{
28 RNDM,
29 DeltaMeasure,
30 Radon,
31 };
32 use crate::measures::merging::{
33 SpikeMerging,
34 };
35 use crate::seminorms::DiscreteMeasureOp;
36 use crate::types::{
37 IterInfo,
38 };
39 use crate::regularisation::RegTerm;
40 use super::{ProxPenalty, FBGenericConfig};
41 24
42 #[replace_float_literals(F::cast_from(literal))] 25 #[replace_float_literals(F::cast_from(literal))]
43 impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize> 26 impl<F, M, Reg, 𝒟, O, Domain> ProxPenalty<Domain, M, Reg, F> for 𝒟
44 ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟
45 where 27 where
46 F : Float + ToNalgebraRealField, 28 Domain: Space + Clone + PartialEq + 'static,
47 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 29 for<'a> &'a Domain: Instance<Domain>,
48 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 30 F: Float + ToNalgebraRealField,
49 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 31 𝒟: DiscreteMeasureOp<Domain, F>,
50 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 32 𝒟::Codomain: Mapping<Domain, Codomain = F>,
51 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 33 M: Mapping<Domain, Codomain = F>,
52 𝒟::Codomain : RealMapping<F, N>, 34 for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>,
53 K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 35 O: MinMaxMapping<Domain, F>,
54 Reg : RegTerm<F, N>, 36 Reg: RegTerm<Domain, F>,
55 RNDM<F, N> : SpikeMerging<F>, 37 DiscreteMeasure<Domain, F>: SpikeMerging<F>,
56 { 38 {
57 type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>; 39 type ReturnMapping = O;
40
41 fn prox_type() -> ProxTerm {
42 ProxTerm::Wave
43 }
58 44
59 fn insert_and_reweigh<I>( 45 fn insert_and_reweigh<I>(
60 &self, 46 &self,
61 μ : &mut RNDM<F, N>, 47 μ: &mut DiscreteMeasure<Domain, F>,
62 τv : &mut BTFN<F, GA, BTA, N>, 48 τv: &mut M,
63 μ_base : &RNDM<F, N>, 49 τ: F,
64 ν_delta: Option<&RNDM<F, N>>, 50 ε: F,
65 τ : F, 51 config: &InsertionConfig<F>,
66 ε : F, 52 reg: &Reg,
67 config : &FBGenericConfig<F>, 53 state: &AlgIteratorIteration<I>,
68 reg : &Reg, 54 stats: &mut IterInfo<F>,
69 state : &AlgIteratorIteration<I>, 55 ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
70 stats : &mut IterInfo<F, N>,
71 ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool)
72 where 56 where
73 I : AlgIterator 57 I: AlgIterator,
74 { 58 {
75 59 let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?;
76 let op𝒟norm = self.opnorm_bound(Radon, Linfinity);
77 60
78 // Maximum insertion count and measure difference calculation depend on insertion style. 61 // Maximum insertion count and measure difference calculation depend on insertion style.
79 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { 62 let (max_insertions, warn_insertions) =
80 (i, Some((l, k))) if i <= l => (k, false), 63 match (state.iteration(), config.bootstrap_insertions) {
81 _ => (config.max_insertions, !state.is_quiet()), 64 (i, Some((l, k))) if i <= l => (k, false),
82 }; 65 _ => (config.max_insertions, !state.is_quiet()),
83 66 };
84 let ω0 = match ν_delta { 67
85 None => self.apply(μ_base), 68 let μ_base = μ.clone();
86 Some(ν) => self.apply(μ_base + ν), 69 let ω0 = self.apply(&μ_base);
87 };
88 70
89 // Add points to support until within error tolerance or maximum insertion count reached. 71 // Add points to support until within error tolerance or maximum insertion count reached.
90 let mut count = 0; 72 let mut count = 0;
91 let (within_tolerances, d) = 'insertion: loop { 73 let (within_tolerances, d) = 'insertion: loop {
92 if μ.len() > 0 { 74 if μ.len() > 0 {
93 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 75 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
94 // from the beginning of the iteration are all contained in the immutable c and g. 76 // from the beginning of the iteration are all contained in the immutable c and g.
95 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional 77 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
96 // problems have not yet been updated to sign change. 78 // problems have not yet been updated to sign change.
97 let à = self.findim_matrix(μ.iter_locations()); 79 let à = self.findim_matrix(μ.iter_locations());
98 let g̃ = DVector::from_iterator(μ.len(), 80 let g̃ = DVector::from_iterator(
99 μ.iter_locations() 81 μ.len(),
100 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) 82 μ.iter_locations()
101 .map(F::to_nalgebra_mixed)); 83 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
84 .map(F::to_nalgebra_mixed),
85 );
102 let mut x = μ.masses_dvector(); 86 let mut x = μ.masses_dvector();
103 87
104 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. 88 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
105 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ 89 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
106 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ 90 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
115 μ.set_masses_dvector(&x); 99 μ.set_masses_dvector(&x);
116 } 100 }
117 101
118 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality 102 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
119 // conditions in the predual space, and finding new points for insertion, if necessary. 103 // conditions in the predual space, and finding new points for insertion, if necessary.
120 let mut d = &*τv + match ν_delta { 104 let mut d = &*τv + self.preapply(μ.sub_matching(&μ_base));
121 None => self.preapply(μ.sub_matching(μ_base)),
122 Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν)
123 };
124 105
125 // If no merging heuristic is used, let's be more conservative about spike insertion, 106 // If no merging heuristic is used, let's be more conservative about spike insertion,
126 // and skip it after first round. If merging is done, being more greedy about spike 107 // and skip it after first round. If merging is done, being more greedy about spike
127 // insertion also seems to improve performance. 108 // insertion also seems to improve performance.
128 let skip_by_rough_check = if config.merging.enabled { 109 let skip_by_rough_check = if config.merging.enabled {
130 } else { 111 } else {
131 count > 0 112 count > 0
132 }; 113 };
133 114
134 // Find a spike to insert, if needed 115 // Find a spike to insert, if needed
135 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( 116 let (ξ, _v_ξ, in_bounds) =
136 &mut d, τ, ε, skip_by_rough_check, config 117 match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) {
137 ) { 118 None => break 'insertion (true, d),
138 None => break 'insertion (true, d), 119 Some(res) => res,
139 Some(res) => res, 120 };
140 };
141 121
142 // Break if maximum insertion count reached 122 // Break if maximum insertion count reached
143 if count >= max_insertions { 123 if count >= max_insertions {
144 break 'insertion (in_bounds, d) 124 break 'insertion (in_bounds, d);
145 } 125 }
146 126
147 // No point in optimising the weight here; the finite-dimensional algorithm is fast. 127 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
148 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 128 *μ += DeltaMeasure { x: ξ, α: 0.0 };
149 count += 1; 129 count += 1;
150 stats.inserted += 1;
151 }; 130 };
152 131
153 if !within_tolerances && warn_insertions { 132 if !within_tolerances && warn_insertions {
154 // Complain (but continue) if we failed to get within tolerances 133 // Complain (but continue) if we failed to get within tolerances
155 // by inserting more points. 134 // by inserting more points.
156 let err = format!("Maximum insertions reached without achieving \ 135 let err = format!(
157 subproblem solution tolerance"); 136 "Maximum insertions reached without achieving \
137 subproblem solution tolerance"
138 );
158 println!("{}", err.red()); 139 println!("{}", err.red());
159 } 140 }
160 141
161 (Some(d), within_tolerances) 142 Ok((Some(d), within_tolerances))
162 } 143 }
163 144
164 fn merge_spikes( 145 fn merge_spikes(
165 &self, 146 &self,
166 μ : &mut RNDM<F, N>, 147 μ: &mut DiscreteMeasure<Domain, F>,
167 τv : &mut BTFN<F, GA, BTA, N>, 148 τv: &mut M,
168 μ_base : &RNDM<F, N>, 149 μ_base: &DiscreteMeasure<Domain, F>,
169 ν_delta: Option<&RNDM<F, N>>, 150 τ: F,
170 τ : F, 151 ε: F,
171 ε : F, 152 config: &InsertionConfig<F>,
172 config : &FBGenericConfig<F>, 153 reg: &Reg,
173 reg : &Reg, 154 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
174 fitness : Option<impl Fn(&RNDM<F, N>) -> F>, 155 ) -> usize {
175 ) -> usize
176 {
177 if config.fitness_merging { 156 if config.fitness_merging {
178 if let Some(f) = fitness { 157 if let Some(f) = fitness {
179 return μ.merge_spikes_fitness(config.merging, f, |&v| v) 158 return μ.merge_spikes_fitness(config.merging, f, |&v| v).1;
180 .1
181 } 159 }
182 } 160 }
183 μ.merge_spikes(config.merging, |μ_candidate| { 161 μ.merge_spikes(config.merging, |μ_candidate| {
184 let mut d = &*τv + self.preapply(match ν_delta { 162 let mut d = &*τv + self.preapply(μ_candidate.sub_matching(μ_base));
185 None => μ_candidate.sub_matching(μ_base),
186 Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
187 });
188 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) 163 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config)
189 }) 164 })
190 } 165 }
191 } 166 }
167
168 #[replace_float_literals(F::cast_from(literal))]
169 impl<'a, F, A, 𝒟, Domain> StepLengthBound<F, QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>>
170 for 𝒟
171 where
172 Domain: Space + Clone + PartialEq + 'static,
173 F: Float + ToNalgebraRealField,
174 𝒟: DiscreteMeasureOp<Domain, F>,
175 A: ForwardModel<DiscreteMeasure<Domain, F>, F>
176 + for<'b> BoundedLinear<DiscreteMeasure<Domain, F>, &'b 𝒟, L2, F>,
177 DiscreteMeasure<Domain, F>: for<'b> Norm<&'b 𝒟, F>,
178 for<'b> &'b 𝒟: NormExponent,
179 {
180 fn step_length_bound(
181 &self,
182 f: &QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>,
183 ) -> DynResult<F> {
184 // TODO: direct squared calculation
185 Ok(f.operator().opnorm_bound(self, L2)?.powi(2))
186 }
187 }
188
189 #[replace_float_literals(F::cast_from(literal))]
190 impl<F, A, 𝒟, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for 𝒟
191 where
192 Domain: Space + Clone + PartialEq + 'static,
193 F: Float + ToNalgebraRealField,
194 𝒟: DiscreteMeasureOp<Domain, F>,
195 A: for<'a> BoundedLinear<DiscreteMeasure<Domain, F>, &'a 𝒟, L2, F>,
196 DiscreteMeasure<Domain, F>: for<'a> Norm<&'a 𝒟, F>,
197 for<'b> &'b 𝒟: NormExponent,
198 {
199 fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> {
200 opA.opnorm_bound(self, L2)
201 }
202 }

mercurial