Procházet zdrojové kódy

Implements fully variadic Zip Value. Not yet tested.

Caleb Fangmeier před 7 roky
rodič
revize
0aed5ac4e7

+ 7 - 0
analysis/MiniTreeDataSet.hpp

@@ -77,6 +77,13 @@ class MiniTreeDataSet : public DataSet,
                  << " and type " << typeid(bref).name());
             return new PointerValue<T>(bname, bref);
         }
+
+        template <typename T>
+        WrapperVector<T>* track_branch_vec(const std::string& size_bname, const std::string& bname){
+            track_branch_ptr<T>(bname);
+            return new WrapperVector<T>(size_bname, bname, bname);
+        }
+
         void save_all(){
             output_file->cd();
             for(auto container : containers){

+ 67 - 52
analysis/TTTT_Analysis.cpp

@@ -29,16 +29,16 @@
  * SOFTWARE.
  *
  * @section DESCRIPTION
- * Main analysis routine file
+ * Main analysis routine file. This file declares the Histogram/Graph objects
+ * that will end up in the final root file. It also declares the values that
+ * are used to populate the histogram, as well as how these values are
+ * calculated. See the Fil-Val documentation for how the system works.
  */
 #include <iostream>
 #include <vector>
+#include <tuple>
 #include <utility>
 
-#include "TFile.h"
-#include "TTree.h"
-#include "TCanvas.h"
-
 #include "filval.hpp"
 #include "filval_root.hpp"
 
@@ -60,25 +60,28 @@ void enable_branches(MiniTreeDataSet& mt){
     mt.fChain->SetBranchStatus("*", false);
 
     mt.track_branch<int>("nLepGood");
-    mt.track_branch_ptr<float>("LepGood_pt");
-    mt.track_branch_ptr<float>("LepGood_eta");
-    mt.track_branch_ptr<float>("LepGood_phi");
-    mt.track_branch_ptr<float>("LepGood_mass");
-    mt.track_branch_ptr<float>("LepGood_mcPt");
-    mt.track_branch_ptr<int>("LepGood_charge");
+    mt.track_branch_vec<float>("nLepGood", "LepGood_pt");
+    mt.track_branch_vec<float>("nLepGood", "LepGood_eta");
+    mt.track_branch_vec<float>("nLepGood", "LepGood_phi");
+    mt.track_branch_vec<float>("nLepGood", "LepGood_mass");
+    mt.track_branch_vec<float>("nLepGood", "LepGood_mcPt");
+    mt.track_branch_vec<int>("nLepGood", "LepGood_charge");
 
     mt.track_branch<int>("nJet");
-    mt.track_branch_ptr<float>("Jet_pt");
-    mt.track_branch_ptr<float>("Jet_eta");
-    mt.track_branch_ptr<float>("Jet_phi");
-    mt.track_branch_ptr<float>("Jet_mass");
-    mt.track_branch_ptr<float>("Jet_btagCMVA");
+    mt.track_branch_vec<float>("nJet", "Jet_pt");
+    mt.track_branch_vec<float>("nJet", "Jet_eta");
+    mt.track_branch_vec<float>("nJet", "Jet_phi");
+    mt.track_branch_vec<float>("nJet", "Jet_mass");
+    mt.track_branch_vec<float>("nJet", "Jet_btagCMVA");
 
     mt.track_branch<int>("nGenTop");
+    mt.track_branch_vec<int>("nGenTop", "GenTop_pdgId");
 
     mt.track_branch<int>("ngenLep");
-    mt.track_branch_ptr<int>("genLep_sourceId");
-    mt.track_branch_ptr<float>("genLep_pt");
+    mt.track_branch_vec<int>("ngenLep", "genLep_sourceId");
+    mt.track_branch_vec<float>("ngenLep", "genLep_pt");
+
+    mt.track_branch<int>("nVert");
 }
 
 void declare_values(MiniTreeDataSet& mt){
@@ -89,17 +92,25 @@ void declare_values(MiniTreeDataSet& mt){
                 return t.E();
             })));
 
-    new WrapperVector<float>("nLepGood", "LepGood_pt", "LepGood_pt");
-    new WrapperVector<float>("nLepGood", "LepGood_eta", "LepGood_eta");
-    new WrapperVector<float>("nLepGood", "LepGood_phi", "LepGood_phi");
-    new WrapperVector<float>("nLepGood", "LepGood_mass", "LepGood_mass");
-
-    new WrapperVector<float>("nJet", "Jet_btagCMVA", "Jet_btagCMVA");
-
-
     new ZipMapFour<float, float>(get_energy, "LepGood_pt", "LepGood_eta", "LepGood_phi", "LepGood_mass",
                                  "lepton_energy");
 
+    typedef tuple<float,float,float,float> PtEtaPhiM;
+    auto& get_energy2 = GenFunction::register_function<float(PtEtaPhiM)>("get_energy",
+            FUNC(([](PtEtaPhiM ptetaphim){
+                float pt, eta, phi, m;
+                tie(pt, eta, phi, m) = ptetaphim;
+                TLorentzVector t;
+                t.SetPtEtaPhiM(pt, eta, phi, m);
+                return t.E();
+            })));
+
+    new Zip<float,float,float,float>(dynamic_cast<Value<vector<float>>*>(lookup("LepGood_pt")),
+                                     dynamic_cast<Value<vector<float>>*>(lookup("LepGood_eta")),
+                                     dynamic_cast<Value<vector<float>>*>(lookup("LepGood_phi")),
+                                     dynamic_cast<Value<vector<float>>*>(lookup("LepGood_mass")),
+                                     "lepton_kinematics");
+
     new Pair<vector<float>,vector<float>>("lepton_energy", "LepGood_pt", "lepton_energy_lepton_pt");
 
     new Max<float>("lepton_energy", "lepton_energy_max");
@@ -107,56 +118,60 @@ void declare_values(MiniTreeDataSet& mt){
     new Range<float>("lepton_energy", "lepton_energy_range");
     new Mean<float>("lepton_energy", "lepton_energy_mean");
 
-
     new Count<float>(GenFunction::register_function<bool(float)>("bJet_Selection", FUNC(([](float x){return x>0;}))),
                      "Jet_btagCMVA",  "b_jet_count");
 
+    
+
+
     new Filter("trilepton", FUNC(([nLepGood=lookup("nLepGood")](){
-            return dynamic_cast<Value<int>*>(nLepGood)->get_value() ==3;})));
+            return dynamic_cast<Value<int>*>(nLepGood)->get_value() == 3;})));
 
