# HG changeset patch # User Tuomo Valkonen # Date 1618347111 18000 # Node ID 62c62f451a410394aedde8056b0d68565f23e788 # Parent f075aca8485be9056e7142206ecaa947c62783f0 DifferentiableFN sums and squares diff -r f075aca8485b -r 62c62f451a41 src/DifferentiableFN.jl --- a/src/DifferentiableFN.jl Tue Apr 13 15:51:28 2021 -0500 +++ b/src/DifferentiableFN.jl Tue Apr 13 15:51:51 2021 -0500 @@ -1,12 +1,22 @@ +######################################################## +# Abstract class for differentiable functions that have +# abstract LinOps as differentials. Also included are +# sum composition and squaring. +######################################################## + module DifferentiableFN using ..LinOps +using ..Util export Func, DiffF, value, differential, - adjoint_differential + adjoint_differential, + Squared, + Summed, + SummedVOnly abstract type Func{X,Y} end abstract type DiffF{X,Y,T} <: Func{X, Y} end @@ -31,4 +41,49 @@ return AdjointOp(differential(f, x)) end +# Sum of differentiable functions + +struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T} + a :: D₁ + b :: D₂ end + +function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} + return value(summed.a, x)+value(summed.b, x) +end + +function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} + return differential(summed.a, x) + differential(summed.b, x) +end + +# Sum of (non-differentiable) fucntions + +struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y} + a :: D₁ + b :: D₂ +end + +function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} + return value(summed.a, x)+value(summed.b, x) +end + +# Squared norm of a differentiable function as a differentiable function + +ℝ = Float64 +ℝⁿ = Vector{Float64} + +struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}} + r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}} +end + +function value(sq :: Squared, x :: ℝⁿ) + return norm₂²(value(sq.r, x))/2 +end + +function differential(sq :: Squared, x :: ℝⁿ) + v = value(sq.r, x) + d = differential(sq.r, x) + return MatrixOp(v'*d.m) +end + +end