| 6 |
6 |
| 7 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; |
7 use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; |
| 8 use crate::dataterm::QuadraticDataTerm; |
8 use crate::dataterm::QuadraticDataTerm; |
| 9 use crate::forward_model::ForwardModel; |
9 use crate::forward_model::ForwardModel; |
| 10 use crate::measures::merging::SpikeMerging; |
10 use crate::measures::merging::SpikeMerging; |
| 11 use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; |
11 use crate::measures::{DiscreteMeasure, Radon}; |
| 12 use crate::regularisation::RegTerm; |
12 use crate::regularisation::RadonSquaredRegTerm; |
| 13 use crate::types::*; |
13 use crate::types::*; |
| 14 use alg_tools::bounds::MinMaxMapping; |
14 use alg_tools::bounds::MinMaxMapping; |
| 15 use alg_tools::error::DynResult; |
15 use alg_tools::error::DynResult; |
| 16 use alg_tools::instance::{Instance, Space}; |
16 use alg_tools::instance::{Instance, Space}; |
| 17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; |
17 use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; |
| 18 use alg_tools::linops::BoundedLinear; |
18 use alg_tools::linops::BoundedLinear; |
| 19 use alg_tools::nalgebra_support::ToNalgebraRealField; |
19 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 20 use alg_tools::norms::{Norm, L2}; |
20 use alg_tools::norms::{Norm, L2}; |
| 21 use anyhow::ensure; |
|
| 22 use nalgebra::DVector; |
|
| 23 use numeric_literals::replace_float_literals; |
21 use numeric_literals::replace_float_literals; |
| 24 use serde::{Deserialize, Serialize}; |
22 use serde::{Deserialize, Serialize}; |
| 25 |
23 |
| 26 /// Radon-norm squared proximal penalty |
24 /// Radon-norm squared proximal penalty |
| 27 |
25 |
| 33 where |
31 where |
| 34 Domain: Space + Clone + PartialEq + 'static, |
32 Domain: Space + Clone + PartialEq + 'static, |
| 35 for<'a> &'a Domain: Instance<Domain>, |
33 for<'a> &'a Domain: Instance<Domain>, |
| 36 F: Float + ToNalgebraRealField, |
34 F: Float + ToNalgebraRealField, |
| 37 M: MinMaxMapping<Domain, F>, |
35 M: MinMaxMapping<Domain, F>, |
| 38 Reg: RegTerm<Domain, F>, |
36 Reg: RadonSquaredRegTerm<Domain, F>, |
| 39 DiscreteMeasure<Domain, F>: SpikeMerging<F>, |
37 DiscreteMeasure<Domain, F>: SpikeMerging<F>, |
| 40 { |
38 { |
| 41 type ReturnMapping = M; |
39 type ReturnMapping = M; |
| 42 |
40 |
| 43 fn prox_type() -> ProxTerm { |
41 fn prox_type() -> ProxTerm { |
| 46 |
44 |
| 47 fn insert_and_reweigh<I>( |
45 fn insert_and_reweigh<I>( |
| 48 &self, |
46 &self, |
| 49 μ: &mut DiscreteMeasure<Domain, F>, |
47 μ: &mut DiscreteMeasure<Domain, F>, |
| 50 τv: &mut M, |
48 τv: &mut M, |
| 51 μ_base: &DiscreteMeasure<Domain, F>, |
|
| 52 ν_delta: Option<&DiscreteMeasure<Domain, F>>, |
|
| 53 τ: F, |
49 τ: F, |
| 54 ε: F, |
50 ε: F, |
| 55 config: &InsertionConfig<F>, |
51 config: &InsertionConfig<F>, |
| 56 reg: &Reg, |
52 reg: &Reg, |
| 57 _state: &AlgIteratorIteration<I>, |
53 _state: &AlgIteratorIteration<I>, |
| 58 stats: &mut IterInfo<F>, |
54 stats: &mut IterInfo<F>, |
| 59 ) -> DynResult<(Option<Self::ReturnMapping>, bool)> |
55 ) -> DynResult<(Option<Self::ReturnMapping>, bool)> |
| 60 where |
56 where |
| 61 I: AlgIterator, |
57 I: AlgIterator, |
| 62 { |
58 { |
| 63 let mut y = μ_base.masses_dvector(); |
59 let violation = reg.find_tolerance_violation(τv, τ, ε, true, config); |
| 64 |
60 reg.solve_oc_radonsq(μ, τv, τ, ε, violation, config, stats); |
| 65 ensure!(μ_base.len() <= μ.len()); |
|
| 66 |
|
| 67 'i_and_w: for i in 0..=1 { |
|
| 68 // Optimise weights |
|
| 69 if μ.len() > 0 { |
|
| 70 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
|
| 71 // from the beginning of the iteration are all contained in the immutable c and g. |
|
| 72 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
|
| 73 // problems have not yet been updated to sign change. |
|
| 74 let g̃ = DVector::from_iterator( |
|
| 75 μ.len(), |
|
| 76 μ.iter_locations() |
|
| 77 .map(|ζ| -F::to_nalgebra_mixed(τv.apply(ζ))), |
|
| 78 ); |
|
| 79 let mut x = μ.masses_dvector(); |
|
| 80 y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len() - y.len()))); |
|
| 81 assert_eq!(y.len(), x.len()); |
|
| 82 // Solve finite-dimensional subproblem. |
|
| 83 // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to |
|
| 84 // ignore it. |
|
| 85 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); |
|
| 86 |
|
| 87 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
| 88 μ.set_masses_dvector(&x); |
|
| 89 } |
|
| 90 |
|
| 91 if i > 0 { |
|
| 92 // Simple debugging test to see if more inserts would be needed. Doesn't seem so. |
|
| 93 //let n = μ.dist_matching(μ_base); |
|
| 94 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); |
|
| 95 break 'i_and_w; |
|
| 96 } |
|
| 97 |
|
| 98 // Calculate ‖μ - μ_base‖_ℳ |
|
| 99 // TODO: This assumes that ν_delta has no common locations with μ-μ_base. |
|
| 100 let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon)); |
|
| 101 |
|
| 102 // Find a spike to insert, if needed. |
|
| 103 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, |
|
| 104 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. |
|
| 105 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { |
|
| 106 None => break 'i_and_w, |
|
| 107 Some((ξ, _v_ξ, _in_bounds)) => { |
|
| 108 // Weight is found out by running the finite-dimensional optimisation algorithm |
|
| 109 // above |
|
| 110 *μ += DeltaMeasure { x: ξ, α: 0.0 }; |
|
| 111 stats.inserted += 1; |
|
| 112 } |
|
| 113 }; |
|
| 114 } |
|
| 115 |
61 |
| 116 Ok((None, true)) |
62 Ok((None, true)) |
| 117 } |
63 } |
| 118 |
64 |
| 119 fn merge_spikes( |
65 fn merge_spikes( |
| 120 &self, |
66 &self, |
| 121 μ: &mut DiscreteMeasure<Domain, F>, |
67 μ: &mut DiscreteMeasure<Domain, F>, |
| 122 τv: &mut M, |
68 τv: &mut M, |
| 123 μ_base: &DiscreteMeasure<Domain, F>, |
69 μ_base: &DiscreteMeasure<Domain, F>, |
| 124 ν_delta: Option<&DiscreteMeasure<Domain, F>>, |
|
| 125 τ: F, |
70 τ: F, |
| 126 ε: F, |
71 ε: F, |
| 127 config: &InsertionConfig<F>, |
72 config: &InsertionConfig<F>, |
| 128 reg: &Reg, |
73 reg: &Reg, |
| 129 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>, |
74 fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>, |
| 139 // TODO: could simplify to requiring μ_base instead of μ_radon. |
84 // TODO: could simplify to requiring μ_base instead of μ_radon. |
| 140 // but may complicate with sliding base's exgtra points that need to be |
85 // but may complicate with sliding base's exgtra points that need to be |
| 141 // after μ_candidate's extra points. |
86 // after μ_candidate's extra points. |
| 142 // TODO: doesn't seem to work, maybe need to merge μ_base as well? |
87 // TODO: doesn't seem to work, maybe need to merge μ_base as well? |
| 143 // Although that doesn't seem to make sense. |
88 // Although that doesn't seem to make sense. |
| 144 let μ_radon = match ν_delta { |
89 let μ_radon = μ_candidate.sub_matching(μ_base); |
| 145 None => μ_candidate.sub_matching(μ_base), |
|
| 146 Some(ν) => μ_candidate.sub_matching(μ_base) - ν, |
|
| 147 }; |
|
| 148 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) |
90 reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) |
| 149 //let n = μ_candidate.dist_matching(μ_base); |
91 //let n = μ_candidate.dist_matching(μ_base); |
| 150 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() |
92 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() |
| 151 }) |
93 }) |
| 152 } |
94 } |