src/DifferentiableFN.jl

Wed, 22 Dec 2021 11:14:38 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 22 Dec 2021 11:14:38 +0200
changeset 35
d881275c6564
parent 27
62c62f451a41
child 37
f8be66557e0f
permissions
-rw-r--r--

Add metaprogramming tools and fast multidimensional loops.

########################################################
# Abstract class for differentiable functions that have
# abstract LinOps as differentials. Also included are
# sum composition and squaring.
########################################################

module DifferentiableFN

using ..LinOps
using ..Util

export Func,
       DiffF,
       value,
       differential,
       adjoint_differential,
       Squared,
       Summed,
       SummedVOnly

abstract type Func{X,Y} end
abstract type DiffF{X,Y,T} <: Func{X, Y} end

function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}}
    @error "`value` unimplemented"
end

# function (f :: D)(x::X) where {X, Y, T <: LinOp{X, Y}, D <: DiffF{X, Y, T}}
#     return value(x)
# end

function differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}}
    @error "`differential` unimplemented"
end

function adjoint_differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}}
    @error "`adjoint_differential` unimplemented"
end

function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}}
    return AdjointOp(differential(f, x))
end

# Sum of differentiable functions

struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T}
    a :: D₁
    b :: D₂
end

function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
    return value(summed.a, x)+value(summed.b, x)
end

function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
    return differential(summed.a, x) + differential(summed.b, x)
end

# Sum of (non-differentiable) fucntions

struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y}
    a :: D₁
    b :: D₂
end

function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}}
    return value(summed.a, x)+value(summed.b, x)
end

# Squared norm of a differentiable function as a differentiable function

ℝ = Float64
ℝⁿ = Vector{Float64}

struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}}
    r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}}
end

function value(sq :: Squared, x :: ℝⁿ)
    return norm₂²(value(sq.r, x))/2
end

function differential(sq :: Squared, x :: ℝⁿ)
    v = value(sq.r, x)
    d = differential(sq.r, x)
    return MatrixOp(v'*d.m)
end

end

mercurial