plotter.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #!/usr/bin/env python3
  2. from collections import namedtuple
  3. import numpy as np
  4. from markdown import Markdown
  5. import latexipy as lp
  6. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  7. hist_normalize)
  8. __all__ = ['make_plot',
  9. 'plot_registry',
  10. 'hist_plot',
  11. 'hist2d_plot']
  12. plot_registry = {}
  13. Plot = namedtuple('Plot', ['name', 'filename', 'title', 'desc', 'args'])
  14. MD = Markdown(extensions=['mdx_math'],
  15. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  16. lp.latexify(params={'pgf.texsystem': 'xelatex',
  17. 'text.usetex': True,
  18. 'font.family': 'serif',
  19. 'pgf.preamble': [r'\usepackage[utf8x]{inputenc}',
  20. r'\usepackage[T1]{fontenc}'],
  21. 'font.size': 8,
  22. 'axes.labelsize': 8,
  23. 'axes.titlesize': 8,
  24. 'legend.fontsize': 8,
  25. 'xtick.labelsize': 8,
  26. 'ytick.labelsize': 8,
  27. 'figure.dpi': 150,
  28. 'savefig.transparent': True,
  29. },
  30. new_backend='TkAgg')
  31. def make_plot(plot_name=None, title='', scale=1, exts=['png', 'pgf']):
  32. import matplotlib.pyplot as plt
  33. from functools import wraps
  34. from os.path import join
  35. from inspect import signature, getdoc
  36. def fn_call_to_dict(fn, *args, **kwargs):
  37. pnames = list(signature(fn).parameters)
  38. pvals = list(args)+list(kwargs.keys())
  39. return {k: v for k, v in zip(pnames, pvals)}
  40. def process_docs(fn):
  41. raw = getdoc(fn)
  42. if raw:
  43. return MD.convert(raw)
  44. else:
  45. return None
  46. def wrap(fn):
  47. @wraps(fn)
  48. def f(*args, **kwargs):
  49. nonlocal plot_name
  50. pdict = fn_call_to_dict(fn, *args, **kwargs)
  51. if plot_name is None:
  52. pstr = ','.join('{}:{}'.format(pname, pval)
  53. for pname, pval in pdict.items())
  54. plot_name = fn.__name__ + '::' + pstr
  55. plot_name = plot_name.replace('/', '_').replace('.', '_')
  56. with lp.figure(plot_name, directory='output/figures',
  57. exts=exts,
  58. size=(scale*10, scale*10)):
  59. fn(*args, **kwargs)
  60. plt.tight_layout()
  61. filename = plot_name+".png"
  62. plot_registry[fn.__name__] = Plot(fn.__name__, join('figures', filename),
  63. title, process_docs(fn), pdict)
  64. return f
  65. return wrap
  66. def add_decorations(axes, luminosity, energy):
  67. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  68. axes.text(0.01, 0.98, cms_prelim,
  69. horizontalalignment='left',
  70. verticalalignment='top',
  71. transform=axes.transAxes)
  72. lumi = ""
  73. energy_str = ""
  74. if luminosity is not None:
  75. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  76. if energy is not None:
  77. energy_str = r'({} TeV)'.format(energy)
  78. axes.text(1, 1, ' '.join([lumi, energy_str]),
  79. horizontalalignment='right',
  80. verticalalignment='bottom',
  81. transform=axes.transAxes)
  82. def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
  83. log=False, fig=None, xlim=None, ylim=None, fit=None,
  84. **kwargs):
  85. """ Plots a 1D ROOT histogram object using matplotlib """
  86. from inspect import signature
  87. if norm:
  88. h = hist_normalize(h, norm)
  89. values, errors, edges = h
  90. scale = 1. if norm is None else norm/np.sum(values)
  91. values = [val*scale for val in values]
  92. errors = [val*scale for val in errors]
  93. left, right = np.array(edges[:-1]), np.array(edges[1:])
  94. X = np.array([left, right]).T.flatten()
  95. Y = np.array([values, values]).T.flatten()
  96. if axes is None:
  97. import matplotlib.pyplot as plt
  98. axes = plt.gca()
  99. axes.set_xlabel(kwargs.pop('xlabel', ''))
  100. axes.set_ylabel(kwargs.pop('ylabel', ''))
  101. axes.set_title(kwargs.pop('title', ''))
  102. if xlim is not None:
  103. axes.set_xlim(xlim)
  104. if ylim is not None:
  105. axes.set_ylim(ylim)
  106. # elif not log:
  107. # axes.set_ylim((0, None))
  108. axes.plot(X, Y, *args, linewidth=1, **kwargs)
  109. if include_errors:
  110. axes.errorbar(hist_bin_centers(h), values, yerr=errors,
  111. color='k', marker=None, linestyle='None',
  112. barsabove=True, elinewidth=.7, capsize=1)
  113. if log:
  114. axes.set_yscale('log')
  115. if fit:
  116. f, p0 = fit
  117. popt, pcov = hist_fit(h, f, p0)
  118. fit_xs = np.linspace(X[0], X[-1], 100)
  119. fit_ys = f(fit_xs, *popt)
  120. axes.plot(fit_xs, fit_ys, '--g')
  121. arglabels = list(signature(f).parameters)[1:]
  122. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  123. for label, value in zip(arglabels, popt))
  124. axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
  125. fontsize='medium', family='monospace', usetex=False)
  126. axes.grid()
  127. def hist2d_plot(h, *args, axes=None, **kwargs):
  128. """ Plots a 2D ROOT histogram object using matplotlib """
  129. try:
  130. values, errors, xs, ys = h
  131. except (TypeError, ValueError):
  132. values, errors, xs, ys = hist2d(h)
  133. if axes is None:
  134. import matplotlib.pyplot as plt
  135. axes = plt.gca()
  136. axes.set_xlabel(kwargs.pop('xlabel', ''))
  137. axes.set_ylabel(kwargs.pop('ylabel', ''))
  138. axes.set_title(kwargs.pop('title', ''))
  139. axes.pcolormesh(xs, ys, values,)
  140. # axes.colorbar() TODO: Re-enable this
  141. class StackHist:
  142. def __init__(self, title=""):
  143. raise NotImplementedError("need to fix to not use to_bin_list")
  144. self.title = title
  145. self.xlabel = ""
  146. self.ylabel = ""
  147. self.xlim = (None, None)
  148. self.ylim = (None, None)
  149. self.logx = False
  150. self.logy = False
  151. self.backgrounds = []
  152. self.signal = None
  153. self.signal_stack = True
  154. self.data = None
  155. def add_mc_background(self, th1, label, lumi=None, plot_color=''):
  156. self.backgrounds.append((label, lumi, hist(th1), plot_color))
  157. def set_mc_signal(self, th1, label, lumi=None, stack=True, scale=1, plot_color=''):
  158. self.signal = (label, lumi, hist(th1), plot_color)
  159. self.signal_stack = stack
  160. self.signal_scale = scale
  161. def set_data(self, th1, lumi=None, plot_color=''):
  162. self.data = ('data', lumi, hist(th1), plot_color)
  163. self.luminosity = lumi
  164. def _verify_binning_match(self):
  165. bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
  166. if self.signal is not None:
  167. bins_count.append(len(self.signal[2]))
  168. if self.data is not None:
  169. bins_count.append(len(self.data[2]))
  170. n_bins = bins_count[0]
  171. if any(bin_count != n_bins for bin_count in bins_count):
  172. raise ValueError("all histograms must have the same number of bins")
  173. self.n_bins = n_bins
  174. def save(self, filename, **kwargs):
  175. import matplotlib.pyplot as plt
  176. plt.ioff()
  177. fig = plt.figure()
  178. ax = fig.gca()
  179. self.do_draw(ax, **kwargs)
  180. fig.savefig("figures/"+filename, transparent=True)
  181. plt.close(fig)
  182. plt.ion()
  183. def do_draw(self, axes):
  184. self.axeses = [axes]
  185. self._verify_binning_match()
  186. bottoms = [0]*self.n_bins
  187. if self.logx:
  188. axes.set_xscale('log')
  189. if self.logy:
  190. axes.set_yscale('log')
  191. def draw_bar(label, lumi, bins, plot_color, scale=1, stack=True, **kwargs):
  192. if stack:
  193. lefts = []
  194. widths = []
  195. heights = []
  196. for left, right, content in bins:
  197. lefts.append(left)
  198. widths.append(right-left)
  199. if lumi is not None:
  200. content *= self.luminosity/lumi
  201. content *= scale
  202. heights.append(content)
  203. axes.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
  204. for i, (_, _, content) in enumerate(bins):
  205. if lumi is not None:
  206. content *= self.luminosity/lumi
  207. content *= scale
  208. bottoms[i] += content
  209. else:
  210. xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
  211. ys = [0]
  212. for left, right, content in bins:
  213. width2 = (right-left)/2
  214. if lumi is not None:
  215. content *= self.luminosity/lumi
  216. content *= scale
  217. xs.append(left-width2)
  218. ys.append(content)
  219. xs.append(right-width2)
  220. ys.append(content)
  221. xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
  222. ys.append(0)
  223. axes.plot(xs, ys, label=label, color=plot_color, **kwargs)
  224. if self.signal is not None and self.signal_stack:
  225. label, lumi, bins, plot_color = self.signal
  226. if self.signal_scale != 1:
  227. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  228. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, hatch='/')
  229. for background in self.backgrounds:
  230. draw_bar(*background)
  231. if self.signal is not None and not self.signal_stack:
  232. # draw_bar(*self.signal, stack=False, color='k')
  233. label, lumi, bins, plot_color = self.signal
  234. if self.signal_scale != 1:
  235. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  236. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
  237. axes.set_title(self.title)
  238. axes.set_xlabel(self.xlabel)
  239. axes.set_ylabel(self.ylabel)
  240. axes.set_xlim(*self.xlim)
  241. # axes.set_ylim(*self.ylim)
  242. if self.logy:
  243. axes.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
  244. else:
  245. axes.set_ylim(None, max(bottoms)*1.2)
  246. axes.legend(frameon=True, ncol=2)
  247. add_decorations(axes, self.luminosity, self.energy)
  248. def draw(self, axes, save=False, filename=None, **kwargs):
  249. self.do_draw(axes, **kwargs)
  250. if save:
  251. if filename is None:
  252. filename = "".join(c for c in self.title if c.isalnum() or c in (' ._+-'))+".png"
  253. self.save(filename, **kwargs)
  254. class StackHistWithSignificance(StackHist):
  255. def __init__(self, *args, **kwargs):
  256. super().__init__(*args, **kwargs)
  257. def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
  258. bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
  259. axes.set_position(top_box)
  260. super().do_draw(axes)
  261. axes.set_xticks([])
  262. rhs_color = '#cc6600'
  263. bottom = axes.get_figure().add_axes(bottom_box)
  264. bottom_rhs = bottom.twinx()
  265. bgs = [0]*self.n_bins
  266. for (_, _, bins, _) in self.backgrounds:
  267. for i, (left, right, value) in enumerate(bins):
  268. bgs[i] += value
  269. sigs = [0]*self.n_bins
  270. if bin_significance:
  271. xs = []
  272. for i, (left, right, value) in enumerate(self.signal[2]):
  273. sigs[i] += value
  274. xs.append(left)
  275. xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
  276. bottom.plot(xs, ys, '.k')
  277. if high_cut_significance:
  278. # s/(s+b) for events passing a minimum cut requirement
  279. min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
  280. min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
  281. min_xs, min_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
  282. if (sig+bg) > 0])
  283. bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
  284. if low_cut_significance:
  285. # s/(s+b) for events passing a maximum cut requirement
  286. max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
  287. max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
  288. max_xs, max_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
  289. if (sig+bg) > 0])
  290. bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
  291. bottom.set_ylabel(r'$S/(S+B)$')
  292. bottom.set_xlim(axes.get_xlim())
  293. bottom.set_ylim((0, 1.1))
  294. if low_cut_significance or high_cut_significance:
  295. bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
  296. bottom_rhs.yaxis.label.set_color(rhs_color)
  297. bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
  298. # bottom.grid()
  299. if __name__ == '__main__':
  300. import matplotlib.pyplot as plt
  301. from utils import ResultSet
  302. rs_TTZ = ResultSet("TTZ", "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
  303. rs_TTW = ResultSet("TTW", "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
  304. rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
  305. rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
  306. sh = StackHist('B-Jet Multiplicity')
  307. sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
  308. sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
  309. sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
  310. sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
  311. sh.luminosity = 40
  312. sh.energy = 13
  313. sh.xlabel = 'B-Jet Count'
  314. sh.ylabel = r'\# Events'
  315. sh.xlim = (-.5, 9.5)
  316. sh.signal_stack = False
  317. fig = plt.figure()
  318. sh.draw(fig.gca())
  319. plt.show()
  320. # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')