result_set.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import ROOT
  2. from filval.plotting import hist_plot, hist2d_plot
  3. from numpy import ceil
  4. class ResultSet:
  5. def __init__(self, sample_name, input_filename):
  6. self.sample_name = sample_name
  7. self.input_filename = input_filename
  8. self.values = {}
  9. self.map = {}
  10. self.config = None
  11. self.load_objects()
  12. ResultSet.add_collection(self)
  13. def load_objects(self):
  14. file = ROOT.TFile.Open(self.input_filename)
  15. try:
  16. self.values = dict(file.Get("_value_lookup"))
  17. except TypeError:
  18. pass
  19. try:
  20. self.config = str(file.Get("_config").GetString())
  21. except TypeError:
  22. pass
  23. list_of_keys = file.GetListOfKeys()
  24. for i in range(list_of_keys.GetSize()):
  25. name = list_of_keys.At(i).GetName()
  26. new_name = ":".join((self.sample_name, name))
  27. obj = file.Get(name)
  28. try:
  29. obj.SetName(new_name)
  30. obj.SetDirectory(0) # disconnects Object from file
  31. except AttributeError:
  32. pass
  33. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  34. obj = obj[0]
  35. self.map[name] = obj
  36. setattr(self, name, obj)
  37. file.Close()
  38. # Now add these histograms into the current ROOT directory (in memory)
  39. # and remove old versions if needed
  40. for obj in self.map.values():
  41. try:
  42. old_obj = ROOT.gDirectory.Get(obj.GetName())
  43. ROOT.gDirectory.Remove(old_obj)
  44. ROOT.gDirectory.Add(obj)
  45. except AttributeError:
  46. pass
  47. @classmethod
  48. def calc_shape(cls, n_plots):
  49. if n_plots > 3:
  50. return ceil(n_plots / 3), 3
  51. else:
  52. return 1, n_plots
  53. def draw(self, figure=None, shape=None):
  54. objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
  55. shape = self.calc_shape(len(objs))
  56. if figure is None:
  57. import matplotlib.pyplot as plt
  58. figure = plt.gcf() if plt.gcf() is not None else plt.figure()
  59. figure.clear()
  60. for i, (name, obj) in enumerate(objs):
  61. axes = figure.add_subplot(*shape, i+1)
  62. if isinstance(obj, ROOT.TH2):
  63. hist2d_plot(obj, title=obj.GetTitle(), axes=axes)
  64. else:
  65. hist_plot(obj, title=obj.GetTitle(), axes=axes)
  66. figure.tight_layout()
  67. @classmethod
  68. def get_hist_set(cls, attrname):
  69. return [(sample_name, getattr(h, attrname))
  70. for sample_name, h in cls.collections.items()]
  71. @classmethod
  72. def add_collection(cls, hc):
  73. if not hasattr(cls, "collections"):
  74. cls.collections = {}
  75. cls.collections[hc.sample_name] = hc
  76. def __str__(self):
  77. return self.sample_name+"@"+self.input_filename
  78. def __repr__(self):
  79. return f"<ResultSet: input_filename: {self.input_filename}>"