src/fb.rs

branch
dev
changeset 61
4f468d35fa29
parent 51
0693cc9ba9f0
child 62
32328a74c790
child 63
7a8a55fd41c0
--- a/src/fb.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/fb.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -74,37 +74,34 @@
 </p>
 
 We solve this with either SSN or FB as determined by
-[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`].
+[`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`].
 */
 
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DiscreteMeasure, RNDM};
+use crate::plot::Plotter;
+pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
+use crate::regularisation::RegTerm;
+use crate::types::*;
+use alg_tools::error::DynResult;
+use alg_tools::instance::Instance;
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::mapping::DifferentiableMapping;
+use alg_tools::nalgebra_support::ToNalgebraRealField;
 use colored::Colorize;
 use numeric_literals::replace_float_literals;
 use serde::{Deserialize, Serialize};
 
-use alg_tools::euclidean::Euclidean;
-use alg_tools::instance::Instance;
-use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::linops::{Mapping, GEMV};
-use alg_tools::mapping::RealMapping;
-use alg_tools::nalgebra_support::ToNalgebraRealField;
-
-use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
-use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
-use crate::measures::merging::SpikeMerging;
-use crate::measures::{DiscreteMeasure, RNDM};
-use crate::plot::{PlotLookup, Plotting, SeqPlotter};
-pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
-use crate::regularisation::RegTerm;
-use crate::types::*;
-
 /// Settings for [`pointsource_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct FBConfig<F: Float> {
     /// Step length scaling
     pub τ0: F,
+    // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`]
+    pub σp0: F,
     /// Generic parameters
-    pub generic: FBGenericConfig<F>,
+    pub insertion: InsertionConfig<F>,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -112,12 +109,13 @@
     fn default() -> Self {
         FBConfig {
             τ0: 0.99,
-            generic: Default::default(),
+            σp0: 0.99,
+            insertion: Default::default(),
         }
     }
 }
 
-pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize {
+pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
     let n_before_prune = μ.len();
     μ.prune();
     debug_assert!(μ.len() <= n_before_prune);
@@ -125,30 +123,19 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn postprocess<
-    F: Float,
-    V: Euclidean<F> + Clone,
-    A: GEMV<F, RNDM<F, N>, Codomain = V>,
-    D: DataTerm<F, V, N>,
-    const N: usize,
->(
-    mut μ: RNDM<F, N>,
-    config: &FBGenericConfig<F>,
-    dataterm: D,
-    opA: &A,
-    b: &V,
-) -> RNDM<F, N>
+pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
+    mut μ: RNDM<N, F>,
+    config: &InsertionConfig<F>,
+    f: Dat,
+) -> DynResult<RNDM<N, F>>
 where
-    RNDM<F, N>: SpikeMerging<F>,
-    for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>,
+    RNDM<N, F>: SpikeMerging<F>,
+    for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>,
 {
-    μ.merge_spikes_fitness(
-        config.final_merging_method(),
-        |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
-        |&v| v,
-    );
+    //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v);
+    μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v);
     μ.prune();
-    μ
+    Ok(μ)
 }
 
 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
@@ -161,50 +148,41 @@
 ///
 /// For details on the mathematical formulation, see the [module level](self) documentation.
 ///
-/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
-/// sums of simple functions usign bisection trees, and the related
-/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
-/// active at a specific points, and to maximise their sums. Through the implementation of the
-/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
-/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
-///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>(
-    opA: &A,
-    b: &A::Observable,
-    reg: Reg,
+pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
     prox_penalty: &P,
     fbconfig: &FBConfig<F>,
     iterator: I,
