src/convection_diffusion.py

Thu, 26 Feb 2026 09:32:12 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 09:32:12 -0500
changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
permissions
-rw-r--r--

Initial working version for known convectivity and diffusivity

import os
from dataclasses import dataclass
from posix import mkdir
from termios import B0
from typing import Optional, Tuple, Union

import dolfinx.geometry as geo
import numpy as np
import ufl
from dolfinx import fem, mesh
from dolfinx.fem.petsc import (
    apply_lifting,
    assemble_matrix,
    assemble_vector,
    create_vector,
    set_bc,
)
from mpi4py import MPI
from petsc4py import PETSc

try:
    import pointsource_pde.dolfinx_extras as dx

    from dolfinx_access import (
        cell_FunctionSpace_f64,
        max_Function_f64_p2,
        min_Function_f64_p2,
    )
except:
    import src.dolfinx_extras as dx


@dataclass
class XBound:
    # Parameter bounds for convection-diffusion operator estimates
    mu_dual: float = 3.0  # ||μ||_M(Ω) bound
    k_min: float = 0.1  # min diffusion > 0 (coercivity)
    k_max: float = 2.0  # max diffusion
    c_Linf: float = 0.5  # max(|c1|,|c2|) L∞ bound
    diam: float = 1.41  # domain diameter (√2 for unit square)
    T: float = 1.0  # final time

    def __add__(self, other: "XBound") -> "XBound":
        # Conservative combination: min(k_min), max(c_Linf) for energy bounds"""
        other = XBound(
            **{
                k: getattr(other, k, getattr(self, k))
                for k in self.__dataclass_fields__
            }
        )
        return XBound(
            mu_dual=max(self.mu_dual, other.mu_dual),
            k_min=min(self.k_min, other.k_min),
            k_max=max(self.k_max, other.k_max),
            c_Linf=max(self.c_Linf, other.c_Linf),
            diam=max(self.diam, other.diam),
            T=max(self.T, other.T),
        )


