"""
In this script, we define a mesh in a unit cube,

    .. math::
        \\Omega = [0,1]^3.

The mesh is of :math:`K^3` elements,

    .. math::
        \\Omega_{m}=\\Omega_{i+(j-1)K+(k-1)K^2}=\\Omega_{i,j,k},\\quad
        i,j,k \\in\\left\\lbrace1,2,\\cdots,K\\right\\rbrace.

The mapping :math:`\\Phi_{m}=\\Phi_{i,j,k}:\\Omega_{\\mathrm{ref}}\\to
\\Omega_{i,j,k}` is given as

    .. math::
        \\Phi_{i,j,k} = \\mathring{\\Phi}\\circ\\Xi_{i,j,k},

where :math:`\\Xi_{i,j,k}` is a linear mapping,
:math:`\\Xi_{i,j,k}:\\Omega_{\\mathrm{ref}}\\to\\left(
\\left[\\dfrac{i-1}{K},\\dfrac{i}{K}\\right],
\\left[\\dfrac{j-1}{K},\\dfrac{j}{K}\\right],
\\left[\\dfrac{k-1}{K},\\dfrac{k}{K}\\right]
\\right)`, i.e.,

    .. math::
        \\begin{pmatrix}
            r\\\\s\\\\t
        \\end{pmatrix} = \\Xi_{i,j,k}(\\xi,\\eta,\\varsigma)
        =
        \\dfrac{1}{K}
        \\begin{pmatrix}
            i-1 + (\\xi+1)/2\\\\
            j-1 + (\\eta+1)/2\\\\
            k-1 + (\\varsigma+1)/2
        \\end{pmatrix},

and :math:`\\mathring{\\Phi}` is a mapping,

    .. math::
        \\begin{pmatrix}
            x\\\\y\\\\z
        \\end{pmatrix} = \\mathring{\\Phi}(r,s,t) =
        \\begin{pmatrix}
        r + \\frac{1}{2}c \\sin(2\\pi r)\\sin(2\\pi s)\\sin(2\\pi t)\\\\
        s + \\frac{1}{2}c \\sin(2\\pi r)\\sin(2\\pi s)\\sin(2\\pi t)\\\\
        t + \\frac{1}{2}c \\sin(2\\pi r)\\sin(2\\pi s)\\sin(2\\pi t)
        \\end{pmatrix},

where :math:`0 \leq c \leq 0.25` is a deformation factor. When
:math:`c=0`, :math:`\mathring{\Phi}` is also a linear mapping. Thus we
have a uniform orthogonal mesh. When :math:`c>0`, the mesh is
curvilinear. Two examples (left: :math:`c=0`, right: :math:`c=0.25`) of
this mesh for :math:`K=3` are shown below.

.. image:: workshop/crazy_mesh.png

⭕ To access the source code, click on the [source] button at the right
side or click on
:download:`[crazy_mesh.py]</contents/LIBRARY/ptc/mathischeap_ptc/crazy_mesh.py>`.
Dependence may exist. In case of error, check import and install required
packages or download required scripts. © mathischeap.com

"""

import numpy as np
from numpy import sin, cos, pi
from coordinate_transformation import CoordinateTransformation

