| 77 } |
82 } |
| 78 } |
83 } |
| 79 |
84 |
| 80 #[replace_float_literals(F::cast_from(literal))] |
85 #[replace_float_literals(F::cast_from(literal))] |
| 81 pub(crate) fn insert_and_reweigh< |
86 pub(crate) fn insert_and_reweigh< |
| 82 'a, F, GA, BTA, S, Reg, State, const N : usize |
87 'a, F, GA, BTA, S, Reg, I, const N : usize |
| 83 >( |
88 >( |
| 84 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
89 μ : &mut RNDM<F, N>, |
| 85 minus_τv : &mut BTFN<F, GA, BTA, N>, |
90 τv : &mut BTFN<F, GA, BTA, N>, |
| 86 μ_base : &mut DiscreteMeasure<Loc<F, N>, F>, |
91 μ_base : &mut RNDM<F, N>, |
| 87 _ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, |
92 //_ν_delta: Option<&RNDM<F, N>>, |
| 88 τ : F, |
93 τ : F, |
| 89 ε : F, |
94 ε : F, |
| 90 config : &FBGenericConfig<F>, |
95 config : &FBGenericConfig<F>, |
| 91 reg : &Reg, |
96 reg : &Reg, |
| 92 _state : &State, |
97 _state : &AlgIteratorIteration<I>, |
| 93 stats : &mut IterInfo<F, N>, |
98 stats : &mut IterInfo<F, N>, |
| 94 ) |
99 ) |
| 95 where F : Float + ToNalgebraRealField, |
100 where F : Float + ToNalgebraRealField, |
| 96 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
101 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 97 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
102 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 98 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
103 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 99 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
104 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 100 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
105 RNDM<F, N> : SpikeMerging<F>, |
| 101 Reg : RegTerm<F, N>, |
106 Reg : RegTerm<F, N>, |
| 102 State : AlgIteratorState { |
107 I : AlgIterator { |
| 103 |
108 |
| 104 'i_and_w: for i in 0..=1 { |
109 'i_and_w: for i in 0..=1 { |
| 105 // Optimise weights |
110 // Optimise weights |
| 106 if μ.len() > 0 { |
111 if μ.len() > 0 { |
| 107 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
112 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
| 108 // from the beginning of the iteration are all contained in the immutable c and g. |
113 // from the beginning of the iteration are all contained in the immutable c and g. |
| |
114 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
| |
115 // problems have not yet been updated to sign change. |
| 109 let g̃ = DVector::from_iterator(μ.len(), |
116 let g̃ = DVector::from_iterator(μ.len(), |
| 110 μ.iter_locations() |
117 μ.iter_locations() |
| 111 .map(|ζ| F::to_nalgebra_mixed(minus_τv.apply(ζ)))); |
118 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); |
| 112 let mut x = μ.masses_dvector(); |
119 let mut x = μ.masses_dvector(); |
| 113 let y = μ_base.masses_dvector(); |
120 let y = μ_base.masses_dvector(); |
| 114 |
121 |
| 115 // Solve finite-dimensional subproblem. |
122 // Solve finite-dimensional subproblem. |
| 116 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); |
123 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); |
| 120 } |
127 } |
| 121 |
128 |
| 122 if i>0 { |
129 if i>0 { |
| 123 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. |
130 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. |
| 124 //let n = μ.dist_matching(μ_base); |
131 //let n = μ.dist_matching(μ_base); |
| 125 //println!("{:?}", reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n)); |
132 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); |
| 126 break 'i_and_w |
133 break 'i_and_w |
| 127 } |
134 } |
| 128 |
135 |
| 129 // Calculate ‖μ - μ_base‖_ℳ |
136 // Calculate ‖μ - μ_base‖_ℳ |
| 130 let n = μ.dist_matching(μ_base); |
137 let n = μ.dist_matching(μ_base); |
| 131 |
138 |
| 132 // Find a spike to insert, if needed. |
139 // Find a spike to insert, if needed. |
| 133 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, |
140 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, |
| 134 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. |
141 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. |
| 135 match reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n) { |
142 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { |
| 136 None => { break 'i_and_w }, |
143 None => { break 'i_and_w }, |
| 137 Some((ξ, _v_ξ, _in_bounds)) => { |
144 Some((ξ, _v_ξ, _in_bounds)) => { |
| 138 // Weight is found out by running the finite-dimensional optimisation algorithm |
145 // Weight is found out by running the finite-dimensional optimisation algorithm |
| 139 // above |
146 // above |
| 140 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
147 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
| 141 *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; |
148 *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; |
| |
149 stats.inserted += 1; |
| 142 } |
150 } |
| 143 }; |
151 }; |
| 144 } |
152 } |
| 145 } |
153 } |
| 146 |
154 |
| 147 #[replace_float_literals(F::cast_from(literal))] |
|
| 148 pub(crate) fn prune_and_maybe_simple_merge< |
|
| 149 'a, F, GA, BTA, S, Reg, State, const N : usize |
|
| 150 >( |
|
| 151 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
| 152 minus_τv : &mut BTFN<F, GA, BTA, N>, |
|
| 153 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 154 τ : F, |
|
| 155 ε : F, |
|
| 156 config : &FBGenericConfig<F>, |
|
| 157 reg : &Reg, |
|
| 158 state : &State, |
|
| 159 stats : &mut IterInfo<F, N>, |
|
| 160 ) |
|
| 161 where F : Float + ToNalgebraRealField, |
|
| 162 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 163 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 164 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 165 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 166 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 167 Reg : RegTerm<F, N>, |
|
| 168 State : AlgIteratorState { |
|
| 169 |
|
| 170 assert!(μ_base.len() <= μ.len()); |
|
| 171 |
|
| 172 if state.iteration() % config.merge_every == 0 { |
|
| 173 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
|
| 174 // Important: μ_candidate's new points are afterwards, |
|
| 175 // and do not conflict with μ_base. |
|
| 176 // TODO: could simplify to requiring μ_base instead of μ_radon. |
|
| 177 // but may complicate with sliding base's exgtra points that need to be |
|
| 178 // after μ_candidate's extra points. |
|
| 179 // TODO: doesn't seem to work, maybe need to merge μ_base as well? |
|
| 180 // Although that doesn't seem to make sense. |
|
| 181 let μ_radon = μ_candidate.sub_matching(μ_base); |
|
| 182 reg.verify_merge_candidate_radonsq(minus_τv, μ_candidate, τ, ε, &config, &μ_radon) |
|
| 183 //let n = μ_candidate.dist_matching(μ_base); |
|
| 184 //reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n).is_none() |
|
| 185 }); |
|
| 186 } |
|
| 187 |
|
| 188 let n_before_prune = μ.len(); |
|
| 189 μ.prune(); |
|
| 190 debug_assert!(μ.len() <= n_before_prune); |
|
| 191 stats.pruned += n_before_prune - μ.len(); |
|
| 192 } |
|
| 193 |
|
| 194 |
155 |
| 195 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. |
156 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. |
| 196 /// |
157 /// |
| 197 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
158 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the |
| 198 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
159 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
| 199 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
160 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 200 /// as documented in [`alg_tools::iterate`]. |
161 /// as documented in [`alg_tools::iterate`]. |
| 201 /// |
162 /// |
| 202 /// For details on the mathematical formulation, see the [module level](self) documentation. |
163 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 217 b : &A::Observable, |
178 b : &A::Observable, |
| 218 reg : Reg, |
179 reg : Reg, |
| 219 fbconfig : &RadonFBConfig<F>, |
180 fbconfig : &RadonFBConfig<F>, |
| 220 iterator : I, |
181 iterator : I, |
| 221 mut _plotter : SeqPlotter<F, N>, |
182 mut _plotter : SeqPlotter<F, N>, |
| 222 ) -> DiscreteMeasure<Loc<F, N>, F> |
183 ) -> RNDM<F, N> |
| 223 where F : Float + ToNalgebraRealField, |
184 where F : Float + ToNalgebraRealField, |
| 224 I : AlgIteratorFactory<IterInfo<F, N>>, |
185 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 225 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
186 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
| 226 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
| 227 A::Observable : std::ops::MulAssign<F>, |
|
| 228 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
187 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 229 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
188 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
| 230 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
189 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 231 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
190 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 232 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
191 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 233 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
192 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
| 234 PlotLookup : Plotting<N>, |
193 RNDM<F, N> : SpikeMerging<F>, |
| 235 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 236 Reg : RegTerm<F, N> { |
194 Reg : RegTerm<F, N> { |
| 237 |
195 |
| 238 // Set up parameters |
196 // Set up parameters |
| 239 let config = &fbconfig.insertion; |
197 let config = &fbconfig.insertion; |
| 240 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
198 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
| 241 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
199 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
| 242 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
200 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
| 243 let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); |
201 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); |
| 244 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
202 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 245 // by τ compared to the conditional gradient approach. |
203 // by τ compared to the conditional gradient approach. |
| 246 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
204 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 247 let mut ε = tolerance.initial(); |
205 let mut ε = tolerance.initial(); |
| 248 |
206 |
| 249 // Initialise iterates |
207 // Initialise iterates |
| 250 let mut μ = DiscreteMeasure::new(); |
208 let mut μ = DiscreteMeasure::new(); |
| 251 let mut residual = -b; |
209 let mut residual = -b; |
| |
210 |
| |
211 // Statistics |
| |
212 let full_stats = |residual : &A::Observable, |
| |
213 μ : &RNDM<F, N>, |
| |
214 ε, stats| IterInfo { |
| |
215 value : residual.norm2_squared_div2() + reg.apply(μ), |
| |
216 n_spikes : μ.len(), |
| |
217 ε, |
| |
218 // postprocessing: config.postprocessing.then(|| μ.clone()), |
| |
219 .. stats |
| |
220 }; |
| 252 let mut stats = IterInfo::new(); |
221 let mut stats = IterInfo::new(); |
| 253 |
222 |
| 254 // Run the algorithm |
223 // Run the algorithm |
| 255 iterator.iterate(|state| { |
224 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
| 256 // Calculate smooth part of surrogate model. |
225 // Calculate smooth part of surrogate model. |
| 257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
226 let mut τv = opA.preadjoint().apply(residual * τ); |
| 258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
| 259 // the residual and replacing it below before the end of this closure. |
|
| 260 residual *= -τ; |
|
| 261 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
| 262 let mut minus_τv = opA.preadjoint().apply(r); |
|
| 263 |
227 |
| 264 // Save current base point |
228 // Save current base point |
| 265 let mut μ_base = μ.clone(); |
229 let mut μ_base = μ.clone(); |
| 266 |
230 |
| 267 // Insert and reweigh |
231 // Insert and reweigh |
| 268 insert_and_reweigh( |
232 insert_and_reweigh( |
| 269 &mut μ, &mut minus_τv, &mut μ_base, None, |
233 &mut μ, &mut τv, &mut μ_base, //None, |
| 270 τ, ε, |
234 τ, ε, |
| 271 config, ®, state, &mut stats |
235 config, ®, &state, &mut stats |
| 272 ); |
236 ); |
| 273 |
237 |
| 274 // Prune and possibly merge spikes |
238 // Prune and possibly merge spikes |
| 275 prune_and_maybe_simple_merge( |
239 assert!(μ_base.len() <= μ.len()); |
| 276 &mut μ, &mut minus_τv, &μ_base, |
240 if config.merge_now(&state) { |
| 277 τ, ε, |
241 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
| 278 config, ®, state, &mut stats |
242 // Important: μ_candidate's new points are afterwards, |
| 279 ); |
243 // and do not conflict with μ_base. |
| |
244 // TODO: could simplify to requiring μ_base instead of μ_radon. |
| |
245 // but may complicate with sliding base's exgtra points that need to be |
| |
246 // after μ_candidate's extra points. |
| |
247 // TODO: doesn't seem to work, maybe need to merge μ_base as well? |
| |
248 // Although that doesn't seem to make sense. |
| |
249 let μ_radon = μ_candidate.sub_matching(&μ_base); |
| |
250 reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon) |
| |
251 //let n = μ_candidate.dist_matching(μ_base); |
| |
252 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() |
| |
253 }); |
| |
254 } |
| |
255 stats.pruned += prune_with_stats(&mut μ); |
| 280 |
256 |
| 281 // Update residual |
257 // Update residual |
| 282 residual = calculate_residual(&μ, opA, b); |
258 residual = calculate_residual(&μ, opA, b); |
| 283 |
259 |
| |
260 let iter = state.iteration(); |
| |
261 stats.this_iters += 1; |
| |
262 |
| |
263 // Give statistics if needed |
| |
264 state.if_verbose(|| { |
| |
265 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| |
266 }); |
| |
267 |
| 284 // Update main tolerance for next iteration |
268 // Update main tolerance for next iteration |
| 285 let ε_prev = ε; |
269 ε = tolerance.update(ε, iter); |
| 286 ε = tolerance.update(ε, state.iteration()); |
270 } |
| 287 stats.this_iters += 1; |
|
| 288 |
|
| 289 // Give function value if needed |
|
| 290 state.if_verbose(|| { |
|
| 291 // Plot if so requested |
|
| 292 // plotter.plot_spikes( |
|
| 293 // format!("iter {} end;", state.iteration()), &d, |
|
| 294 // "start".to_string(), Some(&minus_τv), |
|
| 295 // reg.target_bounds(τ, ε_prev), &μ, |
|
| 296 // ); |
|
| 297 // Calculate mean inner iterations and reset relevant counters. |
|
| 298 // Return the statistics |
|
| 299 let res = IterInfo { |
|
| 300 value : residual.norm2_squared_div2() + reg.apply(&μ), |
|
| 301 n_spikes : μ.len(), |
|
| 302 ε : ε_prev, |
|
| 303 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
| 304 .. stats |
|
| 305 }; |
|
| 306 stats = IterInfo::new(); |
|
| 307 res |
|
| 308 }) |
|
| 309 }); |
|
| 310 |
271 |
| 311 postprocess(μ, config, L2Squared, opA, b) |
272 postprocess(μ, config, L2Squared, opA, b) |
| 312 } |
273 } |
| 313 |
274 |
| 314 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. |
275 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. |
| 315 /// |
276 /// |
| 316 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
277 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the |
| 317 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
278 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
| 318 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
279 /// Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 319 /// as documented in [`alg_tools::iterate`]. |
280 /// as documented in [`alg_tools::iterate`]. |
| 320 /// |
281 /// |
| 321 /// For details on the mathematical formulation, see the [module level](self) documentation. |
282 /// For details on the mathematical formulation, see the [module level](self) documentation. |
| 335 opA : &'a A, |
296 opA : &'a A, |
| 336 b : &A::Observable, |
297 b : &A::Observable, |
| 337 reg : Reg, |
298 reg : Reg, |
| 338 fbconfig : &RadonFBConfig<F>, |
299 fbconfig : &RadonFBConfig<F>, |
| 339 iterator : I, |
300 iterator : I, |
| 340 mut _plotter : SeqPlotter<F, N>, |
301 mut plotter : SeqPlotter<F, N>, |
| 341 ) -> DiscreteMeasure<Loc<F, N>, F> |
302 ) -> RNDM<F, N> |
| 342 where F : Float + ToNalgebraRealField, |
303 where F : Float + ToNalgebraRealField, |
| 343 I : AlgIteratorFactory<IterInfo<F, N>>, |
304 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 344 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
305 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
| 345 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
| 346 A::Observable : std::ops::MulAssign<F>, |
|
| 347 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
306 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 348 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
307 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
| 349 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
308 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 350 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
309 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 351 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
310 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 352 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
311 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
| 353 PlotLookup : Plotting<N>, |
312 PlotLookup : Plotting<N>, |
| 354 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
313 RNDM<F, N> : SpikeMerging<F>, |
| 355 Reg : RegTerm<F, N> { |
314 Reg : RegTerm<F, N> { |
| 356 |
315 |
| 357 // Set up parameters |
316 // Set up parameters |
| 358 let config = &fbconfig.insertion; |
317 let config = &fbconfig.insertion; |
| 359 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
318 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ |
| 360 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
319 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such |
| 361 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
320 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. |
| 362 let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); |
321 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); |
| 363 let mut λ = 1.0; |
322 let mut λ = 1.0; |
| 364 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
323 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 365 // by τ compared to the conditional gradient approach. |
324 // by τ compared to the conditional gradient approach. |
| 366 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
325 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 367 let mut ε = tolerance.initial(); |
326 let mut ε = tolerance.initial(); |
| 368 |
327 |
| 369 // Initialise iterates |
328 // Initialise iterates |
| 370 let mut μ = DiscreteMeasure::new(); |
329 let mut μ = DiscreteMeasure::new(); |
| 371 let mut μ_prev = DiscreteMeasure::new(); |
330 let mut μ_prev = DiscreteMeasure::new(); |
| 372 let mut residual = -b; |
331 let mut residual = -b; |
| |
332 let mut warned_merging = false; |
| |
333 |
| |
334 // Statistics |
| |
335 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
| |
336 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
| |
337 n_spikes : ν.len(), |
| |
338 ε, |
| |
339 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| |
340 .. stats |
| |
341 }; |
| 373 let mut stats = IterInfo::new(); |
342 let mut stats = IterInfo::new(); |
| 374 let mut warned_merging = false; |
|
| 375 |
343 |
| 376 // Run the algorithm |
344 // Run the algorithm |
| 377 iterator.iterate(|state| { |
345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 378 // Calculate smooth part of surrogate model. |
346 // Calculate smooth part of surrogate model. |
| 379 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
347 let mut τv = opA.preadjoint().apply(residual * τ); |
| 380 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
| 381 // the residual and replacing it below before the end of this closure. |
|
| 382 residual *= -τ; |
|
| 383 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
| 384 let mut minus_τv = opA.preadjoint().apply(r); |
|
| 385 |
348 |
| 386 // Save current base point |
349 // Save current base point |
| 387 let mut μ_base = μ.clone(); |
350 let mut μ_base = μ.clone(); |
| 388 |
351 |
| 389 // Insert new spikes and reweigh |
352 // Insert new spikes and reweigh |
| 390 insert_and_reweigh( |
353 insert_and_reweigh( |
| 391 &mut μ, &mut minus_τv, &mut μ_base, None, |
354 &mut μ, &mut τv, &mut μ_base, //None, |
| 392 τ, ε, |
355 τ, ε, |
| 393 config, ®, state, &mut stats |
356 config, ®, &state, &mut stats |
| 394 ); |
357 ); |
| 395 |
358 |
| 396 // (Do not) merge spikes. |
359 // (Do not) merge spikes. |
| 397 if state.iteration() % config.merge_every == 0 { |
360 if config.merge_now(&state) { |
| 398 match config.merging { |
361 match config.merging { |
| 399 SpikeMergingMethod::None => { }, |
362 SpikeMergingMethod::None => { }, |
| 400 _ => if !warned_merging { |
363 _ => if !warned_merging { |
| 401 let err = format!("Merging not supported for μFISTA"); |
364 let err = format!("Merging not supported for μFISTA"); |
| 402 println!("{}", err.red()); |
365 println!("{}", err.red()); |
| 421 debug_assert!(μ.len() <= n_before_prune); |
384 debug_assert!(μ.len() <= n_before_prune); |
| 422 stats.pruned += n_before_prune - μ.len(); |
385 stats.pruned += n_before_prune - μ.len(); |
| 423 |
386 |
| 424 // Update residual |
387 // Update residual |
| 425 residual = calculate_residual(&μ, opA, b); |
388 residual = calculate_residual(&μ, opA, b); |
| |
389 |
| |
390 let iter = state.iteration(); |
| |
391 stats.this_iters += 1; |
| |
392 |
| |
393 // Give statistics if needed |
| |
394 state.if_verbose(|| { |
| |
395 plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev); |
| |
396 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| |
397 }); |
| 426 |
398 |
| 427 // Update main tolerance for next iteration |
399 // Update main tolerance for next iteration |
| 428 let ε_prev = ε; |
400 ε = tolerance.update(ε, iter); |
| 429 ε = tolerance.update(ε, state.iteration()); |
401 } |
| 430 stats.this_iters += 1; |
|
| 431 |
|
| 432 // Give function value if needed |
|
| 433 state.if_verbose(|| { |
|
| 434 // Plot if so requested |
|
| 435 // plotter.plot_spikes( |
|
| 436 // format!("iter {} end;", state.iteration()), &d, |
|
| 437 // "start".to_string(), Some(&minus_τv), |
|
| 438 // reg.target_bounds(τ, ε_prev), &μ_prev, |
|
| 439 // ); |
|
| 440 // Calculate mean inner iterations and reset relevant counters. |
|
| 441 // Return the statistics |
|
| 442 let res = IterInfo { |
|
| 443 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), |
|
| 444 n_spikes : μ_prev.len(), |
|
| 445 ε : ε_prev, |
|
| 446 postprocessing: config.postprocessing.then(|| μ_prev.clone()), |
|
| 447 .. stats |
|
| 448 }; |
|
| 449 stats = IterInfo::new(); |
|
| 450 res |
|
| 451 }) |
|
| 452 }); |
|
| 453 |
402 |
| 454 postprocess(μ_prev, config, L2Squared, opA, b) |
403 postprocess(μ_prev, config, L2Squared, opA, b) |
| 455 } |
404 } |