DifferentiableFN sums and squares

Tue, 13 Apr 2021 15:51:51 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 13 Apr 2021 15:51:51 -0500
changeset 27
62c62f451a41
parent 26
f075aca8485b
child 28
ffd693c381f2

DifferentiableFN sums and squares

src/DifferentiableFN.jl file | annotate | diff | comparison | revisions
--- a/src/DifferentiableFN.jl	Tue Apr 13 15:51:28 2021 -0500
+++ b/src/DifferentiableFN.jl	Tue Apr 13 15:51:51 2021 -0500
@@ -1,12 +1,22 @@
+########################################################
+# Abstract class for differentiable functions that have
+# abstract LinOps as differentials. Also included are
+# sum composition and squaring.
+########################################################
+
 module DifferentiableFN
 
 using ..LinOps
+using ..Util
 
 export Func,
        DiffF,
        value,
        differential,
-       adjoint_differential
+       adjoint_differential,
+       Squared,
+       Summed,
+       SummedVOnly
 
 abstract type Func{X,Y} end
 abstract type DiffF{X,Y,T} <: Func{X, Y} end
@@ -31,4 +41,49 @@
     return AdjointOp(differential(f, x))
 end
 
+# Sum of differentiable functions
+
+struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T}
+    a :: D₁
+    b :: D₂
 end
+
+function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
+    return value(summed.a, x)+value(summed.b, x)
+end
+
+function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
+    return differential(summed.a, x) + differential(summed.b, x)
+end
+
+# Sum of (non-differentiable) fucntions
+
+struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y}
+    a :: D₁
+    b :: D₂
+end
+
+function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}}
+    return value(summed.a, x)+value(summed.b, x)
+end
+
+# Squared norm of a differentiable function as a differentiable function
+
+ℝ = Float64
+ℝⁿ = Vector{Float64}
+
+struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}}
+    r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}}
+end
+
+function value(sq :: Squared, x :: ℝⁿ)
+    return norm₂²(value(sq.r, x))/2
+end
+
+function differential(sq :: Squared, x :: ℝⁿ)
+    v = value(sq.r, x)
+    d = differential(sq.r, x)
+    return MatrixOp(v'*d.m)
+end
+
+end

mercurial