Explorar o código

Adds percent contours function for 2d histograms, simplifies hist_plot

Caleb Fangmeier %!s(int64=6) %!d(string=hai) anos
pai
achega
0cf8a45d13
Modificáronse 3 ficheiros con 60 adicións e 41 borrados
  1. 32 5
      filval/histogram.py
  2. 24 35
      filval/plotting.py
  3. 4 1
      filval/templates/dashboard.j2

+ 32 - 5
filval/histogram.py

@@ -130,10 +130,10 @@ def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
     """
     nbins_x = th2.GetNbinsX()
     nbins_y = th2.GetNbinsY()
-    xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    ys = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    values = np.zeros((nbins_y, nbins_x), np.float32)
-    errors = np.zeros((nbins_y, nbins_x), np.float32)
+    xs = np.zeros((nbins_y+1, nbins_x+1), dtype=np.float32)
+    ys = np.zeros((nbins_y+1, nbins_x+1), dtype=np.float32)
+    values = np.zeros((nbins_y, nbins_x), dtype=np.float32)
+    errors = np.zeros((nbins_y, nbins_x), dtype=np.float32)
     for i in range(nbins_x):
         for j in range(nbins_y):
             xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
@@ -163,7 +163,8 @@ def hist2d_norm(h, norm=1, axis=None):
     :return: The normalized histogram
     """
     values, errors, xs, ys = h
-    with np.errstate(divide='ignore'):
+    with np.warnings.catch_warnings():
+        np.warnings.filterwarnings('ignore', 'invalid value encountered in true_divide')
         scale_values = norm / np.sum(values, axis=axis)
         scale_values[scale_values == np.inf] = 1
         scale_values[scale_values == -np.inf] = 1
@@ -174,3 +175,29 @@ def hist2d_norm(h, norm=1, axis=None):
     return values, errors, xs.copy(), ys.copy()
 
 
+def hist2d_percent_contour(h, percent: float, axis: str):
+    values, _, xs, ys = h
+
+    try:
+        axis = axis.lower()
+        axis_idx = {'x': 1, 'y': 0}[axis]
+    except KeyError:
+        raise ValueError('axis must be \'x\' or \'y\'')
+    if percent < 0 or percent > 1:
+        raise ValueError('percent must be in [0,1]')
+
+    with np.warnings.catch_warnings():
+        np.warnings.filterwarnings('ignore', 'invalid value encountered in true_divide')
+        values = values / np.sum(values, axis=axis_idx, keepdims=True)
+        np.nan_to_num(values, copy=False)
+    values = np.cumsum(values, axis=axis_idx)
+    idxs = np.argmax(values > percent, axis=axis_idx)
+
+    bins_y = (ys[:-1, 0] + ys[1:, 0])/2
+    bins_x = (xs[0, :-1] + xs[0, 1:])/2
+
+    if axis == 'x':
+        return bins_x[idxs], bins_y
+    else:
+        return bins_x, bins_y[idxs]
+

+ 24 - 35
filval/plotting.py

@@ -94,22 +94,34 @@ def decl_plot(fn):
     return f
 
 
-def simple_plot(thx):
+def simple_plot(thx, *args, log=None, **kwargs):
     import ROOT
 
     if isinstance(thx, ROOT.TH2):
         def f(h):
-            hist2d_plot(hist2d(h))
+            hist2d_plot(hist2d(h), *args, **kwargs)
             plt.xlabel(h.GetXaxis().GetTitle())
             plt.ylabel(h.GetYaxis().GetTitle())
+            if log == 'x':
+                plt.semilogx()
+            elif log == 'y':
+                plt.semilogy()
+            elif log == 'xy':
+                plt.loglog()
             return dict(), "", ""
 
         return Plot([[(f, (thx,), {})]], thx.GetName())
     elif isinstance(thx, ROOT.TH1):
         def f(h):
-            hist_plot(hist(h))
+            hist_plot(hist(h), *args, **kwargs)
             plt.xlabel(h.GetXaxis().GetTitle())
             plt.ylabel(h.GetYaxis().GetTitle())
+            if log == 'x':
+                plt.semilogx()
+            elif log == 'y':
+                plt.semilogy()
+            elif log == 'xy':
+                plt.loglog()
             return dict(), "", ""
 
         return Plot([[(f, (thx,), {})]], thx.GetName())
