src/fb.rs

branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
equal deleted inserted replaced
61:4f468d35fa29 62:32328a74c790
78 */ 78 */
79 79
80 use crate::measures::merging::SpikeMerging; 80 use crate::measures::merging::SpikeMerging;
81 use crate::measures::{DiscreteMeasure, RNDM}; 81 use crate::measures::{DiscreteMeasure, RNDM};
82 use crate::plot::Plotter; 82 use crate::plot::Plotter;
83 use crate::prox_penalty::StepLengthBoundValue;
83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; 84 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
84 use crate::regularisation::RegTerm; 85 use crate::regularisation::RegTerm;
85 use crate::types::*; 86 use crate::types::*;
86 use alg_tools::error::DynResult; 87 use alg_tools::error::DynResult;
87 use alg_tools::instance::Instance; 88 use alg_tools::instance::{ClosedSpace, Instance};
88 use alg_tools::iterate::AlgIteratorFactory; 89 use alg_tools::iterate::AlgIteratorFactory;
89 use alg_tools::mapping::DifferentiableMapping; 90 use alg_tools::mapping::{DifferentiableMapping, Mapping};
90 use alg_tools::nalgebra_support::ToNalgebraRealField; 91 use alg_tools::nalgebra_support::ToNalgebraRealField;
92 use anyhow::anyhow;
91 use colored::Colorize; 93 use colored::Colorize;
92 use numeric_literals::replace_float_literals; 94 use numeric_literals::replace_float_literals;
93 use serde::{Deserialize, Serialize}; 95 use serde::{Deserialize, Serialize};
94 96
95 /// Settings for [`pointsource_fb_reg`]. 97 /// Settings for [`pointsource_fb_reg`].
100 pub τ0: F, 102 pub τ0: F,
101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] 103 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`]
102 pub σp0: F, 104 pub σp0: F,
103 /// Generic parameters 105 /// Generic parameters
104 pub insertion: InsertionConfig<F>, 106 pub insertion: InsertionConfig<F>,
107 /// Always adaptive step length
108 pub always_adaptive_τ: bool,
105 } 109 }
106 110
107 #[replace_float_literals(F::cast_from(literal))] 111 #[replace_float_literals(F::cast_from(literal))]
108 impl<F: Float> Default for FBConfig<F> { 112 impl<F: Float> Default for FBConfig<F> {
109 fn default() -> Self { 113 fn default() -> Self {
110 FBConfig { 114 FBConfig { τ0: 0.99, σp0: 0.99, always_adaptive_τ: false, insertion: Default::default() }
111 τ0: 0.99,
112 σp0: 0.99,
113 insertion: Default::default(),
114 }
115 } 115 }
116 } 116 }
117 117
118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { 118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
119 let n_before_prune = μ.len(); 119 let n_before_prune = μ.len();
120 μ.prune(); 120 μ.prune();
121 debug_assert!(μ.len() <= n_before_prune); 121 debug_assert!(μ.len() <= n_before_prune);
122 n_before_prune - μ.len() 122 n_before_prune - μ.len()
123 }
124
125 /// Adaptive step length and Lipschitz parameter estimation state.
126 #[derive(Clone, Debug, Serialize)]
127 pub enum AdaptiveStepLength<const N: usize, F: Float> {
128 Adaptive {
129 l: F,
130 μ_old: RNDM<N, F>,
131 fμ_old: F,
132 μ_dist: F,
133 τ0: F,
134 l_is_initial: bool,
135 },
136 Fixed {
137 τ: F,
138 },
139 }
140
141 #[replace_float_literals(F::cast_from(literal))]
142 impl<const N: usize, F: Float> AdaptiveStepLength<N, F> {
143 pub fn new<Dat, Reg, P>(f: &Dat, prox_penalty: &P, fbconfig: &FBConfig<F>) -> DynResult<Self>
144 where
145 F: ToNalgebraRealField,
146 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
147 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
148 Reg: RegTerm<Loc<N, F>, F>,
149 {
150 match (
151 prox_penalty.step_length_bound(&f),
152 fbconfig.always_adaptive_τ,
153 ) {
154 (StepLengthBoundValue::LipschitzFactor(l), false) => {
155 Ok(AdaptiveStepLength::Fixed { τ: fbconfig.τ0 / l })
156 }
157 (StepLengthBoundValue::LipschitzFactor(l), true) => {
158 let μ_old = DiscreteMeasure::new();
159 let fμ_old = f.apply(&μ_old);
160 Ok(AdaptiveStepLength::Adaptive {
161 l: l,
162 μ_old,
163 fμ_old,
164 μ_dist: 0.0,
165 τ0: fbconfig.τ0,
166 l_is_initial: false,
167 })
168 }
169 (StepLengthBoundValue::UnreliableLipschitzFactor(l), _) => {
170 println!("Lipschitz factor is unreliable; calculating adaptively.");
171 let μ_old = DiscreteMeasure::new();
172 let fμ_old = f.apply(&μ_old);
173 Ok(AdaptiveStepLength::Adaptive {
174 l: l,
175 μ_old,
176 fμ_old,
177 μ_dist: 0.0,
178 τ0: fbconfig.τ0,
179 l_is_initial: true,
180 })
181 }
182 (StepLengthBoundValue::Failure, _) => Err(anyhow!("No Lipschitz estimate available")),
183 }
184 }
185
186 /// Returns the current value of the step length parameter.
187 pub fn current(&self) -> F {
188 match *self {
189 AdaptiveStepLength::Adaptive { τ0, l, .. } => τ0 / l,
190 AdaptiveStepLength::Fixed { τ } => τ,
191 }
192 }
193
194 /// Update daptive Lipschitz factor and return new step length parameter `τ`.
195 ///
196 /// Inputs:
197 /// * `μ`: current point
198 /// * `fμ`: value of the function `f` at `μ`.
199 /// * `ν`: derivative of the function `f` at `μ`.
200 /// * `τ0`: fractional step length parameter in $[0, 1)$.
201 pub fn update<'a, G>(&mut self, μ: &RNDM<N, F>, fμ: F, v: &'a G) -> F
202 where
203 G: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
204 &'a G: Instance<G>,
205 {
206 match self {
207 AdaptiveStepLength::Adaptive { l, μ_old, fμ_old, μ_dist, τ0, l_is_initial } => {
208 // Estimate step length parameter
209 let b = *fμ_old - fμ - μ_old.apply(v) + μ.apply(v);
210 let d = *μ_dist;
211 if d.abs() > F::EPSILON && μ.len() > 0 && μ_old.len() > 0 {
212 let lc = b / (d * d / 2.0);
213 dbg!(b, d, lc);
214 if *l_is_initial {
215 *l = lc;
216 *l_is_initial = false;
217 } else {
218 *l = l.max(lc);
219 }
220 }
221
222 // Store for next iteration
223 *μ_old = μ.clone();
224 *fμ_old = fμ;
225
226 return *τ0 / *l;
227 }
228 AdaptiveStepLength::Fixed { τ } => *τ,
229 }
230 }
231
232 /// Finalises a step, storing μ and its distance to the previous μ.
233 ///
234 /// This is not included in [`Self::update`], as this function is to be called
235 /// before pruning and merging, while μ and its previous version in their internal
236 /// presentation still having matching indices for the same coordinate.
237 pub fn finish_step(&mut self, μ: &RNDM<N, F>) {
238 if let AdaptiveStepLength::Adaptive { μ_dist, μ_old, .. } = self {
239 *μ_dist = μ.dist_matching(&μ_old);
240 }
241 }
123 } 242 }
124 243
125 #[replace_float_literals(F::cast_from(literal))] 244 #[replace_float_literals(F::cast_from(literal))]
126 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( 245 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
127 mut μ: RNDM<N, F>, 246 mut μ: RNDM<N, F>,
155 reg: &Reg, 274 reg: &Reg,
156 prox_penalty: &P, 275 prox_penalty: &P,
157 fbconfig: &FBConfig<F>, 276 fbconfig: &FBConfig<F>,
158 iterator: I, 277 iterator: I,
159 mut plotter: Plot, 278 mut plotter: Plot,
160 μ0 : Option<RNDM<N, F>>, 279 μ0: Option<RNDM<N, F>>,
161 ) -> DynResult<RNDM<N, F>> 280 ) -> DynResult<RNDM<N, F>>
162 where 281 where
163 F: Float + ToNalgebraRealField, 282 F: Float + ToNalgebraRealField,
164 I: AlgIteratorFactory<IterInfo<F>>, 283 I: AlgIteratorFactory<IterInfo<F>>,
165 RNDM<N, F>: SpikeMerging<F>, 284 RNDM<N, F>: SpikeMerging<F>,
166 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, 285 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
167 Dat::DerivativeDomain: ClosedMul<F>, 286 Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
168 Reg: RegTerm<Loc<N, F>, F>, 287 Reg: RegTerm<Loc<N, F>, F>,
169 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, 288 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
170 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, 289 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
290 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
171 { 291 {
172 // Set up parameters 292 // Set up parameters
173 let config = &fbconfig.insertion; 293 let config = &fbconfig.insertion;
174 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; 294 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;
295
175 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 296 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
176 // by τ compared to the conditional gradient approach. 297 // by τ compared to the conditional gradient approach.
177 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 298 let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
178 let mut ε = tolerance.initial(); 299 let mut ε = tolerance.initial();
179 300
180 // Initialise iterates 301 // Initialise iterates
181 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); 302 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
182 303
192 313
193 // Run the algorithm 314 // Run the algorithm
194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 315 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
195 // Calculate smooth part of surrogate model. 316 // Calculate smooth part of surrogate model.
196 // TODO: optimise τ to be applied to residual. 317 // TODO: optimise τ to be applied to residual.
197 let mut τv = f.differential(&μ) * τ; 318 let (fμ, v) = f.apply_and_differential(&μ);
319 let τ = adaptive_τ.update(&μ, fμ, &v);
320 dbg!(τ);
321 let mut τv = v * τ;
198 322
199 // Save current base point 323 // Save current base point
200 let μ_base = μ.clone(); 324 let μ_base = μ.clone();
201 325
202 // Insert and reweigh 326 // Insert and reweigh
203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 327 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 328 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
205 )?; 329 )?;
330
331 // We don't treat merge in adaptive Lipschitz.
332 adaptive_τ.finish_step(&μ);
206 333
207 // Prune and possibly merge spikes 334 // Prune and possibly merge spikes
208 if config.merge_now(&state) { 335 if config.merge_now(&state) {
209 stats.merged += prox_penalty.merge_spikes( 336 stats.merged += prox_penalty.merge_spikes(
210 &mut μ, 337 &mut μ,
255 reg: &Reg, 382 reg: &Reg,
256 prox_penalty: &P, 383 prox_penalty: &P,
257 fbconfig: &FBConfig<F>, 384 fbconfig: &FBConfig<F>,
258 iterator: I, 385 iterator: I,
259 mut plotter: Plot, 386 mut plotter: Plot,
260 μ0: Option<RNDM<N, F>> 387 μ0: Option<RNDM<N, F>>,
261 ) -> DynResult<RNDM<N, F>> 388 ) -> DynResult<RNDM<N, F>>
262 where 389 where
263 F: Float + ToNalgebraRealField, 390 F: Float + ToNalgebraRealField,
264 I: AlgIteratorFactory<IterInfo<F>>, 391 I: AlgIteratorFactory<IterInfo<F>>,
265 RNDM<N, F>: SpikeMerging<F>, 392 RNDM<N, F>: SpikeMerging<F>,
266 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, 393 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
267 Dat::DerivativeDomain: ClosedMul<F>, 394 Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
268 Reg: RegTerm<Loc<N, F>, F>, 395 Reg: RegTerm<Loc<N, F>, F>,
269 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, 396 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
270 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, 397 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
398 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
271 { 399 {
272 // Set up parameters 400 // Set up parameters
273 let config = &fbconfig.insertion; 401 let config = &fbconfig.insertion;
274 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; 402 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;
403
275 let mut λ = 1.0; 404 let mut λ = 1.0;
276 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 405 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
277 // by τ compared to the conditional gradient approach. 406 // by τ compared to the conditional gradient approach.
278 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 407 let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
279 let mut ε = tolerance.initial(); 408 let mut ε = tolerance.initial();
280 409
281 // Initialise iterates 410 // Initialise iterates
282 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); 411 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
283 let mut μ_prev = μ.clone(); 412 let mut μ_prev = μ.clone();
294 let mut stats = IterInfo::new(); 423 let mut stats = IterInfo::new();
295 424
296 // Run the algorithm 425 // Run the algorithm
297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 426 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
298 // Calculate smooth part of surrogate model. 427 // Calculate smooth part of surrogate model.
299 let mut τv = f.differential(&μ) * τ; 428 let (fμ, v) = f.apply_and_differential(&μ);
429 let τ = adaptive_τ.update(&μ, fμ, &v);
430 let mut τv = v * τ;
300 431
301 // Save current base point 432 // Save current base point
302 let μ_base = μ.clone(); 433 let μ_base = μ.clone();
303 434
304 // Insert new spikes and reweigh 435 // Insert new spikes and reweigh
305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 436 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 437 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
307 )?; 438 )?;
439
440 // We don't treat merge in adaptive Lipschitz.
441 adaptive_τ.finish_step(&μ);
308 442
309 // (Do not) merge spikes. 443 // (Do not) merge spikes.
310 if config.merge_now(&state) && !warned_merging { 444 if config.merge_now(&state) && !warned_merging {
311 let err = format!("Merging not supported for μFISTA"); 445 let err = format!("Merging not supported for μFISTA");
312 println!("{}", err.red()); 446 println!("{}", err.red());

mercurial