-    new Filter("os-dilepton", FUNC(([nLepGood=lookup("nLepGood"), LepGood_charge=lookup("LepGood_charge")](){
-                    bool is_dilepton = dynamic_cast<Value<int>*>(nLepGood)->get_value()==2;
-                    int* charge = dynamic_cast<Value<int*>*>(LepGood_charge)->get_value();
-                    return is_dilepton && (charge[0] != charge[1]);})));
+    new Filter("os-dilepton", FUNC(([LepGood_charge=lookup("LepGood_charge")](){
+                    auto& charges = static_cast<Value<vector<int>>*>(LepGood_charge)->get_value();
+                    return charges.size()==2 && (charges[0] != charges[1]);})));
 
-    new Filter("ss-dilepton", FUNC(([nLepGood=lookup("nLepGood"), LepGood_charge=lookup("LepGood_charge")](){
-                    bool is_dilepton = dynamic_cast<Value<int>*>(nLepGood)->get_value()==2;
-                    int* charge = dynamic_cast<Value<int*>*>(LepGood_charge)->get_value();
-                    return is_dilepton && (charge[0] == charge[1]);})));
+    new Filter("ss-dilepton", FUNC(([LepGood_charge=lookup("LepGood_charge")](){
+                    auto& charges = static_cast<Value<vector<int>>*>(LepGood_charge)->get_value();
+                    return charges.size()==2 && (charges[0] == charges[1]);})));
 
 }
 
 void declare_containers(MiniTreeDataSet& mt){
-    mt.register_container(new ContainerTH1I("lepton_count", "Lepton Multiplicity", lookup("nLepGood"), 8, 0, 8));
-    mt.register_container(new ContainerTH1I("top_quark_count", "Top Quark Multiplicity", lookup("nGenTop"), 8, 0, 8));
+    mt.register_container(new ContainerTH1<int>("lepton_count", "Lepton Multiplicity", lookup("nLepGood"), 8, 0, 8));
+    mt.register_container(new ContainerTH1<int>("top_quark_count", "Top Quark Multiplicity", lookup("nGenTop"), 8, 0, 8));
 
-    mt.register_container(new ContainerTH1FMany("lepton_energy_all", "Lepton Energy - All", lookup("lepton_energy"), 50, 0, 500));
-    mt.register_container(new ContainerTH1F("lepton_energy_max", "Lepton Energy - Max", lookup("lepton_energy_max"), 50, 0, 500));
-    mt.register_container(new ContainerTH1F("lepton_energy_min", "Lepton Energy - Min", lookup("lepton_energy_min"), 50, 0, 500));
-    mt.register_container(new ContainerTH1F("lepton_energy_rng", "Lepton Energy - Range", lookup("lepton_energy_range"), 50, 0, 500));
+    mt.register_container(new ContainerTH1Many<float>("lepton_energy_all", "Lepton Energy - All", lookup("lepton_energy"), 50, 0, 500));
+    mt.register_container(new ContainerTH1<float>("lepton_energy_max", "Lepton Energy - Max", lookup("lepton_energy_max"), 50, 0, 500));
+    mt.register_container(new ContainerTH1<float>("lepton_energy_min", "Lepton Energy - Min", lookup("lepton_energy_min"), 50, 0, 500));
+    mt.register_container(new ContainerTH1<float>("lepton_energy_rng", "Lepton Energy - Range", lookup("lepton_energy_range"), 50, 0, 500));
 
 
-    mt.register_container(new ContainerTGraph("nLepvsnJet", new Pair<int, int>("nLepGood", "nJet") ));
+    mt.register_container(new ContainerTGraph("nLepvsnJet", "Number of Leptons vs Number of Jets", new Pair<int, int>("nLepGood", "nJet") ));
 
-    mt.register_container(new ContainerTH2FMany("lepton_energy_vs_pt", "Lepton Energy - Range", lookup("lepton_energy_lepton_pt"),
-                                               50, 0, 500, 50, 0, 500));
+    mt.register_container(new ContainerTH2Many<float>("lepton_energy_vs_pt", "Lepton Energy - Range", lookup("lepton_energy_lepton_pt"),
+                                                50, 0, 500, 50, 0, 500));
 
-    mt.register_container(new ContainerTH1I("b_jet_count", "B-Jet Multiplicity", lookup("b_jet_count"), 10, 0, 10));
+    mt.register_container(new ContainerTH1<int>("b_jet_count", "B-Jet Multiplicity", lookup("b_jet_count"), 10, 0, 10));
 
 
-    mt.register_container(new ContainerTH1I("jet_count_os_dilepton", "Jet Multiplicity - OS Dilepton Events", lookup("nJet"), 14, 0, 14));
+    mt.register_container(new ContainerTH1<int>("jet_count_os_dilepton", "Jet Multiplicity - OS Dilepton Events", lookup("nJet"), 14, 0, 14));
     mt.get_container("jet_count_os_dilepton")->add_filter(lookup_filter("os-dilepton"));
-    mt.register_container(new ContainerTH1I("jet_count_ss_dilepton", "Jet Multiplicity - SS Dilepton Events", lookup("nJet"), 14, 0, 14));
+    mt.register_container(new ContainerTH1<int>("jet_count_ss_dilepton", "Jet Multiplicity - SS Dilepton Events", lookup("nJet"), 14, 0, 14));
     mt.get_container("jet_count_ss_dilepton")->add_filter(lookup_filter("ss-dilepton"));
-    mt.register_container(new ContainerTH1I("jet_count_trilepton", "Jet Multiplicity - Trilepton Events", lookup("nJet"), 14, 0, 14));
+    mt.register_container(new ContainerTH1<int>("jet_count_trilepton", "Jet Multiplicity - Trilepton Events",     lookup("nJet"), 14, 0, 14));
     mt.get_container("jet_count_trilepton")->add_filter(lookup_filter("trilepton"));
-}
 
-std::string replace_suffix(const std::string& input, const std::string& new_suffix){
-    return input.substr(0, input.find_last_of(".")) + new_suffix;
+    mt.register_container(new ContainerTH1<int>("primary_vert_count", "Number of Primary Vertices", lookup("nVert"), 50, 0, 50));
+
+    mt.register_container(new CounterMany<int>("GenTop_pdg_id", lookup("GenTop_pdgId")));
 }
 
