src/prox_penalty/wave.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 62
32328a74c790
child 63
7a8a55fd41c0
--- a/src/prox_penalty/wave.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/prox_penalty/wave.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -2,84 +2,70 @@
 Basic proximal penalty based on convolution operators $𝒟$.
  */
 
-use numeric_literals::replace_float_literals;
-use nalgebra::DVector;
-use colored::Colorize;
-
-use alg_tools::types::*;
-use alg_tools::loc::Loc;
-use alg_tools::mapping::{Mapping, RealMapping};
+use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
+use crate::dataterm::QuadraticDataTerm;
+use crate::forward_model::ForwardModel;
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
+use crate::regularisation::RegTerm;
+use crate::seminorms::DiscreteMeasureOp;
+use crate::types::IterInfo;
+use alg_tools::bounds::MinMaxMapping;
+use alg_tools::error::DynResult;
+use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
+use alg_tools::linops::BoundedLinear;
+use alg_tools::mapping::{Instance, Mapping, Space};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::norms::Linfinity;
-use alg_tools::iterate::{
-    AlgIteratorIteration,
-    AlgIterator,
-};
-use alg_tools::bisection_tree::{
-    BTFN,
-    PreBTFN,
-    Bounds,
-    BTSearch,
-    SupportGenerator,
-    LocalAnalysis,
-    BothGenerators,
-};
-use crate::measures::{
-    RNDM,
-    DeltaMeasure,
-    Radon,
-};
-use crate::measures::merging::{
-    SpikeMerging,
-};
-use crate::seminorms::DiscreteMeasureOp;
-use crate::types::{
-    IterInfo,
-};
-use crate::regularisation::RegTerm;
-use super::{ProxPenalty, FBGenericConfig};
+use alg_tools::norms::{Linfinity, Norm, NormExponent, L2};
+use alg_tools::types::*;
+use colored::Colorize;
+use nalgebra::DVector;
+use numeric_literals::replace_float_literals;
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize>
-ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟
+impl<F, M, Reg, 𝒟, O, Domain> ProxPenalty<Domain, M, Reg, F> for 𝒟
 where
-    F : Float + ToNalgebraRealField,
-    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-    G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-    𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-    𝒟::Codomain : RealMapping<F, N>,
-    K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-    Reg : RegTerm<F, N>,
-    RNDM<F, N> : SpikeMerging<F>,
+    Domain: Space + Clone + PartialEq + 'static,
+    for<'a> &'a Domain: Instance<Domain>,
+    F: Float + ToNalgebraRealField,
+    𝒟: DiscreteMeasureOp<Domain, F>,
+    𝒟::Codomain: Mapping<Domain, Codomain = F>,
+    M: Mapping<Domain, Codomain = F>,
+    for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>,
+    O: MinMaxMapping<Domain, F>,
+    Reg: RegTerm<Domain, F>,
+    DiscreteMeasure<Domain, F>: SpikeMerging<F>,
 {
-    type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>;
+    type ReturnMapping = O;
+
+    fn prox_type() -> ProxTerm {
+        ProxTerm::Wave
+    }
 
     fn insert_and_reweigh<I>(
         &self,
-        μ : &mut RNDM<F, N>,
-        τv : &mut BTFN<F, GA, BTA, N>,
-        μ_base : &RNDM<F, N>,
-        ν_delta: Option<&RNDM<F, N>>,
-        τ : F,
-        ε : F,
-        config : &FBGenericConfig<F>,
-        reg : &Reg,
-        state : &AlgIteratorIteration<I>,
-        stats : &mut IterInfo<F, N>,
-    ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool)
+        μ: &mut DiscreteMeasure<Domain, F>,
+        τv: &mut M,
+        μ_base: &DiscreteMeasure<Domain, F>,
+        ν_delta: Option<&DiscreteMeasure<Domain, F>>,
+        τ: F,
+        ε: F,
+        config: &InsertionConfig<F>,
+        reg: &Reg,
+        state: &AlgIteratorIteration<I>,
+        stats: &mut IterInfo<F>,
+    ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
     where
-        I : AlgIterator
+        I: AlgIterator,
     {
-
-        let op𝒟norm = self.opnorm_bound(Radon, Linfinity);
+        let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?;
 
         // Maximum insertion count and measure difference calculation depend on insertion style.
-        let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
-            (i, Some((l, k))) if i <= l => (k, false),
-            _ => (config.max_insertions, !state.is_quiet()),
-        };
+        let (max_insertions, warn_insertions) =
+            match (state.iteration(), config.bootstrap_insertions) {
+                (i, Some((l, k))) if i <= l => (k, false),
+                _ => (config.max_insertions, !state.is_quiet()),
+            };
 
         let ω0 = match ν_delta {
             None => self.apply(μ_base),
@@ -95,10 +81,12 @@
                 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
                 // problems have not yet been updated to sign change.
                 let à = self.findim_matrix(μ.iter_locations());
-                let g̃ = DVector::from_iterator(μ.len(),
-                                               μ.iter_locations()
-                                                .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
-                                                .map(F::to_nalgebra_mixed));
+                let g̃ = DVector::from_iterator(
+                    μ.len(),
+                    μ.iter_locations()
+                        .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
+                        .map(F::to_nalgebra_mixed),
+                );
                 let mut x = μ.masses_dvector();
 
                 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
@@ -117,10 +105,11 @@
 
             // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
             // conditions in the predual space, and finding new points for insertion, if necessary.
-            let mut d = &*τv + match ν_delta {
-                None => self.preapply(μ.sub_matching(μ_base)),
-                Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν)
-            };
+            let mut d = &*τv
+                + match ν_delta {
+                    None => self.preapply(μ.sub_matching(μ_base)),
+                    Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν),
+                };
 
             // If no merging heuristic is used, let's be more conservative about spike insertion,
             // and skip it after first round. If merging is done, being more greedy about spike
