# Copyright (C) 2015-2025 Garth N. Wells and Jørgen S. Dokken
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
"""Unit tests for the discrete operators."""

import sys

from mpi4py import MPI

import numpy as np
import pytest

import dolfinx.la
import ufl
from basix.ufl import element
from dolfinx.fem import Expression, Function, discrete_curl, discrete_gradient, functionspace
from dolfinx.mesh import CellType, GhostMode, create_unit_cube, create_unit_square


@pytest.mark.parametrize(
    "mesh",
    [
        create_unit_square(MPI.COMM_WORLD, 11, 6, ghost_mode=GhostMode.none, dtype=np.float32),
        create_unit_square(
            MPI.COMM_WORLD, 11, 6, ghost_mode=GhostMode.shared_facet, dtype=np.float64
        ),
        create_unit_cube(MPI.COMM_WORLD, 4, 3, 7, ghost_mode=GhostMode.none, dtype=np.float64),
        create_unit_cube(
            MPI.COMM_WORLD, 4, 3, 7, ghost_mode=GhostMode.shared_facet, dtype=np.float32
        ),
    ],
)
def test_gradient(mesh):
    """Test discrete gradient computation for lowest order elements."""
    V = functionspace(mesh, ("Lagrange", 1))
    W = functionspace(mesh, ("Nedelec 1st kind H(curl)", 1))

    # N.B. do not scatter_rev G - doing so would transfer rows to other
    # processes where they will be summed to give an incorrect matrix
    G = discrete_gradient(V, W)

    num_edges = mesh.topology.index_map(1).size_global
    m, n = G.index_map(0).size_global, G.index_map(1).size_global
    assert m == num_edges
    assert n == mesh.topology.index_map(0).size_global
    assert np.isclose(G.squared_norm(), 2.0 * num_edges)


@pytest.mark.parametrize("cell", [CellType.triangle, CellType.quadrilateral])
def test_discrete_curl_gdim_raises(cell):
    """Test that discrete curl function raises for gdim != 3."""
    msh = create_unit_square(MPI.COMM_WORLD, 3, 3, cell_type=cell, dtype=np.float64)
    E0 = element("N1curl", msh.basix_cell(), 2, dtype=np.float64)
    E1 = element("RT", msh.basix_cell(), 2, dtype=np.float64)
    V0, V1 = functionspace(msh, E0), functionspace(msh, E1)
    with pytest.raises(RuntimeError):
        discrete_curl(V0, V1)


@pytest.mark.parametrize(
    "elements",
    [
        (
            element("Lagrange", "tetrahedron", 2, shape=(3,), dtype=np.float64),
            element("Lagrange", "tetrahedron", 2, shape=(3,), dtype=np.float64),
        ),
        (
            element("N1curl", "tetrahedron", 2, dtype=np.float64),
            element("N1curl", "tetrahedron", 2, dtype=np.float64),
        ),
        (
            element("RT", "tetrahedron", 2, dtype=np.float64),
            element("N1curl", "tetrahedron", 2, dtype=np.float64),
        ),
        (
            element("RT", "tetrahedron", 2, dtype=np.float64),
            element("RT", "tetrahedron", 2, dtype=np.float64),
        ),
    ],
)
def test_discrete_curl_map_raises(elements):
    """Test that discrete curl function raises for incorrect spaces."""
    msh = create_unit_cube(
        MPI.COMM_WORLD, 3, 3, 3, cell_type=CellType.tetrahedron, dtype=np.float64
    )
    V0, V1 = functionspace(msh, elements[0]), functionspace(msh, elements[1])
    with pytest.raises(RuntimeError):
        discrete_curl(V0, V1)