class CrazyMesh(object):
    """The crazy mesh.

    :param c: The deformation factor, :math:`0 \leq c \leq 0.25`.
    :type c: float
    :param K: The crazy mesh is of :math:`K^3` elements.
    :type K: int

    :example:

    >>> cm = CrazyMesh(0.1, 2)
    >>> e0 = cm.CT_of_element_number(0)
    >>> e000 = cm.CT_of_element_index(0, 0, 0)
    >>> e0 is e000
    True
    >>> e7 = cm.CT_of_element_number(7)
    >>> e111 = cm.CT_of_element_index(1, 1, 1)
    >>> e7 is e111
    True
    >>> e000 # doctest: +ELLIPSIS
    <coordinate_transformation.CoordinateTransformation object at...
    >>> e7 # doctest: +ELLIPSIS
    <coordinate_transformation.CoordinateTransformation object at...

    """
    def __init__(self, c, K):
        assert 0 <= c <= 0.25, \
            "The deformation factor must be in [0, 0.25]."
        self.c = c
        self.K = K
        self._cache_ = dict() # we cache the output for repeated usage.

    def CT_of_element_index(self, i, j, k):
        """Return a :class:`CoordinateTransformation` instance for
        element :math:`\\Omega_{i,j,k}`.

        Note that Python index starts from :math:`0`. So for a
        :class:`CrazyMesh` of :math:`K^3` elements, its indices,
        :math:`i,j,k\\in\\left\\lbrace 0, 1, K-1\\right\\rbrace`.

        :param i: Element index ``i``.
        :param j: Element index ``j``.
        :param k: Element index ``k``.
        :type i: int
        :type j: int
        :type k: int
        :return: A :class:`CoordinateTransformation` instance.
        """
        c, K = self.c, self.K

        assert 0 <= i < K and i % 1 == 0, \
            f"i={i} is wrong, must be an integer between 0 and {K}."
        assert 0 <= j < K and j % 1 == 0, \
            f"j={j} is wrong, must be an integer between 0 and {K}."
        assert 0 <= k < K and k % 1 == 0, \
            f"k={k} is wrong, must be an integer between 0 and {K}."

        indicator = str(i)+'-'+str(j)+'-'+str(k)
        if indicator in self._cache_: # to see if this element is cached
            # if yes, then return it from the cache.
            return self._cache_[indicator]

        # otherwise, we make it.
        def Phi(xi, et, sg):
            r = ( 1 / K ) * ( i + (xi + 1) / 2 )
            s = ( 1 / K ) * ( j + (et + 1) / 2 )
            t = ( 1 / K ) * ( k + (sg + 1) / 2 )
            x = r + 0.5* c * sin(2 * pi * r) * sin(2 * pi * s) * sin(2 * pi * t)
            y = s + 0.5* c * sin(2 * pi * r) * sin(2 * pi * s) * sin(2 * pi * t)
            z = t + 0.5* c * sin(2 * pi * r) * sin(2 * pi * s) * sin(2 * pi * t)
            return x, y, z

        def d_Phi(xi, et, sg):
            r = ( 1 / K ) * ( i + (xi + 1) / 2 )
            s = ( 1 / K ) * ( j + (et + 1) / 2 )
            t = ( 1 / K ) * ( k + (sg + 1) / 2 )
            x_xi = (1 + c * pi * cos(2 * pi * r) * sin(
                2 * pi * s) * sin(2 * pi * t)) / (2 * K)
            x_et = (c * pi * sin(2 * pi * r) * cos(2 * pi * s) * sin(
                2 * pi * t)) / (2 * K)
            x_sg = (c * pi * sin(2 * pi * r) * sin(2 * pi * s) * cos(
                2 * pi * t)) / (2 * K)
            y_xi = (c * pi * cos(2 * pi * r) * sin(2 * pi * s) * sin(
                2 * pi * t)) / (2 * K)
            y_et = (1 + c * pi * sin(2 * pi * r) * cos(
                2 * pi * s) * sin(2 * pi * t)) / (2 * K)
            y_sg = (c * pi * sin(2 * pi * r) * sin(2 * pi * s) * cos(
                2 * pi * t)) / (2 * K)
            z_xi = (c * pi * cos(2 * pi * r) * sin(2 * pi * s) * sin(
                2 * pi * t)) / (2 * K)
            z_et = (c * pi * sin(2 * pi * r) * cos(2 * pi * s) * sin(
                2 * pi * t)) / (2 * K)
            z_sg = (1 + c * pi * sin(2 * pi * r) * sin(
                2 * pi * s) * cos(2 * pi * t)) / (2 * K)
            return (x_xi, x_et, x_sg), (y_xi, y_et, y_sg), (
            z_xi, z_et, z_sg)

        ct = CoordinateTransformation(Phi, d_Phi)
        self._cache_[indicator] = ct # before return, we cache it.
        return ct

    def CT_of_element_number(self, m):
        """Return a :class:`CoordinateTransformation` instance for
        element :math:`\\Omega_{m}`.

        Note that Python index starts from :math:`0`. So for a
        :class:`CrazyMesh` of :math:`K^3` elements,
        :math:`m\\in\\left\\lbrace 0,1,\\cdots,K^3-1\\right\\rbrace`.

        :param m: Element No. ``m``.
        :type m: int
        :return: A :class:`CoordinateTransformation` instance.
        """
        K = self.K
        assert 0 <= m < K**3 and m % 1 == 0, \
            f"m={m} is wrong, must be an integer between 0 and {K**3}."
        k = m // K**2
        j = ( m - k * K**2 ) // K
        i = m - k * K**2 - j * K
        return self.CT_of_element_index(i, j, k)



