| 72 \end{aligned} |
72 \end{aligned} |
| 73 $$ |
73 $$ |
| 74 </p> |
74 </p> |
| 75 |
75 |
| 76 We solve this with either SSN or FB as determined by |
76 We solve this with either SSN or FB as determined by |
| 77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. |
77 [`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`]. |
| 78 */ |
78 */ |
| 79 |
79 |
| |
80 use crate::measures::merging::SpikeMerging; |
| |
81 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
82 use crate::plot::Plotter; |
| |
83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; |
| |
84 use crate::regularisation::RegTerm; |
| |
85 use crate::types::*; |
| |
86 use alg_tools::error::DynResult; |
| |
87 use alg_tools::instance::Instance; |
| |
88 use alg_tools::iterate::AlgIteratorFactory; |
| |
89 use alg_tools::mapping::DifferentiableMapping; |
| |
90 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 80 use colored::Colorize; |
91 use colored::Colorize; |
| 81 use numeric_literals::replace_float_literals; |
92 use numeric_literals::replace_float_literals; |
| 82 use serde::{Deserialize, Serialize}; |
93 use serde::{Deserialize, Serialize}; |
| 83 |
|
| 84 use alg_tools::euclidean::Euclidean; |
|
| 85 use alg_tools::instance::Instance; |
|
| 86 use alg_tools::iterate::AlgIteratorFactory; |
|
| 87 use alg_tools::linops::{Mapping, GEMV}; |
|
| 88 use alg_tools::mapping::RealMapping; |
|
| 89 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 90 |
|
| 91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; |
|
| 92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; |
|
| 93 use crate::measures::merging::SpikeMerging; |
|
| 94 use crate::measures::{DiscreteMeasure, RNDM}; |
|
| 95 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
|
| 96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; |
|
| 97 use crate::regularisation::RegTerm; |
|
| 98 use crate::types::*; |
|
| 99 |
94 |
| 100 /// Settings for [`pointsource_fb_reg`]. |
95 /// Settings for [`pointsource_fb_reg`]. |
| 101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
96 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 102 #[serde(default)] |
97 #[serde(default)] |
| 103 pub struct FBConfig<F: Float> { |
98 pub struct FBConfig<F: Float> { |
| 104 /// Step length scaling |
99 /// Step length scaling |
| 105 pub τ0: F, |
100 pub τ0: F, |
| |
101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] |
| |
102 pub σp0: F, |
| 106 /// Generic parameters |
103 /// Generic parameters |
| 107 pub generic: FBGenericConfig<F>, |
104 pub insertion: InsertionConfig<F>, |
| 108 } |
105 } |
| 109 |
106 |
| 110 #[replace_float_literals(F::cast_from(literal))] |
107 #[replace_float_literals(F::cast_from(literal))] |
| 111 impl<F: Float> Default for FBConfig<F> { |
108 impl<F: Float> Default for FBConfig<F> { |
| 112 fn default() -> Self { |
109 fn default() -> Self { |
| 113 FBConfig { |
110 FBConfig { |
| 114 τ0: 0.99, |
111 τ0: 0.99, |
| 115 generic: Default::default(), |
112 σp0: 0.99, |
| |
113 insertion: Default::default(), |
| 116 } |
114 } |
| 117 } |
115 } |
| 118 } |
116 } |
| 119 |
117 |
| 120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { |
118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
| 121 let n_before_prune = μ.len(); |
119 let n_before_prune = μ.len(); |
| 122 μ.prune(); |
120 μ.prune(); |
| 123 debug_assert!(μ.len() <= n_before_prune); |
121 debug_assert!(μ.len() <= n_before_prune); |
| 124 n_before_prune - μ.len() |
122 n_before_prune - μ.len() |
| 125 } |
123 } |
| 126 |
124 |
| 127 #[replace_float_literals(F::cast_from(literal))] |
125 #[replace_float_literals(F::cast_from(literal))] |
| 128 pub(crate) fn postprocess< |
126 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( |
| 129 F: Float, |
127 mut μ: RNDM<N, F>, |
| 130 V: Euclidean<F> + Clone, |
128 config: &InsertionConfig<F>, |
| 131 A: GEMV<F, RNDM<F, N>, Codomain = V>, |
129 f: Dat, |
| 132 D: DataTerm<F, V, N>, |
130 ) -> DynResult<RNDM<N, F>> |
| 133 const N: usize, |
|
| 134 >( |
|
| 135 mut μ: RNDM<F, N>, |
|
| 136 config: &FBGenericConfig<F>, |
|
| 137 dataterm: D, |
|
| 138 opA: &A, |
|
| 139 b: &V, |
|
| 140 ) -> RNDM<F, N> |
|
| 141 where |
131 where |
| 142 RNDM<F, N>: SpikeMerging<F>, |
132 RNDM<N, F>: SpikeMerging<F>, |
| 143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, |
133 for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>, |
| 144 { |
134 { |
| 145 μ.merge_spikes_fitness( |
135 //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v); |
| 146 config.final_merging_method(), |
136 μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v); |
| 147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
|
| 148 |&v| v, |
|
| 149 ); |
|
| 150 μ.prune(); |
137 μ.prune(); |
| 151 μ |
138 Ok(μ) |
| 152 } |
139 } |
| 153 |
140 |
| 154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
141 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
| 155 /// |
142 /// |
| 156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
143 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| 159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
146 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 160 /// as documented in [`alg_tools::iterate`]. |
147 /// as documented in [`alg_tools::iterate`]. |
| 161 /// |
148 /// |
| 162 /// For details on the mathematical formulation, see the [module level](self) documentation. |
149 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 163 /// |
150 /// |
| 164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 165 /// sums of simple functions usign bisection trees, and the related |
|
| 166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 167 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 170 /// |
|
| 171 /// Returns the final iterate. |
151 /// Returns the final iterate. |
| 172 #[replace_float_literals(F::cast_from(literal))] |
152 #[replace_float_literals(F::cast_from(literal))] |
| 173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( |
153 pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>( |
| 174 opA: &A, |
154 f: &Dat, |
| 175 b: &A::Observable, |
155 reg: &Reg, |
| 176 reg: Reg, |
|
| 177 prox_penalty: &P, |
156 prox_penalty: &P, |
| 178 fbconfig: &FBConfig<F>, |
157 fbconfig: &FBConfig<F>, |
| 179 iterator: I, |
158 iterator: I, |
| 180 mut plotter: SeqPlotter<F, N>, |
159 mut plotter: Plot, |
| 181 ) -> RNDM<F, N> |
160 μ0 : Option<RNDM<N, F>>, |
| |
161 ) -> DynResult<RNDM<N, F>> |
| 182 where |
162 where |
| 183 F: Float + ToNalgebraRealField, |
163 F: Float + ToNalgebraRealField, |
| 184 I: AlgIteratorFactory<IterInfo<F, N>>, |
164 I: AlgIteratorFactory<IterInfo<F>>, |
| 185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
165 RNDM<N, F>: SpikeMerging<F>, |
| 186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
166 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 187 A::PreadjointCodomain: RealMapping<F, N>, |
167 Dat::DerivativeDomain: ClosedMul<F>, |
| 188 PlotLookup: Plotting<N>, |
168 Reg: RegTerm<Loc<N, F>, F>, |
| 189 RNDM<F, N>: SpikeMerging<F>, |
169 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 190 Reg: RegTerm<F, N>, |
170 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 192 { |
171 { |
| 193 // Set up parameters |
172 // Set up parameters |
| 194 let config = &fbconfig.generic; |
173 let config = &fbconfig.insertion; |
| 195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
174 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
| 196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
175 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 197 // by τ compared to the conditional gradient approach. |
176 // by τ compared to the conditional gradient approach. |
| 198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
177 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 199 let mut ε = tolerance.initial(); |
178 let mut ε = tolerance.initial(); |
| 200 |
179 |
| 201 // Initialise iterates |
180 // Initialise iterates |
| 202 let mut μ = DiscreteMeasure::new(); |
181 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 203 let mut residual = -b; |
|
| 204 |
182 |
| 205 // Statistics |
183 // Statistics |
| 206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
184 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { |
| 207 value: residual.norm2_squared_div2() + reg.apply(μ), |
185 value: f.apply(μ) + reg.apply(μ), |
| 208 n_spikes: μ.len(), |
186 n_spikes: μ.len(), |
| 209 ε, |
187 ε, |
| 210 //postprocessing: config.postprocessing.then(|| μ.clone()), |
188 //postprocessing: config.postprocessing.then(|| μ.clone()), |
| 211 ..stats |
189 ..stats |
| 212 }; |
190 }; |
| 213 let mut stats = IterInfo::new(); |
191 let mut stats = IterInfo::new(); |
| 214 |
192 |
| 215 // Run the algorithm |
193 // Run the algorithm |
| 216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 217 // Calculate smooth part of surrogate model. |
195 // Calculate smooth part of surrogate model. |
| 218 let mut τv = opA.preadjoint().apply(residual * τ); |
196 // TODO: optimise τ to be applied to residual. |
| |
197 let mut τv = f.differential(&μ) * τ; |
| 219 |
198 |
| 220 // Save current base point |
199 // Save current base point |
| 221 let μ_base = μ.clone(); |
200 let μ_base = μ.clone(); |
| 222 |
201 |
| 223 // Insert and reweigh |
202 // Insert and reweigh |
| 224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 226 ); |
205 )?; |
| 227 |
206 |
| 228 // Prune and possibly merge spikes |
207 // Prune and possibly merge spikes |
| 229 if config.merge_now(&state) { |
208 if config.merge_now(&state) { |
| 230 stats.merged += prox_penalty.merge_spikes( |
209 stats.merged += prox_penalty.merge_spikes( |
| 231 &mut μ, |
210 &mut μ, |
| 234 None, |
213 None, |
| 235 τ, |
214 τ, |
| 236 ε, |
215 ε, |
| 237 config, |
216 config, |
| 238 ®, |
217 ®, |
| 239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
218 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), |
| 240 ); |
219 ); |
| 241 } |
220 } |
| 242 |
221 |
| 243 stats.pruned += prune_with_stats(&mut μ); |
222 stats.pruned += prune_with_stats(&mut μ); |
| 244 |
|
| 245 // Update residual |
|
| 246 residual = calculate_residual(&μ, opA, b); |
|
| 247 |
223 |
| 248 let iter = state.iteration(); |
224 let iter = state.iteration(); |
| 249 stats.this_iters += 1; |
225 stats.this_iters += 1; |
| 250 |
226 |
| 251 // Give statistics if needed |
227 // Give statistics if needed |
| 252 state.if_verbose(|| { |
228 state.if_verbose(|| { |
| 253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
229 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| 254 full_stats( |
230 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 255 &residual, |
|
| 256 &μ, |
|
| 257 ε, |
|
| 258 std::mem::replace(&mut stats, IterInfo::new()), |
|
| 259 ) |
|
| 260 }); |
231 }); |
| 261 |
232 |
| 262 // Update main tolerance for next iteration |
233 // Update main tolerance for next iteration |
| 263 ε = tolerance.update(ε, iter); |
234 ε = tolerance.update(ε, iter); |
| 264 } |
235 } |
| 265 |
236 |
| 266 postprocess(μ, config, L2Squared, opA, b) |
237 //postprocess(μ_prev, config, f) |
| |
238 postprocess(μ, config, |μ̃| f.apply(μ̃)) |
| 267 } |
239 } |
| 268 |
240 |
| 269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
241 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
| 270 /// |
242 /// |
| 271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
243 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| 274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
246 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 275 /// as documented in [`alg_tools::iterate`]. |
247 /// as documented in [`alg_tools::iterate`]. |
| 276 /// |
248 /// |
| 277 /// For details on the mathematical formulation, see the [module level](self) documentation. |
249 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 278 /// |
250 /// |
| 279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 280 /// sums of simple functions usign bisection trees, and the related |
|
| 281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 282 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 285 /// |
|
| 286 /// Returns the final iterate. |
251 /// Returns the final iterate. |
| 287 #[replace_float_literals(F::cast_from(literal))] |
252 #[replace_float_literals(F::cast_from(literal))] |
| 288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( |
253 pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>( |
| 289 opA: &A, |
254 f: &Dat, |
| 290 b: &A::Observable, |
255 reg: &Reg, |
| 291 reg: Reg, |
|
| 292 prox_penalty: &P, |
256 prox_penalty: &P, |
| 293 fbconfig: &FBConfig<F>, |
257 fbconfig: &FBConfig<F>, |
| 294 iterator: I, |
258 iterator: I, |
| 295 mut plotter: SeqPlotter<F, N>, |
259 mut plotter: Plot, |
| 296 ) -> RNDM<F, N> |
260 μ0: Option<RNDM<N, F>> |
| |
261 ) -> DynResult<RNDM<N, F>> |
| 297 where |
262 where |
| 298 F: Float + ToNalgebraRealField, |
263 F: Float + ToNalgebraRealField, |
| 299 I: AlgIteratorFactory<IterInfo<F, N>>, |
264 I: AlgIteratorFactory<IterInfo<F>>, |
| 300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
265 RNDM<N, F>: SpikeMerging<F>, |
| 301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
266 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 302 A::PreadjointCodomain: RealMapping<F, N>, |
267 Dat::DerivativeDomain: ClosedMul<F>, |
| 303 PlotLookup: Plotting<N>, |
268 Reg: RegTerm<Loc<N, F>, F>, |
| 304 RNDM<F, N>: SpikeMerging<F>, |
269 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 305 Reg: RegTerm<F, N>, |
270 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 307 { |
271 { |
| 308 // Set up parameters |
272 // Set up parameters |
| 309 let config = &fbconfig.generic; |
273 let config = &fbconfig.insertion; |
| 310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
274 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
| 311 let mut λ = 1.0; |
275 let mut λ = 1.0; |
| 312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
276 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 313 // by τ compared to the conditional gradient approach. |
277 // by τ compared to the conditional gradient approach. |
| 314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
278 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 315 let mut ε = tolerance.initial(); |
279 let mut ε = tolerance.initial(); |
| 316 |
280 |
| 317 // Initialise iterates |
281 // Initialise iterates |
| 318 let mut μ = DiscreteMeasure::new(); |
282 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 319 let mut μ_prev = DiscreteMeasure::new(); |
283 let mut μ_prev = μ.clone(); |
| 320 let mut residual = -b; |
|
| 321 let mut warned_merging = false; |
284 let mut warned_merging = false; |
| 322 |
285 |
| 323 // Statistics |
286 // Statistics |
| 324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { |
287 let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo { |
| 325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
288 value: f.apply(ν) + reg.apply(ν), |
| 326 n_spikes: ν.len(), |
289 n_spikes: ν.len(), |
| 327 ε, |
290 ε, |
| 328 // postprocessing: config.postprocessing.then(|| ν.clone()), |
291 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| 329 ..stats |
292 ..stats |
| 330 }; |
293 }; |
| 331 let mut stats = IterInfo::new(); |
294 let mut stats = IterInfo::new(); |
| 332 |
295 |
| 333 // Run the algorithm |
296 // Run the algorithm |
| 334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 335 // Calculate smooth part of surrogate model. |
298 // Calculate smooth part of surrogate model. |
| 336 let mut τv = opA.preadjoint().apply(residual * τ); |
299 let mut τv = f.differential(&μ) * τ; |
| 337 |
300 |
| 338 // Save current base point |
301 // Save current base point |
| 339 let μ_base = μ.clone(); |
302 let μ_base = μ.clone(); |
| 340 |
303 |
| 341 // Insert new spikes and reweigh |
304 // Insert new spikes and reweigh |
| 342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 344 ); |
307 )?; |
| 345 |
308 |
| 346 // (Do not) merge spikes. |
309 // (Do not) merge spikes. |
| 347 if config.merge_now(&state) && !warned_merging { |
310 if config.merge_now(&state) && !warned_merging { |
| 348 let err = format!("Merging not supported for μFISTA"); |
311 let err = format!("Merging not supported for μFISTA"); |
| 349 println!("{}", err.red()); |
312 println!("{}", err.red()); |