@pytest.mark.parametrize(
    "dtype",
    [
        np.float32,
        pytest.param(
            np.complex64,
            marks=pytest.mark.xfail(
                sys.platform.startswith("win32"),
                raises=NotImplementedError,
                reason="missing _Complex",
            ),
        ),
        np.float64,
        pytest.param(
            np.complex128,
            marks=pytest.mark.xfail(
                sys.platform.startswith("win32"),
                raises=NotImplementedError,
                reason="missing _Complex",
            ),
        ),
    ],
)
@pytest.mark.parametrize("p", range(2, 5))
@pytest.mark.parametrize(
    "element_data",
    [
        (CellType.tetrahedron, "Nedelec 1st kind H(curl)", "Raviart-Thomas"),
        (CellType.hexahedron, "Nedelec 1st kind H(curl)", "Raviart-Thomas"),
    ],
)
def test_discrete_curl(element_data, p, dtype):
    """Compute discrete curl operator, with verification using Expression."""
    xdtype = dtype(0).real.dtype

    celltype, E0, E1 = element_data
    N = 3
    msh = create_unit_cube(
        MPI.COMM_WORLD,
        N,
        N // 2,
        2 * N,
        ghost_mode=GhostMode.none,
        cell_type=celltype,
        dtype=xdtype,
    )

    # Perturb mesh (making hexahedral cells no longer affine) in serial.
    # Do not perturb in parallel - would make mesh con-conforming.
    rng = np.random.default_rng(0)
    delta_x = 1 / (2 * N - 1) if MPI.COMM_WORLD.size == 1 else 0
    msh.geometry.x[:] = msh.geometry.x + 0.2 * delta_x * (rng.random(msh.geometry.x.shape) - 0.5)

    V0 = functionspace(msh, (E0, p))
    V1 = functionspace(msh, (E1, p - 1))

    u0 = Function(V0, dtype=dtype)
    u0.interpolate(
        lambda x: np.vstack(
            (
                x[1] ** 4 + 3 * x[2] ** 2 + (x[1] * x[2]) ** 3,
                3 * x[0] ** 4 + 3 * x[2] ** 2,
                x[0] ** 3 + x[1] ** 4,
            )
        )
    )

    # Create discrete curl operator and get local part of G (including
    # ghost rows) as a SciPy sparse matrix
    # Note: do not 'assemble' (scatter_rev) G. This would wrongly
    # accumulate data for ghost entries.
    G = discrete_curl(V0, V1)
    Glocal = G.to_scipy(ghosted=True)

    # Apply discrete curl operator to the u0 vector
    u1 = Function(V1, dtype=dtype)
    x0 = u0.x.array
    u1.x.array[:] = Glocal[:, : x0.shape[0]] @ x0

    # Interpolate curl of u0 using Expression
    curl_u = Expression(ufl.curl(u0), V1.element.interpolation_points, dtype=dtype)
    u1_expr = Function(V1, dtype=dtype)
    u1_expr.interpolate(curl_u)

    atol = 1000 * np.finfo(dtype).resolution
    assert np.allclose(u1_expr.x.array, u1.x.array, atol=atol)


