utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import io
  2. import sys
  3. from math import ceil, sqrt
  4. from subprocess import run
  5. import itertools as it
  6. import ROOT
  7. class OutputCapture:
  8. def __init__(self):
  9. self.my_stdout = io.StringIO()
  10. self.my_stderr = io.StringIO()
  11. def get_stdout(self):
  12. self.my_stdout.seek(0)
  13. return self.my_stdout.read()
  14. def get_stderr(self):
  15. self.my_stderr.seek(0)
  16. return self.my_stderr.read()
  17. def __enter__(self):
  18. self.stdout = sys.stdout
  19. self.stderr = sys.stderr
  20. sys.stdout = self.my_stdout
  21. sys.stderr = self.my_stderr
  22. def __exit__(self, *args):
  23. sys.stdout = self.stdout
  24. sys.stderr = self.stderr
  25. self.stdout = None
  26. self.stderr = None
  27. def bin_range(n, end=None):
  28. if end is None:
  29. return range(1, n+1)
  30. else:
  31. return range(n+1, end+1)
  32. def normalize_columns(hist2d):
  33. normHist = ROOT.TH2D(hist2d)
  34. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  35. for col in bin_range(cols):
  36. sum_ = 0
  37. for row in bin_range(rows):
  38. sum_ += hist2d.GetBinContent(col, row)
  39. if sum_ == 0:
  40. continue
  41. for row in bin_range(rows):
  42. norm = hist2d.GetBinContent(col, row) / sum_
  43. normHist.SetBinContent(col, row, norm)
  44. return normHist
  45. class HistCollection:
  46. def __init__(self, sample_name, input_filename,
  47. exe_path="../build/main",
  48. rebuild_hists=False):
  49. self.sample_name = sample_name
  50. if rebuild_hists:
  51. run([exe_path, "-s", "-f", input_filename])
  52. output_filename = input_filename.replace(".root", "_result.root")
  53. file = ROOT.TFile.Open(output_filename)
  54. l = file.GetListOfKeys()
  55. self.map = {}
  56. for i in range(l.GetSize()):
  57. name = l.At(i).GetName()
  58. new_name = ":".join((sample_name, name))
  59. obj = file.Get(name)
  60. try:
  61. obj.SetName(new_name)
  62. obj.SetDirectory(0) # disconnects Object from file
  63. except AttributeError:
  64. pass
  65. self.map[name] = obj
  66. setattr(self, name, obj)
  67. file.Close()
  68. # Now add these histograms into the current ROOT directory (in memory)
  69. # and remove old versions if needed
  70. for obj in self.map.values():
  71. try:
  72. old_obj = ROOT.gDirectory.Get(obj.GetName())
  73. ROOT.gDirectory.Remove(old_obj)
  74. ROOT.gDirectory.Add(obj)
  75. except AttributeError:
  76. pass
  77. HistCollection.add_collection(self)
  78. def draw(self, canvas, shape=None):
  79. if shape is None:
  80. n = int(ceil(sqrt(len(self.map))))
  81. shape = (n, n)
  82. canvas.Clear()
  83. canvas.Divide(*shape)
  84. i = 1
  85. for hist in self.map.values():
  86. canvas.cd(i)
  87. try:
  88. hist.SetStats(False)
  89. except AttributeError:
  90. pass
  91. draw_option = ""
  92. if type(hist) in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  93. draw_option = ""
  94. elif type(hist) in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  95. draw_option = "COLZ"
  96. elif type(hist) in (ROOT.TGraph,):
  97. draw_option = "A*"
  98. else:
  99. print("cannot draw object", hist)
  100. continue # Not a drawable type(probably)
  101. hist.Draw(draw_option)
  102. i += 1
  103. @classmethod
  104. def add_collection(cls, hc):
  105. if not hasattr(cls, "collections"):
  106. cls.collections = {}
  107. cls.collections[hc.sample_name] = hc
  108. @classmethod
  109. def stack_hist(hists,
  110. labels=None, id_=None,
  111. title="", enable_fill=False,
  112. normalize_to=0, draw=False,
  113. draw_option="",
  114. _stacks={}):
  115. """hists should be a list of TH1D objects
  116. returns a new stacked histogram
  117. """
  118. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  119. stack = ROOT.THStack(id_, title)
  120. if labels is None:
  121. labels = [hist.GetName() for hist in hists]
  122. if type(normalize_to) in (int, float):
  123. normalize_to = [normalize_to]*len(hists)
  124. if id_ is None:
  125. id_ = hists[0].GetName() + "_stack"
  126. ens = enumerate(zip(hists, labels, colors, normalize_to))
  127. for i, (hist, label, color, norm) in ens:
  128. hist_copy = hist
  129. hist_copy = hist.Clone(hist.GetName()+"_clone")
  130. hist_copy.SetTitle(label)
  131. if enable_fill:
  132. hist_copy.SetFillColorAlpha(color, 0.75)
  133. hist_copy.SetLineColorAlpha(color, 0.75)
  134. if norm:
  135. integral = hist_copy.Integral()
  136. hist_copy.Scale(norm/integral, "nosw2")
  137. hist_copy.SetStats(False)
  138. stack.Add(hist_copy)
  139. if draw:
  140. stack.Draw(draw_option)
  141. _stacks[id_] = stack # prevent stack from getting garbage collected
  142. # needed for multipad plots :/
  143. return stack
  144. @classmethod
  145. def stack_hist_array(canvas, histcollections, fields, titles,
  146. shape=None, **kwargs):
  147. def get_hist_set(attrname):
  148. hists, labels = zip(*[(getattr(h, attrname), h.sample_name)
  149. for h in histcollections])
  150. return hists, labels
  151. n_fields = len(fields)
  152. if shape is None:
  153. if n_fields <= 4:
  154. shape = (1, n_fields)
  155. else:
  156. shape = (ceil(sqrt(n_fields)),)*2
  157. canvas.Clear()
  158. canvas.Divide(*shape)
  159. for i, field, title in zip(bin_range(n_fields), fields, titles):
  160. canvas.cd(i)
  161. hists, labels = get_hist_set(field)
  162. stack_hist(hists, labels, id_=field, title=title, draw=True, **kwargs)
  163. canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")