/*!
Implementations of the distance function and the distance function squared, on manifolds.
*/

use alg_tools::mapping::{Mapping, Instance};
use alg_tools::euclidean::Euclidean;
use crate::manifold::ManifoldPoint;
use crate::fb::{Grad, Desc, Prox};

/// Structure for distance-to functions
pub struct DistTo<M : ManifoldPoint>(pub M);

impl<M : ManifoldPoint> Mapping<M> for DistTo<M> {
    type Codomain = f64;

    fn apply<I : Instance<M>>(&self, x : I) -> Self::Codomain {
        x.eval(|x̃| self.0.dist_to(x̃))
    }
}

/// Structure for distance-to functions
pub struct DistToSquaredDiv2<M : ManifoldPoint>(pub M);

impl<M : ManifoldPoint> Mapping<M> for DistToSquaredDiv2<M> {
    type Codomain = f64;

    fn apply<I : Instance<M>>(&self, x : I) -> Self::Codomain {
        let d = x.eval(|x̃| self.0.dist_to(x̃));
        d*d / 2.0
    }
}

impl<M : ManifoldPoint> Desc<M> for DistToSquaredDiv2<M> {
    fn desc(&self, τ : f64, x : M) -> M {
        let t = self.grad(&x) * τ;
        x.exp(&t)
    }
}

impl<M : ManifoldPoint> Grad<M> for DistToSquaredDiv2<M> {
    fn grad(&self, x : &M) -> M::Tangent {
       x.log(&self.0)
    }
}

impl<M : ManifoldPoint> Prox<M> for DistTo<M> {
    /// This proximal map is a type of soft-thresholding on manifolds.
    fn prox(&self, τ : f64, x : M) -> M {
        let v = x.log(&self.0);
        let d = v.norm2();
        if d <= τ {
            self.0.clone()
        } else {
            x.exp( &(v * (τ / d)) )
        }
    }
}
