"""
In this script, we implement the MSEM for the Poisson problem with a
manufactured solution in the crazy_mesh.

⭕ To access the source code, click on the [source] button at the right
side or click on
:download:`[Poisson_problem.py]</contents/LIBRARY/ptc/mathischeap_ptc/Poisson_problem.py>`.
Dependence may exist. In case of error, check import and install required
packages or download required scripts. © mathischeap.com

"""

from numpy import pi, sin, cos
import numpy as np
from crazy_mesh import CrazyMesh, CrazyMeshGlobalNumbering, CrazyMeshGlobalBoundaryDOFs
from mimetic_basis_polynomials import MimeticBasisPolynomials
from incidence_matrices import E_div
from mass_matrices import MassMatrices
from projection import Reduction
from scipy import sparse as spspa
from scipy.sparse import linalg as spspalinalg
from assembly import assemble
from L2_error import L2Error

# the manufactured solutions
def phi_exact(x, y, z):
    return sin(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def u_exact(x, y, z):
    return 2 * pi * cos(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def v_exact(x, y, z):
    return 2 * pi * sin(2 * pi * x) * cos(2 * pi * y) * sin(2 * pi * z)
def w_exact(x, y, z):
    return 2 * pi * sin(2 * pi * x) * sin(2 * pi * y) * cos(2 * pi * z)
def f_exact(x,y,z):
    return 12 * pi**2 * sin(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def zero(x, y, z): # div u + f = 0
    return 0 * x * y * z



def Poisson(K, N, c, save=False):
    """
    :param int K: We use a crazy mesh of :math:`K^3` elements.
    :param int N: We use mimetic polynomials of degree :math:`N`.
    :param float c: The deformation factor of the crazy mesh is
        :math:`c,\ 0\\leq c\\leq 0.25`.
    :param save: Bool. If we save the coefficients of the variables.
    :return: A tuple of several outputs:

        - The :math:`L^2\\text{-error}` of solution :math:`\\boldsymbol{u}^h`.
        - The :math:`H(\\mathrm{div})\\text{-error}` of solution :math:`\\boldsymbol{u}^h`.
        - The :math:`L^2\\text{-error}` of solution :math:`\\varphi^h`.
        - The :math:`L^2\\text{-error}` of the projection, :math:`f^h`.
        - The :math:`L^2\\text{-error}` of :math:`\\nabla\\cdot\\boldsymbol{u}^h+f^h`.
        - The :math:`L^\\infty\\text{-error}` of :math:`\\nabla\\cdot\\boldsymbol{u}^h+f^h`.

    :example:

    >>> K = 2
    >>> N = 3
    >>> c = 0
    >>> Poisson(K, N, c) # doctest: +ELLIPSIS
    MSEM
    L^2-error of u^h:  0.1535...

    """
    K = int(K)
    N = int(N)
    c1000 = int(c*1000)

    # define the crazy mesh ...
    crazy_mesh = CrazyMesh(c, K)

    # generate the global numbering (gathering matrix) and find boundary dofs.
    GM_crazy_mesh = CrazyMeshGlobalNumbering(K, N)
    BD_crazy_mesh = CrazyMeshGlobalBoundaryDOFs(K, N)
    GM_FP = GM_crazy_mesh.FP
    GM_VP = GM_crazy_mesh.VP
    B_dofs_FP_dict = BD_crazy_mesh.FP


    B_dofs_FP = list()
    for bn in B_dofs_FP_dict:
        if bn != 'x_minus':
            B_dofs_FP.extend(B_dofs_FP_dict[bn])

    # define the basis functions
    _bfN_ = 'Lobatto-' + str(N)
    mbf = MimeticBasisPolynomials(_bfN_, _bfN_, _bfN_)

    # generate incidence matrix and mass matrices
    E = E_div(N, N, N)
    MF = list()
    MV = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                MM = MassMatrices(mbf, ct)
                MF.append(MM.FP)
                MV.append(MM.VP)

    # reduction of source term f_exact, and u_exact
    f_exact_local = list()
    u_exact_local = list()
    f_L2 = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                RD = Reduction(mbf, ct)
                f_dofs_local = RD.VP(f_exact)
                f_exact_local.append(f_dofs_local)
                L2e = L2Error(mbf, ct)
                f_L2_local = L2e.VP(f_dofs_local, f_exact)
                f_L2.append(f_L2_local**2)
                _temp_ = RD.FP((u_exact, v_exact, w_exact))
                u_exact_local.append(spspa.csc_matrix(
                    _temp_[:,np.newaxis]))
    u_exact_global = assemble(u_exact_local, GM_FP)
    f_L2 = np.sum(f_L2)**0.5

    # generate local systems A_m x = b_m for all elements.
    A00, A01, A10 = list(), list(), list() # store blocks in list
    b0, b1 = list(), list() # store vectors in list
    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                A00_m = MF[m]
                A10_m = MV[m] @ E
                A01_m = A10_m.T
                b0_m = spspa.csc_matrix((3*(N+1)*N**2, 1))
                b1_m = - MV[m] @ f_exact_local[m]
                b1_m = spspa.csc_matrix(b1_m[:,np.newaxis])
                A00.append(A00_m)
                A01.append(A01_m)
                A10.append(A10_m)
                b0.append(b0_m)
                b1.append(b1_m)

                # ( A00[m]   A01[m] )   (b0[m])
                # ( A01[m]          )   (b1[m])
                #
                #      refers to the local system in element m
    del A00_m, A01_m, A10_m

    # assemble local systems into global system
    A00 = assemble(A00, GM_FP, GM_FP)
    A01 = assemble(A01, GM_FP, GM_VP)
    A10 = assemble(A10, GM_VP, GM_FP)
    A = spspa.bmat([(A00, A01 ),               # left hand side matrix A
                    (A10, None)], format='lil')# of global system Ax = b
    del A00, A01, A10
    b0 = assemble(b0, GM_FP)
    b1 = assemble(b1, GM_VP)
    b = spspa.vstack((b0, b1), format='lil') # right hand side vector b

    # we apply the boundary condition.
    A[B_dofs_FP, :] = 0
    A[B_dofs_FP, B_dofs_FP] = 1
    b[B_dofs_FP] = u_exact_global[B_dofs_FP]
    A = A.tocsc() # scipy spsolve handles csc or csr matrix well

    shape_F = A.shape[0]
    # solve the global system using the direct solver provided by scipy
    x = spspalinalg.spsolve(A, b) # solve Ax=b, obtain x
    del A, b

    # post-process x into u, and phi, compute div u + f
    u_global = x[:int(np.max(GM_FP)+1)]
    phi_global = x[int(np.max(GM_FP)+1):]
    u_local = u_global[GM_FP]
    phi_local = phi_global[GM_VP]
    div_u_local = (- E @ u_local.T).T
    div_u_plus_f = (E @ u_local.T + np.array(f_exact_local).T).T

    # measure the L2-error of u phi, and div u + f
    u_L2 = list()
    div_u_L2 = list()
    phi_L2 = list()
    div_L2 = list()
    div_Linf = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                L2E = L2Error(mbf, ct)
                u_l2_m = L2E.FP(u_local[m], (u_exact, v_exact, w_exact))
                div_u_l2_m = L2E.VP(div_u_local[m], f_exact)
                phi_l2_m = L2E.VP(phi_local[m], phi_exact)
                div_L2_m = L2E.VP(div_u_plus_f[m], zero)
                div_Linf_m = L2E.VP(div_u_plus_f[m], zero, n='infty')
                u_L2.append(u_l2_m**2)
                div_u_L2.append(div_u_l2_m**2)
                phi_L2.append(phi_l2_m**2)
                div_L2.append(div_L2_m**2)
                div_Linf.append(div_Linf_m)

    u_L2 = np.sum(u_L2)**0.5
    div_u_L2 = np.sum(div_u_L2)**0.5
    phi_L2 = np.sum(phi_L2)**0.5
    div_L2 = np.sum(div_L2)**0.5
    div_Linf = np.max(div_Linf)

    u_Hdiv = np.sqrt(u_L2**2 + div_u_L2**2)
    # print info and return
    print('MSEM')
    print("L^2-error of u^h: ", u_L2)
    print("H(div)-error of u^h: ", u_Hdiv)
    print("L^2-error of phi^h: ", phi_L2)
    print("L^2-error of projection f^h-f: ", f_L2)
    print("L^2-error of div(u^h)+f^h: ", div_L2)
    print("L^inf-error of div(u^h)+f^h: ", div_Linf)

    M = K**3
    I = 3*K**2*(K-1)
    shape_F_1 = M*(3*N**2*(N+1)+N**3) - I*N**2
    print(f'M={M}, I={I}, shape_F = {shape_F}, {shape_F_1}')

    if save:
        name_temp = f'results/MSEM_K{K}_N{N}_c{c1000}_'
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'u.txt', u_local)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'phi.txt', phi_local)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'du_plus_f.txt', div_u_plus_f)

    return u_L2, u_Hdiv, phi_L2, f_L2, div_L2, div_Linf



if __name__ == '__main__':
    import doctest
    doctest.testmod()

    K = 2
    N = 3
    c = 0.

    Poisson(K, N, c, save=False)
