Sat, 20 Feb 2021 16:26:02 -0500
Add DifferentiableFN
21 | 1 | module DifferentiableFN |
2 | ||
3 | using ..LinOps | |
4 | ||
5 | export DiffF, | |
6 | value, | |
7 | differential, | |
8 | adjoint_differential | |
9 | ||
10 | abstract type DiffF{X,Y,T} end | |
11 | ||
12 | function value(f :: D, x :: X) :: Y where {X, Y, T <: LinOp{X,Y}, D <: DiffF{X,Y,T}} | |
13 | @error "`value` unimplemented" | |
14 | end | |
15 | ||
16 | # function (f :: D)(x::X) where {X, Y, T <: LinOp{X, Y}, D <: DiffF{X, Y, T}} | |
17 | # return value(x) | |
18 | # end | |
19 | ||
20 | function differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}} | |
21 | @error "`differential` unimplemented" | |
22 | end | |
23 | ||
24 | function adjoint_differential(f :: DiffF{X,Y,T}, x :: X) :: T where {X, Y, T <: LinOp{X, Y}} | |
25 | @error "`adjoint_differential` unimplemented" | |
26 | end | |
27 | ||
28 | function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} | |
29 | return AdjointOp(differential(f, x)) | |
30 | end | |
31 | ||
32 | end |