utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. import io
  2. import os
  3. import sys
  4. import itertools as it
  5. from os.path import dirname, join, abspath, normpath, getctime
  6. from math import ceil, floor, sqrt
  7. from collections import deque
  8. from subprocess import run, PIPE
  9. from IPython.display import Image
  10. import pydotplus.graphviz as pdp
  11. import ROOT
  12. from graph_vals import parse
  13. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  14. EXE_PATH = join(PRJ_PATH, "build/main")
  15. PDG = {1: 'd', -1: 'd̄',
  16. 2: 'u', -2: 'ū',
  17. 3: 's', -3: 's̄',
  18. 4: 'c', -4: 'c̄',
  19. 5: 'b', -5: 'b̄',
  20. 6: 't', -6: 't̄',
  21. 11: 'e-', -11: 'e+',
  22. 12: 'ν_e', -12: 'ῡ_e',
  23. 13: 'μ-', -13: 'μ+',
  24. 14: 'ν_μ', -14: 'ῡ_μ',
  25. 15: 'τ-', -15: 'τ+',
  26. 16: 'ν_τ', -16: 'ῡ_τ',
  27. 21: 'g',
  28. 22: 'γ',
  29. 23: 'Z0',
  30. 24: 'W+', -24: 'W-',
  31. 25: 'H',
  32. }
  33. SINGLE_PLOT_SIZE = (600, 450)
  34. MAX_WIDTH = 1800
  35. SCALE = .75
  36. CAN_SIZE_DEF = (int(1600*SCALE), int(1200*SCALE))
  37. CANVAS = ROOT.TCanvas("c1", "", *CAN_SIZE_DEF)
  38. ROOT.gStyle.SetPalette(112) # set the "virdidis" color map
  39. VALUES = {}
  40. def clear():
  41. CANVAS.Clear()
  42. CANVAS.SetCanvasSize(*CAN_SIZE_DEF)
  43. def get_color(val, max_val, min_val = 0):
  44. val = (val-min_val)/(max_val-min_val)
  45. val = round(val * (ROOT.gStyle.GetNumberOfColors()-1))
  46. col_idx = ROOT.gStyle.GetColorPalette(val)
  47. col = ROOT.gROOT.GetColor(col_idx)
  48. r = floor(256*col.GetRed())
  49. g = floor(256*col.GetGreen())
  50. b = floor(256*col.GetBlue())
  51. gs = (r + g + b)//3
  52. text_color = 'white' if gs < 100 else 'black'
  53. return '#{:02x}{:02x}{:02x}'.format(r, g, b), text_color
  54. def show_event(dataset, idx):
  55. ids = list(dataset.GenPart_pdgId[idx])
  56. stats = list(dataset.GenPart_status[idx])
  57. energies = list(dataset.GenPart_energy[idx])
  58. links = list(dataset.GenPart_motherIndex[idx])
  59. max_energy = max(energies)
  60. g = pdp.Dot()
  61. for i, id_ in enumerate(ids):
  62. color, text_color = get_color(energies[i], max_energy)
  63. shape = "ellipse" if stats[i] in (1, 23) else "invhouse"
  64. label = "{}({})".format(PDG[id_], stats[i])
  65. # label = PDG[id_]+"({:03e})".format(energies[i])
  66. g.add_node(pdp.Node(str(i), label=label,
  67. style="filled",
  68. shape=shape,
  69. fontcolor=text_color,
  70. fillcolor=color))
  71. for i, mother in enumerate(links):
  72. if mother != -1:
  73. g.add_edge(pdp.Edge(str(mother), str(i)))
  74. return Image(g.create_gif())
  75. def show_function(dataset, fname):
  76. from IPython.display import Markdown
  77. def md_single(fname_):
  78. impl = dataset._function_impl_lookup[fname_]
  79. return '*{}*\n-----\n```cpp\n{}\n```\n\n---'.format(fname_, impl)
  80. try:
  81. return Markdown('\n'.join(md_single(fname_) for fname_ in iter(fname)))
  82. except TypeError:
  83. return Markdown(md_single(fname))
  84. def show_value(dataset, container):
  85. if type(container) != str:
  86. container = container.GetName().split(':')[1]
  87. g, functions = parse(VALUES[container], container)
  88. try:
  89. return Image(g.create_gif()), show_function(dataset, functions)
  90. except Exception as e:
  91. print(e)
  92. print(g.to_string())
  93. class OutputCapture:
  94. def __init__(self):
  95. self.my_stdout = io.StringIO()
  96. self.my_stderr = io.StringIO()
  97. def get_stdout(self):
  98. self.my_stdout.seek(0)
  99. return self.my_stdout.read()
  100. def get_stderr(self):
  101. self.my_stderr.seek(0)
  102. return self.my_stderr.read()
  103. def __enter__(self):
  104. self.stdout = sys.stdout
  105. self.stderr = sys.stderr
  106. sys.stdout = self.my_stdout
  107. sys.stderr = self.my_stderr
  108. def __exit__(self, *args):
  109. sys.stdout = self.stdout
  110. sys.stderr = self.stderr
  111. self.stdout = None
  112. self.stderr = None
  113. def normalize_columns(hist2d):
  114. normHist = ROOT.TH2D(hist2d)
  115. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  116. for col in range(1, cols+1):
  117. sum_ = 0
  118. for row in range(1, rows+1):
  119. sum_ += hist2d.GetBinContent(col, row)
  120. if sum_ == 0:
  121. continue
  122. for row in range(1, rows+1):
  123. norm = hist2d.GetBinContent(col, row) / sum_
  124. normHist.SetBinContent(col, row, norm)
  125. return normHist
  126. class ResultSet:
  127. def __init__(self, sample_name, input_filename):
  128. self.sample_name = sample_name
  129. self.input_filename = input_filename
  130. self.output_filename = self.input_filename.replace(".root", "_result.root")
  131. self.conditional_recompute()
  132. self.load_objects()
  133. ResultSet.add_collection(self)
  134. def conditional_recompute(self):
  135. def recompute():
  136. print("Running analysis for sample: ", self.sample_name)
  137. if run([EXE_PATH, "-s", "-f", self.input_filename, "-l", self.sample_name]).returncode != 0:
  138. raise RuntimeError(("Failed running analysis code."
  139. " See log file for more information"))
  140. if run(["make", "main"], cwd=join(PRJ_PATH, "build"), stdout=PIPE, stderr=PIPE).returncode != 0:
  141. raise RuntimeError("Failed recompiling analysis code")
  142. if (not os.path.isfile(self.output_filename) or (getctime(EXE_PATH) > getctime(self.output_filename))):
  143. recompute()
  144. else:
  145. print("Loading unchanged result file ", self.output_filename)
  146. def load_objects(self):
  147. file = ROOT.TFile.Open(self.output_filename)
  148. l = file.GetListOfKeys()
  149. self.map = {}
  150. VALUES.update(dict(file.Get("_value_lookup")))
  151. for i in range(l.GetSize()):
  152. name = l.At(i).GetName()
  153. new_name = ":".join((self.sample_name, name))
  154. obj = file.Get(name)
  155. try:
  156. obj.SetName(new_name)
  157. obj.SetDirectory(0) # disconnects Object from file
  158. except AttributeError:
  159. pass
  160. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  161. obj = obj[0]
  162. self.map[name] = obj
  163. setattr(self, name, obj)
  164. file.Close()
  165. # Now add these histograms into the current ROOT directory (in memory)
  166. # and remove old versions if needed
  167. for obj in self.map.values():
  168. try:
  169. old_obj = ROOT.gDirectory.Get(obj.GetName())
  170. ROOT.gDirectory.Remove(old_obj)
  171. ROOT.gDirectory.Add(obj)
  172. except AttributeError:
  173. pass
  174. self.xsec = self._xsecs[self.sample_name]*1000 # convert to fb
  175. self.event_count = self._event_counts[self.sample_name]
  176. self.lumi = self.event_count / self.xsec
  177. @classmethod
  178. def calc_shape(cls, n_plots):
  179. if n_plots*SINGLE_PLOT_SIZE[0] > MAX_WIDTH:
  180. shape_x = MAX_WIDTH//SINGLE_PLOT_SIZE[0]
  181. shape_y = ceil(n_plots / shape_x)
  182. return (shape_x, shape_y)
  183. else:
  184. return (n_plots, 1)
  185. def draw(self, shape=None):
  186. objs = [obj for obj in self.map.values() if hasattr(obj, "Draw")]
  187. if shape is None:
  188. n_plots = len(objs)
  189. shape = self.calc_shape(n_plots)
  190. CANVAS.Clear()
  191. CANVAS.SetCanvasSize(shape[0]*SINGLE_PLOT_SIZE[0], shape[1]*SINGLE_PLOT_SIZE[1])
  192. CANVAS.Divide(*shape)
  193. i = 1
  194. for hist in objs:
  195. CANVAS.cd(i)
  196. try:
  197. hist.SetStats(False)
  198. except AttributeError:
  199. pass
  200. if type(hist) in (ROOT.TH1I, ROOT.TH1F, ROOT.TH1D):
  201. hist.SetMinimum(0)
  202. hist.Draw(self.get_draw_option(hist))
  203. i += 1
  204. CANVAS.Draw()
  205. @staticmethod
  206. def get_draw_option(obj):
  207. obj_type = type(obj)
  208. if obj_type in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  209. return ""
  210. elif obj_type in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  211. return "COLZ"
  212. elif obj_type in (ROOT.TGraph,):
  213. return "A*"
  214. else:
  215. return None
  216. @classmethod
  217. def get_hist_set(cls, attrname):
  218. labels, hists = zip(*[(sample_name, getattr(h, attrname))
  219. for sample_name, h in cls.collections.items()])
  220. return labels, hists
  221. @classmethod
  222. def add_collection(cls, hc):
  223. if not hasattr(cls, "collections"):
  224. cls.collections = {}
  225. cls.collections[hc.sample_name] = hc
  226. @classmethod
  227. def stack_hist(cls,
  228. hist_name,
  229. title="",
  230. enable_fill=False,
  231. normalize_to=0,
  232. draw=False,
  233. draw_canvas=True,
  234. draw_option="",
  235. make_legend=False,
  236. _stacks={}):
  237. labels, hists = cls.get_hist_set(hist_name)
  238. if draw_canvas:
  239. CANVAS.Clear()
  240. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0],
  241. SINGLE_PLOT_SIZE[1])
  242. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen, ROOT.kYellow])
  243. stack = ROOT.THStack(hist_name+"_stack", title)
  244. if labels is None:
  245. labels = [hist.GetName() for hist in hists]
  246. if type(normalize_to) in (int, float):
  247. normalize_to = [normalize_to]*len(hists)
  248. ens = enumerate(zip(hists, labels, colors, normalize_to))
  249. for i, (hist, label, color, norm) in ens:
  250. hist_copy = hist
  251. hist_copy = hist.Clone(hist.GetName()+"_clone" + draw_option)
  252. hist_copy.SetTitle(label)
  253. if enable_fill:
  254. hist_copy.SetFillColorAlpha(color, 0.75)
  255. hist_copy.SetLineColorAlpha(color, 0.75)
  256. if norm:
  257. integral = hist_copy.Integral()
  258. hist_copy.Scale(norm/integral, "nosw2")
  259. hist_copy.SetStats(True)
  260. stack.Add(hist_copy)
  261. if draw:
  262. stack.Draw(draw_option)
  263. if make_legend:
  264. CANVAS.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  265. # prevent stack from getting garbage collected
  266. _stacks[stack.GetName()] = stack
  267. if draw_canvas:
  268. CANVAS.Draw()
  269. return stack
  270. @classmethod
  271. def stack_hist_array(cls,
  272. hist_names,
  273. titles,
  274. shape=None, **kwargs):
  275. n_hist = len(hist_names)
  276. if shape is None:
  277. if n_hist <= 4:
  278. shape = (1, n_hist)
  279. else:
  280. shape = (ceil(sqrt(n_hist)),)*2
  281. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  282. SINGLE_PLOT_SIZE[1]*shape[1])
  283. CANVAS.Divide(*shape)
  284. for i, hist_name, title in zip(range(1, n_hist+1), hist_names, titles):
  285. CANVAS.cd(i)
  286. cls.stack_hist(hist_name, title=title, draw=True,
  287. draw_canvas=False, **kwargs)
  288. CANVAS.cd(n_hist).BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  289. pts = deque([], 50)
  290. @classmethod
  291. def hist_array_single(cls,
  292. hist_name,
  293. title=None,
  294. **kwargs):
  295. n_hist = len(cls.collections)
  296. shape = cls.calc_shape(n_hist)
  297. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  298. SINGLE_PLOT_SIZE[1]*shape[1])
  299. CANVAS.Divide(*shape)
  300. labels, hists = cls.get_hist_set(hist_name)
  301. def pave_loc():
  302. hist.Get
  303. for i, label, hist in zip(range(1, n_hist+1), labels, hists):
  304. CANVAS.cd(i)
  305. hist.SetStats(False)
  306. hist.Draw(cls.get_draw_option(hist))
  307. pt = ROOT.TPaveText(0.70, 0.87, 0.85, 0.95, "NDC")
  308. pt.AddText("Dataset: "+label)
  309. pt.Draw()
  310. cls.pts.append(pt)