module PlotErrors
using CairoMakie
using DelimitedFiles

j_nx = 1
j_ny = 2
j_nc = 3
j_ϕmax = 4
j_p = 5
j_tfinal = 6
j_L1 = 7
j_L2 = 8
j_Linf = 9

markers = Iterators.cycle([
    :circle,
    :rect,
    :diamond,
    :hexagon,
    :cross,
    :xcross,
    :utriangle,
    :dtriangle,
    :ltriangle,
    :rtriangle,
    :pentagon,
    :star4,
    :star5,
    :star6,
    :star8,
    :vline,
    :hline,
])

filepath = joinpath(@__DIR__, "mesh-convergence-approx-ref.csv")
@show filepath
data = readdlm(filepath, ','; comments=true, comment_char='#')

fontsize = 20

function all_errors()
    with_theme(theme_latexfonts()) do
        # All errors on one figure
        for p in (0, 1)
            rows = findall(row -> Int(row[j_p]) == p, eachrow(data))
            if length(rows) == 0
                @warn "didn't find entry for p = $p"
                continue
            end

            all_x = sort(unique(Int.(data[rows, j_nx])))

            f = Figure(; size=(800, 500), fontsize)
            opts = (;
                xscale=log2,
                yscale=log2,
                xminorticksvisible=true,
                xminorgridvisible=true,
                xminorticks=IntervalsBetween(5),
                yminorticksvisible=true,
                yminorgridvisible=true,
                yminorticks=IntervalsBetween(5),
            )
            ax_L1 = Axis(f[1, 1]; ylabel=L"\log_2(L_1)", opts...)
            ax_L2 = Axis(f[2, 1]; ylabel=L"\log_2(L_2)", opts...)
            ax_Linf = Axis(
                f[3, 1];
                xlabel=L"$k$ such that $n_x=n_y=2^k$",
                ylabel=L"\log_2(L_\infty)",
                opts...,
            )

            # Loop over ϕmax
            min_L1 = min_L2 = min_Linf = Inf
            max_L1 = max_L2 = max_Linf = -Inf
            for (ϕmax, marker) in zip(0.2:0.1:0.9, markers)
                rows = findall(
                    row -> Int(row[j_p]) == p && abs(row[j_ϕmax] - ϕmax) < 1e-4,
                    eachrow(data),
                )
                if length(rows) == 0
                    @warn "didn't find entry for ϕmax = $ϕmax"
                    continue
                end
                # marker = first(markers)
                x = data[rows, j_nx]
                I = sortperm(x)
                x .= x[I]

                label = "$ϕmax"
                markersize = 12
                strokewidth = 1
                scatterlines!(ax_L1, x, data[rows[I], j_L1]; label, marker, markersize, strokewidth)
                scatterlines!(ax_L2, x, data[rows[I], j_L2]; label, marker, markersize, strokewidth)
                scatterlines!(ax_Linf, x, data[rows[I], j_Linf]; label, marker, markersize, strokewidth)

                min_L1 = min(min_L1, minimum(data[rows, j_L1]))
                min_L2 = min(min_L2, minimum(data[rows, j_L2]))
                min_Linf = min(min_Linf, minimum(data[rows, j_Linf]))
                max_L1 = max(max_L1, maximum(data[rows, j_L1]))
                max_L2 = max(max_L2, maximum(data[rows, j_L2]))
                max_Linf = max(max_Linf, maximum(data[rows, j_Linf]))
            end

            # Add slope triangle
            x1 = all_x[end-1]
            x2 = all_x[end]
            linewidth = 2
            for (ax, m, M) in zip(
                (ax_L1, ax_L2, ax_Linf),
                (min_L1, min_L2, min_Linf),
                (max_L1, max_L2, max_Linf),
            )
                a = -(p + 1)
                b = (log2(m) - a * log2(x2)) + 0.05 * (log2(M) - log2(m))
                y1 = exp(log(2) * (a * log2(x1) + b))
                y2 = exp(log(2) * (a * log2(x2) + b))

                _x = lines!(ax, [x1, x2, x2, x1], [y1, y2, y1, y1]; color=:black, linewidth)
                text!(
                    ax,
                    x2,
                    (y1 + y2) / 2;
                    text=string(abs(a)),
                    align=(:left, :center),
                    offset=(5, 0),
                )
                text!(
                    ax,
                    (x1 + x2) / 2,
                    y1;
                    text="1",
                    align=(:center, :bottom),
                    offset=(0, 5),
                )
            end

            # xticks
            xticks = all_x
            for ax in (ax_L1, ax_L2, ax_Linf)
                ax.xticks = (xticks, string.(xticks))
            end

            # Legend
            Legend(f[2, 2], ax_L1, L"d_{max}")

            # Fig title
            Label(
                f[0, :],
                L"$L_1$, $L_2$ and $L_\infty$ errors for $p=%$p$";
                fontsize=24,
                halign=:center,
            )

            display(f)
            save(joinpath(@__DIR__, "fig-all-errors-p$(p).png"), f)
            save(joinpath(@__DIR__, "fig-all-errors-p$(p).svg"), f)
            save(joinpath(@__DIR__, "fig-all-errors-p$(p).pdf"), f)
        end
    end

