src/bisection_tree/either.rs

Thu, 01 May 2025 13:06:58 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 01 May 2025 13:06:58 -0500
branch
dev
changeset 124
6aa955ad8122
parent 59
9226980e45a7
child 150
c4e394a9c84c
permissions
-rw-r--r--

Transpose loc parameters to allow f64 defaults

use std::iter::Chain;
use std::sync::Arc;

use crate::iter::{MapF, MapZ, Mappable};
use crate::loc::Loc;
use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Instance, Mapping, Space};
use crate::sets::Cube;
use crate::types::*;

use super::aggregator::*;
use super::support::*;

/// A structure for storing two [`SupportGenerator`]s summed/chain together.
///
/// This is needed to work with sums of different types of [`Support`]s.
#[derive(Debug, Clone)]
pub struct BothGenerators<A, B>(pub(super) Arc<A>, pub(super) Arc<B>);

/// A structure for a [`Support`] that can be either `A` or `B`.
///
/// This is needed to work with sums of different types of [`Support`]s.
#[derive(Debug, Clone)]
pub enum EitherSupport<B, A> {
    Left(A),
    Right(B),
}

// We need type alias bounds to access associate types.
#[allow(type_alias_bounds)]
type BothAllDataIter<
    'a,
    F,
    G1: SupportGenerator<N, F>,
    G2: SupportGenerator<N, F>,
    const N: usize,
> = Chain<
    MapF<G1::AllDataIter<'a>, (usize, EitherSupport<G2::SupportType, G1::SupportType>)>,
    MapZ<G2::AllDataIter<'a>, usize, (usize, EitherSupport<G2::SupportType, G1::SupportType>)>,
>;

impl<G1, G2> BothGenerators<G1, G2> {
    /// Helper for [`all_left_data`].
    #[inline]
    fn map_left<F: Float, const N: usize>(
        (d, support): (G1::Id, G1::SupportType),
    ) -> (usize, EitherSupport<G2::SupportType, G1::SupportType>)
    where
        G1: SupportGenerator<N, F, Id = usize>,
        G2: SupportGenerator<N, F, Id = usize>,
    {
        let id: usize = d.into();
        (id.into(), EitherSupport::Left(support))
    }

    /// Helper for [`all_right_data`].
    #[inline]
    fn map_right<F: Float, const N: usize>(
        n0: &usize,
        (d, support): (G2::Id, G2::SupportType),
    ) -> (usize, EitherSupport<G2::SupportType, G1::SupportType>)
    where
        G1: SupportGenerator<N, F, Id = usize>,
        G2: SupportGenerator<N, F, Id = usize>,
    {
        let id: usize = d.into();
        ((n0 + id).into(), EitherSupport::Right(support))
    }

    /// Calls [`SupportGenerator::all_data`] on the “left” support generator.
    ///
    /// Converts both the id and the [`Support`] into a form that corresponds to `BothGenerators`.
    #[inline]
    pub(super) fn all_left_data<F: Float, const N: usize>(
        &self,
    ) -> MapF<G1::AllDataIter<'_>, (usize, EitherSupport<G2::SupportType, G1::SupportType>)>
    where
        G1: SupportGenerator<N, F, Id = usize>,
        G2: SupportGenerator<N, F, Id = usize>,
    {
        self.0.all_data().mapF(Self::map_left)
    }

    /// Calls [`SupportGenerator::all_data`] on the “right” support generator.
    ///
    /// Converts both the id and the [`Support`] into a form that corresponds to `BothGenerators`.
    #[inline]
    pub(super) fn all_right_data<F: Float, const N: usize>(
        &self,
    ) -> MapZ<G2::AllDataIter<'_>, usize, (usize, EitherSupport<G2::SupportType, G1::SupportType>)>
    where
        G1: SupportGenerator<N, F, Id = usize>,
        G2: SupportGenerator<N, F, Id = usize>,
    {
        let n0 = self.0.support_count();
        self.1.all_data().mapZ(n0, Self::map_right)
    }
}

impl<F: Float, G1, G2, const N: usize> SupportGenerator<N, F> for BothGenerators<G1, G2>
where
    G1: SupportGenerator<N, F, Id = usize>,
    G2: SupportGenerator<N, F, Id = usize>,
{
    type Id = usize;
    type SupportType = EitherSupport<G2::SupportType, G1::SupportType>;
    type AllDataIter<'a>
        = BothAllDataIter<'a, F, G1, G2, N>
    where
        G1: 'a,
        G2: 'a;

    #[inline]
    fn support_for(&self, id: Self::Id) -> Self::SupportType {
        let n0 = self.0.support_count();
        if id < n0 {
            EitherSupport::Left(self.0.support_for(id.into()))
        } else {
            EitherSupport::Right(self.1.support_for((id - n0).into()))
        }
    }

    #[inline]
    fn support_count(&self) -> usize {
        self.0.support_count() + self.1.support_count()
    }

    #[inline]
    fn all_data(&self) -> Self::AllDataIter<'_> {
        self.all_left_data().chain(self.all_right_data())
    }
}

