| 41 # ≤ L_{∇A(x)}(M_A+M_b) + M_{∇A(y)^*} L_A. |
41 # ≤ L_{∇A(x)}(M_A+M_b) + M_{∇A(y)^*} L_A. |
| 42 la = self.opA.lipschitz_factor() |
42 la = self.opA.lipschitz_factor() |
| 43 ma = self.opA.bound(xbound=xbound) |
43 ma = self.opA.bound(xbound=xbound) |
| 44 mb = self.b.norm() |
44 mb = self.b.norm() |
| 45 mda = self.opA.diff_bound(xbound=xbound) |
45 mda = self.opA.diff_bound(xbound=xbound) |
| 46 lda = self.opA.diff_adj_lipschitz_factor() |
46 lda = self.opA.diff_chain_lipschitz_factor() |
| 47 return self.λ * (lda * (ma + mb) + mda * la) |
47 return self.λ * (lda * (ma + mb) + mda * la) |
| 48 |
48 |
| 49 # def diff_lipschitz_factor_pair(self, *args): |
49 def diff_bound(self): |
| 50 # return self.diff_lipschitz_factor() |
50 if hasattr(self.opA, "opnorm"): |
| 51 |
51 opn = self.opA.opnorm() |
| 52 def diff_bound(self, xbound=None): |
52 return self.λ * opn * (opn * self.xbound + self.b_norm) |
| 53 return self.λ * (self.opA.codomain_bound(xbound=xbound) + self.b_norm) |
53 else: |
| |
54 raise Exception("Unimplemented") |