Browse Source

updates to plotter to save pngs in memory

Caleb Fangmeier 6 years ago
parent
commit
99164f00c9
3 changed files with 57 additions and 46 deletions
  1. 2 0
      .gitignore
  2. 51 45
      filval/plotter.py
  3. 4 1
      filval/result_set.py

+ 2 - 0
.gitignore

@@ -1,5 +1,7 @@
 figures/
 env/
+build/
+dist/
 
 __pycache__/
 .ipynb_checkpoints/

+ 51 - 45
filval/plotter.py

@@ -1,6 +1,8 @@
 #!/usr/bin/env python3
 
 from collections import defaultdict
+from io import BytesIO
+from base64 import b64encode
 import numpy as np
 import matplotlib.pyplot as plt
 from markdown import Markdown
@@ -11,7 +13,7 @@ from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
 __all__ = ['Plot',
            'decl_plot',
            'grid_plot',
-           'save_plots',
+           'render_plots',
            'hist_plot',
            'hist2d_plot']
 
@@ -149,15 +151,25 @@ def grid_plot(subplots):
     return argdicts, docs
 
 
-def save_plots(plots, exts=['png'], scale=1.0):
+def render_plots(plots, exts=['png'], scale=1.0, to_disk=True):
     for plot in plots:
         print(f'Building plot {plot.name}')
-        with lp.figure(plot.name, directory='output/figures',
-                       exts=exts,
-                       size=(scale*10, scale*10)):
-            argdicts, docs = grid_plot(plot.subplots)
-            plot.argdicts = argdicts
-            plot.docs = docs
+        plot.data = None
+        if to_disk:
+            with lp.figure(plot.name, directory='output/figures',
+                           exts=exts,
+                           size=(scale*10, scale*10)):
+                argdicts, docs = grid_plot(plot.subplots)
+        else:
+            out = BytesIO()
+            with lp.mem_figure(out,
+                               ext=exts[0],
+                               size=(scale*10, scale*10)):
+                argdicts, docs = grid_plot(plot.subplots)
+            out.seek(0)
+            plot.data = b64encode(out.read()).decode()
+        plot.argdicts = argdicts
+        plot.docs = docs
 
 
 def add_decorations(axes, luminosity, energy):
@@ -253,10 +265,9 @@ def hist2d_plot(h, *args, axes=None, **kwargs):
     # axes.colorbar() TODO: Re-enable this
 
 
-class StackHist:
-
+def hist_plot_stack(hists):
     def __init__(self, title=""):
-        raise NotImplementedError("need to fix to not use to_bin_list")
+        # raise NotImplementedError("need to fix to not use to_bin_list")
         self.title = title
         self.xlabel = ""
         self.ylabel = ""
@@ -269,17 +280,17 @@ class StackHist:
         self.signal_stack = True
         self.data = None
 
-    def add_mc_background(self, th1, label, lumi=None, plot_color=''):
-        self.backgrounds.append((label, lumi, hist(th1), plot_color))
+    def add_mc_background(self, h, label, lumi=None, plot_color=''):
+        self.backgrounds.append((label, lumi, h, plot_color))
 
-    def set_mc_signal(self, th1, label, lumi=None, stack=True, scale=1, plot_color=''):
-        self.signal = (label, lumi, hist(th1), plot_color)
+    def set_mc_signal(self, h, label, lumi=None, stack=True, scale=1, plot_color=''):
+        self.signal = (label, lumi, h, plot_color)
         self.signal_stack = stack
         self.signal_scale = scale
 
-    def set_data(self, th1, lumi=None, plot_color=''):
-        self.data = ('data', lumi, hist(th1), plot_color)
-        self.luminosity = lumi
+    # def set_data(self, th1, lumi=None, plot_color=''):
+    #     self.data = ('data', lumi, hist(th1), plot_color)
+    #     self.luminosity = lumi
 
     def _verify_binning_match(self):
         bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
@@ -302,22 +313,24 @@ class StackHist:
         plt.close(fig)
         plt.ion()
 
