Browse Source

Continuing work on integrating MVA

Caleb Fangmeier 7 years ago
parent
commit
6c1ddf6ba7
3 changed files with 91 additions and 13 deletions
  1. 13 0
      analysis/TTTT_Analysis.cpp
  2. 49 13
      filval_root/container.hpp
  3. 29 0
      filval_root/value.hpp

+ 13 - 0
analysis/TTTT_Analysis.cpp

@@ -37,6 +37,7 @@
 #include <iostream>
 #include <vector>
 #include <utility>
+#include <numeric>
 #include <limits>
 
 #include "filval/filval.hpp"
@@ -251,6 +252,18 @@ void declare_values(MiniTreeDataSet& mt){
     fv::pair<int, int>("genMu_count",  "recMu_count",  "genMu_count_v_recMu_count");
     fv::pair<int, int>("genLep_count", "recLep_count", "genLep_count_v_recLep_count");
 
+    /* auto& sum = GenFunction::register_function<float(std::vector<float>)>("sum", */
+    /*     FUNC(([](const std::vector<float>& v){ */
+    /*         return std::accumulate(v.begin(), v.end(), 0); */
+    /*     }))); */
+
+    /* auto sum_jet_pt = fv::apply(sum, lookup<std::vector<float>>("Jet_pt")) */
+
+    /* fv::tuple(lookup<float>("nJet"), */
+    /*           lookup<float>("nLepGood"), */
+    /*           lookup<float>("Jet_phi"), */
+    /*           lookup<std::vector<float>>("Jet_mass"), */
+
     obs_filter("trilepton", FUNC(([nLepGood=lookup<int>("nLepGood")]()
         {
             return dynamic_cast<Value<int>*>(nLepGood)->get_value() == 3;

+ 49 - 13
filval_root/container.hpp

@@ -303,34 +303,70 @@ class CounterMany : public _Counter<std::vector<V>,V>{
         }
 };
 
-// TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
 
 template <typename... ArgTypes>
-class MVA : public Container<TMVA::DataLoader,std::tuple<ArgTypes...>>{
+class MVA : public Container<TMVA::DataLoader,MVAData<ArgTypes...>>{
     private:
+        std::vector<std::pair<std::string,std::string>> methods;
 
         void _fill(){
-            this->container->push_back(this->value->get_value());
-            std::vector<double>& v = t2v<double>(this->value->get_value());
-            this->container->AddSignalTrainingEvent(v, 1);
-            //TODO: Make custom Value type that includes explicit truth info as
-            //well as weights and ablity to specify testing/training dataset.
+            std::tuple<ArgTypes...> t;
+            bool is_training;
+            bool is_signal;
+            double weight;
+            std::tie(t, is_training, is_signal, weight) = this->value->get_value();
+            std::vector<double> v = t2v<double>(t);
+            if (is_signal){
+                if (is_training){
+                    this->container->AddSignalTrainingEvent(v, weight);
+                } else {
+                    this->container->AddSignalTestingEvent(v, weight);
+                }
+            } else {
+                if (is_training){
+                    this->container->AddBackgroundTrainingEvent(v, weight);
+                } else {
+                    this->container->AddBackgroundTestingEvent(v, weight);
+                }
+            }
         }
     public:
         MVA(const std::string& name, Value<std::tuple<ArgTypes...>>* value, const std::vector<std::string>& labels=std::vector<std::string>())
-          :Container<TMVA::DataLoader,std::tuple<ArgTypes...>(name, value){
-            this->container = new DataLoader(name);
+          :Container<TMVA::DataLoader,std::tuple<ArgTypes...>>(name, value){
+            this->container = new TMVA::DataLoader(name);
             if (labels.size() != sizeof...(ArgTypes)){
-                CRITICAL("Length of labels vector ("<<labels.size()<<") not equal to number of MVA arguments ("<<sizeof...(ArgTypes)<<")",-1)
+                CRITICAL("Length of labels vector ("<<labels.size()<<") not equal to number of MVA arguments ("<<sizeof...(ArgTypes)<<")",-1);
             }
-            for(std::string& label : labels){
+            for(const std::string& label : labels){
                 this->container->AddVariable(label, 'F');
             }
         }
 
+        void add_method(const std::string& method_name, const std::string& method_params) {
+            methods.push_back(std::make_pair(method_name, method_params));
+        }
+
         void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
-            std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(T))+">";
-            util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
+            TFile* outputFile = gDirectory->GetFile();
+            TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
+                                                       "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
+
+            TMVA::Types& types = TMVA::Types::Instance();
+            for(auto& p : methods){
+                std::string method_name, method_params;
+                std::tie(method_name, method_params)  = p;
+                TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
+                factory->BookMethod(this->container, method_type, method_name, method_params);
+            }
+
+            // Train MVAs using the set of training events
+            factory->TrainAllMethods();
+            // Evaluate all MVAs using the set of test events
+            factory->TestAllMethods();
+            // Evaluate and compare performance of all configured MVAs
+            factory->EvaluateAllMethods();
+
+            delete factory;
         }
 };
 }

+ 29 - 0
filval_root/value.hpp

@@ -65,5 +65,34 @@ class Energies : public DerivedValue<std::vector<float>>{
            vectors(vectors) { }
 };
 
+template<typename... DataTypes>
+class MVAData : public DerivedValue<std::tuple<std::tuple<DataTypes...>, bool, bool, double>>{
+    private:
+        Value<std::tuple<DataTypes...>>* data;
+        Value<bool>* is_training;
+        Value<bool>* is_signal;
+        Value<double>* weight;
+
+    public:
+        static std::string fmt_name(Value<std::tuple<DataTypes...>>* data,
+                                    Value<bool>* is_training,
+                                    Value<bool>* is_signal,
+                                    Value<double>* weight){
+            return "mva_data("+data->get_name()+","
+                              +is_training->get_name()+","
+                              +is_signal->get_name()+","
+                              +weight->get_name()+")";
+        }
+
+        MVAData(Value<std::tuple<DataTypes...>>* data,
+                Value<bool>* is_training,
+                Value<bool>* is_signal,
+                Value<double>* weight, const std::string& alias="")
+          :DerivedValue<std::tuple<std::tuple<DataTypes...>, bool, double>>(fmt_name(data, is_training, is_signal, weight), alias),
+           data(data),
+           is_training(is_training),
+           is_signal(is_signal),
+           weight(weight) { }
+};
 }
 #endif // root_value_hpp