plotter.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #!/usr/bin/env python3
  2. import math
  3. import matplotlib as mpl
  4. mpl.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
  5. mpl.rc('font', **{'family': 'serif', 'serif': ['Palatino']})
  6. mpl.rc('text', usetex=True)
  7. mpl.rc('savefig', dpi=120)
  8. def add_decorations(axes, luminosity, energy):
  9. cms_prelim = r'{\raggedright{}\textsf{\textbf{CMS}}\\ \emph{Preliminary}}'
  10. axes.text(0.01, 0.98, cms_prelim,
  11. horizontalalignment='left',
  12. verticalalignment='top',
  13. transform=axes.transAxes)
  14. lumi = ""
  15. energy_str = ""
  16. if luminosity is not None:
  17. lumi = r'${} \mathrm{{fb}}^{{-1}}$'.format(luminosity)
  18. if energy is not None:
  19. energy_str = r'({} TeV)'.format(energy)
  20. axes.text(1, 1, ' '.join([lumi, energy_str]),
  21. horizontalalignment='right',
  22. verticalalignment='bottom',
  23. transform=axes.transAxes)
  24. class StackHist:
  25. def __init__(self, title=""):
  26. self.title = title
  27. self.xlabel = ""
  28. self.ylabel = ""
  29. self.xlim = (None, None)
  30. self.ylim = (None, None)
  31. self.logx = False
  32. self.logy = False
  33. self.backgrounds = []
  34. self.signal = None
  35. self.signal_stack = True
  36. self.data = None
  37. @staticmethod
  38. def to_bin_list(th1):
  39. bins = []
  40. for i in range(th1.GetNbinsX()):
  41. center = th1.GetBinCenter(i + 1)
  42. width = th1.GetBinWidth(i + 1)
  43. content = th1.GetBinContent(i + 1)
  44. bins.append((center-width/2, center+width/2, content))
  45. return bins
  46. def add_mc_background(self, th1, label, lumi=None, plot_color=''):
  47. self.backgrounds.append((label, lumi, self.to_bin_list(th1), plot_color))
  48. def set_mc_signal(self, th1, label, lumi=None, stack=True, scale=1, plot_color=''):
  49. self.signal = (label, lumi, self.to_bin_list(th1), plot_color)
  50. self.signal_stack = stack
  51. self.signal_scale = scale
  52. def set_data(self, th1, lumi=None, plot_color=''):
  53. self.data = ('data', lumi, self.to_bin_list(th1), plot_color)
  54. self.luminosity = lumi
  55. def _verify_binning_match(self):
  56. bins_count = [len(bins) for _, _, bins, _ in self.backgrounds]
  57. if self.signal is not None:
  58. bins_count.append(len(self.signal[2]))
  59. if self.data is not None:
  60. bins_count.append(len(self.data[2]))
  61. n_bins = bins_count[0]
  62. if any(bin_count != n_bins for bin_count in bins_count):
  63. raise ValueError("all histograms must have the same number of bins")
  64. self.n_bins = n_bins
  65. def save(self, fname):
  66. from matplotlib.transforms import Bbox
  67. def full_extent(ax, pad=0.0):
  68. """Get the full extent of an axes, including axes labels, tick labels, and
  69. titles."""
  70. # For text objects, we need to draw the figure first, otherwise the extents
  71. # are undefined.
  72. ax.figure.canvas.draw()
  73. items = ax.get_xticklabels() + ax.get_yticklabels()
  74. items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
  75. # items += [ax, ax.title]
  76. bbox = Bbox.union([item.get_window_extent() for item in items])
  77. return bbox.expanded(1.0 + pad, 1.0 + pad)
  78. extents = []
  79. for axes in self.axeses:
  80. extents.append(full_extent(axes).transformed(axes.figure.dpi_scale_trans.inverted()))
  81. extent = Bbox.union(extents)
  82. axes.figure.savefig('figures/'+fname, bbox_inches=extent)
  83. def do_draw(self, axes):
  84. self.axeses = [axes]
  85. self._verify_binning_match()
  86. bottoms = [0]*self.n_bins
  87. if self.logx:
  88. axes.set_xscale('log')
  89. if self.logy:
  90. axes.set_yscale('log')
  91. def draw_bar(label, lumi, bins, plot_color, scale=1, stack=True, **kwargs):
  92. if stack:
  93. lefts = []
  94. widths = []
  95. heights = []
  96. for left, right, content in bins:
  97. lefts.append(left)
  98. widths.append(right-left)
  99. if lumi is not None:
  100. content *= self.luminosity/lumi
  101. content *= scale
  102. heights.append(content)
  103. axes.bar(lefts, heights, widths, bottoms, label=label, color=plot_color, **kwargs)
  104. for i, (_, _, content) in enumerate(bins):
  105. if lumi is not None:
  106. content *= self.luminosity/lumi
  107. content *= scale
  108. bottoms[i] += content
  109. else:
  110. xs = [bins[0][0] - (bins[0][1]-bins[0][0])/2]
  111. ys = [0]
  112. for left, right, content in bins:
  113. width2 = (right-left)/2
  114. if lumi is not None:
  115. content *= self.luminosity/lumi
  116. content *= scale
  117. xs.append(left-width2)
  118. ys.append(content)
  119. xs.append(right-width2)
  120. ys.append(content)
  121. xs.append(bins[-1][0] + (bins[-1][1]-bins[-1][0])/2)
  122. ys.append(0)
  123. axes.plot(xs, ys, label=label, color=plot_color, **kwargs)
  124. if self.signal is not None and self.signal_stack:
  125. label, lumi, bins, plot_color = self.signal
  126. if self.signal_scale != 1:
  127. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  128. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, hatch='/')
  129. for background in self.backgrounds:
  130. draw_bar(*background)
  131. if self.signal is not None and not self.signal_stack:
  132. # draw_bar(*self.signal, stack=False, color='k')
  133. label, lumi, bins, plot_color = self.signal
  134. if self.signal_scale != 1:
  135. label = r"{}$\times{:d}$".format(label, self.signal_scale)
  136. draw_bar(label, lumi, bins, plot_color, scale=self.signal_scale, stack=False)
  137. axes.set_title(self.title)
  138. axes.set_xlabel(self.xlabel)
  139. axes.set_ylabel(self.ylabel)
  140. axes.set_xlim(*self.xlim)
  141. # axes.set_ylim(*self.ylim)
  142. if self.logy:
  143. axes.set_ylim(None, math.exp(math.log(max(bottoms))*1.4))
  144. else:
  145. axes.set_ylim(None, max(bottoms)*1.2)
  146. axes.legend(frameon=True, ncol=2)
  147. add_decorations(axes, self.luminosity, self.energy)
  148. def draw(self, axes, save=True, **kwargs):
  149. self.do_draw(axes, **kwargs)
  150. if save:
  151. self.save(self.title+".png")
  152. class StackHistWithSignificance(StackHist):
  153. def __init__(self, *args, **kwargs):
  154. super().__init__(*args, **kwargs)
  155. def do_draw(self, axes, bin_significance=True, low_cut_significance=False, high_cut_significance=False):
  156. bottom_box, _, top_box = axes.get_position().splity(0.28, 0.30)
  157. axes.set_position(top_box)
  158. super().do_draw(axes)
  159. axes.set_xticks([])
  160. rhs_color = '#cc6600'
  161. bottom = axes.get_figure().add_axes(bottom_box)
  162. bottom_rhs = bottom.twinx()
  163. self.axeses = [axes, bottom, bottom_rhs]
  164. bgs = [0]*self.n_bins
  165. for (_, _, bins, _) in self.backgrounds:
  166. for i, (left, right, value) in enumerate(bins):
  167. bgs[i] += value
  168. sigs = [0]*self.n_bins
  169. if bin_significance:
  170. xs = []
  171. for i, (left, right, value) in enumerate(self.signal[2]):
  172. sigs[i] += value
  173. xs.append(left)
  174. xs, ys = zip(*[(x, sig/(sig+bg)) for x, sig, bg in zip(xs, sigs, bgs) if (sig+bg)>0])
  175. bottom.plot(xs, ys, '.k')
  176. if high_cut_significance:
  177. # s/(s+b) for events passing a minimum cut requirement
  178. min_bg = [sum(bgs[i:]) for i in range(self.n_bins)]
  179. min_sig = [sum(sigs[i:]) for i in range(self.n_bins)]
  180. min_xs, min_ys = zip(*[(x, sig/math.sqrt(sig+bg)) for x, sig, bg in zip(xs, min_sig, min_bg)
  181. if (sig+bg) > 0])
  182. bottom_rhs.plot(min_xs, min_ys, '->', color=rhs_color)
  183. if low_cut_significance:
  184. # s/(s+b) for events passing a maximum cut requirement
  185. max_bg = [sum(bgs[:i]) for i in range(self.n_bins)]
  186. max_sig = [sum(sigs[:i]) for i in range(self.n_bins)]
  187. max_xs, max_ys = zip(*[(x, sig/math.sqrt(sig+bg)) for x, sig, bg in zip(xs, max_sig, max_bg)
  188. if (sig+bg) > 0])
  189. bottom_rhs.plot(max_xs, max_ys, '-<', color=rhs_color)
  190. bottom.set_ylabel(r'$S/(S+B)$')
  191. bottom.set_xlim(axes.get_xlim())
  192. bottom.set_ylim((0, 1.1))
  193. if low_cut_significance or high_cut_significance:
  194. bottom_rhs.set_ylabel(r'$S/\sqrt{S+B}$')
  195. bottom_rhs.yaxis.label.set_color(rhs_color)
  196. bottom_rhs.tick_params(axis='y', colors=rhs_color, size=4, width=1.5)
  197. # bottom.grid()
  198. if __name__ == '__main__':
  199. import matplotlib.pyplot as plt
  200. from utils import ResultSet
  201. rs_TTZ = ResultSet("TTZ", "../data/TTZToLLNuNu_treeProducerSusyMultilepton_tree.root")
  202. rs_TTW = ResultSet("TTW", "../data/TTWToLNu_treeProducerSusyMultilepton_tree.root")
  203. rs_TTH = ResultSet("TTH", "../data/TTHnobb_mWCutfix_ext1_treeProducerSusyMultilepton_tree.root")
  204. rs_TTTT = ResultSet("TTTT", "../data/TTTT_ext_treeProducerSusyMultilepton_tree.root")
  205. sh = StackHist('B-Jet Multiplicity')
  206. sh.add_mc_background(rs_TTZ.b_jet_count, 'TTZ', lumi=40)
  207. sh.add_mc_background(rs_TTW.b_jet_count, 'TTW', lumi=40)
  208. sh.add_mc_background(rs_TTH.b_jet_count, 'TTH', lumi=40)
  209. sh.set_mc_signal(rs_TTTT.b_jet_count, 'TTTT', lumi=40, scale=10)
  210. sh.luminosity = 40
  211. sh.energy = 13
  212. sh.xlabel = 'B-Jet Count'
  213. sh.ylabel = r'\# Events'
  214. sh.xlim = (-.5, 9.5)
  215. sh.signal_stack = False
  216. fig = plt.figure()
  217. sh.draw(fig.gca())
  218. plt.show()
  219. # sh.add_data(rs_TTZ.b_jet_count, 'TTZ')