plotter.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. #!/usr/bin/env python3
  2. from collections import namedtuple
  3. import matplotlib as mpl
  4. import numpy as np
  5. from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
  6. hist_normalize)
  7. # mpl.rc('text', usetex=True)
  8. # mpl.rc('figure', dpi=200)
  9. # mpl.rc('savefig', dpi=200)
  10. plot_registry = {}
  11. Plot = namedtuple('Plot', ['name', 'filename', 'title', 'desc', 'args'])
  12. def make_plot(filename=None, title='', scale=1):
  13. import matplotlib.pyplot as plt
  14. from functools import wraps
  15. from os.path import join
  16. from os import makedirs
  17. from inspect import signature, getdoc
  18. from markdown import Markdown
  19. def fn_call_to_dict(fn, *args, **kwargs):
  20. pnames = list(signature(fn).parameters)
  21. pvals = list(args)+list(kwargs.keys())
  22. return {k: v for k, v in zip(pnames, pvals)}
  23. def process_docs(fn):
  24. raw = getdoc(fn)
  25. if raw:
  26. md = Markdown(extensions=['mdx_math'],
  27. extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
  28. return md.convert(raw)
  29. else:
  30. return None
  31. def wrap(fn):
  32. @wraps(fn)
  33. def f(*args, **kwargs):
  34. nonlocal filename
  35. plt.clf()
  36. plt.gcf().set_size_inches(scale*10, scale*10)
  37. fn(*args, **kwargs)
  38. pdict = fn_call_to_dict(fn, *args, **kwargs)
  39. if filename is None:
  40. pstr = ','.join('{}:{}'.format(pname, pval)
  41. for pname, pval in pdict.items())
  42. filename = fn.__name__ + '::' + pstr
  43. filename = filename.replace('/', '_').replace('.', '_')+".png"
  44. plt.tight_layout()
  45. try:
  46. makedirs('output/figures')
  47. except FileExistsError:
  48. pass
  49. plt.savefig(join('output/figures', filename))
  50. plot_registry[fn.__name__] = Plot(fn.__name__, join('figures', filename),
  51. title, process_docs(fn), pdict)
  52. return f
  53. return wrap
  54. def add_decorations(axes, luminosity, energy):
  55. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  56. axes.text(0.01, 0.98, cms_prelim,
  57. horizontalalignment='left',
  58. verticalalignment='top',
  59. transform=axes.transAxes)
  60. lumi = ""
  61. energy_str = ""
  62. if luminosity is not None:
  63. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  64. if energy is not None:
  65. energy_str = r'({} TeV)'.format(energy)
  66. axes.text(1, 1, ' '.join([lumi, energy_str]),
  67. horizontalalignment='right',
  68. verticalalignment='bottom',
  69. transform=axes.transAxes)
  70. def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
  71. log=False, fig=None, xlim=None, ylim=None, fit=None,
  72. **kwargs):
  73. """ Plots a 1D ROOT histogram object using matplotlib """
  74. from inspect import signature
  75. if norm:
  76. h = hist_normalize(h, norm)
  77. values, errors, edges = h
  78. scale = 1. if norm is None else norm/np.sum(values)
  79. values = [val*scale for val in values]
  80. errors = [val*scale for val in errors]
  81. left, right = np.array(edges[:-1]), np.array(edges[1:])
  82. X = np.array([left, right]).T.flatten()
  83. Y = np.array([values, values]).T.flatten()
  84. if axes is None:
  85. import matplotlib.pyplot as plt
  86. axes = plt.gca()
  87. axes.set_xlabel(kwargs.pop('xlabel', ''))
  88. axes.set_ylabel(kwargs.pop('ylabel', ''))
  89. axes.set_title(kwargs.pop('title', ''))
  90. if xlim is not None:
  91. axes.set_xlim(xlim)
  92. if ylim is not None:
  93. axes.set_ylim(ylim)
  94. # elif not log:
  95. # axes.set_ylim((0, None))
  96. axes.plot(X, Y, *args, linewidth=1, **kwargs)
  97. if include_errors:
  98. axes.errorbar(hist_bin_centers(h), values, yerr=errors,
  99. color='k', marker=None, linestyle='None',
  100. barsabove=True, elinewidth=.7, capsize=1)
  101. if log:
  102. axes.set_yscale('log')
  103. if fit:
  104. f, p0 = fit
  105. popt, pcov = hist_fit(h, f, p0)
  106. fit_xs = np.linspace(X[0], X[-1], 100)
  107. fit_ys = f(fit_xs, *popt)
  108. axes.plot(fit_xs, fit_ys, '--g')
  109. arglabels = list(signature(f).parameters)[1:]
  110. label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
  111. for label, value in zip(arglabels, popt))
  112. axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
  113. fontsize='x-small', family='monospace', usetex=False)
  114. axes.grid()
  115. def hist2d_plot(h, *args, axes=None, **kwargs):
  116. """ Plots a 2D ROOT histogram object using matplotlib """
  117. try:
  118. values, errors, xs, ys = h
  119. except (TypeError, ValueError):
  120. values, errors, xs, ys = hist2d(h)
  121. if axes is None:
  122. import matplotlib.pyplot as plt
  123. axes = plt.gca()
  124. axes.set_xlabel(kwargs.pop('xlabel', ''))
  125. axes.set_ylabel(kwargs.pop('ylabel', ''))
  126. axes.set_title(kwargs.pop('title', ''))
  127. axes.pcolormesh(xs, ys, values,)
  128. # axes.colorbar() TODO: Re-enable this
  129. class StackHist:
  130. def __init__(self, title=""):
  131. raise NotImplementedError("need to fix to not use to_bin_list")
  132. self.title = title
  133. self.xlabel = ""
  134. self.ylabel = ""
  135. self.xlim = (None, None)
  136. self.ylim = (None, None)
  137. self.logx = False
  138. self.logy = False
  139. self.backgrounds = []
  140. self.signal = None
  141. self.signal_stack = True
  142. self.data = None
  143. def add_mc_background(self, th1, label, lumi=None, plot_color=''):
  144. self.backgrounds.append((label, lumi, hist(th1), plot_color))
  145. def set_mc_signal(self, th1, label, lumi=None, stack=True, scale=1, plot_color=''):
  146. self.signal = (label, lumi, hist(th1), plot_color)
  147. self.signal_stack = stack
  148. self.signal_scale = scale
  149. def set_data(self, th1, lumi=None, plot_color=''):
  150. self.data = ('data', lumi, hist(th1), plot_color)
  151. self.luminosity = lumi
  152. def _verify_binning_match(self):
  153. bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
  154. if self.signal is not None:
  155. bins_count.append(len(self.signal[2]))
  156. if self.data is not None:
  157. bins_count.append(len(self.data[2]))
  158. n_bins = bins_count[0]
  159. if any(bin_count != n_bins for bin_count in bins_count):
  160. raise ValueError("all histograms must have the same number of bins")
  161. self.n_bins = n_bins
  162. def save(self, filename, **kwargs):
  163. import matplotlib.pyplot as plt
  164. plt.ioff()
  165. fig = plt.figure()
  166. ax = fig.gca()
  167. self.do_draw(ax, **kwargs)
  168. fig.savefig("figures/"+filename, transparent=True)
  169. plt.close(fig)
  170. plt.ion()
  171. def do_draw(self, axes):
  172. self.axeses = [axes]
  173. self._verify_binning_match()
  174. bottoms = [0]*self.n_bins
  175. if self.logx:
  176. axes.set_xscale('log')
  177. if self.logy:
  178. axes.set_yscale('log')
  179. def draw_bar(label, lumi, bins, plot_color, scale=1, stack=True, **kwargs):
  180. if stack:
  181. lefts = []
  182. widths = []
  183. heights = []
  184. for left, right, content in bins:
  185. lefts.append(left)
  186. widths.append(right-left)
  187. if lumi is not None:
  188. content *= self.luminosity/lumi
  189. content *= scale
  190. heights.append(content)
  191. axes.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
  192. for i, (_, _, content) in enumerate(bins):
  193. if lumi is not None:
  194. content *= self.luminosity/lumi
  195. content *= scale
  196. bottoms[i] += content
  197. else:
  198. xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
  199. ys = [0]
  200. for left, right, content in bins:
  201. width2 = (right-left)/2
  202. if lumi is not None:
  203. content *= self.luminosity/lumi
  204. content *= scale
  205. xs.append(left-width2)
  206. ys.append(content)
  207. xs.append(right-width2)
  208. ys.append(content)
  209. xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
  210. ys.append(0)
  211. axes.plot(xs, ys, label=label, color=plot_color, **kwargs)
  212. if self.signal is not None and self.signal_stack:
  213. label, lumi, bins, plot_color = self.signal
  214. if self.signal_scale != 1:
  215. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  216. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, hatch='/')
  217. for background in self.backgrounds:
  218. draw_bar(*background)
  219. if self.signal is not None and not self.signal_stack:
  220. # draw_bar(*self.signal, stack=False, color='k')
  221. label, lumi, bins, plot_color = self.signal
  222. if self.signal_scale != 1:
  223. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  224. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
  225. axes.set_title(self.title)
  226. axes.set_xlabel(self.xlabel)
  227. axes.set_ylabel(self.ylabel)
  228. axes.set_xlim(*self.xlim)
  229. # axes.set_ylim(*self.ylim)
  230. if self.logy:
  231. axes.set_ylim(None, np.exp(np.log(max(bottoms))*1.4))
  232. else:
  233. axes.set_ylim(None, max(bottoms)*1.2)
  234. axes.legend(frameon=True, ncol=2)
  235. add_decorations(axes, self.luminosity, self.energy)
  236. def draw(self, axes, save=False, filename=None, **kwargs):
  237. self.do_draw(axes, **kwargs)
  238. if save:
  239. if filename is None:
  240. filename = "".join(c for c in self.title if c.isalnum() or c in (' ._+-'))+".png"
  241. self.save(filename, **kwargs)
  242. class StackHistWithSignificance(StackHist):
  243. def __init__(self, *args, **kwargs):
  244. super().__init__(*args, **kwargs)
  245. def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
  246. bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
  247. axes.set_position(top_box)
  248. super().do_draw(axes)
  249. axes.set_xticks([])
  250. rhs_color = '#cc6600'
  251. bottom = axes.get_figure().add_axes(bottom_box)
  252. bottom_rhs = bottom.twinx()
  253. bgs = [0]*self.n_bins
  254. for (_, _, bins, _) in self.backgrounds:
  255. for i, (left, right, value) in enumerate(bins):
  256. bgs[i] += value
  257. sigs = [0]*self.n_bins
  258. if bin_significance:
  259. xs = []
  260. for i, (left, right, value) in enumerate(self.signal[2]):
  261. sigs[i] += value
  262. xs.append(left)
  263. xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg) > 0])
  264. bottom.plot(xs, ys, '.k')
  265. if high_cut_significance:
  266. # s/(s+b) for events passing a minimum cut requirement
  267. min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
  268. min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
  269. min_xs, min_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
  270. if (sig+bg) > 0])
  271. bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
  272. if low_cut_significance:
  273. # s/(s+b) for events passing a maximum cut requirement
  274. max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
  275. max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
  276. max_xs, max_ys = zip(*[(x, sig/np.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
  277. if (sig+bg) > 0])
  278. bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
  279. bottom.set_ylabel(r'$S/(S+B)$')
  280. bottom.set_xlim(axes.get_xlim())
  281. bottom.set_ylim((0, 1.1))
  282. if low_cut_significance or high_cut_significance:
  283. bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
  284. bottom_rhs.yaxis.label.set_color(rhs_color)
  285. bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
  286. # bottom.grid()
  287. if __name__ == '__main__':
  288. import matplotlib.pyplot as plt
  289. from utils import ResultSet
  290. rs_TTZ = ResultSet("TTZ", "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
  291. rs_TTW = ResultSet("TTW", "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
  292. rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
  293. rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
  294. sh = StackHist('B-Jet Multiplicity')
  295. sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
  296. sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
  297. sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
  298. sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
  299. sh.luminosity = 40
  300. sh.energy = 13
  301. sh.xlabel = 'B-Jet Count'
  302. sh.ylabel = r'\# Events'
  303. sh.xlim = (-.5, 9.5)
  304. sh.signal_stack = False
  305. fig = plt.figure()
  306. sh.draw(fig.gca())
  307. plt.show()
  308. # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')