class ConvectionDiffusion:
    def __init__(
        self,
        u0,
        g,
        # w0,
        domain_size=[-2, 2, -2, 2],
        nx=32,
        ny=32,
        degree=2,
        t0=0.0,
        T=1.0,
        num_steps=50,
        p=None,
        Delta_t=None,
        C_stab=1.0,
        alpha=1.0,  # \alpha = k_m/2, k convective coefficient, k>=km
        beta=1.0,  # \beta = \| c \|_\infty^2 /k_m
        C_emb=1.0,
        C_reg=1.0,
    ):
        self.p = p  # L^p, p=2
        self.C_stab = C_stab  # adjoint stability constant
        self.alpha = alpha  # coercivity lower bound for k
        self.beta = beta  # theoretical parameter
        self.C_emb = C_emb  # embedding constant
        self.C_reg = C_reg
        if Delta_t is None:
            self.Delta_t = T / num_steps  # Matches your N=50 intent
        else:
            self.Delta_t = Delta_t

        self.domain_size = domain_size
        self.nx = nx
        self.ny = ny
        self.degree = degree
        self.T = T

        self.num_steps = num_steps
        self.dt = (T - t0) / num_steps
        self.t0 = t0

        self.save_for_plot = False

        domain = mesh.create_rectangle(
            MPI.COMM_SELF,
            [
                np.array([domain_size[0], domain_size[2]]),
                np.array([domain_size[1], domain_size[3]]),
            ],
            [nx, ny],
            mesh.CellType.triangle,
        )
        self.domain = domain
        # Compute the domain volume for the adjoint Lipschitz factor
        self.domain_size = domain_size
        # Domain diameter from your domain_size
        self.diam = np.sqrt(
            (domain_size[1] - domain_size[0]) ** 2
            + (domain_size[3] - domain_size[2]) ** 2
        )

        V = fem.functionspace(domain, ("Lagrange", degree))
        self.V = V

        self.u0 = fem.Function(V)
        self.u0.name = "u0"
        self.u0.interpolate(u0)

        self.g = fem.Function(V)
        self.g.name = "g"
        self.g.interpolate(g)
        self.bcs = []  # Initialize BCS as LIST
        self.adjoint_bcs = []  # Initialize ADJOINT BCS as LIST (w = 0)

        # Identify boundary facets for Dirichlet BC location (e.g. left/right boundaries at x=domain_size[0] or x=domain_size[1])
        fdim = domain.topology.dim - 1
        # Γ₁ facets: TOP/BOTTOM boundaries (y = ymin, ymax) - STATE Dirichlet u = g
        gamma1_facets = mesh.locate_entities_boundary(
            domain,
            fdim,
            lambda x: (
                np.isclose(x[1], domain_size[2])  # x[1] = y = ymin
                | np.isclose(x[1], domain_size[3])
            ),  # x[1] = y = ymax
        )
        # Γ₂ facets: LEFT/RIGHT boundaries (x = xmin, xmax) - ADJOINT Dirichlet w = 0
        gamma2_facets = mesh.locate_entities_boundary(
            domain,
            fdim,
            lambda x: (
                np.isclose(x[0], domain_size[0])  # x[0] = x = xmin
                | np.isclose(x[0], domain_size[1])
            ),  # x[0] = x = xmax
        )
        # STATE BC: u = g on Γ₁ (top/bottom)
        dofs_gamma1 = fem.locate_dofs_topological(V, fdim, gamma1_facets)
        # Create BC and ADD TO LIST
        bc = fem.dirichletbc(self.g, dofs_gamma1)  # V inferred from self.g!
        self.bcs.append(bc)

        # ADJOINT BC: w = 0 on Γ₁ (same facets, homogenized)
        zero = fem.Constant(domain, PETSc.ScalarType(0.0))
        dofs_gamma2 = fem.locate_dofs_topological(V, fdim, gamma2_facets)
        bc_adjoint = fem.dirichletbc(zero, dofs_gamma2, V)
        self.adjoint_bcs.append(bc_adjoint)

        self.fdim = fdim
        self.boundary_facets = np.unique(np.concatenate((gamma1_facets, gamma2_facets)))

        ux = fem.Function(V)
        wpx = fem.Function(V)
        self.ux = ux
        self.wpx = wpx
        dim = V.mesh.topology.dim
        self.expr2_form = fem.form(ufl.inner(ufl.grad(ux), ufl.grad(wpx)) * ufl.dx)
        self.expr3_forms = tuple(
            fem.form(wpx * ufl.grad(ux)[i] * ufl.dx) for i in range(dim)
        )

        # Only scalar case
        self.k = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
        self.c0 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[0] scalar
        self.c1 = fem.Constant(domain, PETSc.ScalarType(0.0))  # c[1] scalar
        u_n = fem.Function(V)  # New copy!
        u_n.name = "u_n"
        self.u_n = u_n
        u, w, v = ufl.TrialFunction(V), ufl.TrialFunction(V), ufl.TestFunction(V)
        dt = self.dt
        # Forward bilinear form (backward Euler)
        a = (
            u * v * ufl.dx
            + dt * ufl.dot(self.k * ufl.grad(u), ufl.grad(v)) * ufl.dx
            + dt * (self.c0 * u.dx(0) + self.c1 * u.dx(1)) * v * ufl.dx
        )
        self.bilinear_form = fem.form(a)
        L = u_n * v * ufl.dx
        self.linear_form = fem.form(L)

        # Adjoint bilinear form
        w_n = fem.Function(V)  # New copy!
        w_n.name = "w_n"
        self.w_n = w_n
        j_u = fem.Function(V)  # New copy!
        j_u.name = "j_u"
        self.j_u = j_u
        a2 = (
            w * v * ufl.dx
            + dt
            * (
                ufl.dot(self.k * ufl.grad(w), ufl.grad(v))
                - (self.c0 * w.dx(0) * v + self.c1 * w.dx(1) * v)
            )
            * ufl.dx
        )
        self.bilinear_form_w = fem.form(a2)
        # Step n=N-1: L2 uses w_n = w^N=0 → solves w^{N-1}. (∂t)_primal^*
        # Step n=N-2: L2 uses w_n = w^{N-1} → solves w^{N-2}
        # Replace your L2 with 2 simple forms (same notation)

        L2 = w_n * v * ufl.dx + dt * j_u * v * ufl.dx  # kown w_n (previous step)
        self.linear_form_w = fem.form(L2)

    # Solving forward PDE
    def apply(self, x):
        μ, (k, c) = x

        domain = self.domain
        V = self.V
        dt = self.dt
        bcs = self.bcs

        # initial condition
        u_n = self.u_n
        u_n.x.array[:] = self.u0.x.array
        u_n.x.scatter_forward()

        # Linear form: only contains u_n * v (Dirichlet on Γ1 is handled via bc)

        # Assemble matrix WITH Dirichlet BCs applied
        self.c0.value = c[0]
        self.c1.value = c[1]
        self.k.value = k
        linear_form = self.linear_form
        bilinear_form = self.bilinear_form
        A = assemble_matrix(bilinear_form, bcs=bcs)
        A.assemble()

        # Create reusable RHS vector (will be filled each timestep)
        b = create_vector(linear_form)
        b_ps = create_vector(linear_form)

        # Prepare solver once
        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        # Form point source contribution
        mesh = V.mesh
        element = V.element.basix_element
        mesh_nodes = mesh.geometry.x
        cmap = mesh.geometry.cmap

        b_ps.assemblyBegin()

        with b_ps.localForm() as loc_b:
            loc_b.set(0)

        for point, alpha in μ.iter_padded():
            cell_id = cell_FunctionSpace_f64(V._cpp_object, point)

            if cell_id < 0:
                print(f"Point {point} is outside the mesh.")
            else:
                cell_dofs = V.dofmap.cell_dofs(cell_id)

                geom_dofs = mesh.geometry.dofmap[cell_id]
                ref = cmap.pull_back(
                    np.array([[point[0], point[1]]]), mesh_nodes[geom_dofs]
                )
                phi = element.tabulate(0, ref)

                b_ps.setValuesLocal(
                    cell_dofs,
                    alpha * phi,
                    addv=PETSc.InsertMode.ADD_VALUES,
                )

        b_ps.assemblyEnd()

        t = self.t0
        num_steps = self.num_steps
        bcs = self.bcs

        u_n_list = []
        for i in range(num_steps):
            t += dt

            # b.assemblyBegin()

            # zero-out and assemble RHS (M * u_n part)
            with b.localForm() as loc_b:
                loc_b.set(0)
            assemble_vector(b, linear_form)  # b = (u^n, v)

            b.axpy(1.0, b_ps)

            # Apply homogeneous Dirichlet BC correction to RHS and finalize vector
            apply_lifting(b, [bilinear_form], [bcs])  # [[bc1, bc2, ...]]
            b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b, bcs)  # list [bc1, bc2, ...]
            # b.assemble()
            #
            # b.assemblyEnd()

            # b.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

            # Solve linear problem A * uh = b
            uh = fem.Function(V)
            uh.name = "uh"
            solver.solve(b, uh.x.petsc_vec)
            uh.x.scatter_forward()

            # Update u_n for next step, store uh
            u_n.x.array[:] = uh.x.array
            u_n.x.scatter_forward()
            u_n_list.append(uh)  # Unique objects

        return u_n_list

    # Solving the adjoint problem

    def solve_adjoint_pde(self, j, x, u_n_list):
        μ, (k, c) = x

        domain = self.domain
        V = self.V
        dt = self.dt
        adjoint_bcs = self.adjoint_bcs

        num_steps = len(u_n_list)

        assert len(j) == num_steps, f"MISMATCH: j({len(j)}) != u_n_list({num_steps})"
        # Terminal condition: w(T) = 0
        w_n = self.w_n
        w_n.x.array[:] = 0.0
        w_n.x.scatter_forward()

        self.c0.value = c[0]
        self.c1.value = c[1]
        self.k.value = k
        j_u = self.j_u
        bilinear_form_w = self.bilinear_form_w
        linear_form_w = self.linear_form_w

        A2 = assemble_matrix(bilinear_form_w, bcs=adjoint_bcs)  # Fixed bcs
        A2.assemble()

        # Create vector for RHS to be updated each step
        b2 = create_vector(linear_form_w)

        solver = PETSc.KSP().create(domain.comm)
        solver.setOperators(A2)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        t = self.t0 + num_steps * dt  # Start at time T

        # N+1 adjoint values to match N PDE steps
        w_n_list = [None] * (num_steps + 1)
        whT = fem.Function(V)
        whT.x.array[:] = 0.0  # w^N = 0
        whT.x.scatter_forward()
        w_n_list[num_steps] = whT  # Store at FINAL index

        #  backward steps**: compute w^{N-1}, ..., w^0

        for i in range(num_steps - 1, -1, -1):
            t -= dt

            # Source term at current time step: j(t_i)

            j_u.x.array[:] = j[i].x.array
            j_u.x.scatter_forward()

            # Update RHS
            with b2.localForm() as loc_b2:
                loc_b2.set(0)
            assemble_vector(b2, linear_form_w)

            apply_lifting(b2, [bilinear_form_w], [adjoint_bcs])  #  [forms] → [bcs]
            b2.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            set_bc(b2, adjoint_bcs)

            # set_bc(b2, adjoint_bcs)  # list BC object
            # b2.assemble()

            wh = fem.Function(V)
            wh.name = "wh"
            solver.solve(b2, wh.x.petsc_vec)
            wh.x.scatter_forward()

            # Update w_n for next iteration and store
            w_n.x.array[:] = wh.x.array
            w_n.x.scatter_forward()
            w_n_list[i] = wh

        return w_n_list

    # Compute the differential operator
    def diff_adjdir(self, j, x, apply_result=None):
        """
        Returns dual space elements: (dual_μ, (dual_k, (dual_c1, dual_c2)))
        - scalar coeff → float sensitivity
        - Function coeff → fem.Function sensitivity (Fréchet derivative)
        """
        u_n_list = apply_result if apply_result is not None else self.apply(x)
        w_n_list = self.solve_adjoint_pde(j, x, u_n_list)
        dt = self.dt
        V = self.V
        dim = V.mesh.topology.dim

        # Extract coefficients and check types
        _mu, (k, c) = x
        # is_scalar_mu = isinstance(mu, (int, float, np.number))
        is_scalar_k = isinstance(k, (int, float, np.number))
        is_scalar_c = all(isinstance(ci, (int, float, np.number)) for ci in c)

        # 1. expr1: dual sensitivity w.r.t. μ = ∫₀^T w_p dt (always Function)
        wp_bar = fem.Function(V)
        wp_bar.x.array[:] = (
            0.5
            * dt
            * (
                w_n_list[0].x.array
                + w_n_list[-1].x.array
                + 2 * np.sum([w.x.array for w in w_n_list[1:-1]], axis=0)
            )
        )

        ux = self.ux
        wpx = self.wpx

        # 2. expr2: dual sensitivity w.r.t. k
        if is_scalar_k:
            # k scalar → ⟨expr2, δk⟩ = δk ∫ ∇u·∇w_p → scalar
            expr2 = 0.0

            form = self.expr2_form

            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                ux.x.array[:] = u.x.array[:]
                ux.x.scatter_forward()
                wpx.x.array[:] = wp.x.array[:]
                wpx.x.scatter_forward()
                v2 = fem.assemble_scalar(form)
                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                expr2 -= weight * v2

        else:
            raise Exception("Unimplemented - out of date")

            # k Function → ⟨expr2, δk⟩ = ∫ expr2 δk → expr2 = ∇u·∇w_p (pointwise)
            expr2 = fem.Function(V)
            expr2a = expr2.x.array

            for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                form = ufl.inner(ufl.grad(u), ufl.grad(wp)) * ufl.dx
                v2 = fem.assemble_scalar(fem.form(form))
                weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                expr2a -= weight * v2

        # 3. expr3: dual sensitivity w.r.t. c = (c1, c2)
        if is_scalar_c:
            # Scalar c → return tuple of 2 floats

            forms = self.expr3_forms

            def val(i):
                expr3i = 0.0
                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                    ux.x.array[:] = u.x.array[:]
                    ux.x.scatter_forward()
                    wpx.x.array[:] = wp.x.array[:]
                    wpx.x.scatter_forward()
                    v3 = fem.assemble_scalar(forms[i])
                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                    expr3i -= weight * v3
                return expr3i

            expr3 = tuple(val(i) for i in range(dim))

        else:
            raise Exception("Unimplemented - out of date")

            # Function c → return tuple of 2 Functions (pointwise derivatives)
            def val(i):
                # expr3[i]. x.array[:] = 0
                expr3i = fem.Function(V)
                expr3ia = expr3i.x.array[:]
                for n, (u, wp) in enumerate(zip(u_n_list, w_n_list)):
                    form = wp * ufl.grad(u)[i] * ufl.dx
                    v3 = fem.assemble_scalar(fem.form(form))
                    weight = 0.5 * dt if (n == 0 or n == len(u_n_list) - 1) else dt
                    expr3ia -= weight * v3
                return expr3i

            expr3 = tuple(val(i) for i in range(dim))

        if self.save_for_plot:
            self.plot_wp_bar = wp_bar
            self.plot_u_n_list = u_n_list
            self.plot_w_n_list = w_n_list
            self.plot_j_n_list = j

        return wp_bar, (expr2, expr3)

    def do_plot(self, iter, μ, frames, prefix):
        n = self.num_steps
        np.savez_compressed(
            "%s/res_%d.npz" % (prefix, iter),
            wp_bar_array=self.plot_wp_bar.x.array,
            u_n_list_array=[self.plot_u_n_list[i].x.array for i in frames],
            w_n_list_array=[self.plot_w_n_list[i].x.array for i in frames],
            j_n_list_array=[self.plot_j_n_list[i].x.array for i in frames],
            frames=frames,
            times=list(map(lambda i: self.T * i / (n - 1), frames)),
            mu=[[point[0], point[1], alpha] for point, alpha in μ.iter_padded()],
        )

    def _stability_prefactor(
        self, *, C_stab=None, alpha=None, beta=None, T=None, C_emb=None, C_reg=None
    ):
        """
        Return the stability prefactor
        P := (C_stab/alpha) * (1 + sqrt(T)*(1 + sqrt(beta)/sqrt(alpha)))
        and the measure-to-functional constant L_e_mu := C_emb * C_reg * sqrt(T).
        Return (P, L_e_mu):
        """
        # self._stability_prefactor(alpha=0.5, T=2.0)

        # Read from arguments or instance attributes
        C_stab = C_stab if C_stab is not None else self.C_stab
        alpha = alpha if alpha is not None else self.alpha
        beta = beta if beta is not None else self.beta
        T = T if T is not None else self.T
        C_emb = C_emb if C_emb is not None else self.C_emb
        C_reg = C_reg if C_reg is not None else self.C_reg

        # Stability factor P

        P = (C_stab / alpha) * (1 + np.sqrt(T) * (1 + np.sqrt(beta) / np.sqrt(alpha)))
        L_e_mu = C_emb * C_reg * np.sqrt(T)

        return P, L_e_mu

    # Solution operator Lipschitz factor
    def lipschitz_factor(self):
        """
        Conservative Lipschitz factor for the forward solution operator mapping
        (k,c,mu) -> u. Returns a single scalar C_Lip such that
        ||u2 - u1|| <= C_Lip * (||k2-k1|| + ||c2-c1|| + ||mu2-mu1||_* )
        (we use a conservative form combining the measure and (k,c) pieces).
        """
        P, L_e_mu = self._stability_prefactor()
        # Conservative combined choice: multiply P by max(1, L_e_mu)
        C_Lip = P * max(1.0, L_e_mu)
        return C_Lip

    # Adjoint solution operator Lipschitz factor
    """
    def diff_adj_lipschitz_factor(self):

        Compute adjoint-solution Lipschitz factor A for 2D domain.

        For p = 2, A = 1.
        For p > 2, A = |Omega|^(1/2 - 1/p), where |Omega| is the mesh area.

        p = self.p
        omega_vol = self.domain_volume
        return 1.0 if p == 2 else omega_vol ** (0.5 - 1.0 / p)
    """

    def diff_adj_lipschitz_factor(self):
        """
        Always return 1 for p=2
        """
        return 1.0

    def diff_bound(self, xbound=None):
        """
        Differential bound for the solution operator derivative.

        Returns the tuple: (L_e_mu, L_e_k, L_e_c)

         where
        L_e_k = ||∇u||_{L^2(Ω_T)}
        L_e_c = ||∇u||_{L^2(Ω_T)}
        L_e_mu = C_emb * C_reg * sqrt(T)
        """

        # If grad_u_norm is None, compute it from u0
        grad_u_norm = getattr(self, "grad_u_norm", None)
        if grad_u_norm is None:
            u = self.u0
            V = self.V
            # Compute L2 norm of gradient of u0 over the domain
            grad_u = ufl.grad(u)
            grad_u_sq = ufl.inner(grad_u, grad_u)
            dx = ufl.Measure("dx", domain=self.domain)
            # Integral of |∇u|^2 over Ω
            grad_u_norm = fem.assemble_scalar(fem.form(grad_u_sq * dx)) ** 0.5
            # Optionally store it for later reuse
            self.grad_u_norm = grad_u_norm

        return grad_u_norm

    def diff_bound_pair(self, xbound=None, Delta_t=None):
        if xbound is None:
            return self.diff_bound(xbound=None)
        else:
            xb_combined = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]

            k_min = getattr(xb_combined, "k_min", self.alpha)
            c_Linf = getattr(xb_combined, "c_Linf", np.sqrt(self.beta * k_min))
            diam = getattr(xb_combined, "diam", self.diam)
            T = getattr(xb_combined, "T", self.T)

            Pe = c_Linf * diam / k_min
            gamma = c_Linf**2 / k_min
            # adaptive time step size Delta_t (Delta_t = 0.01 in the test)
            if Delta_t is None:
                Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)  # CFL-like

            # Accumulate N = T/Delta_t local steps
            N = int(np.ceil(T / Delta_t))

            # Per-step factors
            exp_local = np.exp(0.5 * gamma * Delta_t)
            C_adj_step = np.sqrt(Delta_t * (exp_local + Delta_t / k_min))
            C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + Pe**2)

            # Total accumulated bound (geometric sum <= N * max_step)
            C_adj_PDE = N * C_adj_step
            C_state = N * C_state_step

            # Use the embedding constant
            C_mu_adj = self.C_emb * C_adj_PDE  # ∫φ δμ
            C_k_adj = C_state * C_adj_PDE  # ∫∇u·∇φ δk
            C_c1_adj = C_state * C_adj_PDE * c_Linf  # ∫∂x u φ δc1
            C_c2_adj = C_state * C_adj_PDE * c_Linf  # ∫∂y u φ δc2

            return C_mu_adj, C_k_adj, C_c1_adj, C_c2_adj

    # ||S(x)||_Y→X ≤ C_state ||μ||_M(Ω) [time-independent μ]
    def codomain_bound(self, xbound=None):
        if xbound is None:
            return 1.0

        # Extract parameters
        if hasattr(xbound, "k_min"):
            xb = xbound
        else:
            if hasattr(xbound, "__len__") and len(xbound) > 0:
                xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
            else:
                xb = xbound
            k_min = getattr(xb, "k_min", self.alpha)
            c_Linf = getattr(xb, "c_Linf", np.sqrt(self.beta * k_min))
            mu_dual = getattr(xb, "mu_dual", 3.0)
            diam = getattr(xb, "diam", self.diam)
            T = getattr(xb, "T", self.T)

        Pe = c_Linf * diam / k_min
        if not hasattr(self, "Delta_t"):
            Delta_t = min(0.1, k_min / c_Linf**2, 0.01 * T)
        else:
            Delta_t = self.Delta_t

        N = int(np.ceil(T / Delta_t))
        gamma = c_Linf**2 / k_min
        exp_local = np.exp(0.25 * gamma * Delta_t)

        C_state_step = np.sqrt(Delta_t) * np.sqrt(1 + 0.5 * Pe**2) * exp_local
        C_state = N * C_state_step

        return C_state * mu_dual

        # ||S(x1)-S(x2)|| ≤ Cμ||Δμ|| + Ck||Δk|| + Cc1||Δc1|| + Cc2||Δc2||

    def codomain_bound_pair(self, xbound=None):
        if xbound is None:
            return self.codomain_bound(xbound=None)
        else:
            xb = xbound[0] + xbound[1] if len(xbound) > 1 else xbound[0]
            C_base = self.codomain_bound(xb)
            # Component
            C_mu = C_base
            C_k = C_base * 2.0
            C_c1 = C_base * 1.5
            C_c2 = C_base * 1.5

            return (C_mu, C_k, C_c1, C_c2)

    """
    def codomain_bound(self, xbound=None):
        return 1.0

    def codomain_bound_pair(self, xbound=None):
        # TODO: implement properly
        if xbound is None:
            return self.codomain_bound(xbound=None)
        else:
            return self.codomain_bound(xbound=xbound[0] + xbound[1])
    """

    def diff_bound3(self):
        """
        Returns full operator differential bounds:

        L_e_mu = C_emb * C_reg * sqrt(T)
        L_e_k  = ||∇u||_{L^2(Ω)}
        L_e_c  = ||∇u||_{L^2(Ω)}

        """

        # First get L_e_k = L_e_c
        grad_u_norm = self.diff_bound()

        # Compute L_e_mu from stability prefactor
        _P, L_e_mu = self._stability_prefactor()

        L_e_k = grad_u_norm
        L_e_c = grad_u_norm

        return L_e_mu, L_e_k, L_e_c

    # Solution operator Lipschitz factor separately wrt. μ and (k, c)
    def lipschitz_factor_pair(self):
        """
        Compute the forward solution operator Lipschitz factors
        separately with respect to the parameters (k, c) and μ.

        Returns:
        tuple: (L_mu, L_kc)
            L_mu  - Lipschitz factor w.r.t. μ
            L_kc  - Lipschitz factor w.r.t. (k, c)
        """
        # Get differential bounds for the solution operator
        L_e_mu, L_e_k, L_e_c = self.diff_bound3()

        # Lipschitz factor w.r.t μ
        L_mu = L_e_mu

        # Lipschitz factor w.r.t (k, c)
        # Conservative choice: max of L_e_k and L_e_c
        L_kc = max(L_e_k, L_e_c)

        return L_mu, L_kc

    # Adjoint solution operator Lipschitz factor separately wrt. μ and (k, c)
    def diff_adj_lipschitz_factor_pair(self):
        A = self.diff_adj_lipschitz_factor()

        # No separate analytic dependence, so return same A for both parts
        A_mu = A
        A_kc = A

        return A_mu, A_kc