@@ -287,58 +299,35 @@ def add_decorations(axes, luminosity, energy):
               transform=axes.transAxes)
 
 
-def hist_plot(h, *args, norm=None, include_errors=False,
-              log=False, xlim=None, ylim=None, fit=None,
-              grid=False, stats=False, **kwargs):
+def hist_plot(h, *args, include_errors=False, fit=None, stats=False, **kwargs):
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
-    if norm:
-        h = hist_norm(h, norm)
     values, errors, edges = h
 
-    scale = 1. if norm is None else norm / np.sum(values)
-    values = [val * scale for val in values]
-    errors = [val * scale for val in errors]
-
     left, right = np.array(edges[:-1]), np.array(edges[1:])
     x = np.array([left, right]).T.flatten()
     y = np.array([values, values]).T.flatten()
 
-    ax = plt.gca()
-
-    ax.set_xlabel(kwargs.pop('xlabel', ''))
-    ax.set_ylabel(kwargs.pop('ylabel', ''))
     title = kwargs.pop('title', '')
-    if xlim is not None:
-        ax.set_xlim(xlim)
-    if ylim is not None:
-        ax.set_ylim(ylim)
-    # elif not log:
-    #     axes.set_ylim((0, None))
-
-    ax.plot(x, y, *args, linewidth=1, **kwargs)
+
+    plt.plot(x, y, *args, linewidth=1, **kwargs)
     if include_errors:
-        ax.errorbar(hist_bin_centers(h), values, yerr=errors,
-                    color='k', marker=None, linestyle='None',
-                    barsabove=True, elinewidth=.7, capsize=1)
-    if log:
-        ax.set_yscale('log')
+        plt.errorbar(hist_bin_centers(h), values, yerr=errors,
+                     color='k', marker=None, linestyle='None',
+                     barsabove=True, elinewidth=.7, capsize=1)
     if fit:
         f, p0 = fit
         popt, pcov = hist_fit(h, f, p0)
         fit_xs = np.linspace(x[0], x[-1], 100)
         fit_ys = f(fit_xs, *popt)
-        ax.plot(fit_xs, fit_ys, '--g')
+        plt.plot(fit_xs, fit_ys, '--g')
         arglabels = list(signature(f).parameters)[1:]
         label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
                               for label, value in zip(arglabels, popt))
-        ax.text(0.60, 0.95, label_txt, va='top', transform=ax.transAxes,
-                fontsize='medium', family='monospace', usetex=False)
+        plt.text(0.60, 0.95, label_txt, va='top', transform=plt.gca().transAxes,
+                 fontsize='medium', family='monospace', usetex=False)
     if stats:
         _add_stats(h, title)
-    else:
-        ax.set_title(title)
-    ax.grid(grid, color='#E0E0E0')
 
 
 def hist2d_plot(h, txt_format=None, colorbar=False, **kwargs):

+ 4 - 1
filval/templates/dashboard.j2

@@ -10,7 +10,7 @@
 
   <script src="https://tttt.fangmeier.tech/hl/shCore.js"         type="text/javascript"></script>
   <script src="https://tttt.fangmeier.tech/hl/shBrushPython.js" type="text/javascript"></script>
-    <script src="https://tttt.fangmeier.tech/hl/shBrushPlain.js" type="text/javascript"></script>
+  <script src="https://tttt.fangmeier.tech/hl/shBrushPlain.js" type="text/javascript"></script>
   <link href="https://tttt.fangmeier.tech/hl/shCore.css"          rel="stylesheet" type="text/css" />
   <link href="https://tttt.fangmeier.tech/hl/shThemeDefault.css"  rel="stylesheet" type="text/css" />
   <script src="https://tttt.fangmeier.tech/hl/shAutoloader.js" type="text/javascript"></script>
@@ -26,6 +26,9 @@ MathJax.Hub.Config({
 </head>
 <body>
 <div class="container-fluid">
+    <div class="row">
+        <a href="./" class="button"> Go Up <span class="glyphicon glyphicon-circle-arrow-left"></span> </a>
+    </div>
 {% for r, plot_row in enumerate(plots) %}
   <div class="row">
   {% for c, plot in enumerate(plot_row) %}