-    def do_draw(self, axes):
-        self.axeses = [axes]
+    def draw(self, ax=None):
+        if ax is None:
+            ax = plt.gca()
+        self.axeses = [ax]
         self._verify_binning_match()
         bottoms = [0]*self.n_bins
 
         if self.logx:
-            axes.set_xscale('log')
+            ax.set_xscale('log')
         if self.logy:
-            axes.set_yscale('log')
+            ax.set_yscale('log')
 
-        def draw_bar(label, lumi, bins, plot_color, scale=1, stack=True, **kwargs):
+        def draw_bar(label, lumi, hist, plot_color, scale=1, stack=True, **kwargs):
             if stack:
                 lefts = []
                 widths = []
                 heights = []
-                for left, right, content in bins:
+                for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0])
                     lefts.append(left)
                     widths.append(right-left)
                     if lumi is not None:
@@ -325,13 +338,14 @@ class StackHist:
                     content *= scale
                     heights.append(content)
 
-                axes.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
-                for i, (_, _, content) in enumerate(bins):
+                ax.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
+                for i, content in enumerate(hist[0]):
                     if lumi is not None:
                         content *= self.luminosity/lumi
                     content *= scale
                     bottoms[i] += content
             else:
+                raise NotImplementedError('only supports stacks')
                 xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
                 ys = [0]
                 for left, right, content in bins:
@@ -345,13 +359,13 @@ class StackHist:
                     ys.append(content)
                 xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
                 ys.append(0)
-                axes.plot(xs, ys, label=label, color=plot_color, **kwargs)
+                ax.plot(xs, ys, label=label, color=plot_color, **kwargs)
 
         if self.signal is not None and self.signal_stack:
-            label, lumi, bins, plot_color = self.signal
+            label, lumi, hist, plot_color = self.signal
             if self.signal_scale != 1:
                 label = r"{}$\times{:d}$".format(label, self.signal_scale)
-            draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, hatch='/')
+            draw_bar(label, lumi, hist, plot_color, scale=self.signal_scale, hatch='/')
 
         for background in self.backgrounds:
             draw_bar(*background)
@@ -363,24 +377,16 @@ class StackHist:
                 label = r"{}$\times{:d}$".format(label, self.signal_scale)
             draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
 
-        axes.set_title(self.title)
-        axes.set_xlabel(self.xlabel)
-        axes.set_ylabel(self.ylabel)
-        axes.set_xlim(*self.xlim)
-        # axes.set_ylim(*self.ylim)
+        ax.set_title(self.title)
+        ax.set_xlabel(self.xlabel)
+        ax.set_ylabel(self.ylabel)
+        ax.set_xlim(*self.xlim)
         if self.logy:
-            axes.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
+            ax.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
         else:
-            axes.set_ylim(None, max(bottoms)*1.2)
-        axes.legend(frameon=True, ncol=2)
-        add_decorations(axes, self.luminosity, self.energy)
-
-    def draw(self, axes, save=False, filename=None, **kwargs):
-        self.do_draw(axes, **kwargs)
-        if save:
-            if filename is None:
-                filename = "".join(c for c in self.title if c.isalnum() or c in (' ._+-'))+".png"
-            self.save(filename, **kwargs)
+            ax.set_ylim(None, max(bottoms)*1.2)
+        ax.legend(frameon=True, ncol=2)
+        add_decorations(ax, self.luminosity, self.energy)
 
 
 class StackHistWithSignificance(StackHist):

+ 4 - 1
filval/result_set.py

@@ -17,7 +17,10 @@ class ResultSet:
         file = ROOT.TFile.Open(self.input_filename)
         l = file.GetListOfKeys()
         self.map = {}
-        self.values = dict(file.Get("_value_lookup"))
+        try:
+            self.values = dict(file.Get("_value_lookup"))
+        except Exception:
+            self.values = {}
         for i in range(l.GetSize()):
             name = l.At(i).GetName()
             new_name = ":".join((self.sample_name, name))