Explorar el Código

Adds simple_plot function and handling of native config

Caleb Fangmeier hace 6 años
padre
commit
dc8c3cf99b
Se han modificado 2 ficheros con 42 adiciones y 12 borrados
  1. 29 5
      filval/plotting.py
  2. 13 7
      filval/result_set.py

+ 29 - 5
filval/plotting.py

@@ -24,11 +24,16 @@ __all__ = ['Plot',
            'hist_plot',
            'hist_plot_stack',
            'hist2d_plot',
-           'hists_to_table']
+           'hists_to_table',
+           'simple_plot']
 
 
 class Plot:
     def __init__(self, subplots, name, title=None, docs="N/A", arg_dicts=None):
+        if type(subplots) is not list:
+            subplots = [[subplots]]
+        elif len(subplots) > 0 and type(subplots[0]) is not list:
+            subplots = [subplots]
         self.subplots = subplots
         self.name = name
         self.title = title
@@ -89,6 +94,29 @@ def decl_plot(fn):
     return f
 
 
+def simple_plot(thx):
+    import ROOT
+
+    if isinstance(thx, ROOT.TH2):
+        def f(h):
+            hist2d_plot(hist2d(h))
+            plt.xlabel(h.GetXaxis().GetTitle())
+            plt.ylabel(h.GetYaxis().GetTitle())
+            return dict(), "", ""
+
+        return Plot([[(f, (thx,), {})]], thx.GetName())
+    elif isinstance(thx, ROOT.TH1):
+        def f(h):
+            hist_plot(hist(h))
+            plt.xlabel(h.GetXaxis().GetTitle())
+            plt.ylabel(h.GetYaxis().GetTitle())
+            return dict(), "", ""
+
+        return Plot([[(f, (thx,), {})]], thx.GetName())
+    else:
+        raise ValueError("must call simple_plot with a ROOT TH1 or TH2 object")
+
+
 def generate_dashboard(plots, title, output='dashboard.html', template='dashboard.j2',
                        source=None, ana_source=None, config=None):
     from jinja2 import Environment, PackageLoader, select_autoescape
@@ -115,10 +143,6 @@ def generate_dashboard(plots, title, output='dashboard.html', template='dashboar
         with open(source, 'r') as f:
             source = f.read()
 
-    if config is not None:
-        with open(config, 'r') as f:
-            config = f.read()
-
     if not isdir('output'):
         mkdir('output')
 

+ 13 - 7
filval/result_set.py

@@ -1,6 +1,6 @@
 import ROOT
 
-from filval.plotter import hist_plot, hist2d_plot
+from filval.plotting import hist_plot, hist2d_plot
 from numpy import ceil
 
 
@@ -9,20 +9,26 @@ class ResultSet:
     def __init__(self, sample_name, input_filename):
         self.sample_name = sample_name
         self.input_filename = input_filename
+        self.values = {}
+        self.map = {}
+        self.config = None
         self.load_objects()
 
         ResultSet.add_collection(self)
 
     def load_objects(self):
         file = ROOT.TFile.Open(self.input_filename)
-        l = file.GetListOfKeys()
-        self.map = {}
         try:
             self.values = dict(file.Get("_value_lookup"))
-        except Exception:
-            self.values = {}
-        for i in range(l.GetSize()):
-            name = l.At(i).GetName()
+        except TypeError:
+            pass
+        try:
+            self.config = str(file.Get("_config").GetString())
+        except TypeError:
+            pass
+        list_of_keys = file.GetListOfKeys()
+        for i in range(list_of_keys.GetSize()):
+            name = list_of_keys.At(i).GetName()
             new_name = ":".join((self.sample_name, name))
             obj = file.Get(name)
             try: