"""
An implementation of the projection for the trace spaces on a 3D surface
:math:`\\Gamma`. The projection can be divided into
two processes, *reduction* and *reconstruction*.

The *reduction* process will use the given scalar or vector to obtain
the coefficients of its projection in the discrete space.

The *reconstruction* process first assembles the coefficients with the
correct basis functions and evaluates the discrete variable at given
points.

⭕ To access the source code, click on the [source] button at the right
side or click on
:download:`[projection_trace.py]</contents/LIBRARY/ptc/mathischeap_ptc/projection_trace.py>`.
Dependence may exist. In case of error, check import and install required
packages or download required scripts. © mathischeap.com

"""

import numpy as np
from quadrature import Gauss
from mimetic_basis_polynomials_2d import grid2d



class ReductionTrace(object):
    """A wrapper of reduction functions.

    :param bf: The basis functions in :math:`\\Pi_{\\mathrm{ref}}`.
    :type bf: MimeticBasisPolynomials2D
    :param ct: The coordinate transformation representing the
        mapping :math:`\\Psi`,

            .. math::

                \\Psi: \\Pi_{\\mathrm{ref}}\\to\\Gamma

    :type ct: CoordinateTransformationSurface
    :param quad_degree: (default: ``None``) The degree used for the
        numerical integral. It should be a list or tuple of two
        positive integers. If it is ``None``, a suitable degree will be
        obtained from ``bf``.
    :type quad_degree: list, tuple

    :example:

    >>> def p(x,y,z): return np.sin(np.pi*x) * np.sin(np.pi*y) * np.sin(np.pi*z)
    >>> from mimetic_basis_polynomials_2d import MimeticBasisPolynomials2D
    >>> from coordinate_transformation_surface import CoordinateTransformationSurface
    >>> from coordinate_transformation_surface import Psi, d_Psi
    >>> bf = MimeticBasisPolynomials2D('Lobatto-3', 'Lobatto-3')
    >>> ct = CoordinateTransformationSurface(Psi, d_Psi)
    >>> Rd = ReductionTrace(bf, ct)
    >>> Rd.TF(p)
    array([0.03873144, 0.02717954, 0.02752846, 0.02717954, 0.04952738,
           0.04752636, 0.02752846, 0.04752636, 0.01838736])

    """
    def __init__(self, bf, ct, quad_degree=None):
        assert bf.__class__.__name__ == 'MimeticBasisPolynomials2D'
        assert ct.__class__.__name__ == 'CoordinateTransformationSurface'
        self.bf = bf
        self.ct = ct
        if quad_degree is None:
            bf_N = bf.degree
            quad_degree = [bf_N[i]+2 for i in range(2)]
        else:
            pass
        self.quad_degree = quad_degree

    def TN(self, scalar):
        """Reduce a scalar on :math:`\\Gamma` to
        :math:`\\text{TN}_{N}(\\Gamma)`.

        :param scalar:
        :return: A 1d np.array representing the local coefficients/dofs
            of the discrete scalar.

        """
        raise NotImplementedError("Could you code it?")

    def TE(self, vector):
        """Reduce a vector on :math:`\\Gamma` to
        :math:`\\text{TE}_{N-1}(\\Gamma)`.

        :param vector:
        :return: A 1d np.array representing the local coefficients/dofs
            of the discrete vector.

        """
        raise NotImplementedError("Could you code it?")

    def TF(self, scalar):
        """Reduce a scalar on :math:`\\Gamma` to
        :math:`\\text{TF}_{N-1}(\\Gamma)`.

        :param scalar:
        :return: A 1d np.array representing the local coefficients/dofs
            of the discrete scalar.

        """
        N = self.bf.degree
        NUM_basis = N[0]*N[1]

        nodes = self.bf.nodes

        p = self.quad_degree
        qn0, qw0 = Gauss(p[0])
        qn1, qw1 = Gauss(p[1])
        quad_weights = [qw0, qw1]
        quad_nodes = [qn0, qn1]
        magic_factor = 0.25

        rho = np.zeros((NUM_basis, p[0] + 1, p[1] + 1))
        tau = np.zeros((NUM_basis, p[0] + 1, p[1] + 1))
        volume = np.zeros(NUM_basis)
        for j in range(N[1]):
            for i in range(N[0]):
                m = i + j*N[0]
                rho[m,...] = (quad_nodes[0][:,np.newaxis].repeat(p[1]+1, axis=1) + 1)\
                          * (nodes[0][i+1]-nodes[0][i]
                          )/2 + nodes[0][i]
                tau[m,...] = (quad_nodes[1][np.newaxis,:].repeat(p[0]+1, axis=0) + 1)\
                          * (nodes[1][j+1]-nodes[1][j]
                          )/2 + nodes[1][j]
                volume[m] = (nodes[0][i+1]-nodes[0][i]) \
                          * (nodes[1][j+1]-nodes[1][j])  * magic_factor

        g = self.ct.metric(rho, tau)
        xyz = self.ct.mapping(rho, tau)
        fxyz = scalar(*xyz)
        return np.einsum('jkl, k, l, j -> j',
                         fxyz * np.sqrt(g), quad_weights[0], quad_weights[1],
                         volume, optimize='greedy'
                         )



