67 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
62 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
68 #[serde(default)] |
63 #[serde(default)] |
69 pub struct SlidingFBConfig<F : Float> { |
64 pub struct SlidingFBConfig<F : Float> { |
70 /// Step length scaling |
65 /// Step length scaling |
71 pub τ0 : F, |
66 pub τ0 : F, |
72 /// Transport smoothness assumption |
67 /// Transport step length $θ$ normalised to $(0, 1)$. |
73 pub ℓ0 : F, |
68 pub θ0 : F, |
74 /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. |
69 /// Maximum transport mass scaling. |
75 /// This means that $τθ$ is the step length for the transport step. |
70 // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$. |
76 pub inverse_transport_scaling : F, |
71 // pub max_transport_scale : F, |
77 /// Factor for deciding transport reduction based on smoothness assumption violation |
72 /// Transport tolerance wrt. ω |
78 pub minimum_goodness_factor : F, |
73 pub transport_tolerance_ω : F, |
79 /// Maximum rays to retain in transports from each source. |
74 /// Transport tolerance wrt. ∇v |
80 pub maximum_rays : usize, |
75 pub transport_tolerance_dv : F, |
81 /// Generic parameters |
76 /// Generic parameters |
82 pub insertion : FBGenericConfig<F>, |
77 pub insertion : FBGenericConfig<F>, |
83 } |
78 } |
84 |
79 |
85 #[replace_float_literals(F::cast_from(literal))] |
80 #[replace_float_literals(F::cast_from(literal))] |
86 impl<F : Float> Default for SlidingFBConfig<F> { |
81 impl<F : Float> Default for SlidingFBConfig<F> { |
87 fn default() -> Self { |
82 fn default() -> Self { |
88 SlidingFBConfig { |
83 SlidingFBConfig { |
89 τ0 : 0.99, |
84 τ0 : 0.99, |
90 ℓ0 : 1.5, |
85 θ0 : 0.99, |
91 inverse_transport_scaling : 1.0, |
86 //max_transport_scale : 10.0, |
92 minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, |
87 transport_tolerance_ω : 1.0, // TODO: no idea what this should be |
93 // should be scaled by problem data? |
88 transport_tolerance_dv : 1.0, // TODO: no idea what this should be |
94 maximum_rays : 10, |
|
95 insertion : Default::default() |
89 insertion : Default::default() |
96 } |
90 } |
97 } |
91 } |
98 } |
92 } |
99 |
93 |
100 /// A transport ray (including various additional computational information). |
94 /// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i) |
101 #[derive(Clone, Debug)] |
|
102 pub struct Ray<Domain, F : Num> { |
|
103 /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`]. |
|
104 δ : DeltaMeasure<Domain, F>, |
|
105 /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$. |
|
106 goodness : F, |
|
107 /// Goodness of the regularisation term for the ray: $w(z)-w(y)$. |
|
108 /// Initially zero until $w$ can be constructed. |
|
109 reg_goodness : F, |
|
110 /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`. |
|
111 to_return : F, |
|
112 } |
|
113 |
|
114 /// A set of transport rays with the same source point. |
|
115 #[derive(Clone, Debug)] |
|
116 pub struct RaySet<Domain, F : Num> { |
|
117 /// Source of every ray in thset |
|
118 source : Domain, |
|
119 /// Mass of the diagonal ray, with destination the same as the source. |
|
120 diagonal: F, |
|
121 /// Goodness of the data term for the diagonal ray with $z=x$: |
|
122 /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$. |
|
123 diagonal_goodness : F, |
|
124 /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$. |
|
125 diagonal_reg_goodness : F, |
|
126 /// The non-diagonal rays. |
|
127 rays : Vec<Ray<Domain, F>>, |
|
128 } |
|
129 |
|
130 #[replace_float_literals(F::cast_from(literal))] |
95 #[replace_float_literals(F::cast_from(literal))] |
131 impl<Domain, F : Float> RaySet<Domain, F> { |
96 fn scale_down<'a, I, F, G, const N : usize>( |
132 fn non_diagonal_mass(&self) -> F { |
97 iter : I, |
133 self.rays |
98 q̄ : F, |
134 .iter() |
99 mut g : G |
135 .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) |
100 ) where F : Float, |
136 .sum() |
101 I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>, |
137 } |
102 G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F { |
138 |
103 iter.for_each(|δ| { |
139 fn total_mass(&self) -> F { |
104 if δ.α != 0.0 { |
140 self.non_diagonal_mass() + self.diagonal |
105 let b = g(δ); |
141 } |
106 if b * δ.α > 0.0 { |
142 |
107 δ.α *= q̄/b; |
143 fn targets<'a>(&'a self) |
108 } |
144 -> Map< |
109 } |
145 std::slice::Iter<'a, Ray<Domain, F>>, |
110 }); |
146 fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F> |
|
147 > { |
|
148 fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>) |
|
149 -> &'b DeltaMeasure<Domain, F> { |
|
150 δ |
|
151 } |
|
152 self.rays |
|
153 .iter() |
|
154 .map(get_δ) |
|
155 } |
|
156 |
|
157 // fn non_diagonal_goodness(&self) -> F { |
|
158 // self.rays |
|
159 // .iter() |
|
160 // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { |
|
161 // α * (goodness + reg_goodness) |
|
162 // }) |
|
163 // .sum() |
|
164 // } |
|
165 |
|
166 // fn total_goodness(&self) -> F { |
|
167 // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness) |
|
168 // } |
|
169 |
|
170 fn non_diagonal_badness(&self) -> F { |
|
171 self.rays |
|
172 .iter() |
|
173 .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { |
|
174 0.0.max(- α * (goodness + reg_goodness)) |
|
175 }) |
|
176 .sum() |
|
177 } |
|
178 |
|
179 fn total_badness(&self) -> F { |
|
180 self.non_diagonal_badness() |
|
181 + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness)) |
|
182 } |
|
183 |
|
184 fn total_return(&self) -> F { |
|
185 self.rays |
|
186 .iter() |
|
187 .map(|&Ray{ to_return, .. }| to_return) |
|
188 .sum() |
|
189 } |
|
190 } |
|
191 |
|
192 #[replace_float_literals(F::cast_from(literal))] |
|
193 impl<Domain : Clone, F : Num> RaySet<Domain, F> { |
|
194 fn return_targets<'a>(&'a self) |
|
195 -> Flatten<Map< |
|
196 std::slice::Iter<'a, Ray<Domain, F>>, |
|
197 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> |
|
198 >> { |
|
199 fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>) |
|
200 -> Option<DeltaMeasure<Domain, F>> { |
|
201 (ray.to_return != 0.0).then_some( |
|
202 DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return} |
|
203 ) |
|
204 } |
|
205 let tmp : Map< |
|
206 std::slice::Iter<'a, Ray<Domain, F>>, |
|
207 fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> |
|
208 > = self.rays |
|
209 .iter() |
|
210 .map(get_return); |
|
211 tmp.flatten() |
|
212 } |
|
213 } |
111 } |
214 |
112 |
215 /// Iteratively solve the pointsource localisation problem using sliding forward-backward |
113 /// Iteratively solve the pointsource localisation problem using sliding forward-backward |
216 /// splitting |
114 /// splitting |
217 /// |
115 /// |
218 /// The parametrisatio is as for [`pointsource_fb_reg`]. |
116 /// The parametrisatio is as for [`pointsource_fb_reg`]. |
219 /// Inertia is currently not supported. |
117 /// Inertia is currently not supported. |
220 #[replace_float_literals(F::cast_from(literal))] |
118 #[replace_float_literals(F::cast_from(literal))] |
221 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( |
119 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( |
222 opA : &'a A, |
120 opA : &'a A, |
223 b : &A::Observable, |
121 b : &A::Observable, |
224 reg : Reg, |
122 reg : Reg, |
225 op𝒟 : &'a 𝒟, |
123 op𝒟 : &'a 𝒟, |
226 sfbconfig : &SlidingFBConfig<F>, |
124 sfbconfig : &SlidingFBConfig<F>, |
249 PlotLookup : Plotting<N>, |
148 PlotLookup : Plotting<N>, |
250 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
149 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
251 Reg : SlidingRegTerm<F, N> { |
150 Reg : SlidingRegTerm<F, N> { |
252 |
151 |
253 assert!(sfbconfig.τ0 > 0.0 && |
152 assert!(sfbconfig.τ0 > 0.0 && |
254 sfbconfig.inverse_transport_scaling > 0.0 && |
153 sfbconfig.θ0 > 0.0); |
255 sfbconfig.ℓ0 > 0.0); |
|
256 |
154 |
257 // Set up parameters |
155 // Set up parameters |
258 let config = &sfbconfig.insertion; |
156 let config = &sfbconfig.insertion; |
259 let op𝒟norm = op𝒟.opnorm_bound(); |
157 let op𝒟norm = op𝒟.opnorm_bound(); |
260 let θ = sfbconfig.inverse_transport_scaling; |
158 //let max_transport = sfbconfig.max_transport_scale |
261 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() |
159 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
262 .max(opA.transport_lipschitz_factor(L2Squared) * θ); |
160 //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport; |
263 let ℓ = sfbconfig.ℓ0; // TODO: v scaling? |
161 //let ℓ = 0.0; |
|
162 let θ = sfbconfig.θ0; // (ℓ + tlip); |
|
163 let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
264 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
164 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
265 // by τ compared to the conditional gradient approach. |
165 // by τ compared to the conditional gradient approach. |
266 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
166 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
267 let mut ε = tolerance.initial(); |
167 let mut ε = tolerance.initial(); |
268 |
168 |
269 // Initialise iterates |
169 // Initialise iterates |
270 let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); |
170 let mut μ = DiscreteMeasure::new(); |
271 let mut μ_transported_base = DiscreteMeasure::new(); |
171 let mut γ1 = DiscreteMeasure::new(); |
272 let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info |
|
273 let mut residual = -b; |
172 let mut residual = -b; |
274 let mut stats = IterInfo::new(); |
173 let mut stats = IterInfo::new(); |
275 |
174 |
276 // Run the algorithm |
175 // Run the algorithm |
277 iterator.iterate(|state| { |
176 iterator.iterate(|state| { |
278 // Calculate smooth part of surrogate model. |
177 // Calculate smooth part of surrogate model. |
279 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
178 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
280 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
179 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
281 // the residual and replacing it below before the end of this closure. |
180 // the residual and replacing it below before the end of this closure. |
282 residual *= -τ; |
|
283 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
181 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
284 let minus_τv = opA.preadjoint().apply(r); |
182 let v = opA.preadjoint().apply(r); |
285 |
183 |
286 // Save current base point and shift μ to new positions. |
184 // Save current base point and shift μ to new positions. Idea is that |
287 let μ_base = μ.clone(); |
185 // μ_base(_masses) = μ^k (vector of masses) |
288 for δ in μ.iter_spikes_mut() { |
186 // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} |
289 δ.x += minus_τv.differential(&δ.x) * θ; |
187 // γ1 = π_♯^1γ^{k+1} |
290 } |
188 // μ = μ^{k+1} |
291 let mut μ_transported = μ.clone(); |
189 let μ_base_masses : Vec<F> = μ.iter_masses().collect(); |
292 |
190 let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. |
293 assert_eq!(μ.len(), γ_hat.len()); |
191 // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates |
294 |
192 let mut sum_norm_dv_times_γinit = 0.0; |
295 // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter |
193 let mut sum_abs_γinit = 0.0; |
296 // transports points x from μ_base to points y in μ as shifted above, or “returns” |
194 //let mut sum_norm_dv = 0.0; |
297 // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays |
195 let γ_prev_len = γ1.len(); |
298 // are not “good” for the smoothness assumptions, or if γ_hat has more mass than |
196 assert!(μ.len() >= γ_prev_len); |
299 // μ_base. |
197 γ1.extend(μ[γ_prev_len..].iter().cloned()); |
300 let mut total_goodness = 0.0; // data term goodness |
198 for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) { |
301 let mut total_reg_goodness = 0.0; // regulariser goodness |
199 let d_v_x = v.differential(&δ.x); |
302 let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; |
200 // If old transport has opposing sign, the new transport will be none. |
303 |
201 ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) { |
304 for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { |
202 0.0 |
305 // Calculate data term goodness for all rays. |
203 } else { |
306 let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; |
204 δ.α |
307 let x = &r.source; |
|
308 let mvy = minus_τv.apply(y); |
|
309 let mdvx = minus_τv.differential(x); |
|
310 let mut r_total_mass = 0.0; // Total mass of all rays with source r.source. |
|
311 let mut bad_mass = 0.0; |
|
312 let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| { |
|
313 *reg_goodness = 0.0; // Initial guess |
|
314 *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y)) |
|
315 + ℓ * z.dist2_squared(&y); |
|
316 total_goodness += *goodness * α; |
|
317 r_total_mass += α; // TODO: should this include to_return from staging? (Probably not) |
|
318 if *goodness < 0.0 { |
|
319 bad_mass += α; |
|
320 } |
|
321 }; |
205 }; |
322 for ray in r.rays.iter_mut() { |
206 δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0. |
323 calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); |
207 ρ.x = δ.x; |
|
208 let nrm = d_v_x.norm(L2); |
|
209 let a = ρ.α.abs(); |
|
210 let v = nrm * a; |
|
211 if v > 0.0 { |
|
212 sum_norm_dv_times_γinit += v; |
|
213 sum_abs_γinit += a; |
324 } |
214 } |
325 calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); |
215 } |
326 |
216 |
327 // If the total mass of the ray set is less than that of μ at the same source, |
217 // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z). |
328 // a diagonal component needs to be added to be able to (attempt to) transport |
218 // This is just one option, there are many. |
329 // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ |
219 let t = ε * sfbconfig.transport_tolerance_dv; |
330 // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” |
220 if sum_norm_dv_times_γinit > t { |
331 // rays in the set. |
221 // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t |
332 if δ_mass >= r_total_mass { |
222 // TODO: store the closure values above? |
333 r.diagonal += δ_mass - r_total_mass; |
223 scale_down(γ1.iter_spikes_mut(), |
334 } else { |
224 t / sum_abs_γinit, |
335 let mut reduce_transport = r_total_mass - δ_mass; |
225 |δ| v.differential(&δ.x).norm(L2)); |
336 let mut good_needed = (bad_mass - reduce_transport).max(0.0); |
226 } |
337 // NOTE: reg_goodness is zero at this point, so it is not used in this code. |
227 //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon)); |
338 let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| { |
|
339 if reduce_transport > 0.0 { |
|
340 let return_amount = if goodness < 0.0 { |
|
341 α.min(reduce_transport) |
|
342 } else { |
|
343 let amount = α.min(good_needed); |
|
344 good_needed -= amount; |
|
345 amount |
|
346 }; |
|
347 |
|
348 if return_amount > 0.0 { |
|
349 reduce_transport -= return_amount; |
|
350 // Adjust total goodness by returned amount |
|
351 total_goodness -= goodness * return_amount; |
|
352 to_return.map(|tr| *tr += return_amount); |
|
353 *α -= return_amount; |
|
354 *α > 0.0 |
|
355 } else { |
|
356 true |
|
357 } |
|
358 } else { |
|
359 true |
|
360 } |
|
361 }; |
|
362 r.rays.retain_mut(|ray| { |
|
363 reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α) |
|
364 }); |
|
365 // A bad diagonal is simply reduced without any 'return'. |
|
366 // It was, after all, just added to match μ, but there is no need to match it. |
|
367 // It's just a heuristic. |
|
368 // TODO: Maybe a bad diagonal should be the first to go. |
|
369 reduce_ray(r.diagonal_goodness, None, &mut r.diagonal); |
|
370 } |
|
371 } |
|
372 |
228 |
373 // Solve finite-dimensional subproblem several times until the dual variable for the |
229 // Solve finite-dimensional subproblem several times until the dual variable for the |
374 // regularisation term conforms to the assumptions made for the transport above. |
230 // regularisation term conforms to the assumptions made for the transport above. |
375 let (d, within_tolerances) = 'adapt_transport: loop { |
231 let (d, within_tolerances) = 'adapt_transport: loop { |
376 // If transport violates goodness requirements, shift it to ‘return’ mass to z, |
232 // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} |
377 // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), |
233 for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(), |
378 // we proportionally distribute the reductions to each ray set, and within each ray |
234 μ_base_minus_γ0.iter_spikes_mut(), |
379 // set, prioritise reducing the oldest bad rays' weight. |
235 μ_base_masses.iter()) { |
380 let tg = total_goodness + total_reg_goodness; |
236 δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass()); |
381 let adaptation_needed = minimum_goodness - tg; |
237 } |
382 if adaptation_needed > 0.0 { |
238 |
383 let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); |
239 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) |
384 |
240 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
385 let mut return_ray = |goodness : F, |
241 let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ)); |
386 reg_goodness : F, |
242 |
387 to_return : Option<&mut F>, |
243 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
388 α : &mut F, |
244 let (d, within_tolerances) = insert_and_reweigh( |
389 left_to_return : &mut F| { |
245 &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0), |
390 let g = goodness + reg_goodness; |
246 op𝒟, op𝒟norm, |
391 assert!(*α >= 0.0 && *left_to_return >= 0.0); |
247 τ, ε, |
392 if *left_to_return > 0.0 && g < 0.0 { |
248 config, |
393 let return_amount = (*left_to_return / (-g)).min(*α); |
249 ®, state, &mut stats, |
394 *left_to_return -= (-g) * return_amount; |
250 ); |
395 total_goodness -= goodness * return_amount; |
251 |
396 total_reg_goodness -= reg_goodness * return_amount; |
252 // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z). |
397 to_return.map(|tr| *tr += return_amount); |
253 let all_ok = if false { // Basic check |
398 *α -= return_amount; |
254 // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, |
399 *α > 0.0 |
255 // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 |
400 } else { |
256 // at that point to zero, and retry. |
401 true |
257 let mut all_ok = true; |
402 } |
258 for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) { |
403 }; |
259 if α_μ == 0.0 && *α_γ1 != 0.0 { |
404 |
260 all_ok = false; |
405 for r in γ_hat.iter_mut() { |
261 *α_γ1 = 0.0; |
406 let mut left_to_return = adaptation_needed * r.total_badness() / total_badness; |
|
407 if left_to_return > 0.0 { |
|
408 for ray in r.rays.iter_mut() { |
|
409 return_ray(ray.goodness, ray.reg_goodness, |
|
410 Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return); |
|
411 } |
|
412 return_ray(r.diagonal_goodness, r.diagonal_reg_goodness, |
|
413 None, &mut r.diagonal, &mut left_to_return); |
|
414 } |
262 } |
415 } |
263 } |
416 } |
264 all_ok |
417 |
265 } else { |
418 // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. |
266 // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above. |
419 // This can be broken down into |
267 let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0); |
420 // |
268 |
421 // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and |
269 // let vpos = γ1.iter_spikes() |
422 // μ_transported = π_#^1 γ_shift |
270 // .filter(|δ| δ.α > 0.0) |
423 // |
271 // .map(|δ| minus_ω.apply(&δ.x)) |
424 // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. |
272 // .reduce(F::max) |
425 // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ |
273 // .and_then(|threshold| { |
426 // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure |
274 // minus_ω.minimise_below(threshold, |
427 // The latter can be constructed from γ_hat rays' source and destination with the |
275 // ε * config.refinement.tolerance_mult, |
428 // to_return mass. |
276 // config.refinement.max_steps) |
429 // |
277 // .map(|(_z, minus_ω_z)| minus_ω_z) |
430 // Note that μ_transported is constructed to have the same spike locations as μ, but |
278 // }); |
431 // to have same length as μ_base. This loop does not iterate over the spikes of μ |
279 |
432 // (and corresponding transports of γ_hat) that have been newly added in the current |
280 // let vneg = γ1.iter_spikes() |
433 // 'adapt_transport loop. |
281 // .filter(|δ| δ.α < 0.0) |
434 for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), |
282 // .map(|δ| minus_ω.apply(&δ.x)) |
435 μ_transported.iter_spikes_mut(), |
283 // .reduce(F::min) |
436 γ_hat.iter()) { |
284 // .and_then(|threshold| { |
437 let &DeltaMeasure{ref x, α} = δ; |
285 // minus_ω.maximise_above(threshold, |
438 debug_assert_eq!(*x, r.source); |
286 // ε * config.refinement.tolerance_mult, |
439 let shifted_mass = r.total_mass(); |
287 // config.refinement.max_steps) |
440 let ret_mass = r.total_return(); |
288 // .map(|(_z, minus_ω_z)| minus_ω_z) |
441 // μ - π_#^0 (γ_shift + γ_return) |
289 // }); |
442 μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; |
290 let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult, |
443 // π_#^1 γ_return |
291 config.refinement.max_steps); |
444 μ_transported_base.extend(r.return_targets()); |
292 let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult, |
445 // π_#^1 γ_shift |
293 config.refinement.max_steps); |
446 δ_transported.set_mass(shifted_mass); |
294 |
447 } |
295 let t = τ * ε * sfbconfig.transport_tolerance_ω; |
448 // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) |
296 let val = |δ : &DeltaMeasure<Loc<F, N>, F>| { |
449 let transported_residual = calculate_residual2(&μ_transported, |
297 δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg }) |
450 &μ_transported_base, |
298 // match if δ.α >= 0.0 { vpos } else { vneg } { |
451 opA, b); |
299 // None => 0.0, |
452 let transported_minus_τv = opA.preadjoint() |
300 // Some(v) => δ.α * (minus_ω.apply(&δ.x) - v) |
453 .apply(transported_residual); |
301 // } |
454 |
302 }; |
455 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
303 // Calculate positive/bad (rp) values under the integral. |
456 let (mut d, within_tolerances) = insert_and_reweigh( |
304 // Also store sum of masses for the positive entries. |
457 &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), |
305 let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| { |
458 op𝒟, op𝒟norm, |
306 let v = val(δ); |
459 τ, ε, |
307 if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) } |
460 config, ®, state, &mut stats |
308 }); |
461 ); |
309 |
462 |
310 if rp > t { |
463 // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely |
311 // TODO: store v above? |
464 // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); |
312 scale_down(γ1.iter_spikes_mut(), t / w, val); |
465 // We “essentially” assume that the subdifferential w of the regularisation term |
313 false |
466 // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] |
314 } else { |
467 // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for |
315 true |
468 // some w in the subdifferential of the regularisation term, such that |
|
469 // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`]. |
|
470 for r in γ_hat.iter_mut() { |
|
471 for ray in r.rays.iter_mut() { |
|
472 ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config); |
|
473 total_reg_goodness += ray.reg_goodness * ray.δ.α; |
|
474 } |
316 } |
475 } |
317 }; |
476 |
318 |
477 // If update of regularisation term goodness didn't invalidate minimum goodness |
319 if all_ok { |
478 // requirements, we have found our step. Otherwise we need to keep reducing |
|
479 // transport by repeating the loop. |
|
480 if total_goodness + total_reg_goodness >= minimum_goodness { |
|
481 break 'adapt_transport (d, within_tolerances) |
320 break 'adapt_transport (d, within_tolerances) |
482 } |
321 } |
483 }; |
322 }; |
484 |
323 |
485 // Update γ_hat to new location |
324 stats.untransported_fraction = Some({ |
486 for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { |
325 assert_eq!(μ_base_masses.len(), γ1.len()); |
487 // Prune rays that only had a return component, as the return component becomes |
326 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
488 // a diagonal in γ̂^{k+1}. |
327 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
489 r.rays.retain(|ray| ray.δ.α != 0.0); |
328 (a + μ_base_minus_γ0.norm(Radon), b + source) |
490 // Otherwise zero out the return component, or stage rays for pruning |
329 }); |
491 // to keep memory and computational demands reasonable. |
330 stats.transport_error = Some({ |
492 let n_rays = r.rays.len(); |
331 assert_eq!(μ_base_masses.len(), γ1.len()); |
493 for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { |
332 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
494 if ir >= sfbconfig.maximum_rays { |
333 let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum(); |
495 // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for |
334 (a + err, b + γ1.norm(Radon)) |
496 // pruning in next step. |
335 }); |
497 ray.to_return = ray.δ.α; |
336 |
498 ray.δ.α = 0.0; |
337 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
499 } else { |
338 // latter needs to be pruned when μ is. |
500 ray.to_return = 0.0; |
339 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
501 } |
340 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
502 ray.goodness = 0.0; // TODO: probably not needed |
341 if μ_new.len() != μ.len() { |
503 ray.reg_goodness = 0.0; |
342 let mut μ_iter = μ.iter_spikes(); |
504 } |
343 γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); |
505 // Add a new ray for the currently diagonal component |
344 μ = μ_new; |
506 if r.diagonal > 0.0 { |
|
507 r.rays.push(Ray{ |
|
508 δ : DeltaMeasure{x : r.source, α : r.diagonal}, |
|
509 goodness : 0.0, |
|
510 reg_goodness : 0.0, |
|
511 to_return : 0.0, |
|
512 }); |
|
513 // TODO: Maybe this does not need to be done here, and is sufficent to to do where |
|
514 // the goodness is calculated. |
|
515 r.diagonal = 0.0; |
|
516 } |
|
517 r.diagonal_goodness = 0.0; |
|
518 |
|
519 // Shift source |
|
520 r.source = δ_new.x; |
|
521 } |
|
522 // Extend to new spikes |
|
523 γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| { |
|
524 RaySet{ |
|
525 source : δ_new.x, |
|
526 rays : [].into(), |
|
527 diagonal : 0.0, |
|
528 diagonal_goodness : 0.0, |
|
529 diagonal_reg_goodness : 0.0 |
|
530 } |
|
531 })); |
|
532 |
|
533 // Prune spikes with zero weight. This also moves the marginal differences of corresponding |
|
534 // transports from γ_hat to γ_pruned_marginal_diff. |
|
535 // TODO: optimise standard prune with swap_remove. |
|
536 μ_transported_base.clear(); |
|
537 let mut i = 0; |
|
538 assert_eq!(μ.len(), γ_hat.len()); |
|
539 while i < μ.len() { |
|
540 if μ[i].α == F::ZERO { |
|
541 μ.swap_remove(i); |
|
542 let r = γ_hat.swap_remove(i); |
|
543 μ_transported_base.extend(r.targets().cloned()); |
|
544 μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source }; |
|
545 } else { |
|
546 i += 1; |
|
547 } |
|
548 } |
345 } |
549 |
346 |
550 // TODO: how to merge? |
347 // TODO: how to merge? |
551 |
348 |
552 // Update residual |
349 // Update residual |