src/compose.py

changeset 3
c3a4f4bb87f7
parent 1
a4137aedcb3a
--- a/src/compose.py	Thu Feb 26 09:32:12 2026 -0500
+++ b/src/compose.py	Wed Apr 22 22:32:00 2026 -0500
@@ -32,19 +32,17 @@
             res = max(res, f_i.diff_lipschitz_factor())
         return res
 
-    def diff_bound(self, xbound=None):
+    def diff_bound(self):
         res = 0
         for f_i in self.fnlist:
-            res = max(res, f_i.diff_bound(xbound=xbound))
+            res = max(res, f_i.diff_bound())
         return res
 
 
 class ComposeFnWithOperator:
-    def __init__(self, f, op, xbound=None, xbound_pair=None):
+    def __init__(self, f, op):
         self.f = f
         self.op = op
-        self.xbound = xbound
-        self.xbound_pair = xbound_pair
 
     def apply(self, *args):
         return self.f.apply(self.op.apply(*args))
@@ -61,40 +59,67 @@
         (a, v) = self.f.apply_and_diff(w)
         return (a, self.op.diff_adjdir(v, *args, apply_result=w))
 
-    def diff_lipschitz_factor(self):
-        # ‖∇A(x)^*∇F(A(x)) - ∇A(y)^*∇F(A(y))‖
-        # = ‖[∇A(x)^*-∇A(y)^*]∇F(A(x)) - ∇A(y)^*[∇F(A(y))-∇F(A(x))]‖
-        # ≤ L_{∇A(x)} M_{∇F} + M_{∇A(y)^*} L_{∇F}L_A.
+    def diff_bound(self):
+        mf = self.f.diff_bound()
         if hasattr(self.op, "opnorm"):
-            # Linear operator
-            lda = 0.0  # This is zero,
-            mdf = 0.0  # hence this not needed.
+            lda = self.op.opnorm() ** 2
         else:
-            mdf = self.f.diff_bound(xbound=self.op.codomain_bound(xbound=self.xbound))
-            lda = self.op.diff_adj_lipschitz_factor()
+            lda = self.op.diff_bound()
+        return lda * mf
+
+    def diff_bound_pair(self):
+        mf = self.f.diff_bound()
+        lda1, lda2 = self.op.diff_bound_pair()
+        return lda1 * mf, lda2 * mf
+
+    # def lipschitz_factor(self, xbound=None):
+    #     if xbound is None:
+    #         xbound = self.xbound
+    #     lf = self.f.lipschitz_factor(xbound=self.op.codomain_bound(xbound=xbound))
+    #     la = self.op.lipschitz_factor(xbound=xbound)
+    #     return lf * la
 
-        ldf = self.f.diff_lipschitz_factor()
-        la = self.op.lipschitz_factor()
-        mda = self.op.diff_bound(xbound=self.xbound)
+    # def lipschitz_factor_pair(self, xbound=None):
+    #     if xbound is None:
+    #         xbound = self.xbound
+    #     lf = self.f.lipschitz_factor(xbound=self.op.codomain_bound(xbound=xbound))
+    #     la1, la2 = self.op.lipschitz_factor_pair(xbound=xbound)
+    #     return lf * la1, lf * la2
 
-        return lda * mdf + mda * ldf * la
+    def diff_lipschitz_factor(self):
+        """
+        Calculate the Lipschitz factor of the differential of this composed function.
+        We assume that either the operator is linear and implementes `opnorm`, or it is nonlinear, and impliements `diff_chain_lipschitz_factor` to directly calculate a Lipschitz factor of $x ↦ ∇A(x)^*∇F(A(x))$, given a bound $M$
+        on ∇F. The function `diff_chain_lipschitz_factor` should return the Lipschitz factor divided by $M$: we obtain $M$ through the `diff_bound` on $F$.
+        """
+
+        if hasattr(self.op, "opnorm"):
+            return self.f.diff_lipschitz_factor() * self.op.opnorm() ** 2
+        else:
+            mdf = self.f.diff_bound()
+            lda = self.op.diff_chain_lipschitz_factor()
+            # print(
+            #     "LDA %s %f; MDF %f; total %f"
+            #     % (type(self.op).__name__, lda, mdf, mdf * lda),
+            # )
+            return mdf * lda
 
     def diff_lipschitz_factor_pair(self):
-        if self.op.hasattr("opnorm"):
-            # Linear operator
-            lda1, lda2 = 0.0, 0.0  # This is zero,
-            mdf = 0.0  # hence this not needed.
-        else:
-            lda1, lda2 = self.op.diff_adj_lipschitz_factor_pair()
-            mdf = self.f.diff_bound(
-                xbound=self.op.codomain_bound_pair(xbound=self.xbound_pair)
-            )
+        """
+        This is similar to `diff_lipschitz_factor`, except separates the
+        factor for arguments pairs.
+
+        This requires the operator to implement `diff_chain_lipschitz_factor`;
+        there is no special handling of linear opeartors.
+        """
 
-        ldf = self.f.diff_lipschitz_factor()
-        la1, la2 = self.op.lipschitz_factor_pair()
-        mda = self.op.diff_bound_pair(xbound=self.xbound_pair)
+        mdf = self.f.diff_bound()
+
+        lda1, lda2 = self.op.diff_chain_lipschitz_factor_pair()
 
-        return lda1 * mdf + mda * ldf * la1, lda2 * mdf + mda * ldf * la2
+        # print("LDA %s %f %f; MDF %f;  " % (type(self.op).__name__, lda1, lda2, mdf))
+
+        return mdf * lda1, mdf * lda2
 
 
 class InjectSecond:
@@ -112,17 +137,17 @@
     def opnorm(self, *args):
         return 1.0
 
-    def lipschitz_factor(self, *args):
+    # def lipschitz_factor(self):
+    #     return 1.0
+
+    # def diff_chain_lipschitz_factor(self, *args):
+    #     return 0.0
+
+    def diff_bound(self):
         return 1.0
 
-    def diff_adj_lipschitz_factor(self, *args):
-        return 0.0
-
-    def diff_bound(self, xbound=None):
-        return 1.0
-
-    def codomain_bound(self, xbound=None):
-        if xbound is None:
-            raise Exception("Linear operators have unbounded range")
-        else:
-            return xbound
+    # def codomain_bound(self, xbound=None):
+    #     if xbound is None:
+    #         raise Exception("Linear operators have unbounded range")
+    #     else:
+    #         return xbound

mercurial