result_set.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import ROOT
  2. from filval.plotter import hist_plot, hist2d_plot
  3. from numpy import ceil
  4. class ResultSet:
  5. def __init__(self, sample_name, input_filename):
  6. self.sample_name = sample_name
  7. self.input_filename = input_filename
  8. self.load_objects()
  9. ResultSet.add_collection(self)
  10. def load_objects(self):
  11. file = ROOT.TFile.Open(self.input_filename)
  12. l = file.GetListOfKeys()
  13. self.map = {}
  14. self.values = dict(file.Get("_value_lookup"))
  15. for i in range(l.GetSize()):
  16. name = l.At(i).GetName()
  17. new_name = ":".join((self.sample_name, name))
  18. obj = file.Get(name)
  19. try:
  20. obj.SetName(new_name)
  21. obj.SetDirectory(0) # disconnects Object from file
  22. except AttributeError:
  23. pass
  24. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  25. obj = obj[0]
  26. self.map[name] = obj
  27. setattr(self, name, obj)
  28. file.Close()
  29. # Now add these histograms into the current ROOT directory (in memory)
  30. # and remove old versions if needed
  31. for obj in self.map.values():
  32. try:
  33. old_obj = ROOT.gDirectory.Get(obj.GetName())
  34. ROOT.gDirectory.Remove(old_obj)
  35. ROOT.gDirectory.Add(obj)
  36. except AttributeError:
  37. pass
  38. @classmethod
  39. def calc_shape(cls, n_plots):
  40. if n_plots > 3:
  41. return ceil(n_plots / 3), 3
  42. else:
  43. return 1, n_plots
  44. def draw(self, figure=None, shape=None):
  45. objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
  46. shape = self.calc_shape(len(objs))
  47. if figure is None:
  48. import matplotlib.pyplot as plt
  49. figure = plt.gcf() if plt.gcf() is not None else plt.figure()
  50. figure.clear()
  51. for i, (name, obj) in enumerate(objs):
  52. axes = figure.add_subplot(*shape, i+1)
  53. if isinstance(obj, ROOT.TH2):
  54. hist2d_plot(obj, title=obj.GetTitle(), axes=axes)
  55. else:
  56. hist_plot(obj, title=obj.GetTitle(), axes=axes)
  57. figure.tight_layout()
  58. @classmethod
  59. def get_hist_set(cls, attrname):
  60. return [(sample_name, getattr(h, attrname))
  61. for sample_name, h in cls.collections.items()]
  62. @classmethod
  63. def add_collection(cls, hc):
  64. if not hasattr(cls, "collections"):
  65. cls.collections = {}
  66. cls.collections[hc.sample_name] = hc
  67. def __str__(self):
  68. return self.sample_name+"@"+self.input_filename
  69. def __repr__(self):
  70. return f"<ResultSet: input_filename: {self.input_filename}>"