def own(u):
    # """
    # Return a version of u that is guaranteed to be uniquely owned.
    # If u has other Python references, return a deep copy.
    # """
    # # sys.getrefcount(u) includes the temporary reference inside getrefcount,
    # # so '2' means exactly one external reference.
    # if sys.getrefcount(u) <= 2:
    #     return u  # safe: no other references

    # deep copy into a new Function
    u_new = fem.Function(u.function_space)
    # Does not work
    # u_new.x.petsc_vec.copy(u.x.petsc_vec)
    # u_new.x.scatter_forward()
    np.copyto(u_new.x.array, u.x.array)
    return u_new


def nn2(x):
    return None if isinstance(x, float) else dx.norm2_squared(x)


class PlotFactory:
    def __init__(self, pde, n_plots=5):
        self.pde = pde
        self.n_plots = n_plots

        m = n_plots
        n = pde.num_steps
        pde.save_for_plot = True
        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        self.pde = pde

    def produce(self, prefix):
        return Plotter(self.pde, self.n_plots, prefix)


class Plotter:
    def __init__(self, pde, n_plots, prefix):
        m = n_plots
        n = pde.num_steps
        pde.save_for_plot = True
        self.plot_frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        self.pde = pde
        os.makedirs(prefix, exist_ok=True)
        self.prefix = prefix

    def plot(self, iter, μ):
        self.pde.do_plot(iter, μ, self.plot_frames, self.prefix)


