#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt

from filval.result_set import ResultSet
from filval.histogram import hist, hist_integral, hist_rebin, hist_norm, hist2d, hist2d_percent_contour
from filval.plotting import (decl_plot, render_plots, hist_plot, hist2d_plot,
                             Plot, generate_dashboard, simple_plot)

matching_cuts = {
    'extra-narrow-window': [
        dict(
            dPhiMaxHighEt=0.025,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.0015,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.025,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.0015,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.025,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'narrow-window': [
        dict(
            dPhiMaxHighEt=0.05,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.003,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.05,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.003,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.05,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'wide-window': [
        dict(
            dPhiMaxHighEt=0.10,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.006,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.10,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.006,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.10,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'extra-wide-window': [
        dict(
            dPhiMaxHighEt=0.15,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.009,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.15,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.009,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.15,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'nwp-tight-window': [
        dict(
            dPhiMaxHighEt=0.025,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.005,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.07,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.006,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.0001,
            dRzMaxHighEt=0.08,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'nwp-window': [
        dict(
            dPhiMaxHighEt=0.05,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.002,
            dRzMaxHighEt=9999.0,
            dRzMaxHighEtThres=0.0,
            dRzMaxLowEtGrad=0.0,
        ),
        dict(
            dPhiMaxHighEt=0.005,
            dPhiMaxHighEtThres=0.0,
            dPhiMaxLowEtGrad=0.0,
            dRzMaxHighEt=0.07,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        ),
        dict(
            dPhiMaxHighEt=0.006,
            dPhiMaxHighEtThres=20.0,
            dPhiMaxLowEtGrad=-0.0001,
            dRzMaxHighEt=0.08,
            dRzMaxHighEtThres=30.0,
            dRzMaxLowEtGrad=-0.002,
        )
    ],
    'nwp-eta-breakdown': [
        dict(
            dPhiMaxHighEt=[0.05, 0.07, 0.06],
            dPhiMaxHighEtThres=[25.0, 25.0, 25.0],
            dPhiMaxLowEtGrad=[-0.002, -0.006, -0.002],
            dRzMaxHighEt=[9999.0, 9999.0, 9999.0],
            dRzMaxHighEtThres=[0.0, 0.0, 0.0],
            dRzMaxLowEtGrad=[0.0, 0.0, 0.0],
            etaBins = [1.1, 1.8]
        ),
        dict(
            dPhiMaxHighEt=[0.0035, 0.006, 0.007],
            dPhiMaxHighEtThres=[0.0, 0.0, 0.0],
            dPhiMaxLowEtGrad=[0.0, 0.0, 0.0],
            dRzMaxHighEt=[0.045, 0.08, 0.045],
            dRzMaxHighEtThres=[30.0, 30.0, 30.0],
            dRzMaxLowEtGrad=[-0.002, -0.006, -0.002],
            etaBins=[1.4, 2.3]
        ),
        dict(
            dPhiMaxHighEt=[0.006, 0.007, 0.007],
            dPhiMaxHighEtThres=[0.0, 20, 20],
            dPhiMaxLowEtGrad=[0.0, -0.0002, -0.0002],
            dRzMaxHighEt=[0.04, 0.10, 0.60],
            dRzMaxHighEtThres=[25.0, 25.0, 25.0],
            dRzMaxLowEtGrad=[-0.007, -0.007, -0.007],
            etaBins=[1.0, 2.0]
        )
    ],
}


def calc_window(et, eta, hit, variable, cut_sel):
    idx = min(hit-1, 2)
    cuts = matching_cuts[cut_sel][idx]
    if 'etaBins' in cuts:
        for eta_idx, bin_high in enumerate(cuts['etaBins']):
            if eta < bin_high:
                high_et = cuts[f'{variable}MaxHighEt'][eta_idx]
                high_et_thres = cuts[f'{variable}MaxHighEtThres'][eta_idx]
                low_et_grad = cuts[f'{variable}MaxLowEtGrad'][eta_idx]
                break
        else:  # highest bin
            high_et = cuts[f'{variable}MaxHighEt'][-1]
            high_et_thres = cuts[f'{variable}MaxHighEtThres'][-1]
            low_et_grad = cuts[f'{variable}MaxLowEtGrad'][-1]
    else:
        high_et = cuts[f'{variable}MaxHighEt']
        high_et_thres = cuts[f'{variable}MaxHighEtThres']
        low_et_grad = cuts[f'{variable}MaxLowEtGrad']
    return high_et + min(0, et-high_et_thres)*low_et_grad


def center_text(x, y, txt, **kwargs):
    plt.text(x, y, txt,
             horizontalalignment='center', verticalalignment='center',
             transform=plt.gca().transAxes, **kwargs)


def hist_integral_ratio(num, den):
    num_int = hist_integral(num, times_bin_width=False)
    den_int = hist_integral(den, times_bin_width=False)

    ratio = num_int / den_int
    error = np.sqrt(den_int) / den_int  # TODO: Check this definition of error
    return ratio, error


@decl_plot
def plot_residuals(rs, layer, hit, variable, subdet, cut_sel=None):

    h_real = hist2d(getattr(rs, f'{variable}_{subdet}_L{layer}_H{hit}_v_Et_TrackMatched'))
    h_fake = hist2d(getattr(rs, f'{variable}_{subdet}_L{layer}_H{hit}_v_Et_NoMatch'))


    def do_plot(h):
        hist2d_plot(h, colorbar=True)

        xs, ys = hist2d_percent_contour(h, .90, 'x')
        plt.plot(xs, ys, color='green', label='90\% contour')
        xs, ys = hist2d_percent_contour(h, .995, 'x')
        plt.plot(xs, ys, color='darkgreen', label='99.5\% contour')

        if cut_sel:
            ets = h[3][:, 0]
            cuts = [calc_window(et, 0, hit, variable, cut_sel) for et in ets]
            plt.plot(cuts, ets, color='red', label='Cut Value')
        plt.xlabel({'dPhi': r'$\delta \phi$ (rads)',
                    'dRz': r'$\delta R/z$ (cm)'}[variable])

    plt.sca(plt.subplot(1, 2, 1))
    do_plot(h_real)
    plt.title('Truth-Matched Seeds')
    plt.ylabel('$E_T$ (GeV)')

    plt.sca(plt.subplot(1, 2, 2))
    do_plot(h_fake)
    plt.title('Not Truth-Matched Seeds')
    plt.legend(loc='upper right')


@decl_plot
def plot_residuals_eta(rs, hit, variable):

    h = hist2d(getattr(rs, f'{variable}_residuals_v_eta_H{hit}'))

    hist2d_plot(h, colorbar=True)

    xs, ys = hist2d_percent_contour(h, .90, 'x')
    plt.plot(xs, ys, color='green', label='90\% contour')
    xs, ys = hist2d_percent_contour(h, .995, 'x')
    plt.plot(xs, ys, color='darkgreen', label='99.5\% contour')

    plt.xlabel({'dPhi': r'$\delta \phi$ (rads)',
                'dRz': r'$\delta R/z$ (cm)'}[variable])
    plt.ylabel(r'$\eta$ (GeV)')


@decl_plot
def plot_seed_eff(rs):
    r"""## ECAL-Driven Seeding Efficiency

    The proportion of gen-level electrons originating in the luminous region that have
    an associated Seed, matched via rechit-simhit associations in the pixel detector. Cuts are on simtrack quantities.
    """
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)

    errors = True
    plt.sca(ax_pt)
    hist_plot(hist(rs.seed_eff_v_pt), include_errors=errors)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    hist_plot(hist(rs.seed_eff_v_eta), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Sim-Track $\eta$")
    plt.ylim((0, 1.1))

    plt.sca(ax_phi)
    hist_plot(hist(rs.seed_eff_v_phi), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_tracking_eff(rs):
    r"""## GSF Tracking Efficiency

    The proportion of electrons origination in the luminous region from the that have
    an associated GSF track. Cuts are on simtrack quantities.
    """
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)
    ax_eta_pt = plt.subplot(224)

    errors = True
    plt.sca(ax_pt)
    hist_plot(hist(rs.tracking_eff_v_pt), include_errors=errors)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    hist_plot(hist(rs.tracking_eff_v_eta), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Sim-Track $\eta$")
    plt.ylim((0, 1.1))

    plt.sca(ax_phi)
    hist_plot(hist(rs.tracking_eff_v_phi), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $\phi$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta_pt)
    hist2d_plot(hist2d(rs.tracking_eff_v_eta_pt))
    plt.xlabel(r"Sim-Track $\eta$")
    plt.ylabel(r"Sim-Track $p_T$")
    plt.colorbar()


@decl_plot
def plot_seed_purity(rs, ext=""):
    r"""## ECAL-Driven Seed Purity

    The proportion of ECAL-driven seeds that have a matched gen-level electron originating in
    the luminous region. Cuts are on seed quantities.
    """
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)

    def get_hist(base_name):
        return hist(getattr(rs, base_name+ext))

    errors = True
    plt.sca(ax_pt)
    hist_plot(get_hist("seed_pur_v_pt"), include_errors=errors)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Seed $p_T$")
    if not ext:
        plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    hist_plot(get_hist("seed_pur_v_eta"), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Seed $\eta$")
    if not ext:
        plt.ylim((0, 1.1))

    plt.sca(ax_phi)
    hist_plot(get_hist("seed_pur_v_phi"), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Seed $\phi$")
    if not ext:
        plt.ylim((0, 1.1))


@decl_plot
def plot_track_purity(rs, ext=""):
    r"""## GSF Track Purity

    The proportion of GSF-tracks w\ ECAL-driven seeds that have a matched gen-level electron originating in
    the luminous region. Cuts are on GSF track quantities.
    """
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)

    def get_hist( base_name):
        return hist(getattr(rs, base_name+ext))

    errors = True
    plt.sca(ax_pt)
    hist_plot(get_hist("tracking_pur_v_pt"), include_errors=errors)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"GSF-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    hist_plot(get_hist("tracking_pur_v_eta"), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"GSF-Track $\eta$")
    plt.ylim((0, 1.1))

    plt.sca(ax_phi)
    hist_plot(get_hist("tracking_pur_v_phi"), include_errors=errors)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"GSF-Track $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_hit_vs_layer(rs, region):

    h = hist2d(getattr(rs, f'hit_vs_layer_{region}'))

    hist2d_plot(h, txt_format='{:2.0f}')
    plt.xlabel('Layer \#')
    plt.ylabel('Hit \#')


def single_cut_plots(cut_sel):
    rs = ResultSet(f'{cut_sel}', f'../hists/{cut_sel}.root')

    seed_eff = plot_seed_eff, (rs,)
    tracking_eff = plot_tracking_eff, (rs,)

    seed_pur = plot_seed_purity, (rs,)

    track_pur = plot_track_purity, (rs,)
    track_pur_seed_match = plot_track_purity, (rs,), dict(ext='2')

    BPIX_residuals_L1_H1_dPhi = plot_residuals, (rs, 1, 1, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L2_H2_dPhi = plot_residuals, (rs, 2, 2, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L3_H3_dPhi = plot_residuals, (rs, 3, 3, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)

    BPIX_residuals_L1_H1_dRz = plot_residuals, (rs, 1, 1, 'dRz', 'BPIX')
    BPIX_residuals_L2_H2_dRz = plot_residuals, (rs, 2, 2, 'dRz', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L3_H3_dRz = plot_residuals, (rs, 3, 3, 'dRz', 'BPIX'), dict(cut_sel=cut_sel)

    BPIX_residuals_L2_H1_dPhi = plot_residuals, (rs, 2, 1, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L3_H2_dPhi = plot_residuals, (rs, 3, 2, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L4_H3_dPhi = plot_residuals, (rs, 4, 3, 'dPhi', 'BPIX'), dict(cut_sel=cut_sel)

    BPIX_residuals_L2_H1_dRz = plot_residuals, (rs, 2, 1, 'dRz', 'BPIX')
    BPIX_residuals_L3_H2_dRz = plot_residuals, (rs, 3, 2, 'dRz', 'BPIX'), dict(cut_sel=cut_sel)
    BPIX_residuals_L4_H3_dRz = plot_residuals, (rs, 4, 3, 'dRz', 'BPIX'), dict(cut_sel=cut_sel)

    FPIX_residuals_L1_H1_dPhi = plot_residuals, (rs, 1, 1, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L2_H2_dPhi = plot_residuals, (rs, 2, 2, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L3_H3_dPhi = plot_residuals, (rs, 3, 3, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)

    FPIX_residuals_L1_H1_dRz = plot_residuals, (rs, 1, 1, 'dRz', 'FPIX')
    FPIX_residuals_L2_H2_dRz = plot_residuals, (rs, 2, 2, 'dRz', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L3_H3_dRz = plot_residuals, (rs, 3, 3, 'dRz', 'FPIX'), dict(cut_sel=cut_sel)

    FPIX_residuals_L1_H2_dPhi = plot_residuals, (rs, 1, 2, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L1_H3_dPhi = plot_residuals, (rs, 1, 3, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L2_H3_dPhi = plot_residuals, (rs, 2, 3, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)

    FPIX_residuals_L1_H2_dRz = plot_residuals, (rs, 1, 2, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L1_H3_dRz = plot_residuals, (rs, 1, 3, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)
    FPIX_residuals_L2_H3_dRz = plot_residuals, (rs, 2, 3, 'dPhi', 'FPIX'), dict(cut_sel=cut_sel)

    hit_vs_layer_barrel = plot_hit_vs_layer, (rs, 'barrel')
    hit_vs_layer_forward = plot_hit_vs_layer, (rs, 'forward')

    dRz_residuals_v_eta_H1 = plot_residuals_eta, (rs, 1, 'dRz')
    dRz_residuals_v_eta_H2 = plot_residuals_eta, (rs, 2, 'dRz')
    dRz_residuals_v_eta_H3 = plot_residuals_eta, (rs, 3, 'dRz')

    dPhi_residuals_v_eta_H1 = plot_residuals_eta, (rs, 1, 'dPhi')
    dPhi_residuals_v_eta_H2 = plot_residuals_eta, (rs, 2, 'dPhi')
    dPhi_residuals_v_eta_H3 = plot_residuals_eta, (rs, 3, 'dPhi')

    plots = [
        Plot(BPIX_residuals_L1_H1_dPhi, 'Phi Residuals Layer 1 Hit 1 - BPIX'),
        Plot(BPIX_residuals_L2_H2_dPhi, 'Phi Residuals Layer 2 Hit 2 - BPIX'),
        Plot(BPIX_residuals_L3_H3_dPhi, 'Phi Residuals Layer 3 Hit 3 - BPIX'),
        Plot(BPIX_residuals_L1_H1_dRz, 'dZ Residuals Layer 1 Hit 1 without cuts - BPIX'),
        Plot(BPIX_residuals_L2_H2_dRz, 'dZ Residuals Layer 2 Hit 2 - BPIX'),
        Plot(BPIX_residuals_L3_H3_dRz, 'dZ Residuals Layer 3 Hit 3 - BPIX'),

        Plot(BPIX_residuals_L2_H1_dPhi, 'Phi Residuals Layer 2 Hit 1 - BPIX'),
        Plot(BPIX_residuals_L3_H2_dPhi, 'Phi Residuals Layer 3 Hit 2 - BPIX'),
        Plot(BPIX_residuals_L4_H3_dPhi, 'Phi Residuals Layer 4 Hit 3 - BPIX'),
        Plot(BPIX_residuals_L2_H1_dRz, 'dZ Residuals Layer 2 Hit 1 without cuts - BPIX'),
        Plot(BPIX_residuals_L3_H2_dRz, 'dZ Residuals Layer 3 Hit 2 - BPIX'),
        Plot(BPIX_residuals_L4_H3_dRz, 'dZ Residuals Layer 4 Hit 3 - BPIX'),

        Plot(FPIX_residuals_L1_H1_dPhi, 'Phi Residuals Layer 1 Hit 1 - FPIX'),
        Plot(FPIX_residuals_L2_H2_dPhi, 'Phi Residuals Layer 2 Hit 2 - FPIX'),
        Plot(FPIX_residuals_L3_H3_dPhi, 'Phi Residuals Layer 3 Hit 3 - FPIX'),
        Plot(FPIX_residuals_L1_H1_dRz, 'dR Residuals Layer 1 Hit 1 without cuts - FPIX'),
        Plot(FPIX_residuals_L2_H2_dRz, 'dR Residuals Layer 2 Hit 2 - FPIX'),
        Plot(FPIX_residuals_L3_H3_dRz, 'dR Residuals Layer 3 Hit 3 - FPIX'),

        Plot(FPIX_residuals_L1_H2_dPhi, 'Phi Residuals Layer 1 Hit 2 - FPIX'),
        Plot(FPIX_residuals_L1_H3_dPhi, 'Phi Residuals Layer 1 Hit 3 - FPIX'),
        Plot(FPIX_residuals_L2_H3_dPhi, 'Phi Residuals Layer 2 Hit 3 - FPIX'),
        Plot(FPIX_residuals_L1_H2_dRz, 'dR Residuals Layer 1 Hit 2 - FPIX'),
        Plot(FPIX_residuals_L1_H3_dRz, 'dR Residuals Layer 1 Hit 3 - FPIX'),
        Plot(FPIX_residuals_L2_H3_dRz, 'dR Residuals Layer 2 Hit 3 - FPIX'),

        Plot(seed_eff, 'ECAL-Driven Seeding Efficiency'),
        Plot(tracking_eff, 'GSF Tracking Efficiency'),
        Plot(hit_vs_layer_barrel, 'Hit vs Layer - Barrel'),
        Plot(hit_vs_layer_forward, 'Hit vs Layer - Forward'),
        Plot(seed_pur, 'ECAL-Driven Seeding Purity'),
        Plot(track_pur, 'GSF Track Purity'),
        Plot(track_pur_seed_match , 'GSF Track Purity (Seed Truth Match)'),
        simple_plot(rs.gsf_tracks_nmatch_sim_tracks, log='y'),
        Plot(dRz_residuals_v_eta_H1, 'dRz Hit 1 Residuals v eta'),
        Plot(dRz_residuals_v_eta_H2, 'dRz Hit 2 Residuals v eta'),
        Plot(dRz_residuals_v_eta_H3, 'dRz Hit 3 Residuals v eta'),
        Plot(dPhi_residuals_v_eta_H1, 'dPhi Hit 1 Residuals v eta'),
        Plot(dPhi_residuals_v_eta_H2, 'dPhi Hit 2 Residuals v eta'),
        Plot(dPhi_residuals_v_eta_H3, 'dPhi Hit 3 Residuals v eta'),
    ]

    render_plots(plots, directory='output/figures/'+rs.sample_name, to_disk=to_disk)
    if not to_disk:
        generate_dashboard(plots, 'Seeding Efficiency',
                           output=f'{rs.sample_name}.html',
                           source=__file__,
                           config=rs.config)


def eta_region_plots(cut_sel):
    rs = ResultSet(f'{cut_sel}', f'../hists/{cut_sel}.root')

    @decl_plot
    def residual_in_region(var, hit):
        for region in (1, 2, 3):
            h = hist2d(getattr(rs, f'{var}_residuals_H{hit}_R{region}'))

            xs, ys = hist2d_percent_contour(h, .99, 'x')
            plt.plot(xs, ys, label=f'99\%, Region {region}')
        plt.legend()

        plt.xlabel({'dPhi': r'$\delta \phi$ (rads)',
                    'dRz': r'$\delta R/z$ (cm)'}[var])
        plt.ylabel('$E_T$ (GeV)')

    plots = []
    for hit in (1, 2, 3):
        plt_tup = residual_in_region, ('dPhi', hit)
        plots.append(Plot(plt_tup, f'dPhi residuals, hit {hit}'))

        plt_tup = residual_in_region, ('dRz', hit)
        plots.append(Plot(plt_tup, f'dRz residuals, hit {hit}'))

    render_plots(plots, directory='output/figures/'+rs.sample_name, to_disk=to_disk)
    if not to_disk:
        generate_dashboard(plots, 'Breakdown by Eta Region',
                           output=f'{rs.sample_name}-eta-regions.html',
                           source=__file__,
                           config=rs.config)


@decl_plot
def plot_seed_roc_curve(rss):
    def get_num_den(rs, basename):
        num = hist(getattr(rs, f'{basename}_num'))
        den = hist(getattr(rs, f'{basename}_den'))
        return hist_integral_ratio(num, den)
    for rs in rss:
        eff, eff_err = get_num_den(rs, 'seed_eff_v_phi')
        pur, pur_err = get_num_den(rs, 'seed_pur_v_phi')
        if rs.sample_name == 'old-seeding':
            plt.errorbar([pur], [eff], xerr=[pur_err], yerr=[eff_err], label=rs.sample_name, color='k', marker='o')
        else:
            plt.errorbar([pur], [eff], xerr=[pur_err], yerr=[eff_err], label=rs.sample_name[:-7], marker='o')
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.axis('equal')
    plt.xlim((0.8, 1.0))
    plt.ylim((0.8, 1.0))
    plt.xlabel('ECAL-Driven Seeding Purity')
    plt.ylabel('ECAL-Driven Seeding Efficiency')
    plt.grid()
    plt.legend()


@decl_plot
def plot_seed_eff_all(rss):
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)
    errors = True
    for rs in rss:
        plt.sca(ax_pt)
        hist_plot(hist(rs.seed_eff_v_pt), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_eta)
        hist_plot(hist(rs.seed_eff_v_eta), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_phi)
        hist_plot(hist(rs.seed_eff_v_phi), include_errors=errors, label=rs.sample_name)

    plt.sca(ax_pt)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Sim-Track $\eta$")
    plt.ylim((0, 1.1))
    plt.legend(loc='lower right')

    plt.sca(ax_phi)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_seed_pur_all(rss):
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)
    errors = True
    for rs in rss:
        plt.sca(ax_pt)
        hist_plot(hist(rs.seed_pur_v_pt), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_eta)
        hist_plot(hist(rs.seed_pur_v_eta), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_phi)
        hist_plot(hist(rs.seed_pur_v_phi), include_errors=errors, label=rs.sample_name)

    plt.sca(ax_pt)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Seed $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Seed $\eta$")
    plt.ylim((0, 1.1))
    plt.legend(loc='lower right')

    plt.sca(ax_phi)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Seed $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_tracking_roc_curve(rss, ext=''):
    def get_num_den(rs, basename):
        num = hist(getattr(rs, f'{basename}{ext}_num'))
        den = hist(getattr(rs, f'{basename}{ext}_den'))
        return hist_integral_ratio(num, den)
    for rs in rss:
        eff, eff_err = get_num_den(rs, 'tracking_eff_v_phi')
        pur, pur_err = get_num_den(rs, 'tracking_pur_v_phi')
        if rs.sample_name == 'old-seeding':
            plt.errorbar([pur], [eff], xerr=[pur_err], yerr=[eff_err], label=rs.sample_name, color='k', marker='o')
        else:
            plt.errorbar([pur], [eff], xerr=[pur_err], yerr=[eff_err], label=rs.sample_name[:-7], marker='o')
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.axis('equal')
    plt.xlim((0.8, 1.0))
    plt.ylim((0.8, 1.0))
    plt.xlabel('GSF-Track Purity')
    plt.ylabel('GSF-Track Efficiency')
    plt.grid()
    plt.legend()


@decl_plot
def plot_tracking_eff_all(rss, ext=''):
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)
    errors = True
    for rs in rss:
        plt.sca(ax_pt)
        hist_plot(hist(getattr(rs, f'tracking_eff_v_pt{ext}')), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_eta)
        hist_plot(hist(getattr(rs, f'tracking_eff_v_eta{ext}')), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_phi)
        hist_plot(hist(getattr(rs, f'tracking_eff_v_phi{ext}')), include_errors=errors, label=rs.sample_name)

    plt.sca(ax_pt)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"Sim-Track $\eta$")
    plt.ylim((0, 1.1))
    plt.legend(loc='lower right')

    plt.sca(ax_phi)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"Sim-Track $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_tracking_pur_all(rss, ext=''):
    ax_pt = plt.subplot(221)
    ax_eta = plt.subplot(222)
    ax_phi = plt.subplot(223)
    errors = True
    for rs in rss:
        plt.sca(ax_pt)
        hist_plot(hist(getattr(rs, f'tracking_pur_v_pt{ext}')), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_eta)
        hist_plot(hist(getattr(rs, f'tracking_pur_v_eta{ext}')), include_errors=errors, label=rs.sample_name)
        plt.sca(ax_phi)
        hist_plot(hist(getattr(rs, f'tracking_pur_v_phi{ext}')), include_errors=errors, label=rs.sample_name)

    plt.sca(ax_pt)
    center_text(0.5, 0.3, r'$|\eta|<2.4$')
    plt.xlabel(r"GSF-Track $p_T$")
    plt.ylim((0, 1.1))

    plt.sca(ax_eta)
    center_text(0.5, 0.3, r'$p_T>20$')
    plt.xlabel(r"GSF-Track $\eta$")
    plt.ylim((0, 1.1))
    plt.legend(loc='lower right')

    plt.sca(ax_phi)
    center_text(0.5, 0.3, r'$p_T>20$ and $|\eta|<2.4$')
    plt.xlabel(r"GSF-Track $\phi$")
    plt.ylim((0, 1.1))


@decl_plot
def plot_ecal_rel_res(rss):
    for rs in rss:
        hist_plot(hist(rs.ecal_energy_resolution), label=rs.sample_name)
    plt.xlabel(r"ECAL $E_T$ relative error")
    plt.legend()


@decl_plot
def plot_res_contour(rss, hit_number, var, layers, ext='_TrackMatched'):
    from itertools import chain
    _, axs = plt.subplots(2, 3)
    axs_all = list(chain(*axs))

    def do_plot(ax, rs):
        plt.sca(ax)
        plt.title(rs.sample_name)
        h = None
        for layer in layers:
            subdet = 'BPIX' if layer[0]=='B' else 'FPIX'
            h = hist2d(getattr(rs, f'{var}_{subdet}_L{layer[1]}_H{hit_number}_v_Et{ext}'))
            pass
            xs, ys = hist2d_percent_contour(h, .99, 'x')
            plt.plot(xs, ys, label=f'{subdet} - L{layer[1]}')

        ets = h[3][:, 0]
        cuts = [calc_window(et, 0, hit_number, var, rs.sample_name) for et in ets]
        plt.plot(cuts, ets, color='red', label='Cut Value')

    max_x = 0
    for ax, rs in zip(axs_all, rss):
        try:
            do_plot(ax, rs)
            _, x_up = ax.get_xlim()
            max_x = max((max_x, x_up))
        except KeyError:
            pass

    plt.sca(axs[0][-1])
    plt.legend(loc='best')

    for ax in axs_all:
        ax.set_xlim((None, max_x))


@decl_plot
def number_of_seeds(rss):
    from filval.histogram import hist_mean

    for rs in rss:
        h = hist_rebin(hist(rs.n_seeds), 50, -0.5, 200.5)
        h = hist_norm(h)
        mean = int(hist_mean(h))
        hist_plot(h, label=f'{rs.sample_name} ($\\mu={mean:d}$)')
    plt.xlabel('Number of Seeds Produced')
    plt.legend()


def all_cut_plots(cuts):
    rss = [ResultSet(f'{cut_sel}', f'../hists/{cut_sel}.root') for cut_sel in cuts]

    tracking_roc_curve = plot_tracking_roc_curve, (rss,)
    tracking_eff_all = plot_tracking_eff_all, (rss,)
    tracking_pur_all = plot_tracking_pur_all, (rss,)
    tracking_roc_curve2 = plot_tracking_roc_curve, (rss, '2')
    tracking_eff_all2 = plot_tracking_eff_all, (rss, '2')
    tracking_pur_all2 = plot_tracking_pur_all, (rss, '2')

    seed_roc_curve = plot_seed_roc_curve, (rss,)
    seed_eff_all = plot_seed_eff_all, (rss,)
    seed_pur_all = plot_seed_pur_all, (rss,)

    ecal_rel_res = plot_ecal_rel_res, (rss,)

    res_contour_dphi_H1 = plot_res_contour, (rss, 1, 'dPhi', ['B1', 'B2', 'F1'])
    res_contour_dphi_H2 = plot_res_contour, (rss, 2, 'dPhi', ['B2', 'B3', 'B4', 'F1', 'F2'])
    res_contour_dRz_H2 = plot_res_contour, (rss, 2, 'dRz', ['B2', 'B3', 'B4', 'F1', 'F2'])

    res_contour_dphi_H3 = plot_res_contour, (rss, 3, 'dPhi', ['B3', 'B4''F1', 'F2', 'F3'])
    res_contour_dRz_H3 = plot_res_contour, (rss, 3, 'dRz', ['B3', 'B4''F1', 'F2', 'F3'])

    plots = [
        Plot(tracking_roc_curve, 'Tracking ROC Curve'),
        Plot(tracking_eff_all, 'Tracking Efficiency'),
        Plot(tracking_pur_all, 'Tracking Purity'),
        Plot(tracking_roc_curve2, 'Tracking ROC Curve (Seed Matched)'),
        Plot(tracking_eff_all2, 'Tracking Efficiency (Seed Matched)'),
        Plot(tracking_pur_all2, 'Tracking Purity (Seed Matched)'),
        Plot(seed_roc_curve, 'Seeding ROC Curve'),
        Plot(seed_eff_all, 'ECAL-Driven Seeding Efficiency'),
        Plot(seed_pur_all, 'ECAL-Driven Seeding Purity'),
        Plot(ecal_rel_res, 'ECAL ET Relative Resolution'),

        Plot(res_contour_dphi_H1, 'dPhi Residual 99% Contours - Hit 1'),
        Plot(res_contour_dphi_H2, 'dPhi Residual 99% Contours - Hit 2'),
        Plot(res_contour_dRz_H2, 'dRz  Residual 99% Contours - Hit 2'),
        Plot(res_contour_dphi_H3, 'dPhi Residual 99% Contours - Hit 3'),
        Plot(res_contour_dRz_H3, 'dRz  Residual 99% Contours - Hit 3'),
        Plot((number_of_seeds, (rss,)), 'Number of Electron Seeds'),
    ]

    render_plots(plots, to_disk=to_disk)
    if not to_disk:
        generate_dashboard(plots, 'Comparisons',
                           output='comparisons.html',
                           source=__file__,
                           config=rss[0].config)


if __name__ == '__main__':
    to_disk = False
    all_cuts = [
        'extra-narrow-window',
        'narrow-window',
        'wide-window',
        'extra-wide-window',
        # 'nwp-window',
        # 'nwp-tight-window',
        # 'nwp-eta-breakdown',
    ]
    all_cut_plots(all_cuts+['old-seeding'])
    # for cut in all_cuts:
    #     single_cut_plots(cut)

    # eta_region_plots('extra-wide-window')