utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import io
  2. from os.path import dirname, join, abspath, normpath
  3. import sys
  4. from math import ceil, sqrt
  5. from subprocess import run
  6. import itertools as it
  7. import ROOT
  8. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  9. EXE_PATH = join(PRJ_PATH, "build/main")
  10. class OutputCapture:
  11. def __init__(self):
  12. self.my_stdout = io.StringIO()
  13. self.my_stderr = io.StringIO()
  14. def get_stdout(self):
  15. self.my_stdout.seek(0)
  16. return self.my_stdout.read()
  17. def get_stderr(self):
  18. self.my_stderr.seek(0)
  19. return self.my_stderr.read()
  20. def __enter__(self):
  21. self.stdout = sys.stdout
  22. self.stderr = sys.stderr
  23. sys.stdout = self.my_stdout
  24. sys.stderr = self.my_stderr
  25. def __exit__(self, *args):
  26. sys.stdout = self.stdout
  27. sys.stderr = self.stderr
  28. self.stdout = None
  29. self.stderr = None
  30. def bin_range(n, end=None):
  31. if end is None:
  32. return range(1, n+1)
  33. else:
  34. return range(n+1, end+1)
  35. def normalize_columns(hist2d):
  36. normHist = ROOT.TH2D(hist2d)
  37. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  38. for col in bin_range(cols):
  39. sum_ = 0
  40. for row in bin_range(rows):
  41. sum_ += hist2d.GetBinContent(col, row)
  42. if sum_ == 0:
  43. continue
  44. for row in bin_range(rows):
  45. norm = hist2d.GetBinContent(col, row) / sum_
  46. normHist.SetBinContent(col, row, norm)
  47. return normHist
  48. class HistCollection:
  49. def __init__(self, sample_name, input_filename,
  50. rebuild_hists=False):
  51. self.sample_name = sample_name
  52. if rebuild_hists:
  53. run([EXE_PATH, "-s", "-f", input_filename])
  54. output_filename = input_filename.replace(".root", "_result.root")
  55. file = ROOT.TFile.Open(output_filename)
  56. l = file.GetListOfKeys()
  57. self.map = {}
  58. for i in range(l.GetSize()):
  59. name = l.At(i).GetName()
  60. new_name = ":".join((sample_name, name))
  61. obj = file.Get(name)
  62. try:
  63. obj.SetName(new_name)
  64. obj.SetDirectory(0) # disconnects Object from file
  65. except AttributeError:
  66. pass
  67. self.map[name] = obj
  68. setattr(self, name, obj)
  69. file.Close()
  70. # Now add these histograms into the current ROOT directory (in memory)
  71. # and remove old versions if needed
  72. for obj in self.map.values():
  73. try:
  74. old_obj = ROOT.gDirectory.Get(obj.GetName())
  75. ROOT.gDirectory.Remove(old_obj)
  76. ROOT.gDirectory.Add(obj)
  77. except AttributeError:
  78. pass
  79. HistCollection.add_collection(self)
  80. def draw(self, shape=None):
  81. if shape is None:
  82. n = int(ceil(sqrt(len(self.map))))
  83. shape = (n, n)
  84. self.canvas.Clear()
  85. self.canvas.Divide(*shape)
  86. i = 1
  87. for hist in self.map.values():
  88. self.canvas.cd(i)
  89. try:
  90. hist.SetStats(False)
  91. except AttributeError:
  92. pass
  93. draw_option = ""
  94. if type(hist) in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  95. draw_option = ""
  96. elif type(hist) in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  97. draw_option = "COLZ"
  98. elif type(hist) in (ROOT.TGraph,):
  99. draw_option = "A*"
  100. else:
  101. print("cannot draw object", hist)
  102. continue # Not a drawable type(probably)
  103. hist.Draw(draw_option)
  104. i += 1
  105. self.canvas.Draw()
  106. @classmethod
  107. def get_hist_set(cls, attrname):
  108. labels, hists = zip(*[(sample_name, getattr(h, attrname))
  109. for sample_name, h in cls.collections.items()])
  110. return labels, hists
  111. @classmethod
  112. def add_collection(cls, hc):
  113. if not hasattr(cls, "collections"):
  114. cls.collections = {}
  115. cls.collections[hc.sample_name] = hc
  116. print("collection added: " + hc.sample_name)
  117. print("collections present: " + ', '.join(list(hc.collections.keys())))
  118. @property
  119. def canvas(self):
  120. cls = self.__class__
  121. if not hasattr(cls, "_canvas"):
  122. cls._canvas = ROOT.TCanvas("c1", "", 1600, 1200)
  123. return cls._canvas
  124. @canvas.setter
  125. def canvas(self, canvas):
  126. cls = self.__class__
  127. cls._canvas = canvas
  128. @classmethod
  129. def stack_hist(cls,
  130. hist_name,
  131. title="", enable_fill=False,
  132. normalize_to=0, draw=False,
  133. draw_option="",
  134. make_legend=False,
  135. _stacks={}):
  136. labels, hists = cls.get_hist_set(hist_name)
  137. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  138. stack = ROOT.THStack(hist_name+"_stack", title)
  139. if labels is None:
  140. labels = [hist.GetName() for hist in hists]
  141. if type(normalize_to) in (int, float):
  142. normalize_to = [normalize_to]*len(hists)
  143. ens = enumerate(zip(hists, labels, colors, normalize_to))
  144. for i, (hist, label, color, norm) in ens:
  145. hist_copy = hist
  146. hist_copy = hist.Clone(hist.GetName()+"_clone")
  147. hist_copy.SetTitle(label)
  148. if enable_fill:
  149. hist_copy.SetFillColorAlpha(color, 0.75)
  150. hist_copy.SetLineColorAlpha(color, 0.75)
  151. if norm:
  152. integral = hist_copy.Integral()
  153. hist_copy.Scale(norm/integral, "nosw2")
  154. hist_copy.SetStats(False)
  155. stack.Add(hist_copy)
  156. if draw:
  157. stack.Draw(draw_option)
  158. if make_legend:
  159. cls._canvas.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  160. # cls._canvas.Draw()
  161. # prevent stack from getting garbage collected
  162. _stacks[stack.GetName()] = stack
  163. return stack
  164. @classmethod
  165. def stack_hist_array(cls,
  166. hist_names,
  167. titles,
  168. shape=None, **kwargs):
  169. n_hists = len(hist_names)
  170. if shape is None:
  171. if n_hists <= 4:
  172. shape = (1, n_hists)
  173. else:
  174. shape = (ceil(sqrt(n_hists)),)*2
  175. cls._canvas.Divide(*shape)
  176. for i, hist_name, title in zip(bin_range(n_hists), hist_names, titles):
  177. cls._canvas.cd(i)
  178. hists, labels = cls.get_hist_set(hist_name)
  179. cls.stack_hist(hist_name, title=title, draw=True, **kwargs)
  180. cls._canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")