utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import io
  2. import sys
  3. from math import ceil, sqrt
  4. from subprocess import run
  5. import itertools as it
  6. import ROOT
  7. class HistCollection:
  8. def __init__(self, sample_name, input_filename,
  9. exe_path="../build/main",
  10. rebuild_hists = False):
  11. self.sample_name = sample_name
  12. if rebuild_hists:
  13. run([exe_path, "-s", "-f", input_filename])
  14. output_filename = input_filename.replace(".root", "_result.root")
  15. self._file = ROOT.TFile.Open(output_filename)
  16. l = self._file.GetListOfKeys()
  17. self.map = {}
  18. for i in range(l.GetSize()):
  19. name = l.At(i).GetName()
  20. self.map[name] = self._file.Get(name)
  21. setattr(self, name, self.map[name])
  22. def draw(self, canvas, shape=None):
  23. if shape is None:
  24. n = int(ceil(sqrt(len(self.map))))
  25. shape = (n, n)
  26. print(shape)
  27. canvas.Clear()
  28. canvas.Divide(*shape)
  29. for i, hist in enumerate(self.map.values()):
  30. canvas.cd(i+1)
  31. try:
  32. hist.SetStats(False)
  33. except AttributeError:
  34. pass
  35. print(i, hist, str(type(hist)))
  36. draw_option = ""
  37. if (type(hist) == ROOT.TH2F):
  38. draw_option = "COLZ"
  39. hist.Draw(draw_option)
  40. class OutputCapture:
  41. def __init__(self):
  42. self.my_stdout = io.StringIO()
  43. self.my_stderr = io.StringIO()
  44. def get_stdout(self):
  45. self.my_stdout.seek(0)
  46. return self.my_stdout.read()
  47. def get_stderr(self):
  48. self.my_stderr.seek(0)
  49. return self.my_stderr.read()
  50. def __enter__(self):
  51. self.stdout = sys.stdout
  52. self.stderr = sys.stderr
  53. sys.stdout = self.my_stdout
  54. sys.stderr = self.my_stderr
  55. def __exit__(self, *args):
  56. sys.stdout = self.stdout
  57. sys.stderr = self.stderr
  58. self.stdout = None
  59. self.stderr = None
  60. def bin_range(n, end=None):
  61. if end is None:
  62. return range(1, n+1)
  63. else:
  64. return range(n+1, end+1)
  65. def normalize_columns(hist2d):
  66. normHist = ROOT.TH2D(hist2d)
  67. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  68. for col in bin_range(cols):
  69. sum_ = 0
  70. for row in bin_range(rows):
  71. sum_ += hist2d.GetBinContent(col, row)
  72. if sum_ == 0:
  73. continue
  74. for row in bin_range(rows):
  75. norm = hist2d.GetBinContent(col, row) / sum_
  76. normHist.SetBinContent(col, row, norm)
  77. return normHist
  78. def stack_hist(hists,
  79. labels=None, id_=None,
  80. title="", enable_fill=False,
  81. normalize_to=0, draw=False,
  82. draw_option="",
  83. _stacks={}):
  84. """hists should be a list of TH1D objects
  85. returns a new stacked histogram
  86. """
  87. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  88. stack = ROOT.THStack(id_, title)
  89. if labels is None:
  90. labels = [hist.GetName() for hist in hists]
  91. if type(normalize_to) in (int, float):
  92. normalize_to = [normalize_to]*len(hists)
  93. if id_ is None:
  94. id_ = hists[0].GetName() + "_stack"
  95. ens = enumerate(zip(hists, labels, colors, normalize_to))
  96. for i, (hist, label, color, norm) in ens:
  97. hist_copy = hist
  98. hist_copy = hist.Clone(hist.GetName()+"_clone")
  99. hist_copy.SetTitle(label)
  100. if enable_fill:
  101. hist_copy.SetFillColorAlpha(color, 0.75)
  102. hist_copy.SetLineColorAlpha(color, 0.75)
  103. if norm:
  104. integral = hist_copy.Integral()
  105. hist_copy.Scale(norm/integral, "nosw2")
  106. hist_copy.SetStats(False)
  107. stack.Add(hist_copy)
  108. if draw:
  109. stack.Draw(draw_option)
  110. _stacks[id_] = stack # prevent stack from getting garbage collected
  111. # needed for multipad plots :/
  112. return stack
  113. def stack_hist_array(canvas, histcollections, fields, titles,
  114. shape=None, **kwargs):
  115. def get_hist_set(attrname):
  116. hists, labels = zip(*[(getattr(h, attrname), h.sample_name)
  117. for h in histcollections])
  118. return hists, labels
  119. n_fields = len(fields)
  120. if shape is None:
  121. if n_fields <= 4:
  122. shape = (1, n_fields)
  123. else:
  124. shape = (ceil(sqrt(n_fields)),)*2
  125. canvas.Clear()
  126. canvas.Divide(*shape)
  127. for i, field, title in zip(bin_range(n_fields), fields, titles):
  128. canvas.cd(i)
  129. hists, labels = get_hist_set(field)
  130. stack_hist(hists, labels, id_=field, title=title, draw=True, **kwargs)
  131. canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")