Quellcode durchsuchen

Rewrites histogram stack to be much simpler (and less comprehensive).
Changes to latexipy dependency to use custom fork

Caleb Fangmeier vor 6 Jahren
Ursprung
Commit
ae2e2b0f3a
3 geänderte Dateien mit 82 neuen und 260 gelöschten Zeilen
  1. 5 14
      filval/histogram_utils.py
  2. 71 245
      filval/plotter.py
  3. 6 1
      setup.py

+ 5 - 14
filval/histogram_utils.py

@@ -76,20 +76,11 @@ def hist_slice(hist, range_):
             np.concatenate([edges[:-1][slice_], [edges[last]]]))
 
 
-def hist_add(hist1, hist2, w1=1.0, w2=1.0):
-    v1, e1, *lim1 = hist1
-    v2, e2, *lim2 = hist2
-    # print(hist1)
-    if v1.shape != v2.shape:
-        raise ValueError(f'Mismatched histograms to add {v1.shape} != {v2.shape}')
-    # print(lim1)
-    # print(lim2)
-    # nlims_equal = (lim1 != lim2).any()
-    # print(nlims_equal)
-    # if nlims_equal:
-    #     raise ValueError(f'Histograms have different limits!')
-    if len(v1.shape) == 1:  # 1D histograms
-        return ((v1+v2), np.sqrt(e1*e1 + e2*e2), *lim1)
+def hist_add(*hists):
+    if len(hists) == 0:
+        return np.zeros(0)
+    vals, errs, edges = zip(*hists)
+    return np.sum(vals, axis=0), np.sqrt(np.sum([err*err for err in errs], axis=0)), edges[0]
 
 
 def hist_integral(hist, times_bin_width=True):

+ 71 - 245
filval/plotter.py

@@ -10,21 +10,23 @@ import latexipy as lp
 
 from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
                                     hist_normalize, hist_stats)
+
 __all__ = ['Plot',
            'decl_plot',
            'grid_plot',
            'render_plots',
            'hist_plot',
+           'hist_plot_stack',
            'hist2d_plot']
 
 
 class Plot:
-    def __init__(self, subplots, name, title=None, docs="N/A", argdicts={}):
+    def __init__(self, subplots, name, title=None, docs="N/A", arg_dicts=None):
         self.subplots = subplots
         self.name = name
         self.title = title
         self.docs = docs
-        self.argdicts = argdicts
+        self.arg_dicts = arg_dicts if arg_dicts is not None else {}
 
 
 MD = Markdown(extensions=['mdx_math'],
@@ -49,7 +51,7 @@ lp.latexify(params={'pgf.texsystem': 'pdflatex',
 def _fn_call_to_dict(fn, *args, **kwargs):
     from inspect import signature
     pnames = list(signature(fn).parameters)
-    pvals = list(args)+list(kwargs.values())
+    pvals = list(args) + list(kwargs.values())
     return {k: v for k, v in zip(pnames, pvals)}
 
 
@@ -72,6 +74,7 @@ def decl_plot(fn):
         docs = _process_docs(fn)
 
         return argdict, docs
+
     return f
 
 
@@ -111,7 +114,7 @@ def grid_plot(subplots):
 
     def calc_rowspan(fig, row, col):
         span = 1
-        for r in range(row+1, len(fig)):
+        for r in range(row + 1, len(fig)):
             if fig[r][col] == "FU":
                 span += 1
             else:
@@ -120,7 +123,7 @@ def grid_plot(subplots):
 
     def calc_colspan(fig, row, col):
         span = 1
-        for c in range(col+1, len(fig[row])):
+        for c in range(col + 1, len(fig[row])):
             if fig[row][c] == "FL":
                 span += 1
             else:
@@ -151,20 +154,20 @@ def grid_plot(subplots):
     return argdicts, docs
 
 
-def render_plots(plots, exts=['png'], scale=1.0, to_disk=True):
+def render_plots(plots, exts=('png',), scale=1.0, to_disk=True):
     for plot in plots:
         print(f'Building plot {plot.name}')
         plot.data = None
         if to_disk:
             with lp.figure(plot.name, directory='output/figures',
                            exts=exts,
-                           size=(scale*10, scale*10)):
+                           size=(scale * 10, scale * 10)):
                 argdicts, docs = grid_plot(plot.subplots)
         else:
             out = BytesIO()
             with lp.mem_figure(out,
                                ext=exts[0],
-                               size=(scale*10, scale*10)):
+                               size=(scale * 10, scale * 10)):
                 argdicts, docs = grid_plot(plot.subplots)
             out.seek(0)
             plot.data = b64encode(out.read()).decode()
@@ -192,8 +195,8 @@ def add_decorations(axes, luminosity, energy):
               transform=axes.transAxes)
 
 
