/*!
Basid definitions for data terms
*/

use numeric_literals::replace_float_literals;

use alg_tools::loc::Loc;
use alg_tools::euclidean::Euclidean;
use alg_tools::linops::GEMV;
pub use alg_tools::norms::L1;
use alg_tools::norms::Norm;

use crate::types::*;
pub use crate::types::L2Squared;
use crate::measures::DiscreteMeasure;

/// Calculates the residual $Aμ-b$.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn calculate_residual<
    F : Float,
    V : Euclidean<F> + Clone,
    A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
    const N : usize
>(
    μ : &DiscreteMeasure<Loc<F, N>, F>,
    opA : &A,
    b : &V
) -> V {
    let mut r = b.clone();
    opA.gemv(&mut r, 1.0, μ, -1.0);
    r
}

/// Calculates the residual $A(μ+μ_delta)-b$.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn calculate_residual2<
    F : Float,
    V : Euclidean<F> + Clone,
    A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
    const N : usize
>(
    μ : &DiscreteMeasure<Loc<F, N>, F>,
    μ_delta : &DiscreteMeasure<Loc<F, N>, F>,
    opA : &A,
    b : &V
) -> V {
    let mut r = b.clone();
    opA.gemv(&mut r, 1.0, μ, -1.0);
    opA.gemv(&mut r, 1.0, μ_delta, -1.0);
    r
}


/// Trait for data terms
#[replace_float_literals(F::cast_from(literal))]
pub trait DataTerm<F : Float, V, const N : usize> {
    /// Calculates $F(y)$, where $F$ is the data fidelity.
    fn calculate_fit(&self, _residual : &V) -> F;

    /// Calculates $F(Aμ-b)$, where $F$ is the data fidelity.
    fn calculate_fit_op<A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>>(
        &self,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
        opA : &A,
        b : &V
    ) -> F
    where V : Euclidean<F> + Clone {
        let r = calculate_residual(&μ, opA, b);
        self.calculate_fit(&r)
    }
}

impl<F : Float, V : Euclidean<F>, const N : usize>
DataTerm<F, V, N>
for L2Squared {
    fn calculate_fit(&self, residual : &V) -> F {
        residual.norm2_squared_div2()
    }
}


impl<F : Float, V : Euclidean<F> + Norm<F, L1>, const N : usize>
DataTerm<F, V, N>
for L1 {
    fn calculate_fit(&self, residual : &V) -> F {
        residual.norm(L1)
    }
}
