src/subproblem.rs

changeset 4
5aa5c279e341
parent 0
eb3c7813b67a
--- a/src/subproblem.rs	Thu Dec 01 23:07:35 2022 +0200
+++ b/src/subproblem.rs	Sun Sep 25 21:45:56 2022 +0300
@@ -1,7 +1,19 @@
 //! Iterative algorithms for solving finite-dimensional subproblems.
 
 use serde::{Serialize, Deserialize};
-use nalgebra::{DVector, DMatrix};
+use nalgebra::{
+    Matrix,
+    Vector,
+    Storage,
+    StorageMut,
+    Const,
+    Dim,
+    DVector,
+    DMatrix,
+    DefaultAllocator,
+    Dynamic,
+};
+use nalgebra::base::allocator::Allocator;
 use numeric_literals::replace_float_literals;
 use itertools::{izip, Itertools};
 use colored::Colorize;
@@ -80,25 +92,33 @@
 /// The `λ` component of the model is handled in the proximal step instead of the gradient step
 /// for potential performance improvements.
 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())]
-pub fn quadratic_nonneg_fb<F, I>(
-    mA : &DMatrix<F::MixedType>,
-    g : &DVector<F::MixedType>,
+pub fn quadratic_nonneg_fb<F, I, S, SA, SG, D>(
+    mA : &Matrix<F::MixedType, D, D, SA>,
+    g : &Vector<F::MixedType, D, SG>,
     //c_ : F,
     λ_ : F,
-    x : &mut DVector<F::MixedType>,
+    x : &mut Vector<F::MixedType, D, S>,
     τ_ : F,
     iterator : I
 ) -> usize
 where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<F>
+      I : AlgIteratorFactory<F>,
+      D : Dim,
+      S : StorageMut<F::MixedType, D, Const::<1>>,
+      SA : Storage<F::MixedType, D, D>,
+      SG : Storage<F::MixedType, D, Const::<1>>,
+      DefaultAllocator : Allocator<F::MixedType, D>
 {
-    let mut xprev = x.clone();
+    let mut xprev = x.clone_owned();
     //let c = c_.to_nalgebra_mixed();
     let λ = λ_.to_nalgebra_mixed();
     let τ = τ_.to_nalgebra_mixed();
     let τλ = τ * λ;
-    let mut v = DVector::zeros(x.len());
     let mut iters = 0;
+    let mut v = {
+        let (r, c) = x.shape_generic();
+        Vector::zeros_generic(r, c)
+    };
 
     iterator.iterate(|state| {
         // Replace `x` with $x - τ[Ax-g]= [x + τg]- τAx$
@@ -195,26 +215,33 @@
 /// forward-backward step.
 /// </p>
 #[replace_float_literals(F::cast_from(literal).to_nalgebra_mixed())]
-pub fn quadratic_nonneg_ssn<F, I>(
-    mA : &DMatrix<F::MixedType>,
-    g : &DVector<F::MixedType>,
+pub fn quadratic_nonneg_ssn<F, I, S, SA, SG>(
+    mA : &Matrix<F::MixedType, Dynamic, Dynamic, SA>,
+    g : &Vector<F::MixedType, Dynamic, SG>,
     //c_ : F,
     λ_ : F,
-    x : &mut DVector<F::MixedType>,
+    x : &mut Vector<F::MixedType, Dynamic, S>,
     τ_ : F,
     iterator : I
 ) -> Result<usize, NumericalError>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<F>
+where F : Float + ToNalgebraRealField<MixedType=F>,
+      I : AlgIteratorFactory<F>,
+      S : StorageMut<F::MixedType, Dynamic, Const::<1>>,
+      SA : Storage<F::MixedType, Dynamic, Dynamic>,
+      SG : Storage<F::MixedType, Dynamic, Const::<1>>,
 {
     let n = x.len();
-    let mut xprev = x.clone();
+    let mut xprev = x.clone_owned();
     let mut v = DVector::zeros(n);
     //let c = c_.to_nalgebra_mixed();
     let λ = λ_.to_nalgebra_mixed();
     let τ = τ_.to_nalgebra_mixed();
     let τλ = τ * λ;
     let mut inact : Vec<bool> = Vec::from_iter(std::iter::repeat(false).take(n));
+    let mut v = {
+        let (r, c) = x.shape_generic();
+        Vector::zeros_generic(r, c)
+    };
     let mut s = DVector::zeros(0);
     let mut decomp = nalgebra::linalg::LU::new(DMatrix::zeros(0, 0));
     let mut iters = 0;
@@ -346,18 +373,21 @@
 ///    cone_, <https://doi.org/10.1080/02331934.2014.929680>.
 ///
 /// This function returns the number of iterations taken.
-pub fn quadratic_nonneg<F, I>(
+pub fn quadratic_nonneg<F, I, S, SA, SG>(
     method : InnerMethod,
-    mA : &DMatrix<F::MixedType>,
-    g : &DVector<F::MixedType>,
+    mA : &Matrix<F::MixedType, Dynamic, Dynamic, SA>,
+    g : &Vector<F::MixedType, Dynamic, SG>,
     //c_ : F,
     λ : F,
-    x : &mut DVector<F::MixedType>,
+    x : &mut Vector<F::MixedType, Dynamic, S>,
     τ : F,
     iterator : I
 ) -> usize
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<F>
+where F : Float + ToNalgebraRealField<MixedType=F>,
+      I : AlgIteratorFactory<F>,
+      S : StorageMut<F::MixedType, Dynamic, Const::<1>>,
+      SA : Storage<F::MixedType, Dynamic, Dynamic>,
+      SG : Storage<F::MixedType, Dynamic, Const::<1>>,
 {
     
     match method {

mercurial