src/convection_diffusion.py

Fri, 08 May 2026 17:28:21 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 08 May 2026 17:28:21 -0500
changeset 5
3e164c024a01
parent 3
c3a4f4bb87f7
permissions
-rw-r--r--

Change README title

import os
from dataclasses import dataclass
from posix import mkdir
from termios import B0
from typing import Optional, Tuple, Union

import dolfinx.geometry as geo
import numpy as np
import ufl
from dolfinx import fem, mesh
from dolfinx.fem.petsc import (
    apply_lifting,
    assemble_matrix,
    assemble_vector,
    create_vector,
    set_bc,
)
from mpi4py import MPI
from petsc4py import PETSc

try:
    import pointsource_pde.dolfinx_extras as dx

    from dolfinx_access import (
        cell_FunctionSpace_f64,
        max_Function_f64_p2,
        min_Function_f64_p2,
    )
except:
    import src.dolfinx_extras as dx


@dataclass
class XBound:
    # Parameter bounds for convection-diffusion operator estimates
    mu_dual: float = 3.0  # ||μ||_M(Ω) bound
    k_min: float = 0.1  # min diffusion > 0 (coercivity)
    k_max: float = 2.0  # max diffusion
    c_Linf: float = 0.5  # max(|c1|,|c2|) L∞ bound
    diam: float = 1.41  # domain diameter (√2 for unit square)
    T: float = 1.0  # final time

    def __add__(self, other: "XBound") -> "XBound":
        # Conservative combination: min(k_min), max(c_Linf) for energy bounds"""
        other = XBound(
            **{
                k: getattr(other, k, getattr(self, k))
                for k in self.__dataclass_fields__
            }
        )
        return XBound(
            mu_dual=max(self.mu_dual, other.mu_dual),
            k_min=min(self.k_min, other.k_min),
            k_max=max(self.k_max, other.k_max),
            c_Linf=max(self.c_Linf, other.c_Linf),
            diam=max(self.diam, other.diam),
            T=max(self.T, other.T),
        )


