src/DifferentiableFN.jl

changeset 27
62c62f451a41
parent 25
90f92ee9cb81
child 37
f8be66557e0f
equal deleted inserted replaced
26:f075aca8485b 27:62c62f451a41
1 ########################################################
2 # Abstract class for differentiable functions that have
3 # abstract LinOps as differentials. Also included are
4 # sum composition and squaring.
5 ########################################################
6
1 module DifferentiableFN 7 module DifferentiableFN
2 8
3 using ..LinOps 9 using ..LinOps
10 using ..Util
4 11
5 export Func, 12 export Func,
6 DiffF, 13 DiffF,
7 value, 14 value,
8 differential, 15 differential,
9 adjoint_differential 16 adjoint_differential,
17 Squared,
18 Summed,
19 SummedVOnly
10 20
11 abstract type Func{X,Y} end 21 abstract type Func{X,Y} end
12 abstract type DiffF{X,Y,T} <: Func{X, Y} end 22 abstract type DiffF{X,Y,T} <: Func{X, Y} end
13 23
14 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} 24 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}}
29 39
30 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} 40 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}}
31 return AdjointOp(differential(f, x)) 41 return AdjointOp(differential(f, x))
32 end 42 end
33 43
44 # Sum of differentiable functions
45
46 struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T}
47 a :: D₁
48 b :: D₂
34 end 49 end
50
51 function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
52 return value(summed.a, x)+value(summed.b, x)
53 end
54
55 function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}}
56 return differential(summed.a, x) + differential(summed.b, x)
57 end
58
59 # Sum of (non-differentiable) fucntions
60
61 struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y}
62 a :: D₁
63 b :: D₂
64 end
65
66 function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}}
67 return value(summed.a, x)+value(summed.b, x)
68 end
69
70 # Squared norm of a differentiable function as a differentiable function
71
72 ℝ = Float64
73 ℝⁿ = Vector{Float64}
74
75 struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}}
76 r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}}
77 end
78
79 function value(sq :: Squared, x :: ℝⁿ)
80 return norm₂²(value(sq.r, x))/2
81 end
82
83 function differential(sq :: Squared, x :: ℝⁿ)
84 v = value(sq.r, x)
85 d = differential(sq.r, x)
86 return MatrixOp(v'*d.m)
87 end
88
89 end

mercurial