+
 void run_analysis(const std::string& input_filename, bool silent){
+    auto replace_suffix = [](const std::string& input, const std::string& new_suffix){
+        return input.substr(0, input.find_last_of(".")) + new_suffix;
+    };
     string log_filename = replace_suffix(input_filename, "_result.log");
     Log::init_logger(log_filename, LogPriority::kLogDebug);
 

+ 26 - 5
filval/container.hpp

@@ -1,8 +1,27 @@
 #ifndef container_hpp
 #define container_hpp
+#include <typeindex>
+#include <vector>
+#include <map>
+
 #include "value.hpp"
 #include "filter.hpp"
-#include <vector>
+
+namespace fv::util{
+std::string& get_type_name(const std::type_index& index){
+    std::map<std::type_index, std::string> _map;
+    // Add to this list as needed :)
+    _map[typeid(int)]="int";
+    _map[typeid(unsigned int)]="unsigned int";
+    _map[typeid(float)]="float";
+    _map[typeid(double)]="double";
+
+    if (_map[index] == ""){
+        CRITICAL("Cannot lookup type name of \"" << index.name() << "\"",-1);
+    }
+    return _map[index];
+}
+}
 
 namespace fv{
 
@@ -81,8 +100,9 @@ class ContainerVector : public Container<std::vector<T> >{
            value(value){
             this->container = new std::vector<T>();
         }
-        void save_as(const std::string& fname) { }
-        virtual void save() { }
+        void save_as(const std::string& fname) {
+            WARNING("Saving of ContainerVector objects not supported");
+        }
 };
 
 template <typename T>
@@ -107,8 +127,9 @@ class ContainerMean : public Container<T>{
             *(this->container) = sum/count;
             return (this->container);
         }
-        void save_as(const std::string& fname) { }
-        virtual void save() { }
+        void save_as(const std::string& fname) {
+            WARNING("Saving of ContainerMean objects not supported");
+        }
 };
 
 }

+ 3 - 0
filval/filter.hpp