end

function error_l2()

    with_theme(theme_latexfonts()) do
        #################### Only L2 error for p=1
        p = 1
        rows = findall(row -> Int(row[j_p]) == p, eachrow(data))

        all_x = sort(unique(Int.(data[rows, j_nx])))

        f = Figure(; size=(800, 500), fontsize)
        opts = (;
            xscale=log2,
            yscale=log2,
            xminorticksvisible=true,
            xminorgridvisible=true,
            xminorticks=IntervalsBetween(5),
            yminorticksvisible=true,
            yminorgridvisible=true,
            yminorticks=IntervalsBetween(5),
        )
        ax = Axis(f[1, 1];
            xlabel=L"$k$ such that $n_x=n_y=2^k$", ylabel=L"\log_2(L_2)", opts...)

        # Loop over ϕmax
        m = Inf
        M = -Inf
        for (ϕmax, marker) in zip(0.2:0.1:0.9, markers)
            rows = findall(
                row -> Int(row[j_p]) == p && abs(row[j_ϕmax] - ϕmax) < 1e-4,
                eachrow(data),
            )
            if length(rows) == 0
                @warn "didn't find entry for ϕmax = $ϕmax"
                continue
            end
            # marker = first(markers)
            x = data[rows, j_nx]
            I = sortperm(x)
            x .= x[I]

            label = "$ϕmax"
            markersize = 12
            strokewidth = 1
            scatterlines!(ax, x, data[rows[I], j_L2]; label, marker, markersize, strokewidth)

            m = min(m, minimum(data[rows, j_L2]))
            M = max(M, maximum(data[rows, j_L2]))
        end

        # Add slope triangle
        x1 = all_x[end-1]
        x2 = all_x[end]
        linewidth = 2

        a = -(p + 1)
        b = (log2(m) - a * log2(x2)) + 0.05 * (log2(M) - log2(m))
        y1 = exp(log(2) * (a * log2(x1) + b))
        y2 = exp(log(2) * (a * log2(x2) + b))

        _x = lines!(ax, [x1, x2, x2, x1], [y1, y2, y1, y1]; color=:black, linewidth)
        text!(
            ax,
            x2,
            (y1 + y2) / 2;
            text=string(abs(a)),
            align=(:left, :center),
            offset=(5, 0),
        )
        text!(
            ax,
            (x1 + x2) / 2,
            y1;
            text="1",
            align=(:center, :bottom),
            offset=(0, 5),
        )

        # xticks
        xticks = all_x
        ax.xticks = (xticks, string.(xticks))

        # Legend
        Legend(f[1, 2], ax, L"d_{max}")

        display(f)
        save(joinpath(@__DIR__, "fig-L2-error-p$(p).png"), f)
        save(joinpath(@__DIR__, "fig-L2-error-p$(p).svg"), f)
        save(joinpath(@__DIR__, "fig-L2-error-p$(p).pdf"), f)
    end
end

function comp_fit_unfit()

    with_theme(theme_latexfonts()) do

        fitted = readdlm(joinpath(@__DIR__, "mesh-convergence-fitted.csv"), ','; comments=true, comment_char='#')

        p = 1
        ϕmax = 0.2
        rows = findall(
            row -> Int(row[j_p]) == p && abs(row[j_ϕmax] - ϕmax) < 1e-4,
            eachrow(data),
        )

        f = Figure(; size=(1000, 500), fontsize)
        opts = (;
            xscale=log2,
            yscale=log2,
            xminorticksvisible=true,
            xminorgridvisible=true,
            xminorticks=IntervalsBetween(5),
            yminorticksvisible=true,
            yminorgridvisible=true,
            yminorticks=IntervalsBetween(5),
        )


        ax1 = Axis(f[1, 1]; xlabel="Total number of dofs", ylabel=L"\log_2(L_2)", opts...)

        x = data[rows, j_nc]
        I = sortperm(x)
        x .= x[I]

        label = "unfitted"
        lines!(ax1, 4 .* x, data[rows[I], j_L2]; label, color=:blue)

        lines!(ax1, 2 .* fitted[:, 1], fitted[:, 5]; label="fitted", color=:red)

        axislegend(ax1, position=:lb)

        ax2 = Axis(f[1, 2]; xlabel=L"Mesh caracteristic size $h$", ylabel=L"\log_2(L_2)", opts...)

        x = data[rows, j_nx]
        I = sortperm(x)
        x .= x[I]

        label = "unfitted"
        lines!(ax2, 3.0 ./ (x .+ 1), data[rows[I], j_L2]; label, color=:blue)

        lines!(ax2, 2π ./ fitted[:, 1], fitted[:, 5]; label="fitted", color=:red)

        axislegend(ax2, position=:rb)

        display(f)
        save(joinpath(@__DIR__, "fig-L2-error-fit-unfit.png"), f)
        save(joinpath(@__DIR__, "fig-L2-error-fit-unfit.svg"), f)
        save(joinpath(@__DIR__, "fig-L2-error-fit-unfit.pdf"), f)
    end

end

comp_fit_unfit()

end