Pārlūkot izejas kodu

updates to python tooling

Caleb Fangmeier 6 gadi atpakaļ
vecāks
revīzija
a35beda421
7 mainītis faili ar 106 papildinājumiem un 197 dzēšanām
  1. 0 66
      canvas_wrapper.py
  2. 4 2
      graph_vals.py
  3. 1 1
      filval_merge.py
  4. 2 2
      process_parallel.py
  5. 5 0
      requirements.txt
  6. 76 0
      result_set.py
  7. 18 126
      utils.py

+ 0 - 66
canvas_wrapper.py

@@ -1,66 +0,0 @@
-"""
-Helper module for displaying ROOT canvases in ipython notebooks
-
-Usage example:
-# Save this file as rootnotes.py to your working directory.
-import rootnotes
-c1 = rootnotes.canvas()
-fun1 = TF1( 'fun1', 'abs(sin(x)/x)', 0, 10)
-c1.SetGridx()
-c1.SetGridy()
-fun1.Draw()
-c1
-
-More examples: http://mazurov.github.io/webfest2013/
-
-@author alexander.mazurov@cern.ch
-@author andrey.ustyuzhanin@cern.ch
-@date 2013-08-09
-"""
-
-import ROOT
-ROOT.gROOT.SetBatch()
-
-import tempfile
-from IPython.core import display
-
-
-def canvas(name="icanvas", size=(800, 600)):
-    """Helper method for creating canvas"""
-
-    assert len(size) == 2
-    if len(size) != 2:
-        raise ValueError("Size must be a 2 element tuple")
-    # Check if icanvas already exists
-    canvas = ROOT.gROOT.FindObject(name)
-    if canvas:
-        return canvas
-    else:
-        return ROOT.TCanvas(name, name, size[0], size[1])
-
-
-def _display_canvas(canvas):
-    file = tempfile.NamedTemporaryFile(suffix=".png")
-    canvas.SaveAs(file.name)
-    ip_img = display.Image(filename=file.name, format='png', embed=True)
-    return ip_img._repr_png_()
-
-
-def _display_any(obj):
-    file = tempfile.NamedTemporaryFile(suffix=".png")
-    obj.Draw()
-    ROOT.gPad.SaveAs(file.name)
-    ip_img = display.Image(filename=file.name, format='png', embed=True)
-    return ip_img._repr_png_()
-
-
-# register display function with PNG formatter:
-png_formatter = get_ipython().display_formatter.formatters['image/png'] # noqa
-
-# Register ROOT types in ipython
-#
-# In [1]: canvas = canvas_wrapper.canvas()
-# In [2]: canvas
-# Out [2]: [image will be here]
-png_formatter.for_type(ROOT.TCanvas, _display_canvas)
-png_formatter.for_type(ROOT.TF1, _display_any)

+ 4 - 2
graph_vals.py

@@ -1,9 +1,9 @@
 import pydotplus.graphviz as pdp
-import sys
-import re
 
 
 def parse(str_in, alias=None):
+    """ Creates a call-tree for the supplied value name
+    """
     str_in = "("+str_in+")"
 
     functions = []
@@ -81,6 +81,8 @@ def parse(str_in, alias=None):
 
 
 if __name__ == '__main__':
+    import re
+    import sys
     aliases = {}
     ali_re = re.compile(r"ALIAS::\"([^\"]*)\" referring to \"([^\"]*)\"")
 

+ 1 - 1
filval_merge.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+#!env/bin/python
 import argparse
 import re
 import os

+ 2 - 2
process_parallel.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+#!env/bin/python
 from os import listdir
 from os.path import join, isdir
 import argparse
@@ -7,7 +7,7 @@ import subprocess
 import multiprocessing
 from multiprocessing.pool import Pool
 
-from filval_merge import merge_files
+from merge import merge_files
 
 
 def run_job(job_number, executable, files):

+ 5 - 0
requirements.txt

@@ -0,0 +1,5 @@
+numpy
+matplotlib
+ipython
+jupyter
+pydotplus

+ 76 - 0
result_set.py

