src/prox_penalty.rs

branch
dev
changeset 61
4f468d35fa29
parent 51
0693cc9ba9f0
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
5 use alg_tools::types::*; 5 use alg_tools::types::*;
6 use numeric_literals::replace_float_literals; 6 use numeric_literals::replace_float_literals;
7 use serde::{Deserialize, Serialize}; 7 use serde::{Deserialize, Serialize};
8 8
9 use crate::measures::merging::SpikeMergingMethod; 9 use crate::measures::merging::SpikeMergingMethod;
10 use crate::measures::RNDM; 10 use crate::measures::DiscreteMeasure;
11 use crate::regularisation::RegTerm; 11 use crate::regularisation::RegTerm;
12 use crate::subproblem::InnerSettings; 12 use crate::subproblem::InnerSettings;
13 use crate::tolerance::Tolerance; 13 use crate::tolerance::Tolerance;
14 use crate::types::{IterInfo, RefinementSettings}; 14 use crate::types::{IterInfo, RefinementSettings};
15 use alg_tools::error::DynResult;
16 use alg_tools::instance::Space;
15 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; 17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
16 use alg_tools::mapping::RealMapping; 18 use alg_tools::mapping::Mapping;
17 use alg_tools::nalgebra_support::ToNalgebraRealField; 19 use alg_tools::nalgebra_support::ToNalgebraRealField;
18 20
19 pub mod radon_squared; 21 pub mod radon_squared;
20 pub mod wave; 22 pub mod wave;
21 pub use radon_squared::RadonSquared; 23 pub use radon_squared::RadonSquared;
22 24
23 /// Settings for the solution of the stepwise optimality condition. 25 /// Settings for the solution of the stepwise optimality condition.
24 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 26 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
25 #[serde(default)] 27 #[serde(default)]
26 pub struct FBGenericConfig<F: Float> { 28 pub struct InsertionConfig<F: Float> {
27 /// Tolerance for point insertion. 29 /// Tolerance for point insertion.
28 pub tolerance: Tolerance<F>, 30 pub tolerance: Tolerance<F>,
29 31
30 /// Stop looking for predual maximum (where to isert a new point) below 32 /// Stop looking for predual maximum (where to isert a new point) below
31 /// `tolerance` multiplied by this factor. 33 /// `tolerance` multiplied by this factor.
66 // /// Save $μ$ for postprocessing optimisation 68 // /// Save $μ$ for postprocessing optimisation
67 // pub postprocessing : bool 69 // pub postprocessing : bool
68 } 70 }
69 71
70 #[replace_float_literals(F::cast_from(literal))] 72 #[replace_float_literals(F::cast_from(literal))]
71 impl<F: Float> Default for FBGenericConfig<F> { 73 impl<F: Float> Default for InsertionConfig<F> {
72 fn default() -> Self { 74 fn default() -> Self {
73 FBGenericConfig { 75 InsertionConfig {
74 tolerance: Default::default(), 76 tolerance: Default::default(),
75 insertion_cutoff_factor: 1.0, 77 insertion_cutoff_factor: 1.0,
76 refinement: Default::default(), 78 refinement: Default::default(),
77 max_insertions: 100, 79 max_insertions: 100,
78 //bootstrap_insertions : None, 80 //bootstrap_insertions : None,
86 // postprocessing : false, 88 // postprocessing : false,
87 } 89 }
88 } 90 }
89 } 91 }
90 92
91 impl<F: Float> FBGenericConfig<F> { 93 impl<F: Float> InsertionConfig<F> {
92 /// Check if merging should be attempted this iteration 94 /// Check if merging should be attempted this iteration
93 pub fn merge_now<I: AlgIterator>(&self, state: &AlgIteratorIteration<I>) -> bool { 95 pub fn merge_now<I: AlgIterator>(&self, state: &AlgIteratorIteration<I>) -> bool {
94 self.merging.enabled && state.iteration() % self.merge_every == 0 96 self.merging.enabled && state.iteration() % self.merge_every == 0
95 } 97 }
96 98
97 /// Returns the final merging method 99 /// Returns the final merging method
98 pub fn final_merging_method(&self) -> SpikeMergingMethod<F> { 100 pub fn final_merging_method(&self) -> SpikeMergingMethod<F> {
99 SpikeMergingMethod { 101 SpikeMergingMethod { enabled: self.final_merging, ..self.merging }
100 enabled: self.final_merging, 102 }
101 ..self.merging 103 }
102 } 104
103 } 105 /// Available proximal terms
106 #[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
107 pub enum ProxTerm {
108 /// Partial-to-wave operator 𝒟.
109 Wave,
110 /// Radon-norm squared
111 RadonSquared,
104 } 112 }
105 113
106 /// Trait for proximal penalties 114 /// Trait for proximal penalties
107 pub trait ProxPenalty<F, PreadjointCodomain, Reg, const N: usize> 115 pub trait ProxPenalty<Domain, PreadjointCodomain, Reg, F = f64>
108 where 116 where
109 F: Float + ToNalgebraRealField, 117 F: Float + ToNalgebraRealField,
110 Reg: RegTerm<F, N>, 118 Reg: RegTerm<Domain, F>,
119 Domain: Space + Clone,
111 { 120 {
112 type ReturnMapping: RealMapping<F, N>; 121 type ReturnMapping: Mapping<Domain, Codomain = F>;
122
123 /// Returns the type of this proximality penalty
124 fn prox_type() -> ProxTerm;
113 125
114 /// Insert new spikes into `μ` to approximately satisfy optimality conditions 126 /// Insert new spikes into `μ` to approximately satisfy optimality conditions
115 /// with the forward step term fixed to `τv`. 127 /// with the forward step term fixed to `τv`.
116 /// 128 ///
117 /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`, 129 /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`,
118 /// as well as an indication of whether the tolerance bounds `ε` are satisfied. 130 /// as well as an indication of whether the tolerance bounds `ε` are satisfied.
119 /// 131 ///
120 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same 132 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same
121 /// spike locations, while `ν_delta` may have different locations. 133 /// spike locations, while `ν_delta` may have different locations.
122 /// 134 ///
123 /// `τv` is mutable to allow [`alg_tools::bisection_tree::BTFN`] refinement. 135 /// `τv` is mutable to allow [`alg_tools::bounds::MinMaxMapping`] optimisation to
124 /// Actual values of `τv` are not supposed to be mutated. 136 /// refine data. Actual values of `τv` are not supposed to be mutated.
125 fn insert_and_reweigh<I>( 137 fn insert_and_reweigh<I>(
126 &self, 138 &self,
127 μ: &mut RNDM<F, N>, 139 μ: &mut DiscreteMeasure<Domain, F>,
128 τv: &mut PreadjointCodomain, 140 τv: &mut PreadjointCodomain,
129 μ_base: &RNDM<F, N>, 141 μ_base: &DiscreteMeasure<Domain, F>,
130 ν_delta: Option<&RNDM<F, N>>, 142 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
131 τ: F, 143 τ: F,
132 ε: F, 144 ε: F,
133 config: &FBGenericConfig<F>, 145 config: &InsertionConfig<F>,
134 reg: &Reg, 146 reg: &Reg,
135 state: &AlgIteratorIteration<I>, 147 state: &AlgIteratorIteration<I>,
136 stats: &mut IterInfo<F, N>, 148 stats: &mut IterInfo<F>,
137 ) -> (Option<Self::ReturnMapping>, bool) 149 ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
138 where 150 where
139 I: AlgIterator; 151 I: AlgIterator;
140 152
141 /// Merge spikes, if possible. 153 /// Merge spikes, if possible.
142 /// 154 ///
143 /// Either optimality condition merging or objective value (fitness) merging 155 /// Either optimality condition merging or objective value (fitness) merging
144 /// may be used, the latter only if `fitness` is provided and `config.fitness_merging` 156 /// may be used, the latter only if `fitness` is provided and `config.fitness_merging`
145 /// is set. 157 /// is set.
146 fn merge_spikes( 158 fn merge_spikes(
147 &self, 159 &self,
148 μ: &mut RNDM<F, N>, 160 μ: &mut DiscreteMeasure<Domain, F>,
149 τv: &mut PreadjointCodomain, 161 τv: &mut PreadjointCodomain,
150 μ_base: &RNDM<F, N>, 162 μ_base: &DiscreteMeasure<Domain, F>,
151 ν_delta: Option<&RNDM<F, N>>, 163 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
152 τ: F, 164 τ: F,
153 ε: F, 165 ε: F,
154 config: &FBGenericConfig<F>, 166 config: &InsertionConfig<F>,
155 reg: &Reg, 167 reg: &Reg,
156 fitness: Option<impl Fn(&RNDM<F, N>) -> F>, 168 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
157 ) -> usize; 169 ) -> usize;
158 170
159 /// Merge spikes, if possible. 171 /// Merge spikes, if possible.
160 /// 172 ///
161 /// Unlike [`Self::merge_spikes`], this variant only supports optimality condition based merging 173 /// Unlike [`Self::merge_spikes`], this variant only supports optimality condition based merging
162 #[inline] 174 #[inline]
163 fn merge_spikes_no_fitness( 175 fn merge_spikes_no_fitness(
164 &self, 176 &self,
165 μ: &mut RNDM<F, N>, 177 μ: &mut DiscreteMeasure<Domain, F>,
166 τv: &mut PreadjointCodomain, 178 τv: &mut PreadjointCodomain,
167 μ_base: &RNDM<F, N>, 179 μ_base: &DiscreteMeasure<Domain, F>,
168 ν_delta: Option<&RNDM<F, N>>, 180 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
169 τ: F, 181 τ: F,
170 ε: F, 182 ε: F,
171 config: &FBGenericConfig<F>, 183 config: &InsertionConfig<F>,
172 reg: &Reg, 184 reg: &Reg,
173 ) -> usize { 185 ) -> usize {
174 /// This is a hack to create a `None` of same type as a `Some` 186 /// This is a hack to create a `None` of same type as a `Some`
175 // for the `impl Fn` parameter of `merge_spikes`. 187 // for the `impl Fn` parameter of `merge_spikes`.
176 #[inline] 188 #[inline]
184 ν_delta, 196 ν_delta,
185 τ, 197 τ,
186 ε, 198 ε,
187 config, 199 config,
188 reg, 200 reg,
189 into_none(Some(|_: &RNDM<F, N>| F::ZERO)), 201 into_none(Some(|_: &DiscreteMeasure<Domain, F>| F::ZERO)),
190 ) 202 )
191 } 203 }
192 } 204 }
205
206 /// Trait to calculate step length bound by `Dat` when the proximal penalty is `Self`,
207 /// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$, and
208 /// and `Dat` respresents the function $f$, then this trait should calculate `L` such that
209 /// $\|f'(x)-f'(y)\| ≤ L\|x-y\|_*, where the step length is supposed to satisfy $τ L ≤ 1$.
210 pub trait StepLengthBound<F: Float, Dat> {
211 /// Returns $L$.
212 fn step_length_bound(&self, f: &Dat) -> DynResult<F>;
213 }
214
215 /// A variant of [`StepLengthBound`] for step length parameters for [`Pair`]s of variables.
216 pub trait StepLengthBoundPair<F: Float, Dat> {
217 fn step_length_bound_pair(&self, f: &Dat) -> DynResult<(F, F)>;
218 }
219
220 /// Trait to calculate step length bound by the operator `A` when the proximal penalty is `Self`,
221 /// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$,
222 /// then this trait should calculate `L` such that
223 /// $\|Ax\| ≤ L\|x\|_*, where the primal-dual step lengths are supposed to satisfy $τσ L^2 ≤ 1$.
224 /// The domain needs to be specified here, because A can operate on various domains.
225 pub trait StepLengthBoundPD<F: Float, A, Domain> {
226 /// Returns $L$.
227 fn step_length_bound_pd(&self, f: &A) -> DynResult<F>;
228 }

mercurial