| 74 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. |
71 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. |
| 75 \end{aligned} |
72 \end{aligned} |
| 76 $$ |
73 $$ |
| 77 </p> |
74 </p> |
| 78 |
75 |
| 79 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by |
76 We solve this with either SSN or FB as determined by |
| 80 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. |
| 81 */ |
78 */ |
| 82 |
79 |
| |
80 use colored::Colorize; |
| 83 use numeric_literals::replace_float_literals; |
81 use numeric_literals::replace_float_literals; |
| 84 use serde::{Serialize, Deserialize}; |
82 use serde::{Deserialize, Serialize}; |
| 85 use colored::Colorize; |
83 |
| 86 use nalgebra::{DVector, DMatrix}; |
|
| 87 |
|
| 88 use alg_tools::iterate::{ |
|
| 89 AlgIteratorFactory, |
|
| 90 AlgIteratorState, |
|
| 91 }; |
|
| 92 use alg_tools::euclidean::Euclidean; |
84 use alg_tools::euclidean::Euclidean; |
| 93 use alg_tools::linops::Apply; |
85 use alg_tools::instance::Instance; |
| 94 use alg_tools::sets::Cube; |
86 use alg_tools::iterate::AlgIteratorFactory; |
| 95 use alg_tools::loc::Loc; |
87 use alg_tools::linops::{Mapping, GEMV}; |
| 96 use alg_tools::mapping::Mapping; |
|
| 97 use alg_tools::bisection_tree::{ |
|
| 98 BTFN, |
|
| 99 PreBTFN, |
|
| 100 Bounds, |
|
| 101 BTNodeLookup, |
|
| 102 BTNode, |
|
| 103 BTSearch, |
|
| 104 P2Minimise, |
|
| 105 SupportGenerator, |
|
| 106 LocalAnalysis, |
|
| 107 Bounded, |
|
| 108 }; |
|
| 109 use alg_tools::mapping::RealMapping; |
88 use alg_tools::mapping::RealMapping; |
| 110 use alg_tools::nalgebra_support::ToNalgebraRealField; |
89 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 111 |
90 |
| |
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; |
| |
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; |
| |
93 use crate::measures::merging::SpikeMerging; |
| |
94 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
| |
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; |
| |
97 use crate::regularisation::RegTerm; |
| 112 use crate::types::*; |
98 use crate::types::*; |
| 113 use crate::measures::{ |
|
| 114 DiscreteMeasure, |
|
| 115 DeltaMeasure, |
|
| 116 }; |
|
| 117 use crate::measures::merging::{ |
|
| 118 SpikeMergingMethod, |
|
| 119 SpikeMerging, |
|
| 120 }; |
|
| 121 use crate::forward_model::ForwardModel; |
|
| 122 use crate::seminorms::{ |
|
| 123 DiscreteMeasureOp, Lipschitz |
|
| 124 }; |
|
| 125 use crate::subproblem::{ |
|
| 126 nonneg::quadratic_nonneg, |
|
| 127 unconstrained::quadratic_unconstrained, |
|
| 128 InnerSettings, |
|
| 129 InnerMethod, |
|
| 130 }; |
|
| 131 use crate::tolerance::Tolerance; |
|
| 132 use crate::plot::{ |
|
| 133 SeqPlotter, |
|
| 134 Plotting, |
|
| 135 PlotLookup |
|
| 136 }; |
|
| 137 use crate::regularisation::{ |
|
| 138 NonnegRadonRegTerm, |
|
| 139 RadonRegTerm, |
|
| 140 }; |
|
| 141 |
|
| 142 /// Method for constructing $μ$ on each iteration |
|
| 143 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 144 #[allow(dead_code)] |
|
| 145 pub enum InsertionStyle { |
|
| 146 /// Resuse previous $μ$ from previous iteration, optimising weights |
|
| 147 /// before inserting new spikes. |
|
| 148 Reuse, |
|
| 149 /// Start each iteration with $μ=0$. |
|
| 150 Zero, |
|
| 151 } |
|
| 152 |
|
| 153 /// Meta-algorithm type |
|
| 154 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 155 #[allow(dead_code)] |
|
| 156 pub enum FBMetaAlgorithm { |
|
| 157 /// No meta-algorithm |
|
| 158 None, |
|
| 159 /// FISTA-style inertia |
|
| 160 InertiaFISTA, |
|
| 161 } |
|
| 162 |
99 |
| 163 /// Settings for [`pointsource_fb_reg`]. |
100 /// Settings for [`pointsource_fb_reg`]. |
| 164 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 165 #[serde(default)] |
102 #[serde(default)] |
| 166 pub struct FBConfig<F : Float> { |
103 pub struct FBConfig<F: Float> { |
| 167 /// Step length scaling |
104 /// Step length scaling |
| 168 pub τ0 : F, |
105 pub τ0: F, |
| 169 /// Meta-algorithm to apply |
|
| 170 pub meta : FBMetaAlgorithm, |
|
| 171 /// Generic parameters |
106 /// Generic parameters |
| 172 pub insertion : FBGenericConfig<F>, |
107 pub generic: FBGenericConfig<F>, |
| 173 } |
|
| 174 |
|
| 175 /// Settings for the solution of the stepwise optimality condition in algorithms based on |
|
| 176 /// [`generic_pointsource_fb_reg`]. |
|
| 177 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 178 #[serde(default)] |
|
| 179 pub struct FBGenericConfig<F : Float> { |
|
| 180 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. |
|
| 181 pub insertion_style : InsertionStyle, |
|
| 182 /// Tolerance for point insertion. |
|
| 183 pub tolerance : Tolerance<F>, |
|
| 184 /// Stop looking for predual maximum (where to isert a new point) below |
|
| 185 /// `tolerance` multiplied by this factor. |
|
| 186 pub insertion_cutoff_factor : F, |
|
| 187 /// Settings for branch and bound refinement when looking for predual maxima |
|
| 188 pub refinement : RefinementSettings<F>, |
|
| 189 /// Maximum insertions within each outer iteration |
|
| 190 pub max_insertions : usize, |
|
| 191 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
|
| 192 pub bootstrap_insertions : Option<(usize, usize)>, |
|
| 193 /// Inner method settings |
|
| 194 pub inner : InnerSettings<F>, |
|
| 195 /// Spike merging method |
|
| 196 pub merging : SpikeMergingMethod<F>, |
|
| 197 /// Tolerance multiplier for merges |
|
| 198 pub merge_tolerance_mult : F, |
|
| 199 /// Spike merging method after the last step |
|
| 200 pub final_merging : SpikeMergingMethod<F>, |
|
| 201 /// Iterations between merging heuristic tries |
|
| 202 pub merge_every : usize, |
|
| 203 /// Save $μ$ for postprocessing optimisation |
|
| 204 pub postprocessing : bool |
|
| 205 } |
108 } |
| 206 |
109 |
| 207 #[replace_float_literals(F::cast_from(literal))] |
110 #[replace_float_literals(F::cast_from(literal))] |
| 208 impl<F : Float> Default for FBConfig<F> { |
111 impl<F: Float> Default for FBConfig<F> { |
| 209 fn default() -> Self { |
112 fn default() -> Self { |
| 210 FBConfig { |
113 FBConfig { |
| 211 τ0 : 0.99, |
114 τ0: 0.99, |
| 212 meta : FBMetaAlgorithm::None, |
115 generic: Default::default(), |
| 213 insertion : Default::default() |
|
| 214 } |
116 } |
| 215 } |
117 } |
| 216 } |
118 } |
| 217 |
119 |
| |
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { |
| |
121 let n_before_prune = μ.len(); |
| |
122 μ.prune(); |
| |
123 debug_assert!(μ.len() <= n_before_prune); |
| |
124 n_before_prune - μ.len() |
| |
125 } |
| |
126 |
| 218 #[replace_float_literals(F::cast_from(literal))] |
127 #[replace_float_literals(F::cast_from(literal))] |
| 219 impl<F : Float> Default for FBGenericConfig<F> { |
128 pub(crate) fn postprocess< |
| 220 fn default() -> Self { |
129 F: Float, |
| 221 FBGenericConfig { |
130 V: Euclidean<F> + Clone, |
| 222 insertion_style : InsertionStyle::Reuse, |
131 A: GEMV<F, RNDM<F, N>, Codomain = V>, |
| 223 tolerance : Default::default(), |
132 D: DataTerm<F, V, N>, |
| 224 insertion_cutoff_factor : 1.0, |
133 const N: usize, |
| 225 refinement : Default::default(), |
134 >( |
| 226 max_insertions : 100, |
135 mut μ: RNDM<F, N>, |
| 227 //bootstrap_insertions : None, |
136 config: &FBGenericConfig<F>, |
| 228 bootstrap_insertions : Some((10, 1)), |
137 dataterm: D, |
| 229 inner : InnerSettings { |
138 opA: &A, |
| 230 method : InnerMethod::SSN, |
139 b: &V, |
| 231 .. Default::default() |
140 ) -> RNDM<F, N> |
| 232 }, |
141 where |
| 233 merging : SpikeMergingMethod::None, |
142 RNDM<F, N>: SpikeMerging<F>, |
| 234 //merging : Default::default(), |
143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, |
| 235 final_merging : Default::default(), |
144 { |
| 236 merge_every : 10, |
145 μ.merge_spikes_fitness( |
| 237 merge_tolerance_mult : 2.0, |
146 config.final_merging_method(), |
| 238 postprocessing : false, |
147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
| |
148 |&v| v, |
| |
149 ); |
| |
150 μ.prune(); |
| |
151 μ |
| |
152 } |
| |
153 |
| |
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
| |
155 /// |
| |
156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| |
157 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
| |
158 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
| |
159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| |
160 /// as documented in [`alg_tools::iterate`]. |
| |
161 /// |
| |
162 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| |
163 /// |
| |
164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
| |
165 /// sums of simple functions usign bisection trees, and the related |
| |
166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
| |
167 /// active at a specific points, and to maximise their sums. Through the implementation of the |
| |
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
| |
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
| |
170 /// |
| |
171 /// Returns the final iterate. |
| |
172 #[replace_float_literals(F::cast_from(literal))] |
| |
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( |
| |
174 opA: &A, |
| |
175 b: &A::Observable, |
| |
176 reg: Reg, |
| |
177 prox_penalty: &P, |
| |
178 fbconfig: &FBConfig<F>, |
| |
179 iterator: I, |
| |
180 mut plotter: SeqPlotter<F, N>, |
| |
181 ) -> RNDM<F, N> |
| |
182 where |
| |
183 F: Float + ToNalgebraRealField, |
| |
184 I: AlgIteratorFactory<IterInfo<F, N>>, |
| |
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
| |
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
| |
187 A::PreadjointCodomain: RealMapping<F, N>, |
| |
188 PlotLookup: Plotting<N>, |
| |
189 RNDM<F, N>: SpikeMerging<F>, |
| |
190 Reg: RegTerm<F, N>, |
| |
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| |
192 { |
| |
193 // Set up parameters |
| |
194 let config = &fbconfig.generic; |
| |
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| |
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| |
197 // by τ compared to the conditional gradient approach. |
| |
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| |
199 let mut ε = tolerance.initial(); |
| |
200 |
| |
201 // Initialise iterates |
| |
202 let mut μ = DiscreteMeasure::new(); |
| |
203 let mut residual = -b; |
| |
204 |
| |
205 // Statistics |
| |
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
| |
207 value: residual.norm2_squared_div2() + reg.apply(μ), |
| |
208 n_spikes: μ.len(), |
| |
209 ε, |
| |
210 //postprocessing: config.postprocessing.then(|| μ.clone()), |
| |
211 ..stats |
| |
212 }; |
| |
213 let mut stats = IterInfo::new(); |
| |
214 |
| |
215 // Run the algorithm |
| |
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
| |
217 // Calculate smooth part of surrogate model. |
| |
218 let mut τv = opA.preadjoint().apply(residual * τ); |
| |
219 |
| |
220 // Save current base point |
| |
221 let μ_base = μ.clone(); |
| |
222 |
| |
223 // Insert and reweigh |
| |
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| |
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| |
226 ); |
| |
227 |
| |
228 // Prune and possibly merge spikes |
| |
229 if config.merge_now(&state) { |
| |
230 stats.merged += prox_penalty.merge_spikes( |
| |
231 &mut μ, |
| |
232 &mut τv, |
| |
233 &μ_base, |
| |
234 None, |
| |
235 τ, |
| |
236 ε, |
| |
237 config, |
| |
238 ®, |
| |
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
| |
240 ); |
| 239 } |
241 } |
| |
242 |
| |
243 stats.pruned += prune_with_stats(&mut μ); |
| |
244 |
| |
245 // Update residual |
| |
246 residual = calculate_residual(&μ, opA, b); |
| |
247 |
| |
248 let iter = state.iteration(); |
| |
249 stats.this_iters += 1; |
| |
250 |
| |
251 // Give statistics if needed |
| |
252 state.if_verbose(|| { |
| |
253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| |
254 full_stats( |
| |
255 &residual, |
| |
256 &μ, |
| |
257 ε, |
| |
258 std::mem::replace(&mut stats, IterInfo::new()), |
| |
259 ) |
| |
260 }); |
| |
261 |
| |
262 // Update main tolerance for next iteration |
| |
263 ε = tolerance.update(ε, iter); |
| 240 } |
264 } |
| 241 } |
265 |
| 242 |
266 postprocess(μ, config, L2Squared, opA, b) |
| 243 /// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA. |
267 } |
| 244 /// |
268 |
| 245 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary |
269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
| 246 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it |
270 /// |
| 247 /// with the dual variable $y$. We can then also implement alternative data terms, as the |
271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| 248 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the |
272 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
| 249 /// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course, |
273 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
| 250 /// $F\_0\'(Aμ-b)=Aμ-b$ is the residual. |
274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 251 pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized { |
275 /// as documented in [`alg_tools::iterate`]. |
| 252 /// Updates the residual and does any necessary pruning of `μ`. |
276 /// |
| 253 /// |
277 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 254 /// Returns the new residual and possibly a new step length. |
278 /// |
| 255 /// |
279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
| 256 /// The measure `μ` may also be modified to apply, e.g., inertia to it. |
280 /// sums of simple functions usign bisection trees, and the related |
| 257 /// The updated residual should correspond to the residual at `μ`. |
281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
| 258 /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual. |
282 /// active at a specific points, and to maximise their sums. Through the implementation of the |
| 259 /// |
283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
| 260 /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate, |
284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
| 261 /// but for, e.g., FISTA has inertia applied to it. |
285 /// |
| 262 fn update( |
286 /// Returns the final iterate. |
| 263 &mut self, |
|
| 264 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
| 265 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 266 ) -> (Observable, Option<F>); |
|
| 267 |
|
| 268 /// Calculates the data term value corresponding to iterate `μ` and available residual. |
|
| 269 /// |
|
| 270 /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`. |
|
| 271 /// |
|
| 272 /// The blanket implementation correspondsn to the 2-norm-squared data fidelity |
|
| 273 /// $\\|\text{residual}\\|\_2^2/2$. |
|
| 274 fn calculate_fit( |
|
| 275 &self, |
|
| 276 _μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 277 residual : &Observable |
|
| 278 ) -> F { |
|
| 279 residual.norm2_squared_div2() |
|
| 280 } |
|
| 281 |
|
| 282 /// Calculates the data term value at $μ$. |
|
| 283 /// |
|
| 284 /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`. |
|
| 285 fn calculate_fit_simple( |
|
| 286 &self, |
|
| 287 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 288 ) -> F; |
|
| 289 |
|
| 290 /// Returns the final iterate after any necessary postprocess pruning, merging, etc. |
|
| 291 fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) |
|
| 292 -> DiscreteMeasure<Loc<F, N>, F> |
|
| 293 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
| 294 μ.merge_spikes_fitness(merging, |
|
| 295 |μ̃| self.calculate_fit_simple(μ̃), |
|
| 296 |&v| v); |
|
| 297 μ.prune(); |
|
| 298 μ |
|
| 299 } |
|
| 300 |
|
| 301 /// Returns measure to be used for value calculations, which may differ from μ. |
|
| 302 fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>) |
|
| 303 -> &'c DiscreteMeasure<Loc<F, N>, F> { |
|
| 304 μ |
|
| 305 } |
|
| 306 } |
|
| 307 |
|
| 308 /// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB. |
|
| 309 struct BasicFB< |
|
| 310 'a, |
|
| 311 F : Float + ToNalgebraRealField, |
|
| 312 A : ForwardModel<Loc<F, N>, F>, |
|
| 313 const N : usize |
|
| 314 > { |
|
| 315 /// The data |
|
| 316 b : &'a A::Observable, |
|
| 317 /// The forward operator |
|
| 318 opA : &'a A, |
|
| 319 } |
|
| 320 |
|
| 321 /// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting. |
|
| 322 #[replace_float_literals(F::cast_from(literal))] |
287 #[replace_float_literals(F::cast_from(literal))] |
| 323 impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize> |
288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( |
| 324 FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> { |
289 opA: &A, |
| 325 fn update( |
290 b: &A::Observable, |
| 326 &mut self, |
291 reg: Reg, |
| 327 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
292 prox_penalty: &P, |
| 328 _μ_base : &DiscreteMeasure<Loc<F, N>, F> |
293 fbconfig: &FBConfig<F>, |
| 329 ) -> (A::Observable, Option<F>) { |
294 iterator: I, |
| 330 μ.prune(); |
295 mut plotter: SeqPlotter<F, N>, |
| 331 //*residual = self.opA.apply(μ) - self.b; |
296 ) -> RNDM<F, N> |
| 332 let mut residual = self.b.clone(); |
297 where |
| 333 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
298 F: Float + ToNalgebraRealField, |
| 334 (residual, None) |
299 I: AlgIteratorFactory<IterInfo<F, N>>, |
| 335 } |
300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
| 336 |
301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
| 337 fn calculate_fit_simple( |
302 A::PreadjointCodomain: RealMapping<F, N>, |
| 338 &self, |
303 PlotLookup: Plotting<N>, |
| 339 μ : &DiscreteMeasure<Loc<F, N>, F>, |
304 RNDM<F, N>: SpikeMerging<F>, |
| 340 ) -> F { |
305 Reg: RegTerm<F, N>, |
| 341 let mut residual = self.b.clone(); |
306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 342 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
307 { |
| 343 residual.norm2_squared_div2() |
308 // Set up parameters |
| 344 } |
309 let config = &fbconfig.generic; |
| 345 } |
310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| 346 |
311 let mut λ = 1.0; |
| 347 /// Specialisation of [`generic_pointsource_fb_reg`] to FISTA. |
312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 348 struct FISTA< |
313 // by τ compared to the conditional gradient approach. |
| 349 'a, |
314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 350 F : Float + ToNalgebraRealField, |
315 let mut ε = tolerance.initial(); |
| 351 A : ForwardModel<Loc<F, N>, F>, |
316 |
| 352 const N : usize |
317 // Initialise iterates |
| 353 > { |
318 let mut μ = DiscreteMeasure::new(); |
| 354 /// The data |
319 let mut μ_prev = DiscreteMeasure::new(); |
| 355 b : &'a A::Observable, |
320 let mut residual = -b; |
| 356 /// The forward operator |
321 let mut warned_merging = false; |
| 357 opA : &'a A, |
322 |
| 358 /// Current inertial parameter |
323 // Statistics |
| 359 λ : F, |
324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { |
| 360 /// Previous iterate without inertia applied. |
325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
| 361 /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will |
326 n_spikes: ν.len(), |
| 362 /// have inertia applied to it, so is not useful to use. |
327 ε, |
| 363 μ_prev : DiscreteMeasure<Loc<F, N>, F>, |
328 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| 364 } |
329 ..stats |
| 365 |
330 }; |
| 366 /// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting. |
331 let mut stats = IterInfo::new(); |
| 367 #[replace_float_literals(F::cast_from(literal))] |
332 |
| 368 impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize> |
333 // Run the algorithm |
| 369 FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> { |
334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 370 fn update( |
335 // Calculate smooth part of surrogate model. |
| 371 &mut self, |
336 let mut τv = opA.preadjoint().apply(residual * τ); |
| 372 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
337 |
| 373 _μ_base : &DiscreteMeasure<Loc<F, N>, F> |
338 // Save current base point |
| 374 ) -> (A::Observable, Option<F>) { |
339 let μ_base = μ.clone(); |
| 375 // Update inertial parameters |
340 |
| 376 let λ_prev = self.λ; |
341 // Insert new spikes and reweigh |
| 377 self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 378 let θ = self.λ / λ_prev - self.λ; |
343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| |
344 ); |
| |
345 |
| |
346 // (Do not) merge spikes. |
| |
347 if config.merge_now(&state) && !warned_merging { |
| |
348 let err = format!("Merging not supported for μFISTA"); |
| |
349 println!("{}", err.red()); |
| |
350 warned_merging = true; |
| |
351 } |
| |
352 |
| |
353 // Update inertial prameters |
| |
354 let λ_prev = λ; |
| |
355 λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt()); |
| |
356 let θ = λ / λ_prev - λ; |
| |
357 |
| 379 // Perform inertial update on μ. |
358 // Perform inertial update on μ. |
| 380 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ |
359 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ |
| 381 // and μ_prev have zero weight. Since both have weights from the finite-dimensional |
360 // and μ_prev have zero weight. Since both have weights from the finite-dimensional |
| 382 // subproblem with a proximal projection step, this is likely to happen when the |
361 // subproblem with a proximal projection step, this is likely to happen when the |
| 383 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
362 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
| 384 // stored in μ_prev. |
363 // stored in μ_prev. |
| 385 μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev); |
364 let n_before_prune = μ.len(); |
| 386 |
365 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); |
| 387 //*residual = self.opA.apply(μ) - self.b; |
366 //let μ_new = (&μ * (1.0 + θ)).sub_matching(&(&μ_prev * θ)); |
| 388 let mut residual = self.b.clone(); |
367 // μ_prev = μ; |
| 389 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
368 // μ = μ_new; |
| 390 (residual, None) |
369 debug_assert!(μ.len() <= n_before_prune); |
| |
370 stats.pruned += n_before_prune - μ.len(); |
| |
371 |
| |
372 // Update residual |
| |
373 residual = calculate_residual(&μ, opA, b); |
| |
374 |
| |
375 let iter = state.iteration(); |
| |
376 stats.this_iters += 1; |
| |
377 |
| |
378 // Give statistics if needed |
| |
379 state.if_verbose(|| { |
| |
380 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev); |
| |
381 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| |
382 }); |
| |
383 |
| |
384 // Update main tolerance for next iteration |
| |
385 ε = tolerance.update(ε, iter); |
| 391 } |
386 } |
| 392 |
387 |
| 393 fn calculate_fit_simple( |
388 postprocess(μ_prev, config, L2Squared, opA, b) |
| 394 &self, |
389 } |
| 395 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 396 ) -> F { |
|
| 397 let mut residual = self.b.clone(); |
|
| 398 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
| 399 residual.norm2_squared_div2() |
|
| 400 } |
|
| 401 |
|
| 402 fn calculate_fit( |
|
| 403 &self, |
|
| 404 _μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 405 _residual : &A::Observable |
|
| 406 ) -> F { |
|
| 407 self.calculate_fit_simple(&self.μ_prev) |
|
| 408 } |
|
| 409 |
|
| 410 // For FISTA we need to do a final pruning as well, due to the limited |
|
| 411 // pruning that can be done on each step. |
|
| 412 fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) |
|
| 413 -> DiscreteMeasure<Loc<F, N>, F> |
|
| 414 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
| 415 let mut μ = self.μ_prev; |
|
| 416 self.μ_prev = μ_base; |
|
| 417 μ.merge_spikes_fitness(merging, |
|
| 418 |μ̃| self.calculate_fit_simple(μ̃), |
|
| 419 |&v| v); |
|
| 420 μ.prune(); |
|
| 421 μ |
|
| 422 } |
|
| 423 |
|
| 424 fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>) |
|
| 425 -> &'c DiscreteMeasure<Loc<F, N>, F> { |
|
| 426 &self.μ_prev |
|
| 427 } |
|
| 428 } |
|
| 429 |
|
| 430 |
|
| 431 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`]. |
|
| 432 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize> |
|
| 433 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> { |
|
| 434 /// Approximately solve the problem |
|
| 435 /// <div>$$ |
|
| 436 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x) |
|
| 437 /// $$</div> |
|
| 438 /// for $G$ depending on the trait implementation. |
|
| 439 /// |
|
| 440 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in |
|
| 441 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`. |
|
| 442 /// |
|
| 443 /// Returns the number of iterations taken. |
|
| 444 fn solve_findim( |
|
| 445 &self, |
|
| 446 mA : &DMatrix<F::MixedType>, |
|
| 447 g : &DVector<F::MixedType>, |
|
| 448 τ : F, |
|
| 449 x : &mut DVector<F::MixedType>, |
|
| 450 mA_normest : F, |
|
| 451 ε : F, |
|
| 452 config : &FBGenericConfig<F> |
|
| 453 ) -> usize; |
|
| 454 |
|
| 455 /// Find a point where `d` may violate the tolerance `ε`. |
|
| 456 /// |
|
| 457 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we |
|
| 458 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the |
|
| 459 /// regulariser. |
|
| 460 /// |
|
| 461 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check |
|
| 462 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there, |
|
| 463 /// and a boolean indicating whether the found point is in bounds. |
|
| 464 fn find_tolerance_violation<G, BT>( |
|
| 465 &self, |
|
| 466 d : &mut BTFN<F, G, BT, N>, |
|
| 467 τ : F, |
|
| 468 ε : F, |
|
| 469 skip_by_rough_check : bool, |
|
| 470 config : &FBGenericConfig<F>, |
|
| 471 ) -> Option<(Loc<F, N>, F, bool)> |
|
| 472 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 473 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 474 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 475 + LocalAnalysis<F, Bounds<F>, N>; |
|
| 476 |
|
| 477 /// Verify that `d` is in bounds `ε` for a merge candidate `μ` |
|
| 478 /// |
|
| 479 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser. |
|
| 480 fn verify_merge_candidate<G, BT>( |
|
| 481 &self, |
|
| 482 d : &mut BTFN<F, G, BT, N>, |
|
| 483 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 484 τ : F, |
|
| 485 ε : F, |
|
| 486 config : &FBGenericConfig<F>, |
|
| 487 ) -> bool |
|
| 488 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 489 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 490 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 491 + LocalAnalysis<F, Bounds<F>, N>; |
|
| 492 |
|
| 493 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>; |
|
| 494 |
|
| 495 /// Returns a scaling factor for the tolerance sequence. |
|
| 496 /// |
|
| 497 /// Typically this is the regularisation parameter. |
|
| 498 fn tolerance_scaling(&self) -> F; |
|
| 499 } |
|
| 500 |
|
| 501 #[replace_float_literals(F::cast_from(literal))] |
|
| 502 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F> |
|
| 503 where Cube<F, N> : P2Minimise<Loc<F, N>, F> { |
|
| 504 fn solve_findim( |
|
| 505 &self, |
|
| 506 mA : &DMatrix<F::MixedType>, |
|
| 507 g : &DVector<F::MixedType>, |
|
| 508 τ : F, |
|
| 509 x : &mut DVector<F::MixedType>, |
|
| 510 mA_normest : F, |
|
| 511 ε : F, |
|
| 512 config : &FBGenericConfig<F> |
|
| 513 ) -> usize { |
|
| 514 let inner_tolerance = ε * config.inner.tolerance_mult; |
|
| 515 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); |
|
| 516 let inner_τ = config.inner.τ0 / mA_normest; |
|
| 517 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x, |
|
| 518 inner_τ, inner_it) |
|
| 519 } |
|
| 520 |
|
| 521 #[inline] |
|
| 522 fn find_tolerance_violation<G, BT>( |
|
| 523 &self, |
|
| 524 d : &mut BTFN<F, G, BT, N>, |
|
| 525 τ : F, |
|
| 526 ε : F, |
|
| 527 skip_by_rough_check : bool, |
|
| 528 config : &FBGenericConfig<F>, |
|
| 529 ) -> Option<(Loc<F, N>, F, bool)> |
|
| 530 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 531 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 532 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 533 + LocalAnalysis<F, Bounds<F>, N> { |
|
| 534 let τα = τ * self.α(); |
|
| 535 let keep_below = τα + ε; |
|
| 536 let maximise_above = τα + ε * config.insertion_cutoff_factor; |
|
| 537 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
| 538 |
|
| 539 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
|
| 540 // the insertion strategy, skip insertion. |
|
| 541 if skip_by_rough_check && d.bounds().upper() <= keep_below { |
|
| 542 None |
|
| 543 } else { |
|
| 544 // If the rough check didn't indicate no insertion needed, find maximising point. |
|
| 545 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps) |
|
| 546 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below)) |
|
| 547 } |
|
| 548 } |
|
| 549 |
|
| 550 fn verify_merge_candidate<G, BT>( |
|
| 551 &self, |
|
| 552 d : &mut BTFN<F, G, BT, N>, |
|
| 553 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 554 τ : F, |
|
| 555 ε : F, |
|
| 556 config : &FBGenericConfig<F>, |
|
| 557 ) -> bool |
|
| 558 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 559 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 560 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 561 + LocalAnalysis<F, Bounds<F>, N> { |
|
| 562 let τα = τ * self.α(); |
|
| 563 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
| 564 let merge_tolerance = config.merge_tolerance_mult * ε; |
|
| 565 let keep_below = τα + merge_tolerance; |
|
| 566 let keep_supp_above = τα - merge_tolerance; |
|
| 567 let bnd = d.bounds(); |
|
| 568 |
|
| 569 return ( |
|
| 570 bnd.lower() >= keep_supp_above |
|
| 571 || |
|
| 572 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { |
|
| 573 (β == 0.0) || d.apply(x) >= keep_supp_above |
|
| 574 }).all(std::convert::identity) |
|
| 575 ) && ( |
|
| 576 bnd.upper() <= keep_below |
|
| 577 || |
|
| 578 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps) |
|
| 579 ) |
|
| 580 } |
|
| 581 |
|
| 582 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> { |
|
| 583 let τα = τ * self.α(); |
|
| 584 Some(Bounds(τα - ε, τα + ε)) |
|
| 585 } |
|
| 586 |
|
| 587 fn tolerance_scaling(&self) -> F { |
|
| 588 self.α() |
|
| 589 } |
|
| 590 } |
|
| 591 |
|
| 592 #[replace_float_literals(F::cast_from(literal))] |
|
| 593 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F> |
|
| 594 where Cube<F, N> : P2Minimise<Loc<F, N>, F> { |
|
| 595 fn solve_findim( |
|
| 596 &self, |
|
| 597 mA : &DMatrix<F::MixedType>, |
|
| 598 g : &DVector<F::MixedType>, |
|
| 599 τ : F, |
|
| 600 x : &mut DVector<F::MixedType>, |
|
| 601 mA_normest: F, |
|
| 602 ε : F, |
|
| 603 config : &FBGenericConfig<F> |
|
| 604 ) -> usize { |
|
| 605 let inner_tolerance = ε * config.inner.tolerance_mult; |
|
| 606 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); |
|
| 607 let inner_τ = config.inner.τ0 / mA_normest; |
|
| 608 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x, |
|
| 609 inner_τ, inner_it) |
|
| 610 } |
|
| 611 |
|
| 612 fn find_tolerance_violation<G, BT>( |
|
| 613 &self, |
|
| 614 d : &mut BTFN<F, G, BT, N>, |
|
| 615 τ : F, |
|
| 616 ε : F, |
|
| 617 skip_by_rough_check : bool, |
|
| 618 config : &FBGenericConfig<F>, |
|
| 619 ) -> Option<(Loc<F, N>, F, bool)> |
|
| 620 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 621 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 622 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 623 + LocalAnalysis<F, Bounds<F>, N> { |
|
| 624 let τα = τ * self.α(); |
|
| 625 let keep_below = τα + ε; |
|
| 626 let keep_above = -τα - ε; |
|
| 627 let maximise_above = τα + ε * config.insertion_cutoff_factor; |
|
| 628 let minimise_below = -τα - ε * config.insertion_cutoff_factor; |
|
| 629 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
| 630 |
|
| 631 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
|
| 632 // the insertion strategy, skip insertion. |
|
| 633 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) { |
|
| 634 None |
|
| 635 } else { |
|
| 636 // If the rough check didn't indicate no insertion needed, find maximising point. |
|
| 637 let mx = d.maximise_above(maximise_above, refinement_tolerance, |
|
| 638 config.refinement.max_steps); |
|
| 639 let mi = d.minimise_below(minimise_below, refinement_tolerance, |
|
| 640 config.refinement.max_steps); |
|
| 641 |
|
| 642 match (mx, mi) { |
|
| 643 (None, None) => None, |
|
| 644 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)), |
|
| 645 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)), |
|
| 646 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => { |
|
| 647 if v_ξ - τα > τα - v_ζ { |
|
| 648 Some((ξ, v_ξ, keep_below >= v_ξ)) |
|
| 649 } else { |
|
| 650 Some((ζ, v_ζ, keep_above <= v_ζ)) |
|
| 651 } |
|
| 652 } |
|
| 653 } |
|
| 654 } |
|
| 655 } |
|
| 656 |
|
| 657 fn verify_merge_candidate<G, BT>( |
|
| 658 &self, |
|
| 659 d : &mut BTFN<F, G, BT, N>, |
|
| 660 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 661 τ : F, |
|
| 662 ε : F, |
|
| 663 config : &FBGenericConfig<F>, |
|
| 664 ) -> bool |
|
| 665 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
| 666 G : SupportGenerator<F, N, Id=BT::Data>, |
|
| 667 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
| 668 + LocalAnalysis<F, Bounds<F>, N> { |
|
| 669 let τα = τ * self.α(); |
|
| 670 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
| 671 let merge_tolerance = config.merge_tolerance_mult * ε; |
|
| 672 let keep_below = τα + merge_tolerance; |
|
| 673 let keep_above = -τα - merge_tolerance; |
|
| 674 let keep_supp_pos_above = τα - merge_tolerance; |
|
| 675 let keep_supp_neg_below = -τα + merge_tolerance; |
|
| 676 let bnd = d.bounds(); |
|
| 677 |
|
| 678 return ( |
|
| 679 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below) |
|
| 680 || |
|
| 681 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { |
|
| 682 use std::cmp::Ordering::*; |
|
| 683 match β.partial_cmp(&0.0) { |
|
| 684 Some(Greater) => d.apply(x) >= keep_supp_pos_above, |
|
| 685 Some(Less) => d.apply(x) <= keep_supp_neg_below, |
|
| 686 _ => true, |
|
| 687 } |
|
| 688 }).all(std::convert::identity) |
|
| 689 ) && ( |
|
| 690 bnd.upper() <= keep_below |
|
| 691 || |
|
| 692 d.has_upper_bound(keep_below, refinement_tolerance, |
|
| 693 config.refinement.max_steps) |
|
| 694 ) && ( |
|
| 695 bnd.lower() >= keep_above |
|
| 696 || |
|
| 697 d.has_lower_bound(keep_above, refinement_tolerance, |
|
| 698 config.refinement.max_steps) |
|
| 699 ) |
|
| 700 } |
|
| 701 |
|
| 702 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> { |
|
| 703 let τα = τ * self.α(); |
|
| 704 Some(Bounds(-τα - ε, τα + ε)) |
|
| 705 } |
|
| 706 |
|
| 707 fn tolerance_scaling(&self) -> F { |
|
| 708 self.α() |
|
| 709 } |
|
| 710 } |
|
| 711 |
|
| 712 |
|
| 713 /// Generic implementation of [`pointsource_fb_reg`]. |
|
| 714 /// |
|
| 715 /// The method can be specialised to even primal-dual proximal splitting through the |
|
| 716 /// [`FBSpecialisation`] parameter `specialisation`. |
|
| 717 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the |
|
| 718 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
| 719 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
|
| 720 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
| 721 /// as documented in [`alg_tools::iterate`]. |
|
| 722 /// |
|
| 723 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 724 /// sums of simple functions usign bisection trees, and the related |
|
| 725 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 726 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 727 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 728 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 729 /// |
|
| 730 /// Returns the final iterate. |
|
| 731 #[replace_float_literals(F::cast_from(literal))] |
|
| 732 pub fn generic_pointsource_fb_reg< |
|
| 733 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize |
|
| 734 >( |
|
| 735 opA : &'a A, |
|
| 736 reg : Reg, |
|
| 737 op𝒟 : &'a 𝒟, |
|
| 738 mut τ : F, |
|
| 739 config : &FBGenericConfig<F>, |
|
| 740 iterator : I, |
|
| 741 mut plotter : SeqPlotter<F, N>, |
|
| 742 mut residual : A::Observable, |
|
| 743 mut specialisation : Spec |
|
| 744 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
| 745 where F : Float + ToNalgebraRealField, |
|
| 746 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 747 Spec : FBSpecialisation<F, A::Observable, N>, |
|
| 748 A::Observable : std::ops::MulAssign<F>, |
|
| 749 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 750 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
| 751 + Lipschitz<𝒟, FloatType=F>, |
|
| 752 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 753 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
| 754 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
| 755 𝒟::Codomain : RealMapping<F, N>, |
|
| 756 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 757 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 758 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 759 PlotLookup : Plotting<N>, |
|
| 760 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 761 Reg : RegTerm<F, N> { |
|
| 762 |
|
| 763 // Set up parameters |
|
| 764 let quiet = iterator.is_quiet(); |
|
| 765 let op𝒟norm = op𝒟.opnorm_bound(); |
|
| 766 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
|
| 767 // by τ compared to the conditional gradient approach. |
|
| 768 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
|
| 769 let mut ε = tolerance.initial(); |
|
| 770 |
|
| 771 // Initialise operators |
|
| 772 let preadjA = opA.preadjoint(); |
|
| 773 |
|
| 774 // Initialise iterates |
|
| 775 let mut μ = DiscreteMeasure::new(); |
|
| 776 |
|
| 777 let mut inner_iters = 0; |
|
| 778 let mut this_iters = 0; |
|
| 779 let mut pruned = 0; |
|
| 780 let mut merged = 0; |
|
| 781 |
|
| 782 let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 783 μ_base : &DiscreteMeasure<Loc<F, N>, F>| { |
|
| 784 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { |
|
| 785 InsertionStyle::Reuse => { |
|
| 786 μ_new.iter_spikes() |
|
| 787 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) |
|
| 788 .map(|(δ, α_base)| (δ.x, α_base - δ.α)) |
|
| 789 .collect() |
|
| 790 }, |
|
| 791 InsertionStyle::Zero => { |
|
| 792 μ_new.iter_spikes() |
|
| 793 .map(|δ| -δ) |
|
| 794 .chain(μ_base.iter_spikes().copied()) |
|
| 795 .collect() |
|
| 796 } |
|
| 797 }; |
|
| 798 ν.prune(); // Potential small performance improvement |
|
| 799 ν |
|
| 800 }; |
|
| 801 |
|
| 802 // Run the algorithm |
|
| 803 iterator.iterate(|state| { |
|
| 804 // Maximum insertion count and measure difference calculation depend on insertion style. |
|
| 805 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
|
| 806 (i, Some((l, k))) if i <= l => (k, false), |
|
| 807 _ => (config.max_insertions, !quiet), |
|
| 808 }; |
|
| 809 let max_insertions = match config.insertion_style { |
|
| 810 InsertionStyle::Zero => { |
|
| 811 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); |
|
| 812 // let n = μ.len(); |
|
| 813 // μ = DiscreteMeasure::new(); |
|
| 814 // n + m |
|
| 815 }, |
|
| 816 InsertionStyle::Reuse => m, |
|
| 817 }; |
|
| 818 |
|
| 819 // Calculate smooth part of surrogate model. |
|
| 820 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
| 821 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
| 822 // the residual and replacing it below before the end of this closure. |
|
| 823 residual *= -τ; |
|
| 824 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
| 825 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
|
| 826 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
|
| 827 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
|
| 828 //let g = &minus_τv + ω0; // Linear term of surrogate model |
|
| 829 |
|
| 830 // Save current base point |
|
| 831 let μ_base = μ.clone(); |
|
| 832 |
|
| 833 // Add points to support until within error tolerance or maximum insertion count reached. |
|
| 834 let mut count = 0; |
|
| 835 let (within_tolerances, d) = 'insertion: loop { |
|
| 836 if μ.len() > 0 { |
|
| 837 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
|
| 838 // from the beginning of the iteration are all contained in the immutable c and g. |
|
| 839 let à = op𝒟.findim_matrix(μ.iter_locations()); |
|
| 840 let g̃ = DVector::from_iterator(μ.len(), |
|
| 841 μ.iter_locations() |
|
| 842 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) |
|
| 843 .map(F::to_nalgebra_mixed)); |
|
| 844 let mut x = μ.masses_dvector(); |
|
| 845 |
|
| 846 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
|
| 847 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
|
| 848 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
|
| 849 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 |
|
| 850 // = n |𝒟| |x|_2, where n is the number of points. Therefore |
|
| 851 let Ã_normest = op𝒟norm * F::cast_from(μ.len()); |
|
| 852 |
|
| 853 // Solve finite-dimensional subproblem. |
|
| 854 inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); |
|
| 855 |
|
| 856 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
| 857 μ.set_masses_dvector(&x); |
|
| 858 } |
|
| 859 |
|
| 860 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality |
|
| 861 // conditions in the predual space, and finding new points for insertion, if necessary. |
|
| 862 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base)); |
|
| 863 |
|
| 864 // If no merging heuristic is used, let's be more conservative about spike insertion, |
|
| 865 // and skip it after first round. If merging is done, being more greedy about spike |
|
| 866 // insertion also seems to improve performance. |
|
| 867 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { |
|
| 868 false |
|
| 869 } else { |
|
| 870 count > 0 |
|
| 871 }; |
|
| 872 |
|
| 873 // Find a spike to insert, if needed |
|
| 874 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( |
|
| 875 &mut d, τ, ε, skip_by_rough_check, config |
|
| 876 ) { |
|
| 877 None => break 'insertion (true, d), |
|
| 878 Some(res) => res, |
|
| 879 }; |
|
| 880 |
|
| 881 // Break if maximum insertion count reached |
|
| 882 if count >= max_insertions { |
|
| 883 break 'insertion (in_bounds, d) |
|
| 884 } |
|
| 885 |
|
| 886 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
|
| 887 μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
| 888 count += 1; |
|
| 889 }; |
|
| 890 |
|
| 891 if !within_tolerances && warn_insertions { |
|
| 892 // Complain (but continue) if we failed to get within tolerances |
|
| 893 // by inserting more points. |
|
| 894 let err = format!("Maximum insertions reached without achieving \ |
|
| 895 subproblem solution tolerance"); |
|
| 896 println!("{}", err.red()); |
|
| 897 } |
|
| 898 |
|
| 899 // Merge spikes |
|
| 900 if state.iteration() % config.merge_every == 0 { |
|
| 901 let n_before_merge = μ.len(); |
|
| 902 μ.merge_spikes(config.merging, |μ_candidate| { |
|
| 903 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base)); |
|
| 904 |
|
| 905 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
| 906 .then_some(()) |
|
| 907 }); |
|
| 908 debug_assert!(μ.len() >= n_before_merge); |
|
| 909 merged += μ.len() - n_before_merge; |
|
| 910 } |
|
| 911 |
|
| 912 let n_before_prune = μ.len(); |
|
| 913 (residual, τ) = match specialisation.update(&mut μ, &μ_base) { |
|
| 914 (r, None) => (r, τ), |
|
| 915 (r, Some(new_τ)) => (r, new_τ) |
|
| 916 }; |
|
| 917 debug_assert!(μ.len() <= n_before_prune); |
|
| 918 pruned += n_before_prune - μ.len(); |
|
| 919 |
|
| 920 this_iters += 1; |
|
| 921 |
|
| 922 // Update main tolerance for next iteration |
|
| 923 let ε_prev = ε; |
|
| 924 ε = tolerance.update(ε, state.iteration()); |
|
| 925 |
|
| 926 // Give function value if needed |
|
| 927 state.if_verbose(|| { |
|
| 928 let value_μ = specialisation.value_μ(&μ); |
|
| 929 // Plot if so requested |
|
| 930 plotter.plot_spikes( |
|
| 931 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
| 932 "start".to_string(), Some(&minus_τv), |
|
| 933 reg.target_bounds(τ, ε_prev), value_μ, |
|
| 934 ); |
|
| 935 // Calculate mean inner iterations and reset relevant counters. |
|
| 936 // Return the statistics |
|
| 937 let res = IterInfo { |
|
| 938 value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ), |
|
| 939 n_spikes : value_μ.len(), |
|
| 940 inner_iters, |
|
| 941 this_iters, |
|
| 942 merged, |
|
| 943 pruned, |
|
| 944 ε : ε_prev, |
|
| 945 postprocessing: config.postprocessing.then(|| value_μ.clone()), |
|
| 946 }; |
|
| 947 inner_iters = 0; |
|
| 948 this_iters = 0; |
|
| 949 merged = 0; |
|
| 950 pruned = 0; |
|
| 951 res |
|
| 952 }) |
|
| 953 }); |
|
| 954 |
|
| 955 specialisation.postprocess(μ, config.final_merging) |
|
| 956 } |
|
| 957 |
|
| 958 /// Iteratively solve the pointsource localisation problem using forward-backward splitting |
|
| 959 /// |
|
| 960 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
|
| 961 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
| 962 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
|
| 963 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
| 964 /// as documented in [`alg_tools::iterate`]. |
|
| 965 /// |
|
| 966 /// For details on the mathematical formulation, see the [module level](self) documentation. |
|
| 967 /// |
|
| 968 /// Returns the final iterate. |
|
| 969 #[replace_float_literals(F::cast_from(literal))] |
|
| 970 pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( |
|
| 971 opA : &'a A, |
|
| 972 b : &A::Observable, |
|
| 973 reg : Reg, |
|
| 974 op𝒟 : &'a 𝒟, |
|
| 975 config : &FBConfig<F>, |
|
| 976 iterator : I, |
|
| 977 plotter : SeqPlotter<F, N>, |
|
| 978 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
| 979 where F : Float + ToNalgebraRealField, |
|
| 980 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 981 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
| 982 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
| 983 A::Observable : std::ops::MulAssign<F>, |
|
| 984 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 985 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
| 986 + Lipschitz<𝒟, FloatType=F>, |
|
| 987 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 988 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
| 989 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
| 990 𝒟::Codomain : RealMapping<F, N>, |
|
| 991 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 992 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 993 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 994 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 995 PlotLookup : Plotting<N>, |
|
| 996 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 997 Reg : RegTerm<F, N> { |
|
| 998 |
|
| 999 let initial_residual = -b; |
|
| 1000 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
|
| 1001 |
|
| 1002 match config.meta { |
|
| 1003 FBMetaAlgorithm::None => generic_pointsource_fb_reg( |
|
| 1004 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, |
|
| 1005 BasicFB{ b, opA }, |
|
| 1006 ), |
|
| 1007 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg( |
|
| 1008 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, |
|
| 1009 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() }, |
|
| 1010 ), |
|
| 1011 } |
|
| 1012 } |
|
| 1013 |
|
| 1014 // |
|
| 1015 // Deprecated interfaces |
|
| 1016 // |
|
| 1017 |
|
| 1018 #[deprecated(note = "Use `pointsource_fb_reg`")] |
|
| 1019 pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>( |
|
| 1020 opA : &'a A, |
|
| 1021 b : &A::Observable, |
|
| 1022 α : F, |
|
| 1023 op𝒟 : &'a 𝒟, |
|
| 1024 config : &FBConfig<F>, |
|
| 1025 iterator : I, |
|
| 1026 plotter : SeqPlotter<F, N> |
|
| 1027 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
| 1028 where F : Float + ToNalgebraRealField, |
|
| 1029 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 1030 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
| 1031 A::Observable : std::ops::MulAssign<F>, |
|
| 1032 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 1033 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
| 1034 + Lipschitz<𝒟, FloatType=F>, |
|
| 1035 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 1036 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
| 1037 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
| 1038 𝒟::Codomain : RealMapping<F, N>, |
|
| 1039 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 1040 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 1041 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 1042 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 1043 PlotLookup : Plotting<N>, |
|
| 1044 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
| 1045 |
|
| 1046 pointsource_fb_reg(opA, b, NonnegRadonRegTerm(α), op𝒟, config, iterator, plotter) |
|
| 1047 } |
|
| 1048 |
|
| 1049 |
|
| 1050 #[deprecated(note = "Use `generic_pointsource_fb_reg`")] |
|
| 1051 pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>( |
|
| 1052 opA : &'a A, |
|
| 1053 α : F, |
|
| 1054 op𝒟 : &'a 𝒟, |
|
| 1055 τ : F, |
|
| 1056 config : &FBGenericConfig<F>, |
|
| 1057 iterator : I, |
|
| 1058 plotter : SeqPlotter<F, N>, |
|
| 1059 residual : A::Observable, |
|
| 1060 specialisation : Spec, |
|
| 1061 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
| 1062 where F : Float + ToNalgebraRealField, |
|
| 1063 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
| 1064 Spec : FBSpecialisation<F, A::Observable, N>, |
|
| 1065 A::Observable : std::ops::MulAssign<F>, |
|
| 1066 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 1067 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
| 1068 + Lipschitz<𝒟, FloatType=F>, |
|
| 1069 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 1070 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
| 1071 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
| 1072 𝒟::Codomain : RealMapping<F, N>, |
|
| 1073 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 1074 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 1075 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 1076 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 1077 PlotLookup : Plotting<N>, |
|
| 1078 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
| 1079 |
|
| 1080 generic_pointsource_fb_reg(opA, NonnegRadonRegTerm(α), op𝒟, τ, config, iterator, plotter, |
|
| 1081 residual, specialisation) |
|
| 1082 } |
|