@@ -132,20 +121,19 @@
             };
 
             // Find a spike to insert, if needed
-            let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
-                &mut d, τ, ε, skip_by_rough_check, config
-            ) {
-                None => break 'insertion (true, d),
-                Some(res) => res,
-            };
+            let (ξ, _v_ξ, in_bounds) =
+                match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) {
+                    None => break 'insertion (true, d),
+                    Some(res) => res,
+                };
 
             // Break if maximum insertion count reached
             if count >= max_insertions {
-                break 'insertion (in_bounds, d)
+                break 'insertion (in_bounds, d);
             }
 
             // No point in optimising the weight here; the finite-dimensional algorithm is fast.
-            *μ += DeltaMeasure { x : ξ, α : 0.0 };
+            *μ += DeltaMeasure { x: ξ, α: 0.0 };
             count += 1;
             stats.inserted += 1;
         };
@@ -153,39 +141,76 @@
         if !within_tolerances && warn_insertions {
             // Complain (but continue) if we failed to get within tolerances
             // by inserting more points.
-            let err = format!("Maximum insertions reached without achieving \
-                                subproblem solution tolerance");
+            let err = format!(
+                "Maximum insertions reached without achieving \
+                                subproblem solution tolerance"
+            );
             println!("{}", err.red());
         }
 
-        (Some(d), within_tolerances)
+        Ok((Some(d), within_tolerances))
     }
 
     fn merge_spikes(
         &self,
-        μ : &mut RNDM<F, N>,
-        τv : &mut BTFN<F, GA, BTA, N>,
-        μ_base : &RNDM<F, N>,
-        ν_delta: Option<&RNDM<F, N>>,
-        τ : F,
-        ε : F,
-        config : &FBGenericConfig<F>,
-        reg : &Reg,
-        fitness : Option<impl Fn(&RNDM<F, N>) -> F>,
-    ) -> usize
-    {
+        μ: &mut DiscreteMeasure<Domain, F>,
+        τv: &mut M,
+        μ_base: &DiscreteMeasure<Domain, F>,
+        ν_delta: Option<&DiscreteMeasure<Domain, F>>,
+        τ: F,
+        ε: F,
+        config: &InsertionConfig<F>,
+        reg: &Reg,
+        fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
+    ) -> usize {
         if config.fitness_merging {
             if let Some(f) = fitness {
-                return μ.merge_spikes_fitness(config.merging, f, |&v| v)
-                        .1
+                return μ.merge_spikes_fitness(config.merging, f, |&v| v).1;
             }
         }
         μ.merge_spikes(config.merging, |μ_candidate| {
-            let mut d = &*τv + self.preapply(match ν_delta {
-                None => μ_candidate.sub_matching(μ_base),
-                Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
-            });
+            let mut d = &*τv
+                + self.preapply(match ν_delta {
+                    None => μ_candidate.sub_matching(μ_base),
+                    Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
+                });
             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config)
         })
     }
 }
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<'a, F, A, 𝒟, Domain> StepLengthBound<F, QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>>
+    for 𝒟
+where
+    Domain: Space + Clone + PartialEq + 'static,
+    F: Float + ToNalgebraRealField,
+    𝒟: DiscreteMeasureOp<Domain, F>,
+    A: ForwardModel<DiscreteMeasure<Domain, F>, F>
+        + for<'b> BoundedLinear<DiscreteMeasure<Domain, F>, &'b 𝒟, L2, F>,
+    DiscreteMeasure<Domain, F>: for<'b> Norm<&'b 𝒟, F>,
+    for<'b> &'b 𝒟: NormExponent,
+{
+    fn step_length_bound(
+        &self,
+        f: &QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>,
+    ) -> DynResult<F> {
+        // TODO: direct squared calculation
+        Ok(f.operator().opnorm_bound(self, L2)?.powi(2))
+    }
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F, A, 𝒟, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for 𝒟
+where
+    Domain: Space + Clone + PartialEq + 'static,
+    F: Float + ToNalgebraRealField,
+    𝒟: DiscreteMeasureOp<Domain, F>,
+    A: for<'a> BoundedLinear<DiscreteMeasure<Domain, F>, &'a 𝒟, L2, F>,
+    DiscreteMeasure<Domain, F>: for<'a> Norm<&'a 𝒟, F>,
+    for<'b> &'b 𝒟: NormExponent,
+{
+    fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> {
+        opA.opnorm_bound(self, L2)
+    }
+}

mercurial