Tue, 31 Dec 2024 09:25:45 -0500
New version of sliding.
| 0 | 1 | /*! |
| 2 | Solver for the point source localisation problem with primal-dual proximal splitting. | |
| 3 | ||
| 4 | This corresponds to the manuscript | |
| 5 | ||
|
13
bdc57366d4f5
arXiv links, README beautification
Tuomo Valkonen <tuomov@iki.fi>
parents:
0
diff
changeset
|
6 | * Valkonen T. - _Proximal methods for point source localisation_, |
|
bdc57366d4f5
arXiv links, README beautification
Tuomo Valkonen <tuomov@iki.fi>
parents:
0
diff
changeset
|
7 | [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). |
| 0 | 8 | |
| 35 | 9 | The main routine is [`pointsource_pdps_reg`]. |
| 0 | 10 | Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for |
| 11 | <div> | |
| 12 | $$ | |
| 13 | \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ - b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ), | |
| 14 | $$ | |
| 15 | for both $F_0(y)=\frac{1}{2}\|y\|_2^2$ and $F_0(y)=\|y\|_1$ with the forward operator | |
| 16 | $A \in 𝕃(ℳ(Ω); ℝ^n)$. | |
| 17 | </div> | |
| 18 | ||
| 19 | ## Approach | |
| 20 | ||
| 21 | <p> | |
| 22 | The problem above can be written as | |
| 23 | $$ | |
| 24 | \min_μ \max_y G(μ) + ⟨y, Aμ-b⟩ - F_0^*(μ), | |
| 25 | $$ | |
| 26 | where $G(μ) = α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ)$. | |
| 27 | The Fenchel–Rockafellar optimality conditions, employing the predual in $ℳ(Ω)$, are | |
| 28 | $$ | |
| 29 | 0 ∈ A_*y + ∂G(μ) | |
| 30 | \quad\text{and}\quad | |
| 31 | Aμ - b ∈ ∂ F_0^*(y). | |
| 32 | $$ | |
| 33 | The solution of the first part is as for forward-backward, treated in the manuscript. | |
| 34 | This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> | |
| 35 | to replace the specific residual $Aμ-b$ by $y$. | |
| 36 | For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. | |
| 37 | For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. | |
| 38 | </p> | |
| 39 | */ | |
| 40 | ||
| 41 | use numeric_literals::replace_float_literals; | |
| 42 | use serde::{Serialize, Deserialize}; | |
| 43 | use nalgebra::DVector; | |
| 44 | use clap::ValueEnum; | |
| 45 | ||
| 35 | 46 | use alg_tools::iterate::AlgIteratorFactory; |
| 0 | 47 | use alg_tools::loc::Loc; |
| 48 | use alg_tools::euclidean::Euclidean; | |
| 35 | 49 | use alg_tools::linops::Mapping; |
| 0 | 50 | use alg_tools::norms::{ |
| 32 | 51 | Linfinity, |
| 52 | Projection, | |
| 0 | 53 | }; |
| 54 | use alg_tools::bisection_tree::{ | |
| 55 | BTFN, | |
| 56 | PreBTFN, | |
| 57 | Bounds, | |
| 58 | BTNodeLookup, | |
| 59 | BTNode, | |
| 60 | BTSearch, | |
| 61 | SupportGenerator, | |
| 62 | LocalAnalysis, | |
| 63 | }; | |
| 35 | 64 | use alg_tools::mapping::{RealMapping, Instance}; |
| 0 | 65 | use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 66 | use alg_tools::linops::AXPY; | |
| 67 | ||
| 68 | use crate::types::*; | |
| 35 | 69 | use crate::measures::{DiscreteMeasure, RNDM, Radon}; |
| 32 | 70 | use crate::measures::merging::SpikeMerging; |
| 35 | 71 | use crate::forward_model::{ |
| 72 | AdjointProductBoundedBy, | |
| 73 | ForwardModel | |
| 74 | }; | |
| 32 | 75 | use crate::seminorms::DiscreteMeasureOp; |
| 0 | 76 | use crate::plot::{ |
| 77 | SeqPlotter, | |
| 78 | Plotting, | |
| 79 | PlotLookup | |
| 80 | }; | |
| 81 | use crate::fb::{ | |
| 82 | FBGenericConfig, | |
| 32 | 83 | insert_and_reweigh, |
| 84 | postprocess, | |
| 35 | 85 | prune_with_stats |
| 32 | 86 | }; |
| 87 | use crate::regularisation::RegTerm; | |
| 88 | use crate::dataterm::{ | |
| 89 | DataTerm, | |
| 90 | L2Squared, | |
| 91 | L1 | |
| 0 | 92 | }; |
| 93 | ||
| 94 | /// Acceleration | |
| 95 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] | |
| 96 | pub enum Acceleration { | |
| 97 | /// No acceleration | |
| 98 | #[clap(name = "none")] | |
| 99 | None, | |
| 100 | /// Partial acceleration, $ω = 1/\sqrt{1+σ}$ | |
| 101 | #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] | |
| 102 | Partial, | |
| 103 | /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed | |
| 104 | #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] | |
| 105 | Full | |
| 106 | } | |
| 107 | ||
| 35 | 108 | #[replace_float_literals(F::cast_from(literal))] |
| 109 | impl Acceleration { | |
| 110 | /// PDPS parameter acceleration. Updates τ and σ and returns ω. | |
| 111 | /// This uses dual strong convexity, not primal. | |
| 112 | fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { | |
| 113 | match self { | |
| 114 | Acceleration::None => 1.0, | |
| 115 | Acceleration::Partial => { | |
| 116 | let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); | |
| 117 | *σ *= ω; | |
| 118 | *τ /= ω; | |
| 119 | ω | |
| 120 | }, | |
| 121 | Acceleration::Full => { | |
| 122 | let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); | |
| 123 | *σ *= ω; | |
| 124 | *τ /= ω; | |
| 125 | ω | |
| 126 | }, | |
| 127 | } | |
| 128 | } | |
| 129 | } | |
| 130 | ||
| 131 | /// Settings for [`pointsource_pdps_reg`]. | |
| 0 | 132 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 133 | #[serde(default)] | |
| 134 | pub struct PDPSConfig<F : Float> { | |
| 135 | /// Primal step length scaling. We must have `τ0 * σ0 < 1`. | |
| 136 | pub τ0 : F, | |
| 137 | /// Dual step length scaling. We must have `τ0 * σ0 < 1`. | |
| 138 | pub σ0 : F, | |
| 139 | /// Accelerate if available | |
| 140 | pub acceleration : Acceleration, | |
| 141 | /// Generic parameters | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
142 | pub generic : FBGenericConfig<F>, |
| 0 | 143 | } |
| 144 | ||
| 145 | #[replace_float_literals(F::cast_from(literal))] | |
| 146 | impl<F : Float> Default for PDPSConfig<F> { | |
| 147 | fn default() -> Self { | |
| 148 | let τ0 = 0.5; | |
| 149 | PDPSConfig { | |
| 150 | τ0, | |
| 151 | σ0 : 0.99/τ0, | |
| 152 | acceleration : Acceleration::Partial, | |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
153 | generic : Default::default(), |
| 0 | 154 | } |
| 155 | } | |
| 156 | } | |
| 157 | ||
| 32 | 158 | /// Trait for data terms for the PDPS |
| 159 | #[replace_float_literals(F::cast_from(literal))] | |
| 160 | pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { | |
| 161 | /// Calculate some subdifferential at `x` for the conjugate | |
| 162 | fn some_subdifferential(&self, x : V) -> V; | |
| 163 | ||
| 164 | /// Factor of strong convexity of the conjugate | |
| 165 | #[inline] | |
| 166 | fn factor_of_strong_convexity(&self) -> F { | |
| 167 | 0.0 | |
| 168 | } | |
| 169 | ||
| 170 | /// Perform dual update | |
| 171 | fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); | |
| 0 | 172 | } |
| 173 | ||
| 32 | 174 | |
| 175 | #[replace_float_literals(F::cast_from(literal))] | |
| 35 | 176 | impl<F, V, const N : usize> PDPSDataTerm<F, V, N> |
| 177 | for L2Squared | |
| 178 | where | |
| 179 | F : Float, | |
| 180 | V : Euclidean<F> + AXPY<F>, | |
| 181 | for<'b> &'b V : Instance<V>, | |
| 182 | { | |
| 32 | 183 | fn some_subdifferential(&self, x : V) -> V { x } |
| 0 | 184 | |
| 32 | 185 | fn factor_of_strong_convexity(&self) -> F { |
| 186 | 1.0 | |
| 187 | } | |
| 188 | ||
| 189 | #[inline] | |
| 190 | fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { | |
| 35 | 191 | y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); |
| 32 | 192 | } |
| 0 | 193 | } |
| 194 | ||
| 32 | 195 | #[replace_float_literals(F::cast_from(literal))] |
| 196 | impl<F : Float + nalgebra::RealField, const N : usize> | |
| 197 | PDPSDataTerm<F, DVector<F>, N> | |
| 198 | for L1 { | |
| 0 | 199 | fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
| 200 | // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. | |
| 201 | x.iter_mut() | |
| 202 | .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); | |
| 203 | x | |
| 204 | } | |
| 205 | ||
| 32 | 206 | #[inline] |
| 207 | fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { | |
| 208 | y.axpy(1.0, y_prev, σ); | |
| 0 | 209 | y.proj_ball_mut(1.0, Linfinity); |
| 210 | } | |
| 211 | } | |
| 212 | ||
| 213 | /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. | |
| 214 | /// | |
| 215 | /// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared. | |
| 216 | /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the | |
| 217 | /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. | |
| 218 | /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution | |
| 219 | /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control | |
| 220 | /// as documented in [`alg_tools::iterate`]. | |
| 221 | /// | |
| 222 | /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. | |
| 223 | /// | |
| 224 | /// Returns the final iterate. | |
| 225 | #[replace_float_literals(F::cast_from(literal))] | |
|
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
226 | pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
| 0 | 227 | opA : &'a A, |
| 228 | b : &'a A::Observable, | |
|
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
229 | reg : Reg, |
| 0 | 230 | op𝒟 : &'a 𝒟, |
| 32 | 231 | pdpsconfig : &PDPSConfig<F>, |
| 0 | 232 | iterator : I, |
| 32 | 233 | mut plotter : SeqPlotter<F, N>, |
| 0 | 234 | dataterm : D, |
| 35 | 235 | ) -> RNDM<F, N> |
| 0 | 236 | where F : Float + ToNalgebraRealField, |
| 237 | I : AlgIteratorFactory<IterInfo<F, N>>, | |
| 35 | 238 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
| 0 | 239 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 35 | 240 | A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
| 241 | + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, | |
| 0 | 242 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 243 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
| 244 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
| 245 | 𝒟::Codomain : RealMapping<F, N>, | |
| 246 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
| 247 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
| 248 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
| 249 | PlotLookup : Plotting<N>, | |
| 35 | 250 | RNDM<F, N> : SpikeMerging<F>, |
| 32 | 251 | D : PDPSDataTerm<F, A::Observable, N>, |
|
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
252 | Reg : RegTerm<F, N> { |
| 0 | 253 | |
| 35 | 254 | // Check parameters |
| 255 | assert!(pdpsconfig.τ0 > 0.0 && | |
| 256 | pdpsconfig.σ0 > 0.0 && | |
| 257 | pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, | |
| 258 | "Invalid step length parameters"); | |
| 259 | ||
| 32 | 260 | // Set up parameters |
|
34
efa60bc4f743
Radon FB + sliding improvements
Tuomo Valkonen <tuomov@iki.fi>
parents:
32
diff
changeset
|
261 | let config = &pdpsconfig.generic; |
| 35 | 262 | let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
| 263 | let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); | |
| 32 | 264 | let mut τ = pdpsconfig.τ0 / l; |
| 265 | let mut σ = pdpsconfig.σ0 / l; | |
| 266 | let γ = dataterm.factor_of_strong_convexity(); | |
| 267 | ||
| 268 | // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled | |
| 269 | // by τ compared to the conditional gradient approach. | |
| 270 | let tolerance = config.tolerance * τ * reg.tolerance_scaling(); | |
| 271 | let mut ε = tolerance.initial(); | |
| 272 | ||
| 273 | // Initialise iterates | |
| 274 | let mut μ = DiscreteMeasure::new(); | |
| 275 | let mut y = dataterm.some_subdifferential(-b); | |
| 276 | let mut y_prev = y.clone(); | |
| 35 | 277 | let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { |
| 278 | value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), | |
| 279 | n_spikes : μ.len(), | |
| 280 | ε, | |
| 281 | // postprocessing: config.postprocessing.then(|| μ.clone()), | |
| 282 | .. stats | |
| 283 | }; | |
| 32 | 284 | let mut stats = IterInfo::new(); |
| 285 | ||
| 286 | // Run the algorithm | |
| 35 | 287 | for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 32 | 288 | // Calculate smooth part of surrogate model. |
| 35 | 289 | let τv = opA.preadjoint().apply(y * τ); |
| 32 | 290 | |
| 291 | // Save current base point | |
| 292 | let μ_base = μ.clone(); | |
| 293 | ||
| 294 | // Insert and reweigh | |
| 35 | 295 | let (d, _within_tolerances) = insert_and_reweigh( |
| 296 | &mut μ, &τv, &μ_base, None, | |
| 32 | 297 | op𝒟, op𝒟norm, |
| 298 | τ, ε, | |
| 35 | 299 | config, ®, &state, &mut stats |
| 32 | 300 | ); |
| 301 | ||
| 302 | // Prune and possibly merge spikes | |
| 35 | 303 | if config.merge_now(&state) { |
| 304 | stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { | |
| 305 | let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); | |
| 306 | reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) | |
| 307 | }); | |
| 308 | } | |
| 309 | stats.pruned += prune_with_stats(&mut μ); | |
| 0 | 310 | |
| 32 | 311 | // Update step length parameters |
| 35 | 312 | let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 32 | 313 | |
| 314 | // Do dual update | |
| 315 | y = b.clone(); // y = b | |
| 316 | opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b | |
| 317 | opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b | |
| 318 | dataterm.dual_update(&mut y, &y_prev, σ); | |
| 319 | y_prev.copy_from(&y); | |
| 0 | 320 | |
| 35 | 321 | // Give statistics if requested |
| 322 | let iter = state.iteration(); | |
| 32 | 323 | stats.this_iters += 1; |
| 324 | ||
| 325 | state.if_verbose(|| { | |
| 35 | 326 | plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
| 327 | full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) | |
| 328 | }); | |
| 329 | ||
| 330 | ε = tolerance.update(ε, iter); | |
| 331 | } | |
| 32 | 332 | |
| 333 | postprocess(μ, config, dataterm, opA, b) | |
|
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
334 | } |
|
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
335 |