src/prox_penalty/radon_squared.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
2 Solver for the point source localisation problem using a simplified forward-backward splitting method. 2 Solver for the point source localisation problem using a simplified forward-backward splitting method.
3 3
4 Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. 4 Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map.
5 */ 5 */
6 6
7 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
8 use crate::dataterm::QuadraticDataTerm;
9 use crate::forward_model::ForwardModel;
10 use crate::measures::merging::SpikeMerging;
11 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
12 use crate::regularisation::RegTerm;
13 use crate::types::*;
14 use alg_tools::bounds::MinMaxMapping;
15 use alg_tools::error::DynResult;
16 use alg_tools::instance::{Instance, Space};
17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
18 use alg_tools::linops::BoundedLinear;
19 use alg_tools::nalgebra_support::ToNalgebraRealField;
20 use alg_tools::norms::{Norm, L2};
21 use anyhow::ensure;
22 use nalgebra::DVector;
7 use numeric_literals::replace_float_literals; 23 use numeric_literals::replace_float_literals;
8 use serde::{Serialize, Deserialize}; 24 use serde::{Deserialize, Serialize};
9 use nalgebra::DVector;
10
11 use alg_tools::iterate::{
12 AlgIteratorIteration,
13 AlgIterator
14 };
15 use alg_tools::norms::{L2, Norm};
16 use alg_tools::linops::Mapping;
17 use alg_tools::bisection_tree::{
18 BTFN,
19 Bounds,
20 BTSearch,
21 SupportGenerator,
22 LocalAnalysis,
23 };
24 use alg_tools::mapping::RealMapping;
25 use alg_tools::nalgebra_support::ToNalgebraRealField;
26
27 use crate::types::*;
28 use crate::measures::{
29 RNDM,
30 DeltaMeasure,
31 Radon,
32 };
33 use crate::measures::merging::SpikeMerging;
34 use crate::regularisation::RegTerm;
35 use crate::forward_model::{
36 ForwardModel,
37 AdjointProductBoundedBy
38 };
39 use super::{
40 FBGenericConfig,
41 ProxPenalty,
42 };
43 25
44 /// Radon-norm squared proximal penalty 26 /// Radon-norm squared proximal penalty
45 27
46 #[derive(Copy,Clone,Serialize,Deserialize)] 28 #[derive(Copy, Clone, Serialize, Deserialize)]
47 pub struct RadonSquared; 29 pub struct RadonSquared;
48 30
49 #[replace_float_literals(F::cast_from(literal))] 31 #[replace_float_literals(F::cast_from(literal))]
50 impl<F, GA, BTA, S, Reg, const N : usize> 32 impl<Domain, F, M, Reg> ProxPenalty<Domain, M, Reg, F> for RadonSquared
51 ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for RadonSquared
52 where 33 where
53 F : Float + ToNalgebraRealField, 34 Domain: Space + Clone + PartialEq + 'static,
54 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 35 for<'a> &'a Domain: Instance<Domain>,
55 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 36 F: Float + ToNalgebraRealField,
56 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 37 M: MinMaxMapping<Domain, F>,
57 Reg : RegTerm<F, N>, 38 Reg: RegTerm<Domain, F>,
58 RNDM<F, N> : SpikeMerging<F>, 39 DiscreteMeasure<Domain, F>: SpikeMerging<F>,
59 { 40 {
60 type ReturnMapping = BTFN<F, GA, BTA, N>; 41 type ReturnMapping = M;
42
43 fn prox_type() -> ProxTerm {
44 ProxTerm::RadonSquared
45 }
61 46
62 fn insert_and_reweigh<I>( 47 fn insert_and_reweigh<I>(
63 &self, 48 &self,
64 μ : &mut RNDM<F, N>, 49 μ: &mut DiscreteMeasure<Domain, F>,
65 τv : &mut BTFN<F, GA, BTA, N>, 50 τv: &mut M,
66 μ_base : &RNDM<F, N>, 51 μ_base: &DiscreteMeasure<Domain, F>,
67 ν_delta: Option<&RNDM<F, N>>, 52 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
68 τ : F, 53 τ: F,
69 ε : F, 54 ε: F,
70 config : &FBGenericConfig<F>, 55 config: &InsertionConfig<F>,
71 reg : &Reg, 56 reg: &Reg,
72 _state : &AlgIteratorIteration<I>, 57 _state: &AlgIteratorIteration<I>,
73 stats : &mut IterInfo<F, N>, 58 stats: &mut IterInfo<F>,
74 ) -> (Option<Self::ReturnMapping>, bool) 59 ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
75 where 60 where
76 I : AlgIterator 61 I: AlgIterator,
77 { 62 {
78 let mut y = μ_base.masses_dvector(); 63 let mut y = μ_base.masses_dvector();
79 64
80 assert!(μ_base.len() <= μ.len()); 65 ensure!(μ_base.len() <= μ.len());
81 66
82 'i_and_w: for i in 0..=1 { 67 'i_and_w: for i in 0..=1 {
83 // Optimise weights 68 // Optimise weights
84 if μ.len() > 0 { 69 if μ.len() > 0 {
85 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 70 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
86 // from the beginning of the iteration are all contained in the immutable c and g. 71 // from the beginning of the iteration are all contained in the immutable c and g.
87 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional 72 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
88 // problems have not yet been updated to sign change. 73 // problems have not yet been updated to sign change.
89 let g̃ = DVector::from_iterator(μ.len(), 74 let g̃ = DVector::from_iterator(
90 μ.iter_locations() 75 μ.len(),
91 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); 76 μ.iter_locations()
77 .map(|ζ| -F::to_nalgebra_mixed(τv.apply(ζ))),
78 );
92 let mut x = μ.masses_dvector(); 79 let mut x = μ.masses_dvector();
93 y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len()-y.len()))); 80 y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len() - y.len())));
94 assert_eq!(y.len(), x.len()); 81 assert_eq!(y.len(), x.len());
95 // Solve finite-dimensional subproblem. 82 // Solve finite-dimensional subproblem.
96 // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to 83 // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to
97 // ignore it. 84 // ignore it.
98 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); 85 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
99 86
100 // Update masses of μ based on solution of finite-dimensional subproblem. 87 // Update masses of μ based on solution of finite-dimensional subproblem.
101 μ.set_masses_dvector(&x); 88 μ.set_masses_dvector(&x);
102 } 89 }
103 90
104 if i>0 { 91 if i > 0 {
105 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. 92 // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
106 //let n = μ.dist_matching(μ_base); 93 //let n = μ.dist_matching(μ_base);
107 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); 94 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
108 break 'i_and_w 95 break 'i_and_w;
109 } 96 }
110 97
111 // Calculate ‖μ - μ_base‖_ℳ 98 // Calculate ‖μ - μ_base‖_ℳ
112 // TODO: This assumes that ν_delta has no common locations with μ-μ_base. 99 // TODO: This assumes that ν_delta has no common locations with μ-μ_base.
113 let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon)); 100 let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon));
114 101
115 // Find a spike to insert, if needed. 102 // Find a spike to insert, if needed.
116 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, 103 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
117 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. 104 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
118 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { 105 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
119 None => { break 'i_and_w }, 106 None => break 'i_and_w,
120 Some((ξ, _v_ξ, _in_bounds)) => { 107 Some((ξ, _v_ξ, _in_bounds)) => {
121 // Weight is found out by running the finite-dimensional optimisation algorithm 108 // Weight is found out by running the finite-dimensional optimisation algorithm
122 // above 109 // above
123 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 110 *μ += DeltaMeasure { x: ξ, α: 0.0 };
124 stats.inserted += 1; 111 stats.inserted += 1;
125 } 112 }
126 }; 113 };
127 } 114 }
128 115
129 (None, true) 116 Ok((None, true))
130 } 117 }
131 118
132 fn merge_spikes( 119 fn merge_spikes(
133 &self, 120 &self,
134 μ : &mut RNDM<F, N>, 121 μ: &mut DiscreteMeasure<Domain, F>,
135 τv : &mut BTFN<F, GA, BTA, N>, 122 τv: &mut M,
136 μ_base : &RNDM<F, N>, 123 μ_base: &DiscreteMeasure<Domain, F>,
137 ν_delta: Option<&RNDM<F, N>>, 124 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
138 τ : F, 125 τ: F,
139 ε : F, 126 ε: F,
140 config : &FBGenericConfig<F>, 127 config: &InsertionConfig<F>,
141 reg : &Reg, 128 reg: &Reg,
142 fitness : Option<impl Fn(&RNDM<F, N>) -> F>, 129 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
143 ) -> usize 130 ) -> usize {
144 {
145 if config.fitness_merging { 131 if config.fitness_merging {
146 if let Some(f) = fitness { 132 if let Some(f) = fitness {
147 return μ.merge_spikes_fitness(config.merging, f, |&v| v) 133 return μ.merge_spikes_fitness(config.merging, f, |&v| v).1;
148 .1
149 } 134 }
150 } 135 }
151 μ.merge_spikes(config.merging, |μ_candidate| { 136 μ.merge_spikes(config.merging, |μ_candidate| {
152 // Important: μ_candidate's new points are afterwards, 137 // Important: μ_candidate's new points are afterwards,
153 // and do not conflict with μ_base. 138 // and do not conflict with μ_base.
165 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() 150 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
166 }) 151 })
167 } 152 }
168 } 153 }
169 154
170 155 #[replace_float_literals(F::cast_from(literal))]
171 impl<F, A, const N : usize> AdjointProductBoundedBy<RNDM<F, N>, RadonSquared> 156 impl<'a, F, A, Domain> StepLengthBound<F, QuadraticDataTerm<F, Domain, A>> for RadonSquared
172 for A
173 where 157 where
174 F : Float, 158 F: Float + ToNalgebraRealField,
175 A : ForwardModel<RNDM<F, N>, F> 159 Domain: Space + Norm<Radon, F>,
160 A: ForwardModel<Domain, F> + BoundedLinear<Domain, Radon, L2, F>,
176 { 161 {
177 type FloatType = F; 162 fn step_length_bound(&self, f: &QuadraticDataTerm<F, Domain, A>) -> DynResult<F> {
178 163 // TODO: direct squared calculation
179 fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option<Self::FloatType> { 164 Ok(f.operator().opnorm_bound(Radon, L2)?.powi(2))
180 self.opnorm_bound(Radon, L2).powi(2).into()
181 } 165 }
182 } 166 }
167
168 #[replace_float_literals(F::cast_from(literal))]
169 impl<'a, F, A, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for RadonSquared
170 where
171 Domain: Space + Clone + PartialEq + 'static,
172 F: Float + ToNalgebraRealField,
173 A: BoundedLinear<DiscreteMeasure<Domain, F>, Radon, L2, F>,
174 {
175 fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> {
176 opA.opnorm_bound(Radon, L2)
177 }
178 }

mercurial