|
1 ######################################################## |
|
2 # Abstract class for differentiable functions that have |
|
3 # abstract LinOps as differentials. Also included are |
|
4 # sum composition and squaring. |
|
5 ######################################################## |
|
6 |
1 module DifferentiableFN |
7 module DifferentiableFN |
2 |
8 |
3 using ..LinOps |
9 using ..LinOps |
|
10 using ..Util |
4 |
11 |
5 export Func, |
12 export Func, |
6 DiffF, |
13 DiffF, |
7 value, |
14 value, |
8 differential, |
15 differential, |
9 adjoint_differential |
16 adjoint_differential, |
|
17 Squared, |
|
18 Summed, |
|
19 SummedVOnly |
10 |
20 |
11 abstract type Func{X,Y} end |
21 abstract type Func{X,Y} end |
12 abstract type DiffF{X,Y,T} <: Func{X, Y} end |
22 abstract type DiffF{X,Y,T} <: Func{X, Y} end |
13 |
23 |
14 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} |
24 function value(f :: F, x :: X) :: Y where {X, Y, F <: Func{X,Y}} |
29 |
39 |
30 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} |
40 function adjoint_differential(f :: DiffF{X,Y, T}, x :: X) :: T where {X, Y, T <: AdjointableOp{X, Y}} |
31 return AdjointOp(differential(f, x)) |
41 return AdjointOp(differential(f, x)) |
32 end |
42 end |
33 |
43 |
|
44 # Sum of differentiable functions |
|
45 |
|
46 struct Summed{X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} <: DiffF{X, Y, T} |
|
47 a :: D₁ |
|
48 b :: D₂ |
34 end |
49 end |
|
50 |
|
51 function value(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} |
|
52 return value(summed.a, x)+value(summed.b, x) |
|
53 end |
|
54 |
|
55 function differential(summed :: Summed{X, Y, T, D₁, D₂}, x :: X) where {X, Y, T, D₁ <: DiffF{X, Y, T}, D₂ <: DiffF{X, Y, T}} |
|
56 return differential(summed.a, x) + differential(summed.b, x) |
|
57 end |
|
58 |
|
59 # Sum of (non-differentiable) fucntions |
|
60 |
|
61 struct SummedVOnly{X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} <: Func{X, Y} |
|
62 a :: D₁ |
|
63 b :: D₂ |
|
64 end |
|
65 |
|
66 function value(summed :: SummedVOnly{X, Y, D₁, D₂}, x :: X) where {X, Y, D₁ <: Func{X, Y}, D₂ <: Func{X, Y}} |
|
67 return value(summed.a, x)+value(summed.b, x) |
|
68 end |
|
69 |
|
70 # Squared norm of a differentiable function as a differentiable function |
|
71 |
|
72 ℝ = Float64 |
|
73 ℝⁿ = Vector{Float64} |
|
74 |
|
75 struct Squared <: DiffF{ℝⁿ, ℝ, MatrixOp{ℝ}} |
|
76 r :: DiffF{ℝⁿ, ℝⁿ, MatrixOp{ℝ}} |
|
77 end |
|
78 |
|
79 function value(sq :: Squared, x :: ℝⁿ) |
|
80 return norm₂²(value(sq.r, x))/2 |
|
81 end |
|
82 |
|
83 function differential(sq :: Squared, x :: ℝⁿ) |
|
84 v = value(sq.r, x) |
|
85 d = differential(sq.r, x) |
|
86 return MatrixOp(v'*d.m) |
|
87 end |
|
88 |
|
89 end |