@pytest.mark.parametrize("p", range(1, 4))
@pytest.mark.parametrize("q", range(1, 4))
@pytest.mark.parametrize(
    "cell_type",
    [
        (
            create_unit_square(
                MPI.COMM_WORLD,
                11,
                6,
                ghost_mode=GhostMode.none,
                cell_type=CellType.triangle,
                dtype=np.float32,
            ),
            "Lagrange",
            "Nedelec 1st kind H(curl)",
        ),
        (
            create_unit_square(
                MPI.COMM_WORLD,
                11,
                6,
                ghost_mode=GhostMode.none,
                cell_type=CellType.triangle,
                dtype=np.float64,
            ),
            "Lagrange",
            "Nedelec 1st kind H(curl)",
        ),
        (
            create_unit_square(
                MPI.COMM_WORLD,
                11,
                6,
                ghost_mode=GhostMode.none,
                cell_type=CellType.quadrilateral,
                dtype=np.float32,
            ),
            "Q",
            "RTCE",
        ),
        (
            create_unit_square(
                MPI.COMM_WORLD,
                11,
                6,
                ghost_mode=GhostMode.none,
                cell_type=CellType.quadrilateral,
                dtype=np.float64,
            ),
            "Q",
            "RTCE",
        ),
        (
            create_unit_cube(
                MPI.COMM_WORLD,
                3,
                3,
                2,
                ghost_mode=GhostMode.none,
                cell_type=CellType.tetrahedron,
                dtype=np.float32,
            ),
            "Lagrange",
            "Nedelec 1st kind H(curl)",
        ),
        (
            create_unit_cube(
                MPI.COMM_WORLD,
                3,
                3,
                2,
                ghost_mode=GhostMode.none,
                cell_type=CellType.tetrahedron,
                dtype=np.float64,
            ),
            "Lagrange",
            "Nedelec 1st kind H(curl)",
        ),
        (
            create_unit_cube(
                MPI.COMM_WORLD,
                3,
                3,
                2,
                ghost_mode=GhostMode.none,
                cell_type=CellType.hexahedron,
                dtype=np.float32,
            ),
            "Q",
            "NCE",
        ),
        (
            create_unit_cube(
                MPI.COMM_WORLD,
                3,
                2,
                2,
                ghost_mode=GhostMode.none,
                cell_type=CellType.hexahedron,
                dtype=np.float64,
            ),
            "Q",
            "NCE",
        ),
    ],
)
def test_gradient_interpolation(cell_type, p, q):
    """Test discrete gradient computation with verification using Expression."""
    mesh, family0, family1 = cell_type
    dtype = mesh.geometry.x.dtype

    V = functionspace(mesh, (family0, p))
    W = functionspace(mesh, (family1, q))
    G = discrete_gradient(V, W)
    # N.B. do not scatter_rev G - doing so would transfer rows to other
    # processes where they will be summed to give an incorrect matrix

    # Vector for 'u' needs additional ghosts defined in columns of G
    uvec = dolfinx.la.vector(G.index_map(1), dtype=dtype)
    u = Function(V, uvec, dtype=dtype)
    u.interpolate(lambda x: 2 * x[0] ** p + 3 * x[1] ** p)

    grad_u = Expression(ufl.grad(u), W.element.interpolation_points, dtype=dtype)
    w_expr = Function(W, dtype=dtype)
    w_expr.interpolate(grad_u)

    # Compute global matrix vector product
    w = Function(W, dtype=dtype)

    # Get the local part of G (no ghost rows)
    Glocal = G.to_scipy(ghosted=False)

    # MatVec
    w.x.array[: Glocal.shape[0]] = Glocal @ u.x.array
    w.x.scatter_forward()

    atol = 1000 * np.finfo(dtype).resolution
    assert np.allclose(w_expr.x.array, w.x.array, atol=atol)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("p", range(1, 4))