class ConvectionDiffusion:
    def __init__(
        self,
        u0,
        g,
        # w0,
        domain_size=[-2, 2, -2, 2],
        nx=32,
        ny=32,
        degree=2,
        t0=0.0,
        T=1.0,
        num_steps=50,
        p=None,
        Delta_t=None,
        C_stab=1.0,
        alpha=1.0,  # \alpha = k_m/2, k convective coefficient, k>=km
        beta=1.0,  # \beta = \| c \|_\infty^2 /k_m
        C_emb=1.0,
        C_reg=1.0,
        xbound=XBound,
        override_lipschitz=None,
        override_lipschitz_pair=None,
    ):
        self.p = p  # L^p, p=2
        self.C_stab = C_stab  # adjoint stability constant
        self.alpha = alpha  # coercivity lower bound for k
        self.beta = beta  # theoretical parameter
        self.C_emb = C_emb  # embedding constant
        self.C_reg = C_reg
        if Delta_t is None:
            self.Delta_t = T / num_steps  # Matches your N=50 intent
        else:
            self.Delta_t = Delta_t

        self.domain_size = domain_size
        self.nx = nx
        self.ny = ny
        self.degree = degree
        self.T = T
        self.xbound = xbound

        self.num_steps = num_steps
        self.dt = (T - t0) / num_steps
        self.t0 = t0

        self.save_for_plot = False

        self.override_lipschitz = override_lipschitz
        self.override_lipschitz_pair = override_lipschitz_pair

        domain = mesh.create_rectangle(
            MPI.COMM_SELF,
            [
                np.array([domain_size[0], domain_size[2]]),
                np.array([domain_size[1], domain_size[3]]),
            ],
            [nx, ny],
            mesh.CellType.triangle,
        )
        self.domain = domain
        # Compute the domain volume for the adjoint Lipschitz factor
        self.domain_size = domain_size
        # Domain diameter from your domain_size
        self.diam = np.sqrt(
            (domain_size[1] - domain_size[0]) ** 2
            + (domain_size[3] - domain_size[2]) ** 2
        )

        V = fem.functionspace(domain, ("Lagrange", degree))
        self.V = V

        self.u0 = fem.Function(V)
        self.u0.name = "u0"
        self.u0.interpolate(u0)

        self.g = fem.Function(V)
        self.g.name = "g"
        self.g.interpolate(g)
        self.bcs = []  # Initialize BCS as LIST
        self.adjoint_bcs = []  # Initialize ADJOINT BCS as LIST (w = 0)

        # Identify boundary facets for Dirichlet BC location (e.g. left/right boundaries at x=domain_size[0] or x=domain_size[1])
        fdim = domain.topology.dim - 1
        # Γ₁ facets: TOP/BOTTOM boundaries (y = ymin, ymax) - STATE Dirichlet u = g
        gamma1_facets = mesh.locate_entities_boundary(
            domain,
            fdim,
            lambda x: (
                np.isclose(x[1], domain_size[2])  # x[1] = y = ymin
                | np.isclose(x[1], domain_size[3])
            ),  # x[1] = y = ymax
        )
        # Γ₂ facets: LEFT/RIGHT boundaries (x = xmin, xmax) - ADJOINT Dirichlet w = 0
        gamma2_facets = mesh.locate_entities_boundary(
            domain,
            fdim,
            lambda x: (
                np.isclose(x[0], domain_size[0])  # x[0] = x = xmin
                | np.isclose(x[0], domain_size[1])
            ),  # x[0] = x = xmax
        )
        # STATE BC: u = g on Γ₁ (top/bottom)
        dofs_gamma1 = fem.locate_dofs_topological(V, fdim, gamma1_facets)
        # Create BC and ADD TO LIST
        bc = fem.dirichletbc(self.g, dofs_gamma1)  # V inferred from self.g!
        self.bcs.append(bc)

        # ADJOINT BC: w = 0 on Γ₁ (same facets, homogenized)
        zero = fem.Constant(domain, PETSc.ScalarType(0.0))
        dofs_gamma2 = fem.locate_dofs_topological(V, fdim, gamma2_facets)
        bc_adjoint = fem.dirichletbc(zero, dofs_gamma2, V)
        self.adjoint_bcs.append(bc_adjoint)

        self.fdim = fdim
        self.boundary_facets = np.unique(np.concatenate((gamma1_facets, gamma2_facets)))

        ux = fem.Function(V)
        wpx = fem.Function(V)
        self.ux = ux
        self.wpx = wpx
        dim = V.mesh.topology.dim
        self.expr2_form = fem.form(ufl.inner(ufl.grad(ux), ufl.grad(wpx)) * ufl.dx)
        self.expr3_forms = tuple(
            fem.form(wpx * ufl.grad(ux)[i] * ufl.dx) for i in range(dim)
        )

        # Only scalar case
        self.k = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
        self.c0 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
        self.c1 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[1] scalar
        u_n = fem.Function(V)  # New copy!
        u_n.name = "u_n"
        self.u_n = u_n
        u, w, v = ufl.TrialFunction(V), ufl.TrialFunction(V), ufl.TestFunction(V)
        dt = self.dt
        # Forward bilinear form (backward Euler)
        a = (
            u * v * ufl.dx
            + dt * ufl.dot(self.k * ufl.grad(u), ufl.grad(v)) * ufl.dx
            + dt * (self.c0 * u.dx(0) + self.c1 * u.dx(1)) * v * ufl.dx
        )
        self.bilinear_form = fem.form(a)
        L = u_n * v * ufl.dx
        self.linear_form = fem.form(L)

        # Adjoint bilinear form
        w_n = fem.Function(V)  # New copy!
        w_n.name = "w_n"
        self.w_n = w_n
        j_u = fem.Function(V)  # New copy!
        j_u.name = "j_u"
        self.j_u = j_u
        a2 = (
            w * v * ufl.dx
            + dt
            * (
                ufl.dot(self.k * ufl.grad(w), ufl.grad(v))
                - (self.c0 * w.dx(0) * v + self.c1 * w.dx(1) * v)
            )
            * ufl.dx
        )
        self.bilinear_form_w = fem.form(a2)
        # Step n=N-1: L2 uses w_n = w^N=0 → solves w^{N-1}. (∂t)_primal^*
        # Step n=N-2: L2 uses w_n = w^{N-1} → solves w^{N-2}
        # Replace your L2 with 2 simple forms (same notation)

        L2 = w_n * v * ufl.dx + dt * j_u * v * ufl.dx  # kown w_n (previous step)
        self.linear_form_w = fem.form(L2)

    # Solving forward PDE
    def apply(self, x):
        μ, (k, c) = x

        domain = self.domain
        V = self.V
        dt = self.dt
        bcs = self.bcs

        # initial condition
        u_n = self.u_n
        u_n.x.array[:] = self.u0.x.array
        u_n.x.scatter_forward()

        # Linear form: only contains u_n * v (Dirichlet on Γ1 is handled via bc)

        # Assemble matrix WITH Dirichlet BCs applied
        self.c0.value = c[0]
        self.c1.value = c[1]
        self.k.value = k
        linear_form = self.linear_form
        bilinear_form = self.bilinear_form
        A = assemble_matrix(bilinear_form, bcs=bcs)
        A.assemble()

        # Create reusable RHS vector (will be filled each timestep)
        b = create_vector(linear_form.function_spaces)
        b_ps = create_vector(linear_form.function_spaces)

        # Prepare solver once
        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        # Form point source contribution
        mesh = V.mesh
        element = V.element.basix_element
        mesh_nodes = mesh.geometry.x
        cmap = mesh.geometry.cmap

        b_ps.assemblyBegin()

        with b_ps.localForm() as loc_b:
            loc_b.set(0)

        for point, alpha in μ.iter_padded():
            cell_id = cell_FunctionSpace_f64(V._cpp_object, point)

            if cell_id < 0:
                print(f"Point {point} is outside the mesh.")
            else:
                cell_dofs = V.dofmap.cell_dofs(cell_id)

                geom_dofs = mesh.geometry.dofmap[cell_id]
                ref = cmap.pull_back(
                    np.array([[point[0], point[1]]]), mesh_nodes[geom_dofs]
                )
                phi = element.tabulate(0, ref)

                b_ps.setValuesLocal(
                    cell_dofs,
                    alpha * phi,
                    addv=PETSc.InsertMode.ADD_VALUES,
                )

        b_ps.assemblyEnd()

        t = self.t0
        num_steps = self.num_steps
        bcs = self.bcs

        u_n_list = []
        for i in range(num_steps):
            t += dt

            # b.assemblyBegin()

            # zero-out and assemble RHS (M * u_n part)
            with b.localForm() as loc_b:
                loc_b.set(0)
            assemble_vector(b, linear_form)  # b = (u^n, v)

            b.axpy(1.0, b_ps)

            # Apply homogeneous Dirichlet BC correction to RHS and finalize vector
            apply_lifting(b, [bilinear_form], [bcs])  # [[bc1, bc2, ...]]
            b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b, bcs)  # list [bc1, bc2, ...]
            # b.assemble()
            #
            # b.assemblyEnd()

            # b.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

            # Solve linear problem A * uh = b
            uh = fem.Function(V)
            uh.name = "uh"
            solver.solve(b, uh.x.petsc_vec)
            uh.x.scatter_forward()

            # Update u_n for next step, store uh
            u_n.x.array[:] = uh.x.array
            u_n.x.scatter_forward()
            u_n_list.append(uh)  # Unique objects

        return u_n_list

    # Solving the adjoint problem

    def solve_adjoint_pde(self, j, x, u_n_list):
        μ, (k, c) = x

        domain = self.domain
        V = self.V
        dt = self.dt
        adjoint_bcs = self.adjoint_bcs

        num_steps = len(u_n_list)

        assert len(j) == num_steps, f"MISMATCH: j({len(j)}) != u_n_list({num_steps})"
        # Terminal condition: w(T) = 0
        w_n = self.w_n
        w_n.x.array[:] = 0.0
        w_n.x.scatter_forward()

        self.c0.value = c[0]
        self.c1.value = c[1]
        self.k.value = k
        j_u = self.j_u
        bilinear_form_w = self.bilinear_form_w
        linear_form_w = self.linear_form_w

        A2 = assemble_matrix(bilinear_form_w, bcs=adjoint_bcs)  # Fixed bcs
        A2.assemble()

        # Create vector for RHS to be updated each step
        b2 = create_vector(linear_form_w.function_spaces)

        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A2)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        t = self.t0 + num_steps * dt  # Start at time T

        # N+1 adjoint values to match N PDE steps
        w_n_list = [None] * (num_steps + 1)
        whT = fem.Function(V)
        whT.x.array[:] = 0.0  # w^N = 0
        whT.x.scatter_forward()
        w_n_list[num_steps] = whT  # Store at FINAL index

        #  backward steps**: compute w^{N-1}, ..., w^0

        for i in range(num_steps - 1, -1, -1):
            t -= dt

            # Source term at current time step: j(t_i)

            j_u.x.array[:] = j[i].x.array
            j_u.x.scatter_forward()

            # Update RHS
            with b2.localForm() as loc_b2:
                loc_b2.set(0)
            assemble_vector(b2, linear_form_w)

            apply_lifting(b2, [bilinear_form_w], [adjoint_bcs])  #  [forms] → [bcs]
            b2.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b2, adjoint_bcs)

            # set_bc(b2, adjoint_bcs)  # list BC object
            # b2.assemble()

            wh = fem.Function(V)
            wh.name = "wh"
            solver.solve(b2, wh.x.petsc_vec)
            wh.x.scatter_forward()

            # Update w_n for next iteration and store
            w_n.x.array[:] = wh.x.array
            w_n.x.scatter_forward()
            w_n_list[i] = wh

        return w_n_list

    # Compute the differential operator
    def diff_adjdir(self, j, x, apply_result=None):
        """
        Returns dual space elements: (dual_μ, (dual_k, (dual_c1, dual_c2)))
        - scalar coeff → float sensitivity
        - Function coeff → fem.Function sensitivity (Fréchet derivative)
        """
        u_n_list = apply_result if apply_result is not None else self.apply(x)
        w_n_list = self.solve_adjoint_pde(j, x, u_n_list)
        dt = self.dt
        V = self.V
        dim = V.mesh.topology.dim

        # Extract coefficients and check types
        _mu, (k, c) = x
        # is_scalar_mu = isinstance(mu, (int, float, np.number))
        is_scalar_k = isinstance(k, (int, float, np.number))
        is_scalar_c = all(isinstance(ci, (int, float, np.number)) for ci in c)

        # 1. expr1: dual sensitivity w.r.t. μ = ∫₀^T w_p dt (always Function)
        wp_bar = fem.Function(V)
        wp_bar.x.array[:] = (
            0.5
            * dt
            * (
                w_n_list[0].x.array
                + w_n_list[-1].x.array
                + 2 * np.sum([w.x.array for w in w_n_list[1:-1]], axis=0)
            )
        )

        ux = self.ux
        wpx = self.wpx

        # 2. expr2: dual sensitivity w.r.t. k
        if is_scalar_k:
            # k scalar → ⟨expr2, δk⟩ = δk ∫ ∇u·∇w_p → scalar
            expr2 = 0.0

            form = self.expr2_form

            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                ux.x.array[:] = u.x.array[:]
                ux.x.scatter_forward()
                wpx.x.array[:] = wp.x.array[:]
                wpx.x.scatter_forward()
                v2 = fem.assemble_scalar(form)
                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                expr2 -= weight * v2

        else:
            raise Exception("Unimplemented - out of date")

            # k Function → ⟨expr2, δk⟩ = ∫ expr2 δk → expr2 = ∇u·∇w_p (pointwise)
            expr2 = fem.Function(V)
            expr2a = expr2.x.array

            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                form = ufl.inner(ufl.grad(u), ufl.grad(wp)) * ufl.dx
                v2 = fem.assemble_scalar(fem.form(form))
                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                expr2a -= weight * v2

        # 3. expr3: dual sensitivity w.r.t. c = (c1, c2)
        if is_scalar_c:
            # Scalar c → return tuple of 2 floats

            forms = self.expr3_forms

            def val(i):
                expr3i = 0.0
                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                    ux.x.array[:] = u.x.array[:]
                    ux.x.scatter_forward()
                    wpx.x.array[:] = wp.x.array[:]
                    wpx.x.scatter_forward()
                    v3 = fem.assemble_scalar(forms[i])
                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                    expr3i -= weight * v3
                return expr3i

            expr3 = tuple(val(i) for i in range(dim))

        else:
            raise Exception("Unimplemented - out of date")

            # Function c → return tuple of 2 Functions (pointwise derivatives)
            def val(i):
                # expr3[i]. x.array[:] = 0
                expr3i = fem.Function(V)
                expr3ia = expr3i.x.array[:]
                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                    form = wp * ufl.grad(u)[i] * ufl.dx
                    v3 = fem.assemble_scalar(fem.form(form))
                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                    expr3ia -= weight * v3
                return expr3i

            expr3 = tuple(val(i) for i in range(dim))

        if self.save_for_plot:
            self.plot_wp_bar = wp_bar
            self.plot_u_n_list = u_n_list
            self.plot_w_n_list = w_n_list
            self.plot_j_n_list = j
            self.plot_aux = (k, c)

        return wp_bar, (expr2, expr3)

    def do_plot(self, iter, μ, frames, prefix):
        n = self.num_steps
        k, (c1, c2) = self.plot_aux
        np.savez_compressed(
            "%s/res_%d.npz" % (prefix, iter),
            wp_bar_array=self.plot_wp_bar.x.array,
            u_n_list_array=[self.plot_u_n_list[i].x.array for i in frames],
            w_n_list_array=[self.plot_w_n_list[i].x.array for i in frames],
            j_n_list_array=[self.plot_j_n_list[i].x.array for i in frames],
            k=k,
            c1=c1,
            c2=c2,
            frames=frames,
            times=list(map(lambda i: self.T * i / (n - 1), frames)),
            mu=[[point[0], point[1], alpha] for point, alpha in μ.iter_padded()],
        )
        # print("Aux variables: k: %f, c1: %f, c2: %f" % (k, c1, c2))

    # def _stability_prefactor(
    #     self, *, C_stab=None, alpha=None, beta=None, T=None, C_emb=None, C_reg=None
    # ):
    #     """
    #     Return the stability prefactor
    #     P := (C_stab/alpha) * (1 + sqrt(T)*(1 + sqrt(beta)/sqrt(alpha)))
    #     and the measure-to-functional constant L_e_mu := C_emb * C_reg * sqrt(T).
    #     Return (P, L_e_mu):
    #     """
    #     # self._stability_prefactor(alpha=0.5, T=2.0)

    #     # Read from arguments or instance attributes
    #     C_stab = C_stab if C_stab is not None else self.C_stab
    #     alpha = alpha if alpha is not None else self.alpha
    #     beta = beta if beta is not None else self.beta
    #     T = T if T is not None else self.T
    #     C_emb = C_emb if C_emb is not None else self.C_emb
    #     C_reg = C_reg if C_reg is not None else self.C_reg

    #     # Stability factor P

    #     P = (C_stab / alpha) * (1 + np.sqrt(T) * (1 + np.sqrt(beta) / np.sqrt(alpha)))
    #     L_e_mu = C_emb * C_reg * np.sqrt(T)

    #     return P, L_e_mu

    # # Solution operator Lipschitz factor
    # def lipschitz_factor(self, xbound=None):
    #     """
    #     Conservative Lipschitz factor for the forward solution operator mapping
    #     (k,c,mu) -> u. Returns a single scalar C_Lip such that
    #     ||u2 - u1|| <= C_Lip * (||k2-k1|| + ||c2-c1|| + ||mu2-mu1||_* )
    #     (we use a conservative form combining the measure and (k,c) pieces).
    #     """
    #     P, L_e_mu = self._stability_prefactor()
    #     # Conservative combined choice: multiply P by max(1, L_e_mu)
    #     C_Lip = P * max(1.0, L_e_mu)
    #     return C_Lip
    def diff_chain_lipschitz_factor(self, Delta_t=None):
        """
        Compute a Lipschitz factor C_tot^Lip using locally accumulated bounds
         M_J = 1.
        """

        if self.override_lipschitz is not None:
            return self.override_lipschitz

        #  Use existing diff_bound3
        C_mu_adj, C_k_adj, C_c_adj = self.diff_bound3(Delta_t=Delta_t)

        if Delta_t is None:
            # fall back to the scheme used in diff_bound3
            Delta_t = self.dt

        N = int(np.ceil(self.T / Delta_t))

        # Use local per‑step bounds
        C_adj_step = C_mu_adj / N
        C_state_step = C_k_adj / N
        C_c_step = C_c_adj / N

        # Per‑step "Lipschitz"
        l_e_local = C_adj_step + C_state_step + C_c_step

        # Global Lipschitz constants
        L_e = l_e_local * N
        M_e = 1.2 * L_e
        L_A = 0.5 * C_state_step * N

        # PDE/coercivity
        alpha = max(0.1, self.alpha)  # guard α ≥ 0.1
        M_J = 1.0

        # C_tot^Lip
        inv_alpha = 1.0 / alpha
        inv_alpha_sq = inv_alpha * inv_alpha
        C_tot = inv_alpha * (L_e + inv_alpha_sq * L_A * M_e) * M_J
        return C_tot

    def diff_chain_lipschitz_factor_pair(self, Delta_t=None):
        """
        Return  (C^Lip_μ/M_J, C^Lip_aux/M_J)
        """

        if self.override_lipschitz_pair is not None:
            return self.override_lipschitz_pair

        C_mu_adj, C_k_adj, C_c_adj = self.diff_bound3(Delta_t=Delta_t)

        if Delta_t is None:
            Delta_t = self.dt

        N = int(np.ceil(self.T / Delta_t))

        # Extract per-component per-step bounds (smaller scaling)
        C_mu_step = C_mu_adj / N
        C_aux_step = max(C_k_adj, C_c_adj) / N

        # TINY Lipschitz factors
        C_Lip_mu = C_mu_step * N
        C_Lip_aux = C_aux_step * N

        # Light coercivity scaling
        alpha = self.alpha
        inv_alpha = 1.0 / alpha

        C_mu_tot = C_Lip_mu * inv_alpha
        C_aux_tot = C_Lip_aux * inv_alpha

        return C_mu_tot, C_aux_tot

    def diff_bound3(self, Delta_t=None):
        xbound = self.xbound
        xb_combined = xbound  # xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]

        k_min = getattr(xb_combined, "k_min", self.alpha)
        c_Linf = getattr(xb_combined, "c_Linf", np.sqrt(self.beta * k_min))
        diam = getattr(xb_combined, "diam", self.diam)
        T = getattr(xb_combined, "T", self.T)

        Pe = c_Linf * diam / k_min
        gamma = c_Linf**2 / k_min
        # adaptive time step size Delta_t (Delta_t = 0.01 in the test)
        if Delta_t is None:
            Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)  # CFL-like

        # Accumulate N = T/Delta_t local steps
        N = int(np.ceil(T / Delta_t))

        # Per-step factors
        exp_local = np.exp(0.5 * gamma * Delta_t)
        C_adj_step = np.sqrt(Delta_t * (exp_local + Delta_t / k_min))
        C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + Pe**2)

        # Total accumulated bound (geometric sum <= N * max_step)
        C_adj_PDE = N * C_adj_step
        C_state = N * C_state_step

        # Use the embedding constant
        C_mu_adj = self.C_emb * C_adj_PDE  # ∫φ δμ
        C_k_adj = C_state * C_adj_PDE  # ∫∇u·∇φ δk
        C_c_adj = C_state * C_adj_PDE * c_Linf  # ∫∂x u φ δc1 and c2

        return C_mu_adj, C_k_adj, C_c_adj

    # ||S(x)||_Y→X ≤ C_state ||μ||_M(Ω) [time-independent μ]
    def codomain_bound(self):
        xbound = self.xbound

        # Extract parameters
        if hasattr(xbound, "k_min"):
            xb = xbound
        elif hasattr(xbound, "__len__") and len(xbound) > 0:
            xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
        else:
            xb = xbound

        k_min = getattr(xb, "k_min", self.alpha)
        c_Linf = getattr(xb, "c_Linf", np.sqrt(self.beta * k_min))
        mu_dual = getattr(xb, "mu_dual", 3.0)
        diam = getattr(xb, "diam", self.diam)
        T = getattr(xb, "T", self.T)

        Pe = c_Linf * diam / k_min
        if not hasattr(self, "Delta_t"):
            Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)
        else:
            Delta_t = self.Delta_t

        N = int(np.ceil(T / Delta_t))
        gamma = c_Linf**2 / k_min
        exp_local = np.exp(0.25 * gamma * Delta_t)

        C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + 0.5 * Pe**2) * exp_local
        C_state = N * C_state_step

        return C_state * mu_dual


