yields.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. import matplotlib.pyplot as plt
  3. from filval.result_set import ResultSet
  4. from filval.histogram_utils import hist
  5. from filval.plotter import (decl_plot, render_plots, hist_plot, Plot)
  6. def generate_dashboard(plots, output='dashboard.htm'):
  7. from jinja2 import Environment, PackageLoader, select_autoescape
  8. from os.path import join
  9. from urllib.parse import quote
  10. env = Environment(
  11. loader=PackageLoader('plots', 'templates'),
  12. autoescape=select_autoescape(['htm', 'html', 'xml']),
  13. )
  14. env.globals.update({'quote': quote,
  15. 'enumerate': enumerate,
  16. 'zip': zip,
  17. })
  18. def render_to_file(template_name, **kwargs):
  19. with open(join('output', template_name), 'w') as tempout:
  20. template = env.get_template(template_name)
  21. tempout.write(template.render(**kwargs))
  22. def get_by_n(objs, n=2):
  23. objs = list(objs)
  24. while objs:
  25. yield objs[:n]
  26. objs = objs[n:]
  27. render_to_file(output, plots=get_by_n(plots, 3),
  28. outdir="figures/")
  29. @decl_plot
  30. def plot_yield(rss):
  31. r''' '''
  32. from filval.plotter import StackHist
  33. ft, ttw, ttz, tth = map(lambda rs: hist(rs.SRs), rss)
  34. # ft10 = ft[0]*10, ft[1]*10, ft[2]
  35. sh = StackHist()
  36. sh.add_mc_background(rss[1].SRs, 'TTW')
  37. sh.set_mc_signal(rss[0].SRs, 'TTTT')
  38. sh.draw(plt.gca())
  39. # hist_plot(ft10, include_errors=False, stats=False,
  40. # color='k', label='TTTT (x10)')
  41. # bg =
  42. # plt.hist(
  43. # hist_plot(ttw, include_errors=False, stats=False,
  44. # color='g', label='TTW')
  45. # hist_plot(ttz, include_errors=False, stats=False,
  46. # color='r', label='TTZ')
  47. # hist_plot(tth, include_errors=False, stats=False,
  48. # color='b', label='TTH')
  49. # plt.xlabel('Signal Region')
  50. # plt.legend()
  51. if __name__ == '__main__':
  52. # First create a ResultSet object which loads all of the objects from output.root
  53. # into memory and makes them available as attributes
  54. # if len(sys.argv) != 2:
  55. # raise ValueError("please supply root file")
  56. rss = (ResultSet("ft", 'yield_ft.root'),
  57. ResultSet("ttw", 'yield_ttw.root'),
  58. ResultSet("ttz", 'yield_ttz.root'),
  59. ResultSet("tth", 'yield_tth.root'))
  60. rss_notau = (ResultSet("ft_notau", 'yield_ft_notau.root'),
  61. ResultSet("ttw_notau", 'yield_ttw_notau.root'),
  62. ResultSet("ttz_notau", 'yield_ttz_notau.root'),
  63. ResultSet("tth_notau", 'yield_tth_notau.root'))
  64. # Next, declare all of the (sub)plots that will be assembled into full
  65. # figures later
  66. yield_tau = (plot_yield, (rss,), {})
  67. yield_notau = (plot_yield, (rss_notau,), {})
  68. # Now assemble the plots into figures.
  69. plots = [
  70. Plot([[yield_tau]],
  71. 'Yield With Tau'),
  72. Plot([[yield_notau]],
  73. 'Yield Without Tau'),
  74. ]
  75. # Finally, render and save the plots and generate the html+bootstrap
  76. # dashboard to view them
  77. render_plots(plots, to_disk=False)
  78. generate_dashboard(plots)