| 78 */ |
78 */ |
| 79 |
79 |
| 80 use crate::measures::merging::SpikeMerging; |
80 use crate::measures::merging::SpikeMerging; |
| 81 use crate::measures::{DiscreteMeasure, RNDM}; |
81 use crate::measures::{DiscreteMeasure, RNDM}; |
| 82 use crate::plot::Plotter; |
82 use crate::plot::Plotter; |
| |
83 use crate::prox_penalty::StepLengthBoundValue; |
| 83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; |
84 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; |
| 84 use crate::regularisation::RegTerm; |
85 use crate::regularisation::RegTerm; |
| 85 use crate::types::*; |
86 use crate::types::*; |
| 86 use alg_tools::error::DynResult; |
87 use alg_tools::error::DynResult; |
| 87 use alg_tools::instance::Instance; |
88 use alg_tools::instance::{ClosedSpace, Instance}; |
| 88 use alg_tools::iterate::AlgIteratorFactory; |
89 use alg_tools::iterate::AlgIteratorFactory; |
| 89 use alg_tools::mapping::DifferentiableMapping; |
90 use alg_tools::mapping::{DifferentiableMapping, Mapping}; |
| 90 use alg_tools::nalgebra_support::ToNalgebraRealField; |
91 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
92 use anyhow::anyhow; |
| 91 use colored::Colorize; |
93 use colored::Colorize; |
| 92 use numeric_literals::replace_float_literals; |
94 use numeric_literals::replace_float_literals; |
| 93 use serde::{Deserialize, Serialize}; |
95 use serde::{Deserialize, Serialize}; |
| 94 |
96 |
| 95 /// Settings for [`pointsource_fb_reg`]. |
97 /// Settings for [`pointsource_fb_reg`]. |
| 100 pub τ0: F, |
102 pub τ0: F, |
| 101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] |
103 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] |
| 102 pub σp0: F, |
104 pub σp0: F, |
| 103 /// Generic parameters |
105 /// Generic parameters |
| 104 pub insertion: InsertionConfig<F>, |
106 pub insertion: InsertionConfig<F>, |
| |
107 /// Always adaptive step length |
| |
108 pub always_adaptive_τ: bool, |
| 105 } |
109 } |
| 106 |
110 |
| 107 #[replace_float_literals(F::cast_from(literal))] |
111 #[replace_float_literals(F::cast_from(literal))] |
| 108 impl<F: Float> Default for FBConfig<F> { |
112 impl<F: Float> Default for FBConfig<F> { |
| 109 fn default() -> Self { |
113 fn default() -> Self { |
| 110 FBConfig { |
114 FBConfig { τ0: 0.99, σp0: 0.99, always_adaptive_τ: false, insertion: Default::default() } |
| 111 τ0: 0.99, |
|
| 112 σp0: 0.99, |
|
| 113 insertion: Default::default(), |
|
| 114 } |
|
| 115 } |
115 } |
| 116 } |
116 } |
| 117 |
117 |
| 118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
| 119 let n_before_prune = μ.len(); |
119 let n_before_prune = μ.len(); |
| 120 μ.prune(); |
120 μ.prune(); |
| 121 debug_assert!(μ.len() <= n_before_prune); |
121 debug_assert!(μ.len() <= n_before_prune); |
| 122 n_before_prune - μ.len() |
122 n_before_prune - μ.len() |
| |
123 } |
| |
124 |
| |
125 /// Adaptive step length and Lipschitz parameter estimation state. |
| |
126 #[derive(Clone, Debug, Serialize)] |
| |
127 pub enum AdaptiveStepLength<const N: usize, F: Float> { |
| |
128 Adaptive { |
| |
129 l: F, |
| |
130 μ_old: RNDM<N, F>, |
| |
131 fμ_old: F, |
| |
132 μ_dist: F, |
| |
133 τ0: F, |
| |
134 l_is_initial: bool, |
| |
135 }, |
| |
136 Fixed { |
| |
137 τ: F, |
| |
138 }, |
| |
139 } |
| |
140 |
| |
141 #[replace_float_literals(F::cast_from(literal))] |
| |
142 impl<const N: usize, F: Float> AdaptiveStepLength<N, F> { |
| |
143 pub fn new<Dat, Reg, P>(f: &Dat, prox_penalty: &P, fbconfig: &FBConfig<F>) -> DynResult<Self> |
| |
144 where |
| |
145 F: ToNalgebraRealField, |
| |
146 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| |
147 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| |
148 Reg: RegTerm<Loc<N, F>, F>, |
| |
149 { |
| |
150 match ( |
| |
151 prox_penalty.step_length_bound(&f), |
| |
152 fbconfig.always_adaptive_τ, |
| |
153 ) { |
| |
154 (StepLengthBoundValue::LipschitzFactor(l), false) => { |
| |
155 Ok(AdaptiveStepLength::Fixed { τ: fbconfig.τ0 / l }) |
| |
156 } |
| |
157 (StepLengthBoundValue::LipschitzFactor(l), true) => { |
| |
158 let μ_old = DiscreteMeasure::new(); |
| |
159 let fμ_old = f.apply(&μ_old); |
| |
160 Ok(AdaptiveStepLength::Adaptive { |
| |
161 l: l, |
| |
162 μ_old, |
| |
163 fμ_old, |
| |
164 μ_dist: 0.0, |
| |
165 τ0: fbconfig.τ0, |
| |
166 l_is_initial: false, |
| |
167 }) |
| |
168 } |
| |
169 (StepLengthBoundValue::UnreliableLipschitzFactor(l), _) => { |
| |
170 println!("Lipschitz factor is unreliable; calculating adaptively."); |
| |
171 let μ_old = DiscreteMeasure::new(); |
| |
172 let fμ_old = f.apply(&μ_old); |
| |
173 Ok(AdaptiveStepLength::Adaptive { |
| |
174 l: l, |
| |
175 μ_old, |
| |
176 fμ_old, |
| |
177 μ_dist: 0.0, |
| |
178 τ0: fbconfig.τ0, |
| |
179 l_is_initial: true, |
| |
180 }) |
| |
181 } |
| |
182 (StepLengthBoundValue::Failure, _) => Err(anyhow!("No Lipschitz estimate available")), |
| |
183 } |
| |
184 } |
| |
185 |
| |
186 /// Returns the current value of the step length parameter. |
| |
187 pub fn current(&self) -> F { |
| |
188 match *self { |
| |
189 AdaptiveStepLength::Adaptive { τ0, l, .. } => τ0 / l, |
| |
190 AdaptiveStepLength::Fixed { τ } => τ, |
| |
191 } |
| |
192 } |
| |
193 |
| |
194 /// Update daptive Lipschitz factor and return new step length parameter `τ`. |
| |
195 /// |
| |
196 /// Inputs: |
| |
197 /// * `μ`: current point |
| |
198 /// * `fμ`: value of the function `f` at `μ`. |
| |
199 /// * `ν`: derivative of the function `f` at `μ`. |
| |
200 /// * `τ0`: fractional step length parameter in $[0, 1)$. |
| |
201 pub fn update<'a, G>(&mut self, μ: &RNDM<N, F>, fμ: F, v: &'a G) -> F |
| |
202 where |
| |
203 G: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace, |
| |
204 &'a G: Instance<G>, |
| |
205 { |
| |
206 match self { |
| |
207 AdaptiveStepLength::Adaptive { l, μ_old, fμ_old, μ_dist, τ0, l_is_initial } => { |
| |
208 // Estimate step length parameter |
| |
209 let b = *fμ_old - fμ - μ_old.apply(v) + μ.apply(v); |
| |
210 let d = *μ_dist; |
| |
211 if d.abs() > F::EPSILON && μ.len() > 0 && μ_old.len() > 0 { |
| |
212 let lc = b / (d * d / 2.0); |
| |
213 dbg!(b, d, lc); |
| |
214 if *l_is_initial { |
| |
215 *l = lc; |
| |
216 *l_is_initial = false; |
| |
217 } else { |
| |
218 *l = l.max(lc); |
| |
219 } |
| |
220 } |
| |
221 |
| |
222 // Store for next iteration |
| |
223 *μ_old = μ.clone(); |
| |
224 *fμ_old = fμ; |
| |
225 |
| |
226 return *τ0 / *l; |
| |
227 } |
| |
228 AdaptiveStepLength::Fixed { τ } => *τ, |
| |
229 } |
| |
230 } |
| |
231 |
| |
232 /// Finalises a step, storing μ and its distance to the previous μ. |
| |
233 /// |
| |
234 /// This is not included in [`Self::update`], as this function is to be called |
| |
235 /// before pruning and merging, while μ and its previous version in their internal |
| |
236 /// presentation still having matching indices for the same coordinate. |
| |
237 pub fn finish_step(&mut self, μ: &RNDM<N, F>) { |
| |
238 if let AdaptiveStepLength::Adaptive { μ_dist, μ_old, .. } = self { |
| |
239 *μ_dist = μ.dist_matching(&μ_old); |
| |
240 } |
| |
241 } |
| 123 } |
242 } |
| 124 |
243 |
| 125 #[replace_float_literals(F::cast_from(literal))] |
244 #[replace_float_literals(F::cast_from(literal))] |
| 126 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( |
245 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( |
| 127 mut μ: RNDM<N, F>, |
246 mut μ: RNDM<N, F>, |
| 155 reg: &Reg, |
274 reg: &Reg, |
| 156 prox_penalty: &P, |
275 prox_penalty: &P, |
| 157 fbconfig: &FBConfig<F>, |
276 fbconfig: &FBConfig<F>, |
| 158 iterator: I, |
277 iterator: I, |
| 159 mut plotter: Plot, |
278 mut plotter: Plot, |
| 160 μ0 : Option<RNDM<N, F>>, |
279 μ0: Option<RNDM<N, F>>, |
| 161 ) -> DynResult<RNDM<N, F>> |
280 ) -> DynResult<RNDM<N, F>> |
| 162 where |
281 where |
| 163 F: Float + ToNalgebraRealField, |
282 F: Float + ToNalgebraRealField, |
| 164 I: AlgIteratorFactory<IterInfo<F>>, |
283 I: AlgIteratorFactory<IterInfo<F>>, |
| 165 RNDM<N, F>: SpikeMerging<F>, |
284 RNDM<N, F>: SpikeMerging<F>, |
| 166 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
285 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 167 Dat::DerivativeDomain: ClosedMul<F>, |
286 Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace, |
| 168 Reg: RegTerm<Loc<N, F>, F>, |
287 Reg: RegTerm<Loc<N, F>, F>, |
| 169 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
288 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 170 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
289 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| |
290 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>, |
| 171 { |
291 { |
| 172 // Set up parameters |
292 // Set up parameters |
| 173 let config = &fbconfig.insertion; |
293 let config = &fbconfig.insertion; |
| 174 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
294 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?; |
| |
295 |
| 175 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
296 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 176 // by τ compared to the conditional gradient approach. |
297 // by τ compared to the conditional gradient approach. |
| 177 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
298 let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling(); |
| 178 let mut ε = tolerance.initial(); |
299 let mut ε = tolerance.initial(); |
| 179 |
300 |
| 180 // Initialise iterates |
301 // Initialise iterates |
| 181 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
302 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 182 |
303 |
| 192 |
313 |
| 193 // Run the algorithm |
314 // Run the algorithm |
| 194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
315 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 195 // Calculate smooth part of surrogate model. |
316 // Calculate smooth part of surrogate model. |
| 196 // TODO: optimise τ to be applied to residual. |
317 // TODO: optimise τ to be applied to residual. |
| 197 let mut τv = f.differential(&μ) * τ; |
318 let (fμ, v) = f.apply_and_differential(&μ); |
| |
319 let τ = adaptive_τ.update(&μ, fμ, &v); |
| |
320 dbg!(τ); |
| |
321 let mut τv = v * τ; |
| 198 |
322 |
| 199 // Save current base point |
323 // Save current base point |
| 200 let μ_base = μ.clone(); |
324 let μ_base = μ.clone(); |
| 201 |
325 |
| 202 // Insert and reweigh |
326 // Insert and reweigh |
| 203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
327 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
328 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 205 )?; |
329 )?; |
| |
330 |
| |
331 // We don't treat merge in adaptive Lipschitz. |
| |
332 adaptive_τ.finish_step(&μ); |
| 206 |
333 |
| 207 // Prune and possibly merge spikes |
334 // Prune and possibly merge spikes |
| 208 if config.merge_now(&state) { |
335 if config.merge_now(&state) { |
| 209 stats.merged += prox_penalty.merge_spikes( |
336 stats.merged += prox_penalty.merge_spikes( |
| 210 &mut μ, |
337 &mut μ, |
| 255 reg: &Reg, |
382 reg: &Reg, |
| 256 prox_penalty: &P, |
383 prox_penalty: &P, |
| 257 fbconfig: &FBConfig<F>, |
384 fbconfig: &FBConfig<F>, |
| 258 iterator: I, |
385 iterator: I, |
| 259 mut plotter: Plot, |
386 mut plotter: Plot, |
| 260 μ0: Option<RNDM<N, F>> |
387 μ0: Option<RNDM<N, F>>, |
| 261 ) -> DynResult<RNDM<N, F>> |
388 ) -> DynResult<RNDM<N, F>> |
| 262 where |
389 where |
| 263 F: Float + ToNalgebraRealField, |
390 F: Float + ToNalgebraRealField, |
| 264 I: AlgIteratorFactory<IterInfo<F>>, |
391 I: AlgIteratorFactory<IterInfo<F>>, |
| 265 RNDM<N, F>: SpikeMerging<F>, |
392 RNDM<N, F>: SpikeMerging<F>, |
| 266 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
393 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 267 Dat::DerivativeDomain: ClosedMul<F>, |
394 Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace, |
| 268 Reg: RegTerm<Loc<N, F>, F>, |
395 Reg: RegTerm<Loc<N, F>, F>, |
| 269 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
396 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 270 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
397 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| |
398 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>, |
| 271 { |
399 { |
| 272 // Set up parameters |
400 // Set up parameters |
| 273 let config = &fbconfig.insertion; |
401 let config = &fbconfig.insertion; |
| 274 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
402 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?; |
| |
403 |
| 275 let mut λ = 1.0; |
404 let mut λ = 1.0; |
| 276 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
405 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 277 // by τ compared to the conditional gradient approach. |
406 // by τ compared to the conditional gradient approach. |
| 278 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
407 let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling(); |
| 279 let mut ε = tolerance.initial(); |
408 let mut ε = tolerance.initial(); |
| 280 |
409 |
| 281 // Initialise iterates |
410 // Initialise iterates |
| 282 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
411 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 283 let mut μ_prev = μ.clone(); |
412 let mut μ_prev = μ.clone(); |
| 294 let mut stats = IterInfo::new(); |
423 let mut stats = IterInfo::new(); |
| 295 |
424 |
| 296 // Run the algorithm |
425 // Run the algorithm |
| 297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
426 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 298 // Calculate smooth part of surrogate model. |
427 // Calculate smooth part of surrogate model. |
| 299 let mut τv = f.differential(&μ) * τ; |
428 let (fμ, v) = f.apply_and_differential(&μ); |
| |
429 let τ = adaptive_τ.update(&μ, fμ, &v); |
| |
430 let mut τv = v * τ; |
| 300 |
431 |
| 301 // Save current base point |
432 // Save current base point |
| 302 let μ_base = μ.clone(); |
433 let μ_base = μ.clone(); |
| 303 |
434 |
| 304 // Insert new spikes and reweigh |
435 // Insert new spikes and reweigh |
| 305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
436 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
437 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 307 )?; |
438 )?; |
| |
439 |
| |
440 // We don't treat merge in adaptive Lipschitz. |
| |
441 adaptive_τ.finish_step(&μ); |
| 308 |
442 |
| 309 // (Do not) merge spikes. |
443 // (Do not) merge spikes. |
| 310 if config.merge_now(&state) && !warned_merging { |
444 if config.merge_now(&state) && !warned_merging { |
| 311 let err = format!("Merging not supported for μFISTA"); |
445 let err = format!("Merging not supported for μFISTA"); |
| 312 println!("{}", err.red()); |
446 println!("{}", err.red()); |