src/prox_penalty/radon_squared.rs

branch
dev
changeset 63
7a8a55fd41c0
parent 61
4f468d35fa29
equal deleted inserted replaced
61:4f468d35fa29 63:7a8a55fd41c0
6 6
7 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; 7 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
8 use crate::dataterm::QuadraticDataTerm; 8 use crate::dataterm::QuadraticDataTerm;
9 use crate::forward_model::ForwardModel; 9 use crate::forward_model::ForwardModel;
10 use crate::measures::merging::SpikeMerging; 10 use crate::measures::merging::SpikeMerging;
11 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; 11 use crate::measures::{DiscreteMeasure, Radon};
12 use crate::regularisation::RegTerm; 12 use crate::regularisation::RadonSquaredRegTerm;
13 use crate::types::*; 13 use crate::types::*;
14 use alg_tools::bounds::MinMaxMapping; 14 use alg_tools::bounds::MinMaxMapping;
15 use alg_tools::error::DynResult; 15 use alg_tools::error::DynResult;
16 use alg_tools::instance::{Instance, Space}; 16 use alg_tools::instance::{Instance, Space};
17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; 17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
18 use alg_tools::linops::BoundedLinear; 18 use alg_tools::linops::BoundedLinear;
19 use alg_tools::nalgebra_support::ToNalgebraRealField; 19 use alg_tools::nalgebra_support::ToNalgebraRealField;
20 use alg_tools::norms::{Norm, L2}; 20 use alg_tools::norms::{Norm, L2};
21 use anyhow::ensure;
22 use nalgebra::DVector;
23 use numeric_literals::replace_float_literals; 21 use numeric_literals::replace_float_literals;
24 use serde::{Deserialize, Serialize}; 22 use serde::{Deserialize, Serialize};
25 23
26 /// Radon-norm squared proximal penalty 24 /// Radon-norm squared proximal penalty
27 25
33 where 31 where
34 Domain: Space + Clone + PartialEq + 'static, 32 Domain: Space + Clone + PartialEq + 'static,
35 for<'a> &'a Domain: Instance<Domain>, 33 for<'a> &'a Domain: Instance<Domain>,
36 F: Float + ToNalgebraRealField, 34 F: Float + ToNalgebraRealField,
37 M: MinMaxMapping<Domain, F>, 35 M: MinMaxMapping<Domain, F>,
38 Reg: RegTerm<Domain, F>, 36 Reg: RadonSquaredRegTerm<Domain, F>,
39 DiscreteMeasure<Domain, F>: SpikeMerging<F>, 37 DiscreteMeasure<Domain, F>: SpikeMerging<F>,
40 { 38 {
41 type ReturnMapping = M; 39 type ReturnMapping = M;
42 40
43 fn prox_type() -> ProxTerm { 41 fn prox_type() -> ProxTerm {
46 44
47 fn insert_and_reweigh<I>( 45 fn insert_and_reweigh<I>(
48 &self, 46 &self,
49 μ: &mut DiscreteMeasure<Domain, F>, 47 μ: &mut DiscreteMeasure<Domain, F>,
50 τv: &mut M, 48 τv: &mut M,
51 μ_base: &DiscreteMeasure<Domain, F>,
52 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
53 τ: F, 49 τ: F,
54 ε: F, 50 ε: F,
55 config: &InsertionConfig<F>, 51 config: &InsertionConfig<F>,
56 reg: &Reg, 52 reg: &Reg,
57 _state: &AlgIteratorIteration<I>, 53 _state: &AlgIteratorIteration<I>,
58 stats: &mut IterInfo<F>, 54 stats: &mut IterInfo<F>,
59 ) -> DynResult<(Option<Self::ReturnMapping>, bool)> 55 ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
60 where 56 where
61 I: AlgIterator, 57 I: AlgIterator,
62 { 58 {
63 let mut y = μ_base.masses_dvector(); 59 let violation = reg.find_tolerance_violation(τv, τ, ε, true, config);
64 60 reg.solve_oc_radonsq(μ, τv, τ, ε, violation, config, stats);
65 ensure!(μ_base.len() <= μ.len());
66
67 'i_and_w: for i in 0..=1 {
68 // Optimise weights
69 if μ.len() > 0 {
70 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
71 // from the beginning of the iteration are all contained in the immutable c and g.
72 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
73 // problems have not yet been updated to sign change.
74 let g̃ = DVector::from_iterator(
75 μ.len(),
76 μ.iter_locations()
77 .map(|ζ| -F::to_nalgebra_mixed(τv.apply(ζ))),
78 );
79 let mut x = μ.masses_dvector();
80 y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len() - y.len())));
81 assert_eq!(y.len(), x.len());
82 // Solve finite-dimensional subproblem.
83 // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to
84 // ignore it.
85 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
86
87 // Update masses of μ based on solution of finite-dimensional subproblem.
88 μ.set_masses_dvector(&x);
89 }
90
91 if i > 0 {
92 // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
93 //let n = μ.dist_matching(μ_base);
94 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
95 break 'i_and_w;
96 }
97
98 // Calculate ‖μ - μ_base‖_ℳ
99 // TODO: This assumes that ν_delta has no common locations with μ-μ_base.
100 let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon));
101
102 // Find a spike to insert, if needed.
103 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
104 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
105 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
106 None => break 'i_and_w,
107 Some((ξ, _v_ξ, _in_bounds)) => {
108 // Weight is found out by running the finite-dimensional optimisation algorithm
109 // above
110 *μ += DeltaMeasure { x: ξ, α: 0.0 };
111 stats.inserted += 1;
112 }
113 };
114 }
115 61
116 Ok((None, true)) 62 Ok((None, true))
117 } 63 }
118 64
119 fn merge_spikes( 65 fn merge_spikes(
120 &self, 66 &self,
121 μ: &mut DiscreteMeasure<Domain, F>, 67 μ: &mut DiscreteMeasure<Domain, F>,
122 τv: &mut M, 68 τv: &mut M,
123 μ_base: &DiscreteMeasure<Domain, F>, 69 μ_base: &DiscreteMeasure<Domain, F>,
124 ν_delta: Option<&DiscreteMeasure<Domain, F>>,
125 τ: F, 70 τ: F,
126 ε: F, 71 ε: F,
127 config: &InsertionConfig<F>, 72 config: &InsertionConfig<F>,
128 reg: &Reg, 73 reg: &Reg,
129 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>, 74 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
139 // TODO: could simplify to requiring μ_base instead of μ_radon. 84 // TODO: could simplify to requiring μ_base instead of μ_radon.
140 // but may complicate with sliding base's exgtra points that need to be 85 // but may complicate with sliding base's exgtra points that need to be
141 // after μ_candidate's extra points. 86 // after μ_candidate's extra points.
142 // TODO: doesn't seem to work, maybe need to merge μ_base as well? 87 // TODO: doesn't seem to work, maybe need to merge μ_base as well?
143 // Although that doesn't seem to make sense. 88 // Although that doesn't seem to make sense.
144 let μ_radon = match ν_delta { 89 let μ_radon = μ_candidate.sub_matching(μ_base);
145 None => μ_candidate.sub_matching(μ_base),
146 Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
147 };
148 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) 90 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon)
149 //let n = μ_candidate.dist_matching(μ_base); 91 //let n = μ_candidate.dist_matching(μ_base);
150 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() 92 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
151 }) 93 })
152 } 94 }

mercurial