src/prox_penalty/wave.rs

changeset 52
f0e8704d3f0e
parent 39
6316d68b58af
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/prox_penalty/wave.rs	Mon Feb 17 13:54:53 2025 -0500
@@ -0,0 +1,191 @@
+/*!
+Basic proximal penalty based on convolution operators $𝒟$.
+ */
+
+use numeric_literals::replace_float_literals;
+use nalgebra::DVector;
+use colored::Colorize;
+
+use alg_tools::types::*;
+use alg_tools::loc::Loc;
+use alg_tools::mapping::{Mapping, RealMapping};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::Linfinity;
+use alg_tools::iterate::{
+    AlgIteratorIteration,
+    AlgIterator,
+};
+use alg_tools::bisection_tree::{
+    BTFN,
+    PreBTFN,
+    Bounds,
+    BTSearch,
+    SupportGenerator,
+    LocalAnalysis,
+    BothGenerators,
+};
+use crate::measures::{
+    RNDM,
+    DeltaMeasure,
+    Radon,
+};
+use crate::measures::merging::{
+    SpikeMerging,
+};
+use crate::seminorms::DiscreteMeasureOp;
+use crate::types::{
+    IterInfo,
+};
+use crate::regularisation::RegTerm;
+use super::{ProxPenalty, FBGenericConfig};
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize>
+ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟
+where
+    F : Float + ToNalgebraRealField,
+    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+    G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
+    𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
+    𝒟::Codomain : RealMapping<F, N>,
+    K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+    Reg : RegTerm<F, N>,
+    RNDM<F, N> : SpikeMerging<F>,
+{
+    type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>;
+
+    fn insert_and_reweigh<I>(
+        &self,
+        μ : &mut RNDM<F, N>,
+        τv : &mut BTFN<F, GA, BTA, N>,
+        μ_base : &RNDM<F, N>,
+        ν_delta: Option<&RNDM<F, N>>,
+        τ : F,
+        ε : F,
+        config : &FBGenericConfig<F>,
+        reg : &Reg,
+        state : &AlgIteratorIteration<I>,
+        stats : &mut IterInfo<F, N>,
+    ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool)
+    where
+        I : AlgIterator
+    {
+
+        let op𝒟norm = self.opnorm_bound(Radon, Linfinity);
+
+        // Maximum insertion count and measure difference calculation depend on insertion style.
+        let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
+            (i, Some((l, k))) if i <= l => (k, false),
+            _ => (config.max_insertions, !state.is_quiet()),
+        };
+
+        let ω0 = match ν_delta {
+            None => self.apply(μ_base),
+            Some(ν) => self.apply(μ_base + ν),
+        };
+
+        // Add points to support until within error tolerance or maximum insertion count reached.
+        let mut count = 0;
+        let (within_tolerances, d) = 'insertion: loop {
+            if μ.len() > 0 {
+                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
+                // from the beginning of the iteration are all contained in the immutable c and g.
+                // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
+                // problems have not yet been updated to sign change.
+                let à = self.findim_matrix(μ.iter_locations());
+                let g̃ = DVector::from_iterator(μ.len(),
+                                               μ.iter_locations()
+                                                .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
+                                                .map(F::to_nalgebra_mixed));
+                let mut x = μ.masses_dvector();
+
+                // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
+                // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
+                // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
+                // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
+                // = n |𝒟| |x|_2, where n is the number of points. Therefore
+                let Ã_normest = op𝒟norm * F::cast_from(μ.len());
+
+                // Solve finite-dimensional subproblem.
+                stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
+
+                // Update masses of μ based on solution of finite-dimensional subproblem.
+                μ.set_masses_dvector(&x);
+            }
+
+            // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
+            // conditions in the predual space, and finding new points for insertion, if necessary.
+            let mut d = &*τv + match ν_delta {
+                None => self.preapply(μ.sub_matching(μ_base)),
+                Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν)
+            };
+
+            // If no merging heuristic is used, let's be more conservative about spike insertion,
+            // and skip it after first round. If merging is done, being more greedy about spike
+            // insertion also seems to improve performance.
+            let skip_by_rough_check = if config.merging.enabled {
+                false
+            } else {
+                count > 0
+            };
+
+            // Find a spike to insert, if needed
+            let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
+                &mut d, τ, ε, skip_by_rough_check, config
+            ) {
+                None => break 'insertion (true, d),
+                Some(res) => res,
+            };
+
+            // Break if maximum insertion count reached
+            if count >= max_insertions {
+                break 'insertion (in_bounds, d)
+            }
+
+            // No point in optimising the weight here; the finite-dimensional algorithm is fast.
+            *μ += DeltaMeasure { x : ξ, α : 0.0 };
+            count += 1;
+            stats.inserted += 1;
+        };
+
+        if !within_tolerances && warn_insertions {
+            // Complain (but continue) if we failed to get within tolerances
+            // by inserting more points.
+            let err = format!("Maximum insertions reached without achieving \
+                                subproblem solution tolerance");
+            println!("{}", err.red());
+        }
+
+        (Some(d), within_tolerances)
+    }
+
+    fn merge_spikes(
+        &self,
+        μ : &mut RNDM<F, N>,
+        τv : &mut BTFN<F, GA, BTA, N>,
+        μ_base : &RNDM<F, N>,
+        ν_delta: Option<&RNDM<F, N>>,
+        τ : F,
+        ε : F,
+        config : &FBGenericConfig<F>,
+        reg : &Reg,
+        fitness : Option<impl Fn(&RNDM<F, N>) -> F>,
+    ) -> usize
+    {
+        if config.fitness_merging {
+            if let Some(f) = fitness {
+                return μ.merge_spikes_fitness(config.merging, f, |&v| v)
+                        .1
+            }
+        }
+        μ.merge_spikes(config.merging, |μ_candidate| {
+            let mut d = &*τv + self.preapply(match ν_delta {
+                None => μ_candidate.sub_matching(μ_base),
+                Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
+            });
+            reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config)
+        })
+    }
+}

mercurial