def own(u):
    # """
    # Return a version of u that is guaranteed to be uniquely owned.
    # If u has other Python references, return a deep copy.
    # """
    # # sys.getrefcount(u) includes the temporary reference inside getrefcount,
    # # so '2' means exactly one external reference.
    # if sys.getrefcount(u) <= 2:
    #     return u  # safe: no other references

    # deep copy into a new Function
    u_new = fem.Function(u.function_space)
    # Does not work
    # u_new.x.petsc_vec.copy(u.x.petsc_vec)
    # u_new.x.scatter_forward()
    np.copyto(u_new.x.array, u.x.array)
    return u_new


def nn2(x):
    return None if isinstance(x, float) else dx.norm2_squared(x)


class PlotFactory:
    def __init__(self, pde, n_plots=5):
        self.pde = pde
        self.n_plots = n_plots

        m = n_plots
        n = pde.num_steps
        pde.save_for_plot = True
        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        self.pde = pde

    def produce(self, prefix):
        return Plotter(self.pde, self.n_plots, prefix)


class Plotter:
    def __init__(self, pde, n_plots, prefix):
        m = n_plots
        n = pde.num_steps
        pde.save_for_plot = True
        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        self.pde = pde
        os.makedirs(prefix, exist_ok=True)
        self.prefix = prefix

    def plot(self, iter, μ):
        self.pde.do_plot(iter, μ, self.plot_frames, self.prefix)


