src/convection_diffusion.py

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/convection_diffusion.py	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,894 @@
+import os
+from dataclasses import dataclass
+from posix import mkdir
+from termios import B0
+from typing import Optional, Tuple, Union
+
+import dolfinx.geometry as geo
+import numpy as np
+import ufl
+from dolfinx import fem, mesh
+from dolfinx.fem.petsc import (
+    apply_lifting,
+    assemble_matrix,
+    assemble_vector,
+    create_vector,
+    set_bc,
+)
+from mpi4py import MPI
+from petsc4py import PETSc
+
+try:
+    import pointsource_pde.dolfinx_extras as dx
+
+    from dolfinx_access import (
+        cell_FunctionSpace_f64,
+        max_Function_f64_p2,
+        min_Function_f64_p2,
+    )
+except:
+    import src.dolfinx_extras as dx
+
+
+@dataclass
+class XBound:
+    # Parameter bounds for convection-diffusion operator estimates
+    mu_dual: float = 3.0  # ||μ||_M(Ω) bound
+    k_min: float = 0.1  # min diffusion > 0 (coercivity)
+    k_max: float = 2.0  # max diffusion
+    c_Linf: float = 0.5  # max(|c1|,|c2|) L∞ bound
+    diam: float = 1.41  # domain diameter (√2 for unit square)
+    T: float = 1.0  # final time
+
+    def __add__(self, other: "XBound") -> "XBound":
+        # Conservative combination: min(k_min), max(c_Linf) for energy bounds"""
+        other = XBound(
+            **{
+                k: getattr(other, k, getattr(self, k))
+                for k in self.__dataclass_fields__
+            }
+        )
+        return XBound(
+            mu_dual=max(self.mu_dual, other.mu_dual),
+            k_min=min(self.k_min, other.k_min),
+            k_max=max(self.k_max, other.k_max),
+            c_Linf=max(self.c_Linf, other.c_Linf),
+            diam=max(self.diam, other.diam),
+            T=max(self.T, other.T),
+        )
+
+
+class ConvectionDiffusion:
+    def __init__(
+        self,
+        u0,
+        g,
+        # w0,
+        domain_size=[-2, 2, -2, 2],
+        nx=32,
+        ny=32,
+        degree=2,
+        t0=0.0,
+        T=1.0,
+        num_steps=50,
+        p=None,
+        Delta_t=None,
+        C_stab=1.0,
+        alpha=1.0,  # \alpha = k_m/2, k convective coefficient, k>=km
+        beta=1.0,  # \beta = \| c \|_\infty^2 /k_m
+        C_emb=1.0,
+        C_reg=1.0,
+    ):
+        self.p = p  # L^p, p=2
+        self.C_stab = C_stab  # adjoint stability constant
+        self.alpha = alpha  # coercivity lower bound for k
+        self.beta = beta  # theoretical parameter
+        self.C_emb = C_emb  # embedding constant
+        self.C_reg = C_reg
+        if Delta_t is None:
+            self.Delta_t = T / num_steps  # Matches your N=50 intent
+        else:
+            self.Delta_t = Delta_t
+
+        self.domain_size = domain_size
+        self.nx = nx
+        self.ny = ny
+        self.degree = degree
+        self.T = T
+
+        self.num_steps = num_steps
+        self.dt = (T - t0) / num_steps
+        self.t0 = t0
+
+        self.save_for_plot = False
+
+        domain = mesh.create_rectangle(
+            MPI.COMM_SELF,
+            [
+                np.array([domain_size[0], domain_size[2]]),
+                np.array([domain_size[1], domain_size[3]]),
+            ],
+            [nx, ny],
+            mesh.CellType.triangle,
+        )
+        self.domain = domain
+        # Compute the domain volume for the adjoint Lipschitz factor
+        self.domain_size = domain_size
+        # Domain diameter from your domain_size
+        self.diam = np.sqrt(
+            (domain_size[1] - domain_size[0]) ** 2
+            + (domain_size[3] - domain_size[2]) ** 2
+        )
+
+        V = fem.functionspace(domain, ("Lagrange", degree))
+        self.V = V
+
+        self.u0 = fem.Function(V)
+        self.u0.name = "u0"
+        self.u0.interpolate(u0)
+
+        self.g = fem.Function(V)
+        self.g.name = "g"
+        self.g.interpolate(g)
+        self.bcs = []  # Initialize BCS as LIST
+        self.adjoint_bcs = []  # Initialize ADJOINT BCS as LIST (w = 0)
+
+        # Identify boundary facets for Dirichlet BC location (e.g. left/right boundaries at x=domain_size[0] or x=domain_size[1])
+        fdim = domain.topology.dim - 1
+        # Γ₁ facets: TOP/BOTTOM boundaries (y = ymin, ymax) - STATE Dirichlet u = g
+        gamma1_facets = mesh.locate_entities_boundary(
+            domain,
+            fdim,
+            lambda x: (
+                np.isclose(x[1], domain_size[2])  # x[1] = y = ymin
+                | np.isclose(x[1], domain_size[3])
+            ),  # x[1] = y = ymax
+        )
+        # Γ₂ facets: LEFT/RIGHT boundaries (x = xmin, xmax) - ADJOINT Dirichlet w = 0
+        gamma2_facets = mesh.locate_entities_boundary(
+            domain,
+            fdim,
+            lambda x: (
+                np.isclose(x[0], domain_size[0])  # x[0] = x = xmin
+                | np.isclose(x[0], domain_size[1])
+            ),  # x[0] = x = xmax
+        )
+        # STATE BC: u = g on Γ₁ (top/bottom)
+        dofs_gamma1 = fem.locate_dofs_topological(V, fdim, gamma1_facets)
+        # Create BC and ADD TO LIST
+        bc = fem.dirichletbc(self.g, dofs_gamma1)  # V inferred from self.g!
+        self.bcs.append(bc)
+
+        # ADJOINT BC: w = 0 on Γ₁ (same facets, homogenized)
+        zero = fem.Constant(domain, PETSc.ScalarType(0.0))
+        dofs_gamma2 = fem.locate_dofs_topological(V, fdim, gamma2_facets)
+        bc_adjoint = fem.dirichletbc(zero, dofs_gamma2, V)
+        self.adjoint_bcs.append(bc_adjoint)
+
+        self.fdim = fdim
+        self.boundary_facets = np.unique(np.concatenate((gamma1_facets, gamma2_facets)))
+
+        ux = fem.Function(V)
+        wpx = fem.Function(V)
+        self.ux = ux
+        self.wpx = wpx
+        dim = V.mesh.topology.dim
+        self.expr2_form = fem.form(ufl.inner(ufl.grad(ux), ufl.grad(wpx)) * ufl.dx)
+        self.expr3_forms = tuple(
+            fem.form(wpx * ufl.grad(ux)[i] * ufl.dx) for i in range(dim)
+        )
+
+        # Only scalar case
+        self.k = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
+        self.c0 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
+        self.c1 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[1] scalar
+        u_n = fem.Function(V)  # New copy!
+        u_n.name = "u_n"
+        self.u_n = u_n
+        u, w, v = ufl.TrialFunction(V), ufl.TrialFunction(V), ufl.TestFunction(V)
+        dt = self.dt
+        # Forward bilinear form (backward Euler)
+        a = (
+            u * v * ufl.dx
+            + dt * ufl.dot(self.k * ufl.grad(u), ufl.grad(v)) * ufl.dx
+            + dt * (self.c0 * u.dx(0) + self.c1 * u.dx(1)) * v * ufl.dx
+        )
+        self.bilinear_form = fem.form(a)
+        L = u_n * v * ufl.dx
+        self.linear_form = fem.form(L)
+
+        # Adjoint bilinear form
+        w_n = fem.Function(V)  # New copy!
+        w_n.name = "w_n"
+        self.w_n = w_n
+        j_u = fem.Function(V)  # New copy!
+        j_u.name = "j_u"
+        self.j_u = j_u
+        a2 = (
+            w * v * ufl.dx
+            + dt
+            * (
+                ufl.dot(self.k * ufl.grad(w), ufl.grad(v))
+                - (self.c0 * w.dx(0) * v + self.c1 * w.dx(1) * v)
+            )
+            * ufl.dx
+        )
+        self.bilinear_form_w = fem.form(a2)
+        # Step n=N-1: L2 uses w_n = w^N=0 → solves w^{N-1}. (∂t)_primal^*
+        # Step n=N-2: L2 uses w_n = w^{N-1} → solves w^{N-2}
+        # Replace your L2 with 2 simple forms (same notation)
+
+        L2 = w_n * v * ufl.dx + dt * j_u * v * ufl.dx  # kown w_n (previous step)
+        self.linear_form_w = fem.form(L2)
+
+    # Solving forward PDE
+    def apply(self, x):
+        μ, (k, c) = x
+
+        domain = self.domain
+        V = self.V
+        dt = self.dt
+        bcs = self.bcs
+
+        # initial condition
+        u_n = self.u_n
+        u_n.x.array[:] = self.u0.x.array
+        u_n.x.scatter_forward()
+
+        # Linear form: only contains u_n * v (Dirichlet on Γ1 is handled via bc)
+
+        # Assemble matrix WITH Dirichlet BCs applied
+        self.c0.value = c[0]
+        self.c1.value = c[1]
+        self.k.value = k
+        linear_form = self.linear_form
+        bilinear_form = self.bilinear_form
+        A = assemble_matrix(bilinear_form, bcs=bcs)
+        A.assemble()
+
+        # Create reusable RHS vector (will be filled each timestep)
+        b = create_vector(linear_form)
+        b_ps = create_vector(linear_form)
+
+        # Prepare solver once
+        solver = PETSc.KSP().create(domain.comm)
+        solver.setOperators(A)
+        solver.setType(PETSc.KSP.Type.PREONLY)
+        solver.getPC().setType(PETSc.PC.Type.LU)
+
+        # Form point source contribution
+        mesh = V.mesh
+        element = V.element.basix_element
+        mesh_nodes = mesh.geometry.x
+        cmap = mesh.geometry.cmap
+
+        b_ps.assemblyBegin()
+
+        with b_ps.localForm() as loc_b:
+            loc_b.set(0)
+
+        for point, alpha in μ.iter_padded():
+            cell_id = cell_FunctionSpace_f64(V._cpp_object, point)
+
+            if cell_id < 0:
+                print(f"Point {point} is outside the mesh.")
+            else:
+                cell_dofs = V.dofmap.cell_dofs(cell_id)
+
+                geom_dofs = mesh.geometry.dofmap[cell_id]
+                ref = cmap.pull_back(
+                    np.array([[point[0], point[1]]]), mesh_nodes[geom_dofs]
+                )
+                phi = element.tabulate(0, ref)
+
+                b_ps.setValuesLocal(
+                    cell_dofs,
+                    alpha * phi,
+                    addv=PETSc.InsertMode.ADD_VALUES,
+                )
+
+        b_ps.assemblyEnd()
+
+        t = self.t0
+        num_steps = self.num_steps
+        bcs = self.bcs
+
+        u_n_list = []
+        for i in range(num_steps):
+            t += dt
+
+            # b.assemblyBegin()
+
+            # zero-out and assemble RHS (M * u_n part)
+            with b.localForm() as loc_b:
+                loc_b.set(0)
+            assemble_vector(b, linear_form)  # b = (u^n, v)
+
+            b.axpy(1.0, b_ps)
+
+            # Apply homogeneous Dirichlet BC correction to RHS and finalize vector
+            apply_lifting(b, [bilinear_form], [bcs])  # [[bc1, bc2, ...]]
+            b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
+            set_bc(b, bcs)  # list [bc1, bc2, ...]
+            # b.assemble()
+            #
+            # b.assemblyEnd()
+
+            # b.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
+
+            # Solve linear problem A * uh = b
+            uh = fem.Function(V)
+            uh.name = "uh"
+            solver.solve(b, uh.x.petsc_vec)
+            uh.x.scatter_forward()
+
+            # Update u_n for next step, store uh
+            u_n.x.array[:] = uh.x.array
+            u_n.x.scatter_forward()
+            u_n_list.append(uh)  # Unique objects
+
+        return u_n_list
+
+    # Solving the adjoint problem
+
+    def solve_adjoint_pde(self, j, x, u_n_list):
+        μ, (k, c) = x
+
+        domain = self.domain
+        V = self.V
+        dt = self.dt
+        adjoint_bcs = self.adjoint_bcs
+
+        num_steps = len(u_n_list)
+
+        assert len(j) == num_steps, f"MISMATCH: j({len(j)}) != u_n_list({num_steps})"
+        # Terminal condition: w(T) = 0
+        w_n = self.w_n
+        w_n.x.array[:] = 0.0
+        w_n.x.scatter_forward()
+
+        self.c0.value = c[0]
+        self.c1.value = c[1]
+        self.k.value = k
+        j_u = self.j_u
+        bilinear_form_w = self.bilinear_form_w
+        linear_form_w = self.linear_form_w
+
+        A2 = assemble_matrix(bilinear_form_w, bcs=adjoint_bcs)  # Fixed bcs
+        A2.assemble()
+
+        # Create vector for RHS to be updated each step
+        b2 = create_vector(linear_form_w)
+
+        solver = PETSc.KSP().create(domain.comm)
+        solver.setOperators(A2)
+        solver.setType(PETSc.KSP.Type.PREONLY)
+        solver.getPC().setType(PETSc.PC.Type.LU)
+
+        t = self.t0 + num_steps * dt  # Start at time T
+
+        # N+1 adjoint values to match N PDE steps
+        w_n_list = [None] * (num_steps + 1)
+        whT = fem.Function(V)
+        whT.x.array[:] = 0.0  # w^N = 0
+        whT.x.scatter_forward()
+        w_n_list[num_steps] = whT  # Store at FINAL index
+
+        #  backward steps**: compute w^{N-1}, ..., w^0
+
+        for i in range(num_steps - 1, -1, -1):
+            t -= dt
+
+            # Source term at current time step: j(t_i)
+
+            j_u.x.array[:] = j[i].x.array
+            j_u.x.scatter_forward()
+
+            # Update RHS
+            with b2.localForm() as loc_b2:
+                loc_b2.set(0)
+            assemble_vector(b2, linear_form_w)
+
+            apply_lifting(b2, [bilinear_form_w], [adjoint_bcs])  #  [forms] → [bcs]
+            b2.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
+            set_bc(b2, adjoint_bcs)
+
+            # set_bc(b2, adjoint_bcs)  # list BC object
+            # b2.assemble()
+
+            wh = fem.Function(V)
+            wh.name = "wh"
+            solver.solve(b2, wh.x.petsc_vec)
+            wh.x.scatter_forward()
+
+            # Update w_n for next iteration and store
+            w_n.x.array[:] = wh.x.array
+            w_n.x.scatter_forward()
+            w_n_list[i] = wh
+
+        return w_n_list
+
+    # Compute the differential operator
+    def diff_adjdir(self, j, x, apply_result=None):
+        """
+        Returns dual space elements: (dual_μ, (dual_k, (dual_c1, dual_c2)))
+        - scalar coeff → float sensitivity
+        - Function coeff → fem.Function sensitivity (Fréchet derivative)
+        """
+        u_n_list = apply_result if apply_result is not None else self.apply(x)
+        w_n_list = self.solve_adjoint_pde(j, x, u_n_list)
+        dt = self.dt
+        V = self.V
+        dim = V.mesh.topology.dim
+
+        # Extract coefficients and check types
+        _mu, (k, c) = x
+        # is_scalar_mu = isinstance(mu, (int, float, np.number))
+        is_scalar_k = isinstance(k, (int, float, np.number))
+        is_scalar_c = all(isinstance(ci, (int, float, np.number)) for ci in c)
+
+        # 1. expr1: dual sensitivity w.r.t. μ = ∫₀^T w_p dt (always Function)
+        wp_bar = fem.Function(V)
+        wp_bar.x.array[:] = (
+            0.5
+            * dt
+            * (
+                w_n_list[0].x.array
+                + w_n_list[-1].x.array
+                + 2 * np.sum([w.x.array for w in w_n_list[1:-1]], axis=0)
+            )
+        )
+
+        ux = self.ux
+        wpx = self.wpx
+
+        # 2. expr2: dual sensitivity w.r.t. k
+        if is_scalar_k:
+            # k scalar → ⟨expr2, δk⟩ = δk ∫ ∇u·∇w_p → scalar
+            expr2 = 0.0
+
+            form = self.expr2_form
+
+            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
+                ux.x.array[:] = u.x.array[:]
+                ux.x.scatter_forward()
+                wpx.x.array[:] = wp.x.array[:]
+                wpx.x.scatter_forward()
+                v2 = fem.assemble_scalar(form)
+                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
+                expr2 -= weight * v2
+
+        else:
+            raise Exception("Unimplemented - out of date")
+
+            # k Function → ⟨expr2, δk⟩ = ∫ expr2 δk → expr2 = ∇u·∇w_p (pointwise)
+            expr2 = fem.Function(V)
+            expr2a = expr2.x.array
+
+            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
+                form = ufl.inner(ufl.grad(u), ufl.grad(wp)) * ufl.dx
+                v2 = fem.assemble_scalar(fem.form(form))
+                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
+                expr2a -= weight * v2
+
+        # 3. expr3: dual sensitivity w.r.t. c = (c1, c2)
+        if is_scalar_c:
+            # Scalar c → return tuple of 2 floats
+
+            forms = self.expr3_forms
+
+            def val(i):
+                expr3i = 0.0
+                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
+                    ux.x.array[:] = u.x.array[:]
+                    ux.x.scatter_forward()
+                    wpx.x.array[:] = wp.x.array[:]
+                    wpx.x.scatter_forward()
+                    v3 = fem.assemble_scalar(forms[i])
+                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
+                    expr3i -= weight * v3
+                return expr3i
+
+            expr3 = tuple(val(i) for i in range(dim))
+
+        else:
+            raise Exception("Unimplemented - out of date")
+
+            # Function c → return tuple of 2 Functions (pointwise derivatives)
+            def val(i):
+                # expr3[i]. x.array[:] = 0
+                expr3i = fem.Function(V)
+                expr3ia = expr3i.x.array[:]
+                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
+                    form = wp * ufl.grad(u)[i] * ufl.dx
+                    v3 = fem.assemble_scalar(fem.form(form))
+                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
+                    expr3ia -= weight * v3
+                return expr3i
+
+            expr3 = tuple(val(i) for i in range(dim))
+
+        if self.save_for_plot:
+            self.plot_wp_bar = wp_bar
+            self.plot_u_n_list = u_n_list
+            self.plot_w_n_list = w_n_list
+            self.plot_j_n_list = j
+
+        return wp_bar, (expr2, expr3)
+
+    def do_plot(self, iter, μ, frames, prefix):
+        n = self.num_steps
+        np.savez_compressed(
+            "%s/res_%d.npz" % (prefix, iter),
+            wp_bar_array=self.plot_wp_bar.x.array,
+            u_n_list_array=[self.plot_u_n_list[i].x.array for i in frames],
+            w_n_list_array=[self.plot_w_n_list[i].x.array for i in frames],
+            j_n_list_array=[self.plot_j_n_list[i].x.array for i in frames],
+            frames=frames,
+            times=list(map(lambda i: self.T * i / (n - 1), frames)),
+            mu=[[point[0], point[1], alpha] for point, alpha in μ.iter_padded()],
+        )
+
+    def _stability_prefactor(
+        self, *, C_stab=None, alpha=None, beta=None, T=None, C_emb=None, C_reg=None
+    ):
+        """
+        Return the stability prefactor
+        P := (C_stab/alpha) * (1 + sqrt(T)*(1 + sqrt(beta)/sqrt(alpha)))
+        and the measure-to-functional constant L_e_mu := C_emb * C_reg * sqrt(T).
+        Return (P, L_e_mu):
+        """
+        # self._stability_prefactor(alpha=0.5, T=2.0)
+
+        # Read from arguments or instance attributes
+        C_stab = C_stab if C_stab is not None else self.C_stab
+        alpha = alpha if alpha is not None else self.alpha
+        beta = beta if beta is not None else self.beta
+        T = T if T is not None else self.T
+        C_emb = C_emb if C_emb is not None else self.C_emb
+        C_reg = C_reg if C_reg is not None else self.C_reg
+
+        # Stability factor P
+
+        P = (C_stab / alpha) * (1 + np.sqrt(T) * (1 + np.sqrt(beta) / np.sqrt(alpha)))
+        L_e_mu = C_emb * C_reg * np.sqrt(T)
+
+        return P, L_e_mu
+
+    # Solution operator Lipschitz factor
+    def lipschitz_factor(self):
+        """
+        Conservative Lipschitz factor for the forward solution operator mapping
+        (k,c,mu) -> u. Returns a single scalar C_Lip such that
+        ||u2 - u1|| <= C_Lip * (||k2-k1|| + ||c2-c1|| + ||mu2-mu1||_* )
+        (we use a conservative form combining the measure and (k,c) pieces).
+        """
+        P, L_e_mu = self._stability_prefactor()
+        # Conservative combined choice: multiply P by max(1, L_e_mu)
+        C_Lip = P * max(1.0, L_e_mu)
+        return C_Lip
+
+    # Adjoint solution operator Lipschitz factor
+    """
+    def diff_adj_lipschitz_factor(self):
+
+        Compute adjoint-solution Lipschitz factor A for 2D domain.
+
+        For p = 2, A = 1.
+        For p > 2, A = |Omega|^(1/2 - 1/p), where |Omega| is the mesh area.
+
+        p = self.p
+        omega_vol = self.domain_volume
+        return 1.0 if p == 2 else omega_vol ** (0.5 - 1.0 / p)
+    """
+
+    def diff_adj_lipschitz_factor(self):
+        """
+        Always return 1 for p=2
+        """
+        return 1.0
+
+    def diff_bound(self, xbound=None):
+        """
+        Differential bound for the solution operator derivative.
+
+        Returns the tuple: (L_e_mu, L_e_k, L_e_c)
+
+         where
+        L_e_k = ||∇u||_{L^2(Ω_T)}
+        L_e_c = ||∇u||_{L^2(Ω_T)}
+        L_e_mu = C_emb * C_reg * sqrt(T)
+        """
+
+        # If grad_u_norm is None, compute it from u0
+        grad_u_norm = getattr(self, "grad_u_norm", None)
+        if grad_u_norm is None:
+            u = self.u0
+            V = self.V
+            # Compute L2 norm of gradient of u0 over the domain
+            grad_u = ufl.grad(u)
+            grad_u_sq = ufl.inner(grad_u, grad_u)
+            dx = ufl.Measure("dx", domain=self.domain)
+            # Integral of |∇u|^2 over Ω
+            grad_u_norm = fem.assemble_scalar(fem.form(grad_u_sq * dx)) ** 0.5
+            # Optionally store it for later reuse
+            self.grad_u_norm = grad_u_norm
+
+        return grad_u_norm
+
+    def diff_bound_pair(self, xbound=None, Delta_t=None):
+        if xbound is None:
+            return self.diff_bound(xbound=None)
+        else:
+            xb_combined = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
+
+            k_min = getattr(xb_combined, "k_min", self.alpha)
+            c_Linf = getattr(xb_combined, "c_Linf", np.sqrt(self.beta * k_min))
+            diam = getattr(xb_combined, "diam", self.diam)
+            T = getattr(xb_combined, "T", self.T)
+
+            Pe = c_Linf * diam / k_min
+            gamma = c_Linf**2 / k_min
+            # adaptive time step size Delta_t (Delta_t = 0.01 in the test)
+            if Delta_t is None:
+                Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)  # CFL-like
+
+            # Accumulate N = T/Delta_t local steps
+            N = int(np.ceil(T / Delta_t))
+
+            # Per-step factors
+            exp_local = np.exp(0.5 * gamma * Delta_t)
+            C_adj_step = np.sqrt(Delta_t * (exp_local + Delta_t / k_min))
+            C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + Pe**2)
+
+            # Total accumulated bound (geometric sum <= N * max_step)
+            C_adj_PDE = N * C_adj_step
+            C_state = N * C_state_step
+
+            # Use the embedding constant
+            C_mu_adj = self.C_emb * C_adj_PDE  # ∫φ δμ
+            C_k_adj = C_state * C_adj_PDE  # ∫∇u·∇φ δk
+            C_c1_adj = C_state * C_adj_PDE * c_Linf  # ∫∂x u φ δc1
+            C_c2_adj = C_state * C_adj_PDE * c_Linf  # ∫∂y u φ δc2
+
+            return C_mu_adj, C_k_adj, C_c1_adj, C_c2_adj
+
+    # ||S(x)||_Y→X ≤ C_state ||μ||_M(Ω) [time-independent μ]
+    def codomain_bound(self, xbound=None):
+        if xbound is None:
+            return 1.0
+
+        # Extract parameters
+        if hasattr(xbound, "k_min"):
+            xb = xbound
+        else:
+            if hasattr(xbound, "__len__") and len(xbound) > 0:
+                xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
+            else:
+                xb = xbound
+            k_min = getattr(xb, "k_min", self.alpha)
+            c_Linf = getattr(xb, "c_Linf", np.sqrt(self.beta * k_min))
+            mu_dual = getattr(xb, "mu_dual", 3.0)
+            diam = getattr(xb, "diam", self.diam)
+            T = getattr(xb, "T", self.T)
+
+        Pe = c_Linf * diam / k_min
+        if not hasattr(self, "Delta_t"):
+            Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)
+        else:
+            Delta_t = self.Delta_t
+
+        N = int(np.ceil(T / Delta_t))
+        gamma = c_Linf**2 / k_min
+        exp_local = np.exp(0.25 * gamma * Delta_t)
+
+        C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + 0.5 * Pe**2) * exp_local
+        C_state = N * C_state_step
+
+        return C_state * mu_dual
+
+        # ||S(x1)-S(x2)|| ≤ Cμ||Δμ|| + Ck||Δk|| + Cc1||Δc1|| + Cc2||Δc2||
+
+    def codomain_bound_pair(self, xbound=None):
+        if xbound is None:
+            return self.codomain_bound(xbound=None)
+        else:
+            xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
+            C_base = self.codomain_bound(xb)
+            # Component
+            C_mu = C_base
+            C_k = C_base * 2.0
+            C_c1 = C_base * 1.5
+            C_c2 = C_base * 1.5
+
+            return (C_mu, C_k, C_c1, C_c2)
+
+    """
+    def codomain_bound(self, xbound=None):
+        return 1.0
+
+    def codomain_bound_pair(self, xbound=None):
+        # TODO: implement properly
+        if xbound is None:
+            return self.codomain_bound(xbound=None)
+        else:
+            return self.codomain_bound(xbound=xbound[0] + xbound[1])
+    """
+
+    def diff_bound3(self):
+        """
+        Returns full operator differential bounds:
+
+        L_e_mu = C_emb * C_reg * sqrt(T)
+        L_e_k  = ||∇u||_{L^2(Ω)}
+        L_e_c  = ||∇u||_{L^2(Ω)}
+
+        """
+
+        # First get L_e_k = L_e_c
+        grad_u_norm = self.diff_bound()
+
+        # Compute L_e_mu from stability prefactor
+        _P, L_e_mu = self._stability_prefactor()
+
+        L_e_k = grad_u_norm
+        L_e_c = grad_u_norm
+
+        return L_e_mu, L_e_k, L_e_c
+
+    # Solution operator Lipschitz factor separately wrt. μ and (k, c)
+    def lipschitz_factor_pair(self):
+        """
+        Compute the forward solution operator Lipschitz factors
+        separately with respect to the parameters (k, c) and μ.
+
+        Returns:
+        tuple: (L_mu, L_kc)
+            L_mu  - Lipschitz factor w.r.t. μ
+            L_kc  - Lipschitz factor w.r.t. (k, c)
+        """
+        # Get differential bounds for the solution operator
+        L_e_mu, L_e_k, L_e_c = self.diff_bound3()
+
+        # Lipschitz factor w.r.t μ
+        L_mu = L_e_mu
+
+        # Lipschitz factor w.r.t (k, c)
+        # Conservative choice: max of L_e_k and L_e_c
+        L_kc = max(L_e_k, L_e_c)
+
+        return L_mu, L_kc
+
+    # Adjoint solution operator Lipschitz factor separately wrt. μ and (k, c)
+    def diff_adj_lipschitz_factor_pair(self):
+        A = self.diff_adj_lipschitz_factor()
+
+        # No separate analytic dependence, so return same A for both parts
+        A_mu = A
+        A_kc = A
+
+        return A_mu, A_kc
+
+
+def own(u):
+    # """
+    # Return a version of u that is guaranteed to be uniquely owned.
+    # If u has other Python references, return a deep copy.
+    # """
+    # # sys.getrefcount(u) includes the temporary reference inside getrefcount,
+    # # so '2' means exactly one external reference.
+    # if sys.getrefcount(u) <= 2:
+    #     return u  # safe: no other references
+
+    # deep copy into a new Function
+    u_new = fem.Function(u.function_space)
+    # Does not work
+    # u_new.x.petsc_vec.copy(u.x.petsc_vec)
+    # u_new.x.scatter_forward()
+    np.copyto(u_new.x.array, u.x.array)
+    return u_new
+
+
+def nn2(x):
+    return None if isinstance(x, float) else dx.norm2_squared(x)
+
+
+class PlotFactory:
+    def __init__(self, pde, n_plots=5):
+        self.pde = pde
+        self.n_plots = n_plots
+
+        m = n_plots
+        n = pde.num_steps
+        pde.save_for_plot = True
+        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
+        self.pde = pde
+
+    def produce(self, prefix):
+        return Plotter(self.pde, self.n_plots, prefix)
+
+
+class Plotter:
+    def __init__(self, pde, n_plots, prefix):
+        m = n_plots
+        n = pde.num_steps
+        pde.save_for_plot = True
+        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
+        self.pde = pde
+        os.makedirs(prefix, exist_ok=True)
+        self.prefix = prefix
+
+    def plot(self, iter, μ):
+        self.pde.do_plot(iter, μ, self.plot_frames, self.prefix)
+
+
+class QuadraticRegularisation:
+    """
+    Quadratic regularisation functional for the auxiliary variables
+    """
+
+    def _own(self, x):
+        return x if isinstance(x, float) else own(x)
+
+    def __init__(self, α, base=None):
+        self.α = α
+        if base is not None:
+            (k, (c1, c2)) = base
+            self.base = (self._own(k), (self._own(c1), self._own(c2)))
+            self.base_norm_squared = (nn2(k), (nn2(c1), nn2(c2)))
+        else:
+            self.base = None
+
+    def _apply1(self, x, b, n):
+        if isinstance(x, float):
+            d = x if b is None else x - b
+            return d * d / 2.0
+        else:
+            if b is not None:
+                return (dx.norm2_squared(x) + n) / 2.0 + dx.dot(x, b)
+            return dx.norm2_squared(x) / 2.0
+
+    def _get_base(self):
+        return (None, (None, None)) if self.base is None else self.base
+
+    def apply(self, x):
+        """
+        Apply the quadratic regularisation fucntional to `x`.
+        """
+        (k, (c1, c2)) = x
+        (kb, (c1b, c2b)) = self._get_base()
+        (nkb, (nc1b, nc2b)) = self.base_norm_squared
+
+        return self.α * (
+            self._apply1(k, kb, nkb)
+            + self._apply1(c1, c1b, nc1b)
+            + self._apply1(c2, c2b, nc2b)
+        )
+
+    def _prox1(self, τ, x, b):
+        γ = self.α * τ
+        β = 1.0 / (1.0 + γ)
+        if isinstance(x, float):
+            return β * (x if b is None else x + b * γ)
+        else:
+            # x = own(x)
+            vx = x.x.array
+            if b is not None:
+                vx += b.x.array * γ
+            vx *= β
+            return x
+
+    def prox(self, τ, x):
+        """
+        Apply the proximal map of the quadratic regularisation fucntional to `x`.
+        WARNING: This function is unsafe. It may modify `x´.
+        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
+        copied or owned instance to Python.
+        """
+        k, (c1, c2) = x
+        (kb, (c1b, c2b)) = self._get_base()
+
+        return (
+            self._prox1(τ, k, kb),
+            (self._prox1(τ, c1, c1b), self._prox1(τ, c2, c2b)),
+        )

mercurial