class ReconstructionTrace(object):
    """A wrapper of reconstruction functions.

    :param bf: The basis functions in :math:`\\Pi_{\\mathrm{ref}}`.
    :type bf: MimeticBasisPolynomials2D
    :param ct: The coordinate transformation representing the
        mapping :math:`\\Psi`,

            .. math::

                \\Psi: \\Pi_{\\mathrm{ref}}\\to\\Gamma

    :type ct: CoordinateTransformationSurface

    :example:

    >>> def p(x,y,z): return np.sin(np.pi*x) * np.sin(np.pi*y) * np.sin(np.pi*z)
    >>> from mimetic_basis_polynomials_2d import MimeticBasisPolynomials2D
    >>> from coordinate_transformation_surface import CoordinateTransformationSurface
    >>> from coordinate_transformation_surface import Psi, d_Psi
    >>> bf = MimeticBasisPolynomials2D('Lobatto-20', 'Lobatto-20')
    >>> ct = CoordinateTransformationSurface(Psi, d_Psi)
    >>> Rd = ReductionTrace(bf, ct)
    >>> dofs = Rd.TF(p)
    >>> Rc = ReconstructionTrace(bf, ct)
    >>> rho = np.linspace(-1,1,5)
    >>> tau = np.linspace(-1,1,5)
    >>> xyz, v = Rc.TF(dofs, rho, tau)
    >>> float(np.max(np.abs(p(*xyz) - v))) # the max error from the exact function # doctest: +ELLIPSIS
    0.0002989...

    """
    def __init__(self, bf, ct):
        assert bf.__class__.__name__ == 'MimeticBasisPolynomials2D'
        assert ct.__class__.__name__ == 'CoordinateTransformationSurface'
        self.bf = bf
        self.ct = ct

    def TN(self, loc_dofs, rho, tau, ravel=False):
        """Reconstruct a discrete trace polynomial in
        :math:`\\text{TN}_{N}(\\Gamma)`
        evaluated at
        :math:`\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)``.

        :param loc_dofs: A 1d np.array containing the coefficients of
            the discrete variable in :math:`\\text{TN}_{N}(\\Gamma)`.
        :type loc_dofs: np.array
        :param rho: :math:`\\varrho`.
        :param tau: :math:`\\tau`. The reconstruction will be
            evaluated at :math:`\\Psi\\circ \\text{grid2d}(
            \\varrho, \\tau)``.
        :type rho: 1d np.array
        :type tau: 1d np.array
        :param ravel: (default: ``False``) If ``ravel`` is ``True``, we
            will flat the outputs (as a 1d array) according to local
            numbering. Otherwise, you get 2d outputs corresponding to
            the indexing.
        :returns: A tuple of two outputs:

            * :math:`(x,y,z)`: The reconstruction is evaluated at
                :math:`(x,y,z):=\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)`.
            * values: The values of the reconstructed variable at :math:`(x,y,z)`.

        """
        raise NotImplementedError("Could you code it?")

    def TE(self, loc_dofs, rho, tau, ravel=False):
        """Reconstruct a discrete trace polynomial in
        :math:`\\text{TE}_{N-1}(\\Gamma)`
        evaluated at
        :math:`\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)``.

        :param loc_dofs: A 1d np.array containing the coefficients of
            the discrete variable in :math:`\\text{TE}_{N-1}(\\Gamma)`.
        :type loc_dofs: np.array
        :param rho: :math:`\\varrho`.
        :param tau: :math:`\\tau`. The reconstruction will be
            evaluated at :math:`\\Psi\\circ \\text{grid2d}(
            \\varrho, \\tau)``.
        :type rho: 1d np.array
        :type tau: 1d np.array
        :param ravel: (default: ``False``) If ``ravel`` is ``True``, we
            will flat the outputs (as a 1d array) according to local
            numbering. Otherwise, you get 2d outputs corresponding to
            the indexing.
        :returns: A tuple of two outputs:

            * :math:`(x,y,z)`: The reconstruction is evaluated at
                :math:`(x,y,z):=\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)`.
            * values: The values of the reconstructed variable at :math:`(x,y,z)`.

        """
        raise NotImplementedError("Could you code it?")

    def TF(self, loc_dofs, rho, tau, ravel=False):
        """Reconstruct a discrete trace polynomial in
        :math:`\\text{TF}_{N-1}(\\Gamma)`
        evaluated at
        :math:`\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)``.

        :param loc_dofs: A 1d np.array containing the coefficients of
            the discrete variable in :math:`\\text{TF}_{N-1}(\\Gamma)`.
        :type loc_dofs: np.array
        :param rho: :math:`\\varrho`.
        :param tau: :math:`\\tau`. The reconstruction will be
            evaluated at :math:`\\Psi\\circ \\text{grid2d}(
            \\varrho, \\tau)``.
        :type rho: 1d np.array
        :type tau: 1d np.array
        :param ravel: (default: ``False``) If ``ravel`` is ``True``, we
            will flat the outputs (as a 1d array) according to local
            numbering. Otherwise, you get 2d outputs corresponding to
            the indexing.
        :returns: A tuple of two outputs:

            * :math:`(x,y,z)`: The reconstruction is evaluated at
                :math:`(x,y,z):=\\Psi\\circ \\text{grid2d}(\\varrho, \\tau)`.
            * values: The values of the reconstructed variable at :math:`(x,y,z)`.

        """
        basis = self.bf.face_polynomials(rho, tau)
        rho_tau = grid2d(rho, tau)
        xyz = self.ct.mapping(*rho_tau)
        g = self.ct.metric(*rho_tau)
        v = np.einsum('ij, i -> j', basis, loc_dofs,
                      optimize='greedy') * np.reciprocal(np.sqrt(g))

        if ravel:
            pass
        else:
            shape = [len(rho), len(tau)]
            xyz = [xyz[j].reshape(shape, order='F') for j in range(3)]
            v = v.reshape(shape, order='F')

        return xyz, v



