utils.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import io
  2. import sys
  3. import itertools as it
  4. from os.path import dirname, join, abspath, normpath
  5. from math import ceil, sqrt
  6. from subprocess import run
  7. from IPython.display import Image
  8. import pydotplus.graphviz as pdp
  9. import ROOT
  10. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  11. EXE_PATH = join(PRJ_PATH, "build/main")
  12. PDG = {
  13. 1: 'd',
  14. -1: 'd̄',
  15. 2: 'u',
  16. -2: 'ū',
  17. 3: 's',
  18. -3: 's̄',
  19. 4: 'c',
  20. -4: 'c̄',
  21. 5: 'b',
  22. -5: 'b̄',
  23. 6: 't',
  24. -6: 't̄',
  25. 11: 'e-',
  26. -11: 'e+',
  27. 12: 'ν_e',
  28. -12: 'ῡ_e',
  29. 13: 'μ-',
  30. -13: 'μ+',
  31. 14: 'ν_μ',
  32. -14: 'ῡ_μ',
  33. 15: 'τ-',
  34. -15: 'τ+',
  35. 16: 'ν_τ',
  36. -16: 'ῡ_τ',
  37. 21: 'gluon',
  38. 22: 'γ',
  39. 23: 'Z0',
  40. 24: 'W+',
  41. -24: 'W-',
  42. 25: 'H',
  43. }
  44. def show_event(dataset, idx):
  45. ids = list(dataset.GenPart_pdgId[idx])
  46. nrgs = list(dataset.GenPart_energy[idx])
  47. links = list(dataset.GenPart_motherIndex[idx])
  48. max_nrg = max(nrgs)
  49. nrgs_scaled = [nrg/max_nrg for nrg in nrgs]
  50. g = pdp.Dot()
  51. for i, id_ in enumerate(ids):
  52. color = ",".join(map(str, [nrgs_scaled[i], .7, .8]))
  53. g.add_node(pdp.Node(str(i), label=PDG[id_],
  54. style="filled",
  55. fillcolor=color))
  56. for i, mother in enumerate(links):
  57. if mother != -1:
  58. g.add_edge(pdp.Edge(str(mother), str(i)))
  59. return Image(g.create_gif())
  60. class OutputCapture:
  61. def __init__(self):
  62. self.my_stdout = io.StringIO()
  63. self.my_stderr = io.StringIO()
  64. def get_stdout(self):
  65. self.my_stdout.seek(0)
  66. return self.my_stdout.read()
  67. def get_stderr(self):
  68. self.my_stderr.seek(0)
  69. return self.my_stderr.read()
  70. def __enter__(self):
  71. self.stdout = sys.stdout
  72. self.stderr = sys.stderr
  73. sys.stdout = self.my_stdout
  74. sys.stderr = self.my_stderr
  75. def __exit__(self, *args):
  76. sys.stdout = self.stdout
  77. sys.stderr = self.stderr
  78. self.stdout = None
  79. self.stderr = None
  80. def bin_range(n, end=None):
  81. if end is None:
  82. return range(1, n+1)
  83. else:
  84. return range(n+1, end+1)
  85. def normalize_columns(hist2d):
  86. normHist = ROOT.TH2D(hist2d)
  87. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  88. for col in bin_range(cols):
  89. sum_ = 0
  90. for row in bin_range(rows):
  91. sum_ += hist2d.GetBinContent(col, row)
  92. if sum_ == 0:
  93. continue
  94. for row in bin_range(rows):
  95. norm = hist2d.GetBinContent(col, row) / sum_
  96. normHist.SetBinContent(col, row, norm)
  97. return normHist
  98. class HistCollection:
  99. single_plot_size = (600, 450)
  100. max_width = 1800
  101. scale = .75
  102. x_size = int(1600*scale)
  103. y_size = int(1200*scale)
  104. canvas = ROOT.TCanvas("c1", "", x_size, y_size)
  105. # @property
  106. # def canvas(self):
  107. # cls = self.__class__
  108. # if not hasattr(cls, "_canvas"):
  109. # scale = .75
  110. # x_size = int(1600*scale)
  111. # y_size = int(1200*scale)
  112. # cls._canvas = ROOT.TCanvas("c1", "", x_size, y_size)
  113. # return cls._canvas
  114. def __init__(self, sample_name, input_filename,
  115. rebuild_hists=False):
  116. self.sample_name = sample_name
  117. if rebuild_hists:
  118. run([EXE_PATH, "-s", "-f", input_filename])
  119. output_filename = input_filename.replace(".root", "_result.root")
  120. file = ROOT.TFile.Open(output_filename)
  121. l = file.GetListOfKeys()
  122. self.map = {}
  123. for i in range(l.GetSize()):
  124. name = l.At(i).GetName()
  125. new_name = ":".join((sample_name, name))
  126. obj = file.Get(name)
  127. try:
  128. obj.SetName(new_name)
  129. obj.SetDirectory(0) # disconnects Object from file
  130. except AttributeError:
  131. pass
  132. self.map[name] = obj
  133. setattr(self, name, obj)
  134. file.Close()
  135. # Now add these histograms into the current ROOT directory (in memory)
  136. # and remove old versions if needed
  137. for obj in self.map.values():
  138. try:
  139. old_obj = ROOT.gDirectory.Get(obj.GetName())
  140. ROOT.gDirectory.Remove(old_obj)
  141. ROOT.gDirectory.Add(obj)
  142. except AttributeError:
  143. pass
  144. HistCollection.add_collection(self)
  145. def draw(self, shape=None):
  146. if shape is None:
  147. n_plots = len([obj for obj in self.map.values() if hasattr(obj, "Draw") ])
  148. if n_plots*self.single_plot_size[0] > self.max_width:
  149. shape_x = self.max_width//self.single_plot_size[0]
  150. shape_y = ceil(n_plots / shape_x)
  151. shape = (shape_x, shape_y)
  152. self.canvas.Clear()
  153. self.canvas.SetCanvasSize(shape[0]*self.single_plot_size[0],
  154. shape[1]*self.single_plot_size[1])
  155. self.canvas.Divide(*shape)
  156. i = 1
  157. for hist in self.map.values():
  158. self.canvas.cd(i)
  159. try:
  160. hist.SetStats(False)
  161. except AttributeError:
  162. pass
  163. draw_option = ""
  164. if type(hist) in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  165. draw_option = ""
  166. elif type(hist) in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  167. draw_option = "COLZ"
  168. elif type(hist) in (ROOT.TGraph,):
  169. draw_option = "A*"
  170. else:
  171. # print("cannot draw object", hist)
  172. continue # Not a drawable type(probably)
  173. hist.Draw(draw_option)
  174. i += 1
  175. self.canvas.Draw()
  176. @classmethod
  177. def get_hist_set(cls, attrname):
  178. labels, hists = zip(*[(sample_name, getattr(h, attrname))
  179. for sample_name, h in cls.collections.items()])
  180. return labels, hists
  181. @classmethod
  182. def add_collection(cls, hc):
  183. if not hasattr(cls, "collections"):
  184. cls.collections = {}
  185. cls.collections[hc.sample_name] = hc
  186. @classmethod
  187. def stack_hist(cls,
  188. hist_name,
  189. title="",
  190. enable_fill=False,
  191. normalize_to=0,
  192. draw=False,
  193. draw_canvas=True,
  194. draw_option="",
  195. make_legend=False,
  196. _stacks={}):
  197. labels, hists = cls.get_hist_set(hist_name)
  198. if draw_canvas:
  199. cls.canvas.Clear()
  200. cls.canvas.SetCanvasSize(cls.single_plot_size[0],
  201. cls.single_plot_size[1])
  202. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  203. stack = ROOT.THStack(hist_name+"_stack", title)
  204. if labels is None:
  205. labels = [hist.GetName() for hist in hists]
  206. if type(normalize_to) in (int, float):
  207. normalize_to = [normalize_to]*len(hists)
  208. ens = enumerate(zip(hists, labels, colors, normalize_to))
  209. for i, (hist, label, color, norm) in ens:
  210. hist_copy = hist
  211. hist_copy = hist.Clone(hist.GetName()+"_clone")
  212. hist_copy.SetTitle(label)
  213. if enable_fill:
  214. hist_copy.SetFillColorAlpha(color, 0.75)
  215. hist_copy.SetLineColorAlpha(color, 0.75)
  216. if norm:
  217. integral = hist_copy.Integral()
  218. hist_copy.Scale(norm/integral, "nosw2")
  219. hist_copy.SetStats(False)
  220. stack.Add(hist_copy)
  221. if draw:
  222. stack.Draw(draw_option)
  223. if make_legend:
  224. cls.canvas.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  225. # prevent stack from getting garbage collected
  226. _stacks[stack.GetName()] = stack
  227. if draw_canvas:
  228. cls.canvas.Draw()
  229. return stack
  230. @classmethod
  231. def stack_hist_array(cls,
  232. hist_names,
  233. titles,
  234. shape=None, **kwargs):
  235. n_hists = len(hist_names)
  236. if shape is None:
  237. if n_hists <= 4:
  238. shape = (1, n_hists)
  239. else:
  240. shape = (ceil(sqrt(n_hists)),)*2
  241. cls.canvas.SetCanvasSize(cls.single_plot_size[0]*shape[0],
  242. cls.single_plot_size[1]*shape[1])
  243. cls.canvas.Divide(*shape)
  244. for i, hist_name, title in zip(bin_range(n_hists), hist_names, titles):
  245. cls.canvas.cd(i)
  246. hists, labels = cls.get_hist_set(hist_name)
  247. cls.stack_hist(hist_name, title=title, draw=True,
  248. draw_canvas=False, **kwargs)
  249. cls.canvas.cd(n_hists).BuildLegend(0.75, 0.75, 0.95, 0.95, "")