src/radon_fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
--- a/src/radon_fb.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/radon_fb.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -11,10 +11,11 @@
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
-    AlgIteratorState,
+    AlgIteratorIteration,
+    AlgIterator
 };
 use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::Apply;
+use alg_tools::linops::Mapping;
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
 use alg_tools::bisection_tree::{
@@ -29,11 +30,14 @@
 };
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::L2;
 
 use crate::types::*;
 use crate::measures::{
+    RNDM,
     DiscreteMeasure,
     DeltaMeasure,
+    Radon,
 };
 use crate::measures::merging::{
     SpikeMergingMethod,
@@ -54,10 +58,11 @@
 
 use crate::fb::{
     FBGenericConfig,
-    postprocess
+    postprocess,
+    prune_with_stats
 };
 
-/// Settings for [`pointsource_fb_reg`].
+/// Settings for [`pointsource_radon_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct RadonFBConfig<F : Float> {
@@ -79,17 +84,17 @@
 
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn insert_and_reweigh<
-    'a, F, GA, BTA, S, Reg, State, const N : usize
+    'a, F, GA, BTA, S, Reg, I, const N : usize
 >(
-    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-    minus_τv : &mut BTFN<F, GA, BTA, N>,
-    μ_base : &mut DiscreteMeasure<Loc<F, N>, F>,
-    _ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
+    μ : &mut RNDM<F, N>,
+    τv : &mut BTFN<F, GA, BTA, N>,
+    μ_base : &mut RNDM<F, N>,
+    //_ν_delta: Option<&RNDM<F, N>>,
     τ : F,
     ε : F,
     config : &FBGenericConfig<F>,
     reg : &Reg,
-    _state : &State,
+    _state : &AlgIteratorIteration<I>,
     stats : &mut IterInfo<F, N>,
 )
 where F : Float + ToNalgebraRealField,
@@ -97,18 +102,20 @@
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTerm<F, N>,
-      State : AlgIteratorState {
+      I : AlgIterator {
 
     'i_and_w: for i in 0..=1 {
         // Optimise weights
         if μ.len() > 0 {
             // Form finite-dimensional subproblem. The subproblem references to the original μ^k
             // from the beginning of the iteration are all contained in the immutable c and g.
+            // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
+            // problems have not yet been updated to sign change.
             let g̃ = DVector::from_iterator(μ.len(),
                                            μ.iter_locations()
-                                            .map(|ζ| F::to_nalgebra_mixed(minus_τv.apply(ζ))));
+                                            .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
             let mut x = μ.masses_dvector();
             let y = μ_base.masses_dvector();
 
@@ -122,7 +129,7 @@
         if i>0 {
             // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
             //let n = μ.dist_matching(μ_base);
-            //println!("{:?}", reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n));
+            //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
             break 'i_and_w
         }
         
@@ -132,69 +139,23 @@
         // Find a spike to insert, if needed.
         // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
         // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
-        match reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n) {
+        match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
             None => { break 'i_and_w },
             Some((ξ, _v_ξ, _in_bounds)) => {
                 // Weight is found out by running the finite-dimensional optimisation algorithm
                 // above
                 *μ += DeltaMeasure { x : ξ, α : 0.0 };
                 *μ_base += DeltaMeasure { x : ξ, α : 0.0 };
+                stats.inserted += 1;
             }
         };
     }
 }
 
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn prune_and_maybe_simple_merge<
-    'a, F, GA, BTA, S, Reg, State, const N : usize
->(
-    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-    minus_τv : &mut BTFN<F, GA, BTA, N>,
-    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
-    τ : F,
-    ε : F,
-    config : &FBGenericConfig<F>,
-    reg : &Reg,
-    state : &State,
-    stats : &mut IterInfo<F, N>,
-)
-where F : Float + ToNalgebraRealField,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
-      Reg : RegTerm<F, N>,
-      State : AlgIteratorState {
-
-    assert!(μ_base.len() <= μ.len());
-
-    if state.iteration() % config.merge_every == 0 {
-        stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
-            // Important: μ_candidate's new points are afterwards,
-            // and do not conflict with μ_base.
-            // TODO: could simplify to requiring μ_base instead of μ_radon.
-            // but may complicate with sliding base's exgtra points that need to be
-            // after μ_candidate's extra points.
-            // TODO: doesn't seem to work, maybe need to merge μ_base as well?
-            // Although that doesn't seem to make sense.
-            let μ_radon = μ_candidate.sub_matching(μ_base);
-            reg.verify_merge_candidate_radonsq(minus_τv, μ_candidate, τ, ε, &config, &μ_radon)
-            //let n = μ_candidate.dist_matching(μ_base);
-            //reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n).is_none()
-        });
-    }
-
-    let n_before_prune = μ.len();
-    μ.prune();
-    debug_assert!(μ.len() <= n_before_prune);
-    stats.pruned += n_before_prune - μ.len();
-}
-
 
 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting.
 ///
-/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
+/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
 /// as documented in [`alg_tools::iterate`].
