| |
1 ######################################################## |
| |
2 # Abstract class for differentiable functions that have |
| |
3 # abstract LinOps as differentials. Also included are |
| |
4 # sum composition and squaring. |
| |
5 ######################################################## |
| |
6 |
| 1 module DifferentiableFN |
7 module DifferentiableFN |
| 2 |
8 |
| 3 using ..LinOps |
9 using ..LinOps |
| |
10 using ..Util |
| 4 |
11 |
| 5 export Func, |
12 export Func, |
| 6 DiffF, |
13 DiffF, |
| 7 value, |
14 value, |
| 8 differential, |
15 differential, |
| 9 adjoint_differential |
16 adjoint_differential, |
| |
17 Squared, |
| |
18 Summed, |
| |
19 SummedVOnly |
| 10 |
20 |
| 11 abstract type Func{X,Y} end |
21 abstract type Func{X,Y} end |
| 12 abstract type DiffF{X,Y,T} <: Func{X, Y} end |
22 abstract type DiffF{X,Y,T} <: Func{X, Y} end |
| 13 |
23 |
| 14 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} |
24 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} |
| 29 |
39 |
| 30 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} |
40 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} |
| 31 return AdjointOp(differential(f, x)) |
41 return AdjointOp(differential(f, x)) |
| 32 end |
42 end |
| 33 |
43 |
| |
44 # Sum of differentiable functions |
| |
45 |
| |
46 struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T} |
| |
47 a :: D₁ |
| |
48 b :: D₂ |
| 34 end |
49 end |
| |
50 |
| |
51 function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} |
| |
52 return value(summed.a, x)+value(summed.b, x) |
| |
53 end |
| |
54 |
| |
55 function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} |
| |
56 return differential(summed.a, x) + differential(summed.b, x) |
| |
57 end |
| |
58 |
| |
59 # Sum of (non-differentiable) fucntions |
| |
60 |
| |
61 struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y} |
| |
62 a :: D₁ |
| |
63 b :: D₂ |
| |
64 end |
| |
65 |
| |
66 function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} |
| |
67 return value(summed.a, x)+value(summed.b, x) |
| |
68 end |
| |
69 |
| |
70 # Squared norm of a differentiable function as a differentiable function |
| |
71 |
| |
72 ℝ = Float64 |
| |
73 ℝⁿ = Vector{Float64} |
| |
74 |
| |
75 struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}} |
| |
76 r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}} |
| |
77 end |
| |
78 |
| |
79 function value(sq :: Squared, x :: ℝⁿ) |
| |
80 return norm₂²(value(sq.r, x))/2 |
| |
81 end |
| |
82 |
| |
83 function differential(sq :: Squared, x :: ℝⁿ) |
| |
84 v = value(sq.r, x) |
| |
85 d = differential(sq.r, x) |
| |
86 return MatrixOp(v'*d.m) |
| |
87 end |
| |
88 |
| |
89 end |