impl<F: Float, S1, S2, const N: usize> Support<N, F> for EitherSupport<S2, S1>
where
    S1: Support<N, F>,
    S2: Support<N, F>,
{
    #[inline]
    fn support_hint(&self) -> Cube<N, F> {
        match self {
            EitherSupport::Left(ref a) => a.support_hint(),
            EitherSupport::Right(ref b) => b.support_hint(),
        }
    }

    #[inline]
    fn in_support(&self, x: &Loc<N, F>) -> bool {
        match self {
            EitherSupport::Left(ref a) => a.in_support(x),
            EitherSupport::Right(ref b) => b.in_support(x),
        }
    }

    #[inline]
    fn bisection_hint(&self, cube: &Cube<N, F>) -> [Option<F>; N] {
        match self {
            EitherSupport::Left(ref a) => a.bisection_hint(cube),
            EitherSupport::Right(ref b) => b.bisection_hint(cube),
        }
    }
}

impl<F: Float, A, S1, S2, const N: usize> LocalAnalysis<F, A, N> for EitherSupport<S2, S1>
where
    A: Aggregator,
    S1: LocalAnalysis<F, A, N>,
    S2: LocalAnalysis<F, A, N>,
{
    #[inline]
    fn local_analysis(&self, cube: &Cube<N, F>) -> A {
        match self {
            EitherSupport::Left(ref a) => a.local_analysis(cube),
            EitherSupport::Right(ref b) => b.local_analysis(cube),
        }
    }
}

impl<F: Float, A, S1, S2> GlobalAnalysis<F, A> for EitherSupport<S2, S1>
where
    A: Aggregator,
    S1: GlobalAnalysis<F, A>,
    S2: GlobalAnalysis<F, A>,
{
    #[inline]
    fn global_analysis(&self) -> A {
        match self {
            EitherSupport::Left(ref a) => a.global_analysis(),
            EitherSupport::Right(ref b) => b.global_analysis(),
        }
    }
}

impl<F, S1, S2, X> Mapping<X> for EitherSupport<S2, S1>
where
    F: Space,
    X: Space,
    S1: Mapping<X, Codomain = F>,
    S2: Mapping<X, Codomain = F>,
{
    type Codomain = F;

    #[inline]
    fn apply<I: Instance<X>>(&self, x: I) -> F {
        match self {
            EitherSupport::Left(ref a) => a.apply(x),
            EitherSupport::Right(ref b) => b.apply(x),
        }
    }
}

impl<X, S1, S2, O> DifferentiableImpl<X> for EitherSupport<S2, S1>
where
    O: Space,
    X: Space,
    S1: DifferentiableMapping<X, DerivativeDomain = O>,
    S2: DifferentiableMapping<X, DerivativeDomain = O>,
{
    type Derivative = O;

    #[inline]
    fn differential_impl<I: Instance<X>>(&self, x: I) -> O {
        match self {
            EitherSupport::Left(ref a) => a.differential(x),
            EitherSupport::Right(ref b) => b.differential(x),
        }
    }
}

macro_rules! make_either_scalarop_rhs {
    ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => {
        impl<F: Float, G1, G2> std::ops::$trait_assign<F> for BothGenerators<G1, G2>
        where
            G1: std::ops::$trait_assign<F> + Clone,
            G2: std::ops::$trait_assign<F> + Clone,
        {
            #[inline]
            fn $fn_assign(&mut self, t: F) {
                Arc::make_mut(&mut self.0).$fn_assign(t);
                Arc::make_mut(&mut self.1).$fn_assign(t);
            }
        }

        impl<'a, F: Float, G1, G2> std::ops::$trait<F> for &'a BothGenerators<G1, G2>
        where
            &'a G1: std::ops::$trait<F, Output = G1>,
            &'a G2: std::ops::$trait<F, Output = G2>,
        {
            type Output = BothGenerators<G1, G2>;
            #[inline]
            fn $fn(self, t: F) -> BothGenerators<G1, G2> {
                BothGenerators(Arc::new(self.0.$fn(t)), Arc::new(self.1.$fn(t)))
            }
        }
    };
}

make_either_scalarop_rhs!(Mul, mul, MulAssign, mul_assign);
make_either_scalarop_rhs!(Div, div, DivAssign, div_assign);

impl<G1, G2> std::ops::Neg for BothGenerators<G1, G2>
where
    G1: std::ops::Neg + Clone,
    G2: std::ops::Neg + Clone,
{
    type Output = BothGenerators<G1::Output, G2::Output>;
    #[inline]
    fn neg(self) -> Self::Output {
        BothGenerators(
            Arc::new(Arc::unwrap_or_clone(self.0).neg()),
            Arc::new(Arc::unwrap_or_clone(self.1).neg()),
        )
    }
}
/*
impl<'a, G1, G2> std::ops::Neg for &'a BothGenerators<G1, G2>
where &'a G1 : std::ops::Neg, &'a G2 : std::ops::Neg, {
    type Output = BothGenerators<<&'a G1 as std::ops::Neg>::Output,
                                 <&'a G2 as std::ops::Neg>::Output>;
    fn neg(self) -> Self::Output {
        BothGenerators(self.0.neg(), self.1.neg())
    }
}
*/

mercurial