result_set.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import ROOT
  2. from plotter import plot_histogram, plot_histogram2d
  3. class ResultSet:
  4. def __init__(self, sample_name, input_filename):
  5. self.sample_name = sample_name
  6. self.input_filename = input_filename
  7. self.load_objects()
  8. ResultSet.add_collection(self)
  9. def load_objects(self):
  10. file = ROOT.TFile.Open(self.input_filename)
  11. l = file.GetListOfKeys()
  12. self.map = {}
  13. self.values = dict(file.Get("_value_lookup"))
  14. for i in range(l.GetSize()):
  15. name = l.At(i).GetName()
  16. new_name = ":".join((self.sample_name, name))
  17. obj = file.Get(name)
  18. try:
  19. obj.SetName(new_name)
  20. obj.SetDirectory(0) # disconnects Object from file
  21. except AttributeError:
  22. pass
  23. if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
  24. obj = obj[0]
  25. self.map[name] = obj
  26. setattr(self, name, obj)
  27. file.Close()
  28. # Now add these histograms into the current ROOT directory (in memory)
  29. # and remove old versions if needed
  30. for obj in self.map.values():
  31. try:
  32. old_obj = ROOT.gDirectory.Get(obj.GetName())
  33. ROOT.gDirectory.Remove(old_obj)
  34. ROOT.gDirectory.Add(obj)
  35. except AttributeError:
  36. pass
  37. @classmethod
  38. def calc_shape(cls, n_plots):
  39. if n_plots > 3:
  40. return ceil(n_plots / 3), 3
  41. else:
  42. return 1, n_plots
  43. def draw(self, figure=None, shape=None):
  44. objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
  45. shape = self.calc_shape(len(objs))
  46. if figure is None:
  47. import matplotlib.pyplot as plt
  48. figure = plt.gcf() if plt.gcf() is not None else plt.figure()
  49. figure.clear()
  50. for i, (name, obj) in enumerate(objs):
  51. axes = figure.add_subplot(*shape, i+1)
  52. if isinstance(obj, ROOT.TH2):
  53. plot_histogram2d(obj, title=obj.GetTitle(), axes=axes)
  54. else:
  55. plot_histogram(obj, title=obj.GetTitle(), axes=axes)
  56. figure.tight_layout()
  57. @classmethod
  58. def get_hist_set(cls, attrname):
  59. return [(sample_name, getattr(h, attrname))
  60. for sample_name, h in cls.collections.items()]
  61. @classmethod
  62. def add_collection(cls, hc):
  63. if not hasattr(cls, "collections"):
  64. cls.collections = {}
  65. cls.collections[hc.sample_name] = hc