src/sliding_fb.rs

branch
dev
changeset 44
03251c546744
parent 41
b6bdb6cb4d44
child 45
5200e7090e06
equal deleted inserted replaced
43:aacd9af21b3a 44:03251c546744
20 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 20 use crate::measures::{DiscreteMeasure, Radon, RNDM};
21 use crate::measures::merging::SpikeMerging; 21 use crate::measures::merging::SpikeMerging;
22 use crate::forward_model::{ 22 use crate::forward_model::{
23 ForwardModel, 23 ForwardModel,
24 AdjointProductBoundedBy, 24 AdjointProductBoundedBy,
25 LipschitzValues, 25 BoundedCurvature,
26 }; 26 };
27 //use crate::tolerance::Tolerance; 27 //use crate::tolerance::Tolerance;
28 use crate::plot::{ 28 use crate::plot::{
29 SeqPlotter, 29 SeqPlotter,
30 Plotting, 30 Plotting,
97 } 97 }
98 98
99 /// Internal type of adaptive transport step length calculation 99 /// Internal type of adaptive transport step length calculation
100 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> { 100 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> {
101 /// Fixed, known step length 101 /// Fixed, known step length
102 #[allow(dead_code)]
102 Fixed(F), 103 Fixed(F),
103 /// Adaptive step length, only wrt. maximum transport. 104 /// Adaptive step length, only wrt. maximum transport.
104 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. 105 /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
105 AdaptiveMax{ l : F, max_transport : F, g : G }, 106 AdaptiveMax{ l : F, max_transport : F, g : G },
106 /// Adaptive step length. 107 /// Adaptive step length.
156 let θτ = τ * θ; 157 let θτ = τ * θ;
157 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 158 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
158 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 159 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
159 } 160 }
160 }, 161 },
161 AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => { 162 AdaptiveMax{ l : ℓ_F, ref mut max_transport, g : ref calculate_θ } => {
162 *max_transport = max_transport.max(γ1.norm(Radon)); 163 *max_transport = max_transport.max(γ1.norm(Radon));
163 let θτ = τ * calculate_θ(ℓ_v, *max_transport); 164 let θτ = τ * calculate_θ(ℓ_F, *max_transport);
164 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 165 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
165 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); 166 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
166 } 167 }
167 }, 168 },
168 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { 169 FullyAdaptive{ l : ref mut adaptive_ℓ_F, ref mut max_transport, g : ref calculate_θ } => {
169 *max_transport = max_transport.max(γ1.norm(Radon)); 170 *max_transport = max_transport.max(γ1.norm(Radon));
170 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); 171 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
171 // Do two runs through the spikes to update θ, breaking if first run did not cause 172 // Do two runs through the spikes to update θ, breaking if first run did not cause
172 // a change. 173 // a change.
173 for _i in 0..=1 { 174 for _i in 0..=1 {
174 let mut changes = false; 175 let mut changes = false;
175 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 176 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
177 let g = &dv_x * (ρ.α.signum() * θ * τ); 178 let g = &dv_x * (ρ.α.signum() * θ * τ);
178 ρ.x = δ.x - g; 179 ρ.x = δ.x - g;
179 let n = g.norm2(); 180 let n = g.norm2();
180 if n >= F::EPSILON { 181 if n >= F::EPSILON {
181 // Estimate Lipschitz factor of ∇v 182 // Estimate Lipschitz factor of ∇v
182 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2() / n; 183 let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n;
183 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); 184 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
184 θ = calculate_θ(*adaptive_ℓ_v, *max_transport); 185 θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
185 changes = true 186 changes = true
186 } 187 }
187 } 188 }
188 if !changes { 189 if !changes {
189 break 190 break
272 ) -> RNDM<F, N> 273 ) -> RNDM<F, N>
273 where 274 where
274 F : Float + ToNalgebraRealField, 275 F : Float + ToNalgebraRealField,
275 I : AlgIteratorFactory<IterInfo<F, N>>, 276 I : AlgIteratorFactory<IterInfo<F, N>>,
276 A : ForwardModel<RNDM<F, N>, F> 277 A : ForwardModel<RNDM<F, N>, F>
277 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, 278 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>
278 //+ TransportLipschitz<L2Squared, FloatType=F>, 279 + BoundedCurvature<FloatType=F>,
279 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, 280 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
280 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
281 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, 281 A::PreadjointCodomain : DifferentiableRealMapping<F, N>,
282 RNDM<F, N> : SpikeMerging<F>, 282 RNDM<F, N> : SpikeMerging<F>,
283 Reg : SlidingRegTerm<F, N>, 283 Reg : SlidingRegTerm<F, N>,
284 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, 284 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
285 PlotLookup : Plotting<N>, 285 PlotLookup : Plotting<N>,
299 //let max_transport = config.max_transport.scale 299 //let max_transport = config.max_transport.scale
300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
302 let ℓ = 0.0; 302 let ℓ = 0.0;
303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); 303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
304 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); 304 let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components();
305 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { 305 let transport_lip = maybe_transport_lip.unwrap();
306 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v 306 let calculate_θ = |ℓ_v, max_transport| {
307 // (the uniform Lipschitz factor of ∇v). 307 let ℓ_F = ℓ_v + transport_lip * max_transport;
308 // We assume that the residual is decreasing. 308 config.transport.θ0 / (τ*(ℓ + ℓ_F))
309 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), 309 };
310 let mut θ_or_adaptive = match maybe_ℓ_v0 {
311 //Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)),
312 Some(ℓ_v0) => TransportStepLength::AdaptiveMax {
313 l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual
314 max_transport : 0.0,
315 g : calculate_θ
316 },
310 None => TransportStepLength::FullyAdaptive { 317 None => TransportStepLength::FullyAdaptive {
311 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials 318 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials
312 max_transport : 0.0, 319 max_transport : 0.0,
313 g : calculate_θ 320 g : calculate_θ
314 }, 321 },

mercurial