diff -r a4137aedcb3a -r c3a4f4bb87f7 src/compose.py --- a/src/compose.py Thu Feb 26 09:32:12 2026 -0500 +++ b/src/compose.py Wed Apr 22 22:32:00 2026 -0500 @@ -32,19 +32,17 @@ res = max(res, f_i.diff_lipschitz_factor()) return res - def diff_bound(self, xbound=None): + def diff_bound(self): res = 0 for f_i in self.fnlist: - res = max(res, f_i.diff_bound(xbound=xbound)) + res = max(res, f_i.diff_bound()) return res class ComposeFnWithOperator: - def __init__(self, f, op, xbound=None, xbound_pair=None): + def __init__(self, f, op): self.f = f self.op = op - self.xbound = xbound - self.xbound_pair = xbound_pair def apply(self, *args): return self.f.apply(self.op.apply(*args)) @@ -61,40 +59,67 @@ (a, v) = self.f.apply_and_diff(w) return (a, self.op.diff_adjdir(v, *args, apply_result=w)) - def diff_lipschitz_factor(self): - # ‖∇A(x)^*∇F(A(x)) - ∇A(y)^*∇F(A(y))‖ - # = ‖[∇A(x)^*-∇A(y)^*]∇F(A(x)) - ∇A(y)^*[∇F(A(y))-∇F(A(x))]‖ - # ≤ L_{∇A(x)} M_{∇F} + M_{∇A(y)^*} L_{∇F}L_A. + def diff_bound(self): + mf = self.f.diff_bound() if hasattr(self.op, "opnorm"): - # Linear operator - lda = 0.0 # This is zero, - mdf = 0.0 # hence this not needed. + lda = self.op.opnorm() ** 2 else: - mdf = self.f.diff_bound(xbound=self.op.codomain_bound(xbound=self.xbound)) - lda = self.op.diff_adj_lipschitz_factor() + lda = self.op.diff_bound() + return lda * mf + + def diff_bound_pair(self): + mf = self.f.diff_bound() + lda1, lda2 = self.op.diff_bound_pair() + return lda1 * mf, lda2 * mf + + # def lipschitz_factor(self, xbound=None): + # if xbound is None: + # xbound = self.xbound + # lf = self.f.lipschitz_factor(xbound=self.op.codomain_bound(xbound=xbound)) + # la = self.op.lipschitz_factor(xbound=xbound) + # return lf * la - ldf = self.f.diff_lipschitz_factor() - la = self.op.lipschitz_factor() - mda = self.op.diff_bound(xbound=self.xbound) + # def lipschitz_factor_pair(self, xbound=None): + # if xbound is None: + # xbound = self.xbound + # lf = self.f.lipschitz_factor(xbound=self.op.codomain_bound(xbound=xbound)) + # la1, la2 = self.op.lipschitz_factor_pair(xbound=xbound) + # return lf * la1, lf * la2 - return lda * mdf + mda * ldf * la + def diff_lipschitz_factor(self): + """ + Calculate the Lipschitz factor of the differential of this composed function. + We assume that either the operator is linear and implementes `opnorm`, or it is nonlinear, and impliements `diff_chain_lipschitz_factor` to directly calculate a Lipschitz factor of $x ↦ ∇A(x)^*∇F(A(x))$, given a bound $M$ + on ∇F. The function `diff_chain_lipschitz_factor` should return the Lipschitz factor divided by $M$: we obtain $M$ through the `diff_bound` on $F$. + """ + + if hasattr(self.op, "opnorm"): + return self.f.diff_lipschitz_factor() * self.op.opnorm() ** 2 + else: + mdf = self.f.diff_bound() + lda = self.op.diff_chain_lipschitz_factor() + # print( + # "LDA %s %f; MDF %f; total %f" + # % (type(self.op).__name__, lda, mdf, mdf * lda), + # ) + return mdf * lda def diff_lipschitz_factor_pair(self): - if self.op.hasattr("opnorm"): - # Linear operator - lda1, lda2 = 0.0, 0.0 # This is zero, - mdf = 0.0 # hence this not needed. - else: - lda1, lda2 = self.op.diff_adj_lipschitz_factor_pair() - mdf = self.f.diff_bound( - xbound=self.op.codomain_bound_pair(xbound=self.xbound_pair) - ) + """ + This is similar to `diff_lipschitz_factor`, except separates the + factor for arguments pairs. + + This requires the operator to implement `diff_chain_lipschitz_factor`; + there is no special handling of linear opeartors. + """ - ldf = self.f.diff_lipschitz_factor() - la1, la2 = self.op.lipschitz_factor_pair() - mda = self.op.diff_bound_pair(xbound=self.xbound_pair) + mdf = self.f.diff_bound() + + lda1, lda2 = self.op.diff_chain_lipschitz_factor_pair() - return lda1 * mdf + mda * ldf * la1, lda2 * mdf + mda * ldf * la2 + # print("LDA %s %f %f; MDF %f; " % (type(self.op).__name__, lda1, lda2, mdf)) + + return mdf * lda1, mdf * lda2 class InjectSecond: @@ -112,17 +137,17 @@ def opnorm(self, *args): return 1.0 - def lipschitz_factor(self, *args): + # def lipschitz_factor(self): + # return 1.0 + + # def diff_chain_lipschitz_factor(self, *args): + # return 0.0 + + def diff_bound(self): return 1.0 - def diff_adj_lipschitz_factor(self, *args): - return 0.0 - - def diff_bound(self, xbound=None): - return 1.0 - - def codomain_bound(self, xbound=None): - if xbound is None: - raise Exception("Linear operators have unbounded range") - else: - return xbound + # def codomain_bound(self, xbound=None): + # if xbound is None: + # raise Exception("Linear operators have unbounded range") + # else: + # return xbound