class QuadraticRegularisation:
    """
    Quadratic regularisation functional for the auxiliary variables
    """

    def _own(self, x):
        return x if isinstance(x, float) else own(x)

    def __init__(self, α, base=None):
        self.α = α
        if base is not None:
            (k, (c1, c2)) = base
            self.base = (self._own(k), (self._own(c1), self._own(c2)))
            self.base_norm_squared = (nn2(k), (nn2(c1), nn2(c2)))
        else:
            self.base = None

    def _apply1(self, x, b, n):
        if isinstance(x, float):
            d = x if b is None else x - b
            return d * d / 2.0
        else:
            if b is not None:
                return (dx.norm2_squared(x) + n) / 2.0 + dx.dot(x, b)
            return dx.norm2_squared(x) / 2.0

    def _get_base(self):
        return (None, (None, None)) if self.base is None else self.base

    def apply(self, x):
        """
        Apply the quadratic regularisation functional to `x`.
        """
        (k, (c1, c2)) = x
        (kb, (c1b, c2b)) = self._get_base()
        (nkb, (nc1b, nc2b)) = self.base_norm_squared

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α

        return αk * self._apply1(k, kb, nkb) + αc * (
            self._apply1(c1, c1b, nc1b) + self._apply1(c2, c2b, nc2b)
        )

    def _prox1(self, γ, x, b):
        β = 1.0 / (1.0 + γ)
        if isinstance(x, float):
            return β * (x if b is None else x + b * γ)
        else:
            # x = own(x)
            vx = x.x.array
            if b is not None:
                vx += b.x.array * γ
            vx *= β
            return x

    def prox(self, τ, x):
        """
        Apply the proximal map of the quadratic regularisation functional to `x`.
        WARNING: This function is unsafe. It may modify `x´.
        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
        copied or owned instance to Python.
        """
        k, (c1, c2) = x
        (kb, (c1b, c2b)) = self._get_base()

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α

        return (
            self._prox1(τ * αk, k, kb),
            (self._prox1(τ * αc, c1, c1b), self._prox1(τ * αc, c2, c2b)),
        )


