|
@@ -303,34 +303,70 @@ class CounterMany : public _Counter<std::vector<V>,V>{
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-// TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
|
|
|
|
|
|
template <typename... ArgTypes>
|
|
|
-class MVA : public Container<TMVA::DataLoader,std::tuple<ArgTypes...>>{
|
|
|
+class MVA : public Container<TMVA::DataLoader,MVAData<ArgTypes...>>{
|
|
|
private:
|
|
|
+ std::vector<std::pair<std::string,std::string>> methods;
|
|
|
|
|
|
void _fill(){
|
|
|
- this->container->push_back(this->value->get_value());
|
|
|
- std::vector<double>& v = t2v<double>(this->value->get_value());
|
|
|
- this->container->AddSignalTrainingEvent(v, 1);
|
|
|
- //TODO: Make custom Value type that includes explicit truth info as
|
|
|
- //well as weights and ablity to specify testing/training dataset.
|
|
|
+ std::tuple<ArgTypes...> t;
|
|
|
+ bool is_training;
|
|
|
+ bool is_signal;
|
|
|
+ double weight;
|
|
|
+ std::tie(t, is_training, is_signal, weight) = this->value->get_value();
|
|
|
+ std::vector<double> v = t2v<double>(t);
|
|
|
+ if (is_signal){
|
|
|
+ if (is_training){
|
|
|
+ this->container->AddSignalTrainingEvent(v, weight);
|
|
|
+ } else {
|
|
|
+ this->container->AddSignalTestingEvent(v, weight);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if (is_training){
|
|
|
+ this->container->AddBackgroundTrainingEvent(v, weight);
|
|
|
+ } else {
|
|
|
+ this->container->AddBackgroundTestingEvent(v, weight);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
public:
|
|
|
MVA(const std::string& name, Value<std::tuple<ArgTypes...>>* value, const std::vector<std::string>& labels=std::vector<std::string>())
|
|
|
- :Container<TMVA::DataLoader,std::tuple<ArgTypes...>(name, value){
|
|
|
- this->container = new DataLoader(name);
|
|
|
+ :Container<TMVA::DataLoader,std::tuple<ArgTypes...>>(name, value){
|
|
|
+ this->container = new TMVA::DataLoader(name);
|
|
|
if (labels.size() != sizeof...(ArgTypes)){
|
|
|
- CRITICAL("Length of labels vector ("<<labels.size()<<") not equal to number of MVA arguments ("<<sizeof...(ArgTypes)<<")",-1)
|
|
|
+ CRITICAL("Length of labels vector ("<<labels.size()<<") not equal to number of MVA arguments ("<<sizeof...(ArgTypes)<<")",-1);
|
|
|
}
|
|
|
- for(std::string& label : labels){
|
|
|
+ for(const std::string& label : labels){
|
|
|
this->container->AddVariable(label, 'F');
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ void add_method(const std::string& method_name, const std::string& method_params) {
|
|
|
+ methods.push_back(std::make_pair(method_name, method_params));
|
|
|
+ }
|
|
|
+
|
|
|
void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
|
|
|
- std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(T))+">";
|
|
|
- util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
|
|
|
+ TFile* outputFile = gDirectory->GetFile();
|
|
|
+ TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
|
|
|
+ "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
|
|
|
+
|
|
|
+ TMVA::Types& types = TMVA::Types::Instance();
|
|
|
+ for(auto& p : methods){
|
|
|
+ std::string method_name, method_params;
|
|
|
+ std::tie(method_name, method_params) = p;
|
|
|
+ TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
|
|
|
+ factory->BookMethod(this->container, method_type, method_name, method_params);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Train MVAs using the set of training events
|
|
|
+ factory->TrainAllMethods();
|
|
|
+ // Evaluate all MVAs using the set of test events
|
|
|
+ factory->TestAllMethods();
|
|
|
+ // Evaluate and compare performance of all configured MVAs
|
|
|
+ factory->EvaluateAllMethods();
|
|
|
+
|
|
|
+ delete factory;
|
|
|
}
|
|
|
};
|
|
|
}
|