src/subproblem.rs

changeset 4
5aa5c279e341
parent 0
eb3c7813b67a
equal deleted inserted replaced
0:eb3c7813b67a 4:5aa5c279e341
1 //! Iterative algorithms for solving finite-dimensional subproblems. 1 //! Iterative algorithms for solving finite-dimensional subproblems.
2 2
3 use serde::{Serialize, Deserialize}; 3 use serde::{Serialize, Deserialize};
4 use nalgebra::{DVector, DMatrix}; 4 use nalgebra::{
5 Matrix,
6 Vector,
7 Storage,
8 StorageMut,
9 Const,
10 Dim,
11 DVector,
12 DMatrix,
13 DefaultAllocator,
14 Dynamic,
15 };
16 use nalgebra::base::allocator::Allocator;
5 use numeric_literals::replace_float_literals; 17 use numeric_literals::replace_float_literals;
6 use itertools::{izip, Itertools}; 18 use itertools::{izip, Itertools};
7 use colored::Colorize; 19 use colored::Colorize;
8 20
9 use alg_tools::iter::Mappable; 21 use alg_tools::iter::Mappable;
78 /// For detailed documentation of the inputs and outputs, refer to there. 90 /// For detailed documentation of the inputs and outputs, refer to there.
79 /// 91 ///
80 /// The `λ` component of the model is handled in the proximal step instead of the gradient step 92 /// The `λ` component of the model is handled in the proximal step instead of the gradient step
81 /// for potential performance improvements. 93 /// for potential performance improvements.
82 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())] 94 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())]
83 pub fn quadratic_nonneg_fb<F, I>( 95 pub fn quadratic_nonneg_fb<F, I, S, SA, SG, D>(
84 mA : &DMatrix<F::MixedType>, 96 mA : &Matrix<F::MixedType, D, D, SA>,
85 g : &DVector<F::MixedType>, 97 g : &Vector<F::MixedType, D, SG>,
86 //c_ : F, 98 //c_ : F,
87 λ_ : F, 99 λ_ : F,
88 x : &mut DVector<F::MixedType>, 100 x : &mut Vector<F::MixedType, D, S>,
89 τ_ : F, 101 τ_ : F,
90 iterator : I 102 iterator : I
91 ) -> usize 103 ) -> usize
92 where F : Float + ToNalgebraRealField, 104 where F : Float + ToNalgebraRealField,
93 I : AlgIteratorFactory<F> 105 I : AlgIteratorFactory<F>,
106 D : Dim,
107 S : StorageMut<F::MixedType, D, Const::<1>>,
108 SA : Storage<F::MixedType, D, D>,
109 SG : Storage<F::MixedType, D, Const::<1>>,
110 DefaultAllocator : Allocator<F::MixedType, D>
94 { 111 {
95 let mut xprev = x.clone(); 112 let mut xprev = x.clone_owned();
96 //let c = c_.to_nalgebra_mixed(); 113 //let c = c_.to_nalgebra_mixed();
97 let λ = λ_.to_nalgebra_mixed(); 114 let λ = λ_.to_nalgebra_mixed();
98 let τ = τ_.to_nalgebra_mixed(); 115 let τ = τ_.to_nalgebra_mixed();
99 let τλ = τ * λ; 116 let τλ = τ * λ;
100 let mut v = DVector::zeros(x.len());
101 let mut iters = 0; 117 let mut iters = 0;
118 let mut v = {
119 let (r, c) = x.shape_generic();
120 Vector::zeros_generic(r, c)
121 };
102 122
103 iterator.iterate(|state| { 123 iterator.iterate(|state| {
104 // Replace `x` with $x - τ[Ax-g]= [x + τg]- τAx$ 124 // Replace `x` with $x - τ[Ax-g]= [x + τg]- τAx$
105 v.copy_from(g); // v = g 125 v.copy_from(g); // v = g
106 v.axpy(1.0, x, τ); // v = x + τ*g 126 v.axpy(1.0, x, τ); // v = x + τ*g
193 /// We need to detect stopping by a subdifferential and return $x$ satisfying $x ≥ 0$, 213 /// We need to detect stopping by a subdifferential and return $x$ satisfying $x ≥ 0$,
194 /// which is in general not true for the SSN. We therefore use that $[G ∘ F](x^k)$ is a valid 214 /// which is in general not true for the SSN. We therefore use that $[G ∘ F](x^k)$ is a valid
195 /// forward-backward step. 215 /// forward-backward step.
196 /// </p> 216 /// </p>
197 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())] 217 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())]
198 pub fn quadratic_nonneg_ssn<F, I>( 218 pub fn quadratic_nonneg_ssn<F, I, S, SA, SG>(
199 mA : &DMatrix<F::MixedType>, 219 mA : &Matrix<F::MixedType, Dynamic, Dynamic, SA>,
200 g : &DVector<F::MixedType>, 220 g : &Vector<F::MixedType, Dynamic, SG>,
201 //c_ : F, 221 //c_ : F,
202 λ_ : F, 222 λ_ : F,
203 x : &mut DVector<F::MixedType>, 223 x : &mut Vector<F::MixedType, Dynamic, S>,
204 τ_ : F, 224 τ_ : F,
205 iterator : I 225 iterator : I
206 ) -> Result<usize, NumericalError> 226 ) -> Result<usize, NumericalError>
207 where F : Float + ToNalgebraRealField, 227 where F : Float + ToNalgebraRealField<MixedType=F>,
208 I : AlgIteratorFactory<F> 228 I : AlgIteratorFactory<F>,
229 S : StorageMut<F::MixedType, Dynamic, Const::<1>>,
230 SA : Storage<F::MixedType, Dynamic, Dynamic>,
231 SG : Storage<F::MixedType, Dynamic, Const::<1>>,
209 { 232 {
210 let n = x.len(); 233 let n = x.len();
211 let mut xprev = x.clone(); 234 let mut xprev = x.clone_owned();
212 let mut v = DVector::zeros(n); 235 let mut v = DVector::zeros(n);
213 //let c = c_.to_nalgebra_mixed(); 236 //let c = c_.to_nalgebra_mixed();
214 let λ = λ_.to_nalgebra_mixed(); 237 let λ = λ_.to_nalgebra_mixed();
215 let τ = τ_.to_nalgebra_mixed(); 238 let τ = τ_.to_nalgebra_mixed();
216 let τλ = τ * λ; 239 let τλ = τ * λ;
217 let mut inact : Vec<bool> = Vec::from_iter(std::iter::repeat(false).take(n)); 240 let mut inact : Vec<bool> = Vec::from_iter(std::iter::repeat(false).take(n));
241 let mut v = {
242 let (r, c) = x.shape_generic();
243 Vector::zeros_generic(r, c)
244 };
218 let mut s = DVector::zeros(0); 245 let mut s = DVector::zeros(0);
219 let mut decomp = nalgebra::linalg::LU::new(DMatrix::zeros(0, 0)); 246 let mut decomp = nalgebra::linalg::LU::new(DMatrix::zeros(0, 0));
220 let mut iters = 0; 247 let mut iters = 0;
221 248
222 let res = iterator.iterate_fallible(|state| { 249 let res = iterator.iterate_fallible(|state| {
344 /// 371 ///
345 /// * Valkonen T. - _A method for weighted projections to the positive definite 372 /// * Valkonen T. - _A method for weighted projections to the positive definite
346 /// cone_, <https://doi.org/10.1080/02331934.2014.929680>. 373 /// cone_, <https://doi.org/10.1080/02331934.2014.929680>.
347 /// 374 ///
348 /// This function returns the number of iterations taken. 375 /// This function returns the number of iterations taken.
349 pub fn quadratic_nonneg<F, I>( 376 pub fn quadratic_nonneg<F, I, S, SA, SG>(
350 method : InnerMethod, 377 method : InnerMethod,
351 mA : &DMatrix<F::MixedType>, 378 mA : &Matrix<F::MixedType, Dynamic, Dynamic, SA>,
352 g : &DVector<F::MixedType>, 379 g : &Vector<F::MixedType, Dynamic, SG>,
353 //c_ : F, 380 //c_ : F,
354 λ : F, 381 λ : F,
355 x : &mut DVector<F::MixedType>, 382 x : &mut Vector<F::MixedType, Dynamic, S>,
356 τ : F, 383 τ : F,
357 iterator : I 384 iterator : I
358 ) -> usize 385 ) -> usize
359 where F : Float + ToNalgebraRealField, 386 where F : Float + ToNalgebraRealField<MixedType=F>,
360 I : AlgIteratorFactory<F> 387 I : AlgIteratorFactory<F>,
388 S : StorageMut<F::MixedType, Dynamic, Const::<1>>,
389 SA : Storage<F::MixedType, Dynamic, Dynamic>,
390 SG : Storage<F::MixedType, Dynamic, Const::<1>>,
361 { 391 {
362 392
363 match method { 393 match method {
364 InnerMethod::FB => 394 InnerMethod::FB =>
365 quadratic_nonneg_fb(mA, g, λ, x, τ, iterator), 395 quadratic_nonneg_fb(mA, g, λ, x, τ, iterator),

mercurial