if __name__ == '__main__':
    def p(x,y,z): return np.sin(np.pi*x) + 0 * y * z # * np.sin(np.pi*y) * np.sin(np.pi*z)

    from mimetic_basis_polynomials_2d import MimeticBasisPolynomials2D
    from coordinate_transformation_surface import CoordinateTransformationSurface
    from coordinate_transformation_surface import Psi, d_Psi

    bf = MimeticBasisPolynomials2D('Lobatto-10', 'Lobatto-10')
    ct = CoordinateTransformationSurface(Psi, d_Psi)

    Rd = ReductionTrace(bf, ct)

    dofs = Rd.TF(p)


    Rc = ReconstructionTrace(bf, ct)
    rho = np.linspace(-1,1,100)
    tau = np.linspace(-1,1,100)

    xyz, v = Rc.TF(dofs, rho, tau)


    import matplotlib.pyplot as plt
    from matplotlib import cm


    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    V = (v + 1) / 2
    cmap = cm.bwr

    ax.plot_surface(*xyz, facecolors=cmap(V))

    mappable = cm.ScalarMappable(cmap=cmap)
    mappable.set_array(V)
    cbar = plt.colorbar(mappable, ax=ax, ticks=[0, 0.5, 1], shrink=1, aspect=20,
                        extend='both',
                        orientation='vertical', label='Some Units')

    cbar.ax.set_yticklabels(
        [r"-1", "0", "1"])  # vertically oriented colorbar

    ax.set_xlabel(r'$x$')
    ax.set_ylabel(r'$y$')
    ax.set_zlabel(r'$z$')
    plt.show()

    # print(v[:,0]- v[:,1])