src/frank_wolfe.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
--- a/src/frank_wolfe.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/frank_wolfe.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -14,18 +14,18 @@
 */
 
 use numeric_literals::replace_float_literals;
+use nalgebra::{DMatrix, DVector};
 use serde::{Serialize, Deserialize};
 //use colored::Colorize;
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
-    AlgIteratorState,
     AlgIteratorOptions,
     ValueIteratorFactory,
 };
 use alg_tools::euclidean::Euclidean;
 use alg_tools::norms::Norm;
-use alg_tools::linops::Apply;
+use alg_tools::linops::Mapping;
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
 use alg_tools::bisection_tree::{
@@ -40,9 +40,11 @@
 };
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::L2;
 
 use crate::types::*;
 use crate::measures::{
+    RNDM,
     DiscreteMeasure,
     DeltaMeasure,
     Radon,
@@ -71,7 +73,7 @@
     RegTerm
 };
 
-/// Settings for [`pointsource_fw`].
+/// Settings for [`pointsource_fw_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct FWConfig<F : Float> {
@@ -111,10 +113,20 @@
     }
 }
 
-/// Helper struct for pre-initialising the finite-dimensional subproblems solver
-/// [`prepare_optimise_weights`].
-///
-/// The pre-initialisation is done by [`prepare_optimise_weights`].
+pub trait FindimQuadraticModel<Domain, F> : ForwardModel<DiscreteMeasure<Domain, F>, F>
+where
+    F : Float + ToNalgebraRealField,
+    Domain : Clone + PartialEq,
+{
+    /// Return A_*A and A_* b
+    fn findim_quadratic_model(
+        &self,
+        μ : &DiscreteMeasure<Domain, F>,
+        b : &Self::Observable
+    ) -> (DMatrix<F::MixedType>, DVector<F::MixedType>);
+}
+
+/// Helper struct for pre-initialising the finite-dimensional subproblem solver.
 pub struct FindimData<F : Float> {
     /// ‖A‖^2
     opAnorm_squared : F,
@@ -125,7 +137,7 @@
 /// Trait for finite dimensional weight optimisation.
 pub trait WeightOptim<
     F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
+    A : ForwardModel<RNDM<F, N>, F>,
     I : AlgIteratorFactory<F>,
     const N : usize
 > {
@@ -154,7 +166,7 @@
     /// Returns the number of iterations taken by the method configured in `inner`.
     fn optimise_weights<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         opA : &'a A,
         b : &A::Observable,
         findim_data : &FindimData<F>,
@@ -166,12 +178,12 @@
 /// Trait for regularisation terms supported by [`pointsource_fw_reg`].
 pub trait RegTermFW<
     F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
+    A : ForwardModel<RNDM<F, N>, F>,
     I : AlgIteratorFactory<F>,
     const N : usize
 > : RegTerm<F, N>
     + WeightOptim<F, A, I, N>
-    + for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
+    + Mapping<RNDM<F, N>, Codomain = F> {
 
     /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted
     /// into $μ$, as determined by the regulariser.
@@ -188,7 +200,7 @@
     /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen.
     fn relaxed_insert<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         g : &A::PreadjointCodomain,
         opA : &'a A,
         ξ : Loc<F, N>,
@@ -201,18 +213,18 @@
 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
 for RadonRegTerm<F>
 where I : AlgIteratorFactory<F>,
-      A : ForwardModel<Loc<F, N>, F> {
+      A : FindimQuadraticModel<Loc<F, N>, F>  {
 
     fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
         FindimData{
-            opAnorm_squared : opA.opnorm_bound().powi(2),
+            opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
             m0 : b.norm2_squared() / (2.0 * self.α()),
         }
     }
 
     fn optimise_weights<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         opA : &'a A,
         b : &A::Observable,
         findim_data : &FindimData<F>,
@@ -245,12 +257,19 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
 for RadonRegTerm<F>
-where Cube<F, N> : P2Minimise<Loc<F, N>, F>,
-      I : AlgIteratorFactory<F>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> {
+where
+    Cube<F, N> : P2Minimise<Loc<F, N>, F>,
+    I : AlgIteratorFactory<F>,
+    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+    A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
+    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+    // FIXME: the following *should not* be needed, they are already implied
+    RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
+    DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
+    //A : Mapping<RNDM<F, N>, Codomain = A::Observable>,
+    //A : Mapping<DeltaMeasure<Loc<F, N>, F>, Codomain = A::Observable>,
+{
 
     fn find_insertion(
         &self,
@@ -269,7 +288,7 @@
 
     fn relaxed_insert<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         g : &A::PreadjointCodomain,
         opA : &'a A,
         ξ : Loc<F, N>,
@@ -282,7 +301,7 @@
         let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ };
         let δ = DeltaMeasure { x : ξ, α : v };
         let dp = μ.apply(g) - δ.apply(g);
-        let d = opA.apply(&*μ) - opA.apply(&δ);
+        let d = opA.apply(&*μ) - opA.apply(δ);
         let r = d.norm2_squared();
         let s = if r == 0.0 {
             1.0
@@ -298,18 +317,18 @@
 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
 for NonnegRadonRegTerm<F>
 where I : AlgIteratorFactory<F>,
-      A : ForwardModel<Loc<F, N>, F> {
+      A : FindimQuadraticModel<Loc<F, N>, F> {
 
     fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
         FindimData{
-            opAnorm_squared : opA.opnorm_bound().powi(2),
+            opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
             m0 : b.norm2_squared() / (2.0 * self.α()),
         }
     }
 
     fn optimise_weights<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         opA : &'a A,
         b : &A::Observable,
         findim_data : &FindimData<F>,
@@ -342,12 +361,17 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
 for NonnegRadonRegTerm<F>
-where Cube<F, N> : P2Minimise<Loc<F, N>, F>,
-      I : AlgIteratorFactory<F>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> {
+where
+    Cube<F, N> : P2Minimise<Loc<F, N>, F>,
+    I : AlgIteratorFactory<F>,
+    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+    A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
+    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+    // FIXME: the following *should not* be needed, they are already implied
+    RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
+    DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
+{
 
     fn find_insertion(
         &self,
@@ -361,7 +385,7 @@
 
     fn relaxed_insert<'a>(
         &self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+        μ : &mut RNDM<F, N>,
         g : &A::PreadjointCodomain,
         opA : &'a A,
         ξ : Loc<F, N>,
@@ -409,20 +433,18 @@
     config : &FWConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> {
 
     // Set up parameters
@@ -438,26 +460,24 @@
     let mut μ = DiscreteMeasure::new();
     let mut residual = -b;
 
-    let mut inner_iters = 0;
-    let mut this_iters = 0;
-    let mut pruned = 0;
-    let mut merged = 0;
+    // Statistics
+    let full_stats = |residual : &A::Observable,
+                      ν : &RNDM<F, N>,
+                      ε, stats| IterInfo {
+        value : residual.norm2_squared_div2() + reg.apply(ν),
+        n_spikes : ν.len(),
+        ε,
+        .. stats
+    };
+    let mut stats = IterInfo::new();
 
     // Run the algorithm
-    iterator.iterate(|state| {
-        // Update tolerance
+    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
         let inner_tolerance = ε * config.inner.tolerance_mult;
         let refinement_tolerance = ε * config.refinement.tolerance_mult;
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
 
         // Calculate smooth part of surrogate model.
-        //
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let mut g = -preadjA.apply(r);
+        let mut g = preadjA.apply(residual * (-1.0));
 
         // Find absolute value maximising point
         let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance,
@@ -467,60 +487,46 @@
             FWVariant::FullyCorrective => {
                 // No point in optimising the weight here: the finite-dimensional algorithm is fast.
                 μ += DeltaMeasure { x : ξ, α : 0.0 };
+                stats.inserted += 1;
                 config.inner.iterator_options.stop_target(inner_tolerance)
             },
             FWVariant::Relaxed => {
                 // Perform a relaxed initialisation of μ
                 reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data);
+                stats.inserted += 1;
                 // The stop_target is only needed for the type system.
                 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0)
             }
         };
 
-        inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, &config.inner, inner_it);
+        stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data,
+                                                  &config.inner, inner_it);
    
         // Merge spikes and update residual for next step and `if_verbose` below.
         let (r, count) = μ.merge_spikes_fitness(config.merging,
                                                 |μ̃| opA.apply(μ̃) - b,
                                                 A::Observable::norm2_squared);
         residual = r;
-        merged += count;
-
+        stats.merged += count;
 
         // Prune points with zero mass
         let n_before_prune = μ.len();
         μ.prune();
         debug_assert!(μ.len() <= n_before_prune);
-        pruned += n_before_prune - μ.len();
+        stats.pruned += n_before_prune - μ.len();
 
-        this_iters +=1;
+        stats.this_iters += 1;
+        let iter = state.iteration();
 
-        // Give function value if needed
+        // Give statistics if needed
         state.if_verbose(|| {
-            plotter.plot_spikes(
-                format!("iter {} start", state.iteration()), &g,
-                "".to_string(), None::<&A::PreadjointCodomain>,
-                None, &μ
-            );
-            let res = IterInfo {
-                value : residual.norm2_squared_div2() + reg.apply(&μ),
-                n_spikes : μ.len(),
-                inner_iters,
-                this_iters,
-                merged,
-                pruned,
-                ε : ε_prev,
-                postprocessing : None,
-                untransported_fraction : None,
-                transport_error : None,
-            };
-            inner_iters = 0;
-            this_iters = 0;
-            merged = 0;
-            pruned = 0;
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ);
+            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+
+        // Update tolerance
+        ε = tolerance.update(ε, iter);
+    }
 
     // Return final iterate
     μ

mercurial