class LogPowerRegularisation:
    """
    Logarithmic barrier + quadratic regularisation functional for the auxiliary variables;
    $-α\\log(x) + ωx^2$.
    """

    def _own(self, x):
        return x if isinstance(x, float) else own(x)

    def __init__(self, α, ω):
        self.α = α
        self.ω = ω

    def _apply1(self, x, α, ω):
        if isinstance(x, float):
            return -α * np.log(x) + ω * x**2
        else:
            # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
            return (-α * np.log(x) + ω * x**2).sum()

    def apply(self, x):
        (k, (c1, c2)) = x

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α
        if isinstance(self.ω, float):
            ωk, ωc = self.ω, self.ω
        else:
            ωk, ωc = self.ω

        return (
            self._apply1(k, αk, ωk)
            + self._apply1(c1, αc, ωc)
            + self._apply1(c2, αc, ωc)
        )

    def _prox1(self, α, ω, x):
        # OC: 0 = -α/z + 2ωz + z-x ⟺ -α + (2ω+1)z^2 - xz = 0.
        # ⟺ z = -x ± √(x^2 + 4α(2ω+1))/(2(2ω+1))
        # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
        return (x + np.sqrt(x**2 + 4 * α * (2 * ω + 1))) / (2 * (2 * ω + 1))

    def prox(self, τ, x):
        """
        Apply the proximal map of the logarithmic regularisation functional to `x`.
        WARNING: This function is unsafe. It may modify `x´.
        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
        copied or owned instance to Python.
        """
        k, (c1, c2) = x

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α
        if isinstance(self.ω, float):
            ωk, ωc = self.ω, self.ω
        else:
            ωk, ωc = self.ω

        return (
            self._prox1(τ * αk, τ * ωk, k),
            (self._prox1(τ * αc, τ * ωc, c1), self._prox1(τ * αc, τ * ωc, c2)),
        )


