utils.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import io
  2. import sys
  3. import itertools as it
  4. from os.path import dirname, join, abspath, normpath
  5. from math import ceil, floor, sqrt
  6. from collections import deque
  7. from IPython.display import Image
  8. import ROOT
  9. from graph_vals import parse
  10. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  11. EXE_PATH = join(PRJ_PATH, "build/main")
  12. PDG = {1: 'd', -1: 'd̄',
  13. 2: 'u', -2: 'ū',
  14. 3: 's', -3: 's̄',
  15. 4: 'c', -4: 'c̄',
  16. 5: 'b', -5: 'b̄',
  17. 6: 't', -6: 't̄',
  18. 11: 'e-', -11: 'e+',
  19. 12: 'ν_e', -12: 'ῡ_e',
  20. 13: 'μ-', -13: 'μ+',
  21. 14: 'ν_μ', -14: 'ῡ_μ',
  22. 15: 'τ-', -15: 'τ+',
  23. 16: 'ν_τ', -16: 'ῡ_τ',
  24. 21: 'g',
  25. 22: 'γ',
  26. 23: 'Z0',
  27. 24: 'W+', -24: 'W-',
  28. 25: 'H',
  29. }
  30. SINGLE_PLOT_SIZE = (600, 450)
  31. MAX_WIDTH = 1800
  32. SCALE = .75
  33. CAN_SIZE_DEF = (int(1600*SCALE), int(1200*SCALE))
  34. CANVAS = ROOT.TCanvas("c1", "", *CAN_SIZE_DEF)
  35. ROOT.gStyle.SetPalette(112) # set the "virdidis" color map
  36. VALUES = {}
  37. def clear():
  38. CANVAS.Clear()
  39. CANVAS.SetCanvasSize(*CAN_SIZE_DEF)
  40. def get_color(val, max_val, min_val=0):
  41. val = (val-min_val)/(max_val-min_val)
  42. val = round(val * (ROOT.gStyle.GetNumberOfColors()-1))
  43. col_idx = ROOT.gStyle.GetColorPalette(val)
  44. col = ROOT.gROOT.GetColor(col_idx)
  45. r = floor(256*col.GetRed())
  46. g = floor(256*col.GetGreen())
  47. b = floor(256*col.GetBlue())
  48. gs = (r + g + b)//3
  49. text_color = 'white' if gs < 100 else 'black'
  50. return '#{:02x}{:02x}{:02x}'.format(r, g, b), text_color
  51. def show_function(dataset, fname):
  52. from IPython.display import Markdown
  53. def md_single(fname_):
  54. impl = dataset._function_impl_lookup[fname_]
  55. return '*{}*\n-----\n```cpp\n{}\n```\n\n---'.format(fname_, impl)
  56. try:
  57. return Markdown('\n'.join(md_single(fname_) for fname_ in iter(fname)))
  58. except TypeError:
  59. return Markdown(md_single(fname))
  60. def show_value(dataset, container):
  61. if type(container) != str:
  62. container = container.GetName().split(':')[1]
  63. g, functions = parse(VALUES[container], container)
  64. try:
  65. return Image(g.create_gif()), show_function(dataset, functions)
  66. except Exception as e:
  67. print(e)
  68. print(g.to_string())
  69. class OutputCapture:
  70. def __init__(self):
  71. self.my_stdout = io.StringIO()
  72. self.my_stderr = io.StringIO()
  73. def get_stdout(self):
  74. self.my_stdout.seek(0)
  75. return self.my_stdout.read()
  76. def get_stderr(self):
  77. self.my_stderr.seek(0)
  78. return self.my_stderr.read()
  79. def __enter__(self):
  80. self.stdout = sys.stdout
  81. self.stderr = sys.stderr
  82. sys.stdout = self.my_stdout
  83. sys.stderr = self.my_stderr
  84. def __exit__(self, *args):
  85. sys.stdout = self.stdout
  86. sys.stderr = self.stderr
  87. self.stdout = None
  88. self.stderr = None
  89. def normalize_columns(hist2d):
  90. normHist = ROOT.TH2D(hist2d)
  91. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  92. for col in range(1, cols+1):
  93. sum_ = 0
  94. for row in range(1, rows+1):
  95. sum_ += hist2d.GetBinContent(col, row)
  96. if sum_ == 0:
  97. continue
  98. for row in range(1, rows+1):
  99. norm = hist2d.GetBinContent(col, row) / sum_
  100. normHist.SetBinContent(col, row, norm)
  101. return normHist
  102. class ResultSet:
  103. def __init__(self, sample_name, input_filename):
  104. self.sample_name = sample_name
  105. self.input_filename = input_filename
  106. # self.output_filename = self.input_filename.replace(".root", "_result.root")
  107. # self.conditional_recompute()
  108. self.load_objects()
  109. ResultSet.add_collection(self)
  110. def load_objects(self):
  111. file = ROOT.TFile.Open(self.input_filename)
  112. l = file.GetListOfKeys()
  113. self.map = {}
  114. VALUES.update(dict(file.Get("_value_lookup")))
  115. for i in range(l.GetSize()):
  116. name = l.At(i).GetName()
  117. new_name = ":".join((self.sample_name, name))
  118. obj = file.Get(name)
  119. try:
  120. obj.SetName(new_name)
  121. obj.SetDirectory(0) # disconnects Object from file
  122. except AttributeError:
  123. pass
  124. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  125. obj = obj[0]
  126. self.map[name] = obj
  127. setattr(self, name, obj)
  128. file.Close()
  129. # Now add these histograms into the current ROOT directory (in memory)
  130. # and remove old versions if needed
  131. for obj in self.map.values():
  132. try:
  133. old_obj = ROOT.gDirectory.Get(obj.GetName())
  134. ROOT.gDirectory.Remove(old_obj)
  135. ROOT.gDirectory.Add(obj)
  136. except AttributeError:
  137. pass
  138. @classmethod
  139. def calc_shape(cls, n_plots):
  140. if n_plots*SINGLE_PLOT_SIZE[0] > MAX_WIDTH:
  141. shape_x = MAX_WIDTH//SINGLE_PLOT_SIZE[0]
  142. shape_y = ceil(n_plots / shape_x)
  143. return (shape_x, shape_y)
  144. else:
  145. return (n_plots, 1)
  146. def draw(self, shape=None):
  147. objs = [obj for obj in self.map.values() if hasattr(obj, "Draw")]
  148. if shape is None:
  149. n_plots = len(objs)
  150. shape = self.calc_shape(n_plots)
  151. CANVAS.Clear()
  152. CANVAS.SetCanvasSize(shape[0]*SINGLE_PLOT_SIZE[0], shape[1]*SINGLE_PLOT_SIZE[1])
  153. CANVAS.Divide(*shape)
  154. i = 1
  155. for hist in objs:
  156. CANVAS.cd(i)
  157. try:
  158. hist.SetStats(False)
  159. except AttributeError:
  160. pass
  161. if type(hist) in (ROOT.TH1I, ROOT.TH1F, ROOT.TH1D):
  162. hist.SetMinimum(0)
  163. hist.Draw(self.get_draw_option(hist))
  164. i += 1
  165. CANVAS.Draw()
  166. @staticmethod
  167. def get_draw_option(obj):
  168. obj_type = type(obj)
  169. if obj_type in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
  170. return ""
  171. elif obj_type in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
  172. return "COLZ"
  173. elif obj_type in (ROOT.TGraph,):
  174. return "A*"
  175. else:
  176. return None
  177. @classmethod
  178. def get_hist_set(cls, attrname):
  179. labels, hists = zip(*[(sample_name, getattr(h, attrname))
  180. for sample_name, h in cls.collections.items()])
  181. return labels, hists
  182. @classmethod
  183. def add_collection(cls, hc):
  184. if not hasattr(cls, "collections"):
  185. cls.collections = {}
  186. cls.collections[hc.sample_name] = hc
  187. @classmethod
  188. def stack_hist(cls,
  189. hist_name,
  190. title="",
  191. enable_fill=False,
  192. normalize_to=0,
  193. draw=False,
  194. draw_canvas=True,
  195. draw_option="",
  196. make_legend=False,
  197. _stacks={}):
  198. labels, hists = cls.get_hist_set(hist_name)
  199. if draw_canvas:
  200. CANVAS.Clear()
  201. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0],
  202. SINGLE_PLOT_SIZE[1])
  203. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen, ROOT.kYellow])
  204. stack = ROOT.THStack(hist_name+"_stack", title)
  205. if labels is None:
  206. labels = [hist.GetName() for hist in hists]
  207. if type(normalize_to) in (int, float):
  208. normalize_to = [normalize_to]*len(hists)
  209. ens = enumerate(zip(hists, labels, colors, normalize_to))
  210. for i, (hist, label, color, norm) in ens:
  211. hist_copy = hist
  212. hist_copy = hist.Clone(hist.GetName()+"_clone" + draw_option)
  213. hist_copy.SetTitle(label)
  214. if enable_fill:
  215. hist_copy.SetFillColorAlpha(color, 0.75)
  216. hist_copy.SetLineColorAlpha(color, 0.75)
  217. if norm:
  218. integral = hist_copy.Integral()
  219. hist_copy.Scale(norm/integral, "nosw2")
  220. hist_copy.SetStats(True)
  221. stack.Add(hist_copy)
  222. if draw:
  223. stack.Draw(draw_option)
  224. if make_legend:
  225. CANVAS.BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  226. # prevent stack from getting garbage collected
  227. _stacks[stack.GetName()] = stack
  228. if draw_canvas:
  229. CANVAS.Draw()
  230. return stack
  231. @classmethod
  232. def stack_hist_array(cls,
  233. hist_names,
  234. titles,
  235. shape=None, **kwargs):
  236. n_hist = len(hist_names)
  237. if shape is None:
  238. if n_hist <= 4:
  239. shape = (1, n_hist)
  240. else:
  241. shape = (ceil(sqrt(n_hist)),)*2
  242. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  243. SINGLE_PLOT_SIZE[1]*shape[1])
  244. CANVAS.Divide(*shape)
  245. for i, hist_name, title in zip(range(1, n_hist+1), hist_names, titles):
  246. CANVAS.cd(i)
  247. cls.stack_hist(hist_name, title=title, draw=True,
  248. draw_canvas=False, **kwargs)
  249. CANVAS.cd(n_hist).BuildLegend(0.75, 0.75, 0.95, 0.95, "")
  250. pts = deque([], 50)
  251. @classmethod
  252. def hist_array_single(cls,
  253. hist_name,
  254. title=None,
  255. **kwargs):
  256. n_hist = len(cls.collections)
  257. shape = cls.calc_shape(n_hist)
  258. CANVAS.SetCanvasSize(SINGLE_PLOT_SIZE[0]*shape[0],
  259. SINGLE_PLOT_SIZE[1]*shape[1])
  260. CANVAS.Divide(*shape)
  261. labels, hists = cls.get_hist_set(hist_name)
  262. def pave_loc():
  263. hist.Get
  264. for i, label, hist in zip(range(1, n_hist+1), labels, hists):
  265. CANVAS.cd(i)
  266. hist.SetStats(False)
  267. hist.Draw(cls.get_draw_option(hist))
  268. pt = ROOT.TPaveText(0.70, 0.87, 0.85, 0.95, "NDC")
  269. pt.AddText("Dataset: "+label)
  270. pt.Draw()
  271. cls.pts.append(pt)