plotting.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """
  2. plotting.py
  3. The functions in this module are meant for plotting the histogram objects created via
  4. filval.histogram
  5. """
  6. from collections import defaultdict
  7. from itertools import zip_longest
  8. from io import BytesIO
  9. from base64 import b64encode
  10. import numpy as np
  11. import matplotlib.pyplot as plt
  12. from markdown import Markdown
  13. import latexipy as lp
  14. from filval.histogram import (hist, hist2d, hist_bin_centers, hist_fit,
  15. hist_norm, hist_stats)
  16. __all__ = ['Plot',
  17. 'decl_plot',
  18. 'grid_plot',
  19. 'render_plots',
  20. 'generate_dashboard',
  21. 'hist_plot',
  22. 'hist_plot_stack',
  23. 'hist2d_plot',
  24. 'hists_to_table',
  25. 'simple_plot']
  26. class Plot:
  27. def __init__(self, subplots, name, title=None, docs="N/A", arg_dicts=None):
  28. if type(subplots) is not list:
  29. subplots = [[subplots]]
  30. elif len(subplots) > 0 and type(subplots[0]) is not list:
  31. subplots = [subplots]
  32. self.subplots = subplots
  33. self.name = name
  34. self.title = title
  35. self.docs = docs
  36. self.arg_dicts = arg_dicts if arg_dicts is not None else {}
  37. MD = Markdown(extensions=['mdx_math'],
  38. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  39. lp.latexify(params={'pgf.texsystem': 'pdflatex',
  40. 'text.usetex': True,
  41. 'font.family': 'serif',
  42. 'pgf.preamble': [],
  43. 'font.size': 15,
  44. 'axes.labelsize': 15,
  45. 'axes.titlesize': 13,
  46. 'legend.fontsize': 13,
  47. 'xtick.labelsize': 11,
  48. 'ytick.labelsize': 11,
  49. 'figure.dpi': 150,
  50. 'savefig.transparent': False,
  51. },
  52. new_backend='TkAgg')
  53. def _fn_call_to_dict(fn, *args, **kwargs):
  54. from inspect import signature
  55. from html import escape
  56. pnames = list(signature(fn).parameters)
  57. pvals = list(args) + list(kwargs.values())
  58. return {escape(str(k)): escape(str(v)) for k, v in zip(pnames, pvals)}
  59. def _process_docs(fn):
  60. from inspect import getdoc
  61. raw = getdoc(fn)
  62. if raw:
  63. return MD.convert(raw)
  64. else:
  65. return None
  66. def decl_plot(fn):
  67. from functools import wraps
  68. @wraps(fn)
  69. def f(*args, **kwargs):
  70. txt = fn(*args, **kwargs)
  71. argdict = _fn_call_to_dict(fn, *args, **kwargs)
  72. docs = _process_docs(fn)
  73. if not txt:
  74. txt = ''
  75. txt = MD.convert(txt)
  76. return argdict, docs, txt
  77. return f
  78. def simple_plot(thx):
  79. import ROOT
  80. if isinstance(thx, ROOT.TH2):
  81. def f(h):
  82. hist2d_plot(hist2d(h))
  83. plt.xlabel(h.GetXaxis().GetTitle())
  84. plt.ylabel(h.GetYaxis().GetTitle())
  85. return dict(), "", ""
  86. return Plot([[(f, (thx,), {})]], thx.GetName())
  87. elif isinstance(thx, ROOT.TH1):
  88. def f(h):
  89. hist_plot(hist(h))
  90. plt.xlabel(h.GetXaxis().GetTitle())
  91. plt.ylabel(h.GetYaxis().GetTitle())
  92. return dict(), "", ""
  93. return Plot([[(f, (thx,), {})]], thx.GetName())
  94. else:
  95. raise ValueError("must call simple_plot with a ROOT TH1 or TH2 object")
  96. def generate_dashboard(plots, title, output='dashboard.html', template='dashboard.j2',
  97. source=None, ana_source=None, config=None):
  98. from jinja2 import Environment, PackageLoader, select_autoescape
  99. from os.path import join, isdir
  100. from os import mkdir
  101. from urllib.parse import quote
  102. env = Environment(
  103. loader=PackageLoader('filval', 'templates'),
  104. autoescape=select_autoescape(['htm', 'html', 'xml']),
  105. )
  106. env.globals.update({'quote': quote,
  107. 'enumerate': enumerate,
  108. 'zip': zip,
  109. })
  110. def get_by_n(objects, n=2):
  111. objects = list(objects)
  112. while objects:
  113. yield objects[:n]
  114. objects = objects[n:]
  115. if source is not None:
  116. with open(source, 'r') as f:
  117. source = f.read()
  118. if not isdir('output'):
  119. mkdir('output')
  120. dashboard_path = join('output', output)
  121. with open(dashboard_path, 'w') as tempout:
  122. templ = env.get_template(template)
  123. tempout.write(templ.render(
  124. plots=get_by_n(plots, 3),
  125. title=title,
  126. source=source,
  127. ana_source=ana_source,
  128. config=config
  129. ))
  130. return dashboard_path
  131. def _add_stats(hist, title=''):
  132. fmt = r'''\begin{{eqnarray*}}
  133. \sum{{x_i}} &=& {sum:5.3f} \\
  134. \sum{{\Delta x_i \cdot x_i}} &=& {int:5.3G} \\
  135. \mu &=& {mean:5.3G} \\
  136. \sigma^2 &=& {var:5.3G} \\
  137. \sigma &=& {std:5.3G}
  138. \end{{eqnarray*}}'''
  139. txt = fmt.format(**hist_stats(hist), title=title)
  140. txt = txt.replace('\n', ' ')
  141. plt.text(0.7, 0.9, txt,
  142. bbox={'facecolor': 'white',
  143. 'alpha': 0.7,
  144. 'boxstyle': 'square,pad=0.8'},
  145. transform=plt.gca().transAxes,
  146. verticalalignment='top',
  147. horizontalalignment='left',
  148. size='small')
  149. if title:
  150. plt.text(0.72, 0.97, title,
  151. bbox={'facecolor': 'white',
  152. 'alpha': 0.8},
  153. transform=plt.gca().transAxes,
  154. verticalalignment='top',
  155. horizontalalignment='left')
  156. def grid_plot(subplots):
  157. if any(len(row) != len(subplots[0]) for row in subplots):
  158. raise ValueError('make_plot requires a rectangular list-of-lists as '
  159. 'input. Fill empty slots with None')
  160. def calc_row_span(fig, row, col):
  161. span = 1
  162. for r in range(row + 1, len(fig)):
  163. if fig[r][col] == 'FU':
  164. span += 1
  165. else:
  166. break
  167. return span
  168. def calc_column_span(fig, row, col):
  169. span = 1
  170. for c in range(col + 1, len(fig[row])):
  171. if fig[row][c] == 'FL':
  172. span += 1
  173. else:
  174. break
  175. return span
  176. rows = len(subplots)
  177. cols = len(subplots[0])
  178. argdicts = defaultdict(list)
  179. docs = defaultdict(list)
  180. txts = defaultdict(list)
  181. for i in range(rows):
  182. for j in range(cols):
  183. cell = subplots[i][j]
  184. if cell in ('FL', 'FU', None):
  185. continue
  186. if not isinstance(cell, list):
  187. cell = [cell]
  188. column_span = calc_column_span(subplots, i, j)
  189. row_span = calc_row_span(subplots, i, j)
  190. plt.subplot2grid((rows, cols), (i, j),
  191. colspan=column_span, rowspan=row_span)
  192. for plot in cell:
  193. if len(plot) == 1:
  194. plot_fn, args, kwargs = plot[0], (), {}
  195. elif len(plot) == 2:
  196. plot_fn, args, kwargs = plot[0], plot[1], {}
  197. elif len(plot) == 3:
  198. plot_fn, args, kwargs = plot[0], plot[1], plot[2]
  199. else:
  200. raise ValueError('Plot tuple must be of format (func), '
  201. f'or (func, tuple), or (func, tuple, dict). Got {plot}')
  202. this_args, this_docs, txt = plot_fn(*args, **kwargs)
  203. argdicts[(i, j)].append(this_args)
  204. docs[(i, j)].append(this_docs)
  205. txts[(i, j)].append(txt)
  206. return argdicts, docs, txts
  207. def render_plots(plots, exts=('png',), scale=1.0, to_disk=True):
  208. for plot in plots:
  209. print(f'Building plot {plot.name}')
  210. plot.data = None
  211. if to_disk:
  212. with lp.figure(plot.name.replace(' ', '_'), directory='output/figures',
  213. exts=exts,
  214. size=(scale * 10, scale * 10)):
  215. argdicts, docs, txts = grid_plot(plot.subplots)
  216. else:
  217. out = BytesIO()
  218. with lp.mem_figure(out,
  219. ext=exts[0],
  220. size=(scale * 10, scale * 10)):
  221. argdicts, docs, txts = grid_plot(plot.subplots)
  222. out.seek(0)
  223. plot.data = b64encode(out.read()).decode()
  224. plot.argdicts = argdicts
  225. plot.docs = docs
  226. plot.txts = txts
  227. def add_decorations(axes, luminosity, energy):
  228. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  229. axes.text(0.01, 0.98, cms_prelim,
  230. horizontalalignment='left',
  231. verticalalignment='top',
  232. transform=axes.transAxes)
  233. lumi = ""
  234. energy_str = ""
  235. if luminosity is not None:
  236. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  237. if energy is not None:
  238. energy_str = r'({} TeV)'.format(energy)
  239. axes.text(1, 1, ' '.join([lumi, energy_str]),
  240. horizontalalignment='right',
  241. verticalalignment='bottom',
  242. transform=axes.transAxes)
  243. def hist_plot(h, *args, norm=None, include_errors=False,
  244. log=False, xlim=None, ylim=None, fit=None,
  245. grid=False, stats=False, **kwargs):
  246. """ Plots a 1D ROOT histogram object using matplotlib """
  247. from inspect import signature
  248. if norm:
  249. h = hist_norm(h, norm)
  250. values, errors, edges = h
  251. scale = 1. if norm is None else norm / np.sum(values)
  252. values = [val * scale for val in values]
  253. errors = [val * scale for val in errors]
  254. left, right = np.array(edges[:-1]), np.array(edges[1:])
  255. x = np.array([left, right]).T.flatten()
  256. y = np.array([values, values]).T.flatten()
  257. ax = plt.gca()
  258. ax.set_xlabel(kwargs.pop('xlabel', ''))
  259. ax.set_ylabel(kwargs.pop('ylabel', ''))
  260. title = kwargs.pop('title', '')
  261. if xlim is not None:
  262. ax.set_xlim(xlim)
  263. if ylim is not None:
  264. ax.set_ylim(ylim)
  265. # elif not log:
  266. # axes.set_ylim((0, None))
  267. ax.plot(x, y, *args, linewidth=1, **kwargs)
  268. if include_errors:
  269. ax.errorbar(hist_bin_centers(h), values, yerr=errors,
  270. color='k', marker=None, linestyle='None',
  271. barsabove=True, elinewidth=.7, capsize=1)
  272. if log:
  273. ax.set_yscale('log')
  274. if fit:
  275. f, p0 = fit
  276. popt, pcov = hist_fit(h, f, p0)
  277. fit_xs = np.linspace(x[0], x[-1], 100)
  278. fit_ys = f(fit_xs, *popt)
  279. ax.plot(fit_xs, fit_ys, '--g')
  280. arglabels = list(signature(f).parameters)[1:]
  281. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  282. for label, value in zip(arglabels, popt))
  283. ax.text(0.60, 0.95, label_txt, va='top', transform=ax.transAxes,
  284. fontsize='medium', family='monospace', usetex=False)
  285. if stats:
  286. _add_stats(h, title)
  287. else:
  288. ax.set_title(title)
  289. ax.grid(grid, color='#E0E0E0')
  290. def hist2d_plot(h, txt_format=None, colorbar=False, **kwargs):
  291. """ Plots a 2D ROOT histogram object using matplotlib """
  292. try:
  293. values, errors, xs, ys = h
  294. except (TypeError, ValueError):
  295. values, errors, xs, ys = hist2d(h)
  296. plt.xlabel(kwargs.pop('xlabel', ''))
  297. plt.ylabel(kwargs.pop('ylabel', ''))
  298. plt.title(kwargs.pop('title', ''))
  299. plt.pcolormesh(xs, ys, values, **kwargs)
  300. if txt_format is not None:
  301. cmap = plt.get_cmap()
  302. min_, max_ = float(np.min(values)), float(np.max(values))
  303. def get_intensity(val):
  304. cmap_idx = int((cmap.N-1) * (val - min_) / (max_-min_))
  305. color = cmap.colors[cmap_idx]
  306. return color[0]*0.25 + color[1]*0.5 + color[2]*0.25
  307. for idx_row in range(values.shape[0]):
  308. for idx_col in range(values.shape[1]):
  309. x_mid = (xs[idx_row, idx_col] + xs[idx_row, idx_col+1]) / 2
  310. y_mid = (ys[idx_row, idx_col] + ys[idx_row+1, idx_col]) / 2
  311. val = txt_format.format(values[idx_row, idx_col])
  312. txt_color = 'w' if get_intensity(values[idx_row, idx_col]) < 0.5 else 'k'
  313. plt.text(x_mid, y_mid, val, verticalalignment='center', horizontalalignment='center',
  314. color=txt_color, fontsize=12)
  315. if colorbar:
  316. plt.colorbar()
  317. def hist_plot_stack(hists: list, labels: list = None):
  318. """
  319. Creates a stacked histogram in the current axes.
  320. :param hists: list of histogram
  321. :param labels:
  322. :return:
  323. """
  324. if len(hists) == 0:
  325. return
  326. if len(set([len(hist[0]) for hist in hists])) != 1:
  327. raise ValueError("all histograms must have the same number of bins")
  328. if labels is None:
  329. labels = [None for _ in hists]
  330. if len(labels) != len(hists):
  331. raise ValueError("Label mismatch")
  332. bottoms = [0 for _ in hists[0][0]]
  333. for hist, label in zip(hists, labels):
  334. centers = []
  335. widths = []
  336. heights = []
  337. for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0]):
  338. centers.append((right + left) / 2)
  339. widths.append(right - left)
  340. heights.append(content)
  341. plt.bar(centers, heights, widths, bottoms, label=label)
  342. for i, content in enumerate(hist[0]):
  343. bottoms[i] += content
  344. def hists_to_table(hists, row_labels=(), column_labels=(), format="{:.2f}"):
  345. table = ['<table class="table table-condensed">']
  346. if column_labels:
  347. table.append('<thead><tr>')
  348. if row_labels:
  349. table.append('<th></th>')
  350. table.extend(f'<th>{label}</th>' for label in column_labels)
  351. table.append('</tr></thead>')
  352. table.append('<tbody>\n')
  353. for row_label, (vals, *_) in zip_longest(row_labels, hists):
  354. table.append('<tr>')
  355. if row_label:
  356. table.append(f'<td><strong>{row_label}</strong></td>')
  357. table.extend(('<td>'+format.format(val)+'</td>') for val in vals)
  358. table.append('</tr>\n')
  359. table.append('</tbody></table>')
  360. return ''.join(table)