Browse Source

Rewrite of plotting routines to revolve around subplots instead of full
figures.

Caleb Fangmeier 6 years ago
parent
commit
c92f9ce6d3
3 changed files with 117 additions and 59 deletions
  1. 1 1
      .flake8
  2. 11 3
      filval/histogram_utils.py
  3. 105 55
      filval/plotter.py

+ 1 - 1
.flake8

@@ -1,4 +1,4 @@
 
 [flake8]
-ignore = E248, E241, E226, E402, E701, E402
+ignore = E248, E241, E226, E402, E701, E402, E202
 max_line_length=120

+ 11 - 3
filval/histogram_utils.py

@@ -12,7 +12,7 @@ import numpy as np
 from scipy.optimize import curve_fit
 
 
-def hist(th1):
+def hist(th1, rescale_x=1.0, rescale_y=1.0):
     nbins = th1.GetNbinsX()
 
     edges = np.zeros(nbins+1, np.float32)
@@ -25,18 +25,21 @@ def hist(th1):
         errors[i] = th1.GetBinError(i+1)
 
     edges[nbins] = th1.GetXaxis().GetBinUpEdge(nbins)
+    edges *= rescale_x
+    values *= rescale_y
+    errors *= rescale_y
     return values, errors, edges
 
+
 def hist_bin_centers(h):
     _, _, edges = h
     return (edges[:-1] + edges[1:])/2.0
 
 
-def hist2d(th2, include_errors=False):
+def hist2d(th2, rescale_x=1.0, rescale_y=1.0, rescale_z=1.0):
     """ Converts TH2 object to something amenable to
         plotting w/ matplotlab's pcolormesh.
     """
-    import numpy as np
     nbins_x = th2.GetNbinsX()
     nbins_y = th2.GetNbinsY()
     xs = np.zeros((nbins_y+1, nbins_x+1), np.float32)
@@ -55,6 +58,11 @@ def hist2d(th2, include_errors=False):
         xs[j][nbins_x] = th2.GetXaxis().GetBinUpEdge(nbins_x+1)
         ys[j][nbins_x] = th2.GetYaxis().GetBinUpEdge(j+1)
 
+    xs *= rescale_x
+    ys *= rescale_y
+    values *= rescale_z
+    errors *= rescale_z
+
     return values, errors, xs, ys
 
 

+ 105 - 55
filval/plotter.py

@@ -1,79 +1,129 @@
 #!/usr/bin/env python3
-from collections import namedtuple
 
 import numpy as np
+import matplotlib.pyplot as plt
 from markdown import Markdown
 import latexipy as lp
 
 from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
                                     hist_normalize)
