| 1 /*! |
|
| 2 Solver for the point source localisation problem using a simplified forward-backward splitting method. |
|
| 3 |
|
| 4 Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. |
|
| 5 */ |
|
| 6 |
|
| 7 use numeric_literals::replace_float_literals; |
|
| 8 use serde::{Serialize, Deserialize}; |
|
| 9 use colored::Colorize; |
|
| 10 use nalgebra::DVector; |
|
| 11 |
|
| 12 use alg_tools::iterate::{ |
|
| 13 AlgIteratorFactory, |
|
| 14 AlgIteratorIteration, |
|
| 15 AlgIterator |
|
| 16 }; |
|
| 17 use alg_tools::euclidean::Euclidean; |
|
| 18 use alg_tools::linops::Mapping; |
|
| 19 use alg_tools::sets::Cube; |
|
| 20 use alg_tools::loc::Loc; |
|
| 21 use alg_tools::bisection_tree::{ |
|
| 22 BTFN, |
|
| 23 Bounds, |
|
| 24 BTNodeLookup, |
|
| 25 BTNode, |
|
| 26 BTSearch, |
|
| 27 P2Minimise, |
|
| 28 SupportGenerator, |
|
| 29 LocalAnalysis, |
|
| 30 }; |
|
| 31 use alg_tools::mapping::RealMapping; |
|
| 32 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 33 use alg_tools::norms::L2; |
|
| 34 |
|
| 35 use crate::types::*; |
|
| 36 use crate::measures::{ |
|
| 37 RNDM, |
|
| 38 DiscreteMeasure, |
|
| 39 DeltaMeasure, |
|
| 40 Radon, |
|
| 41 }; |
|
| 42 use crate::measures::merging::{ |
|
| 43 SpikeMergingMethod, |
|
| 44 SpikeMerging, |
|
| 45 }; |
|
| 46 use crate::forward_model::ForwardModel; |
|
| 47 use crate::plot::{ |
|
| 48 SeqPlotter, |
|
| 49 Plotting, |
|
| 50 PlotLookup |
|
| 51 }; |
|
| 52 use crate::regularisation::RegTerm; |
|
| 53 use crate::dataterm::{ |
|
| 54 calculate_residual, |
|
| 55 L2Squared, |
|
| 56 DataTerm, |
|
| 57 }; |
|
| 58 |
|
| 59 use crate::fb::{ |
|
| 60 FBGenericConfig, |
|
| 61 postprocess, |
|
| 62 prune_with_stats |
|
| 63 }; |
|
| 64 |
|
| 65 /// Settings for [`pointsource_radon_fb_reg`]. |
|
| 66 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 67 #[serde(default)] |
|
| 68 pub struct RadonFBConfig<F : Float> { |
|
| 69 /// Step length scaling |
|
| 70 pub τ0 : F, |
|
| 71 /// Generic parameters |
|
| 72 pub insertion : FBGenericConfig<F>, |
|
| 73 } |
|
| 74 |
|
| 75 #[replace_float_literals(F::cast_from(literal))] |
|
| 76 impl<F : Float> Default for RadonFBConfig<F> { |
|
| 77 fn default() -> Self { |
|
| 78 RadonFBConfig { |
|
| 79 τ0 : 0.99, |
|
| 80 insertion : Default::default() |
|
| 81 } |
|
| 82 } |
|
| 83 } |
|
| 84 |
|
| 85 #[replace_float_literals(F::cast_from(literal))] |
|
| 86 pub(crate) fn insert_and_reweigh< |
|
| 87 'a, F, GA, BTA, S, Reg, I, const N : usize |
|
| 88 >( |
|
| 89 μ : &mut RNDM<F, N>, |
|
| 90 τv : &mut BTFN<F, GA, BTA, N>, |
|
| 91 μ_base : &mut RNDM<F, N>, |
|
| 92 //_ν_delta: Option<&RNDM<F, N>>, |
|
| 93 τ : F, |
|
| 94 ε : F, |
|
| 95 config : &FBGenericConfig<F>, |
|
| 96 reg : &Reg, |
|
| 97 _state : &AlgIteratorIteration<I>, |
|
| 98 stats : &mut IterInfo<F, N>, |
|
| 99 ) |
|
| 100 where F : Float + ToNalgebraRealField, |
|
| 101 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 102 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 103 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 104 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 105 RNDM<F, N> : SpikeMerging<F>, |
|
| 106 Reg : RegTerm<F, N>, |
|
| 107 I : AlgIterator { |
|
| 108 |
|
| 109 'i_and_w: for i in 0..=1 { |
|
| 110 // Optimise weights |
|
| 111 if μ.len() > 0 { |
|
| 112 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
|
| 113 // from the beginning of the iteration are all contained in the immutable c and g. |
|
| 114 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
|
| 115 // problems have not yet been updated to sign change. |
|
| 116 let g̃ = DVector::from_iterator(μ.len(), |
|
| 117 μ.iter_locations() |
|
| 118 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); |
|
| 119 let mut x = μ.masses_dvector(); |
|
| 120 let y = μ_base.masses_dvector(); |
|
| 121 |
|
| 122 // Solve finite-dimensional subproblem. |
|
| 123 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); |
|
| 124 |
|
| 125 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
| 126 μ.set_masses_dvector(&x); |
|
| 127 } |
|
| 128 |
|
| 129 if i>0 { |
|
| 130 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. |
|
| 131 //let n = μ.dist_matching(μ_base); |
|
| 132 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); |
|
| 133 break 'i_and_w |
|
| 134 } |
|
| 135 |
|
| 136 // Calculate ‖μ - μ_base‖_ℳ |
|
| 137 let n = μ.dist_matching(μ_base); |
|
| 138 |
|
| 139 // Find a spike to insert, if needed. |
|
| 140 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, |
|
| 141 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. |
|
| 142 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { |
|
| 143 None => { break 'i_and_w }, |
|
| 144 Some((ξ, _v_ξ, _in_bounds)) => { |
|
| 145 // Weight is found out by running the finite-dimensional optimisation algorithm |
|
| 146 // above |
|
| 147 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
| 148 *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; |
|
| 149 stats.inserted += 1; |
|
| 150 } |
|
| 151 }; |
|
| 152 } |
|
| 153 } |
|
| 154 |
|
| 155 |
|
| 156 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. |
|
| 157 /// |
|
| 158 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the |
|
| 159 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
| 160 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
| 161 /// as documented in [`alg_tools::iterate`]. |
|
| 162 /// |
|
| 163 /// For details on the mathematical formulation, see the [module level](self) documentation. |
|
| 164 /// |
|
| 165 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 166 /// sums of simple functions usign bisection trees, and the related |
|
| 167 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 168 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 169 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 170 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 171 /// |
|
| 172 /// Returns the final iterate. |
|
| 173 #[replace_float_literals(F::cast_from(literal))] |
|
| 174 pub fn pointsource_radon_fb_reg< |
|
| 175 'a, F, I, A, GA, BTA, S, Reg, const N : usize |
|
| 176 >( |
|
| 177 opA : &'a A, |
|
| 178 b : &A::Observable, |
|
| 179 reg : Reg, |
|
| 180 fbconfig : &RadonFBConfig<F>, |
|
| 181 iterator : I, |
|
| 182 mut _plotter : SeqPlotter<F, N>, |
|
| 183 ) -> RNDM<F, N> |
|
| 184 where F : Float + ToNalgebraRealField, |
|
| 185 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 186 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
| 187 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 188 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
|
| 189 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 190 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 191 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 192 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 193 RNDM<F, N> : SpikeMerging<F>, |
|
| 194 Reg : RegTerm<F, N> { |
|
| 195 |
|
| 196 // Set up parameters |
|
| 197 let config = &fbconfig.insertion; |
|
| 198 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
|
| 199 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
|
| 200 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
|
| 201 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); |
|
| 202 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
|
| 203 // by τ compared to the conditional gradient approach. |
|
| 204 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
|
| 205 let mut ε = tolerance.initial(); |
|
| 206 |
|
| 207 // Initialise iterates |
|
| 208 let mut μ = DiscreteMeasure::new(); |
|
| 209 let mut residual = -b; |
|
| 210 |
|
| 211 // Statistics |
|
| 212 let full_stats = |residual : &A::Observable, |
|
| 213 μ : &RNDM<F, N>, |
|
| 214 ε, stats| IterInfo { |
|
| 215 value : residual.norm2_squared_div2() + reg.apply(μ), |
|
| 216 n_spikes : μ.len(), |
|
| 217 ε, |
|
| 218 // postprocessing: config.postprocessing.then(|| μ.clone()), |
|
| 219 .. stats |
|
| 220 }; |
|
| 221 let mut stats = IterInfo::new(); |
|
| 222 |
|
| 223 // Run the algorithm |
|
| 224 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
|
| 225 // Calculate smooth part of surrogate model. |
|
| 226 let mut τv = opA.preadjoint().apply(residual * τ); |
|
| 227 |
|
| 228 // Save current base point |
|
| 229 let mut μ_base = μ.clone(); |
|
| 230 |
|
| 231 // Insert and reweigh |
|
| 232 insert_and_reweigh( |
|
| 233 &mut μ, &mut τv, &mut μ_base, //None, |
|
| 234 τ, ε, |
|
| 235 config, ®, &state, &mut stats |
|
| 236 ); |
|
| 237 |
|
| 238 // Prune and possibly merge spikes |
|
| 239 assert!(μ_base.len() <= μ.len()); |
|
| 240 if config.merge_now(&state) { |
|
| 241 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
|
| 242 // Important: μ_candidate's new points are afterwards, |
|
| 243 // and do not conflict with μ_base. |
|
| 244 // TODO: could simplify to requiring μ_base instead of μ_radon. |
|
| 245 // but may complicate with sliding base's exgtra points that need to be |
|
| 246 // after μ_candidate's extra points. |
|
| 247 // TODO: doesn't seem to work, maybe need to merge μ_base as well? |
|
| 248 // Although that doesn't seem to make sense. |
|
| 249 let μ_radon = μ_candidate.sub_matching(&μ_base); |
|
| 250 reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon) |
|
| 251 //let n = μ_candidate.dist_matching(μ_base); |
|
| 252 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() |
|
| 253 }); |
|
| 254 } |
|
| 255 stats.pruned += prune_with_stats(&mut μ); |
|
| 256 |
|
| 257 // Update residual |
|
| 258 residual = calculate_residual(&μ, opA, b); |
|
| 259 |
|
| 260 let iter = state.iteration(); |
|
| 261 stats.this_iters += 1; |
|
| 262 |
|
| 263 // Give statistics if needed |
|
| 264 state.if_verbose(|| { |
|
| 265 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
|
| 266 }); |
|
| 267 |
|
| 268 // Update main tolerance for next iteration |
|
| 269 ε = tolerance.update(ε, iter); |
|
| 270 } |
|
| 271 |
|
| 272 postprocess(μ, config, L2Squared, opA, b) |
|
| 273 } |
|
| 274 |
|
| 275 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. |
|
| 276 /// |
|
| 277 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the |
|
| 278 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
| 279 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
| 280 /// as documented in [`alg_tools::iterate`]. |
|
| 281 /// |
|
| 282 /// For details on the mathematical formulation, see the [module level](self) documentation. |
|
| 283 /// |
|
| 284 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 285 /// sums of simple functions usign bisection trees, and the related |
|
| 286 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 287 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 288 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 289 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 290 /// |
|
| 291 /// Returns the final iterate. |
|
| 292 #[replace_float_literals(F::cast_from(literal))] |
|
| 293 pub fn pointsource_radon_fista_reg< |
|
| 294 'a, F, I, A, GA, BTA, S, Reg, const N : usize |
|
| 295 >( |
|
| 296 opA : &'a A, |
|
| 297 b : &A::Observable, |
|
| 298 reg : Reg, |
|
| 299 fbconfig : &RadonFBConfig<F>, |
|
| 300 iterator : I, |
|
| 301 mut plotter : SeqPlotter<F, N>, |
|
| 302 ) -> RNDM<F, N> |
|
| 303 where F : Float + ToNalgebraRealField, |
|
| 304 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 305 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
| 306 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 307 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
|
| 308 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 309 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 310 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 311 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 312 PlotLookup : Plotting<N>, |
|
| 313 RNDM<F, N> : SpikeMerging<F>, |
|
| 314 Reg : RegTerm<F, N> { |
|
| 315 |
|
| 316 // Set up parameters |
|
| 317 let config = &fbconfig.insertion; |
|
| 318 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
|
| 319 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
|
| 320 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
|
| 321 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); |
|
| 322 let mut λ = 1.0; |
|
| 323 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
|
| 324 // by τ compared to the conditional gradient approach. |
|
| 325 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
|
| 326 let mut ε = tolerance.initial(); |
|
| 327 |
|
| 328 // Initialise iterates |
|
| 329 let mut μ = DiscreteMeasure::new(); |
|
| 330 let mut μ_prev = DiscreteMeasure::new(); |
|
| 331 let mut residual = -b; |
|
| 332 let mut warned_merging = false; |
|
| 333 |
|
| 334 // Statistics |
|
| 335 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
|
| 336 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
|
| 337 n_spikes : ν.len(), |
|
| 338 ε, |
|
| 339 // postprocessing: config.postprocessing.then(|| ν.clone()), |
|
| 340 .. stats |
|
| 341 }; |
|
| 342 let mut stats = IterInfo::new(); |
|
| 343 |
|
| 344 // Run the algorithm |
|
| 345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
|
| 346 // Calculate smooth part of surrogate model. |
|
| 347 let mut τv = opA.preadjoint().apply(residual * τ); |
|
| 348 |
|
| 349 // Save current base point |
|
| 350 let mut μ_base = μ.clone(); |
|
| 351 |
|
| 352 // Insert new spikes and reweigh |
|
| 353 insert_and_reweigh( |
|
| 354 &mut μ, &mut τv, &mut μ_base, //None, |
|
| 355 τ, ε, |
|
| 356 config, ®, &state, &mut stats |
|
| 357 ); |
|
| 358 |
|
| 359 // (Do not) merge spikes. |
|
| 360 if config.merge_now(&state) { |
|
| 361 match config.merging { |
|
| 362 SpikeMergingMethod::None => { }, |
|
| 363 _ => if !warned_merging { |
|
| 364 let err = format!("Merging not supported for μFISTA"); |
|
| 365 println!("{}", err.red()); |
|
| 366 warned_merging = true; |
|
| 367 } |
|
| 368 } |
|
| 369 } |
|
| 370 |
|
| 371 // Update inertial prameters |
|
| 372 let λ_prev = λ; |
|
| 373 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
|
| 374 let θ = λ / λ_prev - λ; |
|
| 375 |
|
| 376 // Perform inertial update on μ. |
|
| 377 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ |
|
| 378 // and μ_prev have zero weight. Since both have weights from the finite-dimensional |
|
| 379 // subproblem with a proximal projection step, this is likely to happen when the |
|
| 380 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
|
| 381 // stored in μ_prev. |
|
| 382 let n_before_prune = μ.len(); |
|
| 383 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); |
|
| 384 debug_assert!(μ.len() <= n_before_prune); |
|
| 385 stats.pruned += n_before_prune - μ.len(); |
|
| 386 |
|
| 387 // Update residual |
|
| 388 residual = calculate_residual(&μ, opA, b); |
|
| 389 |
|
| 390 let iter = state.iteration(); |
|
| 391 stats.this_iters += 1; |
|
| 392 |
|
| 393 // Give statistics if needed |
|
| 394 state.if_verbose(|| { |
|
| 395 plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev); |
|
| 396 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) |
|
| 397 }); |
|
| 398 |
|
| 399 // Update main tolerance for next iteration |
|
| 400 ε = tolerance.update(ε, iter); |
|
| 401 } |
|
| 402 |
|
| 403 postprocess(μ_prev, config, L2Squared, opA, b) |
|
| 404 } |
|