| 72 \end{aligned} |
72 \end{aligned} |
| 73 $$ |
73 $$ |
| 74 </p> |
74 </p> |
| 75 |
75 |
| 76 We solve this with either SSN or FB as determined by |
76 We solve this with either SSN or FB as determined by |
| 77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. |
77 [`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`]. |
| 78 */ |
78 */ |
| 79 |
79 |
| |
80 use crate::measures::merging::SpikeMerging; |
| |
81 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
82 use crate::plot::Plotter; |
| |
83 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; |
| |
84 use crate::regularisation::RegTerm; |
| |
85 use crate::types::*; |
| |
86 use alg_tools::error::DynResult; |
| |
87 use alg_tools::instance::Instance; |
| |
88 use alg_tools::iterate::AlgIteratorFactory; |
| |
89 use alg_tools::mapping::DifferentiableMapping; |
| |
90 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 80 use colored::Colorize; |
91 use colored::Colorize; |
| 81 use numeric_literals::replace_float_literals; |
92 use numeric_literals::replace_float_literals; |
| 82 use serde::{Deserialize, Serialize}; |
93 use serde::{Deserialize, Serialize}; |
| 83 |
|
| 84 use alg_tools::euclidean::Euclidean; |
|
| 85 use alg_tools::instance::Instance; |
|
| 86 use alg_tools::iterate::AlgIteratorFactory; |
|
| 87 use alg_tools::linops::{Mapping, GEMV}; |
|
| 88 use alg_tools::mapping::RealMapping; |
|
| 89 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 90 |
|
| 91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; |
|
| 92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; |
|
| 93 use crate::measures::merging::SpikeMerging; |
|
| 94 use crate::measures::{DiscreteMeasure, RNDM}; |
|
| 95 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
|
| 96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; |
|
| 97 use crate::regularisation::RegTerm; |
|
| 98 use crate::types::*; |
|
| 99 |
94 |
| 100 /// Settings for [`pointsource_fb_reg`]. |
95 /// Settings for [`pointsource_fb_reg`]. |
| 101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
96 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 102 #[serde(default)] |
97 #[serde(default)] |
| 103 pub struct FBConfig<F: Float> { |
98 pub struct FBConfig<F: Float> { |
| 104 /// Step length scaling |
99 /// Step length scaling |
| 105 pub τ0: F, |
100 pub τ0: F, |
| |
101 // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] |
| |
102 pub σp0: F, |
| 106 /// Generic parameters |
103 /// Generic parameters |
| 107 pub generic: FBGenericConfig<F>, |
104 pub insertion: InsertionConfig<F>, |
| 108 } |
105 } |
| 109 |
106 |
| 110 #[replace_float_literals(F::cast_from(literal))] |
107 #[replace_float_literals(F::cast_from(literal))] |
| 111 impl<F: Float> Default for FBConfig<F> { |
108 impl<F: Float> Default for FBConfig<F> { |
| 112 fn default() -> Self { |
109 fn default() -> Self { |
| 113 FBConfig { |
110 FBConfig { τ0: 0.99, σp0: 0.99, insertion: Default::default() } |
| 114 τ0: 0.99, |
|
| 115 generic: Default::default(), |
|
| 116 } |
|
| 117 } |
111 } |
| 118 } |
112 } |
| 119 |
113 |
| 120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { |
114 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
| 121 let n_before_prune = μ.len(); |
115 let n_before_prune = μ.len(); |
| 122 μ.prune(); |
116 μ.prune(); |
| 123 debug_assert!(μ.len() <= n_before_prune); |
117 debug_assert!(μ.len() <= n_before_prune); |
| 124 n_before_prune - μ.len() |
118 n_before_prune - μ.len() |
| 125 } |
119 } |
| 126 |
120 |
| 127 #[replace_float_literals(F::cast_from(literal))] |
121 #[replace_float_literals(F::cast_from(literal))] |
| 128 pub(crate) fn postprocess< |
122 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( |
| 129 F: Float, |
123 mut μ: RNDM<N, F>, |
| 130 V: Euclidean<F> + Clone, |
124 config: &InsertionConfig<F>, |
| 131 A: GEMV<F, RNDM<F, N>, Codomain = V>, |
125 f: Dat, |
| 132 D: DataTerm<F, V, N>, |
126 ) -> DynResult<RNDM<N, F>> |
| 133 const N: usize, |
|
| 134 >( |
|
| 135 mut μ: RNDM<F, N>, |
|
| 136 config: &FBGenericConfig<F>, |
|
| 137 dataterm: D, |
|
| 138 opA: &A, |
|
| 139 b: &V, |
|
| 140 ) -> RNDM<F, N> |
|
| 141 where |
127 where |
| 142 RNDM<F, N>: SpikeMerging<F>, |
128 RNDM<N, F>: SpikeMerging<F>, |
| 143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, |
129 for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>, |
| 144 { |
130 { |
| 145 μ.merge_spikes_fitness( |
131 //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v); |
| 146 config.final_merging_method(), |
132 μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v); |
| 147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
|
| 148 |&v| v, |
|
| 149 ); |
|
| 150 μ.prune(); |
133 μ.prune(); |
| 151 μ |
134 Ok(μ) |
| 152 } |
135 } |
| 153 |
136 |
| 154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
137 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
| 155 /// |
138 /// |
| 156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
139 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| 159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
142 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 160 /// as documented in [`alg_tools::iterate`]. |
143 /// as documented in [`alg_tools::iterate`]. |
| 161 /// |
144 /// |
| 162 /// For details on the mathematical formulation, see the [module level](self) documentation. |
145 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 163 /// |
146 /// |
| 164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 165 /// sums of simple functions usign bisection trees, and the related |
|
| 166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 167 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 170 /// |
|
| 171 /// Returns the final iterate. |
147 /// Returns the final iterate. |
| 172 #[replace_float_literals(F::cast_from(literal))] |
148 #[replace_float_literals(F::cast_from(literal))] |
| 173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( |
149 pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>( |
| 174 opA: &A, |
150 f: &Dat, |
| 175 b: &A::Observable, |
151 reg: &Reg, |
| 176 reg: Reg, |
|
| 177 prox_penalty: &P, |
152 prox_penalty: &P, |
| 178 fbconfig: &FBConfig<F>, |
153 fbconfig: &FBConfig<F>, |
| 179 iterator: I, |
154 iterator: I, |
| 180 mut plotter: SeqPlotter<F, N>, |
155 mut plotter: Plot, |
| 181 ) -> RNDM<F, N> |
156 μ0: Option<RNDM<N, F>>, |
| |
157 ) -> DynResult<RNDM<N, F>> |
| 182 where |
158 where |
| 183 F: Float + ToNalgebraRealField, |
159 F: Float + ToNalgebraRealField, |
| 184 I: AlgIteratorFactory<IterInfo<F, N>>, |
160 I: AlgIteratorFactory<IterInfo<F>>, |
| 185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
161 RNDM<N, F>: SpikeMerging<F>, |
| 186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
162 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 187 A::PreadjointCodomain: RealMapping<F, N>, |
163 Dat::DerivativeDomain: ClosedMul<F>, |
| 188 PlotLookup: Plotting<N>, |
164 Reg: RegTerm<Loc<N, F>, F>, |
| 189 RNDM<F, N>: SpikeMerging<F>, |
165 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 190 Reg: RegTerm<F, N>, |
166 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 192 { |
167 { |
| 193 // Set up parameters |
168 // Set up parameters |
| 194 let config = &fbconfig.generic; |
169 let config = &fbconfig.insertion; |
| 195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
170 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
| 196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
171 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 197 // by τ compared to the conditional gradient approach. |
172 // by τ compared to the conditional gradient approach. |
| 198 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
173 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 199 let mut ε = tolerance.initial(); |
174 let mut ε = tolerance.initial(); |
| 200 |
175 |
| 201 // Initialise iterates |
176 // Initialise iterates |
| 202 let mut μ = DiscreteMeasure::new(); |
177 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 203 let mut residual = -b; |
|
| 204 |
178 |
| 205 // Statistics |
179 // Statistics |
| 206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { |
180 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { |
| 207 value: residual.norm2_squared_div2() + reg.apply(μ), |
181 value: f.apply(μ) + reg.apply(μ), |
| 208 n_spikes: μ.len(), |
182 n_spikes: μ.len(), |
| 209 ε, |
183 ε, |
| 210 //postprocessing: config.postprocessing.then(|| μ.clone()), |
184 //postprocessing: config.postprocessing.then(|| μ.clone()), |
| 211 ..stats |
185 ..stats |
| 212 }; |
186 }; |
| 213 let mut stats = IterInfo::new(); |
187 let mut stats = IterInfo::new(); |
| 214 |
188 |
| 215 // Run the algorithm |
189 // Run the algorithm |
| 216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
190 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 217 // Calculate smooth part of surrogate model. |
191 // Calculate smooth part of surrogate model. |
| 218 let mut τv = opA.preadjoint().apply(residual * τ); |
192 // TODO: optimise τ to be applied to residual. |
| 219 |
193 let mut τv = f.differential(&μ) * τ; |
| 220 // Save current base point |
194 |
| 221 let μ_base = μ.clone(); |
195 // Save current base point for merge |
| |
196 let μ_base_len = μ.len(); |
| |
197 let maybe_μ_base = config.merge_now(&state).then(|| μ.clone()); |
| 222 |
198 |
| 223 // Insert and reweigh |
199 // Insert and reweigh |
| 224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
200 let (maybe_d, _within_tolerances) = prox_penalty |
| 225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
201 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, ®, &state, &mut stats)?; |
| 226 ); |
202 |
| |
203 stats.inserted += μ.len() - μ_base_len; |
| 227 |
204 |
| 228 // Prune and possibly merge spikes |
205 // Prune and possibly merge spikes |
| 229 if config.merge_now(&state) { |
206 if let Some(μ_base) = maybe_μ_base { |
| 230 stats.merged += prox_penalty.merge_spikes( |
207 stats.merged += prox_penalty.merge_spikes( |
| 231 &mut μ, |
208 &mut μ, |
| 232 &mut τv, |
209 &mut τv, |
| 233 &μ_base, |
210 &μ_base, |
| 234 None, |
|
| 235 τ, |
211 τ, |
| 236 ε, |
212 ε, |
| 237 config, |
213 config, |
| 238 ®, |
214 ®, |
| 239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
215 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), |
| 240 ); |
216 ); |
| 241 } |
217 } |
| 242 |
218 |
| 243 stats.pruned += prune_with_stats(&mut μ); |
219 stats.pruned += prune_with_stats(&mut μ); |
| 244 |
|
| 245 // Update residual |
|
| 246 residual = calculate_residual(&μ, opA, b); |
|
| 247 |
220 |
| 248 let iter = state.iteration(); |
221 let iter = state.iteration(); |
| 249 stats.this_iters += 1; |
222 stats.this_iters += 1; |
| 250 |
223 |
| 251 // Give statistics if needed |
224 // Give statistics if needed |
| 252 state.if_verbose(|| { |
225 state.if_verbose(|| { |
| 253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
226 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| 254 full_stats( |
227 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 255 &residual, |
|
| 256 &μ, |
|
| 257 ε, |
|
| 258 std::mem::replace(&mut stats, IterInfo::new()), |
|
| 259 ) |
|
| 260 }); |
228 }); |
| 261 |
229 |
| 262 // Update main tolerance for next iteration |
230 // Update main tolerance for next iteration |
| 263 ε = tolerance.update(ε, iter); |
231 ε = tolerance.update(ε, iter); |
| 264 } |
232 } |
| 265 |
233 |
| 266 postprocess(μ, config, L2Squared, opA, b) |
234 //postprocess(μ_prev, config, f) |
| |
235 postprocess(μ, config, |μ̃| f.apply(μ̃)) |
| 267 } |
236 } |
| 268 |
237 |
| 269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
238 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
| 270 /// |
239 /// |
| 271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
240 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
| 274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
243 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 275 /// as documented in [`alg_tools::iterate`]. |
244 /// as documented in [`alg_tools::iterate`]. |
| 276 /// |
245 /// |
| 277 /// For details on the mathematical formulation, see the [module level](self) documentation. |
246 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 278 /// |
247 /// |
| 279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
| 280 /// sums of simple functions usign bisection trees, and the related |
|
| 281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
| 282 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
| 283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
| 284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
| 285 /// |
|
| 286 /// Returns the final iterate. |
248 /// Returns the final iterate. |
| 287 #[replace_float_literals(F::cast_from(literal))] |
249 #[replace_float_literals(F::cast_from(literal))] |
| 288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( |
250 pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>( |
| 289 opA: &A, |
251 f: &Dat, |
| 290 b: &A::Observable, |
252 reg: &Reg, |
| 291 reg: Reg, |
|
| 292 prox_penalty: &P, |
253 prox_penalty: &P, |
| 293 fbconfig: &FBConfig<F>, |
254 fbconfig: &FBConfig<F>, |
| 294 iterator: I, |
255 iterator: I, |
| 295 mut plotter: SeqPlotter<F, N>, |
256 mut plotter: Plot, |
| 296 ) -> RNDM<F, N> |
257 μ0: Option<RNDM<N, F>>, |
| |
258 ) -> DynResult<RNDM<N, F>> |
| 297 where |
259 where |
| 298 F: Float + ToNalgebraRealField, |
260 F: Float + ToNalgebraRealField, |
| 299 I: AlgIteratorFactory<IterInfo<F, N>>, |
261 I: AlgIteratorFactory<IterInfo<F>>, |
| 300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, |
262 RNDM<N, F>: SpikeMerging<F>, |
| 301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, |
263 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, |
| 302 A::PreadjointCodomain: RealMapping<F, N>, |
264 Dat::DerivativeDomain: ClosedMul<F>, |
| 303 PlotLookup: Plotting<N>, |
265 Reg: RegTerm<Loc<N, F>, F>, |
| 304 RNDM<F, N>: SpikeMerging<F>, |
266 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 305 Reg: RegTerm<F, N>, |
267 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
|
| 307 { |
268 { |
| 308 // Set up parameters |
269 // Set up parameters |
| 309 let config = &fbconfig.generic; |
270 let config = &fbconfig.insertion; |
| 310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
271 let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; |
| 311 let mut λ = 1.0; |
272 let mut λ = 1.0; |
| 312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
273 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 313 // by τ compared to the conditional gradient approach. |
274 // by τ compared to the conditional gradient approach. |
| 314 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
275 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 315 let mut ε = tolerance.initial(); |
276 let mut ε = tolerance.initial(); |
| 316 |
277 |
| 317 // Initialise iterates |
278 // Initialise iterates |
| 318 let mut μ = DiscreteMeasure::new(); |
279 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 319 let mut μ_prev = DiscreteMeasure::new(); |
280 let mut μ_prev = μ.clone(); |
| 320 let mut residual = -b; |
|
| 321 let mut warned_merging = false; |
281 let mut warned_merging = false; |
| 322 |
282 |
| 323 // Statistics |
283 // Statistics |
| 324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { |
284 let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo { |
| 325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
285 value: f.apply(ν) + reg.apply(ν), |
| 326 n_spikes: ν.len(), |
286 n_spikes: ν.len(), |
| 327 ε, |
287 ε, |
| 328 // postprocessing: config.postprocessing.then(|| ν.clone()), |
288 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| 329 ..stats |
289 ..stats |
| 330 }; |
290 }; |
| 331 let mut stats = IterInfo::new(); |
291 let mut stats = IterInfo::new(); |
| 332 |
292 |
| 333 // Run the algorithm |
293 // Run the algorithm |
| 334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
294 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 335 // Calculate smooth part of surrogate model. |
295 // Calculate smooth part of surrogate model. |
| 336 let mut τv = opA.preadjoint().apply(residual * τ); |
296 let mut τv = f.differential(&μ) * τ; |
| 337 |
297 |
| 338 // Save current base point |
298 let μ_base_len = μ.len(); |
| 339 let μ_base = μ.clone(); |
|
| 340 |
299 |
| 341 // Insert new spikes and reweigh |
300 // Insert new spikes and reweigh |
| 342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
301 let (maybe_d, _within_tolerances) = prox_penalty |
| 343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
302 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, ®, &state, &mut stats)?; |
| 344 ); |
303 |
| |
304 stats.inserted += μ.len() - μ_base_len; |
| 345 |
305 |
| 346 // (Do not) merge spikes. |
306 // (Do not) merge spikes. |
| 347 if config.merge_now(&state) && !warned_merging { |
307 if config.merge_now(&state) && !warned_merging { |
| 348 let err = format!("Merging not supported for μFISTA"); |
308 let err = format!("Merging not supported for μFISTA"); |
| 349 println!("{}", err.red()); |
309 println!("{}", err.red()); |