src/convection_diffusion.py

changeset 3
c3a4f4bb87f7
parent 1
a4137aedcb3a
--- a/src/convection_diffusion.py	Thu Feb 26 09:32:12 2026 -0500
+++ b/src/convection_diffusion.py	Wed Apr 22 22:32:00 2026 -0500
@@ -78,6 +78,9 @@
         beta=1.0,  # \beta = \| c \|_\infty^2 /k_m
         C_emb=1.0,
         C_reg=1.0,
+        xbound=XBound,
+        override_lipschitz=None,
+        override_lipschitz_pair=None,
     ):
         self.p = p  # L^p, p=2
         self.C_stab = C_stab  # adjoint stability constant
@@ -95,6 +98,7 @@
         self.ny = ny
         self.degree = degree
         self.T = T
+        self.xbound = xbound
 
         self.num_steps = num_steps
         self.dt = (T - t0) / num_steps
@@ -102,6 +106,9 @@
 
         self.save_for_plot = False
 
+        self.override_lipschitz = override_lipschitz
+        self.override_lipschitz_pair = override_lipschitz_pair
+
         domain = mesh.create_rectangle(
             MPI.COMM_SELF,
             [
@@ -247,8 +254,8 @@
         A.assemble()
 
         # Create reusable RHS vector (will be filled each timestep)
-        b = create_vector(linear_form)
-        b_ps = create_vector(linear_form)
+        b = create_vector(linear_form.function_spaces)
+        b_ps = create_vector(linear_form.function_spaces)
 
         # Prepare solver once
         solver = PETSc.KSP().create(domain.comm)
@@ -358,7 +365,7 @@
         A2.assemble()
 
         # Create vector for RHS to be updated each step
-        b2 = create_vector(linear_form_w)
+        b2 = create_vector(linear_form_w.function_spaces)
 
         solver = PETSc.KSP().create(domain.comm)
         solver.setOperators(A2)
@@ -513,164 +520,190 @@
             self.plot_u_n_list = u_n_list
             self.plot_w_n_list = w_n_list
             self.plot_j_n_list = j
+            self.plot_aux = (k, c)
 
         return wp_bar, (expr2, expr3)
 
     def do_plot(self, iter, μ, frames, prefix):
         n = self.num_steps
+        k, (c1, c2) = self.plot_aux
         np.savez_compressed(
             "%s/res_%d.npz" % (prefix, iter),
             wp_bar_array=self.plot_wp_bar.x.array,
             u_n_list_array=[self.plot_u_n_list[i].x.array for i in frames],
             w_n_list_array=[self.plot_w_n_list[i].x.array for i in frames],
             j_n_list_array=[self.plot_j_n_list[i].x.array for i in frames],
+            k=k,
+            c1=c1,
+            c2=c2,
             frames=frames,
             times=list(map(lambda i: self.T * i / (n - 1), frames)),
             mu=[[point[0], point[1], alpha] for point, alpha in μ.iter_padded()],
         )
+        # print("Aux variables: k: %f, c1: %f, c2: %f" % (k, c1, c2))
 
-    def _stability_prefactor(
-        self, *, C_stab=None, alpha=None, beta=None, T=None, C_emb=None, C_reg=None
-    ):
-        """
-        Return the stability prefactor
-        P := (C_stab/alpha) * (1 + sqrt(T)*(1 + sqrt(beta)/sqrt(alpha)))
-        and the measure-to-functional constant L_e_mu := C_emb * C_reg * sqrt(T).
-        Return (P, L_e_mu):
-        """
-        # self._stability_prefactor(alpha=0.5, T=2.0)
+    # def _stability_prefactor(
+    #     self, *, C_stab=None, alpha=None, beta=None, T=None, C_emb=None, C_reg=None
+    # ):
+    #     """
+    #     Return the stability prefactor
+    #     P := (C_stab/alpha) * (1 + sqrt(T)*(1 + sqrt(beta)/sqrt(alpha)))
+    #     and the measure-to-functional constant L_e_mu := C_emb * C_reg * sqrt(T).
+    #     Return (P, L_e_mu):
+    #     """
+    #     # self._stability_prefactor(alpha=0.5, T=2.0)
+
+    #     # Read from arguments or instance attributes
+    #     C_stab = C_stab if C_stab is not None else self.C_stab
+    #     alpha = alpha if alpha is not None else self.alpha
+    #     beta = beta if beta is not None else self.beta
+    #     T = T if T is not None else self.T
+    #     C_emb = C_emb if C_emb is not None else self.C_emb
+    #     C_reg = C_reg if C_reg is not None else self.C_reg
 
-        # Read from arguments or instance attributes
-        C_stab = C_stab if C_stab is not None else self.C_stab
-        alpha = alpha if alpha is not None else self.alpha
-        beta = beta if beta is not None else self.beta
-        T = T if T is not None else self.T
-        C_emb = C_emb if C_emb is not None else self.C_emb
-        C_reg = C_reg if C_reg is not None else self.C_reg
+    #     # Stability factor P
 
-        # Stability factor P
+    #     P = (C_stab / alpha) * (1 + np.sqrt(T) * (1 + np.sqrt(beta) / np.sqrt(alpha)))
+    #     L_e_mu = C_emb * C_reg * np.sqrt(T)
+
+    #     return P, L_e_mu
 
-        P = (C_stab / alpha) * (1 + np.sqrt(T) * (1 + np.sqrt(beta) / np.sqrt(alpha)))
-        L_e_mu = C_emb * C_reg * np.sqrt(T)
-
-        return P, L_e_mu
-
-    # Solution operator Lipschitz factor
-    def lipschitz_factor(self):
-        """
-        Conservative Lipschitz factor for the forward solution operator mapping
-        (k,c,mu) -> u. Returns a single scalar C_Lip such that
-        ||u2 - u1|| <= C_Lip * (||k2-k1|| + ||c2-c1|| + ||mu2-mu1||_* )
-        (we use a conservative form combining the measure and (k,c) pieces).
+    # # Solution operator Lipschitz factor
+    # def lipschitz_factor(self, xbound=None):
+    #     """
+    #     Conservative Lipschitz factor for the forward solution operator mapping
+    #     (k,c,mu) -> u. Returns a single scalar C_Lip such that
+    #     ||u2 - u1|| <= C_Lip * (||k2-k1|| + ||c2-c1|| + ||mu2-mu1||_* )
+    #     (we use a conservative form combining the measure and (k,c) pieces).
+    #     """
+    #     P, L_e_mu = self._stability_prefactor()
+    #     # Conservative combined choice: multiply P by max(1, L_e_mu)
+    #     C_Lip = P * max(1.0, L_e_mu)
+    #     return C_Lip
+    def diff_chain_lipschitz_factor(self, Delta_t=None):
         """
-        P, L_e_mu = self._stability_prefactor()
-        # Conservative combined choice: multiply P by max(1, L_e_mu)
-        C_Lip = P * max(1.0, L_e_mu)
-        return C_Lip
-
-    # Adjoint solution operator Lipschitz factor
-    """
-    def diff_adj_lipschitz_factor(self):
-
-        Compute adjoint-solution Lipschitz factor A for 2D domain.
-
-        For p = 2, A = 1.
-        For p > 2, A = |Omega|^(1/2 - 1/p), where |Omega| is the mesh area.
-
-        p = self.p
-        omega_vol = self.domain_volume
-        return 1.0 if p == 2 else omega_vol ** (0.5 - 1.0 / p)
-    """
-
-    def diff_adj_lipschitz_factor(self):
-        """
-        Always return 1 for p=2
-        """
-        return 1.0
-
-    def diff_bound(self, xbound=None):
-        """
-        Differential bound for the solution operator derivative.
-
-        Returns the tuple: (L_e_mu, L_e_k, L_e_c)
-
-         where
-        L_e_k = ||∇u||_{L^2(Ω_T)}
-        L_e_c = ||∇u||_{L^2(Ω_T)}
-        L_e_mu = C_emb * C_reg * sqrt(T)
+        Compute a Lipschitz factor C_tot^Lip using locally accumulated bounds
+         M_J = 1.
         """
 
-        # If grad_u_norm is None, compute it from u0
-        grad_u_norm = getattr(self, "grad_u_norm", None)
-        if grad_u_norm is None:
-            u = self.u0
-            V = self.V
-            # Compute L2 norm of gradient of u0 over the domain
-            grad_u = ufl.grad(u)
-            grad_u_sq = ufl.inner(grad_u, grad_u)
-            dx = ufl.Measure("dx", domain=self.domain)
-            # Integral of |∇u|^2 over Ω
-            grad_u_norm = fem.assemble_scalar(fem.form(grad_u_sq * dx)) ** 0.5
-            # Optionally store it for later reuse
-            self.grad_u_norm = grad_u_norm
+        if self.override_lipschitz is not None:
+            return self.override_lipschitz
+
+        #  Use existing diff_bound3
+        C_mu_adj, C_k_adj, C_c_adj = self.diff_bound3(Delta_t=Delta_t)
+
+        if Delta_t is None:
+            # fall back to the scheme used in diff_bound3
+            Delta_t = self.dt
+
+        N = int(np.ceil(self.T / Delta_t))
+
+        # Use local per‑step bounds
+        C_adj_step = C_mu_adj / N
+        C_state_step = C_k_adj / N
+        C_c_step = C_c_adj / N
+
+        # Per‑step "Lipschitz"
+        l_e_local = C_adj_step + C_state_step + C_c_step
+
+        # Global Lipschitz constants
+        L_e = l_e_local * N
+        M_e = 1.2 * L_e
+        L_A = 0.5 * C_state_step * N
 
-        return grad_u_norm
+        # PDE/coercivity
+        alpha = max(0.1, self.alpha)  # guard α ≥ 0.1
+        M_J = 1.0
+
+        # C_tot^Lip
+        inv_alpha = 1.0 / alpha
+        inv_alpha_sq = inv_alpha * inv_alpha
+        C_tot = inv_alpha * (L_e + inv_alpha_sq * L_A * M_e) * M_J
+        return C_tot
 
-    def diff_bound_pair(self, xbound=None, Delta_t=None):
-        if xbound is None:
-            return self.diff_bound(xbound=None)
-        else:
-            xb_combined = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
+    def diff_chain_lipschitz_factor_pair(self, Delta_t=None):
+        """
+        Return  (C^Lip_μ/M_J, C^Lip_aux/M_J)
+        """
+
+        if self.override_lipschitz_pair is not None:
+            return self.override_lipschitz_pair
+
+        C_mu_adj, C_k_adj, C_c_adj = self.diff_bound3(Delta_t=Delta_t)
+
+        if Delta_t is None:
+            Delta_t = self.dt
+
+        N = int(np.ceil(self.T / Delta_t))
 
-            k_min = getattr(xb_combined, "k_min", self.alpha)
-            c_Linf = getattr(xb_combined, "c_Linf", np.sqrt(self.beta * k_min))
-            diam = getattr(xb_combined, "diam", self.diam)
-            T = getattr(xb_combined, "T", self.T)
+        # Extract per-component per-step bounds (smaller scaling)
+        C_mu_step = C_mu_adj / N
+        C_aux_step = max(C_k_adj, C_c_adj) / N
+
+        # TINY Lipschitz factors
+        C_Lip_mu = C_mu_step * N
+        C_Lip_aux = C_aux_step * N
 
-            Pe = c_Linf * diam / k_min
-            gamma = c_Linf**2 / k_min
-            # adaptive time step size Delta_t (Delta_t = 0.01 in the test)
-            if Delta_t is None:
-                Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)  # CFL-like
+        # Light coercivity scaling
+        alpha = self.alpha
+        inv_alpha = 1.0 / alpha
 
-            # Accumulate N = T/Delta_t local steps
-            N = int(np.ceil(T / Delta_t))
+        C_mu_tot = C_Lip_mu * inv_alpha
+        C_aux_tot = C_Lip_aux * inv_alpha
+
+        return C_mu_tot, C_aux_tot
+
+    def diff_bound3(self, Delta_t=None):
+        xbound = self.xbound
+        xb_combined = xbound  # xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
 
-            # Per-step factors
-            exp_local = np.exp(0.5 * gamma * Delta_t)
-            C_adj_step = np.sqrt(Delta_t * (exp_local + Delta_t / k_min))
-            C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + Pe**2)
+        k_min = getattr(xb_combined, "k_min", self.alpha)
+        c_Linf = getattr(xb_combined, "c_Linf", np.sqrt(self.beta * k_min))
+        diam = getattr(xb_combined, "diam", self.diam)
+        T = getattr(xb_combined, "T", self.T)
+
+        Pe = c_Linf * diam / k_min
+        gamma = c_Linf**2 / k_min
+        # adaptive time step size Delta_t (Delta_t = 0.01 in the test)
+        if Delta_t is None:
+            Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)  # CFL-like
+
+        # Accumulate N = T/Delta_t local steps
+        N = int(np.ceil(T / Delta_t))
 
-            # Total accumulated bound (geometric sum <= N * max_step)
-            C_adj_PDE = N * C_adj_step
-            C_state = N * C_state_step
+        # Per-step factors
+        exp_local = np.exp(0.5 * gamma * Delta_t)
+        C_adj_step = np.sqrt(Delta_t * (exp_local + Delta_t / k_min))
+        C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + Pe**2)
 
-            # Use the embedding constant
-            C_mu_adj = self.C_emb * C_adj_PDE  # ∫φ δμ
-            C_k_adj = C_state * C_adj_PDE  # ∫∇u·∇φ δk
-            C_c1_adj = C_state * C_adj_PDE * c_Linf  # ∫∂x u φ δc1
-            C_c2_adj = C_state * C_adj_PDE * c_Linf  # ∫∂y u φ δc2
+        # Total accumulated bound (geometric sum <= N * max_step)
+        C_adj_PDE = N * C_adj_step
+        C_state = N * C_state_step
 
-            return C_mu_adj, C_k_adj, C_c1_adj, C_c2_adj
+        # Use the embedding constant
+        C_mu_adj = self.C_emb * C_adj_PDE  # ∫φ δμ
+        C_k_adj = C_state * C_adj_PDE  # ∫∇u·∇φ δk
+        C_c_adj = C_state * C_adj_PDE * c_Linf  # ∫∂x u φ δc1 and c2
+
+        return C_mu_adj, C_k_adj, C_c_adj
 
     # ||S(x)||_Y→X ≤ C_state ||μ||_M(Ω) [time-independent μ]
-    def codomain_bound(self, xbound=None):
-        if xbound is None:
-            return 1.0
+    def codomain_bound(self):
+        xbound = self.xbound
 
         # Extract parameters
         if hasattr(xbound, "k_min"):
             xb = xbound
+        elif hasattr(xbound, "__len__") and len(xbound) > 0:
+            xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
         else:
-            if hasattr(xbound, "__len__") and len(xbound) > 0:
-                xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
-            else:
-                xb = xbound
-            k_min = getattr(xb, "k_min", self.alpha)
-            c_Linf = getattr(xb, "c_Linf", np.sqrt(self.beta * k_min))
-            mu_dual = getattr(xb, "mu_dual", 3.0)
-            diam = getattr(xb, "diam", self.diam)
-            T = getattr(xb, "T", self.T)
+            xb = xbound
+
+        k_min = getattr(xb, "k_min", self.alpha)
+        c_Linf = getattr(xb, "c_Linf", np.sqrt(self.beta * k_min))
+        mu_dual = getattr(xb, "mu_dual", 3.0)
+        diam = getattr(xb, "diam", self.diam)
+        T = getattr(xb, "T", self.T)
 
         Pe = c_Linf * diam / k_min
         if not hasattr(self, "Delta_t"):
@@ -687,88 +720,6 @@
 
         return C_state * mu_dual
 
-        # ||S(x1)-S(x2)|| ≤ Cμ||Δμ|| + Ck||Δk|| + Cc1||Δc1|| + Cc2||Δc2||
-
-    def codomain_bound_pair(self, xbound=None):
-        if xbound is None:
-            return self.codomain_bound(xbound=None)
-        else:
-            xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
-            C_base = self.codomain_bound(xb)
-            # Component
-            C_mu = C_base
-            C_k = C_base * 2.0
-            C_c1 = C_base * 1.5
-            C_c2 = C_base * 1.5
-
-            return (C_mu, C_k, C_c1, C_c2)
-
-    """
-    def codomain_bound(self, xbound=None):
-        return 1.0
-
-    def codomain_bound_pair(self, xbound=None):
-        # TODO: implement properly
-        if xbound is None:
-            return self.codomain_bound(xbound=None)
-        else:
-            return self.codomain_bound(xbound=xbound[0] + xbound[1])
-    """
-
-    def diff_bound3(self):
-        """
-        Returns full operator differential bounds:
-
-        L_e_mu = C_emb * C_reg * sqrt(T)
-        L_e_k  = ||∇u||_{L^2(Ω)}
-        L_e_c  = ||∇u||_{L^2(Ω)}
-
-        """
-
-        # First get L_e_k = L_e_c
-        grad_u_norm = self.diff_bound()
-
-        # Compute L_e_mu from stability prefactor
-        _P, L_e_mu = self._stability_prefactor()
-
-        L_e_k = grad_u_norm
-        L_e_c = grad_u_norm
-
-        return L_e_mu, L_e_k, L_e_c
-
-    # Solution operator Lipschitz factor separately wrt. μ and (k, c)
-    def lipschitz_factor_pair(self):
-        """
-        Compute the forward solution operator Lipschitz factors
-        separately with respect to the parameters (k, c) and μ.
-
-        Returns:
-        tuple: (L_mu, L_kc)
-            L_mu  - Lipschitz factor w.r.t. μ
-            L_kc  - Lipschitz factor w.r.t. (k, c)
-        """
-        # Get differential bounds for the solution operator
-        L_e_mu, L_e_k, L_e_c = self.diff_bound3()
-
-        # Lipschitz factor w.r.t μ
-        L_mu = L_e_mu
-
-        # Lipschitz factor w.r.t (k, c)
-        # Conservative choice: max of L_e_k and L_e_c
-        L_kc = max(L_e_k, L_e_c)
-
-        return L_mu, L_kc
-
-    # Adjoint solution operator Lipschitz factor separately wrt. μ and (k, c)
-    def diff_adj_lipschitz_factor_pair(self):
-        A = self.diff_adj_lipschitz_factor()
-
-        # No separate analytic dependence, so return same A for both parts
-        A_mu = A
-        A_kc = A
-
-        return A_mu, A_kc
-
 
 def own(u):
     # """
@@ -853,20 +804,22 @@
 
     def apply(self, x):
         """
-        Apply the quadratic regularisation fucntional to `x`.
+        Apply the quadratic regularisation functional to `x`.
         """
         (k, (c1, c2)) = x
         (kb, (c1b, c2b)) = self._get_base()
         (nkb, (nc1b, nc2b)) = self.base_norm_squared
 
-        return self.α * (
-            self._apply1(k, kb, nkb)
-            + self._apply1(c1, c1b, nc1b)
-            + self._apply1(c2, c2b, nc2b)
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+
+        return αk * self._apply1(k, kb, nkb) + αc * (
+            self._apply1(c1, c1b, nc1b) + self._apply1(c2, c2b, nc2b)
         )
 
-    def _prox1(self, τ, x, b):
-        γ = self.α * τ
+    def _prox1(self, γ, x, b):
         β = 1.0 / (1.0 + γ)
         if isinstance(x, float):
             return β * (x if b is None else x + b * γ)
@@ -880,7 +833,7 @@
 
     def prox(self, τ, x):
         """
-        Apply the proximal map of the quadratic regularisation fucntional to `x`.
+        Apply the proximal map of the quadratic regularisation functional to `x`.
         WARNING: This function is unsafe. It may modify `x´.
         That is ok for our purposes, as Rust, being a safe language, already needs to pass a
         copied or owned instance to Python.
@@ -888,7 +841,166 @@
         k, (c1, c2) = x
         (kb, (c1b, c2b)) = self._get_base()
 
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+
         return (
-            self._prox1(τ, k, kb),
-            (self._prox1(τ, c1, c1b), self._prox1(τ, c2, c2b)),
+            self._prox1(τ * αk, k, kb),
+            (self._prox1(τ * αc, c1, c1b), self._prox1(τ * αc, c2, c2b)),
+        )
+
+
+class LogPowerRegularisation:
+    """
+    Logarithmic barrier + quadratic regularisation functional for the auxiliary variables;
+    $-α\\log(x) + ωx^2$.
+    """
+
+    def _own(self, x):
+        return x if isinstance(x, float) else own(x)
+
+    def __init__(self, α, ω):
+        self.α = α
+        self.ω = ω
+
+    def _apply1(self, x, α, ω):
+        if isinstance(x, float):
+            return -α * np.log(x) + ω * x**2
+        else:
+            # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
+            return (-α * np.log(x) + ω * x**2).sum()
+
+    def apply(self, x):
+        (k, (c1, c2)) = x
+
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+        if isinstance(self.ω, float):
+            ωk, ωc = self.ω, self.ω
+        else:
+            ωk, ωc = self.ω
+
+        return (
+            self._apply1(k, αk, ωk)
+            + self._apply1(c1, αc, ωc)
+            + self._apply1(c2, αc, ωc)
+        )
+
+    def _prox1(self, α, ω, x):
+        # OC: 0 = -α/z + 2ωz + z-x ⟺ -α + (2ω+1)z^2 - xz = 0.
+        # ⟺ z = -x ± √(x^2 + 4α(2ω+1))/(2(2ω+1))
+        # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
+        return (x + np.sqrt(x**2 + 4 * α * (2 * ω + 1))) / (2 * (2 * ω + 1))
+
+    def prox(self, τ, x):
+        """
+        Apply the proximal map of the logarithmic regularisation functional to `x`.
+        WARNING: This function is unsafe. It may modify `x´.
+        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
+        copied or owned instance to Python.
+        """
+        k, (c1, c2) = x
+
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+        if isinstance(self.ω, float):
+            ωk, ωc = self.ω, self.ω
+        else:
+            ωk, ωc = self.ω
+
+        return (
+            self._prox1(τ * αk, τ * ωk, k),
+            (self._prox1(τ * αc, τ * ωc, c1), self._prox1(τ * αc, τ * ωc, c2)),
         )
+
+
+class BoxConstraints:
+    """
+    Simple box constraints on the  auxiliary variables
+    """
+
+    def _own(self, x):
+        return x if isinstance(x, float) else own(x)
+
+    def __init__(self, α, ω):
+        self.α = α
+        self.ω = ω
+
+    def _apply1(self, x, α, ω):
+        if isinstance(x, float):
+            if α <= x and x <= ω:
+                return 0.0
+            else:
+                return np.inf
+        else:
+            raise Exception("Unimplemented")
+
+    def apply(self, x):
+        (k, (c1, c2)) = x
+
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+        if isinstance(self.ω, float):
+            ωk, ωc = self.ω, self.ω
+        else:
+            ωk, ωc = self.ω
+
+        return (
+            self._apply1(k, αk, ωk)
+            + self._apply1(c1, αc, ωc)
+            + self._apply1(c2, αc, ωc)
+        )
+
+    def _prox1(self, α, ω, x):
+        # OC: 0 = -α/z + 2ωz + z-x ⟺ -α + (2ω+1)z^2 - xz = 0.
+        # ⟺ z = -x ± √(x^2 + 4α(2ω+1))/(2(2ω+1))
+        # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
+        return max(min(ω, x), α)
+
+    def prox(self, τ, x):
+        """
+        Apply the proximal map of the box constraints to `x`.
+        WARNING: This function is unsafe. It may modify `x´.
+        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
+        copied or owned instance to Python.
+        """
+        k, (c1, c2) = x
+
+        if isinstance(self.α, float):
+            αk, αc = self.α, self.α
+        else:
+            αk, αc = self.α
+        if isinstance(self.ω, float):
+            ωk, ωc = self.ω, self.ω
+        else:
+            ωk, ωc = self.ω
+
+        return (
+            self._prox1(αk, ωk, k),
+            (self._prox1(αc, ωc, c1), self._prox1(αc, ωc, c2)),
+        )
+
+
+"""
+Quadratic regularisation with box contraints
+"""
+
+
+class BoxedQuadraticRegularisation:
+    def __init__(self, ω0, ω1, α, base=None):
+        self.box = BoxConstraints(ω0, ω1)
+        self.quadratic = QuadraticRegularisation(α, base)
+
+    def apply(self, x):
+        return self.box.apply(x) + self.quadratic.apply(x)
+
+    def prox(self, τ, x):
+        return self.box.prox(1.0, self.quadratic.prox(τ, x))

mercurial