src/sliding_fb.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
--- a/src/sliding_fb.rs	Mon Jan 06 11:32:57 2025 -0500
+++ b/src/sliding_fb.rs	Thu Jan 23 23:35:28 2025 +0100
@@ -12,38 +12,19 @@
 
 use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::euclidean::Euclidean;
-use alg_tools::sets::Cube;
-use alg_tools::loc::Loc;
-use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance};
+use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
 use alg_tools::norms::Norm;
-use alg_tools::bisection_tree::{
-    BTFN,
-    PreBTFN,
-    Bounds,
-    BTNodeLookup,
-    BTNode,
-    BTSearch,
-    P2Minimise,
-    SupportGenerator,
-    LocalAnalysis,
-    //Bounded,
-};
-use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::norms::{L2, Linfinity};
+use alg_tools::norms::L2;
 
 use crate::types::*;
 use crate::measures::{DiscreteMeasure, Radon, RNDM};
-use crate::measures::merging::{
-    SpikeMergingMethod,
-    SpikeMerging,
-};
+use crate::measures::merging::SpikeMerging;
 use crate::forward_model::{
     ForwardModel,
     AdjointProductBoundedBy,
     LipschitzValues,
 };
-use crate::seminorms::DiscreteMeasureOp;
 //use crate::tolerance::Tolerance;
 use crate::plot::{
     SeqPlotter,
@@ -151,7 +132,7 @@
     Observable : Euclidean<F, Output=Observable>,
     for<'a> &'a Observable : Instance<Observable>,
     //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
-    D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>,
+    D : DifferentiableRealMapping<F, N>,
 {
 
     use TransportStepLength::*;
@@ -353,40 +334,29 @@
 /// The parametrisation is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
-    opA : &'a A,
+pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>(
+    opA : &A,
     b : &A::Observable,
     reg : Reg,
-    op𝒟 : &'a 𝒟,
+    prox_penalty : &P,
     config : &SlidingFBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
 ) -> RNDM<F, N>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-      for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
-      A::PreadjointCodomain : DifferentiableMapping<
-        Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F
-      >,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
-          //+ TransportLipschitz<L2Squared, FloatType=F>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
-                                          Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
-      BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
-         + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-         //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
-      PlotLookup : Plotting<N>,
-      RNDM<F, N> : SpikeMerging<F>,
-      Reg : SlidingRegTerm<F, N> {
+where
+    F : Float + ToNalgebraRealField,
+    I : AlgIteratorFactory<IterInfo<F, N>>,
+    A : ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
+        //+ TransportLipschitz<L2Squared, FloatType=F>,
+    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
+    for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
+    A::PreadjointCodomain : DifferentiableRealMapping<F, N>,
+    RNDM<F, N> : SpikeMerging<F>,
+    Reg : SlidingRegTerm<F, N>,
+    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    PlotLookup : Plotting<N>,
+{
 
     // Check parameters
     assert!(config.τ0 > 0.0, "Invalid step length parameter");
@@ -398,13 +368,12 @@
     let mut residual = -b; // Has to equal $Aμ-b$.
 
     // Set up parameters
-    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
     let opAnorm = opA.opnorm_bound(Radon, L2);
     //let max_transport = config.max_transport.scale
     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
     let ℓ = 0.0;
-    let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap();
+    let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
     let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
     let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
         // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
@@ -446,15 +415,14 @@
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
-        let (d, _within_tolerances, τv̆) = 'adapt_transport: loop {
+        let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop {
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
             let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
-            let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
+            let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
-            let (d, within_tolerances) = insert_and_reweigh(
-                &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0),
-                op𝒟, op𝒟norm,
+            let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
+                &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0),
                 τ, ε, &config.insertion,
                 &reg, &state, &mut stats,
             );
@@ -464,7 +432,7 @@
                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
                 ε, &config.transport
             ) {
-                break 'adapt_transport (d, within_tolerances, τv̆)
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆)
             }
         };
 
@@ -480,20 +448,20 @@
             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
         });
 
-        // Merge spikes.
-        // This expects the prune below to prune γ.
-        // TODO: This may not work correctly in all cases.
-        let ins = &config.insertion;
-        if ins.merge_now(&state) {
-            if let SpikeMergingMethod::None = ins.merging {
-            } else {
-                stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
-                    let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
-                    let mut d = &τv̆ + op𝒟.preapply(ν);
-                    reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
-                });
-            }
-        }
+        // // Merge spikes.
+        // // This expects the prune below to prune γ.
+        // // TODO: This may not work correctly in all cases.
+        // let ins = &config.insertion;
+        // if ins.merge_now(&state) {
+        //     if let SpikeMergingMethod::None = ins.merging {
+        //     } else {
+        //         stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
+        //             let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
+        //             let mut d = &τv̆ + op𝒟.preapply(ν);
+        //             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
+        //         });
+        //     }
+        // }
 
         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
         // latter needs to be pruned when μ is.
@@ -514,7 +482,7 @@
 
         // Give statistics if requested
         state.if_verbose(|| {
-            plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ);
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
             full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 

mercurial