Explorar o código

Adds optional stats to histogram plotter

Caleb Fangmeier %!s(int64=6) %!d(string=hai) anos
pai
achega
d010e5c54f
Modificáronse 2 ficheiros con 89 adicións e 5 borrados
  1. 53 2
      filval/histogram_utils.py
  2. 36 3
      filval/plotter.py

+ 53 - 2
filval/histogram_utils.py

@@ -76,12 +76,63 @@ def hist_slice(hist, range_):
             np.concatenate([edges[:-1][slice_], [edges[last]]]))
 
 
-def hist_normalize(h, norm):
-    values, errors, edges = h
+def hist_add(hist1, hist2, w1=1.0, w2=1.0):
+    v1, e1, *lim1 = hist1
+    v2, e2, *lim2 = hist2
+    # print(hist1)
+    if v1.shape != v2.shape:
+        raise ValueError(f'Mismatched histograms to add {v1.shape} != {v2.shape}')
+    # print(lim1)
+    # print(lim2)
+    # nlims_equal = (lim1 != lim2).any()
+    # print(nlims_equal)
+    # if nlims_equal:
+    #     raise ValueError(f'Histograms have different limits!')
+    if len(v1.shape) == 1:  # 1D histograms
+        return ((v1+v2), np.sqrt(e1*e1 + e2*e2), *lim1)
+
+
+def hist_integral(hist, times_bin_width=True):
+    values, errors, edges = hist
+    if times_bin_width:
+        bin_widths = [abs(x2 - x1) for x1, x2 in zip(edges[:-1], edges[1:])]
+        return sum(val*width for val, width in zip(values, bin_widths))
+    else:
+        return sum(values)
+
+
+def hist_normalize(hist, norm):
+    values, errors, edges = hist
     scale = norm/np.sum(values)
     return values*scale, errors*scale, edges
 
 
+def hist_mean(hist):
+    xs = hist_bin_centers(hist)
+    ys, _, _ = hist
+    return sum(x*y for x, y in zip(xs, ys)) / sum(ys)
+
+
+def hist_var(hist):
+    xs = hist_bin_centers(hist)
+    ys, _, _ = hist
+    mean = sum(x*y for x, y in zip(xs, ys)) / sum(ys)
+    mean2 = sum((x**2)*y for x, y in zip(xs, ys)) / sum(ys)
+    return mean2 - mean**2
+
+
+def hist_std(hist):
+    return np.sqrt(hist_var(hist))
+
+
+def hist_stats(hist):
+    return {'int': hist_integral(hist),
+            'sum': hist_integral(hist, False),
+            'mean': hist_mean(hist),
+            'var': hist_var(hist),
+            'std': hist_std(hist)}
+
+
 # def hist_slice2d(h, range_):
 #     values, errors, xs, ys = h
 

+ 36 - 3
filval/plotter.py

@@ -7,7 +7,7 @@ from markdown import Markdown
 import latexipy as lp
 
 from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
-                                    hist_normalize)
+                                    hist_normalize, hist_stats)
 __all__ = ['Plot',
            'decl_plot',
            'grid_plot',
@@ -73,6 +73,35 @@ def decl_plot(fn):
     return f
 
 
+def _add_stats(hist, title=''):
+    fmt = r'''\begin{{eqnarray*}}
+\sum{{x_i}} &=& {sum:5.3f}                  \\
+\sum{{\Delta x_i \cdot x_i}} &=& {int:5.3G} \\
+\mu &=& {mean:5.3G}                         \\
+\sigma^2 &=& {var:5.3G}                     \\
+\sigma &=& {std:5.3G}
+\end{{eqnarray*}}'''
+
+    txt = fmt.format(**hist_stats(hist), title=title)
+    txt = txt.replace('\n', ' ')
+
+    plt.text(0.7, 0.9, txt,
+             bbox={'facecolor': 'white',
+                   'alpha': 0.7,
+                   'boxstyle': 'square,pad=0.8'},
+             transform=plt.gca().transAxes,
+             verticalalignment='top',
+             horizontalalignment='left',
+             size='small')
+    if title:
+        plt.text(0.72, 0.97, title,
+                 bbox={'facecolor': 'white',
+                       'alpha': 0.8},
+                 transform=plt.gca().transAxes,
+                 verticalalignment='top',
+                 horizontalalignment='left')
+
+
 def grid_plot(subplots):
     if any(len(row) != len(subplots[0]) for row in subplots):
         raise ValueError("make_plot requires a rectangular list-of-lists as "
@@ -153,7 +182,7 @@ def add_decorations(axes, luminosity, energy):
 
 def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
               log=False, fig=None, xlim=None, ylim=None, fit=None,
-              grid=False, **kwargs):
+              grid=False, stats=True, **kwargs):
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
     if norm:
@@ -174,7 +203,7 @@ def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
 
     axes.set_xlabel(kwargs.pop('xlabel', ''))
     axes.set_ylabel(kwargs.pop('ylabel', ''))
-    axes.set_title(kwargs.pop('title', ''))
+    title = kwargs.pop('title', '')
     if xlim is not None:
         axes.set_xlim(xlim)
     if ylim is not None:
@@ -200,6 +229,10 @@ def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
                               for label, value in zip(arglabels, popt))
         axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
                   fontsize='medium', family='monospace', usetex=False)
+    if stats:
+        _add_stats(h, title)
+    else:
+        axes.set_title(title)
     axes.grid(grid, color='#E0E0E0')