jet_pt_studies.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #!/usr/bin/env python
  2. import re
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from filval.plotting import (decl_plot, render_plots, hist_plot, Plot, generate_dashboard, hists_to_table)
  6. from filval.histogram import hist, hist_add
  7. X_TICKS = ['CRZ', 'CRW']+[f'SR{i}' for i in range(1, 17)]
  8. def read_yields(tag, postfit=False):
  9. import ROOT
  10. f = ROOT.TFile.Open(f'data/JetPtStudies/{tag}/mlfit.root')
  11. hists = {'tttt': hist(f.Get('shapes_prefit/SS/tttt')),
  12. 'ttw': hist(f.Get('shapes_prefit/SS/ttw')),
  13. 'ttz': hist(f.Get('shapes_prefit/SS/ttz')),
  14. 'tth': hist(f.Get('shapes_prefit/SS/tth'))
  15. }
  16. try:
  17. hists['fakes_mc'] = hist(f.Get('shapes_prefit/SS/fakes_mc'))
  18. except:
  19. print(tag, 'has no fakes_mc')
  20. hists['fakes_mc'] = None
  21. return hists
  22. def set_xticks():
  23. plt.xticks([x-0.5 for x in range(1, 19)], X_TICKS, rotation=60)
  24. @decl_plot
  25. def plot_yields(bJetPt):
  26. """\
  27. [SR Definitions](https://github.com/cfangmeier/FTAnalysis/blob/new_baseline/analysis/misc/signal_regions.h#L127)
  28. """
  29. _, ((ax_tttt, ax_ttw), (ax_ttz, ax_tth)) = plt.subplots(2, 2)
  30. if bJetPt == 25:
  31. cut_40 = read_yields('v1.0.4_JetPtCut40')
  32. cut_35 = read_yields('v1.0.4_JetPtCut35')
  33. cut_30 = read_yields('v1.0.4_JetPtCut30')
  34. cut_25 = read_yields('v1.0.4_JetPtCut25')
  35. cut_20 = read_yields('v1.0.4_JetPtCut20')
  36. else: # bJetPt == 20
  37. cut_40 = read_yields('v1.0.4_bJetPtCut20JetPtCut40')
  38. cut_35 = read_yields('v1.0.4_bJetPtCut20JetPtCut35')
  39. cut_30 = read_yields('v1.0.4_bJetPtCut20JetPtCut30')
  40. cut_25 = read_yields('v1.0.4_bJetPtCut20JetPtCut25')
  41. cut_20 = read_yields('v1.0.4_bJetPtCut20JetPtCut20')
  42. def do_plot(axis, key):
  43. plt.sca(axis)
  44. hist_plot(cut_40[key], label="Jet $P_T>$40", include_errors=True)
  45. hist_plot(cut_35[key], label="Jet $P_T>$35", include_errors=True)
  46. hist_plot(cut_30[key], label="Jet $P_T>$30", include_errors=True)
  47. hist_plot(cut_25[key], label="Jet $P_T>$25", include_errors=True)
  48. hist_plot(cut_20[key], label="Jet $P_T>$20", include_errors=True)
  49. plt.title(key)
  50. set_xticks()
  51. do_plot(ax_tttt, 'tttt')
  52. do_plot(ax_ttw, 'ttw')
  53. plt.legend()
  54. do_plot(ax_ttz, 'ttz')
  55. do_plot(ax_tth, 'tth')
  56. tables = []
  57. for cut, label in zip([cut_40, cut_35, cut_30, cut_25, cut_20],
  58. ["Jet $P_T>$40", "Jet $P_T>$35", "Jet $P_T>$30", "Jet $P_T>$25", "Jet $P_T>$20"]):
  59. tables.append(f'<h3>{label}</h3>')
  60. tables.append(hists_to_table([cut['tttt'], cut['ttw'], cut['tth'], cut['ttz']],
  61. column_labels=X_TICKS, row_labels=['TTTT', 'TTW', 'TTH', 'TTZ']))
  62. return ''.join(tables)
  63. @decl_plot
  64. def plot_yields_ratio():
  65. _, ((ax_tttt, ax_ttw), (ax_ttz, ax_tth)) = plt.subplots(2, 2)
  66. cut_40 = read_yields('v1.0.4_JetPtCut40')
  67. cut_35 = read_yields('v1.0.4_JetPtCut35')
  68. cut_30 = read_yields('v1.0.4_JetPtCut30')
  69. cut_25 = read_yields('v1.0.4_JetPtCut25')
  70. cut_20 = read_yields('v1.0.4_JetPtCut20')
  71. def do_plot(axis, key):
  72. plt.sca(axis)
  73. c35 = (cut_35[key][0] / cut_40[key][0], *cut_35[key][1:])
  74. c30 = (cut_30[key][0] / cut_40[key][0], *cut_30[key][1:])
  75. c25 = (cut_25[key][0] / cut_40[key][0], *cut_25[key][1:])
  76. c20 = (cut_20[key][0] / cut_40[key][0], *cut_25[key][1:])
  77. hist_plot(c35, label="Jet $P_T>$35")
  78. hist_plot(c30, label="Jet $P_T>$30")
  79. hist_plot(c25, label="Jet $P_T>$25")
  80. hist_plot(c20, label="Jet $P_T>$20")
  81. plt.title(key)
  82. axis.set_ylim((0, 10))
  83. set_xticks()
  84. do_plot(ax_tttt, 'tttt')
  85. plt.legend()
  86. do_plot(ax_ttw, 'ttw')
  87. do_plot(ax_ttz, 'ttz')
  88. do_plot(ax_tth, 'tth')
  89. @decl_plot
  90. def plot_signal_over_background(tags):
  91. cuts = [read_yields(tag) for tag in tags]
  92. def clean_label(label):
  93. return label.replace('_', r'\_')
  94. def do_plot_significance(hists, label):
  95. background = hist_add(hists['ttw'], hists['ttz'], hists['tth'])
  96. tttt = hists['tttt']
  97. ttw = hists['ttw'][0]
  98. ttz = hists['ttz'][0]
  99. tth = hists['tth'][0]
  100. sigma2_bg = (0.4*ttw)**2 + (0.4*ttz)**2 + (0.5*tth)**2
  101. if hists['fakes_mc'] is not None:
  102. fakes_mc = hists['fakes_mc']
  103. background = hist_add(background, fakes_mc)
  104. sigma2_bg += (0.2*fakes_mc[0])**2
  105. ratio = tttt[0] / np.sqrt(tttt[0] + background[0] + sigma2_bg), *tttt[1:]
  106. hist_plot(ratio, label=clean_label(label))
  107. plt.sca(plt.subplot(221))
  108. for cut, tag in zip(cuts, tags):
  109. signal = cut['tttt']
  110. hist_plot(signal, label=clean_label(tag))
  111. set_xticks()
  112. plt.ylabel('TTTT')
  113. plt.ylim((0, 10))
  114. plt.sca(plt.subplot(222))
  115. for cut, tag in zip(cuts, tags):
  116. background = hist_add(cut['ttw'], cut['ttz'], cut['tth'])
  117. hist_plot(background, label=clean_label(tag))
  118. set_xticks()
  119. plt.ylabel('TTW+TTZ+TTH+MCFakes')
  120. plt.ylim((0, 10))
  121. plt.sca(plt.subplot(212))
  122. for cut, tag in zip(cuts, tags):
  123. do_plot_significance(cut, tag)
  124. set_xticks()
  125. plt.ylabel(r'SIG / $\sqrt{SIG+BG + \sigma_{BG}^2}$')
  126. plt.legend()
  127. tables = []
  128. for cut, tag in zip(cuts, tags):
  129. tables.append(f'<h3>{tag}</h3>')
  130. tables.append(hists_to_table([cut['tttt'], cut['ttw'], cut['tth'], cut['ttz']],
  131. column_labels=X_TICKS, row_labels=['TTTT', 'TTW', 'TTH', 'TTZ']))
  132. return ''.join(tables)
  133. @decl_plot
  134. def plot_sigs(include_fakes):
  135. """\
  136. Significances above only include TTZ, TTW, TTH, and (optionally) MC-Fakes(TTBar) backgrounds.
  137. """
  138. def get_sig(tag):
  139. regex = re.compile('Sig: ([0-9.]+)')
  140. ext = '' if include_fakes else '_nofakes'
  141. with open(f'data/JetPtStudies/{tag}{ext}.txt') as f:
  142. sig = float(regex.findall(f.read())[0])
  143. return sig
  144. sigs_b25 = {n: get_sig(f'v1.0.4_JetPtCut{n}') for n in [40, 35, 30, 25, 20]}
  145. sigs_b20 = {n: get_sig(f'v1.0.4_bJetPtCut20JetPtCut{n}') for n in [40, 35, 30, 25, 20]}
  146. plt.scatter(*zip(*sigs_b25.items()), label="bJetMinPt=25")
  147. for key, val in sigs_b25.items():
  148. plt.text(key, val-0.1, f'{key}: {val:0.3f}')
  149. plt.scatter(*zip(*sigs_b20.items()), label="bJetMinPt=20")
  150. for key, val in sigs_b20.items():
  151. plt.text(key, val+0.1, f'{key}: {val:0.3f}')
  152. plt.ylabel('Significance')
  153. plt.xlabel('Jet $P_T$ Cut')
  154. plt.xticks([20, 25, 30, 35, 40])
  155. plt.xlim((18, 43))
  156. plt.ylim((0, 3.0))
  157. plt.legend()
  158. def main():
  159. sig_o_bg = plot_signal_over_background, (['v1.0.4_JetPtCut40', 'v1.0.4_JetPtCut25', 'v1.0.4_JetPtCut20'],)
  160. sig_o_bg2 = plot_signal_over_background, (['v1.0.4_JetPtCut25', 'v1.0.4_bJetPtCut20JetPtCut25'],)
  161. plots = [
  162. Plot((plot_yields, (25,)), 'Prefit Yields (bJetPt>25)'),
  163. Plot((plot_yields, (20,)), 'Prefit Yields (bJetPt>20)'),
  164. Plot(plot_yields_ratio, 'Prefit Yields Relative to Baseline'),
  165. Plot((plot_sigs, (True,)), 'Significances with Fakes'),
  166. Plot((plot_sigs, (False,)), 'Significances w/o Fakes'),
  167. Plot(sig_o_bg, 'Binned Yields with variable JetPt'),
  168. Plot(sig_o_bg2, 'Binned Yields with variable bJetPt'),
  169. ]
  170. render_plots(plots, to_disk=False)
  171. generate_dashboard(plots, 'New Baseline Yields - Various JetPt Cuts',
  172. output='new_baseline_yields_jet_pt_cuts.html',
  173. source=__file__,
  174. )
  175. if __name__ == '__main__':
  176. main()