utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import io
  2. import sys
  3. from os.path import dirname, join, abspath, normpath
  4. from math import floor, ceil, sqrt
  5. import ROOT
  6. from graph_vals import parse
  7. from plotter import plot_histogram, plot_histogram2d
  8. PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
  9. EXE_PATH = join(PRJ_PATH, "build/main")
  10. PDG = {1: 'd', -1: 'd̄',
  11. 2: 'u', -2: 'ū',
  12. 3: 's', -3: 's̄',
  13. 4: 'c', -4: 'c̄',
  14. 5: 'b', -5: 'b̄',
  15. 6: 't', -6: 't̄',
  16. 11: 'e-', -11: 'e+',
  17. 12: 'ν_e', -12: 'ῡ_e',
  18. 13: 'μ-', -13: 'μ+',
  19. 14: 'ν_μ', -14: 'ῡ_μ',
  20. 15: 'τ-', -15: 'τ+',
  21. 16: 'ν_τ', -16: 'ῡ_τ',
  22. 21: 'g',
  23. 22: 'γ',
  24. 23: 'Z0',
  25. 24: 'W+', -24: 'W-',
  26. 25: 'H',
  27. }
  28. def get_color(val, max_val, min_val=0):
  29. val = (val-min_val)/(max_val-min_val)
  30. val = round(val * (ROOT.gStyle.GetNumberOfColors()-1))
  31. col_idx = ROOT.gStyle.GetColorPalette(val)
  32. col = ROOT.gROOT.GetColor(col_idx)
  33. r = floor(256*col.GetRed())
  34. g = floor(256*col.GetGreen())
  35. b = floor(256*col.GetBlue())
  36. gs = (r + g + b)//3
  37. text_color = 'white' if gs < 100 else 'black'
  38. return '#{:02x}{:02x}{:02x}'.format(r, g, b), text_color
  39. def show_function(dataset, fname):
  40. from IPython.display import Markdown
  41. def md_single(fname_):
  42. impl = dataset._function_impl_lookup[fname_]
  43. return '*{}*\n-----\n```cpp\n{}\n```\n\n---'.format(fname_, impl)
  44. try:
  45. return Markdown('\n'.join(md_single(fname_) for fname_ in iter(fname)))
  46. except TypeError:
  47. return Markdown(md_single(fname))
  48. def show_value(dataset, container):
  49. from IPython.display import Image
  50. if type(container) != str:
  51. container = container.GetName().split(':')[1]
  52. g, functions = parse(dataset.values[container], container)
  53. try:
  54. return Image(g.create_gif()), show_function(dataset, functions)
  55. except Exception as e:
  56. print(e)
  57. print(g.to_string())
  58. def normalize_columns(hist2d):
  59. normHist = ROOT.TH2D(hist2d)
  60. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  61. for col in range(1, cols+1):
  62. sum_ = 0
  63. for row in range(1, rows+1):
  64. sum_ += hist2d.GetBinContent(col, row)
  65. if sum_ == 0:
  66. continue
  67. for row in range(1, rows+1):
  68. norm = hist2d.GetBinContent(col, row) / sum_
  69. normHist.SetBinContent(col, row, norm)
  70. return normHist
  71. class ResultSet:
  72. def __init__(self, sample_name, input_filename):
  73. self.sample_name = sample_name
  74. self.input_filename = input_filename
  75. self.load_objects()
  76. ResultSet.add_collection(self)
  77. def load_objects(self):
  78. file = ROOT.TFile.Open(self.input_filename)
  79. l = file.GetListOfKeys()
  80. self.map = {}
  81. self.values = dict(file.Get("_value_lookup"))
  82. for i in range(l.GetSize()):
  83. name = l.At(i).GetName()
  84. new_name = ":".join((self.sample_name, name))
  85. obj = file.Get(name)
  86. try:
  87. obj.SetName(new_name)
  88. obj.SetDirectory(0) # disconnects Object from file
  89. except AttributeError:
  90. pass
  91. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  92. obj = obj[0]
  93. self.map[name] = obj
  94. setattr(self, name, obj)
  95. file.Close()
  96. # Now add these histograms into the current ROOT directory (in memory)
  97. # and remove old versions if needed
  98. for obj in self.map.values():
  99. try:
  100. old_obj = ROOT.gDirectory.Get(obj.GetName())
  101. ROOT.gDirectory.Remove(old_obj)
  102. ROOT.gDirectory.Add(obj)
  103. except AttributeError:
  104. pass
  105. @classmethod
  106. def calc_shape(cls, n_plots):
  107. if n_plots > 3:
  108. return ceil(n_plots / 3), 3
  109. else:
  110. return 1, n_plots
  111. def draw(self, figure=None, shape=None):
  112. objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
  113. shape = self.calc_shape(len(objs))
  114. if figure is None:
  115. import matplotlib.pyplot as plt
  116. figure = plt.gcf() if plt.gcf() is not None else plt.figure()
  117. figure.clear()
  118. for i, (name, obj) in enumerate(objs):
  119. axes = figure.add_subplot(*shape, i+1)
  120. if isinstance(obj, ROOT.TH2):
  121. plot_histogram2d(obj, title=obj.GetTitle(), axes=axes)
  122. else:
  123. plot_histogram(obj, title=obj.GetTitle(), axes=axes)
  124. figure.tight_layout()
  125. @classmethod
  126. def get_hist_set(cls, attrname):
  127. return [(sample_name, getattr(h, attrname))
  128. for sample_name, h in cls.collections.items()]
  129. @classmethod
  130. def add_collection(cls, hc):
  131. if not hasattr(cls, "collections"):
  132. cls.collections = {}
  133. cls.collections[hc.sample_name] = hc