container.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #ifndef root_container_hpp
  2. #define root_container_hpp
  3. #include <iostream>
  4. #include <utility>
  5. #include <map>
  6. #include "TROOT.h"
  7. #include "TFile.h"
  8. #include "TCanvas.h"
  9. #include "TGraph.h"
  10. #include "TH1.h"
  11. #include "TH2.h"
  12. #include "TMVA/Factory.h"
  13. #include "TMVA/DataLoader.h"
  14. #include "filval/container.hpp"
  15. namespace fv::root::util{
  16. /**
  17. * Save a TObject. The TObject will typically be a Histogram or Graph object,
  18. * but can really be any TObject. The SaveOption can be used to specify how to
  19. * save the file.
  20. */
  21. void save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  22. auto save_img = [](TObject* container, const std::string& fname){
  23. TCanvas* c1 = new TCanvas("c1");
  24. container->Draw();
  25. c1->Draw();
  26. c1->SaveAs(fname.c_str());
  27. delete c1;
  28. };
  29. auto save_bin = [](TObject* container){
  30. INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
  31. container->Write(container->GetName(), TObject::kOverwrite);
  32. };
  33. switch(option){
  34. case PNG:
  35. save_img(container, fname+".png"); break;
  36. case PDF:
  37. save_img(container, fname+".pdf"); break;
  38. case ROOT:
  39. save_bin(container); break;
  40. default:
  41. break;
  42. }
  43. }
  44. /**
  45. * Saves an STL container into a ROOT file. ROOT knows how to serialize STL
  46. * containers, but it needs the *name* of the type of the container, eg.
  47. * std::map<int,int> to be able to do this. In order to generate this name at
  48. * run-time, the fv::util::get_type_name function uses RTTI to get type info
  49. * and use it to look up the proper name.
  50. *
  51. * For nexted containers, it is necessary to generate the CLING dictionaries
  52. * for each type at compile time to enable serialization. To do this, add the
  53. * type definition into the LinkDef.hpp header file.
  54. */
  55. void save_as_stl(void* container, const std::string& type_name,
  56. const std::string& obj_name,
  57. const SaveOption& option = SaveOption::PNG) {
  58. switch(option){
  59. case PNG:
  60. INFO("Cannot save STL container " << type_name <<" as png");
  61. break;
  62. case PDF:
  63. INFO("Cannot save STL container " << type_name <<" as pdf");
  64. break;
  65. case ROOT:
  66. gDirectory->WriteObjectAny(container, type_name.c_str(), obj_name.c_str());
  67. break;
  68. default:
  69. break;
  70. }
  71. }
  72. }
  73. namespace fv::root {
  74. template <typename V>
  75. class _ContainerTH1 : public Container<TH1,V>{
  76. private:
  77. void _fill(){
  78. if (this->container == nullptr){
  79. if (this->value == nullptr){
  80. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  81. << "Probably built with imcompatible type",-1);
  82. }
  83. this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
  84. this->nbins, this->low, this->high);
  85. this->container->SetXTitle(label_x.c_str());
  86. this->container->SetYTitle(label_y.c_str());
  87. }
  88. _do_fill();
  89. }
  90. protected:
  91. std::string title;
  92. std::string label_x;
  93. std::string label_y;
  94. int nbins;
  95. double low;
  96. double high;
  97. virtual void _do_fill() = 0;
  98. public:
  99. explicit _ContainerTH1(const std::string &name, const std::string& title, Value<V>* value,
  100. int nbins, double low, double high,
  101. const std::string& label_x = "",
  102. const std::string& label_y = "")
  103. :Container<TH1,V>(name, value),
  104. title(title), nbins(nbins), low(low), high(high),
  105. label_x(label_x), label_y(label_y) { }
  106. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  107. util::save_as(this->get_container(), fname, option);
  108. }
  109. };
  110. template <typename V>
  111. class ContainerTH1 : public _ContainerTH1<V>{
  112. using _ContainerTH1<V>::_ContainerTH1;
  113. void _do_fill(){
  114. this->container->Fill(this->value->get_value());
  115. }
  116. };
  117. template <typename V>
  118. class ContainerTH1Many : public _ContainerTH1<std::vector<V>>{
  119. using _ContainerTH1<std::vector<V>>::_ContainerTH1;
  120. void _do_fill(){
  121. for(V x : this->value->get_value())
  122. this->container->Fill(x);
  123. }
  124. };
  125. template <typename V>
  126. class _ContainerTH2 : public Container<TH2,std::pair<V,V>>{
  127. private:
  128. void _fill(){
  129. if (this->container == nullptr){
  130. if (this->value == nullptr){
  131. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  132. << "Probably built with imcompatible type",-1);
  133. }
  134. this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
  135. this->nbins_x, this->low_x, this->high_x,
  136. this->nbins_y, this->low_y, this->high_y);
  137. this->container->SetXTitle(label_x.c_str());
  138. this->container->SetYTitle(label_y.c_str());
  139. }
  140. _do_fill(this->value->get_value());
  141. }
  142. protected:
  143. std::string title;
  144. std::string label_x;
  145. std::string label_y;
  146. int nbins_x;
  147. int nbins_y;
  148. double low_x;
  149. double low_y;
  150. double high_x;
  151. double high_y;
  152. virtual void _do_fill(std::pair<V,V>& val) = 0;
  153. public:
  154. explicit _ContainerTH2(const std::string& name, const std::string& title,
  155. Value<std::pair<V, V>>* value,
  156. int nbins_x, double low_x, double high_x,
  157. int nbins_y, double low_y, double high_y,
  158. const std::string& label_x = "",
  159. const std::string& label_y = "")
  160. :Container<TH2,std::pair<V,V>>(name, value),
  161. title(title),
  162. nbins_x(nbins_x), low_x(low_x), high_x(high_x),
  163. nbins_y(nbins_y), low_y(low_y), high_y(high_y),
  164. label_x(label_x), label_y(label_y) { }
  165. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  166. util::save_as(this->get_container(), fname, option);
  167. }
  168. };
  169. template <typename V>
  170. class ContainerTH2 : public _ContainerTH2<V>{
  171. using _ContainerTH2<V>::_ContainerTH2;
  172. void _do_fill(std::pair<V,V>& val){
  173. this->container->Fill(val.first,val.second);
  174. }
  175. };
  176. template <typename V>
  177. class ContainerTH2Many : public _ContainerTH2<std::vector<V>>{
  178. using _ContainerTH2<std::vector<V>>::_ContainerTH2;
  179. void _do_fill(std::pair<std::vector<V>,std::vector<V>>& val){
  180. int min_size = std::min(val.first.size(), val.second.size());
  181. for(int i=0; i<min_size; i++)
  182. this->container->Fill(val.first[i],val.second[i]);
  183. }
  184. };
  185. template <typename V>
  186. class ContainerTGraph : public Container<TGraph,std::pair<V,V>>{
  187. private:
  188. std::vector<V> x_data;
  189. std::vector<V> y_data;
  190. std::string title;
  191. bool data_modified;
  192. void _fill(){
  193. auto val = this->value->get_value();
  194. x_data.push_back(val.first);
  195. y_data.push_back(val.second);
  196. data_modified = true;
  197. }
  198. public:
  199. ContainerTGraph(const std::string& name, const std::string& title, Value<std::pair<V, V>>* value)
  200. :Container<TGraph,std::pair<V,V>>(name, value),
  201. data_modified(false){
  202. this->container = new TGraph();
  203. }
  204. TGraph* get_container(){
  205. if (data_modified){
  206. delete this->container;
  207. this->container = new TGraph(x_data.size(), x_data.data(), y_data.data());
  208. this->container->SetName(this->get_name().c_str());
  209. this->container->SetTitle(title.c_str());
  210. data_modified = false;
  211. }
  212. return this->container;
  213. }
  214. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  215. util::save_as(get_container(), fname, option);
  216. }
  217. };
  218. template <typename T>
  219. class Vector : public Container<std::vector<T>,T>{
  220. private:
  221. void _fill(){
  222. this->container->push_back(this->value->get_value());
  223. }
  224. public:
  225. Vector(const std::string& name, Value<T>* value)
  226. :Container<std::vector<T>,T>(name, value){
  227. this->container = new std::vector<T>;
  228. }
  229. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  230. std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(T))+">";
  231. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  232. }
  233. };
  234. template <typename V, typename D>
  235. class _Counter : public Container<std::map<D,int>,V>{
  236. public:
  237. explicit _Counter(const std::string& name, Value<V>* value)
  238. :Container<std::map<D,int>,V>(name, value) {
  239. this->container = new std::map<D,int>;
  240. }
  241. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  242. std::string type_name = "std::map<"+fv::util::get_type_name(typeid(D))+",int>";
  243. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  244. }
  245. };
  246. /**
  247. * A Counter that keeps a mapping of the number of occurances of each input
  248. * value.
  249. */
  250. template <typename V>
  251. class Counter : public _Counter<V,V>{
  252. using _Counter<V,V>::_Counter;
  253. void _fill(){
  254. (*this->container)[this->value->get_value()]++;
  255. }
  256. };
  257. /**
  258. * Same as Counter but accepts multiple values per fill.
  259. */
  260. template <typename V>
  261. class CounterMany : public _Counter<std::vector<V>,V>{
  262. using _Counter<std::vector<V>,V>::_Counter;
  263. void _fill(){
  264. for(V& val : this->value->get_value())
  265. (*this->container)[val]++;
  266. }
  267. };
  268. template <typename... ArgTypes>
  269. class MVA : public Container<TMVA::DataLoader,MVAData<ArgTypes...>>{
  270. private:
  271. std::vector<std::pair<std::string,std::string>> methods;
  272. void _fill(){
  273. std::tuple<ArgTypes...> t;
  274. bool is_training;
  275. bool is_signal;
  276. double weight;
  277. std::tie(t, is_training, is_signal, weight) = this->value->get_value();
  278. std::vector<double> v = t2v<double>(t);
  279. if (is_signal){
  280. if (is_training){
  281. this->container->AddSignalTrainingEvent(v, weight);
  282. } else {
  283. this->container->AddSignalTestingEvent(v, weight);
  284. }
  285. } else {
  286. if (is_training){
  287. this->container->AddBackgroundTrainingEvent(v, weight);
  288. } else {
  289. this->container->AddBackgroundTestingEvent(v, weight);
  290. }
  291. }
  292. }
  293. public:
  294. MVA(const std::string& name, Value<std::tuple<ArgTypes...>>* value, const std::vector<std::string>& labels=std::vector<std::string>())
  295. :Container<TMVA::DataLoader,std::tuple<ArgTypes...>>(name, value){
  296. this->container = new TMVA::DataLoader(name);
  297. if (labels.size() != sizeof...(ArgTypes)){
  298. CRITICAL("Length of labels vector ("<<labels.size()<<") not equal to number of MVA arguments ("<<sizeof...(ArgTypes)<<")",-1);
  299. }
  300. for(const std::string& label : labels){
  301. this->container->AddVariable(label, 'F');
  302. }
  303. }
  304. void add_method(const std::string& method_name, const std::string& method_params) {
  305. methods.push_back(std::make_pair(method_name, method_params));
  306. }
  307. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  308. TFile* outputFile = gDirectory->GetFile();
  309. TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
  310. "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
  311. TMVA::Types& types = TMVA::Types::Instance();
  312. for(auto& p : methods){
  313. std::string method_name, method_params;
  314. std::tie(method_name, method_params) = p;
  315. TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
  316. factory->BookMethod(this->container, method_type, method_name, method_params);
  317. }
  318. // Train MVAs using the set of training events
  319. factory->TrainAllMethods();
  320. // Evaluate all MVAs using the set of test events
  321. factory->TestAllMethods();
  322. // Evaluate and compare performance of all configured MVAs
  323. factory->EvaluateAllMethods();
  324. delete factory;
  325. }
  326. };
  327. }
  328. #endif // root_container_hpp