plotter.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. #!/usr/bin/env python3
  2. from collections import defaultdict
  3. from itertools import zip_longest
  4. from io import BytesIO
  5. from base64 import b64encode
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from markdown import Markdown
  9. import latexipy as lp
  10. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  11. hist_normalize, hist_stats)
  12. __all__ = ['Plot',
  13. 'decl_plot',
  14. 'grid_plot',
  15. 'render_plots',
  16. 'generate_dashboard',
  17. 'hist_plot',
  18. 'hist_plot_stack',
  19. 'hist2d_plot',
  20. 'hists_to_table']
  21. class Plot:
  22. def __init__(self, subplots, name, title=None, docs="N/A", arg_dicts=None):
  23. self.subplots = subplots
  24. self.name = name
  25. self.title = title
  26. self.docs = docs
  27. self.arg_dicts = arg_dicts if arg_dicts is not None else {}
  28. MD = Markdown(extensions=['mdx_math'],
  29. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  30. lp.latexify(params={'pgf.texsystem': 'pdflatex',
  31. 'text.usetex': True,
  32. 'font.family': 'serif',
  33. 'pgf.preamble': [],
  34. 'font.size': 15,
  35. 'axes.labelsize': 15,
  36. 'axes.titlesize': 13,
  37. 'legend.fontsize': 13,
  38. 'xtick.labelsize': 11,
  39. 'ytick.labelsize': 11,
  40. 'figure.dpi': 150,
  41. 'savefig.transparent': False,
  42. },
  43. new_backend='TkAgg')
  44. def _fn_call_to_dict(fn, *args, **kwargs):
  45. from inspect import signature
  46. from html import escape
  47. pnames = list(signature(fn).parameters)
  48. pvals = list(args) + list(kwargs.values())
  49. return {escape(str(k)): escape(str(v)) for k, v in zip(pnames, pvals)}
  50. def _process_docs(fn):
  51. from inspect import getdoc
  52. raw = getdoc(fn)
  53. if raw:
  54. return MD.convert(raw)
  55. else:
  56. return None
  57. def decl_plot(fn):
  58. from functools import wraps
  59. @wraps(fn)
  60. def f(*args, **kwargs):
  61. txt = fn(*args, **kwargs)
  62. argdict = _fn_call_to_dict(fn, *args, **kwargs)
  63. docs = _process_docs(fn)
  64. if not txt:
  65. txt = ''
  66. txt = MD.convert(txt)
  67. return argdict, docs, txt
  68. return f
  69. def generate_dashboard(plots, title, output='dashboard.html', template='dashboard.j2', source_file=None, ana_source=None):
  70. from jinja2 import Environment, PackageLoader, select_autoescape
  71. from os.path import join, isdir
  72. from os import mkdir
  73. from urllib.parse import quote
  74. env = Environment(
  75. loader=PackageLoader('filval', 'templates'),
  76. autoescape=select_autoescape(['htm', 'html', 'xml']),
  77. )
  78. env.globals.update({'quote': quote,
  79. 'enumerate': enumerate,
  80. 'zip': zip,
  81. })
  82. def get_by_n(objects, n=2):
  83. objects = list(objects)
  84. while objects:
  85. yield objects[:n]
  86. objects = objects[n:]
  87. if source_file is not None:
  88. with open(source_file, 'r') as this_file:
  89. source = this_file.read()
  90. else:
  91. source = "# Not supplied!!"
  92. if not isdir('output'):
  93. mkdir('output')
  94. with open(join('output', output), 'w') as tempout:
  95. templ = env.get_template(template)
  96. tempout.write(templ.render(
  97. plots=get_by_n(plots, 3),
  98. title=title,
  99. source=source,
  100. ana_source=ana_source
  101. ))
  102. def _add_stats(hist, title=''):
  103. fmt = r'''\begin{{eqnarray*}}
  104. \sum{{x_i}} &=& {sum:5.3f} \\
  105. \sum{{\Delta x_i \cdot x_i}} &=& {int:5.3G} \\
  106. \mu &=& {mean:5.3G} \\
  107. \sigma^2 &=& {var:5.3G} \\
  108. \sigma &=& {std:5.3G}
  109. \end{{eqnarray*}}'''
  110. txt = fmt.format(**hist_stats(hist), title=title)
  111. txt = txt.replace('\n', ' ')
  112. plt.text(0.7, 0.9, txt,
  113. bbox={'facecolor': 'white',
  114. 'alpha': 0.7,
  115. 'boxstyle': 'square,pad=0.8'},
  116. transform=plt.gca().transAxes,
  117. verticalalignment='top',
  118. horizontalalignment='left',
  119. size='small')
  120. if title:
  121. plt.text(0.72, 0.97, title,
  122. bbox={'facecolor': 'white',
  123. 'alpha': 0.8},
  124. transform=plt.gca().transAxes,
  125. verticalalignment='top',
  126. horizontalalignment='left')
  127. def grid_plot(subplots):
  128. if any(len(row) != len(subplots[0]) for row in subplots):
  129. raise ValueError('make_plot requires a rectangular list-of-lists as '
  130. 'input. Fill empty slots with None')
  131. def calc_row_span(fig, row, col):
  132. span = 1
  133. for r in range(row + 1, len(fig)):
  134. if fig[r][col] == 'FU':
  135. span += 1
  136. else:
  137. break
  138. return span
  139. def calc_column_span(fig, row, col):
  140. span = 1
  141. for c in range(col + 1, len(fig[row])):
  142. if fig[row][c] == 'FL':
  143. span += 1
  144. else:
  145. break
  146. return span
  147. rows = len(subplots)
  148. cols = len(subplots[0])
  149. argdicts = defaultdict(list)
  150. docs = defaultdict(list)
  151. txts = defaultdict(list)
  152. for i in range(rows):
  153. for j in range(cols):
  154. cell = subplots[i][j]
  155. if cell in ('FL', 'FU', None):
  156. continue
  157. if not isinstance(cell, list):
  158. cell = [cell]
  159. column_span = calc_column_span(subplots, i, j)
  160. row_span = calc_row_span(subplots, i, j)
  161. plt.subplot2grid((rows, cols), (i, j),
  162. colspan=column_span, rowspan=row_span)
  163. for plot in cell:
  164. if len(plot) == 1:
  165. plot_fn, args, kwargs = plot[0], (), {}
  166. elif len(plot) == 2:
  167. plot_fn, args, kwargs = plot[0], plot[1], {}
  168. elif len(plot) == 3:
  169. plot_fn, args, kwargs = plot[0], plot[1], plot[2]
  170. else:
  171. raise ValueError('Plot tuple must be of format (func), '
  172. f'or (func, tuple), or (func, tuple, dict). Got {plot}')
  173. this_args, this_docs, txt = plot_fn(*args, **kwargs)
  174. argdicts[(i, j)].append(this_args)
  175. docs[(i, j)].append(this_docs)
  176. txts[(i, j)].append(txt)
  177. return argdicts, docs, txts
  178. def render_plots(plots, exts=('png',), scale=1.0, to_disk=True):
  179. for plot in plots:
  180. print(f'Building plot {plot.name}')
  181. plot.data = None
  182. if to_disk:
  183. with lp.figure(plot.name.replace(' ', '_'), directory='output/figures',
  184. exts=exts,
  185. size=(scale * 10, scale * 10)):
  186. argdicts, docs, txts = grid_plot(plot.subplots)
  187. else:
  188. out = BytesIO()
  189. with lp.mem_figure(out,
  190. ext=exts[0],
  191. size=(scale * 10, scale * 10)):
  192. argdicts, docs, txts = grid_plot(plot.subplots)
  193. out.seek(0)
  194. plot.data = b64encode(out.read()).decode()
  195. plot.argdicts = argdicts
  196. plot.docs = docs
  197. plot.txts = txts
  198. def add_decorations(axes, luminosity, energy):
  199. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  200. axes.text(0.01, 0.98, cms_prelim,
  201. horizontalalignment='left',
  202. verticalalignment='top',
  203. transform=axes.transAxes)
  204. lumi = ""
  205. energy_str = ""
  206. if luminosity is not None:
  207. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  208. if energy is not None:
  209. energy_str = r'({} TeV)'.format(energy)
  210. axes.text(1, 1, ' '.join([lumi, energy_str]),
  211. horizontalalignment='right',
  212. verticalalignment='bottom',
  213. transform=axes.transAxes)
  214. def hist_plot(h, *args, norm=None, include_errors=False,
  215. log=False, xlim=None, ylim=None, fit=None,
  216. grid=False, stats=False, **kwargs):
  217. """ Plots a 1D ROOT histogram object using matplotlib """
  218. from inspect import signature
  219. if norm:
  220. h = hist_normalize(h, norm)
  221. values, errors, edges = h
  222. scale = 1. if norm is None else norm / np.sum(values)
  223. values = [val * scale for val in values]
  224. errors = [val * scale for val in errors]
  225. left, right = np.array(edges[:-1]), np.array(edges[1:])
  226. x = np.array([left, right]).T.flatten()
  227. y = np.array([values, values]).T.flatten()
  228. ax = plt.gca()
  229. ax.set_xlabel(kwargs.pop('xlabel', ''))
  230. ax.set_ylabel(kwargs.pop('ylabel', ''))
  231. title = kwargs.pop('title', '')
  232. if xlim is not None:
  233. ax.set_xlim(xlim)
  234. if ylim is not None:
  235. ax.set_ylim(ylim)
  236. # elif not log:
  237. # axes.set_ylim((0, None))
  238. ax.plot(x, y, *args, linewidth=1, **kwargs)
  239. if include_errors:
  240. ax.errorbar(hist_bin_centers(h), values, yerr=errors,
  241. color='k', marker=None, linestyle='None',
  242. barsabove=True, elinewidth=.7, capsize=1)
  243. if log:
  244. ax.set_yscale('log')
  245. if fit:
  246. f, p0 = fit
  247. popt, pcov = hist_fit(h, f, p0)
  248. fit_xs = np.linspace(x[0], x[-1], 100)
  249. fit_ys = f(fit_xs, *popt)
  250. ax.plot(fit_xs, fit_ys, '--g')
  251. arglabels = list(signature(f).parameters)[1:]
  252. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  253. for label, value in zip(arglabels, popt))
  254. ax.text(0.60, 0.95, label_txt, va='top', transform=ax.transAxes,
  255. fontsize='medium', family='monospace', usetex=False)
  256. if stats:
  257. _add_stats(h, title)
  258. else:
  259. ax.set_title(title)
  260. ax.grid(grid, color='#E0E0E0')
  261. def hist2d_plot(h, **kwargs):
  262. """ Plots a 2D ROOT histogram object using matplotlib """
  263. try:
  264. values, errors, xs, ys = h
  265. except (TypeError, ValueError):
  266. values, errors, xs, ys = hist2d(h)
  267. plt.xlabel(kwargs.pop('xlabel', ''))
  268. plt.ylabel(kwargs.pop('ylabel', ''))
  269. plt.title(kwargs.pop('title', ''))
  270. plt.pcolormesh(xs, ys, values, )
  271. # axes.colorbar() TODO: Re-enable this
  272. def hist_plot_stack(hists: list, labels: list = None):
  273. """
  274. Creates a stacked histogram in the current axes.
  275. :param hists: list of histogram
  276. :param labels:
  277. :return:
  278. """
  279. if len(hists) == 0:
  280. return
  281. if len(set([len(hist[0]) for hist in hists])) != 1:
  282. raise ValueError("all histograms must have the same number of bins")
  283. if labels is None:
  284. labels = [None for _ in hists]
  285. if len(labels) != len(hists):
  286. raise ValueError("Label mismatch")
  287. bottoms = [0 for _ in hists[0][0]]
  288. for hist, label in zip(hists, labels):
  289. centers = []
  290. widths = []
  291. heights = []
  292. for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0]):
  293. centers.append((right + left) / 2)
  294. widths.append(right - left)
  295. heights.append(content)
  296. plt.bar(centers, heights, widths, bottoms, label=label)
  297. for i, content in enumerate(hist[0]):
  298. bottoms[i] += content
  299. def hists_to_table(hists, row_labels=(), column_labels=(), format="{:.2f}"):
  300. table = ['<table class="table table-condensed">']
  301. if column_labels:
  302. table.append('<thead><tr>')
  303. if row_labels:
  304. table.append('<th></th>')
  305. table.extend(f'<th>{label}</th>' for label in column_labels)
  306. table.append('</tr></thead>')
  307. table.append('<tbody>\n')
  308. for row_label, (vals, *_) in zip_longest(row_labels, hists):
  309. table.append('<tr>')
  310. if row_label:
  311. table.append(f'<td><strong>{row_label}</strong></td>')
  312. table.extend(('<td>'+format.format(val)+'</td>') for val in vals)
  313. table.append('</tr>\n')
  314. table.append('</tbody></table>')
  315. return ''.join(table)