-def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
-              log=False, fig=None, xlim=None, ylim=None, fit=None,
+def hist_plot(h, *args, norm=None, include_errors=False,
+              log=False, xlim=None, ylim=None, fit=None,
               grid=False, stats=True, **kwargs):
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
@@ -201,271 +204,94 @@ def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
         h = hist_normalize(h, norm)
     values, errors, edges = h
 
-    scale = 1. if norm is None else norm/np.sum(values)
-    values = [val*scale for val in values]
-    errors = [val*scale for val in errors]
+    scale = 1. if norm is None else norm / np.sum(values)
+    values = [val * scale for val in values]
+    errors = [val * scale for val in errors]
 
     left, right = np.array(edges[:-1]), np.array(edges[1:])
-    X = np.array([left, right]).T.flatten()
-    Y = np.array([values, values]).T.flatten()
+    x = np.array([left, right]).T.flatten()
+    y = np.array([values, values]).T.flatten()
 
-    if axes is None:
-        import matplotlib.pyplot as plt
-        axes = plt.gca()
+    ax = plt.gca()
 
-    axes.set_xlabel(kwargs.pop('xlabel', ''))
-    axes.set_ylabel(kwargs.pop('ylabel', ''))
+    ax.set_xlabel(kwargs.pop('xlabel', ''))
+    ax.set_ylabel(kwargs.pop('ylabel', ''))
     title = kwargs.pop('title', '')
     if xlim is not None:
-        axes.set_xlim(xlim)
+        ax.set_xlim(xlim)
     if ylim is not None:
-        axes.set_ylim(ylim)
+        ax.set_ylim(ylim)
     # elif not log:
     #     axes.set_ylim((0, None))
 
-    axes.plot(X, Y, *args, linewidth=1, **kwargs)
+    ax.plot(x, y, *args, linewidth=1, **kwargs)
     if include_errors:
-        axes.errorbar(hist_bin_centers(h), values, yerr=errors,
-                      color='k', marker=None, linestyle='None',
-                      barsabove=True, elinewidth=.7, capsize=1)
+        ax.errorbar(hist_bin_centers(h), values, yerr=errors,
+                    color='k', marker=None, linestyle='None',
+                    barsabove=True, elinewidth=.7, capsize=1)
     if log:
-        axes.set_yscale('log')
+        ax.set_yscale('log')
     if fit:
         f, p0 = fit
         popt, pcov = hist_fit(h, f, p0)
-        fit_xs = np.linspace(X[0], X[-1], 100)
+        fit_xs = np.linspace(x[0], x[-1], 100)
         fit_ys = f(fit_xs, *popt)
-        axes.plot(fit_xs, fit_ys, '--g')
+        ax.plot(fit_xs, fit_ys, '--g')
         arglabels = list(signature(f).parameters)[1:]
         label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
                               for label, value in zip(arglabels, popt))
-        axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
-                  fontsize='medium', family='monospace', usetex=False)
+        ax.text(0.60, 0.95, label_txt, va='top', transform=ax.transAxes,
+                fontsize='medium', family='monospace', usetex=False)
     if stats:
         _add_stats(h, title)
     else:
-        axes.set_title(title)
-    axes.grid(grid, color='#E0E0E0')
+        ax.set_title(title)
+    ax.grid(grid, color='#E0E0E0')
 
 
-def hist2d_plot(h, *args, axes=None, **kwargs):
+def hist2d_plot(h, **kwargs):
     """ Plots a 2D ROOT histogram object using matplotlib """
     try:
         values, errors, xs, ys = h
     except (TypeError, ValueError):
         values, errors, xs, ys = hist2d(h)
 
