src/prox_penalty/radon_squared.rs

branch
dev
changeset 37
c5d8bd1a7728
child 39
6316d68b58af
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/prox_penalty/radon_squared.rs	Thu Jan 23 23:35:28 2025 +0100
@@ -0,0 +1,170 @@
+/*!
+Solver for the point source localisation problem using a simplified forward-backward splitting method.
+
+Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map.
+*/
+
+use numeric_literals::replace_float_literals;
+use serde::{Serialize, Deserialize};
+use nalgebra::DVector;
+
+use alg_tools::iterate::{
+    AlgIteratorIteration,
+    AlgIterator
+};
+use alg_tools::norms::L2;
+use alg_tools::linops::Mapping;
+use alg_tools::bisection_tree::{
+    BTFN,
+    Bounds,
+    BTSearch,
+    SupportGenerator,
+    LocalAnalysis,
+};
+use alg_tools::mapping::RealMapping;
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+
+use crate::types::*;
+use crate::measures::{
+    RNDM,
+    DeltaMeasure,
+    Radon,
+};
+use crate::measures::merging::SpikeMerging;
+use crate::regularisation::RegTerm;
+use crate::forward_model::{
+    ForwardModel,
+    AdjointProductBoundedBy
+};
+use super::{
+    FBGenericConfig,
+    ProxPenalty,
+};
+
+/// Radon-norm squared proximal penalty
+
+#[derive(Copy,Clone,Serialize,Deserialize)]
+pub struct RadonSquared;
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F, GA, BTA, S, Reg, const N : usize>
+ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for RadonSquared
+where
+    F : Float + ToNalgebraRealField,
+    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+    Reg : RegTerm<F, N>,
+    RNDM<F, N> : SpikeMerging<F>,
+{
+    type ReturnMapping = BTFN<F, GA, BTA, N>;
+
+    fn insert_and_reweigh<I>(
+        &self,
+        μ : &mut RNDM<F, N>,
+        τv : &mut BTFN<F, GA, BTA, N>,
+        μ_base : &RNDM<F, N>,
+        ν_delta: Option<&RNDM<F, N>>,
+        τ : F,
+        ε : F,
+        config : &FBGenericConfig<F>,
+        reg : &Reg,
+        _state : &AlgIteratorIteration<I>,
+        stats : &mut IterInfo<F, N>,
+    ) -> (Option<Self::ReturnMapping>, bool)
+    where
+        I : AlgIterator
+    {
+        assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term");
+
+        let mut y = μ_base.masses_vec();
+
+        'i_and_w: for i in 0..=1 {
+            // Optimise weights
+            if μ.len() > 0 {
+                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
+                // from the beginning of the iteration are all contained in the immutable c and g.
+                // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
+                // problems have not yet been updated to sign change.
+                let g̃ = DVector::from_iterator(μ.len(),
+                                               μ.iter_locations()
+                                                .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
+                let mut x = μ.masses_dvector();
+                // Ugly hack because DVector::push doesn't push but copies.
+                let yvec = DVector::from_column_slice(y.as_slice());
+                // Solve finite-dimensional subproblem.
+                stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config);
+
+                // Update masses of μ based on solution of finite-dimensional subproblem.
+                μ.set_masses_dvector(&x);
+            }
+
+            if i>0 {
+                // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
+                //let n = μ.dist_matching(μ_base);
+                //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
+                break 'i_and_w
+            }
+            
+            // Calculate ‖μ - μ_base‖_ℳ
+            let n = μ.dist_matching(μ_base);
+        
+            // Find a spike to insert, if needed.
+            // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
+            // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
+            match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
+                None => { break 'i_and_w },
+                Some((ξ, _v_ξ, _in_bounds)) => {
+                    // Weight is found out by running the finite-dimensional optimisation algorithm
+                    // above
+                    *μ += DeltaMeasure { x : ξ, α : 0.0 };
+                    //*μ_base += DeltaMeasure { x : ξ, α : 0.0 };
+                    y.push(0.0.to_nalgebra_mixed());
+                    stats.inserted += 1;
+                }
+            };
+        }
+
+        (None, true)
+    }
+
+    fn merge_spikes(
+        &self,
+        μ : &mut RNDM<F, N>,
+        τv : &mut BTFN<F, GA, BTA, N>,
+        μ_base : &RNDM<F, N>,
+        τ : F,
+        ε : F,
+        config : &FBGenericConfig<F>,
+        reg : &Reg,
+    ) -> usize
+    {
+        μ.merge_spikes(config.merging, |μ_candidate| {
+            // Important: μ_candidate's new points are afterwards,
+            // and do not conflict with μ_base.
+            // TODO: could simplify to requiring μ_base instead of μ_radon.
+            // but may complicate with sliding base's exgtra points that need to be
+            // after μ_candidate's extra points.
+            // TODO: doesn't seem to work, maybe need to merge μ_base as well?
+            // Although that doesn't seem to make sense.
+            let μ_radon = μ_candidate.sub_matching(μ_base);
+            reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon)
+            //let n = μ_candidate.dist_matching(μ_base);
+            //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
+        })
+    }
+}
+
+
+impl<F, A, const N : usize> AdjointProductBoundedBy<RNDM<F, N>, RadonSquared>
+for A
+where
+    F : Float,
+    A : ForwardModel<RNDM<F, N>, F>
+{
+    type FloatType = F;
+
+    fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option<Self::FloatType> {
+        self.opnorm_bound(Radon, L2).powi(2).into()
+    }
+}

mercurial