container.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. #ifndef root_container_hpp
  2. #define root_container_hpp
  3. #include <iostream>
  4. #include <utility>
  5. #include <map>
  6. #include "TROOT.h"
  7. #include "TFile.h"
  8. #include "TCanvas.h"
  9. #include "TGraph.h"
  10. #include "TH1.h"
  11. #include "TH2.h"
  12. #include "TMVA/Factory.h"
  13. #include "TMVA/DataLoader.h"
  14. #include "TMVA/DataSetInfo.h"
  15. #include "filval/container.hpp"
  16. namespace fv::root::util{
  17. /**
  18. * Save a TObject. The TObject will typically be a Histogram or Graph object,
  19. * but can really be any TObject. The SaveOption can be used to specify how to
  20. * save the file.
  21. */
  22. void save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  23. auto save_img = [](TObject* container, const std::string& fname){
  24. TCanvas* c1 = new TCanvas("c1");
  25. container->Draw();
  26. c1->Draw();
  27. c1->SaveAs(fname.c_str());
  28. delete c1;
  29. };
  30. auto save_bin = [](TObject* container){
  31. INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
  32. container->Write(container->GetName(), TObject::kOverwrite);
  33. };
  34. switch(option){
  35. case PNG:
  36. save_img(container, fname+".png"); break;
  37. case PDF:
  38. save_img(container, fname+".pdf"); break;
  39. case ROOT:
  40. save_bin(container); break;
  41. default:
  42. break;
  43. }
  44. }
  45. /**
  46. * Saves an STL container into a ROOT file. ROOT knows how to serialize STL
  47. * containers, but it needs the *name* of the type of the container, eg.
  48. * std::map<int,int> to be able to do this. In order to generate this name at
  49. * run-time, the fv::util::get_type_name function uses RTTI to get type info
  50. * and use it to look up the proper name.
  51. *
  52. * For nexted containers, it is necessary to generate the CLING dictionaries
  53. * for each type at compile time to enable serialization. To do this, add the
  54. * type definition into the LinkDef.hpp header file.
  55. */
  56. void save_as_stl(void* container, const std::string& type_name,
  57. const std::string& obj_name,
  58. const SaveOption& option = SaveOption::PNG) {
  59. switch(option){
  60. case PNG:
  61. INFO("Cannot save STL container " << type_name <<" as png");
  62. break;
  63. case PDF:
  64. INFO("Cannot save STL container " << type_name <<" as pdf");
  65. break;
  66. case ROOT:
  67. /* DEBUG("Writing object \"" << obj_name << "\" of type \"" << type_name << "\"\n"); */
  68. gDirectory->WriteObjectAny(container, type_name.c_str(), obj_name.c_str());
  69. break;
  70. default:
  71. break;
  72. }
  73. }
  74. }
  75. namespace fv::root {
  76. struct TH1Params{
  77. std::string label_x;
  78. int nbins;
  79. double low;
  80. double high;
  81. std::string label_y;
  82. };
  83. template <typename V>
  84. class _ContainerTH1 : public Container<TH1,V>{
  85. private:
  86. void _fill(){
  87. if (this->container == nullptr){
  88. if (this->value == nullptr){
  89. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  90. << "Probably built with imcompatible type",-1);
  91. }
  92. this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
  93. params.nbins, params.low, params.high);
  94. this->container->SetXTitle(params.label_x.c_str());
  95. this->container->SetYTitle(params.label_y.c_str());
  96. }
  97. _do_fill();
  98. }
  99. protected:
  100. std::string title;
  101. TH1Params params;
  102. virtual void _do_fill() = 0;
  103. public:
  104. explicit _ContainerTH1(const std::string& name, Value<V>* value,
  105. const std::string& title,
  106. const TH1Params& params)
  107. :Container<TH1,V>(name, value),
  108. title(title),
  109. params(params) { }
  110. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  111. util::save_as(this->get_container(), fname, option);
  112. }
  113. };
  114. template <typename V>
  115. class ContainerTH1 : public _ContainerTH1<V>{
  116. using _ContainerTH1<V>::_ContainerTH1;
  117. void _do_fill(){
  118. this->container->Fill(this->value->get_value());
  119. }
  120. public:
  121. GenContainer* clone_as(const std::string& new_name){
  122. return new ContainerTH1<V>(new_name, this->value, this->title, this->params);
  123. }
  124. };
  125. template <typename V>
  126. class ContainerTH1Many : public _ContainerTH1<std::vector<V>>{
  127. using _ContainerTH1<std::vector<V>>::_ContainerTH1;
  128. void _do_fill(){
  129. for(const V &x : this->value->get_value())
  130. this->container->Fill(x);
  131. }
  132. public:
  133. GenContainer* clone_as(const std::string& new_name){
  134. return new ContainerTH1Many<V>(new_name, this->value, this->title, this->params);
  135. }
  136. };
  137. struct TH2Params{
  138. std::string label_x;
  139. int nbins_x;
  140. double low_x;
  141. double high_x;
  142. std::string label_y;
  143. int nbins_y;
  144. double low_y;
  145. double high_y;
  146. };
  147. template <typename V>
  148. class _ContainerTH2 : public Container<TH2,std::pair<V,V>>{
  149. private:
  150. void _fill(){
  151. if (this->container == nullptr){
  152. if (this->value == nullptr){
  153. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  154. << "Probably built with imcompatible type",-1);
  155. }
  156. this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
  157. params.nbins_x, params.low_x, params.high_x,
  158. params.nbins_y, params.low_y, params.high_y);
  159. this->container->SetXTitle(params.label_x.c_str());
  160. this->container->SetYTitle(params.label_y.c_str());
  161. }
  162. _do_fill(this->value->get_value());
  163. }
  164. protected:
  165. std::string title;
  166. TH2Params params;
  167. virtual void _do_fill(const std::pair<V,V>& val) = 0;
  168. public:
  169. explicit _ContainerTH2(const std::string& name, Value<std::pair<V, V>>* value,
  170. const std::string& title,
  171. TH2Params params)
  172. :Container<TH2,std::pair<V,V>>(name, value),
  173. title(title),
  174. params(params) { }
  175. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  176. util::save_as(this->get_container(), fname, option);
  177. }
  178. };
  179. template <typename V>
  180. class ContainerTH2 : public _ContainerTH2<V>{
  181. using _ContainerTH2<V>::_ContainerTH2;
  182. void _do_fill(const std::pair<V,V>& val){
  183. this->container->Fill(val.first,val.second);
  184. }
  185. public:
  186. GenContainer* clone_as(const std::string& new_name){
  187. return new ContainerTH2<V>(new_name, this->value, this->title, this->params);
  188. }
  189. };
  190. template <typename V>
  191. class ContainerTH2Many : public _ContainerTH2<std::vector<V>>{
  192. using _ContainerTH2<std::vector<V>>::_ContainerTH2;
  193. void _do_fill(const std::pair<std::vector<V>,std::vector<V>>& val){
  194. int min_size = std::min(val.first.size(), val.second.size());
  195. for(int i=0; i<min_size; i++)
  196. this->container->Fill(val.first[i],val.second[i]);
  197. }
  198. public:
  199. GenContainer* clone_as(const std::string& new_name){
  200. return new ContainerTH2Many<V>(new_name, this->value, this->title, this->params);
  201. }
  202. };
  203. template <typename V>
  204. class ContainerTGraph : public Container<TGraph,std::pair<V,V>>{
  205. private:
  206. std::vector<V> x_data;
  207. std::vector<V> y_data;
  208. std::string title;
  209. bool data_modified;
  210. void _fill(){
  211. auto val = this->value->get_value();
  212. x_data.push_back(val.first);
  213. y_data.push_back(val.second);
  214. data_modified = true;
  215. }
  216. public:
  217. ContainerTGraph(const std::string& name, const std::string& title, Value<std::pair<V, V>>* value)
  218. :Container<TGraph,std::pair<V,V>>(name, value),
  219. data_modified(false){
  220. this->container = new TGraph();
  221. }
  222. TGraph* get_container(){
  223. if (data_modified){
  224. delete this->container;
  225. this->container = new TGraph(x_data.size(), x_data.data(), y_data.data());
  226. this->container->SetName(this->get_name().c_str());
  227. this->container->SetTitle(title.c_str());
  228. data_modified = false;
  229. }
  230. return this->container;
  231. }
  232. GenContainer* clone_as(const std::string& new_name){
  233. return new ContainerTGraph<V>(new_name, this->title, this->value);
  234. }
  235. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  236. util::save_as(get_container(), fname, option);
  237. }
  238. };
  239. template <typename V>
  240. class Vector : public Container<std::vector<V>,V>{
  241. private:
  242. void _fill(){
  243. this->container->push_back(this->value->get_value());
  244. }
  245. public:
  246. Vector(const std::string& name, Value<V>* value)
  247. :Container<std::vector<V>,V>(name, value){
  248. this->container = new std::vector<V>;
  249. }
  250. GenContainer* clone_as(const std::string& new_name){
  251. return new Vector<V>(new_name, this->value);
  252. }
  253. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  254. std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(V))+">";
  255. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  256. }
  257. };
  258. template <typename V>
  259. class VectorMany : public Container<std::vector<V>,std::vector<V>>{
  260. private:
  261. void _fill(){
  262. for(const V& val: this->value->get_value())
  263. this->container->push_back(val);
  264. }
  265. public:
  266. VectorMany(const std::string& name, Value<std::vector<V>>* value)
  267. :Container<std::vector<V>,std::vector<V>>(name, value){
  268. this->container = new std::vector<V>;
  269. }
  270. GenContainer* clone_as(const std::string& new_name){
  271. return new VectorMany<V>(new_name, this->value);
  272. }
  273. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  274. std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(V))+">";
  275. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  276. }
  277. };
  278. template <typename V, typename D>
  279. class _Counter : public Container<std::map<D,int>,V>{
  280. public:
  281. explicit _Counter(const std::string& name, Value<V>* value)
  282. :Container<std::map<D,int>,V>(name, value) {
  283. this->container = new std::map<D,int>;
  284. }
  285. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  286. std::string type_name = "std::map<"+fv::util::get_type_name(typeid(D))+",int>";
  287. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  288. }
  289. };
  290. /**
  291. * A Counter that keeps a mapping of the number of occurances of each input
  292. * value.
  293. */
  294. template <typename V>
  295. class Counter : public _Counter<V,V>{
  296. using _Counter<V,V>::_Counter;
  297. void _fill(){
  298. (*this->container)[this->value->get_value()]++;
  299. }
  300. public:
  301. GenContainer* clone_as(const std::string& new_name){
  302. return new Counter<V>(new_name, this->value);
  303. }
  304. };
  305. /**
  306. * Same as Counter but accepts multiple values per fill.
  307. */
  308. template <typename V>
  309. class CounterMany : public _Counter<std::vector<V>,V>{
  310. using _Counter<std::vector<V>,V>::_Counter;
  311. void _fill(){
  312. for(V& val : this->value->get_value())
  313. (*this->container)[val]++;
  314. }
  315. public:
  316. GenContainer* clone_as(const std::string& new_name){
  317. return new CounterMany<V>(new_name, this->value);
  318. }
  319. };
  320. class PassCount : public Container<int,bool>{
  321. private:
  322. void _fill(){
  323. if(this->value->get_value()){
  324. (*this->container)++;
  325. }
  326. }
  327. public:
  328. PassCount(const std::string& name, Value<bool>* value)
  329. :Container<int,bool>(name, value){
  330. this->container = new int(0);
  331. }
  332. GenContainer* clone_as(const std::string& new_name){
  333. return new PassCount(new_name, this->value);
  334. }
  335. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  336. //ROOT(hilariously) cannot serialize basic data types, we wrap this
  337. //in a vector.
  338. std::vector<int> v({*this->get_container()});
  339. util::save_as_stl(&v, "std::vector<int>", this->get_name(), option);
  340. }
  341. };
  342. template <typename... ArgTypes>
  343. class MVA : public Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>{
  344. private:
  345. std::vector<std::pair<std::string,std::string>> methods;
  346. std::string cut;
  347. std::string opt;
  348. void _fill(){
  349. std::tuple<ArgTypes...> t;
  350. typename MVAData<ArgTypes...>::type& event = this->value->get_value();
  351. bool is_training, is_signal;
  352. double weight;
  353. std::tie(is_training, is_signal, weight, t) = event;
  354. std::vector<double> v = t2v<double>(t);
  355. if (is_signal){
  356. if (is_training){
  357. this->container->AddSignalTrainingEvent(v, weight);
  358. } else {
  359. this->container->AddSignalTestEvent(v, weight);
  360. }
  361. } else {
  362. if (is_training){
  363. this->container->AddBackgroundTrainingEvent(v, weight);
  364. } else {
  365. this->container->AddBackgroundTestEvent(v, weight);
  366. }
  367. }
  368. }
  369. public:
  370. MVA(const std::string& name, MVAData<ArgTypes...>* value, const std::string& cut = "", const std::string& opt = "")
  371. :Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>(name, value),
  372. cut(cut), opt(opt) {
  373. this->container = new TMVA::DataLoader(name);
  374. for (std::pair<std::string,char>& p : value->get_label_types()){
  375. this->container->AddVariable(p.first, p.second);
  376. }
  377. }
  378. void add_method(const std::string& method_name, const std::string& method_params) {
  379. methods.push_back(std::make_pair(method_name, method_params));
  380. }
  381. GenContainer* clone_as(const std::string& new_name){
  382. auto mva = new MVA<ArgTypes...>(new_name, (MVAData<ArgTypes...>*)this->value, this->cut, this->opt);
  383. mva->methods = methods;
  384. return mva;
  385. }
  386. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  387. TFile* outputFile = gDirectory->GetFile();
  388. this->container->PrepareTrainingAndTestTree(cut.c_str(), opt.c_str());
  389. TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
  390. "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
  391. TMVA::Types& types = TMVA::Types::Instance();
  392. for(auto& p : methods){
  393. std::string method_name, method_params;
  394. std::tie(method_name, method_params) = p;
  395. TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
  396. factory->BookMethod(this->container, method_type, method_name, method_params);
  397. }
  398. // Train MVAs using the set of training events
  399. factory->TrainAllMethods();
  400. // Evaluate all MVAs using the set of test events
  401. factory->TestAllMethods();
  402. // Evaluate and compare performance of all configured MVAs
  403. factory->EvaluateAllMethods();
  404. delete factory;
  405. }
  406. };
  407. }
  408. #endif // root_container_hpp