-__all__ = ['make_plot',
-           'plot_registry',
+__all__ = ['Plot',
+           'decl_plot',
+           'grid_plot',
+           'save_plots',
            'hist_plot',
            'hist2d_plot']
 
-plot_registry = {}
-Plot = namedtuple('Plot', ['name', 'filename', 'title', 'desc', 'args'])
+
+class Plot:
+    def __init__(self, subplots, name, title=None, docs="N/A", argdict={}):
+        self.subplots = subplots
+        self.name = name
+        self.title = title
+        self.docs = docs
+        self.argdict = argdict
+
 
 MD = Markdown(extensions=['mdx_math'],
               extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
 
-lp.latexify(params={'pgf.texsystem': 'xelatex',
+lp.latexify(params={'pgf.texsystem': 'pdflatex',
                     'text.usetex': True,
                     'font.family': 'serif',
-                    'pgf.preamble': [r'\usepackage[utf8x]{inputenc}',
-                                     r'\usepackage[T1]{fontenc}'],
-                    'font.size': 8,
-                    'axes.labelsize': 8,
-                    'axes.titlesize': 8,
-                    'legend.fontsize': 8,
-                    'xtick.labelsize': 8,
-                    'ytick.labelsize': 8,
+                    'pgf.preamble': [],
+                    'font.size': 15,
+                    'axes.labelsize': 15,
+                    'axes.titlesize': 13,
+                    'legend.fontsize': 13,
+                    'xtick.labelsize': 11,
+                    'ytick.labelsize': 11,
                     'figure.dpi': 150,
-                    'savefig.transparent': True,
+                    'savefig.transparent': False,
                     },
             new_backend='TkAgg')
 
 
-def make_plot(plot_name=None, title='', scale=1, exts=['png', 'pgf']):
-    import matplotlib.pyplot as plt
+def _fn_call_to_dict(fn, *args, **kwargs):
+    from inspect import signature
+    pnames = list(signature(fn).parameters)
+    pvals = list(args)+list(kwargs.values())
+    return {k: v for k, v in zip(pnames, pvals)}
+
+
+def _process_docs(fn):
+    from inspect import getdoc
+    raw = getdoc(fn)
+    if raw:
+        return MD.convert(raw)
+    else:
+        return None
+
+
+def decl_plot(fn):
     from functools import wraps
-    from os.path import join
-    from inspect import signature, getdoc
-
-    def fn_call_to_dict(fn, *args, **kwargs):
-        pnames = list(signature(fn).parameters)
-        pvals = list(args)+list(kwargs.keys())
-        return {k: v for k, v in zip(pnames, pvals)}
-
-    def process_docs(fn):
-        raw = getdoc(fn)
-        if raw:
-            return MD.convert(raw)
-        else:
-            return None
-
-    def wrap(fn):
-        @wraps(fn)
-        def f(*args, **kwargs):
-            nonlocal plot_name
-            pdict = fn_call_to_dict(fn, *args, **kwargs)
-            if plot_name is None:
-                pstr = ','.join('{}:{}'.format(pname, pval)
-                                for pname, pval in pdict.items())
-                plot_name = fn.__name__ + '::' + pstr
-                plot_name = plot_name.replace('/', '_').replace('.', '_')
-            with lp.figure(plot_name, directory='output/figures',
-                           exts=exts,
-                           size=(scale*10, scale*10)):
-                fn(*args, **kwargs)
-                plt.tight_layout()
-            filename = plot_name+".png"
-            plot_registry[fn.__name__] = Plot(fn.__name__, join('figures', filename),
-                                              title, process_docs(fn), pdict)
-        return f
-
-    return wrap
+
+    @wraps(fn)
+    def f(*args, **kwargs):
+        fn(*args, **kwargs)
+        argdict = _fn_call_to_dict(fn, *args, **kwargs)
+        docs = _process_docs(fn)
+
+        return argdict, docs
+    return f
+
+
+def grid_plot(subplots):
+    if any(len(row) != len(subplots[0]) for row in subplots):
+        raise ValueError("make_plot requires a rectangular list-of-lists as "
+                         "input. Fill empty slots with None")
+
+    def calc_rowspan(fig, row, col):
+        span = 1
+        for r in range(row+1, len(fig)):
+            if fig[r][col] == "FU":
+                span += 1
+            else:
+                break
+        return span
+
+    def calc_colspan(fig, row, col):
+        span = 1
+        for c in range(col+1, len(fig[row])):
+            if fig[row][c] == "FL":
+                span += 1
+            else:
+                break
+        return span
+
+    rows = len(subplots)
+    cols = len(subplots[0])
+
+    argdicts = {}
+    docs = {}
+    for i in range(rows):
+        for j in range(cols):
+            plot = subplots[i][j]
+            if plot in ("FL", "FU", None):
+                continue
+            plot_fn, args, kwargs = plot
+            colspan = calc_colspan(subplots, i, j)
+            rowspan = calc_rowspan(subplots, i, j)
+            plt.subplot2grid((rows, cols), (i, j),
+                             colspan=colspan, rowspan=rowspan)
+            this_args, this_docs = plot_fn(*args, **kwargs)
+            argdicts[(i, j)] = this_args
+            docs[(i, j)] = this_docs
+    return argdicts, docs
+
+
+def save_plots(plots, exts=['png'], scale=1.0):
+    for plot in plots:
+        with lp.figure(plot.name, directory='output/figures',
+                       exts=exts,
+                       size=(scale*10, scale*10)):
+            argdict, docs = grid_plot(plot.subplots)
+            plot.argdict = argdict
+            plot.docs = docs
 
 
 def add_decorations(axes, luminosity, energy):
@@ -98,7 +148,7 @@ def add_decorations(axes, luminosity, energy):
 
 def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
               log=False, fig=None, xlim=None, ylim=None, fit=None,
-              **kwargs):
+              grid=False, **kwargs):
     """ Plots a 1D ROOT histogram object using matplotlib """
     from inspect import signature
     if norm:
@@ -145,7 +195,7 @@ def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
                               for label, value in zip(arglabels, popt))
         axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
                   fontsize='medium', family='monospace', usetex=False)
-    axes.grid()
+    axes.grid(grid, color='#E0E0E0')
 
 
 def hist2d_plot(h, *args, axes=None, **kwargs):