# import weakref

import numpy as np
import ufl
from dolfinx import fem
from dolfinx.fem.function import Function, FunctionSpace

# _mass_matrices = weakref.WeakKeyDictionary()
_mass_matrices = {}


def get_mass_matrix(space: FunctionSpace):
    # idex = space # Does not work due to fenicx being garbage
    idx = id(space)  # FIXME: This leaks memory 🤯
    if idx not in _mass_matrices:
        u = ufl.TrialFunction(space)
        v = ufl.TestFunction(space)
        a = fem.form(ufl.inner(u, v) * ufl.dx)
        m = fem.petsc.assemble_matrix(a)
        m.assemble()
        _mass_matrices[idx] = m
    return _mass_matrices[idx]


def dot(f: Function, g: Function):
    if g.function_space != f.function_space:
        raise ValueError("Function spaces need to agree")
    m = get_mass_matrix(f.function_space)
    f_vec = f.x.petsc_vec
    g_vec = g.x.petsc_vec
    return g_vec.dot(m @ f_vec)


def norm2_squared(f: Function):
    return dot(f, f)


def norm2(f: Function):
    return norm2_squared(f).sqrt()


def dist2_squared(f: Function, g: Function):
    if g.function_space != f.function_space:
        raise ValueError("Function spaces need to agree")
    m = get_mass_matrix(f.function_space)
    f_vec = f.x.petsc_vec
    g_vec = g.x.petsc_vec
    mf = m @ f_vec
    mg = m @ g_vec
    return f_vec.dot(mf) + 2 * g_vec.dot(mg) + g_vec.dot(mf)


def dist2(f: Function, g: Function):
    return np.sqrt(dist2_squared(f, g))
