/*!
Proximal penalty abstraction
*/

use numeric_literals::replace_float_literals;
use alg_tools::types::*;
use serde::{Serialize, Deserialize};

use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::iterate::{
    AlgIteratorIteration,
    AlgIterator,
};
use crate::measures::RNDM;
use crate::types::{
    RefinementSettings,
    IterInfo,
};
use crate::subproblem::InnerSettings;
use crate::tolerance::Tolerance;
use crate::measures::merging::SpikeMergingMethod;
use crate::regularisation::RegTerm;

pub mod wave;
pub mod radon_squared;
pub use radon_squared::RadonSquared;

/// Settings for the solution of the stepwise optimality condition.
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBGenericConfig<F : Float> {
    /// Tolerance for point insertion.
    pub tolerance : Tolerance<F>,

    /// Stop looking for predual maximum (where to isert a new point) below
    /// `tolerance` multiplied by this factor.
    ///
    /// Not used by [`super::radon_fb`].
    pub insertion_cutoff_factor : F,

    /// Settings for branch and bound refinement when looking for predual maxima
    pub refinement : RefinementSettings<F>,

    /// Maximum insertions within each outer iteration
    ///
    /// Not used by [`super::radon_fb`].
    pub max_insertions : usize,

    /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
    ///
    /// Not used by [`super::radon_fb`].
    pub bootstrap_insertions : Option<(usize, usize)>,

    /// Inner method settings
    pub inner : InnerSettings<F>,

    /// Spike merging method
    pub merging : SpikeMergingMethod<F>,

    /// Tolerance multiplier for merges
    pub merge_tolerance_mult : F,

    /// Merge spikes after last step (even if merging not generally enabled)
    pub final_merging : bool,

    /// Use fitness as merging criterion. Implies worse convergence guarantees.
    pub fitness_merging : bool,

    /// Iterations between merging heuristic tries
    pub merge_every : usize,

    // /// Save $μ$ for postprocessing optimisation
    // pub postprocessing : bool
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for FBGenericConfig<F> {
    fn default() -> Self {
        FBGenericConfig {
            tolerance : Default::default(),
            insertion_cutoff_factor : 1.0,
            refinement : Default::default(),
            max_insertions : 100,
            //bootstrap_insertions : None,
            bootstrap_insertions : Some((10, 1)),
            inner : Default::default(),
            merging : Default::default(),
            final_merging : true,
            fitness_merging : false,
            merge_every : 10,
            merge_tolerance_mult : 2.0,
            // postprocessing : false,
        }
    }
}

impl<F : Float> FBGenericConfig<F> {
    /// Check if merging should be attempted this iteration
    pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool {
        self.merging.enabled && state.iteration() % self.merge_every == 0
    }

    /// Returns the final merging method
    pub fn final_merging_method(&self) -> SpikeMergingMethod<F> {
        SpikeMergingMethod{ enabled : self.final_merging, ..self.merging}
    }
}


/// Trait for proximal penalties
pub trait ProxPenalty<F, PreadjointCodomain, Reg, const N : usize>
where
    F : Float + ToNalgebraRealField,
    Reg : RegTerm<F, N>,
{
    type ReturnMapping : RealMapping<F, N>;

    /// Insert new spikes into `μ` to approximately satisfy optimality conditions
    /// with the forward step term fixed to `τv`.
    ///
    /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`,
    /// as well as an indication of whether the tolerance bounds `ε` are satisfied.
    ///
    /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same
    /// spike locations, while `ν_delta` may have different locations.
    ///
    /// `τv` is mutable to allow [`alg_tools::bisection_tree::btfn::BTFN`] refinement.
    /// Actual values of `τv` are not supposed to be mutated.
    fn insert_and_reweigh<I>(
        &self,
        μ : &mut RNDM<F, N>,
        τv : &mut PreadjointCodomain,
        μ_base : &RNDM<F, N>,
        ν_delta: Option<&RNDM<F, N>>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
        reg : &Reg,
        state : &AlgIteratorIteration<I>,
        stats : &mut IterInfo<F, N>,
    ) -> (Option<Self::ReturnMapping>, bool)
    where
        I : AlgIterator;


    /// Merge spikes, if possible.
    ///
    /// Either optimality condition merging or objective value (fitness) merging
    /// may be used, the latter only if `fitness` is provided and `config.fitness_merging`
    /// is set.
    fn merge_spikes(
        &self,
        μ : &mut RNDM<F, N>,
        τv : &mut PreadjointCodomain,
        μ_base : &RNDM<F, N>,
        ν_delta: Option<&RNDM<F, N>>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
        reg : &Reg,
        fitness : Option<impl Fn(&RNDM<F, N>) -> F>,
    ) -> usize;

    /// Merge spikes, if possible.
    ///
    /// Unlike [`merge_spikes`], this variant only supports optimality condition based merging
    #[inline]
    fn merge_spikes_no_fitness(
        &self,
        μ : &mut RNDM<F, N>,
        τv : &mut PreadjointCodomain,
        μ_base : &RNDM<F, N>,
        ν_delta: Option<&RNDM<F, N>>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
        reg : &Reg,
    ) -> usize {
        /// This is a hack to create a `None` of same type as a `Some`
        // for the `impl Fn` parameter of `merge_spikes`.
        #[inline]
        fn into_none<T>(_ : Option<T>) -> Option<T>{
            None
        }
        self.merge_spikes(μ, τv, μ_base, ν_delta, τ, ε, config, reg,
                          into_none(Some(|_ : &RNDM<F, N>| F::ZERO)))
    }
}
