瀏覽代碼

Adds text-display to hist2d_plot and renames some modules.

Caleb Fangmeier 6 年之前
父節點
當前提交
5f45b0ec37
共有 2 個文件被更改,包括 115 次插入69 次删除
  1. 83 61
      filval/histogram_utils.py
  2. 32 8
      filval/plotter.py

+ 83 - 61
filval/histogram_utils.py

@@ -1,12 +1,12 @@
-'''
-    histogram_utils.py
+"""
+    histogram.py
     The functions in this module use a representation of a histogram that is a
     The functions in this module use a representation of a histogram that is a
     tuple containing an arr of N bin values, an array of N bin errors(symmetric)
     tuple containing an arr of N bin values, an array of N bin errors(symmetric)
     and an array of N+1 bin edges(N lower edges + 1 upper edge).
     and an array of N+1 bin edges(N lower edges + 1 upper edge).
 
 
     For 2d histograms, It is similar, but the arrays are two dimensional and
     For 2d histograms, It is similar, but the arrays are two dimensional and
     there are separate arrays for x-edges and y-edges.
     there are separate arrays for x-edges and y-edges.
-'''
+"""
 
 
 import numpy as np
 import numpy as np
 from scipy.optimize import curve_fit
 from scipy.optimize import curve_fit
@@ -36,39 +36,8 @@ def hist_bin_centers(h):
     return (edges[:-1] + edges[1:])/2.0
     return (edges[:-1] + edges[1:])/2.0
 
 
 
 
-def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
-    """ Converts TH2 object to something amenable to
-        plotting w/ matplotlab's pcolormesh.
-    """
-    nbins_x = th2.GetNbinsX()
-    nbins_y = th2.GetNbinsY()
-    print(nbins_x, nbins_y)
-    xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    ys = np.zeros((nbins_y+1, nbins_x+1), np.float32)
-    values = np.zeros((nbins_y, nbins_x), np.float32)
-    errors = np.zeros((nbins_y, nbins_x), np.float32)
-    for i in range(nbins_x):
-        for j in range(nbins_y):
-            xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
-            ys[j][i] = th2.GetYaxis().GetBinLowEdge(j+1)
-            values[j][i] = th2.GetBinContent(i+1, j+1)
-            errors[j][i] = th2.GetBinError(i+1, j+1)
-        xs[nbins_y][i] = th2.GetXaxis().GetBinUpEdge(i)
-        ys[nbins_y][i] = th2.GetYaxis().GetBinUpEdge(nbins_y)
-    for j in range(nbins_y+1):
-        xs[j][nbins_x] = th2.GetXaxis().GetBinUpEdge(nbins_x)
-        ys[j][nbins_x] = th2.GetYaxis().GetBinUpEdge(j)
-
-    xs *= rescale_x
-    ys *= rescale_y
-    values *= rescale_z
-    errors *= rescale_z
-
-    return values, errors, xs, ys
-
-
-def hist_slice(hist, range_):
-    values, errors, edges = hist
+def hist_slice(h, range_):
+    values, errors, edges = h
     lim_low, lim_high = range_
     lim_low, lim_high = range_
     slice_ = np.logical_and(edges[:-1] > lim_low, edges[1:] < lim_high)
     slice_ = np.logical_and(edges[:-1] > lim_low, edges[1:] < lim_high)
     last = len(slice_) - np.argmax(slice_[::-1])
     last = len(slice_) - np.argmax(slice_[::-1])
@@ -77,55 +46,56 @@ def hist_slice(hist, range_):
             np.concatenate([edges[:-1][slice_], [edges[last]]]))
             np.concatenate([edges[:-1][slice_], [edges[last]]]))
 
 
 
 
-def hist_add(*hists):
-    if len(hists) == 0:
+def hist_add(*hs):
+    if len(hs) == 0:
         return np.zeros(0)
         return np.zeros(0)
-    vals, errs, edges = zip(*hists)
+    vals, errs, edges = zip(*hs)
     return np.sum(vals, axis=0), np.sqrt(np.sum([err*err for err in errs], axis=0)), edges[0]
     return np.sum(vals, axis=0), np.sqrt(np.sum([err*err for err in errs], axis=0)), edges[0]
 
 
 
 
-def hist_integral(hist, times_bin_width=True):
-    values, errors, edges = hist
+def hist_integral(h, times_bin_width=True):
+    values, errors, edges = h
     if times_bin_width:
     if times_bin_width:
         bin_widths = [abs(x2 - x1) for x1, x2 in zip(edges[:-1], edges[1:])]
         bin_widths = [abs(x2 - x1) for x1, x2 in zip(edges[:-1], edges[1:])]
         return sum(val*width for val, width in zip(values, bin_widths))
         return sum(val*width for val, width in zip(values, bin_widths))
     else:
     else:
         return sum(values)
         return sum(values)
 
 
-def hist_scale(hist, scale):
-    values, errors, edges = hist
+
+def hist_scale(h, scale):
+    values, errors, edges = h
     return values*scale, errors*scale, edges
     return values*scale, errors*scale, edges
 
 
-def hist_normalize(hist, norm = 1):
-    scale = norm/np.sum(hist[0])
-    return hist_scale(hist, scale)
 
 
+def hist_norm(h, norm=1):
+    scale = norm/np.sum(h[0])
+    return hist_scale(h, scale)
 
 
 
 
-def hist_mean(hist):
-    xs = hist_bin_centers(hist)
-    ys, _, _ = hist
+def hist_mean(h):
+    xs = hist_bin_centers(h)
+    ys, _, _ = h
     return sum(x*y for x, y in zip(xs, ys)) / sum(ys)
     return sum(x*y for x, y in zip(xs, ys)) / sum(ys)
 
 
 
 
-def hist_var(hist):
-    xs = hist_bin_centers(hist)
-    ys, _, _ = hist
+def hist_var(h):
+    xs = hist_bin_centers(h)
+    ys, _, _ = h
     mean = sum(x*y for x, y in zip(xs, ys)) / sum(ys)
     mean = sum(x*y for x, y in zip(xs, ys)) / sum(ys)
     mean2 = sum((x**2)*y for x, y in zip(xs, ys)) / sum(ys)
     mean2 = sum((x**2)*y for x, y in zip(xs, ys)) / sum(ys)
     return mean2 - mean**2
     return mean2 - mean**2
 
 
 
 
-def hist_std(hist):
-    return np.sqrt(hist_var(hist))
+def hist_std(h):
+    return np.sqrt(hist_var(h))
 
 
 
 
-def hist_stats(hist):
-    return {'int': hist_integral(hist),
-            'sum': hist_integral(hist, False),
-            'mean': hist_mean(hist),
-            'var': hist_var(hist),
-            'std': hist_std(hist)}
+def hist_stats(h):
+    return {'int': hist_integral(h),
+            'sum': hist_integral(h, False),
+            'mean': hist_mean(h),
+            'var': hist_var(h),
+            'std': hist_std(h)}
 
 
 
 
 # def hist_slice2d(h, range_):
 # def hist_slice2d(h, range_):
@@ -150,5 +120,57 @@ def hist_fit(h, f, p0=None):
     return popt, pcov
     return popt, pcov
 
 
 
 
-def hist_rebin(hist, range_, nbins):
+def hist_rebin(h, range_, nbins):
     raise NotImplementedError()
     raise NotImplementedError()
+
+
+def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
+    """ Converts TH2 object to something amenable to
+        plotting w/ matplotlab's pcolormesh.
+    """
+    nbins_x = th2.GetNbinsX()
+    nbins_y = th2.GetNbinsY()
+    xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
+    ys = np.zeros((nbins_y+1, nbins_x+1), np.float32)
+    values = np.zeros((nbins_y, nbins_x), np.float32)
+    errors = np.zeros((nbins_y, nbins_x), np.float32)
+    for i in range(nbins_x):
+        for j in range(nbins_y):
+            xs[j][i] = th2.GetXaxis().GetBinLowEdge(i+1)
+            ys[j][i] = th2.GetYaxis().GetBinLowEdge(j+1)
+            values[j][i] = th2.GetBinContent(i+1, j+1)
+            errors[j][i] = th2.GetBinError(i+1, j+1)
+        xs[nbins_y][i] = th2.GetXaxis().GetBinUpEdge(i)
+        ys[nbins_y][i] = th2.GetYaxis().GetBinUpEdge(nbins_y)
+    for j in range(nbins_y+1):
+        xs[j][nbins_x] = th2.GetXaxis().GetBinUpEdge(nbins_x)
+        ys[j][nbins_x] = th2.GetYaxis().GetBinUpEdge(j)
+
+    xs *= rescale_x
+    ys *= rescale_y
+    values *= rescale_z
+    errors *= rescale_z
+
+    return values, errors, xs, ys
+
+
+def hist2d_norm(h, norm=1, axis=None):
+    """
+
+    :param h:
+    :param norm: value to normalize the sum of axis to
+    :param axis: which axis to normalize None is the sum over all bins, 0 is columns, 1 is rows.
+    :return: The normalized histogram
+    """
+    values, errors, xs, ys = h
+    with np.errstate(divide='ignore'):
+        scale_values = norm / np.sum(values, axis=axis)
+        scale_values[scale_values == np.inf] = 1
+        scale_values[scale_values == -np.inf] = 1
+    if axis == 1:
+        scale_values.shape = (scale_values.shape[0], 1)
+    values = values * scale_values
+    errors = errors * scale_values
+    return values, errors, xs.copy(), ys.copy()
+
+

+ 32 - 8
filval/plotter.py

@@ -1,4 +1,8 @@
-#!/usr/bin/env python3
+"""
+    plotting.py
+    The functions in this module are meant for plotting the histogram objects created via
+    filval.histogram
+"""
 
 
 from collections import defaultdict
 from collections import defaultdict
 from itertools import zip_longest
 from itertools import zip_longest
@@ -9,8 +13,8 @@ import matplotlib.pyplot as plt
 from markdown import Markdown
 from markdown import Markdown
 import latexipy as lp
 import latexipy as lp
 
 
-from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
-                                    hist_normalize, hist_stats)
+from filval.histogram import (hist, hist2d, hist_bin_centers, hist_fit,
+                              hist_norm, hist_stats)
 
 
 __all__ = ['Plot',
 __all__ = ['Plot',
            'decl_plot',
            'decl_plot',
@@ -115,7 +119,8 @@ def generate_dashboard(plots, title, output='dashboard.html', template='dashboar
     if not isdir('output'):
     if not isdir('output'):
         mkdir('output')
         mkdir('output')
 
 
-    with open(join('output', output), 'w') as tempout:
+    dashboard_path = join('output', output)
+    with open(dashboard_path, 'w') as tempout:
         templ = env.get_template(template)
         templ = env.get_template(template)
         tempout.write(templ.render(
         tempout.write(templ.render(
             plots=get_by_n(plots, 3),
             plots=get_by_n(plots, 3),
@@ -123,6 +128,7 @@ def generate_dashboard(plots, title, output='dashboard.html', template='dashboar
             source=source,
             source=source,
             ana_source=ana_source
             ana_source=ana_source
         ))
         ))
+    return dashboard_path
 
 
 
 
 def _add_stats(hist, title=''):
 def _add_stats(hist, title=''):
@@ -259,7 +265,7 @@ def hist_plot(h, *args, norm=None, include_errors=False,
     """ Plots a 1D ROOT histogram object using matplotlib """
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
     from inspect import signature
     if norm:
     if norm:
-        h = hist_normalize(h, norm)
+        h = hist_norm(h, norm)
     values, errors, edges = h
     values, errors, edges = h
 
 
     scale = 1. if norm is None else norm / np.sum(values)
     scale = 1. if norm is None else norm / np.sum(values)
@@ -307,7 +313,7 @@ def hist_plot(h, *args, norm=None, include_errors=False,
     ax.grid(grid, color='#E0E0E0')
     ax.grid(grid, color='#E0E0E0')
 
 
 
 
-def hist2d_plot(h, **kwargs):
+def hist2d_plot(h, txt_format=None, colorbar=False, **kwargs):
     """ Plots a 2D ROOT histogram object using matplotlib """
     """ Plots a 2D ROOT histogram object using matplotlib """
     try:
     try:
         values, errors, xs, ys = h
         values, errors, xs, ys = h
@@ -317,8 +323,26 @@ def hist2d_plot(h, **kwargs):
     plt.xlabel(kwargs.pop('xlabel', ''))
     plt.xlabel(kwargs.pop('xlabel', ''))
     plt.ylabel(kwargs.pop('ylabel', ''))
     plt.ylabel(kwargs.pop('ylabel', ''))
     plt.title(kwargs.pop('title', ''))
     plt.title(kwargs.pop('title', ''))
-    plt.pcolormesh(xs, ys, values, )
-    # axes.colorbar() TODO: Re-enable this
+    plt.pcolormesh(xs, ys, values, **kwargs)
+    if txt_format is not None:
+        cmap = plt.get_cmap()
+        min_, max_ = float(np.min(values)), float(np.max(values))
+
+        def get_intensity(val):
+            cmap_idx = int((cmap.N-1) * (val - min_) / (max_-min_))
+            color = cmap.colors[cmap_idx]
+            return color[0]*0.25 + color[1]*0.5 + color[2]*0.25
+
+        for idx_row in range(values.shape[0]):
+            for idx_col in range(values.shape[1]):
+                x_mid = (xs[idx_row, idx_col] + xs[idx_row, idx_col+1]) / 2
+                y_mid = (ys[idx_row, idx_col] + ys[idx_row+1, idx_col]) / 2
+                val = txt_format.format(values[idx_row, idx_col])
+                txt_color = 'w' if get_intensity(values[idx_row, idx_col]) < 0.5 else 'k'
+                plt.text(x_mid, y_mid, val, verticalalignment='center', horizontalalignment='center',
+                         color=txt_color, fontsize=12)
+    if colorbar:
+        plt.colorbar()
 
 
 
 
 def hist_plot_stack(hists: list, labels: list = None):
 def hist_plot_stack(hists: list, labels: list = None):