src/prox_penalty/radon_squared.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
equal deleted inserted replaced
37:c5d8bd1a7728 39:6316d68b58af
10 10
11 use alg_tools::iterate::{ 11 use alg_tools::iterate::{
12 AlgIteratorIteration, 12 AlgIteratorIteration,
13 AlgIterator 13 AlgIterator
14 }; 14 };
15 use alg_tools::norms::L2; 15 use alg_tools::norms::{L2, Norm};
16 use alg_tools::linops::Mapping; 16 use alg_tools::linops::Mapping;
17 use alg_tools::bisection_tree::{ 17 use alg_tools::bisection_tree::{
18 BTFN, 18 BTFN,
19 Bounds, 19 Bounds,
20 BTSearch, 20 BTSearch,
73 stats : &mut IterInfo<F, N>, 73 stats : &mut IterInfo<F, N>,
74 ) -> (Option<Self::ReturnMapping>, bool) 74 ) -> (Option<Self::ReturnMapping>, bool)
75 where 75 where
76 I : AlgIterator 76 I : AlgIterator
77 { 77 {
78 assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term"); 78 let mut y = μ_base.masses_dvector();
79 79
80 let mut y = μ_base.masses_vec(); 80 assert!(μ_base.len() <= μ.len());
81 81
82 'i_and_w: for i in 0..=1 { 82 'i_and_w: for i in 0..=1 {
83 // Optimise weights 83 // Optimise weights
84 if μ.len() > 0 { 84 if μ.len() > 0 {
85 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 85 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
86 // from the beginning of the iteration are all contained in the immutable c and g. 86 // from the beginning of the iteration are all contained in the immutable c and g.
88 // problems have not yet been updated to sign change. 88 // problems have not yet been updated to sign change.
89 let g̃ = DVector::from_iterator(μ.len(), 89 let g̃ = DVector::from_iterator(μ.len(),
90 μ.iter_locations() 90 μ.iter_locations()
91 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); 91 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
92 let mut x = μ.masses_dvector(); 92 let mut x = μ.masses_dvector();
93 // Ugly hack because DVector::push doesn't push but copies. 93 y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len()-y.len())));
94 let yvec = DVector::from_column_slice(y.as_slice()); 94 assert_eq!(y.len(), x.len());
95 // Solve finite-dimensional subproblem. 95 // Solve finite-dimensional subproblem.
96 stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config); 96 // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to
97 // ignore it.
98 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
97 99
98 // Update masses of μ based on solution of finite-dimensional subproblem. 100 // Update masses of μ based on solution of finite-dimensional subproblem.
99 μ.set_masses_dvector(&x); 101 μ.set_masses_dvector(&x);
100 } 102 }
101 103
105 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); 107 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
106 break 'i_and_w 108 break 'i_and_w
107 } 109 }
108 110
109 // Calculate ‖μ - μ_base‖_ℳ 111 // Calculate ‖μ - μ_base‖_ℳ
110 let n = μ.dist_matching(μ_base); 112 // TODO: This assumes that ν_delta has no common locations with μ-μ_base.
113 let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon));
111 114
112 // Find a spike to insert, if needed. 115 // Find a spike to insert, if needed.
113 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, 116 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
114 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. 117 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
115 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { 118 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
116 None => { break 'i_and_w }, 119 None => { break 'i_and_w },
117 Some((ξ, _v_ξ, _in_bounds)) => { 120 Some((ξ, _v_ξ, _in_bounds)) => {
118 // Weight is found out by running the finite-dimensional optimisation algorithm 121 // Weight is found out by running the finite-dimensional optimisation algorithm
119 // above 122 // above
120 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 123 *μ += DeltaMeasure { x : ξ, α : 0.0 };
121 //*μ_base += DeltaMeasure { x : ξ, α : 0.0 };
122 y.push(0.0.to_nalgebra_mixed());
123 stats.inserted += 1; 124 stats.inserted += 1;
124 } 125 }
125 }; 126 };
126 } 127 }
127 128
131 fn merge_spikes( 132 fn merge_spikes(
132 &self, 133 &self,
133 μ : &mut RNDM<F, N>, 134 μ : &mut RNDM<F, N>,
134 τv : &mut BTFN<F, GA, BTA, N>, 135 τv : &mut BTFN<F, GA, BTA, N>,
135 μ_base : &RNDM<F, N>, 136 μ_base : &RNDM<F, N>,
137 ν_delta: Option<&RNDM<F, N>>,
136 τ : F, 138 τ : F,
137 ε : F, 139 ε : F,
138 config : &FBGenericConfig<F>, 140 config : &FBGenericConfig<F>,
139 reg : &Reg, 141 reg : &Reg,
142 fitness : Option<impl Fn(&RNDM<F, N>) -> F>,
140 ) -> usize 143 ) -> usize
141 { 144 {
145 if config.fitness_merging {
146 if let Some(f) = fitness {
147 return μ.merge_spikes_fitness(config.merging, f, |&v| v)
148 .1
149 }
150 }
142 μ.merge_spikes(config.merging, |μ_candidate| { 151 μ.merge_spikes(config.merging, |μ_candidate| {
143 // Important: μ_candidate's new points are afterwards, 152 // Important: μ_candidate's new points are afterwards,
144 // and do not conflict with μ_base. 153 // and do not conflict with μ_base.
145 // TODO: could simplify to requiring μ_base instead of μ_radon. 154 // TODO: could simplify to requiring μ_base instead of μ_radon.
146 // but may complicate with sliding base's exgtra points that need to be 155 // but may complicate with sliding base's exgtra points that need to be
147 // after μ_candidate's extra points. 156 // after μ_candidate's extra points.
148 // TODO: doesn't seem to work, maybe need to merge μ_base as well? 157 // TODO: doesn't seem to work, maybe need to merge μ_base as well?
149 // Although that doesn't seem to make sense. 158 // Although that doesn't seem to make sense.
150 let μ_radon = μ_candidate.sub_matching(μ_base); 159 let μ_radon = match ν_delta {
160 None => μ_candidate.sub_matching(μ_base),
161 Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
162 };
151 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) 163 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon)
152 //let n = μ_candidate.dist_matching(μ_base); 164 //let n = μ_candidate.dist_matching(μ_base);
153 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() 165 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
154 }) 166 })
155 } 167 }

mercurial