@@ -219,20 +180,17 @@
     fbconfig : &RadonFBConfig<F>,
     iterator : I,
     mut _plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
-      PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTerm<F, N> {
 
     // Set up parameters
@@ -240,7 +198,7 @@
     // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
     // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
     // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
-    let τ = fbconfig.τ0/opA.opnorm_bound().powi(2);
+    let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
@@ -249,71 +207,74 @@
     // Initialise iterates
     let mut μ = DiscreteMeasure::new();
     let mut residual = -b;
+    
+    // Statistics
+    let full_stats = |residual : &A::Observable,
+                      μ : &RNDM<F, N>,
+                      ε, stats| IterInfo {
+        value : residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes : μ.len(),
+        ε,
+        // postprocessing: config.postprocessing.then(|| μ.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    iterator.iterate(|state| {
+    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let mut minus_τv = opA.preadjoint().apply(r);
+        let mut τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let mut μ_base = μ.clone();
             
         // Insert and reweigh
         insert_and_reweigh(
-            &mut μ, &mut minus_τv, &mut μ_base, None,
+            &mut μ, &mut τv, &mut μ_base, //None,
             τ, ε,
-            config, &reg, state, &mut stats
+            config, &reg, &state, &mut stats
         );
 
         // Prune and possibly merge spikes
-        prune_and_maybe_simple_merge(
-            &mut μ, &mut minus_τv, &μ_base,
-            τ, ε,
-            config, &reg, state, &mut stats
-        );
+        assert!(μ_base.len() <= μ.len());
+        if config.merge_now(&state) {
+            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
+                // Important: μ_candidate's new points are afterwards,
+                // and do not conflict with μ_base.
+                // TODO: could simplify to requiring μ_base instead of μ_radon.
+                // but may complicate with sliding base's exgtra points that need to be
+                // after μ_candidate's extra points.
+                // TODO: doesn't seem to work, maybe need to merge μ_base as well?
+                // Although that doesn't seem to make sense.
+                let μ_radon = μ_candidate.sub_matching(&μ_base);
+                reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon)
+                //let n = μ_candidate.dist_matching(μ_base);
+                //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
+            });
+        }
+        stats.pruned += prune_with_stats(&mut μ);
 
         // Update residual
         residual = calculate_residual(&μ, opA, b);
 
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
+        // Give statistics if needed
         state.if_verbose(|| {
-            // Plot if so requested
-            // plotter.plot_spikes(
-            //     format!("iter {} end;", state.iteration()), &d,
-            //     "start".to_string(), Some(&minus_τv),
-            //     reg.target_bounds(τ, ε_prev), &μ,
-            // );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : residual.norm2_squared_div2() + reg.apply(&μ),
-                n_spikes : μ.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
 
     postprocess(μ, config, L2Squared, opA, b)
 }
 
 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting.
 ///
-/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
+/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
 /// as documented in [`alg_tools::iterate`].
@@ -337,21 +298,19 @@
     reg : Reg,
     fbconfig : &RadonFBConfig<F>,
     iterator : I,
-    mut _plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+    mut plotter : SeqPlotter<F, N>,
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTerm<F, N> {
 
     // Set up parameters
@@ -359,7 +318,7 @@
     // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
     // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
     // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
-    let τ = fbconfig.τ0/opA.opnorm_bound().powi(2);
+    let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -370,31 +329,35 @@
     let mut μ = DiscreteMeasure::new();
     let mut μ_prev = DiscreteMeasure::new();
     let mut residual = -b;
+    let mut warned_merging = false;
+
+    // Statistics
+    let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
+        value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
+        n_spikes : ν.len(),
+        ε,
+        // postprocessing: config.postprocessing.then(|| ν.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
-    let mut warned_merging = false;
 
     // Run the algorithm
-    iterator.iterate(|state| {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let mut minus_τv = opA.preadjoint().apply(r);
+        let mut τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let mut μ_base = μ.clone();
             
         // Insert new spikes and reweigh
         insert_and_reweigh(
-            &mut μ, &mut minus_τv, &mut μ_base, None,
+            &mut μ, &mut τv, &mut μ_base, //None,
             τ, ε,
-            config, &reg, state, &mut stats
+            config, &reg, &state, &mut stats
         );
 
         // (Do not) merge spikes.
-        if state.iteration() % config.merge_every == 0 {
+        if config.merge_now(&state) {
             match config.merging {
                 SpikeMergingMethod::None => { },
                 _ => if !warned_merging {
@@ -423,33 +386,19 @@
 
         // Update residual
         residual = calculate_residual(&μ, opA, b);
-
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+       
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
+        // Give statistics if needed
         state.if_verbose(|| {
-            // Plot if so requested
-            // plotter.plot_spikes(
-            //     format!("iter {} end;", state.iteration()), &d,
-            //     "start".to_string(), Some(&minus_τv),
-            //     reg.target_bounds(τ, ε_prev), &μ_prev,
-            // );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
-                n_spikes : μ_prev.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ_prev.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev);
+            full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
 
     postprocess(μ_prev, config, L2Squared, opA, b)
 }

mercurial