Tue, 31 Dec 2024 09:34:24 -0500
Early transport sketches
| 32 | 1 | /*! |
| 2 | Solver for the point source localisation problem using a sliding | |
| 3 | forward-backward splitting method. | |
| 4 | */ | |
| 5 | ||
| 6 | use numeric_literals::replace_float_literals; | |
| 7 | use serde::{Serialize, Deserialize}; | |
| 8 | //use colored::Colorize; | |
| 9 | //use nalgebra::{DVector, DMatrix}; | |
| 10 | use itertools::izip; | |
| 11 | use std::iter::{Map, Flatten}; | |
| 12 | ||
| 13 | use alg_tools::iterate::{ | |
| 14 | AlgIteratorFactory, | |
| 15 | AlgIteratorState | |
| 16 | }; | |
| 17 | use alg_tools::euclidean::{ | |
| 18 | Euclidean, | |
| 19 | Dot | |
| 20 | }; | |
| 21 | use alg_tools::sets::Cube; | |
| 22 | use alg_tools::loc::Loc; | |
| 23 | use alg_tools::mapping::{Apply, Differentiable}; | |
| 24 | use alg_tools::bisection_tree::{ | |
| 25 | BTFN, | |
| 26 | PreBTFN, | |
| 27 | Bounds, | |
| 28 | BTNodeLookup, | |
| 29 | BTNode, | |
| 30 | BTSearch, | |
| 31 | P2Minimise, | |
| 32 | SupportGenerator, | |
| 33 | LocalAnalysis, | |
| 34 | //Bounded, | |
| 35 | }; | |
| 36 | use alg_tools::mapping::RealMapping; | |
| 37 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
| 38 | ||
| 39 | use crate::types::*; | |
| 40 | use crate::measures::{ | |
| 41 | DiscreteMeasure, | |
| 42 | DeltaMeasure, | |
| 43 | }; | |
| 44 | use crate::measures::merging::{ | |
| 45 | //SpikeMergingMethod, | |
| 46 | SpikeMerging, | |
| 47 | }; | |
| 48 | use crate::forward_model::ForwardModel; | |
| 49 | use crate::seminorms::DiscreteMeasureOp; | |
| 50 | //use crate::tolerance::Tolerance; | |
| 51 | use crate::plot::{ | |
| 52 | SeqPlotter, | |
| 53 | Plotting, | |
| 54 | PlotLookup | |
| 55 | }; | |
| 56 | use crate::fb::*; | |
| 57 | use crate::regularisation::SlidingRegTerm; | |
| 58 | use crate::dataterm::{ | |
| 59 | L2Squared, | |
| 60 | //DataTerm, | |
| 61 | calculate_residual, | |
| 62 | calculate_residual2, | |
| 63 | }; | |
| 64 | use crate::transport::TransportLipschitz; | |
| 65 | ||
| 66 | /// Settings for [`pointsource_sliding_fb_reg`]. | |
| 67 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
| 68 | #[serde(default)] | |
| 69 | pub struct SlidingFBConfig<F : Float> { | |
| 70 | /// Step length scaling | |
| 71 | pub τ0 : F, | |
| 72 | /// Transport smoothness assumption | |
| 73 | pub ℓ0 : F, | |
| 74 | /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. | |
| 75 | /// This means that $τθ$ is the step length for the transport step. | |
| 76 | pub inverse_transport_scaling : F, | |
| 77 | /// Factor for deciding transport reduction based on smoothness assumption violation | |
| 78 | pub minimum_goodness_factor : F, | |
| 79 | /// Maximum rays to retain in transports from each source. | |
| 80 | pub maximum_rays : usize, | |
| 81 | /// Generic parameters | |
| 82 | pub insertion : FBGenericConfig<F>, | |
| 83 | } | |
| 84 | ||
| 85 | #[replace_float_literals(F::cast_from(literal))] | |
| 86 | impl<F : Float> Default for SlidingFBConfig<F> { | |
| 87 | fn default() -> Self { | |
| 88 | SlidingFBConfig { | |
| 89 | τ0 : 0.99, | |
| 90 | ℓ0 : 1.5, | |
| 91 | inverse_transport_scaling : 1.0, | |
| 92 | minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, | |
| 93 | // should be scaled by problem data? | |
| 94 | maximum_rays : 10, | |
| 95 | insertion : Default::default() | |
| 96 | } | |
| 97 | } | |
| 98 | } | |
| 99 | ||
| 100 | /// A transport ray (including various additional computational information). | |
| 101 | #[derive(Clone, Debug)] | |
| 102 | pub struct Ray<Domain, F : Num> { | |
| 103 | /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`]. | |
| 104 | δ : DeltaMeasure<Domain, F>, | |
| 105 | /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$. | |
| 106 | goodness : F, | |
| 107 | /// Goodness of the regularisation term for the ray: $w(z)-w(y)$. | |
| 108 | /// Initially zero until $w$ can be constructed. | |
| 109 | reg_goodness : F, | |
| 110 | /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`. | |
| 111 | to_return : F, | |
| 112 | } | |
| 113 | ||
| 114 | /// A set of transport rays with the same source point. | |
| 115 | #[derive(Clone, Debug)] | |
| 116 | pub struct RaySet<Domain, F : Num> { | |
| 117 | /// Source of every ray in thset | |
| 118 | source : Domain, | |
| 119 | /// Mass of the diagonal ray, with destination the same as the source. | |
| 120 | diagonal: F, | |
| 121 | /// Goodness of the data term for the diagonal ray with $z=x$: | |
| 122 | /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$. | |
| 123 | diagonal_goodness : F, | |
| 124 | /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$. | |
| 125 | diagonal_reg_goodness : F, | |
| 126 | /// The non-diagonal rays. | |
| 127 | rays : Vec<Ray<Domain, F>>, | |
| 128 | } | |
| 129 | ||
| 130 | #[replace_float_literals(F::cast_from(literal))] | |
| 131 | impl<Domain, F : Float> RaySet<Domain, F> { | |
| 132 | fn non_diagonal_mass(&self) -> F { | |
| 133 | self.rays | |
| 134 | .iter() | |
| 135 | .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) | |
| 136 | .sum() | |
| 137 | } | |
| 138 | ||
| 139 | fn total_mass(&self) -> F { | |
| 140 | self.non_diagonal_mass() + self.diagonal | |
| 141 | } | |
| 142 | ||
| 143 | fn targets<'a>(&'a self) | |
| 144 | -> Map< | |
| 145 | std::slice::Iter<'a, Ray<Domain, F>>, | |
| 146 | fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F> | |
| 147 | > { | |
| 148 | fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>) | |
| 149 | -> &'b DeltaMeasure<Domain, F> { | |
| 150 | δ | |
| 151 | } | |
| 152 | self.rays | |
| 153 | .iter() | |
| 154 | .map(get_δ) | |
| 155 | } | |
| 156 | ||
| 157 | // fn non_diagonal_goodness(&self) -> F { | |
| 158 | // self.rays | |
| 159 | // .iter() | |
| 160 | // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { | |
| 161 | // α * (goodness + reg_goodness) | |
| 162 | // }) | |
| 163 | // .sum() | |
| 164 | // } | |
| 165 | ||
| 166 | // fn total_goodness(&self) -> F { | |
| 167 | // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness) | |
| 168 | // } | |
| 169 | ||
| 170 | fn non_diagonal_badness(&self) -> F { | |
| 171 | self.rays | |
| 172 | .iter() | |
| 173 | .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { | |
| 174 | 0.0.max(- α * (goodness + reg_goodness)) | |
| 175 | }) | |
| 176 | .sum() | |
| 177 | } | |
| 178 | ||
| 179 | fn total_badness(&self) -> F { | |
| 180 | self.non_diagonal_badness() | |
| 181 | + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness)) | |
| 182 | } | |
| 183 | ||
| 184 | fn total_return(&self) -> F { | |
| 185 | self.rays | |
| 186 | .iter() | |
| 187 | .map(|&Ray{ to_return, .. }| to_return) | |
| 188 | .sum() | |
| 189 | } | |
| 190 | } | |
| 191 | ||
| 192 | #[replace_float_literals(F::cast_from(literal))] | |
| 193 | impl<Domain : Clone, F : Num> RaySet<Domain, F> { | |
| 194 | fn return_targets<'a>(&'a self) | |
| 195 | -> Flatten<Map< | |
| 196 | std::slice::Iter<'a, Ray<Domain, F>>, | |
| 197 | fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> | |
| 198 | >> { | |
| 199 | fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>) | |
| 200 | -> Option<DeltaMeasure<Domain, F>> { | |
| 201 | (ray.to_return != 0.0).then_some( | |
| 202 | DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return} | |
| 203 | ) | |
| 204 | } | |
| 205 | let tmp : Map< | |
| 206 | std::slice::Iter<'a, Ray<Domain, F>>, | |
| 207 | fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> | |
| 208 | > = self.rays | |
| 209 | .iter() | |
| 210 | .map(get_return); | |
| 211 | tmp.flatten() | |
| 212 | } | |
| 213 | } | |
| 214 | ||
| 215 | /// Iteratively solve the pointsource localisation problem using sliding forward-backward | |
| 216 | /// splitting | |
| 217 | /// | |
| 218 | /// The parametrisatio is as for [`pointsource_fb_reg`]. | |
| 219 | /// Inertia is currently not supported. | |
| 220 | #[replace_float_literals(F::cast_from(literal))] | |
| 221 | pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( | |
| 222 | opA : &'a A, | |
| 223 | b : &A::Observable, | |
| 224 | reg : Reg, | |
| 225 | op𝒟 : &'a 𝒟, | |
| 226 | sfbconfig : &SlidingFBConfig<F>, | |
| 227 | iterator : I, | |
| 228 | mut plotter : SeqPlotter<F, N>, | |
| 229 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
| 230 | where F : Float + ToNalgebraRealField, | |
| 231 | I : AlgIteratorFactory<IterInfo<F, N>>, | |
| 232 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, | |
| 233 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow | |
| 234 | A::Observable : std::ops::MulAssign<F>, | |
| 235 | A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>, | |
| 236 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
| 237 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> | |
| 238 | + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, | |
| 239 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
| 240 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
| 241 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
| 242 | 𝒟::Codomain : RealMapping<F, N>, | |
| 243 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> | |
| 244 | + Differentiable<Loc<F, N>, Output=Loc<F,N>>, | |
| 245 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
| 246 | //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, | |
| 247 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
| 248 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
| 249 | PlotLookup : Plotting<N>, | |
| 250 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
| 251 | Reg : SlidingRegTerm<F, N> { | |
| 252 | ||
| 253 | assert!(sfbconfig.τ0 > 0.0 && | |
| 254 | sfbconfig.inverse_transport_scaling > 0.0 && | |
| 255 | sfbconfig.ℓ0 > 0.0); | |
| 256 | ||
| 257 | // Set up parameters | |
| 258 | let config = &sfbconfig.insertion; | |
| 259 | let op𝒟norm = op𝒟.opnorm_bound(); | |
| 260 | let θ = sfbconfig.inverse_transport_scaling; | |
| 261 | let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() | |
| 262 | .max(opA.transport_lipschitz_factor(L2Squared) * θ); | |
| 263 | let ℓ = sfbconfig.ℓ0; // TODO: v scaling? | |
| 264 | // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled | |
| 265 | // by τ compared to the conditional gradient approach. | |
| 266 | let tolerance = config.tolerance * τ * reg.tolerance_scaling(); | |
| 267 | let mut ε = tolerance.initial(); | |
| 268 | ||
| 269 | // Initialise iterates | |
| 270 | let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); | |
| 271 | let mut μ_transported_base = DiscreteMeasure::new(); | |
| 272 | let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info | |
| 273 | let mut residual = -b; | |
| 274 | let mut stats = IterInfo::new(); | |
| 275 | ||
| 276 | // Run the algorithm | |
| 277 | iterator.iterate(|state| { | |
| 278 | // Calculate smooth part of surrogate model. | |
| 279 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
| 280 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
| 281 | // the residual and replacing it below before the end of this closure. | |
| 282 | residual *= -τ; | |
| 283 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
| 284 | let minus_τv = opA.preadjoint().apply(r); | |
| 285 | ||
| 286 | // Save current base point and shift μ to new positions. | |
| 287 | let μ_base = μ.clone(); | |
| 288 | for δ in μ.iter_spikes_mut() { | |
| 289 | δ.x += minus_τv.differential(&δ.x) * θ; | |
| 290 | } | |
| 291 | let mut μ_transported = μ.clone(); | |
| 292 | ||
| 293 | assert_eq!(μ.len(), γ_hat.len()); | |
| 294 | ||
| 295 | // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter | |
| 296 | // transports points x from μ_base to points y in μ as shifted above, or “returns” | |
| 297 | // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays | |
| 298 | // are not “good” for the smoothness assumptions, or if γ_hat has more mass than | |
| 299 | // μ_base. | |
| 300 | let mut total_goodness = 0.0; // data term goodness | |
| 301 | let mut total_reg_goodness = 0.0; // regulariser goodness | |
| 302 | let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; | |
| 303 | ||
| 304 | for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { | |
| 305 | // Calculate data term goodness for all rays. | |
| 306 | let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; | |
| 307 | let x = &r.source; | |
| 308 | let mvy = minus_τv.apply(y); | |
| 309 | let mdvx = minus_τv.differential(x); | |
| 310 | let mut r_total_mass = 0.0; // Total mass of all rays with source r.source. | |
| 311 | let mut bad_mass = 0.0; | |
| 312 | let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| { | |
| 313 | *reg_goodness = 0.0; // Initial guess | |
| 314 | *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y)) | |
| 315 | + ℓ * z.dist2_squared(&y); | |
| 316 | total_goodness += *goodness * α; | |
| 317 | r_total_mass += α; // TODO: should this include to_return from staging? (Probably not) | |
| 318 | if *goodness < 0.0 { | |
| 319 | bad_mass += α; | |
| 320 | } | |
| 321 | }; | |
| 322 | for ray in r.rays.iter_mut() { | |
| 323 | calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); | |
| 324 | } | |
| 325 | calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); | |
| 326 | ||
| 327 | // If the total mass of the ray set is less than that of μ at the same source, | |
| 328 | // a diagonal component needs to be added to be able to (attempt to) transport | |
| 329 | // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ | |
| 330 | // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” | |
| 331 | // rays in the set. | |
| 332 | if δ_mass >= r_total_mass { | |
| 333 | r.diagonal += δ_mass - r_total_mass; | |
| 334 | } else { | |
| 335 | let mut reduce_transport = r_total_mass - δ_mass; | |
| 336 | let mut good_needed = (bad_mass - reduce_transport).max(0.0); | |
| 337 | // NOTE: reg_goodness is zero at this point, so it is not used in this code. | |
| 338 | let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| { | |
| 339 | if reduce_transport > 0.0 { | |
| 340 | let return_amount = if goodness < 0.0 { | |
| 341 | α.min(reduce_transport) | |
| 342 | } else { | |
| 343 | let amount = α.min(good_needed); | |
| 344 | good_needed -= amount; | |
| 345 | amount | |
| 346 | }; | |
| 347 | ||
| 348 | if return_amount > 0.0 { | |
| 349 | reduce_transport -= return_amount; | |
| 350 | // Adjust total goodness by returned amount | |
| 351 | total_goodness -= goodness * return_amount; | |
| 352 | to_return.map(|tr| *tr += return_amount); | |
| 353 | *α -= return_amount; | |
| 354 | *α > 0.0 | |
| 355 | } else { | |
| 356 | true | |
| 357 | } | |
| 358 | } else { | |
| 359 | true | |
| 360 | } | |
| 361 | }; | |
| 362 | r.rays.retain_mut(|ray| { | |
| 363 | reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α) | |
| 364 | }); | |
| 365 | // A bad diagonal is simply reduced without any 'return'. | |
| 366 | // It was, after all, just added to match μ, but there is no need to match it. | |
| 367 | // It's just a heuristic. | |
| 368 | // TODO: Maybe a bad diagonal should be the first to go. | |
| 369 | reduce_ray(r.diagonal_goodness, None, &mut r.diagonal); | |
| 370 | } | |
| 371 | } | |
| 372 | ||
| 373 | // Solve finite-dimensional subproblem several times until the dual variable for the | |
| 374 | // regularisation term conforms to the assumptions made for the transport above. | |
| 375 | let (d, within_tolerances) = 'adapt_transport: loop { | |
| 376 | // If transport violates goodness requirements, shift it to ‘return’ mass to z, | |
| 377 | // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), | |
| 378 | // we proportionally distribute the reductions to each ray set, and within each ray | |
| 379 | // set, prioritise reducing the oldest bad rays' weight. | |
| 380 | let tg = total_goodness + total_reg_goodness; | |
| 381 | let adaptation_needed = minimum_goodness - tg; | |
| 382 | if adaptation_needed > 0.0 { | |
| 383 | let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); | |
| 384 | ||
| 385 | let mut return_ray = |goodness : F, | |
| 386 | reg_goodness : F, | |
| 387 | to_return : Option<&mut F>, | |
| 388 | α : &mut F, | |
| 389 | left_to_return : &mut F| { | |
| 390 | let g = goodness + reg_goodness; | |
| 391 | assert!(*α >= 0.0 && *left_to_return >= 0.0); | |
| 392 | if *left_to_return > 0.0 && g < 0.0 { | |
| 393 | let return_amount = (*left_to_return / (-g)).min(*α); | |
| 394 | *left_to_return -= (-g) * return_amount; | |
| 395 | total_goodness -= goodness * return_amount; | |
| 396 | total_reg_goodness -= reg_goodness * return_amount; | |
| 397 | to_return.map(|tr| *tr += return_amount); | |
| 398 | *α -= return_amount; | |
| 399 | *α > 0.0 | |
| 400 | } else { | |
| 401 | true | |
| 402 | } | |
| 403 | }; | |
| 404 | ||
| 405 | for r in γ_hat.iter_mut() { | |
| 406 | let mut left_to_return = adaptation_needed * r.total_badness() / total_badness; | |
| 407 | if left_to_return > 0.0 { | |
| 408 | for ray in r.rays.iter_mut() { | |
| 409 | return_ray(ray.goodness, ray.reg_goodness, | |
| 410 | Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return); | |
| 411 | } | |
| 412 | return_ray(r.diagonal_goodness, r.diagonal_reg_goodness, | |
| 413 | None, &mut r.diagonal, &mut left_to_return); | |
| 414 | } | |
| 415 | } | |
| 416 | } | |
| 417 | ||
| 418 | // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. | |
| 419 | // This can be broken down into | |
| 420 | // | |
| 421 | // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and | |
| 422 | // μ_transported = π_#^1 γ_shift | |
| 423 | // | |
| 424 | // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. | |
| 425 | // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ | |
| 426 | // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure | |
| 427 | // The latter can be constructed from γ_hat rays' source and destination with the | |
| 428 | // to_return mass. | |
| 429 | // | |
| 430 | // Note that μ_transported is constructed to have the same spike locations as μ, but | |
| 431 | // to have same length as μ_base. This loop does not iterate over the spikes of μ | |
| 432 | // (and corresponding transports of γ_hat) that have been newly added in the current | |
| 433 | // 'adapt_transport loop. | |
| 434 | for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), | |
| 435 | μ_transported.iter_spikes_mut(), | |
| 436 | γ_hat.iter()) { | |
| 437 | let &DeltaMeasure{ref x, α} = δ; | |
| 438 | debug_assert_eq!(*x, r.source); | |
| 439 | let shifted_mass = r.total_mass(); | |
| 440 | let ret_mass = r.total_return(); | |
| 441 | // μ - π_#^0 (γ_shift + γ_return) | |
| 442 | μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; | |
| 443 | // π_#^1 γ_return | |
| 444 | μ_transported_base.extend(r.return_targets()); | |
| 445 | // π_#^1 γ_shift | |
| 446 | δ_transported.set_mass(shifted_mass); | |
| 447 | } | |
| 448 | // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) | |
| 449 | let transported_residual = calculate_residual2(&μ_transported, | |
| 450 | &μ_transported_base, | |
| 451 | opA, b); | |
| 452 | let transported_minus_τv = opA.preadjoint() | |
| 453 | .apply(transported_residual); | |
| 454 | ||
| 455 | // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | |
| 456 | let (mut d, within_tolerances) = insert_and_reweigh( | |
| 457 | &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), | |
| 458 | op𝒟, op𝒟norm, | |
| 459 | τ, ε, | |
| 460 | config, ®, state, &mut stats | |
| 461 | ); | |
| 462 | ||
| 463 | // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely | |
| 464 | // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); | |
| 465 | // We “essentially” assume that the subdifferential w of the regularisation term | |
| 466 | // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] | |
| 467 | // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for | |
| 468 | // some w in the subdifferential of the regularisation term, such that | |
| 469 | // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`]. | |
| 470 | for r in γ_hat.iter_mut() { | |
| 471 | for ray in r.rays.iter_mut() { | |
| 472 | ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config); | |
| 473 | total_reg_goodness += ray.reg_goodness * ray.δ.α; | |
| 474 | } | |
| 475 | } | |
| 476 | ||
| 477 | // If update of regularisation term goodness didn't invalidate minimum goodness | |
| 478 | // requirements, we have found our step. Otherwise we need to keep reducing | |
| 479 | // transport by repeating the loop. | |
| 480 | if total_goodness + total_reg_goodness >= minimum_goodness { | |
| 481 | break 'adapt_transport (d, within_tolerances) | |
| 482 | } | |
| 483 | }; | |
| 484 | ||
| 485 | // Update γ_hat to new location | |
| 486 | for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { | |
| 487 | // Prune rays that only had a return component, as the return component becomes | |
| 488 | // a diagonal in γ̂^{k+1}. | |
| 489 | r.rays.retain(|ray| ray.δ.α != 0.0); | |
| 490 | // Otherwise zero out the return component, or stage rays for pruning | |
| 491 | // to keep memory and computational demands reasonable. | |
| 492 | let n_rays = r.rays.len(); | |
| 493 | for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { | |
| 494 | if ir >= sfbconfig.maximum_rays { | |
| 495 | // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for | |
| 496 | // pruning in next step. | |
| 497 | ray.to_return = ray.δ.α; | |
| 498 | ray.δ.α = 0.0; | |
| 499 | } else { | |
| 500 | ray.to_return = 0.0; | |
| 501 | } | |
| 502 | ray.goodness = 0.0; // TODO: probably not needed | |
| 503 | ray.reg_goodness = 0.0; | |
| 504 | } | |
| 505 | // Add a new ray for the currently diagonal component | |
| 506 | if r.diagonal > 0.0 { | |
| 507 | r.rays.push(Ray{ | |
| 508 | δ : DeltaMeasure{x : r.source, α : r.diagonal}, | |
| 509 | goodness : 0.0, | |
| 510 | reg_goodness : 0.0, | |
| 511 | to_return : 0.0, | |
| 512 | }); | |
| 513 | // TODO: Maybe this does not need to be done here, and is sufficent to to do where | |
| 514 | // the goodness is calculated. | |
| 515 | r.diagonal = 0.0; | |
| 516 | } | |
| 517 | r.diagonal_goodness = 0.0; | |
| 518 | ||
| 519 | // Shift source | |
| 520 | r.source = δ_new.x; | |
| 521 | } | |
| 522 | // Extend to new spikes | |
| 523 | γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| { | |
| 524 | RaySet{ | |
| 525 | source : δ_new.x, | |
| 526 | rays : [].into(), | |
| 527 | diagonal : 0.0, | |
| 528 | diagonal_goodness : 0.0, | |
| 529 | diagonal_reg_goodness : 0.0 | |
| 530 | } | |
| 531 | })); | |
| 532 | ||
| 533 | // Prune spikes with zero weight. This also moves the marginal differences of corresponding | |
| 534 | // transports from γ_hat to γ_pruned_marginal_diff. | |
| 535 | // TODO: optimise standard prune with swap_remove. | |
| 536 | μ_transported_base.clear(); | |
| 537 | let mut i = 0; | |
| 538 | assert_eq!(μ.len(), γ_hat.len()); | |
| 539 | while i < μ.len() { | |
| 540 | if μ[i].α == F::ZERO { | |
| 541 | μ.swap_remove(i); | |
| 542 | let r = γ_hat.swap_remove(i); | |
| 543 | μ_transported_base.extend(r.targets().cloned()); | |
| 544 | μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source }; | |
| 545 | } else { | |
| 546 | i += 1; | |
| 547 | } | |
| 548 | } | |
| 549 | ||
| 550 | // TODO: how to merge? | |
| 551 | ||
| 552 | // Update residual | |
| 553 | residual = calculate_residual(&μ, opA, b); | |
| 554 | ||
| 555 | // Update main tolerance for next iteration | |
| 556 | let ε_prev = ε; | |
| 557 | ε = tolerance.update(ε, state.iteration()); | |
| 558 | stats.this_iters += 1; | |
| 559 | ||
| 560 | // Give function value if needed | |
| 561 | state.if_verbose(|| { | |
| 562 | // Plot if so requested | |
| 563 | plotter.plot_spikes( | |
| 564 | format!("iter {} end; {}", state.iteration(), within_tolerances), &d, | |
| 565 | "start".to_string(), Some(&minus_τv), | |
| 566 | reg.target_bounds(τ, ε_prev), &μ, | |
| 567 | ); | |
| 568 | // Calculate mean inner iterations and reset relevant counters. | |
| 569 | // Return the statistics | |
| 570 | let res = IterInfo { | |
| 571 | value : residual.norm2_squared_div2() + reg.apply(&μ), | |
| 572 | n_spikes : μ.len(), | |
| 573 | ε : ε_prev, | |
| 574 | postprocessing: config.postprocessing.then(|| μ.clone()), | |
| 575 | .. stats | |
| 576 | }; | |
| 577 | stats = IterInfo::new(); | |
| 578 | res | |
| 579 | }) | |
| 580 | }); | |
| 581 | ||
| 582 | postprocess(μ, config, L2Squared, opA, b) | |
| 583 | } |