result_set.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import ROOT
  2. from filval.plotter import hist_plot, hist2d_plot
  3. from numpy import ceil
  4. class ResultSet:
  5. def __init__(self, sample_name, input_filename):
  6. self.sample_name = sample_name
  7. self.input_filename = input_filename
  8. self.load_objects()
  9. ResultSet.add_collection(self)
  10. def load_objects(self):
  11. file = ROOT.TFile.Open(self.input_filename)
  12. l = file.GetListOfKeys()
  13. self.map = {}
  14. try:
  15. self.values = dict(file.Get("_value_lookup"))
  16. except Exception:
  17. self.values = {}
  18. for i in range(l.GetSize()):
  19. name = l.At(i).GetName()
  20. new_name = ":".join((self.sample_name, name))
  21. obj = file.Get(name)
  22. try:
  23. obj.SetName(new_name)
  24. obj.SetDirectory(0) # disconnects Object from file
  25. except AttributeError:
  26. pass
  27. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  28. obj = obj[0]
  29. self.map[name] = obj
  30. setattr(self, name, obj)
  31. file.Close()
  32. # Now add these histograms into the current ROOT directory (in memory)
  33. # and remove old versions if needed
  34. for obj in self.map.values():
  35. try:
  36. old_obj = ROOT.gDirectory.Get(obj.GetName())
  37. ROOT.gDirectory.Remove(old_obj)
  38. ROOT.gDirectory.Add(obj)
  39. except AttributeError:
  40. pass
  41. @classmethod
  42. def calc_shape(cls, n_plots):
  43. if n_plots > 3:
  44. return ceil(n_plots / 3), 3
  45. else:
  46. return 1, n_plots
  47. def draw(self, figure=None, shape=None):
  48. objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
  49. shape = self.calc_shape(len(objs))
  50. if figure is None:
  51. import matplotlib.pyplot as plt
  52. figure = plt.gcf() if plt.gcf() is not None else plt.figure()
  53. figure.clear()
  54. for i, (name, obj) in enumerate(objs):
  55. axes = figure.add_subplot(*shape, i+1)
  56. if isinstance(obj, ROOT.TH2):
  57. hist2d_plot(obj, title=obj.GetTitle(), axes=axes)
  58. else:
  59. hist_plot(obj, title=obj.GetTitle(), axes=axes)
  60. figure.tight_layout()
  61. @classmethod
  62. def get_hist_set(cls, attrname):
  63. return [(sample_name, getattr(h, attrname))
  64. for sample_name, h in cls.collections.items()]
  65. @classmethod
  66. def add_collection(cls, hc):
  67. if not hasattr(cls, "collections"):
  68. cls.collections = {}
  69. cls.collections[hc.sample_name] = hc
  70. def __str__(self):
  71. return self.sample_name+"@"+self.input_filename
  72. def __repr__(self):
  73. return f"<ResultSet: input_filename: {self.input_filename}>"