plotter.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. #!/usr/bin/env python3
  2. from collections import defaultdict
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from markdown import Markdown
  6. import latexipy as lp
  7. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  8. hist_normalize)
  9. __all__ = ['Plot',
  10. 'decl_plot',
  11. 'grid_plot',
  12. 'save_plots',
  13. 'hist_plot',
  14. 'hist2d_plot']
  15. class Plot:
  16. def __init__(self, subplots, name, title=None, docs="N/A", argdicts={}):
  17. self.subplots = subplots
  18. self.name = name
  19. self.title = title
  20. self.docs = docs
  21. self.argdicts = argdicts
  22. MD = Markdown(extensions=['mdx_math'],
  23. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  24. lp.latexify(params={'pgf.texsystem': 'pdflatex',
  25. 'text.usetex': True,
  26. 'font.family': 'serif',
  27. 'pgf.preamble': [],
  28. 'font.size': 15,
  29. 'axes.labelsize': 15,
  30. 'axes.titlesize': 13,
  31. 'legend.fontsize': 13,
  32. 'xtick.labelsize': 11,
  33. 'ytick.labelsize': 11,
  34. 'figure.dpi': 150,
  35. 'savefig.transparent': False,
  36. },
  37. new_backend='TkAgg')
  38. def _fn_call_to_dict(fn, *args, **kwargs):
  39. from inspect import signature
  40. pnames = list(signature(fn).parameters)
  41. pvals = list(args)+list(kwargs.values())
  42. return {k: v for k, v in zip(pnames, pvals)}
  43. def _process_docs(fn):
  44. from inspect import getdoc
  45. raw = getdoc(fn)
  46. if raw:
  47. return MD.convert(raw)
  48. else:
  49. return None
  50. def decl_plot(fn):
  51. from functools import wraps
  52. @wraps(fn)
  53. def f(*args, **kwargs):
  54. fn(*args, **kwargs)
  55. argdict = _fn_call_to_dict(fn, *args, **kwargs)
  56. docs = _process_docs(fn)
  57. return argdict, docs
  58. return f
  59. def grid_plot(subplots):
  60. if any(len(row) != len(subplots[0]) for row in subplots):
  61. raise ValueError("make_plot requires a rectangular list-of-lists as "
  62. "input. Fill empty slots with None")
  63. def calc_rowspan(fig, row, col):
  64. span = 1
  65. for r in range(row+1, len(fig)):
  66. if fig[r][col] == "FU":
  67. span += 1
  68. else:
  69. break
  70. return span
  71. def calc_colspan(fig, row, col):
  72. span = 1
  73. for c in range(col+1, len(fig[row])):
  74. if fig[row][c] == "FL":
  75. span += 1
  76. else:
  77. break
  78. return span
  79. rows = len(subplots)
  80. cols = len(subplots[0])
  81. argdicts = defaultdict(list)
  82. docs = defaultdict(list)
  83. for i in range(rows):
  84. for j in range(cols):
  85. cell = subplots[i][j]
  86. if cell in ("FL", "FU", None):
  87. continue
  88. if not isinstance(cell, list):
  89. cell = [cell]
  90. colspan = calc_colspan(subplots, i, j)
  91. rowspan = calc_rowspan(subplots, i, j)
  92. plt.subplot2grid((rows, cols), (i, j),
  93. colspan=colspan, rowspan=rowspan)
  94. for plot in cell:
  95. plot_fn, args, kwargs = plot
  96. this_args, this_docs = plot_fn(*args, **kwargs)
  97. argdicts[(i, j)].append(this_args)
  98. docs[(i, j)].append(this_docs)
  99. return argdicts, docs
  100. def save_plots(plots, exts=['png'], scale=1.0):
  101. for plot in plots:
  102. print(f'Building plot {plot.name}')
  103. with lp.figure(plot.name, directory='output/figures',
  104. exts=exts,
  105. size=(scale*10, scale*10)):
  106. argdicts, docs = grid_plot(plot.subplots)
  107. plot.argdicts = argdicts
  108. plot.docs = docs
  109. def add_decorations(axes, luminosity, energy):
  110. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  111. axes.text(0.01, 0.98, cms_prelim,
  112. horizontalalignment='left',
  113. verticalalignment='top',
  114. transform=axes.transAxes)
  115. lumi = ""
  116. energy_str = ""
  117. if luminosity is not None:
  118. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  119. if energy is not None:
  120. energy_str = r'({} TeV)'.format(energy)
  121. axes.text(1, 1, ' '.join([lumi, energy_str]),
  122. horizontalalignment='right',
  123. verticalalignment='bottom',
  124. transform=axes.transAxes)
  125. def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
  126. log=False, fig=None, xlim=None, ylim=None, fit=None,
  127. grid=False, **kwargs):
  128. """ Plots a 1D ROOT histogram object using matplotlib """
  129. from inspect import signature
  130. if norm:
  131. h = hist_normalize(h, norm)
  132. values, errors, edges = h
  133. scale = 1. if norm is None else norm/np.sum(values)
  134. values = [val*scale for val in values]
  135. errors = [val*scale for val in errors]
  136. left, right = np.array(edges[:-1]), np.array(edges[1:])
  137. X = np.array([left, right]).T.flatten()
  138. Y = np.array([values, values]).T.flatten()
  139. if axes is None:
  140. import matplotlib.pyplot as plt
  141. axes = plt.gca()
  142. axes.set_xlabel(kwargs.pop('xlabel', ''))
  143. axes.set_ylabel(kwargs.pop('ylabel', ''))
  144. axes.set_title(kwargs.pop('title', ''))
  145. if xlim is not None:
  146. axes.set_xlim(xlim)
  147. if ylim is not None:
  148. axes.set_ylim(ylim)
  149. # elif not log:
  150. # axes.set_ylim((0, None))
  151. axes.plot(X, Y, *args, linewidth=1, **kwargs)
  152. if include_errors:
  153. axes.errorbar(hist_bin_centers(h), values, yerr=errors,
  154. color='k', marker=None, linestyle='None',
  155. barsabove=True, elinewidth=.7, capsize=1)
  156. if log:
  157. axes.set_yscale('log')
  158. if fit:
  159. f, p0 = fit
  160. popt, pcov = hist_fit(h, f, p0)
  161. fit_xs = np.linspace(X[0], X[-1], 100)
  162. fit_ys = f(fit_xs, *popt)
  163. axes.plot(fit_xs, fit_ys, '--g')
  164. arglabels = list(signature(f).parameters)[1:]
  165. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  166. for label, value in zip(arglabels, popt))
  167. axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
  168. fontsize='medium', family='monospace', usetex=False)
  169. axes.grid(grid, color='#E0E0E0')
  170. def hist2d_plot(h, *args, axes=None, **kwargs):
  171. """ Plots a 2D ROOT histogram object using matplotlib """
  172. try:
  173. values, errors, xs, ys = h
  174. except (TypeError, ValueError):
  175. values, errors, xs, ys = hist2d(h)
  176. if axes is None:
  177. import matplotlib.pyplot as plt
  178. axes = plt.gca()
  179. axes.set_xlabel(kwargs.pop('xlabel', ''))
  180. axes.set_ylabel(kwargs.pop('ylabel', ''))
  181. axes.set_title(kwargs.pop('title', ''))
  182. axes.pcolormesh(xs, ys, values,)
  183. # axes.colorbar() TODO: Re-enable this
  184. class StackHist:
  185. def __init__(self, title=""):
  186. raise NotImplementedError("need to fix to not use to_bin_list")
  187. self.title = title
  188. self.xlabel = ""
  189. self.ylabel = ""
  190. self.xlim = (None, None)
  191. self.ylim = (None, None)
  192. self.logx = False
  193. self.logy = False
  194. self.backgrounds = []
  195. self.signal = None
  196. self.signal_stack = True
  197. self.data = None
  198. def add_mc_background(self, th1, label, lumi=None, plot_color=''):
  199. self.backgrounds.append((label, lumi, hist(th1), plot_color))
  200. def set_mc_signal(self, th1, label, lumi=None, stack=True, scale=1, plot_color=''):
  201. self.signal = (label, lumi, hist(th1), plot_color)
  202. self.signal_stack = stack
  203. self.signal_scale = scale
  204. def set_data(self, th1, lumi=None, plot_color=''):
  205. self.data = ('data', lumi, hist(th1), plot_color)
  206. self.luminosity = lumi
  207. def _verify_binning_match(self):
  208. bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
  209. if self.signal is not None:
  210. bins_count.append(len(self.signal[2]))
  211. if self.data is not None:
  212. bins_count.append(len(self.data[2]))
  213. n_bins = bins_count[0]
  214. if any(bin_count != n_bins for bin_count in bins_count):
  215. raise ValueError("all histograms must have the same number of bins")
  216. self.n_bins = n_bins
  217. def save(self, filename, **kwargs):
  218. import matplotlib.pyplot as plt
  219. plt.ioff()
  220. fig = plt.figure()
  221. ax = fig.gca()
  222. self.do_draw(ax, **kwargs)
  223. fig.savefig("figures/"+filename, transparent=True)
  224. plt.close(fig)
  225. plt.ion()
  226. def do_draw(self, axes):
  227. self.axeses = [axes]
  228. self._verify_binning_match()
  229. bottoms = [0]*self.n_bins
  230. if self.logx:
  231. axes.set_xscale('log')
  232. if self.logy:
  233. axes.set_yscale('log')
  234. def draw_bar(label, lumi, bins, plot_color, scale=1, stack=True, **kwargs):
  235. if stack:
  236. lefts = []
  237. widths = []
  238. heights = []
  239. for left, right, content in bins:
  240. lefts.append(left)
  241. widths.append(right-left)
  242. if lumi is not None:
  243. content *= self.luminosity/lumi
  244. content *= scale
  245. heights.append(content)
  246. axes.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
  247. for i, (_, _, content) in enumerate(bins):
  248. if lumi is not None:
  249. content *= self.luminosity/lumi
  250. content *= scale
  251. bottoms[i] += content
  252. else:
  253. xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
  254. ys = [0]
  255. for left, right, content in bins:
  256. width2 = (right-left)/2
  257. if lumi is not None:
  258. content *= self.luminosity/lumi
  259. content *= scale
  260. xs.append(left-width2)
  261. ys.append(content)
  262. xs.append(right-width2)
  263. ys.append(content)
  264. xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
  265. ys.append(0)
  266. axes.plot(xs, ys, label=label, color=plot_color, **kwargs)
  267. if self.signal is not None and self.signal_stack:
  268. label, lumi, bins, plot_color = self.signal
  269. if self.signal_scale != 1:
  270. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  271. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, hatch='/')
  272. for background in self.backgrounds:
  273. draw_bar(*background)
  274. if self.signal is not None and not self.signal_stack:
  275. # draw_bar(*self.signal, stack=False, color='k')
  276. label, lumi, bins, plot_color = self.signal
  277. if self.signal_scale != 1:
  278. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  279. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
  280. axes.set_title(self.title)
  281. axes.set_xlabel(self.xlabel)
  282. axes.set_ylabel(self.ylabel)
  283. axes.set_xlim(*self.xlim)
  284. # axes.set_ylim(*self.ylim)
  285. if self.logy:
  286. axes.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
  287. else:
  288. axes.set_ylim(None, max(bottoms)*1.2)
  289. axes.legend(frameon=True, ncol=2)
  290. add_decorations(axes, self.luminosity, self.energy)
  291. def draw(self, axes, save=False, filename=None, **kwargs):
  292. self.do_draw(axes, **kwargs)
  293. if save:
  294. if filename is None:
  295. filename = "".join(c for c in self.title if c.isalnum() or c in (' ._+-'))+".png"
  296. self.save(filename, **kwargs)
  297. class StackHistWithSignificance(StackHist):
  298. def __init__(self, *args, **kwargs):
  299. super().__init__(*args, **kwargs)
  300. def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
  301. bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
  302. axes.set_position(top_box)
  303. super().do_draw(axes)
  304. axes.set_xticks([])
  305. rhs_color = '#cc6600'
  306. bottom = axes.get_figure().add_axes(bottom_box)
  307. bottom_rhs = bottom.twinx()
  308. bgs = [0]*self.n_bins
  309. for (_, _, bins, _) in self.backgrounds:
  310. for i, (left, right, value) in enumerate(bins):
  311. bgs[i] += value
  312. sigs = [0]*self.n_bins
  313. if bin_significance:
  314. xs = []
  315. for i, (left, right, value) in enumerate(self.signal[2]):
  316. sigs[i] += value
  317. xs.append(left)
  318. xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
  319. bottom.plot(xs, ys, '.k')
  320. if high_cut_significance:
  321. # s/(s+b) for events passing a minimum cut requirement
  322. min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
  323. min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
  324. min_xs, min_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
  325. if (sig+bg) > 0])
  326. bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
  327. if low_cut_significance:
  328. # s/(s+b) for events passing a maximum cut requirement
  329. max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
  330. max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
  331. max_xs, max_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
  332. if (sig+bg) > 0])
  333. bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
  334. bottom.set_ylabel(r'$S/(S+B)$')
  335. bottom.set_xlim(axes.get_xlim())
  336. bottom.set_ylim((0, 1.1))
  337. if low_cut_significance or high_cut_significance:
  338. bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
  339. bottom_rhs.yaxis.label.set_color(rhs_color)
  340. bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
  341. # bottom.grid()
  342. if __name__ == '__main__':
  343. import matplotlib.pyplot as plt
  344. from utils import ResultSet
  345. rs_TTZ = ResultSet("TTZ", "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
  346. rs_TTW = ResultSet("TTW", "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
  347. rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
  348. rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
  349. sh = StackHist('B-Jet Multiplicity')
  350. sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
  351. sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
  352. sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
  353. sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
  354. sh.luminosity = 40
  355. sh.energy = 13
  356. sh.xlabel = 'B-Jet Count'
  357. sh.ylabel = r'\# Events'
  358. sh.xlim = (-.5, 9.5)
  359. sh.signal_stack = False
  360. fig = plt.figure()
  361. sh.draw(fig.gca())
  362. plt.show()
  363. # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')