src/forward_pdps.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
--- a/src/forward_pdps.rs	Mon Jan 06 11:32:57 2025 -0500
+++ b/src/forward_pdps.rs	Thu Jan 23 23:35:28 2025 +0100
@@ -8,30 +8,15 @@
 
 use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::euclidean::Euclidean;
-use alg_tools::sets::Cube;
-use alg_tools::loc::Loc;
-use alg_tools::mapping::{Mapping, Instance};
+use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
 use alg_tools::norms::Norm;
 use alg_tools::direct_product::Pair;
-use alg_tools::bisection_tree::{
-    BTFN,
-    PreBTFN,
-    Bounds,
-    BTNodeLookup,
-    BTNode,
-    BTSearch,
-    P2Minimise,
-    SupportGenerator,
-    LocalAnalysis,
-    //Bounded,
-};
-use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::linops::{
     BoundedLinear, AXPY, GEMV, Adjointable, IdOp,
 };
 use alg_tools::convex::{Conjugable, Prox};
-use alg_tools::norms::{L2, Linfinity, PairNorm};
+use alg_tools::norms::{L2, PairNorm};
 
 use crate::types::*;
 use crate::measures::{DiscreteMeasure, Radon, RNDM};
@@ -40,7 +25,6 @@
     ForwardModel,
     AdjointProductPairBoundedBy,
 };
-use crate::seminorms::DiscreteMeasureOp;
 use crate::plot::{
     SeqPlotter,
     Plotting,
@@ -83,12 +67,12 @@
 /// using primal-dual proximal splitting with a forward step.
 #[replace_float_literals(F::cast_from(literal))]
 pub fn pointsource_forward_pdps_pair<
-    'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize
+    F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize
 >(
-    opA : &'a A,
+    opA : &A,
     b : &A::Observable,
     reg : Reg,
-    op𝒟 : &'a 𝒟,
+    prox_penalty : &P,
     config : &ForwardPDPSConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
@@ -102,27 +86,19 @@
 where
     F : Float + ToNalgebraRealField,
     I : AlgIteratorFactory<IterInfo<F, N>>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
     A : ForwardModel<
             MeasureZ<F, Z, N>,
             F,
             PairNorm<Radon, L2, L2>,
-            PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>,
+            PreadjointCodomain = Pair<S, Z>,
         >
-        + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>,
-    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-    G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-    𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
-                                        Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
-    BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-    K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-    BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-    Cube<F, N>: P2Minimise<Loc<F, N>, F>,
+        + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>,
+    S: DifferentiableRealMapping<F, N>,
+    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
     PlotLookup : Plotting<N>,
     RNDM<F, N> : SpikeMerging<F>,
     Reg : RegTerm<F, N>,
+    P : ProxPenalty<F, S, Reg, N>,
     KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y>
         + GEMV<F, Z>
         + Adjointable<Z, Y, AdjointCodomain = Z>,
@@ -150,11 +126,10 @@
     let mut residual = calculate_residual(Pair(&μ, &z), opA, b);
 
     // Set up parameters
-    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
-    let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt();
+    let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt();
     let nKz = opKz.opnorm_bound(L2, L2);
     let opIdZ = IdOp::new();
-    let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap();
+    let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap();
     // We need to satisfy
     //
     //     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
@@ -196,14 +171,13 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
         // Calculate initial transport
-        let Pair(τv, τz) = opA.preadjoint().apply(residual * τ);
+        let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ);
         let z_base = z.clone();
         let μ_base = μ.clone();
 
         // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
-        let (d, _within_tolerances) = insert_and_reweigh(
-            &mut μ, &τv, &μ_base, None,
-            op𝒟, op𝒟norm,
+        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
+            &mut μ, &mut τv, &μ_base, None,
             τ, ε, &config.insertion,
             &reg, &state, &mut stats,
         );
@@ -248,7 +222,7 @@
         stats.this_iters += 1;
 
         state.if_verbose(|| {
-            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
             full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 

mercurial