-    if axes is None:
-        import matplotlib.pyplot as plt
-        axes = plt.gca()
-    axes.set_xlabel(kwargs.pop('xlabel', ''))
-    axes.set_ylabel(kwargs.pop('ylabel', ''))
-    axes.set_title(kwargs.pop('title', ''))
-    axes.pcolormesh(xs, ys, values,)
+    plt.xlabel(kwargs.pop('xlabel', ''))
+    plt.ylabel(kwargs.pop('ylabel', ''))
+    plt.title(kwargs.pop('title', ''))
+    plt.pcolormesh(xs, ys, values, )
     # axes.colorbar() TODO: Re-enable this
 
 
-def hist_plot_stack(hists):
-    def __init__(self, title=""):
-        # raise NotImplementedError("need to fix to not use to_bin_list")
-        self.title = title
-        self.xlabel = ""
-        self.ylabel = ""
-        self.xlim = (None, None)
-        self.ylim = (None, None)
-        self.logx = False
-        self.logy = False
-        self.backgrounds = []
-        self.signal = None
-        self.signal_stack = True
-        self.data = None
-
-    def add_mc_background(self, h, label, lumi=None, plot_color=''):
-        self.backgrounds.append((label, lumi, h, plot_color))
-
-    def set_mc_signal(self, h, label, lumi=None, stack=True, scale=1, plot_color=''):
-        self.signal = (label, lumi, h, plot_color)
-        self.signal_stack = stack
-        self.signal_scale = scale
-
-    # def set_data(self, th1, lumi=None, plot_color=''):
-    #     self.data = ('data', lumi, hist(th1), plot_color)
-    #     self.luminosity = lumi
-
-    def _verify_binning_match(self):
-        bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
-        if self.signal is not None:
-            bins_count.append(len(self.signal[2]))
-        if self.data is not None:
-            bins_count.append(len(self.data[2]))
-        n_bins = bins_count[0]
-        if any(bin_count != n_bins for bin_count in bins_count):
-            raise ValueError("all histograms must have the same number of bins")
-        self.n_bins = n_bins
-
-    def save(self, filename, **kwargs):
-        import matplotlib.pyplot as plt
-        plt.ioff()
-        fig = plt.figure()
-        ax = fig.gca()
-        self.do_draw(ax, **kwargs)
-        fig.savefig("figures/"+filename, transparent=True)
-        plt.close(fig)
-        plt.ion()
-
-    def draw(self, ax=None):
-        if ax is None:
-            ax = plt.gca()
-        self.axeses = [ax]
-        self._verify_binning_match()
-        bottoms = [0]*self.n_bins
-
-        if self.logx:
-            ax.set_xscale('log')
-        if self.logy:
-            ax.set_yscale('log')
-
-        def draw_bar(label, lumi, hist, plot_color, scale=1, stack=True, **kwargs):
-            if stack:
-                lefts = []
-                widths = []
-                heights = []
-                for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0])
-                    lefts.append(left)
-                    widths.append(right-left)
-                    if lumi is not None:
-                        content *= self.luminosity/lumi
-                    content *= scale
-                    heights.append(content)
-
-                ax.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
-                for i, content in enumerate(hist[0]):
-                    if lumi is not None:
-                        content *= self.luminosity/lumi
-                    content *= scale
-                    bottoms[i] += content
-            else:
-                raise NotImplementedError('only supports stacks')
-                xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
-                ys = [0]
-                for left, right, content in bins:
-                    width2 = (right-left)/2
-                    if lumi is not None:
-                        content *= self.luminosity/lumi
-                    content *= scale
-                    xs.append(left-width2)
-                    ys.append(content)
-                    xs.append(right-width2)
-                    ys.append(content)
-                xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
-                ys.append(0)
-                ax.plot(xs, ys, label=label, color=plot_color, **kwargs)
-
-        if self.signal is not None and self.signal_stack:
-            label, lumi, hist, plot_color = self.signal
-            if self.signal_scale != 1:
-                label = r"{}$\times{:d}$".format(label, self.signal_scale)
-            draw_bar(label, lumi, hist, plot_color, scale=self.signal_scale, hatch='/')
-
-        for background in self.backgrounds:
-            draw_bar(*background)
-
-        if self.signal is not None and not self.signal_stack:
-            # draw_bar(*self.signal, stack=False, color='k')
-            label, lumi, bins, plot_color = self.signal
-            if self.signal_scale != 1:
-                label = r"{}$\times{:d}$".format(label, self.signal_scale)
-            draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
-
-        ax.set_title(self.title)
-        ax.set_xlabel(self.xlabel)
-        ax.set_ylabel(self.ylabel)
-        ax.set_xlim(*self.xlim)
-        if self.logy:
-            ax.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
-        else:
-            ax.set_ylim(None, max(bottoms)*1.2)
-        ax.legend(frameon=True, ncol=2)
-        add_decorations(ax, self.luminosity, self.energy)
-
-
-class StackHistWithSignificance(StackHist):
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
-        bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
-        axes.set_position(top_box)
-        super().do_draw(axes)
-        axes.set_xticks([])
-        rhs_color = '#cc6600'
-
-        bottom = axes.get_figure().add_axes(bottom_box)
-        bottom_rhs = bottom.twinx()
-        bgs = [0]*self.n_bins
-        for (_, _, bins, _) in self.backgrounds:
-            for i, (left, right, value) in enumerate(bins):
-                bgs[i] += value
-
-        sigs = [0]*self.n_bins
-        if bin_significance:
-            xs = []
-            for i, (left, right, value) in enumerate(self.signal[2]):
-                sigs[i] += value
-                xs.append(left)
-            xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
-            bottom.plot(xs, ys, '.k')
-
-        if high_cut_significance:
-            # s/(s+b) for events passing a minimum cut requirement
-            min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
-            min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
-            min_xs, min_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
-                                   if (sig+bg) > 0])
-            bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
-
-        if low_cut_significance:
-            # s/(s+b) for events passing a maximum cut requirement
-            max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
-            max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
-            max_xs, max_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
-                                   if (sig+bg) > 0])
-            bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
-
-        bottom.set_ylabel(r'$S/(S+B)$')
-        bottom.set_xlim(axes.get_xlim())
-        bottom.set_ylim((0, 1.1))
-        if low_cut_significance or high_cut_significance:
-            bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
-            bottom_rhs.yaxis.label.set_color(rhs_color)
-            bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
-        # bottom.grid()
-
-
-if __name__ == '__main__':
-    import matplotlib.pyplot as plt
-    from utils import ResultSet
-
-    rs_TTZ = ResultSet("TTZ",  "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
-    rs_TTW = ResultSet("TTW",  "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
-    rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
-    rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
-
-    sh = StackHist('B-Jet Multiplicity')
-    sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
-    sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
-    sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
-    sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
-
-    sh.luminosity = 40
-    sh.energy = 13
-    sh.xlabel = 'B-Jet Count'
-    sh.ylabel = r'\# Events'
-    sh.xlim = (-.5, 9.5)
-    sh.signal_stack = False
-
-    fig = plt.figure()
-    sh.draw(fig.gca())
-    plt.show()
-    # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')
+def hist_plot_stack(hists: list, labels: list = None):
+    """
+    Creates a stacked histogram in the current axes.
+
+    :param hists: list of histogram
+    :param labels:
+    :return:
+    """
+    if len(hists) == 0:
+        return
+
+    if len(set([len(hist[0]) for hist in hists])) != 1:
+        raise ValueError("all histograms must have the same number of bins")
+    if labels is None:
+        labels = [None for _ in hists]
+    if len(labels) != len(hists):
+        raise ValueError("Label mismatch")
+
+    bottoms = [0 for _ in hists[0][0]]
+
+    for hist, label in zip(hists, labels):
+        centers = []
+        widths = []
+        heights = []
+        for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0]):
+            centers.append((right+left)/2)
+            widths.append(right - left)
+            heights.append(content)
+
+        plt.bar(centers, heights, widths, bottoms, label=label)
+        for i, content in enumerate(hist[0]):
+            bottoms[i] += content

+ 6 - 1
setup.py

@@ -1,12 +1,17 @@
 from setuptools import setup
 
 with open('requirements.txt') as req:
-    install_requires = req.readlines()
+    install_requires = [l.strip() for l in req.readlines()]
+
+print(install_requires)
 
 setup(
     name='filval',
     version='0.1',
     install_requires=install_requires,
+    dependency_links=[
+        "git+ssh://git@github.com/cfangmeier/latexipy.git#egg=latexipy"
+    ],
     packages=['filval'],
     scripts=['scripts/merge.py',
              'scripts/process_parallel.py'