123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- #ifndef root_container_hpp
- #define root_container_hpp
- #include <iostream>
- #include <utility>
- #include <map>
- #include "TROOT.h"
- #include "TFile.h"
- #include "TCanvas.h"
- #include "TGraph.h"
- #include "TH1.h"
- #include "TH2.h"
- #include "TMVA/Factory.h"
- #include "TMVA/DataLoader.h"
- #include "TMVA/DataSetInfo.h"
- #include "filval/container.hpp"
- namespace fv::root::util{
- /**
- * Save a TObject. The TObject will typically be a Histogram or Graph object,
- * but can really be any TObject. The SaveOption can be used to specify how to
- * save the file.
- */
- void save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- auto save_img = [](TObject* container, const std::string& fname){
- TCanvas* c1 = new TCanvas("c1");
- container->Draw();
- c1->Draw();
- c1->SaveAs(fname.c_str());
- delete c1;
- };
- auto save_bin = [](TObject* container){
- INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
- container->Write(container->GetName(), TObject::kOverwrite);
- };
- switch(option){
- case PNG:
- save_img(container, fname+".png"); break;
- case PDF:
- save_img(container, fname+".pdf"); break;
- case ROOT:
- save_bin(container); break;
- default:
- break;
- }
- }
- /**
- * Saves an STL container into a ROOT file. ROOT knows how to serialize STL
- * containers, but it needs the *name* of the type of the container, eg.
- * std::map<int,int> to be able to do this. In order to generate this name at
- * run-time, the fv::util::get_type_name function uses RTTI to get type info
- * and use it to look up the proper name.
- *
- * For nexted containers, it is necessary to generate the CLING dictionaries
- * for each type at compile time to enable serialization. To do this, add the
- * type definition into the LinkDef.hpp header file.
- */
- void save_as_stl(void* container, const std::string& type_name,
- const std::string& obj_name,
- const SaveOption& option = SaveOption::PNG) {
- switch(option){
- case PNG:
- INFO("Cannot save STL container " << type_name <<" as png");
- break;
- case PDF:
- INFO("Cannot save STL container " << type_name <<" as pdf");
- break;
- case ROOT:
- /* DEBUG("Writing object \"" << obj_name << "\" of type \"" << type_name << "\"\n"); */
- gDirectory->WriteObjectAny(container, type_name.c_str(), obj_name.c_str());
- break;
- default:
- break;
- }
- }
- }
- namespace fv::root {
- template <typename V>
- class _ContainerTH1 : public Container<TH1,V>{
- private:
- void _fill(){
- if (this->container == nullptr){
- if (this->value == nullptr){
- CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
- << "Probably built with imcompatible type",-1);
- }
- this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
- this->nbins, this->low, this->high);
- this->container->SetXTitle(label_x.c_str());
- this->container->SetYTitle(label_y.c_str());
- }
- _do_fill();
- }
- protected:
- std::string title;
- std::string label_x;
- std::string label_y;
- int nbins;
- double low;
- double high;
- virtual void _do_fill() = 0;
- public:
- explicit _ContainerTH1(const std::string &name, const std::string& title, Value<V>* value,
- int nbins, double low, double high,
- const std::string& label_x = "",
- const std::string& label_y = "")
- :Container<TH1,V>(name, value),
- title(title), nbins(nbins), low(low), high(high),
- label_x(label_x), label_y(label_y) { }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- util::save_as(this->get_container(), fname, option);
- }
- };
- template <typename V>
- class ContainerTH1 : public _ContainerTH1<V>{
- using _ContainerTH1<V>::_ContainerTH1;
- void _do_fill(){
- this->container->Fill(this->value->get_value());
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new ContainerTH1<V>(new_name, this->title, this->value, this->nbins, this->low, this->high, this->label_x, this->label_y);
- }
- };
- template <typename V>
- class ContainerTH1Many : public _ContainerTH1<std::vector<V>>{
- using _ContainerTH1<std::vector<V>>::_ContainerTH1;
- void _do_fill(){
- for(V x : this->value->get_value())
- this->container->Fill(x);
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new ContainerTH1Many<V>(new_name, this->title, this->value, this->nbins, this->low, this->high, this->label_x, this->label_y);
- }
- };
- template <typename V>
- class _ContainerTH2 : public Container<TH2,std::pair<V,V>>{
- private:
- void _fill(){
- if (this->container == nullptr){
- if (this->value == nullptr){
- CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
- << "Probably built with imcompatible type",-1);
- }
- this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
- this->nbins_x, this->low_x, this->high_x,
- this->nbins_y, this->low_y, this->high_y);
- this->container->SetXTitle(label_x.c_str());
- this->container->SetYTitle(label_y.c_str());
- }
- _do_fill(this->value->get_value());
- }
- protected:
- std::string title;
- std::string label_x;
- std::string label_y;
- int nbins_x;
- int nbins_y;
- double low_x;
- double low_y;
- double high_x;
- double high_y;
- virtual void _do_fill(std::pair<V,V>& val) = 0;
- public:
- explicit _ContainerTH2(const std::string& name, const std::string& title,
- Value<std::pair<V, V>>* value,
- int nbins_x, double low_x, double high_x,
- int nbins_y, double low_y, double high_y,
- const std::string& label_x = "",
- const std::string& label_y = "")
- :Container<TH2,std::pair<V,V>>(name, value),
- title(title),
- nbins_x(nbins_x), low_x(low_x), high_x(high_x),
- nbins_y(nbins_y), low_y(low_y), high_y(high_y),
- label_x(label_x), label_y(label_y) { }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- util::save_as(this->get_container(), fname, option);
- }
- };
- template <typename V>
- class ContainerTH2 : public _ContainerTH2<V>{
- using _ContainerTH2<V>::_ContainerTH2;
- void _do_fill(std::pair<V,V>& val){
- this->container->Fill(val.first,val.second);
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new ContainerTH2<V>(new_name, this->title, this->value, this->nbins_x, this->low_x, this->high_x,
- this->nbins_y, this->low_y, this->high_y, this->label_x, this->label_y);
- }
- };
- template <typename V>
- class ContainerTH2Many : public _ContainerTH2<std::vector<V>>{
- using _ContainerTH2<std::vector<V>>::_ContainerTH2;
- void _do_fill(std::pair<std::vector<V>,std::vector<V>>& val){
- int min_size = std::min(val.first.size(), val.second.size());
- for(int i=0; i<min_size; i++)
- this->container->Fill(val.first[i],val.second[i]);
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new ContainerTH2Many<V>(new_name, this->title, this->value, this->nbins_x, this->low_x, this->high_x,
- this->nbins_y, this->low_y, this->high_y, this->label_x, this->label_y);
- }
- };
- template <typename V>
- class ContainerTGraph : public Container<TGraph,std::pair<V,V>>{
- private:
- std::vector<V> x_data;
- std::vector<V> y_data;
- std::string title;
- bool data_modified;
- void _fill(){
- auto val = this->value->get_value();
- x_data.push_back(val.first);
- y_data.push_back(val.second);
- data_modified = true;
- }
- public:
- ContainerTGraph(const std::string& name, const std::string& title, Value<std::pair<V, V>>* value)
- :Container<TGraph,std::pair<V,V>>(name, value),
- data_modified(false){
- this->container = new TGraph();
- }
- TGraph* get_container(){
- if (data_modified){
- delete this->container;
- this->container = new TGraph(x_data.size(), x_data.data(), y_data.data());
- this->container->SetName(this->get_name().c_str());
- this->container->SetTitle(title.c_str());
- data_modified = false;
- }
- return this->container;
- }
- GenContainer* clone_as(const std::string& new_name){
- return new ContainerTGraph<V>(new_name, this->title, this->value);
- }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- util::save_as(get_container(), fname, option);
- }
- };
- template <typename V>
- class Vector : public Container<std::vector<V>,V>{
- private:
- void _fill(){
- this->container->push_back(this->value->get_value());
- }
- public:
- Vector(const std::string& name, Value<V>* value)
- :Container<std::vector<V>,V>(name, value){
- this->container = new std::vector<V>;
- }
- GenContainer* clone_as(const std::string& new_name){
- return new Vector<V>(new_name, this->value);
- }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(V))+">";
- util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
- }
- };
- template <typename V, typename D>
- class _Counter : public Container<std::map<D,int>,V>{
- public:
- explicit _Counter(const std::string& name, Value<V>* value)
- :Container<std::map<D,int>,V>(name, value) {
- this->container = new std::map<D,int>;
- }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- std::string type_name = "std::map<"+fv::util::get_type_name(typeid(D))+",int>";
- util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
- }
- };
- /**
- * A Counter that keeps a mapping of the number of occurances of each input
- * value.
- */
- template <typename V>
- class Counter : public _Counter<V,V>{
- using _Counter<V,V>::_Counter;
- void _fill(){
- (*this->container)[this->value->get_value()]++;
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new Counter<V>(new_name, this->value);
- }
- };
- /**
- * Same as Counter but accepts multiple values per fill.
- */
- template <typename V>
- class CounterMany : public _Counter<std::vector<V>,V>{
- using _Counter<std::vector<V>,V>::_Counter;
- void _fill(){
- for(V& val : this->value->get_value())
- (*this->container)[val]++;
- }
- public:
- GenContainer* clone_as(const std::string& new_name){
- return new CounterMany<V>(new_name, this->value);
- }
- };
- class PassCount : public Container<int,bool>{
- private:
- void _fill(){
- if(this->value->get_value()){
- (*this->container)++;
- }
- }
- public:
- PassCount(const std::string& name, Value<bool>* value)
- :Container<int,bool>(name, value){
- this->container = new int(0);
- }
- GenContainer* clone_as(const std::string& new_name){
- return new PassCount(new_name, this->value);
- }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- //ROOT(hilariously) cannot serialize basic data types, we wrap this
- //in a vector.
- std::vector<int> v({*this->get_container()});
- util::save_as_stl(&v, "std::vector<int>", this->get_name(), option);
- }
- };
- template <typename... ArgTypes>
- class MVA : public Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>{
- private:
- std::vector<std::pair<std::string,std::string>> methods;
- std::string cut;
- std::string opt;
- void _fill(){
- std::tuple<ArgTypes...> t;
- typename MVAData<ArgTypes...>::type& event = this->value->get_value();
- bool is_training, is_signal;
- double weight;
- std::tie(is_training, is_signal, weight, t) = event;
- std::vector<double> v = t2v<double>(t);
- if (is_signal){
- if (is_training){
- this->container->AddSignalTrainingEvent(v, weight);
- } else {
- this->container->AddSignalTestEvent(v, weight);
- }
- } else {
- if (is_training){
- this->container->AddBackgroundTrainingEvent(v, weight);
- } else {
- this->container->AddBackgroundTestEvent(v, weight);
- }
- }
- }
- public:
- MVA(const std::string& name, MVAData<ArgTypes...>* value, const std::string& cut = "", const std::string& opt = "")
- :Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>(name, value),
- cut(cut), opt(opt) {
- this->container = new TMVA::DataLoader(name);
- for (std::pair<std::string,char>& p : value->get_label_types()){
- this->container->AddVariable(p.first, p.second);
- }
- }
- void add_method(const std::string& method_name, const std::string& method_params) {
- methods.push_back(std::make_pair(method_name, method_params));
- }
- GenContainer* clone_as(const std::string& new_name){
- auto mva = new MVA<ArgTypes...>(new_name, (MVAData<ArgTypes...>*)this->value, this->cut, this->opt);
- mva->methods = methods;
- return mva;
- }
- void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
- TFile* outputFile = gDirectory->GetFile();
- this->container->PrepareTrainingAndTestTree(cut.c_str(), opt.c_str());
- TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
- "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
- TMVA::Types& types = TMVA::Types::Instance();
- for(auto& p : methods){
- std::string method_name, method_params;
- std::tie(method_name, method_params) = p;
- TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
- factory->BookMethod(this->container, method_type, method_name, method_params);
- }
- // Train MVAs using the set of training events
- factory->TrainAllMethods();
- // Evaluate all MVAs using the set of test events
- factory->TestAllMethods();
- // Evaluate and compare performance of all configured MVAs
- factory->EvaluateAllMethods();
- delete factory;
- }
- };
- }
- #endif // root_container_hpp
|