@@ -0,0 +1,76 @@
+import ROOT
+
+from plotter import plot_histogram, plot_histogram2d
+
+
+class ResultSet:
+
+    def __init__(self, sample_name, input_filename):
+        self.sample_name = sample_name
+        self.input_filename = input_filename
+        self.load_objects()
+
+        ResultSet.add_collection(self)
+
+    def load_objects(self):
+        file = ROOT.TFile.Open(self.input_filename)
+        l = file.GetListOfKeys()
+        self.map = {}
+        self.values = dict(file.Get("_value_lookup"))
+        for i in range(l.GetSize()):
+            name = l.At(i).GetName()
+            new_name = ":".join((self.sample_name, name))
+            obj = file.Get(name)
+            try:
+                obj.SetName(new_name)
+                obj.SetDirectory(0)  # disconnects Object from file
+            except AttributeError:
+                pass
+            if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
+                obj = obj[0]
+            self.map[name] = obj
+            setattr(self, name, obj)
+        file.Close()
+
+        # Now add these histograms into the current ROOT directory (in memory)
+        # and remove old versions if needed
+        for obj in self.map.values():
+            try:
+                old_obj = ROOT.gDirectory.Get(obj.GetName())
+                ROOT.gDirectory.Remove(old_obj)
+                ROOT.gDirectory.Add(obj)
+            except AttributeError:
+                pass
+
+    @classmethod
+    def calc_shape(cls, n_plots):
+        if n_plots > 3:
+            return ceil(n_plots / 3), 3
+        else:
+            return 1, n_plots
+
+    def draw(self, figure=None, shape=None):
+        objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
+        shape = self.calc_shape(len(objs))
+        if figure is None:
+            import matplotlib.pyplot as plt
+            figure = plt.gcf() if plt.gcf() is not None else plt.figure()
+        figure.clear()
+        for i, (name, obj) in enumerate(objs):
+            axes = figure.add_subplot(*shape, i+1)
+            if isinstance(obj, ROOT.TH2):
+                plot_histogram2d(obj, title=obj.GetTitle(), axes=axes)
+            else:
+                plot_histogram(obj, title=obj.GetTitle(), axes=axes)
+        figure.tight_layout()
+
+    @classmethod
+    def get_hist_set(cls, attrname):
+        return [(sample_name, getattr(h, attrname))
+                for sample_name, h in cls.collections.items()]
+
+    @classmethod
+    def add_collection(cls, hc):
+        if not hasattr(cls, "collections"):
+            cls.collections = {}
+        cls.collections[hc.sample_name] = hc

+ 18 - 126
utils.py

@@ -1,51 +1,29 @@
-
-import io
-import sys
-from os.path import dirname, join, abspath, normpath
-from math import floor, ceil, sqrt
-
 import ROOT
-from graph_vals import parse
-from plotter import plot_histogram, plot_histogram2d
 
-PRJ_PATH = normpath(join(dirname(abspath(__file__)), "../"))
-EXE_PATH = join(PRJ_PATH, "build/main")
+__all__ = ["pdg", "show_function", "show_value"]
 
-PDG = {1:   'd',   -1:  'd̄',
-       2:   'u',   -2:  'ū',
-       3:   's',   -3:  's̄',
-       4:   'c',   -4:  'c̄',
-       5:   'b',   -5:  'b̄',
-       6:   't',   -6:  't̄',
+db = ROOT.TDatabasePDG()
 
-       11:  'e-',  -11: 'e+',
-       12:  'ν_e', -12: 'ῡ_e',
 
-       13:  'μ-',  -13: 'μ+',
-       14:  'ν_μ', -14: 'ῡ_μ',
+class PDGParticle:
 
-       15:  'τ-',  -15: 'τ+',
-       16:  'ν_τ', -16: 'ῡ_τ',
+    def __init__(self, tPart):
+        self.pdgId = tPart.PdgCode()
+        self.name = tPart.GetName()
+        self.charge = tPart.Charge() / 3.0
+        self.mass = tPart.Mass()
+        self.spin = tPart.Spin()
 
-       21:  'g',
-       22:  'γ',
-       23:  'Z0',
-       24:  'W+',  -24: 'W-',
-       25:  'H',
-       }
+    def __repr__(self):
+        return (f"<PDGParticle {self.name}:"
+                f"pdgId={self.pdgId}, charge={self.charge}, mass={self.mass:5.4e} GeV, spin={self.spin}>")
 
 
-def get_color(val, max_val, min_val=0):
-    val = (val-min_val)/(max_val-min_val)
-    val = round(val * (ROOT.gStyle.GetNumberOfColors()-1))
-    col_idx = ROOT.gStyle.GetColorPalette(val)
-    col = ROOT.gROOT.GetColor(col_idx)
-    r = floor(256*col.GetRed())
-    g = floor(256*col.GetGreen())
-    b = floor(256*col.GetBlue())
-    gs = (r + g + b)//3
-    text_color = 'white' if gs < 100 else 'black'
-    return '#{:02x}{:02x}{:02x}'.format(r, g, b), text_color
+def pdg(pdg_id):
+    try:
+        return PDGParticle(db.GetParticle(pdg_id))
+    except ReferenceError:
+        raise ValueError(f"unknown pdgId: {pdg_id}")
 
 
 def show_function(dataset, fname):