-    mut plotter: SeqPlotter<F, N>,
-) -> RNDM<F, N>
+    mut plotter: Plot,
+    μ0 : Option<RNDM<N, F>>,
+) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
-    I: AlgIteratorFactory<IterInfo<F, N>>,
-    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
-    A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
-    A::PreadjointCodomain: RealMapping<F, N>,
-    PlotLookup: Plotting<N>,
-    RNDM<F, N>: SpikeMerging<F>,
-    Reg: RegTerm<F, N>,
-    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
+    Dat::DerivativeDomain: ClosedMul<F>,
+    Reg: RegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
+    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
 {
     // Set up parameters
-    let config = &fbconfig.generic;
-    let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
+    let config = &fbconfig.insertion;
+    let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
     let mut ε = tolerance.initial();
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-    let mut residual = -b;
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
 
     // Statistics
-    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
-        value: residual.norm2_squared_div2() + reg.apply(μ),
+    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
+        value: f.apply(μ) + reg.apply(μ),
         n_spikes: μ.len(),
         ε,
         //postprocessing: config.postprocessing.then(|| μ.clone()),
@@ -213,9 +191,10 @@
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let mut τv = opA.preadjoint().apply(residual * τ);
+        // TODO: optimise τ to be applied to residual.
+        let mut τv = f.differential(&μ) * τ;
 
         // Save current base point
         let μ_base = μ.clone();
@@ -223,7 +202,7 @@
         // Insert and reweigh
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
             &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
-        );
+        )?;
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
@@ -236,34 +215,27 @@
                 ε,
                 config,
                 &reg,
-                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+                Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
             );
         }
 
         stats.pruned += prune_with_stats(&mut μ);
 
-        // Update residual
-        residual = calculate_residual(&μ, opA, b);
-
         let iter = state.iteration();
         stats.this_iters += 1;
 
         // Give statistics if needed
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
-            full_stats(
-                &residual,
-                &μ,
-                ε,
-                std::mem::replace(&mut stats, IterInfo::new()),
-            )
+            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
 
-    postprocess(μ, config, L2Squared, opA, b)
+    //postprocess(μ_prev, config, f)
+    postprocess(μ, config, |μ̃| f.apply(μ̃))
 }
 
 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
@@ -276,38 +248,30 @@
 ///
 /// For details on the mathematical formulation, see the [module level](self) documentation.
 ///
-/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
-/// sums of simple functions usign bisection trees, and the related
-/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
-/// active at a specific points, and to maximise their sums. Through the implementation of the
-/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
-/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
-///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>(
-    opA: &A,
-    b: &A::Observable,
-    reg: Reg,
+pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
     prox_penalty: &P,
     fbconfig: &FBConfig<F>,
     iterator: I,
-    mut plotter: SeqPlotter<F, N>,
-) -> RNDM<F, N>
+    mut plotter: Plot,
+    μ0: Option<RNDM<N, F>>
+) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
-    I: AlgIteratorFactory<IterInfo<F, N>>,
-    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
-    A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
-    A::PreadjointCodomain: RealMapping<F, N>,
-    PlotLookup: Plotting<N>,
-    RNDM<F, N>: SpikeMerging<F>,
-    Reg: RegTerm<F, N>,
-    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
+    Dat::DerivativeDomain: ClosedMul<F>,
+    Reg: RegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
+    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
 {
     // Set up parameters
-    let config = &fbconfig.generic;
-    let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
+    let config = &fbconfig.insertion;
+    let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -315,14 +279,13 @@
     let mut ε = tolerance.initial();
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-    let mut μ_prev = DiscreteMeasure::new();
-    let mut residual = -b;
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
+    let mut μ_prev = μ.clone();
     let mut warned_merging = false;
 
     // Statistics
-    let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo {
-        value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
+    let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo {
+        value: f.apply(ν) + reg.apply(ν),
         n_spikes: ν.len(),
         ε,
         // postprocessing: config.postprocessing.then(|| ν.clone()),
@@ -333,7 +296,7 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let mut τv = opA.preadjoint().apply(residual * τ);
+        let mut τv = f.differential(&μ) * τ;
 
         // Save current base point
         let μ_base = μ.clone();
@@ -341,7 +304,7 @@
         // Insert new spikes and reweigh
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
             &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
-        );
+        )?;
 
         // (Do not) merge spikes.
         if config.merge_now(&state) && !warned_merging {
@@ -369,9 +332,6 @@
         debug_assert!(μ.len() <= n_before_prune);
         stats.pruned += n_before_prune - μ.len();
 
-        // Update residual
-        residual = calculate_residual(&μ, opA, b);
-
         let iter = state.iteration();
         stats.this_iters += 1;
 
@@ -385,5 +345,6 @@
         ε = tolerance.update(ε, iter);
     }
 
-    postprocess(μ_prev, config, L2Squared, opA, b)
+    //postprocess(μ_prev, config, f)
+    postprocess(μ_prev, config, |μ̃| f.apply(μ̃))
 }

mercurial