plotter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #!/usr/bin/env python3
  2. from collections import defaultdict
  3. from io import BytesIO
  4. from base64 import b64encode
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from markdown import Markdown
  8. import latexipy as lp
  9. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  10. hist_normalize, hist_stats)
  11. __all__ = ['Plot',
  12. 'decl_plot',
  13. 'grid_plot',
  14. 'render_plots',
  15. 'hist_plot',
  16. 'hist2d_plot']
  17. class Plot:
  18. def __init__(self, subplots, name, title=None, docs="N/A", argdicts={}):
  19. self.subplots = subplots
  20. self.name = name
  21. self.title = title
  22. self.docs = docs
  23. self.argdicts = argdicts
  24. MD = Markdown(extensions=['mdx_math'],
  25. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  26. lp.latexify(params={'pgf.texsystem': 'pdflatex',
  27. 'text.usetex': True,
  28. 'font.family': 'serif',
  29. 'pgf.preamble': [],
  30. 'font.size': 15,
  31. 'axes.labelsize': 15,
  32. 'axes.titlesize': 13,
  33. 'legend.fontsize': 13,
  34. 'xtick.labelsize': 11,
  35. 'ytick.labelsize': 11,
  36. 'figure.dpi': 150,
  37. 'savefig.transparent': False,
  38. },
  39. new_backend='TkAgg')
  40. def _fn_call_to_dict(fn, *args, **kwargs):
  41. from inspect import signature
  42. pnames = list(signature(fn).parameters)
  43. pvals = list(args)+list(kwargs.values())
  44. return {k: v for k, v in zip(pnames, pvals)}
  45. def _process_docs(fn):
  46. from inspect import getdoc
  47. raw = getdoc(fn)
  48. if raw:
  49. return MD.convert(raw)
  50. else:
  51. return None
  52. def decl_plot(fn):
  53. from functools import wraps
  54. @wraps(fn)
  55. def f(*args, **kwargs):
  56. fn(*args, **kwargs)
  57. argdict = _fn_call_to_dict(fn, *args, **kwargs)
  58. docs = _process_docs(fn)
  59. return argdict, docs
  60. return f
  61. def _add_stats(hist, title=''):
  62. fmt = r'''\begin{{eqnarray*}}
  63. \sum{{x_i}} &=& {sum:5.3f} \\
  64. \sum{{\Delta x_i \cdot x_i}} &=& {int:5.3G} \\
  65. \mu &=& {mean:5.3G} \\
  66. \sigma^2 &=& {var:5.3G} \\
  67. \sigma &=& {std:5.3G}
  68. \end{{eqnarray*}}'''
  69. txt = fmt.format(**hist_stats(hist), title=title)
  70. txt = txt.replace('\n', ' ')
  71. plt.text(0.7, 0.9, txt,
  72. bbox={'facecolor': 'white',
  73. 'alpha': 0.7,
  74. 'boxstyle': 'square,pad=0.8'},
  75. transform=plt.gca().transAxes,
  76. verticalalignment='top',
  77. horizontalalignment='left',
  78. size='small')
  79. if title:
  80. plt.text(0.72, 0.97, title,
  81. bbox={'facecolor': 'white',
  82. 'alpha': 0.8},
  83. transform=plt.gca().transAxes,
  84. verticalalignment='top',
  85. horizontalalignment='left')
  86. def grid_plot(subplots):
  87. if any(len(row) != len(subplots[0]) for row in subplots):
  88. raise ValueError("make_plot requires a rectangular list-of-lists as "
  89. "input. Fill empty slots with None")
  90. def calc_rowspan(fig, row, col):
  91. span = 1
  92. for r in range(row+1, len(fig)):
  93. if fig[r][col] == "FU":
  94. span += 1
  95. else:
  96. break
  97. return span
  98. def calc_colspan(fig, row, col):
  99. span = 1
  100. for c in range(col+1, len(fig[row])):
  101. if fig[row][c] == "FL":
  102. span += 1
  103. else:
  104. break
  105. return span
  106. rows = len(subplots)
  107. cols = len(subplots[0])
  108. argdicts = defaultdict(list)
  109. docs = defaultdict(list)
  110. for i in range(rows):
  111. for j in range(cols):
  112. cell = subplots[i][j]
  113. if cell in ("FL", "FU", None):
  114. continue
  115. if not isinstance(cell, list):
  116. cell = [cell]
  117. colspan = calc_colspan(subplots, i, j)
  118. rowspan = calc_rowspan(subplots, i, j)
  119. plt.subplot2grid((rows, cols), (i, j),
  120. colspan=colspan, rowspan=rowspan)
  121. for plot in cell:
  122. plot_fn, args, kwargs = plot
  123. this_args, this_docs = plot_fn(*args, **kwargs)
  124. argdicts[(i, j)].append(this_args)
  125. docs[(i, j)].append(this_docs)
  126. return argdicts, docs
  127. def render_plots(plots, exts=['png'], scale=1.0, to_disk=True):
  128. for plot in plots:
  129. print(f'Building plot {plot.name}')
  130. plot.data = None
  131. if to_disk:
  132. with lp.figure(plot.name, directory='output/figures',
  133. exts=exts,
  134. size=(scale*10, scale*10)):
  135. argdicts, docs = grid_plot(plot.subplots)
  136. else:
  137. out = BytesIO()
  138. with lp.mem_figure(out,
  139. ext=exts[0],
  140. size=(scale*10, scale*10)):
  141. argdicts, docs = grid_plot(plot.subplots)
  142. out.seek(0)
  143. plot.data = b64encode(out.read()).decode()
  144. plot.argdicts = argdicts
  145. plot.docs = docs
  146. def add_decorations(axes, luminosity, energy):
  147. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  148. axes.text(0.01, 0.98, cms_prelim,
  149. horizontalalignment='left',
  150. verticalalignment='top',
  151. transform=axes.transAxes)
  152. lumi = ""
  153. energy_str = ""
  154. if luminosity is not None:
  155. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  156. if energy is not None:
  157. energy_str = r'({} TeV)'.format(energy)
  158. axes.text(1, 1, ' '.join([lumi, energy_str]),
  159. horizontalalignment='right',
  160. verticalalignment='bottom',
  161. transform=axes.transAxes)
  162. def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
  163. log=False, fig=None, xlim=None, ylim=None, fit=None,
  164. grid=False, stats=True, **kwargs):
  165. """ Plots a 1D ROOT histogram object using matplotlib """
  166. from inspect import signature
  167. if norm:
  168. h = hist_normalize(h, norm)
  169. values, errors, edges = h
  170. scale = 1. if norm is None else norm/np.sum(values)
  171. values = [val*scale for val in values]
  172. errors = [val*scale for val in errors]
  173. left, right = np.array(edges[:-1]), np.array(edges[1:])
  174. X = np.array([left, right]).T.flatten()
  175. Y = np.array([values, values]).T.flatten()
  176. if axes is None:
  177. import matplotlib.pyplot as plt
  178. axes = plt.gca()
  179. axes.set_xlabel(kwargs.pop('xlabel', ''))
  180. axes.set_ylabel(kwargs.pop('ylabel', ''))
  181. title = kwargs.pop('title', '')
  182. if xlim is not None:
  183. axes.set_xlim(xlim)
  184. if ylim is not None:
  185. axes.set_ylim(ylim)
  186. # elif not log:
  187. # axes.set_ylim((0, None))
  188. axes.plot(X, Y, *args, linewidth=1, **kwargs)
  189. if include_errors:
  190. axes.errorbar(hist_bin_centers(h), values, yerr=errors,
  191. color='k', marker=None, linestyle='None',
  192. barsabove=True, elinewidth=.7, capsize=1)
  193. if log:
  194. axes.set_yscale('log')
  195. if fit:
  196. f, p0 = fit
  197. popt, pcov = hist_fit(h, f, p0)
  198. fit_xs = np.linspace(X[0], X[-1], 100)
  199. fit_ys = f(fit_xs, *popt)
  200. axes.plot(fit_xs, fit_ys, '--g')
  201. arglabels = list(signature(f).parameters)[1:]
  202. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  203. for label, value in zip(arglabels, popt))
  204. axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
  205. fontsize='medium', family='monospace', usetex=False)
  206. if stats:
  207. _add_stats(h, title)
  208. else:
  209. axes.set_title(title)
  210. axes.grid(grid, color='#E0E0E0')
  211. def hist2d_plot(h, *args, axes=None, **kwargs):
  212. """ Plots a 2D ROOT histogram object using matplotlib """
  213. try:
  214. values, errors, xs, ys = h
  215. except (TypeError, ValueError):
  216. values, errors, xs, ys = hist2d(h)
  217. if axes is None:
  218. import matplotlib.pyplot as plt
  219. axes = plt.gca()
  220. axes.set_xlabel(kwargs.pop('xlabel', ''))
  221. axes.set_ylabel(kwargs.pop('ylabel', ''))
  222. axes.set_title(kwargs.pop('title', ''))
  223. axes.pcolormesh(xs, ys, values,)
  224. # axes.colorbar() TODO: Re-enable this
  225. def hist_plot_stack(hists):
  226. def __init__(self, title=""):
  227. # raise NotImplementedError("need to fix to not use to_bin_list")
  228. self.title = title
  229. self.xlabel = ""
  230. self.ylabel = ""
  231. self.xlim = (None, None)
  232. self.ylim = (None, None)
  233. self.logx = False
  234. self.logy = False
  235. self.backgrounds = []
  236. self.signal = None
  237. self.signal_stack = True
  238. self.data = None
  239. def add_mc_background(self, h, label, lumi=None, plot_color=''):
  240. self.backgrounds.append((label, lumi, h, plot_color))
  241. def set_mc_signal(self, h, label, lumi=None, stack=True, scale=1, plot_color=''):
  242. self.signal = (label, lumi, h, plot_color)
  243. self.signal_stack = stack
  244. self.signal_scale = scale
  245. # def set_data(self, th1, lumi=None, plot_color=''):
  246. # self.data = ('data', lumi, hist(th1), plot_color)
  247. # self.luminosity = lumi
  248. def _verify_binning_match(self):
  249. bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
  250. if self.signal is not None:
  251. bins_count.append(len(self.signal[2]))
  252. if self.data is not None:
  253. bins_count.append(len(self.data[2]))
  254. n_bins = bins_count[0]
  255. if any(bin_count != n_bins for bin_count in bins_count):
  256. raise ValueError("all histograms must have the same number of bins")
  257. self.n_bins = n_bins
  258. def save(self, filename, **kwargs):
  259. import matplotlib.pyplot as plt
  260. plt.ioff()
  261. fig = plt.figure()
  262. ax = fig.gca()
  263. self.do_draw(ax, **kwargs)
  264. fig.savefig("figures/"+filename, transparent=True)
  265. plt.close(fig)
  266. plt.ion()
  267. def draw(self, ax=None):
  268. if ax is None:
  269. ax = plt.gca()
  270. self.axeses = [ax]
  271. self._verify_binning_match()
  272. bottoms = [0]*self.n_bins
  273. if self.logx:
  274. ax.set_xscale('log')
  275. if self.logy:
  276. ax.set_yscale('log')
  277. def draw_bar(label, lumi, hist, plot_color, scale=1, stack=True, **kwargs):
  278. if stack:
  279. lefts = []
  280. widths = []
  281. heights = []
  282. for left, right, content in zip(hist[2][:-1], hist[2][1:], hist[0])
  283. lefts.append(left)
  284. widths.append(right-left)
  285. if lumi is not None:
  286. content *= self.luminosity/lumi
  287. content *= scale
  288. heights.append(content)
  289. ax.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
  290. for i, content in enumerate(hist[0]):
  291. if lumi is not None:
  292. content *= self.luminosity/lumi
  293. content *= scale
  294. bottoms[i] += content
  295. else:
  296. raise NotImplementedError('only supports stacks')
  297. xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
  298. ys = [0]
  299. for left, right, content in bins:
  300. width2 = (right-left)/2
  301. if lumi is not None:
  302. content *= self.luminosity/lumi
  303. content *= scale
  304. xs.append(left-width2)
  305. ys.append(content)
  306. xs.append(right-width2)
  307. ys.append(content)
  308. xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
  309. ys.append(0)
  310. ax.plot(xs, ys, label=label, color=plot_color, **kwargs)
  311. if self.signal is not None and self.signal_stack:
  312. label, lumi, hist, plot_color = self.signal
  313. if self.signal_scale != 1:
  314. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  315. draw_bar(label, lumi, hist, plot_color, scale=self.signal_scale, hatch='/')
  316. for background in self.backgrounds:
  317. draw_bar(*background)
  318. if self.signal is not None and not self.signal_stack:
  319. # draw_bar(*self.signal, stack=False, color='k')
  320. label, lumi, bins, plot_color = self.signal
  321. if self.signal_scale != 1:
  322. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  323. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
  324. ax.set_title(self.title)
  325. ax.set_xlabel(self.xlabel)
  326. ax.set_ylabel(self.ylabel)
  327. ax.set_xlim(*self.xlim)
  328. if self.logy:
  329. ax.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
  330. else:
  331. ax.set_ylim(None, max(bottoms)*1.2)
  332. ax.legend(frameon=True, ncol=2)
  333. add_decorations(ax, self.luminosity, self.energy)
  334. class StackHistWithSignificance(StackHist):
  335. def __init__(self, *args, **kwargs):
  336. super().__init__(*args, **kwargs)
  337. def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
  338. bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
  339. axes.set_position(top_box)
  340. super().do_draw(axes)
  341. axes.set_xticks([])
  342. rhs_color = '#cc6600'
  343. bottom = axes.get_figure().add_axes(bottom_box)
  344. bottom_rhs = bottom.twinx()
  345. bgs = [0]*self.n_bins
  346. for (_, _, bins, _) in self.backgrounds:
  347. for i, (left, right, value) in enumerate(bins):
  348. bgs[i] += value
  349. sigs = [0]*self.n_bins
  350. if bin_significance:
  351. xs = []
  352. for i, (left, right, value) in enumerate(self.signal[2]):
  353. sigs[i] += value
  354. xs.append(left)
  355. xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
  356. bottom.plot(xs, ys, '.k')
  357. if high_cut_significance:
  358. # s/(s+b) for events passing a minimum cut requirement
  359. min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
  360. min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
  361. min_xs, min_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
  362. if (sig+bg) > 0])
  363. bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
  364. if low_cut_significance:
  365. # s/(s+b) for events passing a maximum cut requirement
  366. max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
  367. max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
  368. max_xs, max_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
  369. if (sig+bg) > 0])
  370. bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
  371. bottom.set_ylabel(r'$S/(S+B)$')
  372. bottom.set_xlim(axes.get_xlim())
  373. bottom.set_ylim((0, 1.1))
  374. if low_cut_significance or high_cut_significance:
  375. bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
  376. bottom_rhs.yaxis.label.set_color(rhs_color)
  377. bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
  378. # bottom.grid()
  379. if __name__ == '__main__':
  380. import matplotlib.pyplot as plt
  381. from utils import ResultSet
  382. rs_TTZ = ResultSet("TTZ", "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
  383. rs_TTW = ResultSet("TTW", "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
  384. rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
  385. rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
  386. sh = StackHist('B-Jet Multiplicity')
  387. sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
  388. sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
  389. sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
  390. sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
  391. sh.luminosity = 40
  392. sh.energy = 13
  393. sh.xlabel = 'B-Jet Count'
  394. sh.ylabel = r'\# Events'
  395. sh.xlim = (-.5, 9.5)
  396. sh.signal_stack = False
  397. fig = plt.figure()
  398. sh.draw(fig.gca())
  399. plt.show()
  400. # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')