class CrazyMeshGlobalNumbering(object):
    """A wrapper of global numberings for discrete variables in the
    crazy mesh.

    :param K: The crazy mesh is of :math:`K^3` elements.
    :param N: The degree :math:`N`. of the to be used mimetic polynomial
        basis functions.
    :type K: int
    :type N: int

    :example:

    >>> K = 2
    >>> N = 1
    >>> GM = CrazyMeshGlobalNumbering(K, N)
    >>> GM.FP
    array([[ 0,  1, 12, 14, 24, 28],
           [ 1,  2, 13, 15, 25, 29],
           [ 3,  4, 14, 16, 26, 30],
           [ 4,  5, 15, 17, 27, 31],
           [ 6,  7, 18, 20, 28, 32],
           [ 7,  8, 19, 21, 29, 33],
           [ 9, 10, 20, 22, 30, 34],
           [10, 11, 21, 23, 31, 35]])
    >>> GM.VP
    array([[0],
           [1],
           [2],
           [3],
           [4],
           [5],
           [6],
           [7]])

    """
    def __init__(self, K, N):
        self.K, self.N = K, N

    @property
    def NP(self):
        """Generate a global numbering for the dofs of an element in
        :math:`\\text{NP}_{N}(\\Omega)` on a crazy mesh of
        :math:`K^3` elements.
        """
        raise NotImplementedError("Could you complete this property?")

    @property
    def EP(self):
        """Generate a global numbering for the dofs of an element in
        :math:`\\text{EP}_{N-1}(\\Omega)` on a crazy mesh of
        :math:`K^3` elements.
        """
        raise NotImplementedError("Could you complete this property?")

    @property
    def FP(self):
        """Generate a global numbering for the dofs of an element in
        :math:`\\text{FP}_{N-1}(\Omega)` on a crazy mesh of
        :math:`K^3` elements.
        """
        K, N = self.K, self.N
        KN = K * N
        FP_dofs_3d_x = np.arange((KN+1)*KN**2, dtype='int').reshape(
            (KN+1, KN, KN), order='F')
        accumulated = (KN + 1) * KN**2
        FP_dofs_3d_y = np.arange(accumulated,
                                 accumulated + (KN+1) * KN**2,
                                 dtype='int').reshape((KN, KN+1, KN),
                                                      order='F')
        accumulated = 2 * (KN + 1) * KN**2
        FP_dofs_3d_z = np.arange(accumulated,
                                 accumulated + (KN+1) * KN**2,
                                 dtype='int').reshape((KN, KN, KN+1),
                                                      order='F')

        GM_FP_x = np.zeros((K**3, (N+1) * N**2), dtype='int')
        GM_FP_y = np.zeros((K**3, (N+1) * N**2), dtype='int')
        GM_FP_z = np.zeros((K**3, (N+1) * N**2), dtype='int')
        for k in range(K):
            for j in range(K):
                for i in range(K):
                    m = i + j * K + k * K ** 2
                    GM_FP_x[m] = FP_dofs_3d_x[
                                    i*N:(i+1)*N+1,
                                    j*N:(j+1)*N,
                                    k*N:(k+1)*N].ravel('F')
                    GM_FP_y[m] = FP_dofs_3d_y[
                                    i*N:(i+1)*N,
                                    j*N:(j+1)*N+1,
                                    k*N:(k+1)*N].ravel('F')
                    GM_FP_z[m] = FP_dofs_3d_z[
                                    i*N:(i+1)*N,
                                    j*N:(j+1)*N,
                                    k*N:(k+1)*N+1].ravel('F')
        GM_FP = np.hstack((GM_FP_x, GM_FP_y, GM_FP_z))
        return GM_FP

    @property
    def VP(self):
        """Generate a global numbering for the dofs of an element in
        :math:`\\text{VP}_{N-1}(\\Omega)` on a crazy mesh of
        :math:`K^3` elements.
        """
        K, N = self.K, self.N
        KN = K * N
        VP_dofs_3d = np.arange(KN ** 3, dtype='int'
                               ).reshape((KN, KN, KN), order='F')
        GM_VP = np.zeros((K ** 3, N ** 3), dtype='int')
        for k in range(K):
            for j in range(K):
                for i in range(K):
                    m = i + j * K + k * K ** 2
                    GM_VP[m] = VP_dofs_3d[i*N:(i+1)*N,
                                          j*N:(j+1)*N,
                                          k*N:(k+1)*N].ravel('F')
        return GM_VP