class QuadraticRegularisation:
    """
    Quadratic regularisation functional for the auxiliary variables
    """

    def _own(self, x):
        return x if isinstance(x, float) else own(x)

    def __init__(self, α, base=None):
        self.α = α
        if base is not None:
            (k, (c1, c2)) = base
            self.base = (self._own(k), (self._own(c1), self._own(c2)))
            self.base_norm_squared = (nn2(k), (nn2(c1), nn2(c2)))
        else:
            self.base = None

    def _apply1(self, x, b, n):
        if isinstance(x, float):
            d = x if b is None else x - b
            return d * d / 2.0
        else:
            if b is not None:
                return (dx.norm2_squared(x) + n) / 2.0 + dx.dot(x, b)
            return dx.norm2_squared(x) / 2.0

    def _get_base(self):
        return (None, (None, None)) if self.base is None else self.base

    def apply(self, x):
        """
        Apply the quadratic regularisation fucntional to `x`.
        """
        (k, (c1, c2)) = x
        (kb, (c1b, c2b)) = self._get_base()
        (nkb, (nc1b, nc2b)) = self.base_norm_squared

        return self.α * (
            self._apply1(k, kb, nkb)
            + self._apply1(c1, c1b, nc1b)
            + self._apply1(c2, c2b, nc2b)
        )

    def _prox1(self, τ, x, b):
        γ = self.α * τ
        β = 1.0 / (1.0 + γ)
        if isinstance(x, float):
            return β * (x if b is None else x + b * γ)
        else:
            # x = own(x)
            vx = x.x.array
            if b is not None:
                vx += b.x.array * γ
            vx *= β
            return x

    def prox(self, τ, x):
        """
        Apply the proximal map of the quadratic regularisation fucntional to `x`.
        WARNING: This function is unsafe. It may modify `x´.
        That is ok for our purposes, as Rust, being a safe language, already needs to pass a
        copied or owned instance to Python.
        """
        k, (c1, c2) = x
        (kb, (c1b, c2b)) = self._get_base()

        return (
            self._prox1(τ, k, kb),
            (self._prox1(τ, c1, c1b), self._prox1(τ, c2, c2b)),
        )

mercurial