src/compose.py

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/compose.py	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,128 @@
+import numpy as np
+
+
+class SumOfSeparableFunctions:
+    def __init__(self, fnlist):
+        self.fnlist = fnlist
+
+    def apply(self, x):
+        val = 0.0
+        for f_i, x_i in zip(self.fnlist, x):
+            val += f_i.apply(x_i)
+        return val
+
+    def diff(self, x):
+        d = []
+        for f_i, x_i in zip(self.fnlist, x):
+            d.append(f_i.diff(x_i))
+        return d
+
+    def apply_and_diff(self, x):
+        d = []
+        val = 0.0
+        for f_i, x_i in zip(self.fnlist, x):
+            (a, v) = f_i.apply_and_diff(x_i)
+            val += a
+            d.append(v)
+        return (val, d)
+
+    def diff_lipschitz_factor(self):
+        res = 0
+        for f_i in self.fnlist:
+            res = max(res, f_i.diff_lipschitz_factor())
+        return res
+
+    def diff_bound(self, xbound=None):
+        res = 0
+        for f_i in self.fnlist:
+            res = max(res, f_i.diff_bound(xbound=xbound))
+        return res
+
+
+class ComposeFnWithOperator:
+    def __init__(self, f, op, xbound=None, xbound_pair=None):
+        self.f = f
+        self.op = op
+        self.xbound = xbound
+        self.xbound_pair = xbound_pair
+
+    def apply(self, *args):
+        return self.f.apply(self.op.apply(*args))
+
+    def diff(self, *args):
+        # TODO: precalculations in apply should be used in diff_adjdir
+        w = self.op.apply(*args)
+        v = self.f.diff(w)
+        return self.op.diff_adjdir(v, *args, apply_result=w)
+
+    def apply_and_diff(self, *args):
+        # TODO: precalculations in apply should be used in diff_adjdir
+        w = self.op.apply(*args)
+        (a, v) = self.f.apply_and_diff(w)
+        return (a, self.op.diff_adjdir(v, *args, apply_result=w))
+
+    def diff_lipschitz_factor(self):
+        # ‖∇A(x)^*∇F(A(x)) - ∇A(y)^*∇F(A(y))‖
+        # = ‖[∇A(x)^*-∇A(y)^*]∇F(A(x)) - ∇A(y)^*[∇F(A(y))-∇F(A(x))]‖
+        # ≤ L_{∇A(x)} M_{∇F} + M_{∇A(y)^*} L_{∇F}L_A.
+        if hasattr(self.op, "opnorm"):
+            # Linear operator
+            lda = 0.0  # This is zero,
+            mdf = 0.0  # hence this not needed.
+        else:
+            mdf = self.f.diff_bound(xbound=self.op.codomain_bound(xbound=self.xbound))
+            lda = self.op.diff_adj_lipschitz_factor()
+
+        ldf = self.f.diff_lipschitz_factor()
+        la = self.op.lipschitz_factor()
+        mda = self.op.diff_bound(xbound=self.xbound)
+
+        return lda * mdf + mda * ldf * la
+
+    def diff_lipschitz_factor_pair(self):
+        if self.op.hasattr("opnorm"):
+            # Linear operator
+            lda1, lda2 = 0.0, 0.0  # This is zero,
+            mdf = 0.0  # hence this not needed.
+        else:
+            lda1, lda2 = self.op.diff_adj_lipschitz_factor_pair()
+            mdf = self.f.diff_bound(
+                xbound=self.op.codomain_bound_pair(xbound=self.xbound_pair)
+            )
+
+        ldf = self.f.diff_lipschitz_factor()
+        la1, la2 = self.op.lipschitz_factor_pair()
+        mda = self.op.diff_bound_pair(xbound=self.xbound_pair)
+
+        return lda1 * mdf + mda * ldf * la1, lda2 * mdf + mda * ldf * la2
+
+
+class InjectSecond:
+    def __init__(self, y):
+        self.y = y
+
+    def apply(self, x):
+        return (x, self.y)
+
+    def diff_adjdir(self, j, _x, apply_result=None):
+        return j[0]
+
+    # This is not really a linear operator, but for our purposes affine behave essentially
+    # the same
+    def opnorm(self, *args):
+        return 1.0
+
+    def lipschitz_factor(self, *args):
+        return 1.0
+
+    def diff_adj_lipschitz_factor(self, *args):
+        return 0.0
+
+    def diff_bound(self, xbound=None):
+        return 1.0
+
+    def codomain_bound(self, xbound=None):
+        if xbound is None:
+            raise Exception("Linear operators have unbounded range")
+        else:
+            return xbound

mercurial