class CrazyMeshGlobalBoundaryDOFs(object):
    """We find the global numbering of the dofs on each boundary of the
    crazy mesh.

    :param K: The crazy mesh is of :math:`K^3` elements.
    :param N: The degree :math:`N`. of the to be used mimetic polynomial
        basis functions.
    :type K: int
    :type N: int

    :example:

    >>> K = 2
    >>> N = 1
    >>> B_DOFs = CrazyMeshGlobalBoundaryDOFs(K, N)
    >>> FB_dofs = B_DOFs.FP
    >>> FB_dofs['x_minus'] # x=0 face
    array([0, 3, 6, 9])
    >>> FB_dofs['x_plus'] # x=1 face
    array([ 2,  5,  8, 11])
    >>> FB_dofs['y_minus'] # y=0 face
    array([12, 13, 18, 19])
    >>> FB_dofs['y_plus'] # y=1 face
    array([16, 17, 22, 23])
    >>> FB_dofs['z_minus'] # z=0 face
    array([24, 25, 26, 27])
    >>> FB_dofs['z_plus'] # z=1 face
    array([32, 33, 34, 35])
    """
    def __init__(self, K, N):
        self.K, self.N = K, N

    @property
    def NP(self):
        """Find the dofs of an element in
        :math:`\\text{NP}_{N}(\Omega)` which are on boundary of the
        crazy mesh of :math:`K^3` elements.
        """
        raise NotImplementedError("Could you complete this property?")

    @property
    def EP(self):
        """Find the dofs of an element in
        :math:`\\text{EP}_{N-1}(\Omega)` which are on boundary of the
        crazy mesh of :math:`K^3` elements.
        """
        raise NotImplementedError("Could you complete this property?")

    @property
    def FP(self):
        """Find the dofs of an element in
        :math:`\\text{FP}_{N-1}(\Omega)` which are on boundary of the
        crazy mesh of :math:`K^3` elements.

        :returns: A dict whose keys are 'x_minus', 'x_plus', 'y_minus',
            'y_plus', 'z_minus', 'z_plus', and whose values are the
            global numbering of the dofs on the boundaries indicated
            by the keys.
        """
        K, N = self.K, self.N
        KN = K * N
        FP_dofs_3d_x = np.arange((KN+1)*KN**2).reshape(
            (KN+1, KN, KN), order='F')
        accumulated = (KN + 1) * KN**2
        FP_dofs_3d_y = np.arange(accumulated,
                                 accumulated + (KN+1) * KN**2).reshape(
            (KN, KN+1, KN), order='F')
        accumulated = 2 * (KN + 1) * KN**2
        FP_dofs_3d_z = np.arange(accumulated,
                                 accumulated + (KN+1) * KN**2).reshape(
            (KN, KN, KN+1), order='F')

        DOFs_on_boundary = dict()
        DOFs_on_boundary['x_minus'] = FP_dofs_3d_x[0, :, :].ravel('F')
        DOFs_on_boundary['x_plus'] = FP_dofs_3d_x[-1, :, :].ravel('F')
        DOFs_on_boundary['y_minus'] = FP_dofs_3d_y[:, 0, :].ravel('F')
        DOFs_on_boundary['y_plus'] = FP_dofs_3d_y[:, -1, :].ravel('F')
        DOFs_on_boundary['z_minus'] = FP_dofs_3d_z[:, :, 0].ravel('F')
        DOFs_on_boundary['z_plus'] = FP_dofs_3d_z[:, :, -1].ravel('F')

        return DOFs_on_boundary



