import numpy as np
from dolfinx import fem, geometry
from dolfinx_access import cell_FunctionSpace_f64
from mpi4py import MPI
from petsc4py import PETSc
from pointsource_pde.dolfinx_extras import get_mass_matrix
from scipy.sparse.linalg import svds
from slepc4py import SLEPc


class LaserSampling:
    """
    A linear operator class for laser sampling of a concentration (fem.Function)
    """

    def __init__(
        self,
        V,
        domain,
        nx,
        ny,
        beams,
        num_segments=100,
        domain_size=[-2, 2, -2, 2],
    ):
        self.V = V
        self.domain = domain
        self.nx, self.ny = nx, ny
        self.M = len(beams)
        self.domain_size = domain_size
        self.num_segments = num_segments

        # Build R matrix (single function does everything)
        self.R = self.matrix(beams)

        A = get_mass_matrix(V)
        solver = PETSc.KSP().create(A.comm)
        solver.setOperators(A)
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)
        self.solver = solver

        # ‖x‖^2 = ‖Ax.array‖^2
        # ‖Lx‖^2 = ‖Rx.array‖^2 = ‖RA^{-1}Ax.array‖^2 ≤ ‖RA^{-1}‖‖x‖^2 ≤ ‖R‖‖A^{-1}‖‖x‖^2,
        # so we need the inverse of the minimal eigenvalue.
        E = SLEPc.EPS().create(A.comm)
        E.setOperators(A)
        E.setProblemType(SLEPc.EPS.ProblemType.NHEP)
        E.setWhichEigenpairs(SLEPc.EPS.Which.SMALLEST_MAGNITUDE)
        # E.setWhichEigenpairs(SLEPc.EPS.Which.LARGEST_MAGNITUDE)
        E.setFromOptions()
        E.solve()

        # sigma = abs(E.getEigenvalue(0))
        # self._opnorm = 1.0 / sigma
        # print("Full sampling opnorm = %f" % self._opnorm)

        # SVD operator norm
        _u, s, _v = svds(self.R, k=1)
        self._opnorm = s[0]  # / sigma
        # print(f"Laser sampling opnorm = {self._opnorm:.6f}")

    def matrix(self, beams):
        """Combined: generate beams + segment + cell collision + dof weighting"""
        dofmap = self.V.dofmap
        tdim = self.domain.topology.dim

        # Geometry trees

        bb_tree = geometry.bb_tree(self.domain, tdim)

        n_dofs = self.V.dofmap.index_map.size_global
        R = np.zeros((self.M, n_dofs))

        xmin, xmax, ymin, ymax = self.domain_size

        for beam_idx, (start_pt, end_pt) in enumerate(beams):
            # Segment into midpoints
            vec = end_pt - start_pt
            total_len = np.linalg.norm(vec)
            seg_len = total_len / self.num_segments

            midpoints = []
            for i in range(self.num_segments):
                t = (i + 0.5) / self.num_segments
                midpoint = start_pt + t * vec
                midpoints.append(midpoint)
            # midpoints = np.array(midpoints, dtype=np.float64)  # Force float64

            # Ensure shape (N, 3), C-contiguous, read-only for DOLFINx
            # assert midpoints.shape[1] == 2, (
            #     "Expected 2D points, got shape[1]=2 but need padding to 3D"
            # )
            # midpoints = np.pad(
            #     midpoints, ((0, 0), (0, 1)), mode="constant"
            # )  # (N, 3) with z=0
            # midpoints = np.ascontiguousarray(midpoints)  # C-order
            # midpoints.setflags(write=False)  # Read-only requirement

            # Find colliding cells
            # cell_candidates = geometry.compute_collisions_points(bb_tree, midpoints)
            # colliding_cells = geometry.compute_colliding_cells(
            #    self.domain, cell_candidates, midpoints
            # )

            # Cell→DOFs→weight accumulation
            # cell_lists = list(colliding_cells.array)
            # for seg_i, cell_idx_raw in enumerate(cell_lists):
            #    cell_idx = int(cell_idx_raw)  # Convert np.int32 → Python int
            for point in midpoints:
                pad_point = np.array([point[0], point[1], 0.0])
                cell_idx = cell_FunctionSpace_f64(self.V._cpp_object, pad_point)
                if cell_idx >= 0:  # Valid cell index
                    # Get global DOFs
                    cell_dofs_local = dofmap.cell_dofs(cell_idx)
                    # Convert local dofs to numpy int32 array (DOLFINx requirement)
                    # local_dofs_array = np.asarray(cell_dofs_local, dtype=np.int32)
                    # local_dofs_array.setflags(write=False)  # Read-only requirement

                    # Batch convert all local indices to global in one call
                    cell_dofs_global = dofmap.index_map.local_to_global(cell_dofs_local)
                    cell_dofs = np.unique(cell_dofs_global)

                    weight = seg_len / len(cell_dofs)
                    R[beam_idx, cell_dofs] += weight  # ADD for overlaps!

        return R

    def apply(self, x):
        """
        This does not work with MPI
        """
        x.x.scatter_forward()
        return self.R @ x.x.array

    def diff_adjdir(self, j, _x):
        """
        This does not work with MPI
        """
        # We need ⟨Rx,v⟩_{ℝ^n} = ⟨x,R^*v⟩_V
        # We have  ⟨x,R^*v⟩_V = ⟨Ax.array,[R^*v].array⟩_{ℝ^m}
        # But Rx = R₀ x.array, so we need
        # ⟨x.array,R₀^* v⟩_{ℝ^n} = ⟨Ax.array,[R^*v].array⟩_{ℝ^m}
        # Taking [R^*v].array=A^{-1}R₀^* v, the conjugate works correctly.
        tmp = fem.Function(self.V)
        np.matmul(self.R.T, j, out=tmp.x.array)
        tmp.x.scatter_forward()
        return tmp
        # TODO: is convection_diffusion actualy based on norms in ℝ^n?
        # res = fem.Function(self.V)
        # res.x.array[:] = 0.0
        # self.solver.solve(tmp.x.petsc_vec, res.x.petsc_vec)
        # res.x.scatter_forward()
        # return res

    def opnorm(self, *args):
        # raise NotImplementedError("Need mesh weighting?")
        return self._opnorm

    def lipschitz_factor(self, *args):
        return self.opnorm(*args)

    def diff_chain_lipschitz_factor(self, *args):
        return self.opnorm(*args)

    def diff_bound(self, *args):
        return self.opnorm(*args)

    # Construct observation noise
    def noise(self, noise_level, rng=None):
        if rng is None:
            rng = np.random.default_rng()
        M = self.M
        # Generate noise with shape (M, 1) using scalar standard deviation
        noise = rng.normal(0, noise_level, size=(M,))
        return noise
