Parcourir la source

Adds text-display to hist2d_plot and renames some modules.

Caleb Fangmeier il y a 7 ans
Parent
commit
5f45b0ec37
2 fichiers modifiés avec 115 ajouts et 69 suppressions
  1. 83 61
      filval/histogram_utils.py
  2. 32 8
      filval/plotter.py

+ 83 - 61
filval/histogram_utils.py

@@ -1,12 +1,12 @@
-'''
-    histogram_utils.py
+"""
+    histogram.py
     The functions in this module use a representation of a histogram that is a
     tuple containing an arr of N bin values, an array of N bin errors(symmetric)
     and an array of N+1 bin edges(N lower edges + 1 upper edge).
 
     For 2d histograms, It is similar, but the arrays are two dimensional and
     there are separate arrays for x-edges and y-edges.
-'''
+"""
 
 import numpy as np
 from scipy.optimize import curve_fit
@@ -36,39 +36,8 @@ def hist_bin_centers(h):
     return (edges[:-1] + edges[1:])/2.0
 
 
-def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
-    """ Converts TH2 object to something amenable to
-        plotting w/ matplotlab's pcolormesh.
-    """
-    nbins_x = th2.GetNbinsX()
-    nbins_y = th2.GetNbinsY()
-    print(nbins_x, nbins_y)
-    xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    ys = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    values = np.zeros((nbins_y, nbins_x), np.float32)
-    errors = np.zeros((nbins_y, nbins_x), np.float32)
-    for i in range(nbins_x):
-        for j in range(nbins_y):
-            xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
-            ys[j][i] = th2.GetYaxis().GetBinLowEdge(j+1)
-            values[j][i] = th2.GetBinContent(i+1, j+1)
-            errors[j][i] = th2.GetBinError(i+1, j+1)
-        xs[nbins_y][i] = th2.GetXaxis().GetBinUpEdge(i)
-        ys[nbins_y][i] = th2.GetYaxis().GetBinUpEdge(nbins_y)
-    for j in range(nbins_y+1):
-        xs[j][nbins_x] = th2.GetXaxis().GetBinUpEdge(nbins_x)
-        ys[j][nbins_x] = th2.GetYaxis().GetBinUpEdge(j)
-
-    xs *= rescale_x
-    ys *= rescale_y
-    values *= rescale_z
-    errors *= rescale_z
-
-    return values, errors, xs, ys
-
-
-def hist_slice(hist, range_):
-    values, errors, edges = hist
+def hist_slice(h, range_):
+    values, errors, edges = h
     lim_low, lim_high = range_
     slice_ = np.logical_and(edges[:-1] > lim_low, edges[1:] < lim_high)
     last = len(slice_) - np.argmax(slice_[::-1])
@@ -77,55 +46,56 @@ def hist_slice(hist, range_):
             np.concatenate([edges[:-1][slice_], [edges[last]]]))
 
 
-def hist_add(*hists):
-    if len(hists) == 0:
+def hist_add(*hs):
+    if len(hs) == 0:
         return np.zeros(0)
-    vals, errs, edges = zip(*hists)
+    vals, errs, edges = zip(*hs)
     return np.sum(vals, axis=0), np.sqrt(np.sum([err*err for err in errs], axis=0)), edges[0]
 
 
-def hist_integral(hist, times_bin_width=True):
-    values, errors, edges = hist
+def hist_integral(h, times_bin_width=True):
+    values, errors, edges = h
     if times_bin_width:
         bin_widths = [abs(x2 - x1) for x1, x2 in zip(edges[:-1], edges[1:])]
         return sum(val*width for val, width in zip(values, bin_widths))
     else:
         return sum(values)
 
-def hist_scale(hist, scale):
-    values, errors, edges = hist
+
+def hist_scale(h, scale):
+    values, errors, edges = h
     return values*scale, errors*scale, edges
 
-def hist_normalize(hist, norm = 1):
-    scale = norm/np.sum(hist[0])
-    return hist_scale(hist, scale)
 
+def hist_norm(h, norm=1):
+    scale = norm/np.sum(h[0])
+    return hist_scale(h, scale)
 
 
-def hist_mean(hist):
-    xs = hist_bin_centers(hist)
-    ys, _, _ = hist
+def hist_mean(h):
+    xs = hist_bin_centers(h)
+    ys, _, _ = h
     return sum(x*y for x, y in zip(xs, ys)) / sum(ys)
 
 
-def hist_var(hist):
-    xs = hist_bin_centers(hist)
-    ys, _, _ = hist
+def hist_var(h):
+    xs = hist_bin_centers(h)
+    ys, _, _ = h
     mean = sum(x*y for x, y in zip(xs, ys)) / sum(ys)
     mean2 = sum((x**2)*y for x, y in zip(xs, ys)) / sum(ys)
     return mean2 - mean**2
 
 
-def hist_std(hist):
-    return np.sqrt(hist_var(hist))
+def hist_std(h):
+    return np.sqrt(hist_var(h))
 
 
-def hist_stats(hist):
-    return {'int': hist_integral(hist),
-            'sum': hist_integral(hist, False),
-            'mean': hist_mean(hist),
-            'var': hist_var(hist),
-            'std': hist_std(hist)}
+def hist_stats(h):
+    return {'int': hist_integral(h),
+            'sum': hist_integral(h, False),
+            'mean': hist_mean(h),
+            'var': hist_var(h),
+            'std': hist_std(h)}
 
 
 # def hist_slice2d(h, range_):
@@ -150,5 +120,57 @@ def hist_fit(h, f, p0=None):
     return popt, pcov
 
 
-def hist_rebin(hist, range_, nbins):
+def hist_rebin(h, range_, nbins):
     raise NotImplementedError()
+
+
+def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
+    """ Converts TH2 object to something amenable to
+        plotting w/ matplotlab's pcolormesh.
+    """
+    nbins_x = th2.GetNbinsX()
+    nbins_y = th2.GetNbinsY()
+    xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
+    ys = np.zeros((nbins_y+1, nbins_x+1), np.float32)
+    values = np.zeros((nbins_y, nbins_x), np.float32)
+    errors = np.zeros((nbins_y, nbins_x), np.float32)
+    for i in range(nbins_x):
+        for j in range(nbins_y):
+            xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
+            ys[j][i] = th2.GetYaxis().GetBinLowEdge(j+1)
+            values[j][i] = th2.GetBinContent(i+1, j+1)
+            errors[j][i] = th2.GetBinError(i+1, j+1)
+        xs[nbins_y][i] = th2.GetXaxis().GetBinUpEdge(i)
+        ys[nbins_y][i] = th2.GetYaxis().GetBinUpEdge(nbins_y)
+    for j in range(nbins_y+1):
+        xs[j][nbins_x] = th2.GetXaxis().GetBinUpEdge(nbins_x)
+        ys[j][nbins_x] = th2.GetYaxis().GetBinUpEdge(j)
+
+    xs *= rescale_x
+    ys *= rescale_y
+    values *= rescale_z
+    errors *= rescale_z
+
+    return values, errors, xs, ys
+
+
+def hist2d_norm(h, norm=1, axis=None):
+    """
+
+    :param h:
+    :param norm: value to normalize the sum of axis to
+    :param axis: which axis to normalize None is the sum over all bins, 0 is columns, 1 is rows.
+    :return: The normalized histogram
+    """
+    values, errors, xs, ys = h
+    with np.errstate(divide='ignore'):
+        scale_values = norm / np.sum(values, axis=axis)
+        scale_values[scale_values == np.inf] = 1
+        scale_values[scale_values == -np.inf] = 1
+    if axis == 1:
+        scale_values.shape = (scale_values.shape[0], 1)
+    values = values * scale_values
+    errors = errors * scale_values
+    return values, errors, xs.copy(), ys.copy()
+
+

+ 32 - 8
filval/plotter.py

@@ -1,4 +1,8 @@
-#!/usr/bin/env python3
+"""
+    plotting.py
+    The functions in this module are meant for plotting the histogram objects created via
+    filval.histogram
+"""
 
 from collections import defaultdict
 from itertools import zip_longest
@@ -9,8 +13,8 @@ import matplotlib.pyplot as plt
 from markdown import Markdown
 import latexipy as lp
 
-from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
-                                    hist_normalize, hist_stats)
+from filval.histogram import (hist, hist2d, hist_bin_centers, hist_fit,
+                              hist_norm, hist_stats)
 
 __all__ = ['Plot',
            'decl_plot',
@@ -115,7 +119,8 @@ def generate_dashboard(plots, title, output='dashboard.html', template='dashboar
     if not isdir('output'):
         mkdir('output')
 
-    with open(join('output', output), 'w') as tempout:
+    dashboard_path = join('output', output)
+    with open(dashboard_path, 'w') as tempout:
         templ = env.get_template(template)
         tempout.write(templ.render(
             plots=get_by_n(plots, 3),
@@ -123,6 +128,7 @@ def generate_dashboard(plots, title, output='dashboard.html', template='dashboar
             source=source,
             ana_source=ana_source
         ))
+    return dashboard_path
 
 
 def _add_stats(hist, title=''):
@@ -259,7 +265,7 @@ def hist_plot(h, *args, norm=None, include_errors=False,
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
     if norm:
-        h = hist_normalize(h, norm)
+        h = hist_norm(h, norm)
     values, errors, edges = h
 
     scale = 1. if norm is None else norm / np.sum(values)
@@ -307,7 +313,7 @@ def hist_plot(h, *args, norm=None, include_errors=False,
     ax.grid(grid, color='#E0E0E0')
 
 
-def hist2d_plot(h, **kwargs):
+def hist2d_plot(h, txt_format=None, colorbar=False, **kwargs):
     """ Plots a 2D ROOT histogram object using matplotlib """
     try:
         values, errors, xs, ys = h
@@ -317,8 +323,26 @@ def hist2d_plot(h, **kwargs):
     plt.xlabel(kwargs.pop('xlabel', ''))
     plt.ylabel(kwargs.pop('ylabel', ''))
     plt.title(kwargs.pop('title', ''))
-    plt.pcolormesh(xs, ys, values, )
-    # axes.colorbar() TODO: Re-enable this
+    plt.pcolormesh(xs, ys, values, **kwargs)
+    if txt_format is not None:
+        cmap = plt.get_cmap()
+        min_, max_ = float(np.min(values)), float(np.max(values))
+
+        def get_intensity(val):
+            cmap_idx = int((cmap.N-1) * (val - min_) / (max_-min_))
+            color = cmap.colors[cmap_idx]
+            return color[0]*0.25 + color[1]*0.5 + color[2]*0.25
+
+        for idx_row in range(values.shape[0]):
+            for idx_col in range(values.shape[1]):
+                x_mid = (xs[idx_row, idx_col] + xs[idx_row, idx_col+1]) / 2
+                y_mid = (ys[idx_row, idx_col] + ys[idx_row+1, idx_col]) / 2
+                val = txt_format.format(values[idx_row, idx_col])
+                txt_color = 'w' if get_intensity(values[idx_row, idx_col]) < 0.5 else 'k'
+                plt.text(x_mid, y_mid, val, verticalalignment='center', horizontalalignment='center',
+                         color=txt_color, fontsize=12)
+    if colorbar:
+        plt.colorbar()
 
 
 def hist_plot_stack(hists: list, labels: list = None):