@pytest.mark.parametrize("q", range(1, 4))
@pytest.mark.parametrize("from_lagrange", [True, False])
@pytest.mark.parametrize(
    "cell_type",
    [CellType.quadrilateral, CellType.triangle, CellType.tetrahedron, CellType.hexahedron],
)
def test_interpolation_matrix(dtype, cell_type, p, q, from_lagrange):
    """Test that discrete interpolation matrix yields the same result as interpolation."""
    from dolfinx.fem import interpolation_matrix

    comm = MPI.COMM_WORLD
    if cell_type == CellType.triangle:
        mesh = create_unit_square(
            comm, 7, 5, ghost_mode=GhostMode.none, cell_type=cell_type, dtype=dtype
        )
        lagrange = "Lagrange" if from_lagrange else "DG"
        nedelec = "Nedelec 1st kind H(curl)"
    elif cell_type == CellType.quadrilateral:
        mesh = create_unit_square(
            comm, 11, 6, ghost_mode=GhostMode.none, cell_type=cell_type, dtype=dtype
        )
        lagrange = "Q" if from_lagrange else "DQ"
        nedelec = "RTCE"
    elif cell_type == CellType.hexahedron:
        mesh = create_unit_cube(
            comm, 3, 2, 1, ghost_mode=GhostMode.none, cell_type=cell_type, dtype=dtype
        )
        lagrange = "Q" if from_lagrange else "DQ"
        nedelec = "NCE"
    elif cell_type == CellType.tetrahedron:
        mesh = create_unit_cube(
            comm, 3, 2, 2, ghost_mode=GhostMode.none, cell_type=cell_type, dtype=dtype
        )
        lagrange = "Lagrange" if from_lagrange else "DG"
        nedelec = "Nedelec 1st kind H(curl)"

    v_el = element(lagrange, mesh.basix_cell(), p, shape=(mesh.geometry.dim,), dtype=dtype)
    s_el = element(nedelec, mesh.basix_cell(), q, dtype=dtype)
    if from_lagrange:
        el0 = v_el
        el1 = s_el
    else:
        el0 = s_el
        el1 = v_el

    V, W = functionspace(mesh, el0), functionspace(mesh, el1)
    G = interpolation_matrix(V, W)
    G.scatter_reverse()
    G_sp = G.to_scipy()

    def f(x):
        if mesh.geometry.dim == 2:
            return (x[1] ** p, x[0] ** p)
        else:
            return (x[0] ** p, x[2] ** p, x[1] ** p)

    u = Function(V, dtype=dtype)
    u.interpolate(f)
    w_vec = Function(W, dtype=dtype)
    w_vec.interpolate(u)

    # Compute global matrix vector product
    w = Function(W, dtype=dtype)
    ux = np.zeros(G_sp.shape[1])
    ux[: len(u.x.array)] = u.x.array[:]
    w.x.array[: G_sp.shape[0]] = G_sp @ ux
    w.x.scatter_forward()

    atol = 100 * np.finfo(dtype).resolution
    assert np.allclose(w_vec.x.array, w.x.array, atol=atol)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
    "cell_type",
    [CellType.triangle, CellType.quadrilateral, CellType.tetrahedron, CellType.hexahedron],
)
def test_discrete_interpolation(cell_type, dtype):
    tdim = dolfinx.cpp.mesh.cell_dim(cell_type)
    if tdim == 2:
        mesh = dolfinx.mesh.create_unit_square(
            MPI.COMM_WORLD, 4, 4, cell_type=cell_type, dtype=dtype
        )
    elif tdim == 3:
        mesh = dolfinx.mesh.create_unit_cube(
            MPI.COMM_WORLD, 3, 2, 4, cell_type=cell_type, dtype=dtype
        )
    else:
        raise ValueError(f"Unsupported {cell_type=}")

    V = dolfinx.fem.functionspace(mesh, ("DG", 3))
    Q = dolfinx.fem.functionspace(mesh, ("Lagrange", 3))

    u = dolfinx.fem.Function(V, dtype=dtype)
    u.interpolate(lambda x: x[0] ** 3 - x[1] ** 3)

    int_matrix = dolfinx.fem.interpolation_matrix(V, Q)
    int_matrix.scatter_reverse()

    _u = dolfinx.la.vector(int_matrix.index_map(1), int_matrix.block_size[1], dtype=dtype)
    num_owned_dofs = V.dofmap.index_map.size_local * V.dofmap.index_map_bs
    _u.array[:num_owned_dofs] = u.x.array[:num_owned_dofs]
    _u.scatter_forward()

    q = dolfinx.fem.Function(Q, dtype=dtype)
    int_matrix.mult(_u, q.x)
    q.x.scatter_forward()

    q_ref = dolfinx.fem.Function(Q, dtype=dtype)
    q_ref.interpolate(u)

    atol = 100 * np.finfo(dtype).resolution
    np.testing.assert_allclose(q.x.array, q_ref.x.array, atol=atol)