if __name__ == '__main__':
    import doctest
    doctest.testmod()


    def phi_exact(x, y, z):
        return sin(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)


    def u_exact(x, y, z):
        return 2 * pi * cos(2 * pi * x) * sin(2 * pi * y) * sin(
            2 * pi * z)


    def v_exact(x, y, z):
        return 2 * pi * sin(2 * pi * x) * cos(2 * pi * y) * sin(
            2 * pi * z)


    def w_exact(x, y, z):
        return 2 * pi * sin(2 * pi * x) * sin(2 * pi * y) * cos(
            2 * pi * z)


    def f_exact(x, y, z):
        return 12 * pi ** 2 * sin(2 * pi * x) * sin(2 * pi * y) * sin(
            2 * pi * z)


    from projection import Reduction
    from L2_error import L2Error
    from mimetic_basis_polynomials import MimeticBasisPolynomials

    N = 2
    KK = np.array([i for i in range(1,11)])
    c=0.1

    Error_p = list()
    Error_u = list()
    Error_w = list()
    Error_f = list()
    for K in KK:
        _bfN_ = 'Lobatto-' + str(N)
        mbf = MimeticBasisPolynomials(_bfN_, _bfN_, _bfN_)
        crazy_mesh = CrazyMesh(c, K)

        p_exact_local = list()
        p_L2 = list()
        w_exact_local = list()
        w_L2 = list()
        u_exact_local = list()
        u_L2 = list()
        f_exact_local = list()
        f_L2 = list()
        for k in range(K):
            for j in range(K):
                for i in range(K):
                    ct = crazy_mesh.CT_of_element_index(i, j, k)
                    RD = Reduction(mbf, ct)
                    L2e = L2Error(mbf, ct)

                    p_dofs_local = RD.NP(phi_exact)
                    p_L2_local = L2e.NP(p_dofs_local, phi_exact)
                    p_L2.append(p_L2_local ** 2)

                    w_dofs_local = RD.EP((u_exact, v_exact, w_exact))
                    w_L2_local = L2e.EP(w_dofs_local,
                                        (u_exact, v_exact, w_exact))
                    w_L2.append(w_L2_local ** 2)

                    u_dofs_local = RD.FP((u_exact, v_exact, w_exact))
                    u_L2_local = L2e.FP(u_dofs_local,
                                        (u_exact, v_exact, w_exact))
                    u_L2.append(u_L2_local ** 2)

                    f_dofs_local = RD.VP(f_exact)
                    f_L2_local = L2e.VP(f_dofs_local, f_exact)
                    f_L2.append(f_L2_local**2)


        Error_p.append(np.sum(p_L2)**0.5)
        Error_w.append(np.sum(w_L2)**0.5)
        Error_u.append(np.sum(u_L2)**0.5)
        Error_f.append(np.sum(f_L2)**0.5)


    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12, 12))
    plt.subplot2grid((2, 2), (0, 0))
    x = 1/KK
    y = Error_p
    order = (np.log10(y[-1])- np.log10(y[-2])) /\
                        (np.log10(x[-1])-np.log10(x[-2]))
    plt.loglog(x, y,
               label=f'Node, h-convergence, c={c}, order=%.2f'%order)
    plt.legend()

    plt.subplot2grid((2, 2), (0, 1))
    x = 1/KK
    y = Error_w
    order = (np.log10(y[-1])- np.log10(y[-2])) /\
                        (np.log10(x[-1])-np.log10(x[-2]))
    plt.loglog(x, y,
               label=f'Edge, h-convergence, c={c}, order=%.2f'%order)
    plt.legend()

    plt.subplot2grid((2, 2), (1, 0))
    x = 1/KK
    y = Error_u
    order = (np.log10(y[-1])- np.log10(y[-2])) /\
                        (np.log10(x[-1])-np.log10(x[-2]))
    plt.loglog(x, y,
               label=f'Face, h-convergence, c={c}, order=%.2f'%order)
    plt.legend()

    plt.subplot2grid((2, 2), (1, 1))
    x = 1/KK
    y = Error_f
    order = (np.log10(y[-1])- np.log10(y[-2])) /\
                        (np.log10(x[-1])-np.log10(x[-2]))
    plt.loglog(x, y,
        label=f'Volume form, h-convergence, c={c}, order=%.2f'%order)
    plt.legend()


    plt.show()