@@ -62,6 +40,7 @@ def show_function(dataset, fname):
 
 def show_value(dataset, container):
     from IPython.display import Image
+    from graph_vals import parse
     if type(container) != str:
         container = container.GetName().split(':')[1]
     g, functions = parse(dataset.values[container], container)
@@ -70,90 +49,3 @@ def show_value(dataset, container):
     except Exception as e:
         print(e)
         print(g.to_string())
-
-def normalize_columns(hist2d):
-    normHist = ROOT.TH2D(hist2d)
-    cols, rows = hist2d.GetNbinsX(), hist2d.GetNbinsY()
-    for col in range(1, cols+1):
-        sum_ = 0
-        for row in range(1, rows+1):
-            sum_ += hist2d.GetBinContent(col, row)
-        if sum_ == 0:
-            continue
-        for row in range(1, rows+1):
-            norm = hist2d.GetBinContent(col, row) / sum_
-            normHist.SetBinContent(col, row, norm)
-    return normHist
-
-
-class ResultSet:
-
-    def __init__(self, sample_name, input_filename):
-        self.sample_name = sample_name
-        self.input_filename = input_filename
-        self.load_objects()
-
-        ResultSet.add_collection(self)
-
-    def load_objects(self):
-        file = ROOT.TFile.Open(self.input_filename)
-        l = file.GetListOfKeys()
-        self.map = {}
-        self.values = dict(file.Get("_value_lookup"))
-        for i in range(l.GetSize()):
-            name = l.At(i).GetName()
-            new_name = ":".join((self.sample_name, name))
-            obj = file.Get(name)
-            try:
-                obj.SetName(new_name)
-                obj.SetDirectory(0)  # disconnects Object from file
-            except AttributeError:
-                pass
-            if 'ROOT.vector<int>' in str(type(obj)) and '_count' in name:
-                obj = obj[0]
-            self.map[name] = obj
-            setattr(self, name, obj)
-        file.Close()
-
-        # Now add these histograms into the current ROOT directory (in memory)
-        # and remove old versions if needed
-        for obj in self.map.values():
-            try:
-                old_obj = ROOT.gDirectory.Get(obj.GetName())
-                ROOT.gDirectory.Remove(old_obj)
-                ROOT.gDirectory.Add(obj)
-            except AttributeError:
-                pass
-
-    @classmethod
-    def calc_shape(cls, n_plots):
-        if n_plots > 3:
-            return ceil(n_plots / 3), 3
-        else:
-            return 1, n_plots
-
-    def draw(self, figure=None, shape=None):
-        objs = [(name, obj) for name, obj in self.map.items() if isinstance(obj, ROOT.TH1)]
-        shape = self.calc_shape(len(objs))
-        if figure is None:
-            import matplotlib.pyplot as plt
-            figure = plt.gcf() if plt.gcf() is not None else plt.figure()
-        figure.clear()
-        for i, (name, obj) in enumerate(objs):
-            axes = figure.add_subplot(*shape, i+1)
-            if isinstance(obj, ROOT.TH2):
-                plot_histogram2d(obj, title=obj.GetTitle(), axes=axes)
-            else:
-                plot_histogram(obj, title=obj.GetTitle(), axes=axes)
-        figure.tight_layout()
-
-    @classmethod
-    def get_hist_set(cls, attrname):
-        return [(sample_name, getattr(h, attrname))
-                for sample_name, h in cls.collections.items()]
-
-    @classmethod
-    def add_collection(cls, hc):
-        if not hasattr(cls, "collections"):
-            cls.collections = {}
-        cls.collections[hc.sample_name] = hc