Wed, 22 Dec 2021 11:14:38 +0200
Add metaprogramming tools and fast multidimensional loops.
######################################################## # Abstract class for differentiable functions that have # abstract LinOps as differentials. Also included are # sum composition and squaring. ######################################################## module DifferentiableFN using ..LinOps using ..Util export Func, DiffF, value, differential, adjoint_differential, Squared, Summed, SummedVOnly abstract type Func{X,Y} end abstract type DiffF{X,Y,T} <: Func{X, Y} end function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} @error "`value` unimplemented" end # function (f :: D)(x::X) where {X, Y, T <: LinOp{X, Y}, D <: DiffF{X, Y, T}} # return value(x) # end function differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}} @error "`differential` unimplemented" end function adjoint_differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}} @error "`adjoint_differential` unimplemented" end function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} return AdjointOp(differential(f, x)) end # Sum of differentiable functions struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T} a :: D₁ b :: D₂ end function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} return value(summed.a, x)+value(summed.b, x) end function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} return differential(summed.a, x) + differential(summed.b, x) end # Sum of (non-differentiable) fucntions struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y} a :: D₁ b :: D₂ end function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} return value(summed.a, x)+value(summed.b, x) end # Squared norm of a differentiable function as a differentiable function ℝ = Float64 ℝⁿ = Vector{Float64} struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}} r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}} end function value(sq :: Squared, x :: ℝⁿ) return norm₂²(value(sq.r, x))/2 end function differential(sq :: Squared, x :: ℝⁿ) v = value(sq.r, x) d = differential(sq.r, x) return MatrixOp(v'*d.m) end end