utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import io
  2. import os
  3. import sys
  4. import itertools as it
  5. from os.path import dirname, join, abspath, normpath, getctime
  6. from math import ceil, floor, sqrt
  7. from subprocess import run, PIPE
  8. from IPython.display import Image
  9. import pydotplus.graphviz as pdp
  10. import ROOT
  11. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  12. EXE_PATH = join(PRJ_PATH, "build/main")
  13. PDG = {1: 'd', -1: 'd̄',
  14. 2: 'u', -2: 'ū',
  15. 3: 's', -3: 's̄',
  16. 4: 'c', -4: 'c̄',
  17. 5: 'b', -5: 'b̄',
  18. 6: 't', -6: 't̄',
  19. 11: 'e-', -11: 'e+',
  20. 12: 'ν_e', -12: 'ῡ_e',
  21. 13: 'μ-', -13: 'μ+',
  22. 14: 'ν_μ', -14: 'ῡ_μ',
  23. 15: 'τ-', -15: 'τ+',
  24. 16: 'ν_τ', -16: 'ῡ_τ',
  25. 21: 'gluon',
  26. 22: 'γ',
  27. 23: 'Z0',
  28. 24: 'W+', -24: 'W-',
  29. 25: 'H',
  30. }
  31. SINGLE_PLOT_SIZE = (600, 450)
  32. MAX_WIDTH = 1800
  33. SCALE = .75
  34. CAN_SIZE_DEF = (int(1600*SCALE), int(1200*SCALE))
  35. CANVAS = ROOT.TCanvas("c1", "", *CAN_SIZE_DEF)
  36. ROOT.gStyle.SetPalette(112) # set the "virdidis" color map
  37. def clear():
  38. CANVAS.Clear()
  39. CANVAS.SetCanvasSize(*CAN_SIZE_DEF)
  40. def get_color(val, max_val, min_val = 0):
  41. val = (val-min_val)/(max_val-min_val)
  42. val = round(val * (ROOT.gStyle.GetNumberOfColors()-1))
  43. col_idx = ROOT.gStyle.GetColorPalette(val)
  44. col = ROOT.gROOT.GetColor(col_idx)
  45. r = floor(256*col.GetRed())
  46. g = floor(256*col.GetGreen())
  47. b = floor(256*col.GetBlue())
  48. gs = (r + g + b)//3
  49. text_color = 'white' if gs < 100 else 'black'
  50. return '#{:02x}{:02x}{:02x}'.format(r, g, b), text_color
  51. def show_event(dataset, idx):
  52. ids = list(dataset.GenPart_pdgId[idx])
  53. stats = list(dataset.GenPart_status[idx])
  54. energies = list(dataset.GenPart_energy[idx])
  55. links = list(dataset.GenPart_motherIndex[idx])
  56. max_energy = max(energies)
  57. g = pdp.Dot()
  58. for i, id_ in enumerate(ids):
  59. color, text_color = get_color(energies[i], max_energy)
  60. shape = "ellipse" if stats[i] in (1, 23) else "invhouse"
  61. label = "{}({})".format(PDG[id_], stats[i])
  62. # label = PDG[id_]+"({:03e})".format(energies[i])
  63. g.add_node(pdp.Node(str(i), label=label,
  64. style="filled",
  65. shape=shape,
  66. fontcolor=text_color,
  67. fillcolor=color))
  68. for i, mother in enumerate(links):
  69. if mother != -1:
  70. g.add_edge(pdp.Edge(str(mother), str(i)))
  71. return Image(g.create_gif())
  72. class OutputCapture:
  73. def __init__(self):
  74. self.my_stdout = io.StringIO()
  75. self.my_stderr = io.StringIO()
  76. def get_stdout(self):
  77. self.my_stdout.seek(0)
  78. return self.my_stdout.read()
  79. def get_stderr(self):
  80. self.my_stderr.seek(0)
  81. return self.my_stderr.read()
  82. def __enter__(self):
  83. self.stdout = sys.stdout
  84. self.stderr = sys.stderr
  85. sys.stdout = self.my_stdout
  86. sys.stderr = self.my_stderr
  87. def __exit__(self, *args):
  88. sys.stdout = self.stdout
  89. sys.stderr = self.stderr
  90. self.stdout = None
  91. self.stderr = None
  92. def normalize_columns(hist2d):
  93. normHist = ROOT.TH2D(hist2d)
  94. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  95. for col in range(1, cols+1):
  96. sum_ = 0
  97. for row in range(1, rows+1):
  98. sum_ += hist2d.GetBinContent(col, row)
  99. if sum_ == 0:
  100. continue
  101. for row in range(1, rows+1):
  102. norm = hist2d.GetBinContent(col, row) / sum_
  103. normHist.SetBinContent(col, row, norm)
  104. return normHist
  105. class HistCollection:
  106. def __init__(self, sample_name, input_filename):
  107. self.sample_name = sample_name
  108. self.input_filename = input_filename
  109. self.output_filename = self.input_filename.replace(".root", "_result.root")
  110. self.conditional_recompute()
  111. self.load_objects()
  112. # Now add these histograms into the current ROOT directory (in memory)
  113. # and remove old versions if needed
  114. for obj in self.map.values():
  115. try:
  116. old_obj = ROOT.gDirectory.Get(obj.GetName())
  117. ROOT.gDirectory.Remove(old_obj)
  118. ROOT.gDirectory.Add(obj)
  119. except AttributeError:
  120. pass
  121. HistCollection.add_collection(self)
  122. def conditional_recompute(self):
  123. def recompute():
  124. print("Running analysis for sample: ", self.sample_name)
  125. if run([EXE_PATH, "-s", "-f", self.input_filename]).returncode != 0:
  126. raise RuntimeError(("Failed running analysis code."
  127. " See log file for more information"))
  128. if run(["make"], cwd=join(PRJ_PATH, "build"), stdout=PIPE, stderr=PIPE).returncode != 0:
  129. raise RuntimeError("Failed recompiling analysis code")
  130. if (not os.path.isfile(self.output_filename) or (getctime(EXE_PATH) > getctime(self.output_filename))):
  131. recompute()
  132. else:
  133. print("Loading unchanged result file ", self.output_filename)
  134. def load_objects(self):
  135. file = ROOT.TFile.Open(self.output_filename)
  136. l = file.GetListOfKeys()
  137. self.map = {}
  138. for i in range(l.GetSize()):
  139. name = l.At(i).GetName()
  140. new_name = ":".join((self.sample_name, name))
  141. obj = file.Get(name)
  142. try:
  143. obj.SetName(new_name)
  144. obj.SetDirectory(0) # disconnects Object from file
  145. except AttributeError:
  146. pass
  147. self.map[name] = obj
  148. setattr(self, name, obj)
  149. file.Close()
  150. @classmethod
  151. def calc_shape(cls, n_plots):
  152. if n_plots*SINGLE_PLOT_SIZE[0] > MAX_WIDTH:
  153. shape_x = MAX_WIDTH//SINGLE_PLOT_SIZE[0]
  154. shape_y = ceil(n_plots / shape_x)
  155. return (shape_x, shape_y)
  156. else:
  157. return (n_plots, 1)
  158. def draw(self, shape=None):
  159. objs = [obj for obj in self.map.values() if hasattr(obj, "Draw")]
  160. if shape is None:
  161. n_plots = len(objs)
  162. shape = self.calc_shape(n_plots)
  163. CANVAS.Clear()
  164. CANVAS.SetCanvasSize(shape[0]*SINGLE_PLOT_SIZE[0], shape[1]*SINGLE_PLOT_SIZE[1])
  165. CANVAS.Divide(*shape)
  166. i = 1
  167. for hist in objs:
  168. CANVAS.cd(i)
  169. try:
  170. hist.SetStats(False)
  171. except AttributeError:
  172. pass
  173. hist.Draw(self.get_draw_option(hist))
  174. i += 1
  175. CANVAS.Draw()
  176. @staticmethod
  177. def get_draw_option(obj):
  178. obj_type = type(obj)
  179. if obj_type in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  180. return ""
  181. elif obj_type in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  182. return "COLZ"
  183. elif obj_type in (ROOT.TGraph,):
  184. return "A*"
  185. else:
  186. return None
  187. @classmethod
  188. def get_hist_set(cls, attrname):
  189. labels, hists = zip(*[(sample_name, getattr(h, attrname))
  190. for sample_name, h in cls.collections.items()])
  191. return labels, hists
  192. @classmethod
  193. def add_collection(cls, hc):
  194. if not hasattr(cls, "collections"):
  195. cls.collections = {}
  196. cls.collections[hc.sample_name] = hc
  197. @classmethod
  198. def stack_hist(cls,
  199. hist_name,
  200. title="",
  201. enable_fill=False,
  202. normalize_to=0,
  203. draw=False,
  204. draw_canvas=True,
  205. draw_option="",
  206. make_legend=False,
  207. _stacks={}):
  208. labels, hists = cls.get_hist_set(hist_name)
  209. if draw_canvas:
  210. CANVAS.Clear()
  211. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0],
  212. SINGLE_PLOT_SIZE[1])
  213. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  214. stack = ROOT.THStack(hist_name+"_stack", title)
  215. if labels is None:
  216. labels = [hist.GetName() for hist in hists]
  217. if type(normalize_to) in (int, float):
  218. normalize_to = [normalize_to]*len(hists)
  219. ens = enumerate(zip(hists, labels, colors, normalize_to))
  220. for i, (hist, label, color, norm) in ens:
  221. hist_copy = hist
  222. hist_copy = hist.Clone(hist.GetName()+"_clone" + draw_option)
  223. hist_copy.SetTitle(label)
  224. if enable_fill:
  225. hist_copy.SetFillColorAlpha(color, 0.75)
  226. hist_copy.SetLineColorAlpha(color, 0.75)
  227. if norm:
  228. integral = hist_copy.Integral()
  229. hist_copy.Scale(norm/integral, "nosw2")
  230. hist_copy.SetStats(True)
  231. stack.Add(hist_copy)
  232. if draw:
  233. stack.Draw(draw_option)
  234. if make_legend:
  235. CANVAS.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  236. # prevent stack from getting garbage collected
  237. _stacks[stack.GetName()] = stack
  238. if draw_canvas:
  239. CANVAS.Draw()
  240. return stack
  241. @classmethod
  242. def stack_hist_array(cls,
  243. hist_names,
  244. titles,
  245. shape=None, **kwargs):
  246. n_hists = len(hist_names)
  247. if shape is None:
  248. if n_hists <= 4:
  249. shape = (1, n_hists)
  250. else:
  251. shape = (ceil(sqrt(n_hists)),)*2
  252. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  253. SINGLE_PLOT_SIZE[1]*shape[1])
  254. CANVAS.Divide(*shape)
  255. for i, hist_name, title in zip(range(1, n_hists+1), hist_names, titles):
  256. CANVAS.cd(i)
  257. cls.stack_hist(hist_name, title=title, draw=True,
  258. draw_canvas=False, **kwargs)
  259. CANVAS.cd(n_hists).BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  260. pts = []
  261. @classmethod
  262. def hist_array_single(cls,
  263. hist_name,
  264. title=None,
  265. **kwargs):
  266. n_hists = len(cls.collections)
  267. shape = cls.calc_shape(n_hists)
  268. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  269. SINGLE_PLOT_SIZE[1]*shape[1])
  270. CANVAS.Divide(*shape)
  271. labels, hists = cls.get_hist_set(hist_name)
  272. for i, label, hist in zip(range(1, n_hists+1), labels, hists):
  273. pt = ROOT.TPaveText(300, 3, 400, 3.5)
  274. CANVAS.cd(i)
  275. hist.SetStats(False)
  276. hist.Draw(cls.get_draw_option(hist))
  277. pt.AddText("Dataset: "+label)
  278. pt.Draw()
  279. cls.pts.append(pt)