diff -r 9738b51d90d7 -r 4f468d35fa29 src/frank_wolfe.rs --- a/src/frank_wolfe.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/frank_wolfe.rs Thu Feb 26 11:38:43 2026 -0500 @@ -13,82 +13,51 @@ DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). */ +use nalgebra::{DMatrix, DVector}; use numeric_literals::replace_float_literals; -use nalgebra::{DMatrix, DVector}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; //use colored::Colorize; - -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorOptions, - ValueIteratorFactory, -}; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::Norm; -use alg_tools::linops::Mapping; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::L2; - -use crate::types::*; -use crate::measures::{ - RNDM, - DiscreteMeasure, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::dataterm::QuadraticDataTerm; use crate::forward_model::ForwardModel; +use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; +use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM}; +use crate::plot::Plotter; +use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, RegTerm}; #[allow(unused_imports)] // Used in documentation use crate::subproblem::{ - unconstrained::quadratic_unconstrained, - nonneg::quadratic_nonneg, - InnerSettings, - InnerMethod, + nonneg::quadratic_nonneg, unconstrained::quadratic_unconstrained, InnerMethod, InnerSettings, }; use crate::tolerance::Tolerance; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::regularisation::{ - NonnegRadonRegTerm, - RadonRegTerm, - RegTerm -}; +use crate::types::*; +use alg_tools::bisection_tree::P2Minimise; +use alg_tools::bounds::MinMaxMapping; +use alg_tools::error::DynResult; +use alg_tools::euclidean::Euclidean; +use alg_tools::instance::Instance; +use alg_tools::iterate::{AlgIteratorFactory, AlgIteratorOptions, ValueIteratorFactory}; +use alg_tools::linops::Mapping; +use alg_tools::loc::Loc; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::Norm; +use alg_tools::norms::L2; +use alg_tools::sets::Cube; /// Settings for [`pointsource_fw_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FWConfig { +pub struct FWConfig { /// Tolerance for branch-and-bound new spike location discovery - pub tolerance : Tolerance, + pub tolerance: Tolerance, /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] /// as the conditional gradient subproblems' optimality conditions do not in general have an /// invertible Newton derivative for SSN. - pub inner : InnerSettings, + pub inner: InnerSettings, /// Variant of the conditional gradient method - pub variant : FWVariant, + pub variant: FWVariant, /// Settings for branch and bound refinement when looking for predual maxima - pub refinement : RefinementSettings, + pub refinement: RefinementSettings, /// Spike merging heuristic - pub merging : SpikeMergingMethod, + pub merging: SpikeMergingMethod, } /// Conditional gradient method variant; see also [`FWConfig`]. @@ -101,51 +70,51 @@ Relaxed, } -impl Default for FWConfig { +impl Default for FWConfig { fn default() -> Self { FWConfig { - tolerance : Default::default(), - refinement : Default::default(), - inner : Default::default(), - variant : FWVariant::FullyCorrective, - merging : SpikeMergingMethod { enabled : true, ..Default::default() }, + tolerance: Default::default(), + refinement: Default::default(), + inner: Default::default(), + variant: FWVariant::FullyCorrective, + merging: SpikeMergingMethod { enabled: true, ..Default::default() }, } } } -pub trait FindimQuadraticModel : ForwardModel, F> +pub trait FindimQuadraticModel: ForwardModel, F> where - F : Float + ToNalgebraRealField, - Domain : Clone + PartialEq, + F: Float + ToNalgebraRealField, + Domain: Clone + PartialEq, { /// Return A_*A and A_* b fn findim_quadratic_model( &self, - μ : &DiscreteMeasure, - b : &Self::Observable + μ: &DiscreteMeasure, + b: &Self::Observable, ) -> (DMatrix, DVector); } /// Helper struct for pre-initialising the finite-dimensional subproblem solver. -pub struct FindimData { +pub struct FindimData { /// ‖A‖^2 - opAnorm_squared : F, + opAnorm_squared: F, /// Bound $M_0$ from the Bredies–Pikkarainen article. - m0 : F + m0: F, } /// Trait for finite dimensional weight optimisation. pub trait WeightOptim< - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - I : AlgIteratorFactory, - const N : usize -> { - + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + I: AlgIteratorFactory, + const N: usize, +> +{ /// Return a pre-initialisation struct for [`Self::optimise_weights`]. /// /// The parameter `opA` is the forward operator $A$. - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData; + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult>; /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity /// point source localisation problem. @@ -166,72 +135,70 @@ /// Returns the number of iterations taken by the method configured in `inner`. fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize; } /// Trait for regularisation terms supported by [`pointsource_fw_reg`]. pub trait RegTermFW< - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - I : AlgIteratorFactory, - const N : usize -> : RegTerm - + WeightOptim - + Mapping, Codomain = F> { - + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + I: AlgIteratorFactory, + const N: usize, +>: RegTerm, F> + WeightOptim + Mapping, Codomain = F> +{ /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted /// into $μ$, as determined by the regulariser. /// /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant - /// [`BTFN`] minimisation and maximisation routines. + /// [`MinMaxMapping`] minimisation and maximisation routines. fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F); + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F); /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen. fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ); } #[replace_float_literals(F::cast_from(literal))] -impl WeightOptim -for RadonRegTerm -where I : AlgIteratorFactory, - A : FindimQuadraticModel, F> { - - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData { - FindimData{ - opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), - m0 : b.norm2_squared() / (2.0 * self.α()), - } +impl WeightOptim + for RadonRegTerm +where + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, +{ + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult> { + Ok(FindimData { + opAnorm_squared: opA.opnorm_bound(Radon, L2)?.powi(2), + m0: b.norm2_squared() / (2.0 * self.α()), + }) } fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize { - // Form and solve finite-dimensional subproblem. let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); let mut x = μ.masses_dvector(); @@ -245,8 +212,7 @@ // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no // square root is needed when we scale: let normest = findim_data.opAnorm_squared * F::cast_from(μ.len()); - let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x, - normest, inner, iterator); + let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x, normest, inner, iterator); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); @@ -255,28 +221,23 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RegTermFW -for RadonRegTerm +impl RegTermFW for RadonRegTerm where - Cube : P2Minimise, F>, - I : AlgIteratorFactory, - S: RealMapping + LocalAnalysis, N>, - GA : SupportGenerator + Clone, - A : FindimQuadraticModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, + Cube: P2Minimise, F>, + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'a> &'a A::PreadjointCodomain: Instance, // FIXME: the following *should not* be needed, they are already implied - RNDM : Mapping, - DeltaMeasure, F> : Mapping, - //A : Mapping, Codomain = A::Observable>, - //A : Mapping, F>, Codomain = A::Observable>, + RNDM: Mapping, + DeltaMeasure, F>: Mapping, { - fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F) { + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F) { let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, max_steps); let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, max_steps); if v_ξmin < 0.0 && -v_ξmin > v_ξmax { @@ -288,25 +249,35 @@ fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ) { let α = self.0; let m0 = findim_data.m0; - let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; - let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; - let δ = DeltaMeasure { x : ξ, α : v }; + let φ = |t| { + if t <= m0 { + α * t + } else { + α / (2.0 * m0) * (t * t + m0 * m0) + } + }; + let v = if v_ξ.abs() <= α { + 0.0 + } else { + m0 / α * v_ξ + }; + let δ = DeltaMeasure { x: ξ, α: v }; let dp = μ.apply(g) - δ.apply(g); let d = opA.apply(&*μ) - opA.apply(δ); let r = d.norm2_squared(); let s = if r == 0.0 { 1.0 } else { - 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) + 1.0.min((α * μ.norm(Radon) - φ(v.abs()) - dp) / r) }; *μ *= 1.0 - s; *μ += δ * s; @@ -314,28 +285,28 @@ } #[replace_float_literals(F::cast_from(literal))] -impl WeightOptim -for NonnegRadonRegTerm -where I : AlgIteratorFactory, - A : FindimQuadraticModel, F> { - - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData { - FindimData{ - opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), - m0 : b.norm2_squared() / (2.0 * self.α()), - } +impl WeightOptim + for NonnegRadonRegTerm +where + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, +{ + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult> { + Ok(FindimData { + opAnorm_squared: opA.opnorm_bound(Radon, L2)?.powi(2), + m0: b.norm2_squared() / (2.0 * self.α()), + }) } fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize { - // Form and solve finite-dimensional subproblem. let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); let mut x = μ.masses_dvector(); @@ -349,8 +320,7 @@ // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no // square root is needed when we scale: let normest = findim_data.opAnorm_squared * F::cast_from(μ.len()); - let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x, - normest, inner, iterator); + let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x, normest, inner, iterator); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); @@ -359,59 +329,65 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RegTermFW -for NonnegRadonRegTerm +impl RegTermFW + for NonnegRadonRegTerm where - Cube : P2Minimise, F>, - I : AlgIteratorFactory, - S: RealMapping + LocalAnalysis, N>, - GA : SupportGenerator + Clone, - A : FindimQuadraticModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, + Cube: P2Minimise, F>, + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'a> &'a A::PreadjointCodomain: Instance, // FIXME: the following *should not* be needed, they are already implied - RNDM : Mapping, - DeltaMeasure, F> : Mapping, + RNDM: Mapping, + DeltaMeasure, F>: Mapping, { - fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F) { + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F) { g.maximise(refinement_tolerance, max_steps) } - fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ) { // This is just a verbatim copy of RadonRegTerm::relaxed_insert. let α = self.0; let m0 = findim_data.m0; - let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; - let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; - let δ = DeltaMeasure { x : ξ, α : v }; + let φ = |t| { + if t <= m0 { + α * t + } else { + α / (2.0 * m0) * (t * t + m0 * m0) + } + }; + let v = if v_ξ.abs() <= α { + 0.0 + } else { + m0 / α * v_ξ + }; + let δ = DeltaMeasure { x: ξ, α: v }; let dp = μ.apply(g) - δ.apply(g); let d = opA.apply(&*μ) - opA.apply(&δ); let r = d.norm2_squared(); let s = if r == 0.0 { 1.0 } else { - 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) + 1.0.min((α * μ.norm(Radon) - φ(v.abs()) - dp) / r) }; *μ *= 1.0 - s; *μ += δ * s; } } - /// Solve point source localisation problem using a conditional gradient method /// for the 2-norm-squared data fidelity, i.e., the problem ///
$$ @@ -425,49 +401,48 @@ /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fw_reg( - opA : &A, - b : &A::Observable, - reg : Reg, - //domain : Cube, - config : &FWConfig, - iterator : I, - mut plotter : SeqPlotter, -) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTermFW, N> { +pub fn pointsource_fw_reg<'a, F, I, A, Reg, Plot, const N: usize>( + f: &'a QuadraticDataTerm, A>, + reg: &Reg, + //domain : Cube, + config: &FWConfig, + iterator: I, + mut plotter: Plot, + μ0 : Option>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + &'a A::PreadjointCodomain: Instance, + Cube: P2Minimise, F>, + RNDM: SpikeMerging, + Reg: RegTermFW, N>, + Plot: Plotter>, +{ + let opA = f.operator(); + let b = f.data(); // Set up parameters // We multiply tolerance by α for all algoritms. let tolerance = config.tolerance * reg.tolerance_scaling(); let mut ε = tolerance.initial(); - let findim_data = reg.prepare_optimise_weights(opA, b); + let findim_data = reg.prepare_optimise_weights(opA, b)?; // Initialise operators let preadjA = opA.preadjoint(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut residual = f.residual(&μ); // Statistics - let full_stats = |residual : &A::Observable, - ν : &RNDM, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(ν), - n_spikes : ν.len(), + let full_stats = |residual: &A::Observable, ν: &RNDM, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + reg.apply(ν), + n_spikes: ν.len(), ε, - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -480,32 +455,34 @@ let mut g = preadjA.apply(residual * (-1.0)); // Find absolute value maximising point - let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance, - config.refinement.max_steps); + let (ξ, v_ξ) = + reg.find_insertion(&mut g, refinement_tolerance, config.refinement.max_steps); let inner_it = match config.variant { FWVariant::FullyCorrective => { // No point in optimising the weight here: the finite-dimensional algorithm is fast. - μ += DeltaMeasure { x : ξ, α : 0.0 }; + μ += DeltaMeasure { x: ξ, α: 0.0 }; stats.inserted += 1; config.inner.iterator_options.stop_target(inner_tolerance) - }, + } FWVariant::Relaxed => { // Perform a relaxed initialisation of μ reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data); stats.inserted += 1; // The stop_target is only needed for the type system. - AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) + AlgIteratorOptions { max_iter: 1, ..config.inner.iterator_options }.stop_target(0.0) } }; - stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, - &config.inner, inner_it); - + stats.inner_iters += + reg.optimise_weights(&mut μ, opA, b, &findim_data, &config.inner, inner_it); + // Merge spikes and update residual for next step and `if_verbose` below. - let (r, count) = μ.merge_spikes_fitness(config.merging, - |μ̃| opA.apply(μ̃) - b, - A::Observable::norm2_squared); + let (r, count) = μ.merge_spikes_fitness( + config.merging, + |μ̃| f.residual(μ̃), + A::Observable::norm2_squared, + ); residual = r; stats.merged += count; @@ -520,8 +497,13 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ); - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + plotter.plot_spikes(iter, Some(&g), Option::<&A::PreadjointCodomain>::None, &μ); + full_stats( + &residual, + &μ, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); // Update tolerance @@ -529,5 +511,5 @@ } // Return final iterate - μ + Ok(μ) }