plotter.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. #!/usr/bin/env python3
  2. from collections import defaultdict
  3. from io import BytesIO
  4. from base64 import b64encode
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from markdown import Markdown
  8. import latexipy as lp
  9. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  10. hist_normalize, hist_stats)
  11. __all__ = ['Plot',
  12. 'decl_plot',
  13. 'grid_plot',
  14. 'render_plots',
  15. 'generate_dashboard',
  16. 'hist_plot',
  17. 'hist_plot_stack',
  18. 'hist2d_plot']
  19. class Plot:
  20. def __init__(self, subplots, name, title=None, docs="N/A", arg_dicts=None):
  21. self.subplots = subplots
  22. self.name = name
  23. self.title = title
  24. self.docs = docs
  25. self.arg_dicts = arg_dicts if arg_dicts is not None else {}
  26. MD = Markdown(extensions=['mdx_math'],
  27. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  28. lp.latexify(params={'pgf.texsystem': 'pdflatex',
  29. 'text.usetex': True,
  30. 'font.family': 'serif',
  31. 'pgf.preamble': [],
  32. 'font.size': 15,
  33. 'axes.labelsize': 15,
  34. 'axes.titlesize': 13,
  35. 'legend.fontsize': 13,
  36. 'xtick.labelsize': 11,
  37. 'ytick.labelsize': 11,
  38. 'figure.dpi': 150,
  39. 'savefig.transparent': False,
  40. },
  41. new_backend='TkAgg')
  42. def _fn_call_to_dict(fn, *args, **kwargs):
  43. from inspect import signature
  44. pnames = list(signature(fn).parameters)
  45. pvals = list(args) + list(kwargs.values())
  46. return {k: v for k, v in zip(pnames, pvals)}
  47. def _process_docs(fn):
  48. from inspect import getdoc
  49. raw = getdoc(fn)
  50. if raw:
  51. return MD.convert(raw)
  52. else:
  53. return None
  54. def decl_plot(fn):
  55. from functools import wraps
  56. @wraps(fn)
  57. def f(*args, **kwargs):
  58. fn(*args, **kwargs)
  59. argdict = _fn_call_to_dict(fn, *args, **kwargs)
  60. docs = _process_docs(fn)
  61. return argdict, docs
  62. return f
  63. def generate_dashboard(plots, title, output='dashboard.htm', source_file=None):
  64. from jinja2 import Environment, PackageLoader, select_autoescape
  65. from os.path import join
  66. from urllib.parse import quote
  67. env = Environment(
  68. loader=PackageLoader('plots', 'templates'),
  69. autoescape=select_autoescape(['htm', 'html', 'xml']),
  70. )
  71. env.globals.update({'quote': quote,
  72. 'enumerate': enumerate,
  73. 'zip': zip,
  74. })
  75. def render_to_file(template_name, **kwargs):
  76. with open(join('output', output), 'w') as tempout:
  77. template = env.get_template(template_name)
  78. tempout.write(template.render(**kwargs))
  79. def get_by_n(objs, n=2):
  80. objs = list(objs)
  81. while objs:
  82. yield objs[:n]
  83. objs = objs[n:]
  84. if source_file is not None:
  85. with open(source_file, 'r') as this_file:
  86. source = this_file.read()
  87. else:
  88. source = "# Not supplied!!"
  89. render_to_file('dashboard.htm', plots=get_by_n(plots, 3),
  90. title=title, source=source,
  91. outdir="figures/")
  92. def _add_stats(hist, title=''):
  93. fmt = r'''\begin{{eqnarray*}}
  94. \sum{{x_i}} &=& {sum:5.3f} \\
  95. \sum{{\Delta x_i \cdot x_i}} &=& {int:5.3G} \\
  96. \mu &=& {mean:5.3G} \\
  97. \sigma^2 &=& {var:5.3G} \\
  98. \sigma &=& {std:5.3G}
  99. \end{{eqnarray*}}'''
  100. txt = fmt.format(**hist_stats(hist), title=title)
  101. txt = txt.replace('\n', ' ')
  102. plt.text(0.7, 0.9, txt,
  103. bbox={'facecolor': 'white',
  104. 'alpha': 0.7,
  105. 'boxstyle': 'square,pad=0.8'},
  106. transform=plt.gca().transAxes,
  107. verticalalignment='top',
  108. horizontalalignment='left',
  109. size='small')
  110. if title:
  111. plt.text(0.72, 0.97, title,
  112. bbox={'facecolor': 'white',
  113. 'alpha': 0.8},
  114. transform=plt.gca().transAxes,
  115. verticalalignment='top',
  116. horizontalalignment='left')
  117. def grid_plot(subplots):
  118. if any(len(row) != len(subplots[0]) for row in subplots):
  119. raise ValueError("make_plot requires a rectangular list-of-lists as "
  120. "input. Fill empty slots with None")
  121. def calc_rowspan(fig, row, col):
  122. span = 1
  123. for r in range(row + 1, len(fig)):
  124. if fig[r][col] == "FU":
  125. span += 1
  126. else:
  127. break
  128. return span
  129. def calc_colspan(fig, row, col):
  130. span = 1
  131. for c in range(col + 1, len(fig[row])):
  132. if fig[row][c] == "FL":
  133. span += 1
  134. else:
  135. break
  136. return span
  137. rows = len(subplots)
  138. cols = len(subplots[0])
  139. argdicts = defaultdict(list)
  140. docs = defaultdict(list)
  141. for i in range(rows):
  142. for j in range(cols):
  143. cell = subplots[i][j]
  144. if cell in ("FL", "FU", None):
  145. continue
  146. if not isinstance(cell, list):
  147. cell = [cell]
  148. colspan = calc_colspan(subplots, i, j)
  149. rowspan = calc_rowspan(subplots, i, j)
  150. plt.subplot2grid((rows, cols), (i, j),
  151. colspan=colspan, rowspan=rowspan)
  152. for plot in cell:
  153. plot_fn, args, kwargs = plot
  154. this_args, this_docs = plot_fn(*args, **kwargs)
  155. argdicts[(i, j)].append(this_args)
  156. docs[(i, j)].append(this_docs)
  157. return argdicts, docs
  158. def render_plots(plots, exts=('png',), scale=1.0, to_disk=True):
  159. for plot in plots:
  160. print(f'Building plot {plot.name}')
  161. plot.data = None
  162. if to_disk:
  163. with lp.figure(plot.name, directory='output/figures',
  164. exts=exts,
  165. size=(scale * 10, scale * 10)):
  166. argdicts, docs = grid_plot(plot.subplots)
  167. else:
  168. out = BytesIO()
  169. with lp.mem_figure(out,
  170. ext=exts[0],
  171. size=(scale * 10, scale * 10)):
  172. argdicts, docs = grid_plot(plot.subplots)
  173. out.seek(0)
  174. plot.data = b64encode(out.read()).decode()
  175. plot.argdicts = argdicts
  176. plot.docs = docs
  177. def add_decorations(axes, luminosity, energy):
  178. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  179. axes.text(0.01, 0.98, cms_prelim,
  180. horizontalalignment='left',
  181. verticalalignment='top',
  182. transform=axes.transAxes)
  183. lumi = ""
  184. energy_str = ""
  185. if luminosity is not None:
  186. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  187. if energy is not None:
  188. energy_str = r'({} TeV)'.format(energy)
  189. axes.text(1, 1, ' '.join([lumi, energy_str]),
  190. horizontalalignment='right',
  191. verticalalignment='bottom',
  192. transform=axes.transAxes)
  193. def hist_plot(h, *args, norm=None, include_errors=False,
  194. log=False, xlim=None, ylim=None, fit=None,
  195. grid=False, stats=True, **kwargs):
  196. """ Plots a 1D ROOT histogram object using matplotlib """
  197. from inspect import signature
  198. if norm:
  199. h = hist_normalize(h, norm)
  200. values, errors, edges = h
  201. scale = 1. if norm is None else norm / np.sum(values)
  202. values = [val * scale for val in values]
  203. errors = [val * scale for val in errors]
  204. left, right = np.array(edges[:-1]), np.array(edges[1:])
  205. x = np.array([left, right]).T.flatten()
  206. y = np.array([values, values]).T.flatten()
  207. ax = plt.gca()
  208. ax.set_xlabel(kwargs.pop('xlabel', ''))
  209. ax.set_ylabel(kwargs.pop('ylabel', ''))
  210. title = kwargs.pop('title', '')
  211. if xlim is not None:
  212. ax.set_xlim(xlim)
  213. if ylim is not None:
  214. ax.set_ylim(ylim)
  215. # elif not log:
  216. # axes.set_ylim((0, None))
  217. ax.plot(x, y, *args, linewidth=1, **kwargs)
  218. if include_errors:
  219. ax.errorbar(hist_bin_centers(h), values, yerr=errors,
  220. color='k', marker=None, linestyle='None',
  221. barsabove=True, elinewidth=.7, capsize=1)
  222. if log:
  223. ax.set_yscale('log')
  224. if fit:
  225. f, p0 = fit
  226. popt, pcov = hist_fit(h, f, p0)
  227. fit_xs = np.linspace(x[0], x[-1], 100)
  228. fit_ys = f(fit_xs, *popt)
  229. ax.plot(fit_xs, fit_ys, '--g')
  230. arglabels = list(signature(f).parameters)[1:]
  231. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  232. for label, value in zip(arglabels, popt))
  233. ax.text(0.60, 0.95, label_txt, va='top', transform=ax.transAxes,
  234. fontsize='medium', family='monospace', usetex=False)
  235. if stats:
  236. _add_stats(h, title)
  237. else:
  238. ax.set_title(title)
  239. ax.grid(grid, color='#E0E0E0')
  240. def hist2d_plot(h, **kwargs):
  241. """ Plots a 2D ROOT histogram object using matplotlib """
  242. try:
  243. values, errors, xs, ys = h
  244. except (TypeError, ValueError):
  245. values, errors, xs, ys = hist2d(h)
  246. plt.xlabel(kwargs.pop('xlabel', ''))
  247. plt.ylabel(kwargs.pop('ylabel', ''))
  248. plt.title(kwargs.pop('title', ''))
  249. plt.pcolormesh(xs, ys, values, )
  250. # axes.colorbar() TODO: Re-enable this
  251. def hist_plot_stack(hists: list, labels: list = None):
  252. """
  253. Creates a stacked histogram in the current axes.
  254. :param hists: list of histogram
  255. :param labels:
  256. :return:
  257. """
  258. if len(hists) == 0:
  259. return
  260. if len(set([len(hist[0]) for hist in hists])) != 1:
  261. raise ValueError("all histograms must have the same number of bins")
  262. if labels is None:
  263. labels = [None for _ in hists]
  264. if len(labels) != len(hists):
  265. raise ValueError("Label mismatch")
  266. bottoms = [0 for _ in hists[0][0]]
  267. for hist, label in zip(hists, labels):
  268. centers = []
  269. widths = []
  270. heights = []
  271. for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0]):
  272. centers.append((right + left) / 2)
  273. widths.append(right - left)
  274. heights.append(content)
  275. plt.bar(centers, heights, widths, bottoms, label=label)
  276. for i, content in enumerate(hist[0]):
  277. bottoms[i] += content