class BoxConstraints:
    """
    Simple box constraints on the  auxiliary variables
    """

    def _own(self, x):
        return x if isinstance(x, float) else own(x)

    def __init__(self, α, ω):
        self.α = α
        self.ω = ω

    def _apply1(self, x, α, ω):
        if isinstance(x, float):
            if α <= x and x <= ω:
                return 0.0
            else:
                return np.inf
        else:
            raise Exception("Unimplemented")

    def apply(self, x):
        (k, (c1, c2)) = x

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α
        if isinstance(self.ω, float):
            ωk, ωc = self.ω, self.ω
        else:
            ωk, ωc = self.ω

        return (
            self._apply1(k, αk, ωk)
            + self._apply1(c1, αc, ωc)
            + self._apply1(c2, αc, ωc)
        )

    def _prox1(self, α, ω, x):
        # OC: 0 = -α/z + 2ωz + z-x ⟺ -α + (2ω+1)z^2 - xz = 0.
        # ⟺ z = -x ± √(x^2 + 4α(2ω+1))/(2(2ω+1))
        # Numpy is Matlab-grade junk. No reduce! Temporary arrays! In 2026! WTF?!?!?!?
        return max(min(ω, x), α)

    def prox(self, τ, x):
        """
        Apply the proximal map of the box constraints to `x`.
        WARNING: This function is unsafe. It may modify `x´.
        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
        copied or owned instance to Python.
        """
        k, (c1, c2) = x

        if isinstance(self.α, float):
            αk, αc = self.α, self.α
        else:
            αk, αc = self.α
        if isinstance(self.ω, float):
            ωk, ωc = self.ω, self.ω
        else:
            ωk, ωc = self.ω

        return (
            self._prox1(αk, ωk, k),
            (self._prox1(αc, ωc, c1), self._prox1(αc, ωc, c2)),
        )


"""
Quadratic regularisation with box contraints
"""


class BoxedQuadraticRegularisation:
    def __init__(self, ω0, ω1, α, base=None):
        self.box = BoxConstraints(ω0, ω1)
        self.quadratic = QuadraticRegularisation(α, base)

    def apply(self, x):
        return self.box.apply(x) + self.quadratic.apply(x)

    def prox(self, τ, x):
        return self.box.prox(1.0, self.quadratic.prox(τ, x))

mercurial