Browse Source

Moves more plotting code to be mpl based

Caleb Fangmeier 7 years ago
parent
commit
6206f9a6cb
2 changed files with 67 additions and 191 deletions
  1. 42 8
      plotter.py
  2. 25 183
      utils.py

+ 42 - 8
plotter.py

@@ -53,6 +53,32 @@ def histogram(th1, include_errors=False):
     return values, edges
 
 
+def plot_histogram(h1, *args, axes=None, norm=None, include_errors=False, **kwargs):
+    """ Plots a 1D ROOT histogram object using matplotlib """
+    import numpy as np
+    bins, edges = histogram(h1, include_errors=include_errors)
+
+    if norm is not None:
+        scale = norm/np.sum(bins)
+        bins = [(bin*scale, err*scale) for (bin, err) in bins]
+    bins, errs = list(zip(*bins))
+
+    left, right = np.array(edges[:-1]), np.array(edges[1:])
+    X = np.array([left, right]).T.flatten()
+    Y = np.array([bins, bins]).T.flatten()
+    if axes is None:
+        import matplotlib.pyplot as plt
+        axes = plt.gca()
+    axes.set_xlabel(kwargs.pop('xlabel', ''))
+    axes.set_ylabel(kwargs.pop('ylabel', ''))
+    axes.set_title(kwargs.pop('title', ''))
+    axes.plot(X, Y, *args, linewidth=1, **kwargs)
+    if include_errors:
+        axes.errorbar(0.5*(left+right), bins, yerr=errs,
+                      color='k', marker=None, linestyle='None',
+                      barsabove=True, elinewidth=.7, capsize=1)
+
+
 def histogram2d(th2, include_errors=False):
     """ converts TH2 object to something amenable to
     plotting w/ matplotlab's pcolormesh
@@ -68,14 +94,22 @@ def histogram2d(th2, include_errors=False):
             xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
             ys[j][i] = th2.GetYaxis().GetBinLowEdge(j+1)
             zs[j][i] = th2.GetBinContent(i+1, j+1)
-    # just_xs = np.array([th2.GetXaxes().GetBinLowEdge(i) for i in range(1,nbins_x)] +
-    #                     [th2.GetXaxes().GetBinHighEdge(nbins_x-1)])
-    # just_ys = np.array([th2.GetYaxes().GetBinLowEdge(i) for i in range(1,nbins_y)] +
-    #                     [th2.GetYaxes().GetBinHighEdge(nbins_y-1)])
 
     return xs, ys, zs
 
 
+def plot_histogram2d(th2, *args, axes=None, **kwargs):
+    """ Plots a 2D ROOT histogram object using matplotlib """
+    if axes is None:
+        import matplotlib.pyplot as plt
+        axes = plt.gca()
+    axes.set_xlabel(kwargs.pop('xlabel', ''))
+    axes.set_ylabel(kwargs.pop('ylabel', ''))
+    axes.set_title(kwargs.pop('title', ''))
+    axes.pcolormesh(*histogram2d(th2))
+    # axes.colorbar() TODO: Re-enable this
+
+
 class StackHist:
 
     def __init__(self, title=""):
@@ -230,7 +264,7 @@ class StackHistWithSignificance(StackHist):
             for i, (left, right, value) in enumerate(self.signal[2]):
                 sigs[i] += value
                 xs.append(left)
-            xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg)>0])
+            xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
             bottom.plot(xs, ys, '.k')
 
         if high_cut_significance:
@@ -263,9 +297,9 @@ if __name__ == '__main__':
     import matplotlib.pyplot as plt
     from utils import ResultSet
 
-    rs_TTZ =  ResultSet("TTZ",  "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
-    rs_TTW  = ResultSet("TTW",  "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
-    rs_TTH  = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
+    rs_TTZ = ResultSet("TTZ",  "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
+    rs_TTW = ResultSet("TTW",  "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
+    rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
     rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
 
     sh = StackHist('B-Jet Multiplicity')

+ 25 - 183
utils.py

@@ -1,14 +1,12 @@
 
 import io
 import sys
-import itertools as it
 from os.path import dirname, join, abspath, normpath
-from math import ceil, floor, sqrt
-from collections import deque
-from IPython.display import Image
+from math import floor, ceil, sqrt
 
 import ROOT
 from graph_vals import parse
+from plotter import plot_histogram, plot_histogram2d
 
 PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
 EXE_PATH = join(PRJ_PATH, "build/main")
@@ -36,21 +34,6 @@ PDG = {1:   'd',   -1:  'd̄',
        25:  'H',
        }
 
-SINGLE_PLOT_SIZE = (600, 450)
-MAX_WIDTH = 1800
-
-SCALE = .75
-CAN_SIZE_DEF = (int(1600*SCALE), int(1200*SCALE))
-CANVAS = ROOT.TCanvas("c1", "", *CAN_SIZE_DEF)
-ROOT.gStyle.SetPalette(112)  # set the "virdidis" color map
-
-VALUES = {}
-
-
-def clear():
-    CANVAS.Clear()
-    CANVAS.SetCanvasSize(*CAN_SIZE_DEF)
-
 
 def get_color(val, max_val, min_val=0):
     val = (val-min_val)/(max_val-min_val)
@@ -78,42 +61,16 @@ def show_function(dataset, fname):
 
 
 def show_value(dataset, container):
+    from IPython.display import Image
     if type(container) != str:
         container = container.GetName().split(':')[1]
-    g, functions = parse(VALUES[container], container)
+    g, functions = parse(dataset.values[container], container)
     try:
         return Image(g.create_gif()), show_function(dataset, functions)
     except Exception as e:
         print(e)
         print(g.to_string())
 
-
-class OutputCapture:
-    def __init__(self):
-        self.my_stdout = io.StringIO()
-        self.my_stderr = io.StringIO()
-
-    def get_stdout(self):
-        self.my_stdout.seek(0)
-        return self.my_stdout.read()
-
-    def get_stderr(self):
-        self.my_stderr.seek(0)
-        return self.my_stderr.read()
-
-    def __enter__(self):
-        self.stdout = sys.stdout
-        self.stderr = sys.stderr
-        sys.stdout = self.my_stdout
-        sys.stderr = self.my_stderr
-
-    def __exit__(self, *args):
-        sys.stdout = self.stdout
-        sys.stderr = self.stderr
-        self.stdout = None
-        self.stderr = None
-
-
 def normalize_columns(hist2d):
     normHist = ROOT.TH2D(hist2d)
     cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
@@ -134,8 +91,6 @@ class ResultSet:
     def __init__(self, sample_name, input_filename):
         self.sample_name = sample_name
         self.input_filename = input_filename
-        # self.output_filename = self.input_filename.replace(".root", "_result.root")
-        # self.conditional_recompute()
         self.load_objects()
 
         ResultSet.add_collection(self)
@@ -144,7 +99,7 @@ class ResultSet:
         file = ROOT.TFile.Open(self.input_filename)
         l = file.GetListOfKeys()
         self.map = {}
-        VALUES.update(dict(file.Get("_value_lookup")))
+        self.values = dict(file.Get("_value_lookup"))
         for i in range(l.GetSize()):
             name = l.At(i).GetName()
             new_name = ":".join((self.sample_name, name))
@@ -172,146 +127,33 @@ class ResultSet:
 
     @classmethod
     def calc_shape(cls, n_plots):
-        if n_plots*SINGLE_PLOT_SIZE[0] > MAX_WIDTH:
-            shape_x = MAX_WIDTH//SINGLE_PLOT_SIZE[0]
-            shape_y = ceil(n_plots / shape_x)
-            return (shape_x, shape_y)
-        else:
-            return (n_plots, 1)
-
-    def draw(self, shape=None):
-        objs = [obj for obj in self.map.values() if hasattr(obj, "Draw")]
-        if shape is None:
-            n_plots = len(objs)
-            shape = self.calc_shape(n_plots)
-        CANVAS.Clear()
-        CANVAS.SetCanvasSize(shape[0]*SINGLE_PLOT_SIZE[0], shape[1]*SINGLE_PLOT_SIZE[1])
-        CANVAS.Divide(*shape)
-        i = 1
-        for hist in objs:
-            CANVAS.cd(i)
-            try:
-                hist.SetStats(True)
-            except AttributeError:
-                pass
-            if type(hist) in (ROOT.TH1I, ROOT.TH1F, ROOT.TH1D):
-                hist.SetMinimum(0)
-            hist.Draw(self.get_draw_option(hist))
-            i += 1
-        CANVAS.Draw()
-
-    @staticmethod
-    def get_draw_option(obj):
-        obj_type = type(obj)
-        if obj_type in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
-            return ""
-        elif obj_type in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
-            return "COLZ"
-        elif obj_type in (ROOT.TGraph,):
-            return "A*"
+        if n_plots > 3:
+            return ceil(n_plots / 3), 3
         else:
-            return None
+            return 1, n_plots
+
+    def draw(self, figure=None, shape=None):
+        objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
+        shape = self.calc_shape(len(objs))
+        if figure is None:
+            import matplotlib.pyplot as plt
+            figure = plt.gcf() if plt.gcf() is not None else plt.figure()
+        figure.clear()
+        for i, (name, obj) in enumerate(objs):
+            axes = figure.add_subplot(*shape, i+1)
+            if isinstance(obj, ROOT.TH2):
+                plot_histogram2d(obj, title=obj.GetTitle(), axes=axes)
+            else:
+                plot_histogram(obj, title=obj.GetTitle(), axes=axes)
+        figure.tight_layout()
 
     @classmethod
     def get_hist_set(cls, attrname):
-        labels, hists = zip(*[(sample_name, getattr(h, attrname))
-                              for sample_name, h in cls.collections.items()])
-        return labels, hists
+        return [(sample_name, getattr(h, attrname))
+                for sample_name, h in cls.collections.items()]
 
     @classmethod
     def add_collection(cls, hc):
         if not hasattr(cls, "collections"):
             cls.collections = {}
         cls.collections[hc.sample_name] = hc
-
-    @classmethod
-    def stack_hist(cls,
-                   hist_name,
-                   title="",
-                   enable_fill=False,
-                   normalize_to=0,
-                   draw=False,
-                   draw_canvas=True,
-                   draw_option="",
-                   make_legend=False,
-                   _stacks={}):
-        labels, hists = cls.get_hist_set(hist_name)
-        if draw_canvas:
-            CANVAS.Clear()
-            CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0],
-                                 SINGLE_PLOT_SIZE[1])
-
-        colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen, ROOT.kYellow])
-        stack = ROOT.THStack(hist_name+"_stack", title)
-        if labels is None:
-            labels = [hist.GetName() for hist in hists]
-        if type(normalize_to) in (int, float):
-            normalize_to = [normalize_to]*len(hists)
-        ens = enumerate(zip(hists, labels, colors, normalize_to))
-        for i, (hist, label, color, norm) in ens:
-            hist_copy = hist
-            hist_copy = hist.Clone(hist.GetName()+"_clone" + draw_option)
-            hist_copy.SetTitle(label)
-            if enable_fill:
-                hist_copy.SetFillColorAlpha(color, 0.75)
-                hist_copy.SetLineColorAlpha(color, 0.75)
-            if norm:
-                integral = hist_copy.Integral()
-                hist_copy.Scale(norm/integral, "nosw2")
-                hist_copy.SetStats(True)
-            stack.Add(hist_copy)
-        if draw:
-            stack.Draw(draw_option)
-            if make_legend:
-                CANVAS.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
-        # prevent stack from getting garbage collected
-        _stacks[stack.GetName()] = stack
-        if draw_canvas:
-            CANVAS.Draw()
-        return stack
-
-    @classmethod
-    def stack_hist_array(cls,
-                         hist_names,
-                         titles,
-                         shape=None, **kwargs):
-        n_hist = len(hist_names)
-        if shape is None:
-            if n_hist <= 4:
-                shape = (1, n_hist)
-            else:
-                shape = (ceil(sqrt(n_hist)),)*2
-        CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
-                             SINGLE_PLOT_SIZE[1]*shape[1])
-        CANVAS.Divide(*shape)
-        for i, hist_name, title in zip(range(1, n_hist+1), hist_names, titles):
-            CANVAS.cd(i)
-            cls.stack_hist(hist_name, title=title, draw=True,
-                           draw_canvas=False, **kwargs)
-        CANVAS.cd(n_hist).BuildLegend(0.75, 0.75, 0.95, 0.95, "")
-
-    pts = deque([], 50)
-
-    @classmethod
-    def hist_array_single(cls,
-                          hist_name,
-                          title=None,
-                          **kwargs):
-        n_hist = len(cls.collections)
-        shape = cls.calc_shape(n_hist)
-        CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
-                             SINGLE_PLOT_SIZE[1]*shape[1])
-        CANVAS.Divide(*shape)
-        labels, hists = cls.get_hist_set(hist_name)
-
-        def pave_loc():
-            hist.Get
-        for i, label, hist in zip(range(1, n_hist+1), labels, hists):
-            CANVAS.cd(i)
-            hist.SetStats(False)
-            hist.Draw(cls.get_draw_option(hist))
-
-            pt = ROOT.TPaveText(0.70, 0.87, 0.85, 0.95, "NDC")
-            pt.AddText("Dataset: "+label)
-            pt.Draw()
-            cls.pts.append(pt)