utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import io
  2. import sys
  3. from math import ceil, sqrt
  4. import itertools as it
  5. import ROOT
  6. class OutputCapture:
  7. def __init__(self):
  8. self.my_stdout = io.StringIO()
  9. self.my_stderr = io.StringIO()
  10. def get_stdout(self):
  11. self.my_stdout.seek(0)
  12. return self.my_stdout.read()
  13. def get_stderr(self):
  14. self.my_stderr.seek(0)
  15. return self.my_stderr.read()
  16. def __enter__(self):
  17. self.stdout = sys.stdout
  18. self.stderr = sys.stderr
  19. sys.stdout = self.my_stdout
  20. sys.stderr = self.my_stderr
  21. def __exit__(self, *args):
  22. sys.stdout = self.stdout
  23. sys.stderr = self.stderr
  24. self.stdout = None
  25. self.stderr = None
  26. def bin_range(n, end=None):
  27. if end is None:
  28. return range(1, n+1)
  29. else:
  30. return range(n+1, end+1)
  31. def normalize_columns(hist2d):
  32. normHist = ROOT.TH2D(hist2d)
  33. cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
  34. for col in bin_range(cols):
  35. sum_ = 0
  36. for row in bin_range(rows):
  37. sum_ += hist2d.GetBinContent(col, row)
  38. if sum_ == 0:
  39. continue
  40. for row in bin_range(rows):
  41. norm = hist2d.GetBinContent(col, row) / sum_
  42. normHist.SetBinContent(col, row, norm)
  43. return normHist
  44. def stack_hist(hists,
  45. labels=None, id_=None,
  46. title="", enable_fill=False,
  47. normalize_to=0, draw=False,
  48. draw_option="",
  49. _stacks={}):
  50. """hists should be a list of TH1D objects
  51. returns a new stacked histogram
  52. """
  53. colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
  54. stack = ROOT.THStack(id_, title)
  55. if labels is None:
  56. labels = [hist.GetName() for hist in hists]
  57. if type(normalize_to) in (int, float):
  58. normalize_to = [normalize_to]*len(hists)
  59. if id_ is None:
  60. id_ = hists[0].GetName() + "_stack"
  61. ens = enumerate(zip(hists, labels, colors, normalize_to))
  62. for i, (hist, label, color, norm) in ens:
  63. hist_copy = hist
  64. hist_copy = hist.Clone(hist.GetName()+"_clone")
  65. hist_copy.SetTitle(label)
  66. if enable_fill:
  67. hist_copy.SetFillColorAlpha(color, 0.75)
  68. hist_copy.SetLineColorAlpha(color, 0.75)
  69. if norm:
  70. integral = hist_copy.Integral()
  71. hist_copy.Scale(norm/integral, "nosw2")
  72. hist_copy.SetStats(False)
  73. stack.Add(hist_copy)
  74. if draw:
  75. stack.Draw(draw_option)
  76. _stacks[id_] = stack # prevent stack from getting garbage collected
  77. # needed for multipad plots :/
  78. return stack
  79. def stack_hist_array(canvas, histcollections, fields, titles,
  80. shape=None, **kwargs):
  81. def get_hist_set(attrname):
  82. hists, labels = zip(*[(getattr(h, attrname), h.get_sample_name())
  83. for h in histcollections])
  84. return hists, labels
  85. n_fields = len(fields)
  86. if shape is None:
  87. if n_fields <= 4:
  88. shape = (1, n_fields)
  89. else:
  90. shape = (ceil(sqrt(n_fields)),)*2
  91. canvas.Clear()
  92. canvas.Divide(*shape)
  93. for i, field, title in zip(bin_range(n_fields), fields, titles):
  94. canvas.cd(i)
  95. hists, labels = get_hist_set(field)
  96. stack_hist(hists, labels, id_=field, title=title, draw=True, **kwargs)
  97. canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")