/*!
Solver for the point source localisation problem using a simplified forward-backward splitting method.

Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map.
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
use nalgebra::DVector;

use alg_tools::iterate::{
    AlgIteratorIteration,
    AlgIterator
};
use alg_tools::norms::L2;
use alg_tools::linops::Mapping;
use alg_tools::bisection_tree::{
    BTFN,
    Bounds,
    BTSearch,
    SupportGenerator,
    LocalAnalysis,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

use crate::types::*;
use crate::measures::{
    RNDM,
    DeltaMeasure,
    Radon,
};
use crate::measures::merging::SpikeMerging;
use crate::regularisation::RegTerm;
use crate::forward_model::{
    ForwardModel,
    AdjointProductBoundedBy
};
use super::{
    FBGenericConfig,
    ProxPenalty,
};

/// Radon-norm squared proximal penalty

#[derive(Copy,Clone,Serialize,Deserialize)]
pub struct RadonSquared;

#[replace_float_literals(F::cast_from(literal))]
impl<F, GA, BTA, S, Reg, const N : usize>
ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for RadonSquared
where
    F : Float + ToNalgebraRealField,
    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
    Reg : RegTerm<F, N>,
    RNDM<F, N> : SpikeMerging<F>,
{
    type ReturnMapping = BTFN<F, GA, BTA, N>;

    fn insert_and_reweigh<I>(
        &self,
        μ : &mut RNDM<F, N>,
        τv : &mut BTFN<F, GA, BTA, N>,
        μ_base : &RNDM<F, N>,
        ν_delta: Option<&RNDM<F, N>>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
        reg : &Reg,
        _state : &AlgIteratorIteration<I>,
        stats : &mut IterInfo<F, N>,
    ) -> (Option<Self::ReturnMapping>, bool)
    where
        I : AlgIterator
    {
        assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term");

        let mut y = μ_base.masses_vec();

        'i_and_w: for i in 0..=1 {
            // Optimise weights
            if μ.len() > 0 {
                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
                // from the beginning of the iteration are all contained in the immutable c and g.
                // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
                // problems have not yet been updated to sign change.
                let g̃ = DVector::from_iterator(μ.len(),
                                               μ.iter_locations()
                                                .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
                let mut x = μ.masses_dvector();
                // Ugly hack because DVector::push doesn't push but copies.
                let yvec = DVector::from_column_slice(y.as_slice());
                // Solve finite-dimensional subproblem.
                stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config);

                // Update masses of μ based on solution of finite-dimensional subproblem.
                μ.set_masses_dvector(&x);
            }

            if i>0 {
                // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
                //let n = μ.dist_matching(μ_base);
                //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
                break 'i_and_w
            }
            
            // Calculate ‖μ - μ_base‖_ℳ
            let n = μ.dist_matching(μ_base);
        
            // Find a spike to insert, if needed.
            // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
            // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
            match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
                None => { break 'i_and_w },
                Some((ξ, _v_ξ, _in_bounds)) => {
                    // Weight is found out by running the finite-dimensional optimisation algorithm
                    // above
                    *μ += DeltaMeasure { x : ξ, α : 0.0 };
                    //*μ_base += DeltaMeasure { x : ξ, α : 0.0 };
                    y.push(0.0.to_nalgebra_mixed());
                    stats.inserted += 1;
                }
            };
        }

        (None, true)
    }

    fn merge_spikes(
        &self,
        μ : &mut RNDM<F, N>,
        τv : &mut BTFN<F, GA, BTA, N>,
        μ_base : &RNDM<F, N>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
        reg : &Reg,
    ) -> usize
    {
        μ.merge_spikes(config.merging, |μ_candidate| {
            // Important: μ_candidate's new points are afterwards,
            // and do not conflict with μ_base.
            // TODO: could simplify to requiring μ_base instead of μ_radon.
            // but may complicate with sliding base's exgtra points that need to be
            // after μ_candidate's extra points.
            // TODO: doesn't seem to work, maybe need to merge μ_base as well?
            // Although that doesn't seem to make sense.
            let μ_radon = μ_candidate.sub_matching(μ_base);
            reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon)
            //let n = μ_candidate.dist_matching(μ_base);
            //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
        })
    }
}


impl<F, A, const N : usize> AdjointProductBoundedBy<RNDM<F, N>, RadonSquared>
for A
where
    F : Float,
    A : ForwardModel<RNDM<F, N>, F>
{
    type FloatType = F;

    fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option<Self::FloatType> {
        self.opnorm_bound(Radon, L2).powi(2).into()
    }
}
