"""
In this script, we implement the hdMSEM for the Poisson problem with a
manufactured solution in the crazy_mesh.

⭕ To access the source code, click on the [source] button at the right
side or click on
:download:`[Poisson_problem_hd.py]</contents/LIBRARY/ptc/mathischeap_ptc/Poisson_problem_hd.py>`.
Dependence may exist. In case of error, check import and install required
packages or download required scripts. © mathischeap.com

"""
from numpy import pi, sin, cos
import numpy as np
from crazy_mesh_hybrid import CrazyMeshHybrid, CrazyMeshHybridGlobalNumbering, CrazyMeshHybridLocalBoundaryDOFs
from mimetic_basis_polynomials import MimeticBasisPolynomials
from incidence_matrices import E_div
from trace_matrices import TF
from mass_matrices import MassMatrices
from projection import Reduction
from scipy import sparse as spspa
from scipy.sparse import linalg as spspalinalg
from assembly import assemble
from L2_error import L2Error
from L2_error_dual import L2ErrorDual



# the manufactured solutions
def phi_exact(x, y, z):
    return sin(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def u_exact(x, y, z):
    return 2 * pi * cos(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def v_exact(x, y, z):
    return 2 * pi * sin(2 * pi * x) * cos(2 * pi * y) * sin(2 * pi * z)
def w_exact(x, y, z):
    return 2 * pi * sin(2 * pi * x) * sin(2 * pi * y) * cos(2 * pi * z)
def f_exact(x,y,z):
    return 12 * pi**2 * sin(2 * pi * x) * sin(2 * pi * y) * sin(2 * pi * z)
def zero(x, y, z): # div u + f = 0
    return 0 * x * y * z



def Poisson_hd(K, N, c, save=False):
    """
    :param int K: We use a crazy mesh of :math:`K^3` elements. The
        domain decomposition is based on this crazy mesh.
    :param int N: We use mimetic polynomials of degree :math:`N`.
    :param float c: The deformation factor of the crazy mesh is
        :math:`c,\ 0\\leq c\\leq 0.25`.
    :param save: Bool. If we save the coefficients of the variables.
    :return: A tuple of several outputs:

        - The :math:`L^2\\text{-error}` of solution :math:`\\boldsymbol{u}^h`.
        - The :math:`H(\\mathrm{div})\\text{-error}` of solution :math:`\\boldsymbol{u}^h`.
        - The :math:`L^2\\text{-error}` of solution :math:`\\varphi^h`.
        - The :math:`L^2\\text{-error}` of the projection, :math:`f^h`.
        - The :math:`L^2\\text{-error}` of :math:`\\nabla\\cdot\\boldsymbol{u}^h+f^h`.
        - The :math:`L^\\infty\\text{-error}` of :math:`\\nabla\\cdot\\boldsymbol{u}^h+f^h`.
        - The :math:`\\widetilde{H}^1\\text{-error}` of solution :math:`\\varphi^h`.

    :example:

    >>> K = 2
    >>> N = 3
    >>> c = 0
    >>> Poisson_hd(K, N, c) # doctest: +ELLIPSIS
    hdMSEM
    L^2-error of u^h:  0.1535...

    """
    K = int(K)
    N = int(N)
    c1000 = int(c*1000)

    # define the crazy mesh ...
    crazy_mesh = CrazyMeshHybrid(c, K)

    # generate the global numbering (gathering matrix) and find boundary dofs.
    GM_crazy_mesh = CrazyMeshHybridGlobalNumbering(K, N)
    GM_TF = GM_crazy_mesh.TF

    BD_crazy_mesh = CrazyMeshHybridLocalBoundaryDOFs(K, N)
    B_dofs_FP_dict = BD_crazy_mesh.FP

    B_dofs_FP = dict()
    for bn in B_dofs_FP_dict:
        if bn != 'x_minus':
            dofs_on_side = B_dofs_FP_dict[bn]
            for m in dofs_on_side:
                if m not in B_dofs_FP:
                    B_dofs_FP[m] = list()
                B_dofs_FP[m].extend(dofs_on_side[m])

    B_dofs_TF_dict = BD_crazy_mesh.TF
    B_dofs_TF_EN = dict()
    B_dofs_TF_NA = dict()
    for bn in B_dofs_TF_dict:
        if bn == 'x_minus':
            dofs_on_side = B_dofs_TF_dict[bn]
            for m  in dofs_on_side:
                if m not in B_dofs_TF_EN:
                    B_dofs_TF_EN[m] = list()
                B_dofs_TF_EN[m].extend(dofs_on_side[m])
        else:
            dofs_on_side = B_dofs_TF_dict[bn]
            for m  in dofs_on_side:
                if m not in B_dofs_TF_NA:
                    B_dofs_TF_NA[m] = list()
                B_dofs_TF_NA[m].extend(dofs_on_side[m])


    # define the basis functions
    _bfN_ = 'Lobatto-' + str(N)
    mbf = MimeticBasisPolynomials(_bfN_, _bfN_, _bfN_)

    # generate incidence matrix, trace matrices and mass matrices
    E = E_div(N, N, N)
    T = TF(N, N, N)
    MF = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                MM = MassMatrices(mbf, ct)
                MF.append(MM.FP)

    # reduction of source term f_exact, and u_exact
    f_exact_local = list()
    u_exact_local = list()
    f_L2 = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                RD = Reduction(mbf, ct)
                f_dofs_local = RD.VP(f_exact)
                f_exact_local.append(f_dofs_local)
                L2e = L2Error(mbf, ct)
                f_L2_local = L2e.VP(f_dofs_local, f_exact)
                f_L2.append(f_L2_local**2)
                _temp_ = RD.FP((u_exact, v_exact, w_exact))
                u_exact_local.append(spspa.csc_matrix(
                    _temp_[:,np.newaxis]))

    f_L2 = np.sum(f_L2)**0.5

    # generate local systems
    invA_List = list()
    g_List = list()
    B_List = list()

    RSA = list() # reduced systems Ax = b
    RSb = list() # reduced systems Ax = b

    num_basis_FP = 3*(N+1)*N**2
    num_basis_VP = N**3
    num_basis_TF = 6 * N**2
    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                A00 = MF[m]
                A01 = E.T
                A02 = - T.T

                A10 = E
                A11 = spspa.csc_matrix((num_basis_VP, num_basis_VP))
                A12 = spspa.csc_matrix((num_basis_VP, num_basis_TF))

                A20 = A02.T.tolil()
                A21 = spspa.csc_matrix((num_basis_TF, num_basis_VP))
                A22 = spspa.lil_matrix((num_basis_TF, num_basis_TF))

                b0 = spspa.csc_matrix((num_basis_FP, 1))
                b1 = spspa.csc_matrix(- f_exact_local[m][:,np.newaxis])
                b2 = spspa.lil_matrix((num_basis_TF, 1))

                # now apply local boundary condition
                if m in B_dofs_TF_EN:
                    EN_dofs = B_dofs_TF_EN[m]
                    A20[EN_dofs, :] = 0
                    A22[EN_dofs, EN_dofs] = 1

                if m in B_dofs_TF_NA:
                    NA_dofs = B_dofs_TF_NA[m]
                    A20[NA_dofs, :] = 0
                    A20[NA_dofs, B_dofs_FP[m]] = 1
                    b2[NA_dofs] = u_exact_local[m][B_dofs_FP[m]]

                A = spspa.bmat(([A00, A01],
                                [A10, A11]), format='csc')
                invA = spspalinalg.inv(A)
                invA_List.append(invA)
                B = spspa.bmat(([A02],[A12]), format='csc')
                B_List.append(B)
                C = spspa.bmat(([A20, A21],), format='csc')
                D = A22.tocsc()

                g = spspa.bmat(([b0],[b1]), format='csc')
                g_List.append(g)
                h = b2.tocsc()

                __ = C @ invA
                rsA = D - __ @ B
                rsb = h - __ @ g

                RSA.append(rsA)
                RSb.append(rsb)

    # assemble the reduce systems to a global system
    A = assemble(RSA, GM_TF, GM_TF)
    b = assemble(RSb, GM_TF)

    # solve the global system using the direct solver provided by scipy
    lamb = spspalinalg.spsolve(A, b)  # solve Ax=b, obtain x
    del A, b

    lambda_local = lamb[GM_TF]
    u_local = list()
    phi_local = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                u_phi_local_m = invA_List[m] @ (g_List[m].toarray().ravel('F') - B_List[m] @ lambda_local[m])
                u_local_m = u_phi_local_m[:num_basis_FP]
                phi_local_m = u_phi_local_m[num_basis_FP:]
                u_local.append(u_local_m)
                phi_local.append(phi_local_m)

    u_local = np.array(u_local)
    phi_local = np.array(phi_local)

    div_u_local = (- E @ u_local.T).T
    div_u_plus_f = (E @ u_local.T + np.array(f_exact_local).T).T

    dual_gradient = (-E.T, T.T)

    dG_phi = list()

    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                dG_phi_m = dual_gradient[0] @ phi_local[m] + dual_gradient[1] @ lambda_local[m]
                dG_phi.append(dG_phi_m)

    dG_phi = np.array(dG_phi)
    u_L2 = list()
    div_u_L2 = list()
    phi_L2 = list()
    div_L2 = list()
    div_Linf = list()
    phi_dH1 = list()
    dH1_error = list()
    for k in range(K):
        for j in range(K):
            for i in range(K):
                m = i + j * K + k * K**2
                ct = crazy_mesh.CT_of_element_index(i, j, k)
                L2E = L2Error(mbf, ct)
                L2Ed = L2ErrorDual(mbf, ct)

                u_l2_m = L2E.FP(u_local[m], (u_exact, v_exact, w_exact))
                phi_dH1_m = L2Ed.FP(dG_phi[m], (u_exact, v_exact, w_exact))

                dH1_error_m = L2Ed._FP_diff_(dG_phi[m], u_local[m], (zero, zero, zero))


                div_u_l2_m = L2E.VP(div_u_local[m], f_exact)
                phi_l2_m = L2Ed.VP(phi_local[m], phi_exact)


                div_L2_m = L2E.VP(div_u_plus_f[m], zero)
                div_Linf_m = L2E.VP(div_u_plus_f[m], zero, n='infty')
                u_L2.append(u_l2_m**2)
                div_u_L2.append(div_u_l2_m**2)
                phi_L2.append(phi_l2_m**2)
                div_L2.append(div_L2_m**2)
                div_Linf.append(div_Linf_m)
                phi_dH1.append(phi_dH1_m**2)

                dH1_error.append(dH1_error_m**2)





    u_L2 = np.sum(u_L2)**0.5
    div_u_L2 = np.sum(div_u_L2)**0.5
    phi_L2 = np.sum(phi_L2)**0.5
    div_L2 = np.sum(div_L2)**0.5
    div_Linf = np.max(div_Linf)
    phi_dH1 = np.sum(phi_dH1)**0.5
    dH1_error = np.sum(dH1_error)**0.5

    u_Hdiv = np.sqrt(u_L2**2 + div_u_L2**2)
    phi_dH1 = np.sqrt(phi_L2**2 + phi_dH1**2)
    # print info and return
    print('hdMSEM')
    print("L^2-error of u^h: ", u_L2)
    print("H(div)-error of u^h: ", u_Hdiv)
    print("L^2-error of phi^h: ", phi_L2)
    print("L^2-error of projection f^h-f: ", f_L2)
    print("L^2-error of div(u^h)+f^h: ", div_L2)
    print("L^inf-error of div(u^h)+f^h: ", div_Linf)
    print("dual_H^1-error pf phi^h: ", phi_dH1)
    print("L^2-error of dual gradient phi^h minus u: ", dH1_error)

    if save:
        name_temp = f'results/hdMSEM_K{K}_N{N}_c{c1000}_'
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'u.txt', u_local)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'phi.txt', phi_local)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'du_plus_f.txt', div_u_plus_f)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'lambda.txt', lambda_local)
        # noinspection PyTypeChecker
        np.savetxt(name_temp + 'dGphi.txt', dG_phi)

    return u_L2, u_Hdiv, phi_L2, f_L2, div_L2, div_Linf, phi_dH1, dH1_error


if __name__ == '__main__':
    import doctest
    doctest.testmod()

    K = 2
    N = 3
    c = 0.25

    Poisson_hd(K, N, c, save=False)