20 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
20 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
21 use crate::measures::merging::SpikeMerging; |
21 use crate::measures::merging::SpikeMerging; |
22 use crate::forward_model::{ |
22 use crate::forward_model::{ |
23 ForwardModel, |
23 ForwardModel, |
24 AdjointProductBoundedBy, |
24 AdjointProductBoundedBy, |
25 LipschitzValues, |
25 BoundedCurvature, |
26 }; |
26 }; |
27 //use crate::tolerance::Tolerance; |
27 //use crate::tolerance::Tolerance; |
28 use crate::plot::{ |
28 use crate::plot::{ |
29 SeqPlotter, |
29 SeqPlotter, |
30 Plotting, |
30 Plotting, |
97 } |
97 } |
98 |
98 |
99 /// Internal type of adaptive transport step length calculation |
99 /// Internal type of adaptive transport step length calculation |
100 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> { |
100 pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> { |
101 /// Fixed, known step length |
101 /// Fixed, known step length |
|
102 #[allow(dead_code)] |
102 Fixed(F), |
103 Fixed(F), |
103 /// Adaptive step length, only wrt. maximum transport. |
104 /// Adaptive step length, only wrt. maximum transport. |
104 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
105 /// Content of `l` depends on use case, while `g` calculates the step length from `l`. |
105 AdaptiveMax{ l : F, max_transport : F, g : G }, |
106 AdaptiveMax{ l : F, max_transport : F, g : G }, |
106 /// Adaptive step length. |
107 /// Adaptive step length. |
156 let θτ = τ * θ; |
157 let θτ = τ * θ; |
157 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
158 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
158 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
159 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
159 } |
160 } |
160 }, |
161 }, |
161 AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => { |
162 AdaptiveMax{ l : ℓ_F, ref mut max_transport, g : ref calculate_θ } => { |
162 *max_transport = max_transport.max(γ1.norm(Radon)); |
163 *max_transport = max_transport.max(γ1.norm(Radon)); |
163 let θτ = τ * calculate_θ(ℓ_v, *max_transport); |
164 let θτ = τ * calculate_θ(ℓ_F, *max_transport); |
164 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
165 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
165 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
166 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); |
166 } |
167 } |
167 }, |
168 }, |
168 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { |
169 FullyAdaptive{ l : ref mut adaptive_ℓ_F, ref mut max_transport, g : ref calculate_θ } => { |
169 *max_transport = max_transport.max(γ1.norm(Radon)); |
170 *max_transport = max_transport.max(γ1.norm(Radon)); |
170 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); |
171 let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); |
171 // Do two runs through the spikes to update θ, breaking if first run did not cause |
172 // Do two runs through the spikes to update θ, breaking if first run did not cause |
172 // a change. |
173 // a change. |
173 for _i in 0..=1 { |
174 for _i in 0..=1 { |
174 let mut changes = false; |
175 let mut changes = false; |
175 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
176 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
177 let g = &dv_x * (ρ.α.signum() * θ * τ); |
178 let g = &dv_x * (ρ.α.signum() * θ * τ); |
178 ρ.x = δ.x - g; |
179 ρ.x = δ.x - g; |
179 let n = g.norm2(); |
180 let n = g.norm2(); |
180 if n >= F::EPSILON { |
181 if n >= F::EPSILON { |
181 // Estimate Lipschitz factor of ∇v |
182 // Estimate Lipschitz factor of ∇v |
182 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2() / n; |
183 let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n; |
183 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); |
184 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); |
184 θ = calculate_θ(*adaptive_ℓ_v, *max_transport); |
185 θ = calculate_θ(*adaptive_ℓ_F, *max_transport); |
185 changes = true |
186 changes = true |
186 } |
187 } |
187 } |
188 } |
188 if !changes { |
189 if !changes { |
189 break |
190 break |
272 ) -> RNDM<F, N> |
273 ) -> RNDM<F, N> |
273 where |
274 where |
274 F : Float + ToNalgebraRealField, |
275 F : Float + ToNalgebraRealField, |
275 I : AlgIteratorFactory<IterInfo<F, N>>, |
276 I : AlgIteratorFactory<IterInfo<F, N>>, |
276 A : ForwardModel<RNDM<F, N>, F> |
277 A : ForwardModel<RNDM<F, N>, F> |
277 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
278 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F> |
278 //+ TransportLipschitz<L2Squared, FloatType=F>, |
279 + BoundedCurvature<FloatType=F>, |
279 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
280 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
280 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
|
281 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, |
281 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, |
282 RNDM<F, N> : SpikeMerging<F>, |
282 RNDM<F, N> : SpikeMerging<F>, |
283 Reg : SlidingRegTerm<F, N>, |
283 Reg : SlidingRegTerm<F, N>, |
284 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
284 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
285 PlotLookup : Plotting<N>, |
285 PlotLookup : Plotting<N>, |
299 //let max_transport = config.max_transport.scale |
299 //let max_transport = config.max_transport.scale |
300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
300 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
301 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
302 let ℓ = 0.0; |
302 let ℓ = 0.0; |
303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
303 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
304 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
304 let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components(); |
305 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
305 let transport_lip = maybe_transport_lip.unwrap(); |
306 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
306 let calculate_θ = |ℓ_v, max_transport| { |
307 // (the uniform Lipschitz factor of ∇v). |
307 let ℓ_F = ℓ_v + transport_lip * max_transport; |
308 // We assume that the residual is decreasing. |
308 config.transport.θ0 / (τ*(ℓ + ℓ_F)) |
309 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), |
309 }; |
|
310 let mut θ_or_adaptive = match maybe_ℓ_v0 { |
|
311 //Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), |
|
312 Some(ℓ_v0) => TransportStepLength::AdaptiveMax { |
|
313 l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual |
|
314 max_transport : 0.0, |
|
315 g : calculate_θ |
|
316 }, |
310 None => TransportStepLength::FullyAdaptive { |
317 None => TransportStepLength::FullyAdaptive { |
311 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials |
318 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials |
312 max_transport : 0.0, |
319 max_transport : 0.0, |
313 g : calculate_θ |
320 g : calculate_θ |
314 }, |
321 }, |