| |
1 import numpy as np |
| |
2 |
| |
3 |
| |
4 class SumOfSeparableFunctions: |
| |
5 def __init__(self, fnlist): |
| |
6 self.fnlist = fnlist |
| |
7 |
| |
8 def apply(self, x): |
| |
9 val = 0.0 |
| |
10 for f_i, x_i in zip(self.fnlist, x): |
| |
11 val += f_i.apply(x_i) |
| |
12 return val |
| |
13 |
| |
14 def diff(self, x): |
| |
15 d = [] |
| |
16 for f_i, x_i in zip(self.fnlist, x): |
| |
17 d.append(f_i.diff(x_i)) |
| |
18 return d |
| |
19 |
| |
20 def apply_and_diff(self, x): |
| |
21 d = [] |
| |
22 val = 0.0 |
| |
23 for f_i, x_i in zip(self.fnlist, x): |
| |
24 (a, v) = f_i.apply_and_diff(x_i) |
| |
25 val += a |
| |
26 d.append(v) |
| |
27 return (val, d) |
| |
28 |
| |
29 def diff_lipschitz_factor(self): |
| |
30 res = 0 |
| |
31 for f_i in self.fnlist: |
| |
32 res = max(res, f_i.diff_lipschitz_factor()) |
| |
33 return res |
| |
34 |
| |
35 def diff_bound(self, xbound=None): |
| |
36 res = 0 |
| |
37 for f_i in self.fnlist: |
| |
38 res = max(res, f_i.diff_bound(xbound=xbound)) |
| |
39 return res |
| |
40 |
| |
41 |
| |
42 class ComposeFnWithOperator: |
| |
43 def __init__(self, f, op, xbound=None, xbound_pair=None): |
| |
44 self.f = f |
| |
45 self.op = op |
| |
46 self.xbound = xbound |
| |
47 self.xbound_pair = xbound_pair |
| |
48 |
| |
49 def apply(self, *args): |
| |
50 return self.f.apply(self.op.apply(*args)) |
| |
51 |
| |
52 def diff(self, *args): |
| |
53 # TODO: precalculations in apply should be used in diff_adjdir |
| |
54 w = self.op.apply(*args) |
| |
55 v = self.f.diff(w) |
| |
56 return self.op.diff_adjdir(v, *args, apply_result=w) |
| |
57 |
| |
58 def apply_and_diff(self, *args): |
| |
59 # TODO: precalculations in apply should be used in diff_adjdir |
| |
60 w = self.op.apply(*args) |
| |
61 (a, v) = self.f.apply_and_diff(w) |
| |
62 return (a, self.op.diff_adjdir(v, *args, apply_result=w)) |
| |
63 |
| |
64 def diff_lipschitz_factor(self): |
| |
65 # ‖∇A(x)^*∇F(A(x)) - ∇A(y)^*∇F(A(y))‖ |
| |
66 # = ‖[∇A(x)^*-∇A(y)^*]∇F(A(x)) - ∇A(y)^*[∇F(A(y))-∇F(A(x))]‖ |
| |
67 # ≤ L_{∇A(x)} M_{∇F} + M_{∇A(y)^*} L_{∇F}L_A. |
| |
68 if hasattr(self.op, "opnorm"): |
| |
69 # Linear operator |
| |
70 lda = 0.0 # This is zero, |
| |
71 mdf = 0.0 # hence this not needed. |
| |
72 else: |
| |
73 mdf = self.f.diff_bound(xbound=self.op.codomain_bound(xbound=self.xbound)) |
| |
74 lda = self.op.diff_adj_lipschitz_factor() |
| |
75 |
| |
76 ldf = self.f.diff_lipschitz_factor() |
| |
77 la = self.op.lipschitz_factor() |
| |
78 mda = self.op.diff_bound(xbound=self.xbound) |
| |
79 |
| |
80 return lda * mdf + mda * ldf * la |
| |
81 |
| |
82 def diff_lipschitz_factor_pair(self): |
| |
83 if self.op.hasattr("opnorm"): |
| |
84 # Linear operator |
| |
85 lda1, lda2 = 0.0, 0.0 # This is zero, |
| |
86 mdf = 0.0 # hence this not needed. |
| |
87 else: |
| |
88 lda1, lda2 = self.op.diff_adj_lipschitz_factor_pair() |
| |
89 mdf = self.f.diff_bound( |
| |
90 xbound=self.op.codomain_bound_pair(xbound=self.xbound_pair) |
| |
91 ) |
| |
92 |
| |
93 ldf = self.f.diff_lipschitz_factor() |
| |
94 la1, la2 = self.op.lipschitz_factor_pair() |
| |
95 mda = self.op.diff_bound_pair(xbound=self.xbound_pair) |
| |
96 |
| |
97 return lda1 * mdf + mda * ldf * la1, lda2 * mdf + mda * ldf * la2 |
| |
98 |
| |
99 |
| |
100 class InjectSecond: |
| |
101 def __init__(self, y): |
| |
102 self.y = y |
| |
103 |
| |
104 def apply(self, x): |
| |
105 return (x, self.y) |
| |
106 |
| |
107 def diff_adjdir(self, j, _x, apply_result=None): |
| |
108 return j[0] |
| |
109 |
| |
110 # This is not really a linear operator, but for our purposes affine behave essentially |
| |
111 # the same |
| |
112 def opnorm(self, *args): |
| |
113 return 1.0 |
| |
114 |
| |
115 def lipschitz_factor(self, *args): |
| |
116 return 1.0 |
| |
117 |
| |
118 def diff_adj_lipschitz_factor(self, *args): |
| |
119 return 0.0 |
| |
120 |
| |
121 def diff_bound(self, xbound=None): |
| |
122 return 1.0 |
| |
123 |
| |
124 def codomain_bound(self, xbound=None): |
| |
125 if xbound is None: |
| |
126 raise Exception("Linear operators have unbounded range") |
| |
127 else: |
| |
128 return xbound |