@@ -48,6 +48,9 @@ class Filter : public DerivedValue<bool>{
         void update_value(){
             value = filter_function();
         }
+
+        void verify_integrity(){ };
+
     public:
         Filter(const std::string& name, std::function<bool()> filter_function, const std::string& impl="")
           :DerivedValue<bool>(name),

+ 194 - 11
filval/value.hpp

@@ -47,6 +47,7 @@
 #include <utility>
 #include <algorithm>
 #include <map>
+#include <limits>
 #include <vector>
 #include <tuple>
 #include <initializer_list>
@@ -80,9 +81,8 @@ class GenFunction {
         inline static std::map<const std::string, GenFunction*> function_registry;
 
         GenFunction(const std::string& name, const std::string& impl)
-          :impl(impl),
-           name(name){
-        }
+          :impl(impl), name(name){ }
+
         virtual ~GenFunction() { };
 
         std::string& get_name(){
@@ -189,6 +189,7 @@ class GenValue{
          * values based on their name via GenValue::get_value.
          */
         std::string name;
+
     protected:
         /**
          * Mark the internal value as invalid. This is needed for DerivedValue
@@ -213,6 +214,17 @@ class GenValue{
          * over a name with that value.
          */
         inline static std::map<const std::string, GenValue*> aliases;
+
+        /**
+         * This function serves to check that this Value has been created with
+         * real, i.e. non null, arguments. This is to avoid segfaulting when a
+         * dynamic_cast fails. If no checks need to be made, simple override
+         * this method with a no-op. If checks fail, the function should
+         * utilize the CRITICAL macro with a meaningfull error message stating
+         * what failed and especially the name of the current value.
+         */
+        virtual void verify_integrity() = 0;
+
     public:
         GenValue(const std::string& name, const std::string& alias)
           :name(name){
@@ -239,7 +251,7 @@ class GenValue{
             else{
                 ERROR("Could not find alias or value \"" << name << "\". I'll tell you the ones I know about." << std::endl
                         << summary());
-                CRITICAL("Aborting... :(", -1);
+                CRITICAL("Aborting... :(",-1);
             }
         }
 
@@ -320,6 +332,12 @@ class ObservedValue : public Value<T>{
     private:
         T *val_ref;
         void _reset(){ }
+
+        void verify_integrity() {
+            if (val_ref == nullptr)
+                CRITICAL("ObservedValue " << this->get_name() << " created with null pointer",-1);
+        }
+
     public:
         ObservedValue(const std::string& name, T* val_ref, const std::string& alias="")
           :Value<T>(name, alias),
@@ -400,6 +418,13 @@ class WrapperVector : public DerivedValue<std::vector<T> >{
             this->value.assign(data_ref, data_ref+n);
         }
 
+        void verify_integrity() {
+            if (size == nullptr)
+                CRITICAL("WrapperVector " << this->get_name() << " created with invalid size.",-1);
+            if (data == nullptr)
+                CRITICAL("WrapperVector " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         WrapperVector(Value<int>* size, Value<T*>* data, const std::string& alias="")
           :DerivedValue<std::vector<T> >("vectorOf("+size->get_name()+","+data->get_name()+")", alias),
@@ -421,6 +446,14 @@ class Pair : public DerivedValue<std::pair<T1, T2> >{
             this->value.first = value_pair.first->get_value();
             this->value.second = value_pair.second->get_value();
         }
+
+        void verify_integrity() {
+            if (value_pair.first == nullptr)
+                CRITICAL("Pair " << this->get_name() << " created with invalid first value.",-1);
+            if (value_pair.second == nullptr)
+                CRITICAL("Pair " << this->get_name() << " created with invalid second value.",-1);
+        }
+
     public:
         Pair(Value<T1> *value1, Value<T2> *value2, const std::string alias="")
           :DerivedValue<std::pair<T1, T2> >("pair("+value1->get_name()+","+value2->get_name()+")", alias),
@@ -431,6 +464,72 @@ class Pair : public DerivedValue<std::pair<T1, T2> >{
                 alias){ }
 };
 
+template<typename... T> class _Zip;
+template<>
+class _Zip<> {
+    protected:
+
+        int _get_size(){
+            return std::numeric_limits<int>::max();
+        }
+
+        std::tuple<> _get_at(int idx){
+            return std::make_tuple();
+        }
+    public:
+        _Zip() { }
+};
+
+template<typename Head, typename... Tail>
+class _Zip<Head, Tail...> : private _Zip<Tail...> {
+    protected:
+        Value<std::vector<Head>>* head;
+
+        int _get_size(){
+            int this_size = head->get_value().size();
+            int rest_size = _Zip<Tail...>::_get_size();
+            return std::min(this_size, rest_size);
+        }
+
+        typename std::tuple<Head,Tail...> _get_at(int idx){
+            auto tail_tuple = _Zip<Tail...>::_get_at(idx);
+            return std::tuple_cat(std::make_tuple(head->get_value()[idx]),tail_tuple);
+        }
+    public:
+        _Zip() { }
+
+        _Zip(Value<std::vector<Head>>* head, Value<std::vector<Tail>>*... tail)
+          : _Zip<Tail...>(tail...),
+            head(head) { }
+};
+
+/**
+ * Zips a series of observations together
+ */
+template <typename... ArgTypes>
+class Zip : public DerivedValue<std::vector<std::tuple<ArgTypes...>>>,
+             private _Zip<ArgTypes...>{
+    protected:
+        void update_value(){
+            /* auto tuple_of_vectors this->_get_value(); */
+            this->value.clear();
+            int size = _Zip<ArgTypes...>::_get_size();
+            for(int i=0; i<size; i++){
+                this->value.push_back(_Zip<ArgTypes...>::_get_at(i));
+            }
+        }
+
+        /**
+         * /todo Implement this.
+         */
+        void verify_integrity() { }
+
+    public:
+        Zip(Value<std::vector<ArgTypes>>*... args, const std::string alias="")
+          :DerivedValue<std::vector<std::tuple<ArgTypes...>>>("a kickin zip", ""),
+           _Zip<ArgTypes...>(args...) { }
+};
+
 /**
  * Takes a set of four Value<std::vector<T> > objects and a function of four Ts
  * and returns a std::vector<R>. This is used in, for instance, calculating the
@@ -467,6 +566,17 @@ class ZipMapFour : public DerivedValue<std::vector<R> >{
             }
         }
 
+        void verify_integrity() {
+            if (v1 == nullptr)
+                CRITICAL("ZipMapFour " << this->get_name() << " created with invalid first value.",-1);
+            if (v2 == nullptr)
+                CRITICAL("ZipMapFour " << this->get_name() << " created with invalid second value.",-1);
+            if (v3 == nullptr)
+                CRITICAL("ZipMapFour " << this->get_name() << " created with invalid third value.",-1);
+            if (v4 == nullptr)
+                CRITICAL("ZipMapFour " << this->get_name() << " created with invalid fourth value.",-1);
+        }
+
     public:
         ZipMapFour(Function<R(T, T, T, T)>& f,
                    Value<std::vector<T> >* v1, Value<std::vector<T> >* v2,
@@ -487,13 +597,14 @@ class ZipMapFour : public DerivedValue<std::vector<R> >{
 };
 
 /**
- *
+ * Returns the count of elements in the input vector passing a test function.
  */
 template<typename T>
 class Count : public DerivedValue<int>{
     private:
         Function<bool(T)>& selector;
         Value<std::vector<T> >* v;
+
         void update_value(){
             value = 0;
             for(auto val : v->get_value()){
@@ -501,6 +612,12 @@ class Count : public DerivedValue<int>{
                     value++;
             }
         }
+
+        void verify_integrity() {
+            if (v == nullptr)
+                CRITICAL("Count " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         Count(Function<bool(T)>& selector, Value<std::vector<T>>* v, const std::string alias="")
           :DerivedValue<int>("count("+selector.get_name()+":"+v->get_name()+")", alias),
@@ -521,10 +638,19 @@ template <typename T>
 class Reduce : public DerivedValue<T>{
     private:
         Function<T(std::vector<T>)>& reduce;
-        Value<std::vector<T> >* v;
+
         void update_value(){
             this->value = reduce(v->get_value());
         }
+
+        virtual void verify_integrity() {
+            if (v == nullptr)
+                CRITICAL("Reduce " << this->get_name() << " created with invalid value.",-1);
+        }
+
+    protected:
+        Value<std::vector<T> >* v;
+
     public:
         Reduce(Function<T(std::vector<T>)>& reduce, Value<std::vector<T> >* v, const std::string alias="")
           :DerivedValue<T>("reduceWith("+reduce.get_name()+":"+v->get_name()+")", alias),
@@ -539,6 +665,11 @@ class Reduce : public DerivedValue<T>{
  */
 template <typename T>
 class Max : public Reduce<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("Max " << this->get_name() << " created with invalid value.",-1);
+        }
     public:
         Max(const std::string& v_name, const std::string alias="")
           :Reduce<T>(GenFunction::register_function<T(std::vector<T>)>("max",
@@ -552,6 +683,11 @@ class Max : public Reduce<T>{
  */
 template <typename T>
 class Min : public Reduce<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("Min " << this->get_name() << " created with invalid value.",-1);
+        }
     public:
         Min(const std::string& v_name, const std::string alias="")
           :Reduce<T>(GenFunction::register_function<T(std::vector<T>)>("min",
@@ -565,6 +701,12 @@ class Min : public Reduce<T>{
  */
 template <typename T>
 class Mean : public Reduce<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("Mean " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         Mean(const std::string& v_name, const std::string alias="")
           :Reduce<T>(GenFunction::register_function<T(std::vector<T>)>("mean",
@@ -580,6 +722,12 @@ class Mean : public Reduce<T>{
  */
 template <typename T>
 class Range : public Reduce<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("Range " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         Range(const std::string& v_name, const std::string alias="")
           :Reduce<T>(GenFunction::register_function<T(std::vector<T>)>("range",
@@ -594,6 +742,12 @@ class Range : public Reduce<T>{
  */
 template <typename T>
 class ElementOf : public Reduce<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("ElementOf " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         ElementOf(Value<int>* index, const std::string& v_name, const std::string alias="")
           :Reduce<T>(GenFunction::register_function<T(std::vector<T>)>("elementOf",
@@ -613,9 +767,16 @@ class ReduceIndex : public DerivedValue<std::pair<T, int> >{
     private:
         Function<std::pair<T,int>(std::vector<T>)>& reduce;
         Value<std::vector<T> >* v;
+
         void update_value(){
             this->value = reduce(v->get_value());
         }
+
+        virtual void verify_integrity() {
+            if (v == nullptr)
+                CRITICAL("ReduceIndex " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         ReduceIndex(Function<std::pair<T,int>(std::vector<T>)>& reduce, Value<std::vector<T> >* v, const std::string alias="")
           :DerivedValue<T>("reduceIndexWith("+reduce.get_name()+":"+v->get_name()+")", alias),
@@ -630,6 +791,12 @@ class ReduceIndex : public DerivedValue<std::pair<T, int> >{
  */
 template <typename T>
 class MaxIndex : public ReduceIndex<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("MaxIndex " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         MaxIndex(const std::string& v_name, const std::string alias="")
           :ReduceIndex<T>(GenFunction::register_function<T(std::vector<T>)>("maxIndex",
@@ -644,6 +811,12 @@ class MaxIndex : public ReduceIndex<T>{
  */
 template <typename T>
 class MinIndex : public ReduceIndex<T>{
+    private:
+        void verify_integrity() {
+            if (this->v == nullptr)
+                CRITICAL("MinIndex " << this->get_name() << " created with invalid value.",-1);
+        }
+
     public:
         MinIndex(const std::string& v_name, const std::string alias="")
           :ReduceIndex<T>(GenFunction::register_function<T(std::vector<T>)>("minIndex",
@@ -660,6 +833,8 @@ class MinIndex : public ReduceIndex<T>{
  */
 template <typename T>
 class BoundValue : public DerivedValue<T>{
+    private:
+        void verify_integrity() { }
     protected:
         Function<T()>& f;
         void update_value(){
@@ -677,8 +852,15 @@ class BoundValue : public DerivedValue<T>{
  */
 template <typename T>
 class PointerValue : public DerivedValue<T*>{
+    private:
+        void verify_integrity() {
+            if(this->value == nullptr)
+                CRITICAL("PointerValue " << this->get_name() << " created with null pointer",-1);
+        }
+
     protected:
         void update_value(){ }
+
     public:
         PointerValue(const std::string& name, T* ptr, const std::string alias="")
           :DerivedValue<T*>(name, alias){
@@ -691,15 +873,16 @@ class PointerValue : public DerivedValue<T*>{
  */
 template <typename T>
 class ConstantValue : public DerivedValue<T>{
+    private:
+        void verify_integrity() { }
+
     protected:
-        T const_value;
-        void update_value(){
-            this->value = const_value;
-        }
+        void update_value(){ }
+
     public:
         ConstantValue(const std::string& name, T const_value, const std::string alias="")
             :DerivedValue<T>("const::"+name, alias),
-             const_value(const_value) { }
+             Value<T>::value(const_value) { }
 };
 }
 #endif // value_hpp

+ 5 - 1
filval_root/README.md

@@ -1,2 +1,6 @@
-ROOT compatability layer for FILter-VALue System
+ROOT compatability layer for FilVal
 ================================================
+See [FilVal](../filval/README.md) for details on FilVal. This layer provides
+container classes wrapping ROOT histograms and Graph objects. It also provides
+the ability to write these containers, as well as a variety of STL containers
+to ROOT files.

+ 139 - 176
filval_root/container.hpp

@@ -1,39 +1,73 @@
 #ifndef root_container_hpp
 #define root_container_hpp
-#include <utility>
 #include <iostream>
+#include <utility>
+#include <map>
+
+#include "TROOT.h"
 #include "TFile.h"
+#include "TCanvas.h"
+#include "TGraph.h"
 #include "TH1.h"
 #include "TH2.h"
-#include "TGraph.h"
-#include "TROOT.h"
 
 #include "filval.hpp"
 
 namespace fv::root::util{
-void _save_img(TObject* container, const std::string& fname){
-    TCanvas* c1 = new TCanvas("c1");
-    container->Draw();
-    c1->Draw();
-    c1->SaveAs(fname.c_str());
-    delete c1;
-}
 
-void _save_bin(TObject* container){
-    INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
-    /* TFile* f = TFile::Open(fname.c_str(), "UPDATE"); */
-    container->Write(container->GetName(), TObject::kOverwrite);
-    /* f->Close(); */
+/**
+ * Save a TObject. The TObject will typically be a Histogram or Graph object,
+ * but can really be any TObject. The SaveOption can be used to specify how to
+ * save the file.
+ */
+void save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
+
+    auto save_img = [](TObject* container, const std::string& fname){
+        TCanvas* c1 = new TCanvas("c1");
+        container->Draw();
+        c1->Draw();
+        c1->SaveAs(fname.c_str());
+        delete c1;
+    };
+
+    auto save_bin = [](TObject* container){
+        INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
+        container->Write(container->GetName(), TObject::kOverwrite);
+    };
+
+    switch(option){
+        case PNG:
+            save_img(container, fname+".png"); break;
+        case PDF:
+            save_img(container, fname+".pdf"); break;
+        case ROOT:
+            save_bin(container); break;
+        default:
+            break;
+    }
 }
 
-void _save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
+
+/**
+ * Saves an STL container into a ROOT file. ROOT knows how to serialize STL
+ * containers, but it needs the *name* of the type of the container, eg.
+ * std::map<int,int> to be able to do this. In order to generate this name at
+ * run-time, the fv::util::get_type_name function uses RTTI to get type info
+ * and use it to look up the proper name.
+ */
+void save_as_stl(void* container, const std::string& type_name,
+                  const std::string& obj_name,
+                  const SaveOption& option = SaveOption::PNG) {
     switch(option){
         case PNG:
-            _save_img(container, fname+".png"); break;
+            INFO("Cannot save STL container " << type_name <<" as png");
+            break;
         case PDF:
-            _save_img(container, fname+".pdf"); break;
+            INFO("Cannot save STL container " << type_name <<" as pdf");
+            break;
         case ROOT:
-            _save_bin(container); break;
+            gDirectory->WriteObjectAny(container, type_name.c_str(), obj_name.c_str());
+            break;
         default:
             break;
     }
@@ -42,8 +76,8 @@ void _save_as(TObject* container, const std::string& fname, const SaveOption& op
 
 namespace fv::root {
 
-template <typename V, typename D>
-class ContainerTH1 : public Container<TH1>{
+template <typename V>
+class _ContainerTH1 : public Container<TH1>{
     private:
         void _fill(){
             if (container == nullptr){
@@ -51,109 +85,54 @@ class ContainerTH1 : public Container<TH1>{
                     CRITICAL("Container: \"" << get_name() << "\" has a null Value object. "
                              << "Probably built with imcompatible type",-1);
                 }
-                init_TH1();
+                this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
+                                           this->nbins, this->low, this->high);
             }
-            _do_fill(value->get_value());
+            _do_fill();
         }
 
     protected:
         std::string title;
         int nbins;
-        D low;
-        D high;
+        double low;
+        double high;
         Value<V> *value;
-        virtual void init_TH1() = 0;
-        virtual void _do_fill(V& val) = 0;
+
+        virtual void _do_fill() = 0;
 
     public:
-        explicit ContainerTH1(const std::string &name, const std::string& title, GenValue *value,
-                     int nbins, D low, D high)
+        explicit _ContainerTH1(const std::string &name, const std::string& title, GenValue *value,
+                               int nbins, double low, double high)
           :Container<TH1>(name, nullptr),
            title(title), nbins(nbins), low(low), high(high),
            value(dynamic_cast<Value<V>*>(value)) { }
 
         void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
-            util::_save_as(get_container(), fname, option);
+            util::save_as(get_container(), fname, option);
         }
 };
 
 template <typename V>
-class _ContainerTH1D : public ContainerTH1<V, double>{
-    using ContainerTH1<V, double>::ContainerTH1;
-    void init_TH1(){
-        this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins, this->low, this->high);
+class ContainerTH1 : public _ContainerTH1<V>{
+    using _ContainerTH1<V>::_ContainerTH1;
+    void _do_fill(){
+        this->container->Fill(this->value->get_value());
     }
 };
 
-class ContainerTH1D : public _ContainerTH1D<double>{
-    using _ContainerTH1D<double>::_ContainerTH1D;
-    void _do_fill(double& val){
-        this->container->Fill(val);
-    }
-};
-
-class ContainerTH1DMany : public _ContainerTH1D<std::vector<double>>{
-    using _ContainerTH1D<std::vector<double>>::_ContainerTH1D;
-    void _do_fill(std::vector<double>& val){
-        for(double x : val)
-            this->container->Fill(x);
-    }
-};
-
-
 template <typename V>
-class _ContainerTH1F : public ContainerTH1<V, float>{
-    using ContainerTH1<V,float>::ContainerTH1;
-    void init_TH1(){
-        this->container = new TH1F(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins, this->low, this->high);
-    }
-};
-
-class ContainerTH1F : public _ContainerTH1F<float>{
-    using _ContainerTH1F<float>::_ContainerTH1F;
-    void _do_fill(float& val){
-        this->container->Fill(val);
-    }
-};
-
-class ContainerTH1FMany : public _ContainerTH1F<std::vector<float>>{
-    using _ContainerTH1F<std::vector<float>>::_ContainerTH1F;
-    void _do_fill(std::vector<float>& val){
-        for(float x : val)
+class ContainerTH1Many : public _ContainerTH1<std::vector<V>>{
+    using _ContainerTH1<std::vector<V>>::_ContainerTH1;
+    void _do_fill(){
+        for(V x : this->value->get_value())
             this->container->Fill(x);
     }
 };
 
 
-template <typename V>
-class _ContainerTH1I : public ContainerTH1<V, int>{
-    using ContainerTH1<V,int>::ContainerTH1;
-    void init_TH1(){
-        this->container = new TH1I(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins, this->low, this->high);
-    }
-};
-
-class ContainerTH1I : public _ContainerTH1I<int>{
-    using _ContainerTH1I<int>::_ContainerTH1I;
-    void _do_fill(int& val){
-        this->container->Fill(val);
-    }
-};
-
-class ContainerTH1IMany : public _ContainerTH1I<std::vector<int>>{
-    using _ContainerTH1I<std::vector<int>>::_ContainerTH1I;
-    void _do_fill(std::vector<int>& val){
-        for(int x : val)
-            this->container->Fill(x);
-    }
-};
 
-
-template <typename V, typename D>
-class ContainerTH2 : public Container<TH2>{
+template <typename V>
+class _ContainerTH2 : public Container<TH2>{
     private:
         void _fill(){
             if (container == nullptr){
@@ -161,7 +140,9 @@ class ContainerTH2 : public Container<TH2>{
                     CRITICAL("Container: \"" << get_name() << "\" has a null Value object. "
                              << "Probably built with imcompatible type",-1);
                 }
-                init_TH2();
+                this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
+                                           this->nbins_x, this->low_x, this->high_x,
+                                           this->nbins_y, this->low_y, this->high_y);
             }
             _do_fill(value->get_value());
         }
@@ -170,101 +151,41 @@ class ContainerTH2 : public Container<TH2>{
         std::string title;
         int nbins_x;
         int nbins_y;
-        D low_x;
-        D low_y;
-        D high_x;
-        D high_y;
+        double low_x;
+        double low_y;
+        double high_x;
+        double high_y;
         Value<std::pair<V,V>> *value;
-        virtual void init_TH2() = 0;
+
         virtual void _do_fill(std::pair<V,V>& val) = 0;
 
     public:
-        explicit ContainerTH2(const std::string& name, const std::string& title,
-                              GenValue* value,
-                              int nbins_x, D low_x, D high_x,
-                              int nbins_y, D low_y, D high_y)
+        explicit _ContainerTH2(const std::string& name, const std::string& title,
+                               GenValue* value,
+                               int nbins_x, double low_x, double high_x,
+                               int nbins_y, double low_y, double high_y)
           :Container<TH2>(name, nullptr),
            nbins_x(nbins_x), low_x(low_x), high_x(high_x),
            nbins_y(nbins_y), low_y(low_y), high_y(high_y),
            value(dynamic_cast<Value<std::pair<V, V>>*>(value)) { }
 
         void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
-            util::_save_as(get_container(), fname, option);
+            util::save_as(get_container(), fname, option);
         }
 };
 
 template <typename V>
-class _ContainerTH2D : public ContainerTH2<V, double>{
-    using ContainerTH2<V, double>::ContainerTH2;
-    void init_TH2(){
-        this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins_x, this->low_x, this->high_x,
-                                   this->nbins_y, this->low_y, this->high_y);
-    }
-};
-
-class ContainerTH2D : public _ContainerTH2D<double>{
-    using _ContainerTH2D<double>::_ContainerTH2D;
-    void _do_fill(std::pair<double,double>& val){
-        this->container->Fill(val.first, val.second);
-    }
-};
-
-class ContainerTH2DMany : public _ContainerTH2D<std::vector<double>>{
-    using _ContainerTH2D<std::vector<double>>::_ContainerTH2D;
-    void _do_fill(std::pair<std::vector<double>,std::vector<double>>& val){
-        int min_size = std::min(val.first.size(), val.second.size());
-        for(int i=0; i<min_size; i++)
-            this->container->Fill(val.first[i],val.second[i]);
+class ContainerTH2 : public _ContainerTH2<std::vector<V>>{
+    using _ContainerTH2<std::vector<V>>::_ContainerTH2;
+    void _do_fill(std::pair<V,V>& val){
+        this->container->Fill(val.first,val.second);
     }
 };
 
 template <typename V>
-class _ContainerTH2F : public ContainerTH2<V, float>{
-    using ContainerTH2<V, float>::ContainerTH2;
-    void init_TH2(){
-        this->container = new TH2F(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins_x, this->low_x, this->high_x,
-                                   this->nbins_y, this->low_y, this->high_y);
-    }
-};
-
-class ContainerTH2F : public _ContainerTH2F<float>{
-    using _ContainerTH2F<float>::_ContainerTH2F;
-    void _do_fill(std::pair<float,float>& val){
-        this->container->Fill(val.first, val.second);
-    }
-};
-
-class ContainerTH2FMany : public _ContainerTH2F<std::vector<float>>{
-    using _ContainerTH2F<std::vector<float>>::_ContainerTH2F;
-    void _do_fill(std::pair<std::vector<float>,std::vector<float>>& val){
-        int min_size = std::min(val.first.size(), val.second.size());
-        for(int i=0; i<min_size; i++)
-            this->container->Fill(val.first[i],val.second[i]);
-    }
-};
-
-template <typename V>
-class _ContainerTH2I : public ContainerTH2<V, int>{
-    using ContainerTH2<V, int>::ContainerTH2;
-    void init_TH2(){
-        this->container = new TH2I(this->get_name().c_str(), this->title.c_str(),
-                                   this->nbins_x, this->low_x, this->high_x,
-                                   this->nbins_y, this->low_y, this->high_y);
-    }
-};
-
-class ContainerTH2I : public _ContainerTH2I<int>{
-    using _ContainerTH2I<int>::_ContainerTH2I;
-    void _do_fill(std::pair<int,int>& val){
-        this->container->Fill(val.first, val.second);
-    }
-};
-
-class ContainerTH2IMany : public _ContainerTH2I<std::vector<int>>{
-    using _ContainerTH2I<std::vector<int>>::_ContainerTH2I;
-    void _do_fill(std::pair<std::vector<int>,std::vector<int>>& val){
+class ContainerTH2Many : public _ContainerTH2<std::vector<V>>{
+    using _ContainerTH2<std::vector<V>>::_ContainerTH2;
+    void _do_fill(std::pair<std::vector<V>,std::vector<V>>& val){
         int min_size = std::min(val.first.size(), val.second.size());
         for(int i=0; i<min_size; i++)
             this->container->Fill(val.first[i],val.second[i]);
@@ -276,6 +197,7 @@ class ContainerTGraph : public Container<TGraph>{
         Value<std::pair<int, int> > *value;
         std::vector<int> x_data;
         std::vector<int> y_data;
+        std::string title;
         bool data_modified;
         void _fill(){
             auto val = value->get_value();
@@ -284,7 +206,7 @@ class ContainerTGraph : public Container<TGraph>{
             data_modified = true;
         }
     public:
-        ContainerTGraph(const std::string &name, GenValue* value)
+        ContainerTGraph(const std::string& name, const std::string& title, GenValue* value)
           :Container<TGraph>(name, new TGraph()),
            value(dynamic_cast<Value<std::pair<int, int> >*>(value)),
            data_modified(false){ }
@@ -294,14 +216,55 @@ class ContainerTGraph : public Container<TGraph>{
                 delete container;
                 container = new TGraph(x_data.size(), x_data.data(), y_data.data());
                 container->SetName(get_name().c_str());
+                container->SetTitle(title.c_str());
                 data_modified = false;
             }
             return container;
         }
         void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
-            util::_save_as(get_container(), fname, option);
+            util::save_as(get_container(), fname, option);
         }
 };
 
+
+template <typename V, typename D>
+class _Counter : public Container<std::map<D,int>>{
+    protected:
+        Value<V>* value;
+    public:
+        explicit _Counter(const std::string& name, GenValue* value)
+          :Container<std::map<D,int>>(name, new std::map<D,int>()),
+           value(dynamic_cast<Value<V>*>(value)) { }
+
+        void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
+            std::string type_name = "std::map<"+fv::util::get_type_name(typeid(D))+",int>";
+            util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
+        }
+
+};
+
+/**
+ * A Counter that keeps a mapping of the number of occurances of each input
+ * value.
+ */
+template <typename V>
+class Counter : public _Counter<V,V>{
+    using _Counter<V,V>::_Counter;
+        void _fill(){
+            (*this->container)[this->value->get_value()]++;
+        }
+};
+
+/**
+ * Same as Counter but accepts multiple values per fill.
+ */
+template <typename V>
+class CounterMany : public _Counter<std::vector<V>,V>{
+    using _Counter<std::vector<V>,V>::_Counter;
+        void _fill(){
+            for(V& val : this->value->get_value())
+                (*this->container)[val]++;
+        }
+};
 }
 #endif // root_container_hpp

+ 19 - 0
filval_root/value.hpp

@@ -11,9 +11,22 @@ class LorentzVector : public DerivedValue<TLorentzVector>{
         Value<double> *eta;
         Value<double> *phi;
         Value<double> *m;
+
         void update_value(){
             value.SetPtEtaPhiM(pt->get_value(), eta->get_value(), phi->get_value(), m->get_value());
         }
+
+        void verify_integrity(){
+            if (pt == nullptr)
+                CRITICAL("LorentzVector " << this->get_name() << " created with invalid pt", -1);
+            if (eta == nullptr)
+                CRITICAL("LorentzVector " << this->get_name() << " created with invalid eta", -1);
+            if (phi == nullptr)
+                CRITICAL("LorentzVector " << this->get_name() << " created with invalid phi", -1);
+            if (m == nullptr)
+                CRITICAL("LorentzVector " << this->get_name() << " created with invalid mass", -1);
+        }
+
     public:
         LorentzVector(const std::string& name,
                       Value<double>* pt,
@@ -42,6 +55,12 @@ class LorentzVectorEnergy : public DerivedValue<double>{
         void update_value(){
             value = vector->get_value().E();
         }
+
+        void verify_integrity(){
+            if (vector == nullptr)
+                CRITICAL("LorentzVectorEnergy " << this->get_name() << " created with invalid vector", -1);
+        }
+
     public:
         LorentzVectorEnergy(const std::string& name, Value<TLorentzVector>* vector)
           :DerivedValue<double>(name),

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 119 - 241
python/TTTT_Analysis.ipynb


+ 124 - 93
python/utils.py

@@ -5,43 +5,6 @@ from subprocess import run
 import itertools as it
 import ROOT
 
-
-class HistCollection:
-    def __init__(self, sample_name, input_filename,
-                 exe_path="../build/main",
-                 rebuild_hists = False):
-        self.sample_name = sample_name
-        if rebuild_hists:
-            run([exe_path, "-s", "-f", input_filename])
-        output_filename = input_filename.replace(".root", "_result.root")
-        self._file = ROOT.TFile.Open(output_filename)
-        l = self._file.GetListOfKeys()
-        self.map = {}
-        for i in range(l.GetSize()):
-            name = l.At(i).GetName()
-            self.map[name] = self._file.Get(name)
-            setattr(self, name, self.map[name])
-
-    def draw(self, canvas, shape=None):
-        if shape is None:
-            n = int(ceil(sqrt(len(self.map))))
-            shape = (n, n)
-        print(shape)
-        canvas.Clear()
-        canvas.Divide(*shape)
-        for i, hist in enumerate(self.map.values()):
-            canvas.cd(i+1)
-            try:
-                hist.SetStats(False)
-            except AttributeError:
-                pass
-            print(i, hist, str(type(hist)))
-            draw_option = ""
-            if (type(hist) == ROOT.TH2F):
-                draw_option = "COLZ"
-            hist.Draw(draw_option)
-
-
 class OutputCapture:
     def __init__(self):
         self.my_stdout = io.StringIO()
@@ -89,60 +52,128 @@ def normalize_columns(hist2d):
             normHist.SetBinContent(col, row, norm)
     return normHist
 
+class HistCollection:
+    def __init__(self, sample_name, input_filename,
+                 exe_path="../build/main",
+                 rebuild_hists=False):
+        self.sample_name = sample_name
+        if rebuild_hists:
+            run([exe_path, "-s", "-f", input_filename])
+        output_filename = input_filename.replace(".root", "_result.root")
+        file = ROOT.TFile.Open(output_filename)
+        l = file.GetListOfKeys()
+        self.map = {}
+        for i in range(l.GetSize()):
+            name = l.At(i).GetName()
+            new_name = ":".join((sample_name, name))
+            obj = file.Get(name)
+            try:
+                obj.SetName(new_name)
+                obj.SetDirectory(0)  # disconnects Object from file
+            except AttributeError:
+                pass
+            self.map[name] = obj
+            setattr(self, name, obj)
+        file.Close()
+        # Now add these histograms into the current ROOT directory (in memory)
+        # and remove old versions if needed
+        for obj in self.map.values():
+            try:
+                old_obj = ROOT.gDirectory.Get(obj.GetName())
+                ROOT.gDirectory.Remove(old_obj)
+                ROOT.gDirectory.Add(obj)
+            except AttributeError:
+                pass
+        HistCollection.add_collection(self)
 
-def stack_hist(hists,
-               labels=None, id_=None,
-               title="", enable_fill=False,
-               normalize_to=0, draw=False,
-               draw_option="",
-               _stacks={}):
-    """hists should be a list of TH1D objects
-    returns a new stacked histogram
-    """
-    colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
-    stack = ROOT.THStack(id_, title)
-    if labels is None:
-        labels = [hist.GetName() for hist in hists]
-    if type(normalize_to) in (int, float):
-        normalize_to = [normalize_to]*len(hists)
-    if id_ is None:
-        id_ = hists[0].GetName() + "_stack"
-    ens = enumerate(zip(hists, labels, colors, normalize_to))
-    for i, (hist, label, color, norm) in ens:
-        hist_copy = hist
-        hist_copy = hist.Clone(hist.GetName()+"_clone")
-        hist_copy.SetTitle(label)
-        if enable_fill:
-            hist_copy.SetFillColorAlpha(color, 0.75)
-            hist_copy.SetLineColorAlpha(color, 0.75)
-        if norm:
-            integral = hist_copy.Integral()
-            hist_copy.Scale(norm/integral, "nosw2")
-            hist_copy.SetStats(False)
-        stack.Add(hist_copy)
-    if draw:
-        stack.Draw(draw_option)
-    _stacks[id_] = stack  # prevent stack from getting garbage collected
-                          # needed for multipad plots :/
-    return stack
-
-
-def stack_hist_array(canvas, histcollections, fields, titles,
-                     shape=None, **kwargs):
-    def get_hist_set(attrname):
-        hists, labels = zip(*[(getattr(h, attrname), h.sample_name)
-                              for h in histcollections])
-        return hists, labels
-    n_fields = len(fields)
-    if shape is None:
-        if n_fields <= 4:
-            shape = (1, n_fields)
-        else:
-            shape = (ceil(sqrt(n_fields)),)*2
-    canvas.Clear()
-    canvas.Divide(*shape)
-    for i, field, title in zip(bin_range(n_fields), fields, titles):
-        canvas.cd(i)
-        hists, labels = get_hist_set(field)
-        stack_hist(hists, labels, id_=field, title=title, draw=True, **kwargs)
-    canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")
+    def draw(self, canvas, shape=None):
+        if shape is None:
+            n = int(ceil(sqrt(len(self.map))))
+            shape = (n, n)
+        canvas.Clear()
+        canvas.Divide(*shape)
+        i = 1
+        for hist in self.map.values():
+            canvas.cd(i)
+            try:
+                hist.SetStats(False)
+            except AttributeError:
+                pass
+            draw_option = ""
+            if type(hist) in (ROOT.TH1F, ROOT.TH1I, ROOT.TH1D):
+                draw_option = ""
+            elif type(hist) in (ROOT.TH2F, ROOT.TH2I, ROOT.TH2D):
+                draw_option = "COLZ"
+            elif type(hist) in (ROOT.TGraph,):
+                draw_option = "A*"
+            else:
+                print("cannot draw object", hist)
+                continue  # Not a drawable type(probably)
+            hist.Draw(draw_option)
+            i += 1
+
+    @classmethod
+    def add_collection(cls, hc):
+        if not hasattr(cls, "collections"):
+            cls.collections = {}
+        cls.collections[hc.sample_name] = hc
+
+
+    @classmethod
+    def stack_hist(hists,
+                   labels=None, id_=None,
+                   title="", enable_fill=False,
+                   normalize_to=0, draw=False,
+                   draw_option="",
+                   _stacks={}):
+        """hists should be a list of TH1D objects
+        returns a new stacked histogram
+        """
+        colors = it.cycle([ROOT.kRed, ROOT.kBlue, ROOT.kGreen])
+        stack = ROOT.THStack(id_, title)
+        if labels is None:
+            labels = [hist.GetName() for hist in hists]
+        if type(normalize_to) in (int, float):
+            normalize_to = [normalize_to]*len(hists)
+        if id_ is None:
+            id_ = hists[0].GetName() + "_stack"
+        ens = enumerate(zip(hists, labels, colors, normalize_to))
+        for i, (hist, label, color, norm) in ens:
+            hist_copy = hist
+            hist_copy = hist.Clone(hist.GetName()+"_clone")
+            hist_copy.SetTitle(label)
+            if enable_fill:
+                hist_copy.SetFillColorAlpha(color, 0.75)
+                hist_copy.SetLineColorAlpha(color, 0.75)
+            if norm:
+                integral = hist_copy.Integral()
+                hist_copy.Scale(norm/integral, "nosw2")
+                hist_copy.SetStats(False)
+            stack.Add(hist_copy)
+        if draw:
+            stack.Draw(draw_option)
+        _stacks[id_] = stack  # prevent stack from getting garbage collected
+                              # needed for multipad plots :/
+        return stack
+
+
+    @classmethod
+    def stack_hist_array(canvas, histcollections, fields, titles,
+                         shape=None, **kwargs):
+        def get_hist_set(attrname):
+            hists, labels = zip(*[(getattr(h, attrname), h.sample_name)
+                                  for h in histcollections])
+            return hists, labels
+        n_fields = len(fields)
+        if shape is None:
+            if n_fields <= 4:
+                shape = (1, n_fields)
+            else:
+                shape = (ceil(sqrt(n_fields)),)*2
+        canvas.Clear()
+        canvas.Divide(*shape)
+        for i, field, title in zip(bin_range(n_fields), fields, titles):
+            canvas.cd(i)
+            hists, labels = get_hist_set(field)
+            stack_hist(hists, labels, id_=field, title=title, draw=True, **kwargs)
+        canvas.cd(1).BuildLegend(0.75, 0.75, 0.95, 0.95, "")