diff -r 6105b5cd8d89 -r f0e8704d3f0e src/prox_penalty/wave.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty/wave.rs Mon Feb 17 13:54:53 2025 -0500 @@ -0,0 +1,191 @@ +/*! +Basic proximal penalty based on convolution operators $𝒟$. + */ + +use numeric_literals::replace_float_literals; +use nalgebra::DVector; +use colored::Colorize; + +use alg_tools::types::*; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Mapping, RealMapping}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::Linfinity; +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator, +}; +use alg_tools::bisection_tree::{ + BTFN, + PreBTFN, + Bounds, + BTSearch, + SupportGenerator, + LocalAnalysis, + BothGenerators, +}; +use crate::measures::{ + RNDM, + DeltaMeasure, + Radon, +}; +use crate::measures::merging::{ + SpikeMerging, +}; +use crate::seminorms::DiscreteMeasureOp; +use crate::types::{ + IterInfo, +}; +use crate::regularisation::RegTerm; +use super::{ProxPenalty, FBGenericConfig}; + +#[replace_float_literals(F::cast_from(literal))] +impl +ProxPenalty, Reg, N> for 𝒟 +where + F : Float + ToNalgebraRealField, + GA : SupportGenerator + Clone, + BTA : BTSearch>, + S: RealMapping + LocalAnalysis, N>, + G𝒟 : SupportGenerator + Clone, + 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, + 𝒟::Codomain : RealMapping, + K : RealMapping + LocalAnalysis, N>, + Reg : RegTerm, + RNDM : SpikeMerging, +{ + type ReturnMapping = BTFN, BTA, N>; + + fn insert_and_reweigh( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + ν_delta: Option<&RNDM>, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + state : &AlgIteratorIteration, + stats : &mut IterInfo, + ) -> (Option, BTA, N>>, bool) + where + I : AlgIterator + { + + let op𝒟norm = self.opnorm_bound(Radon, Linfinity); + + // Maximum insertion count and measure difference calculation depend on insertion style. + let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { + (i, Some((l, k))) if i <= l => (k, false), + _ => (config.max_insertions, !state.is_quiet()), + }; + + let ω0 = match ν_delta { + None => self.apply(μ_base), + Some(ν) => self.apply(μ_base + ν), + }; + + // Add points to support until within error tolerance or maximum insertion count reached. + let mut count = 0; + let (within_tolerances, d) = 'insertion: loop { + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional + // problems have not yet been updated to sign change. + let à = self.findim_matrix(μ.iter_locations()); + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) + .map(F::to_nalgebra_mixed)); + let mut x = μ.masses_dvector(); + + // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. + // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ + // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ + // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 + // = n |𝒟| |x|_2, where n is the number of points. Therefore + let Ã_normest = op𝒟norm * F::cast_from(μ.len()); + + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality + // conditions in the predual space, and finding new points for insertion, if necessary. + let mut d = &*τv + match ν_delta { + None => self.preapply(μ.sub_matching(μ_base)), + Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) + }; + + // If no merging heuristic is used, let's be more conservative about spike insertion, + // and skip it after first round. If merging is done, being more greedy about spike + // insertion also seems to improve performance. + let skip_by_rough_check = if config.merging.enabled { + false + } else { + count > 0 + }; + + // Find a spike to insert, if needed + let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( + &mut d, τ, ε, skip_by_rough_check, config + ) { + None => break 'insertion (true, d), + Some(res) => res, + }; + + // Break if maximum insertion count reached + if count >= max_insertions { + break 'insertion (in_bounds, d) + } + + // No point in optimising the weight here; the finite-dimensional algorithm is fast. + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + count += 1; + stats.inserted += 1; + }; + + if !within_tolerances && warn_insertions { + // Complain (but continue) if we failed to get within tolerances + // by inserting more points. + let err = format!("Maximum insertions reached without achieving \ + subproblem solution tolerance"); + println!("{}", err.red()); + } + + (Some(d), within_tolerances) + } + + fn merge_spikes( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + ν_delta: Option<&RNDM>, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + fitness : Option) -> F>, + ) -> usize + { + if config.fitness_merging { + if let Some(f) = fitness { + return μ.merge_spikes_fitness(config.merging, f, |&v| v) + .1 + } + } + μ.merge_spikes(config.merging, |μ_candidate| { + let mut d = &*τv + self.preapply(match ν_delta { + None => μ_candidate.sub_matching(μ_base), + Some(ν) => μ_candidate.sub_matching(μ_base) - ν, + }); + reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) + }) + } +}