container.hpp 18 KB


  1. #ifndef root_container_hpp
  2. #define root_container_hpp
  3. #include <iostream>
  4. #include <utility>
  5. #include <map>
  6. #include "TROOT.h"
  7. #include "TFile.h"
  8. #include "TCanvas.h"
  9. #include "TGraph.h"
  10. #include "TH1.h"
  11. #include "TH2.h"
  12. #include "TMVA/Factory.h"
  13. #include "TMVA/DataLoader.h"
  14. #include "TMVA/DataSetInfo.h"
  15. #include "filval/container.hpp"
  16. namespace fv::root::util{
  17. /**
  18. * Save a TObject. The TObject will typically be a Histogram or Graph object,
  19. * but can really be any TObject. The SaveOption can be used to specify how to
  20. * save the file.
  21. */
  22. void save_as(TObject* container, const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  23. auto save_img = [](TObject* container, const std::string& fname){
  24. TCanvas* c1 = new TCanvas("c1");
  25. container->Draw();
  26. c1->Draw();
  27. c1->SaveAs(fname.c_str());
  28. delete c1;
  29. };
  30. auto save_bin = [](TObject* container){
  31. INFO("Saving object: " << container->GetName() << " into file " << gDirectory->GetName());
  32. container->Write(container->GetName(), TObject::kOverwrite);
  33. };
  34. switch(option){
  35. case PNG:
  36. save_img(container, fname+".png"); break;
  37. case PDF:
  38. save_img(container, fname+".pdf"); break;
  39. case ROOT:
  40. save_bin(container); break;
  41. default:
  42. break;
  43. }
  44. }
  45. /**
  46. * Saves an STL container into a ROOT file. ROOT knows how to serialize STL
  47. * containers, but it needs the *name* of the type of the container, eg.
  48. * std::map<int,int> to be able to do this. In order to generate this name at
  49. * run-time, the fv::util::get_type_name function uses RTTI to get type info
  50. * and use it to look up the proper name.
  51. *
  52. * For nexted containers, it is necessary to generate the CLING dictionaries
  53. * for each type at compile time to enable serialization. To do this, add the
  54. * type definition into the LinkDef.hpp header file.
  55. */
  56. void save_as_stl(void* container, const std::string& type_name,
  57. const std::string& obj_name,
  58. const SaveOption& option = SaveOption::PNG) {
  59. switch(option){
  60. case PNG:
  61. INFO("Cannot save STL container " << type_name <<" as png");
  62. break;
  63. case PDF:
  64. INFO("Cannot save STL container " << type_name <<" as pdf");
  65. break;
  66. case ROOT:
  67. /* DEBUG("Writing object \"" << obj_name << "\" of type \"" << type_name << "\"\n"); */
  68. gDirectory->WriteObjectAny(container, type_name.c_str(), obj_name.c_str());
  69. break;
  70. default:
  71. break;
  72. }
  73. }
  74. }
  75. namespace fv::root {
  76. struct TH1Params{
  77. std::string label_x;
  78. int nbins;
  79. double low;
  80. double high;
  81. std::string label_y;
  82. static TH1Params lookup(const std::string&& param_key){
  83. auto hist_params = fv::util::the_config->get("hist-params");
  84. if(!hist_params[param_key]){
  85. CRITICAL("Key \"" << param_key << "\" does not exist under hist-params in supplied config file. Add it!", true);
  86. }
  87. else{
  88. auto params = hist_params[param_key];
  89. return TH1Params({params["label_x"].as<std::string>(),
  90. params["nbins"].as<int>(),
  91. params["low"].as<double>(),
  92. params["high"].as<double>(),
  93. params["label_y"].as<std::string>()
  94. });
  95. }
  96. }
  97. };
  98. template <typename V>
  99. class _ContainerTH1 : public Container<TH1,V>{
  100. private:
  101. void _fill(){
  102. if (this->container == nullptr){
  103. if (this->value == nullptr){
  104. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  105. << "Probably built with imcompatible type",-1);
  106. }
  107. this->container = new TH1D(this->get_name().c_str(), this->title.c_str(),
  108. params.nbins, params.low, params.high);
  109. this->container->SetXTitle(params.label_x.c_str());
  110. this->container->SetYTitle(params.label_y.c_str());
  111. }
  112. _do_fill();
  113. }
  114. protected:
  115. std::string title;
  116. TH1Params params;
  117. virtual void _do_fill() = 0;
  118. public:
  119. explicit _ContainerTH1(const std::string& name, Value<V>* value,
  120. const std::string& title,
  121. const TH1Params& params)
  122. :Container<TH1,V>(name, value),
  123. title(title),
  124. params(params) { }
  125. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  126. util::save_as(this->get_container(), fname, option);
  127. }
  128. };
  129. template <typename V>
  130. class ContainerTH1 : public _ContainerTH1<V>{
  131. using _ContainerTH1<V>::_ContainerTH1;
  132. void _do_fill(){
  133. this->container->Fill(this->value->get_value());
  134. }
  135. public:
  136. GenContainer* clone_as(const std::string& new_name){
  137. return new ContainerTH1<V>(new_name, this->value, this->title, this->params);
  138. }
  139. };
  140. template <typename V>
  141. class ContainerTH1Many : public _ContainerTH1<std::vector<V>>{
  142. using _ContainerTH1<std::vector<V>>::_ContainerTH1;
  143. void _do_fill(){
  144. for(const V &x : this->value->get_value())
  145. this->container->Fill(x);
  146. }
  147. public:
  148. GenContainer* clone_as(const std::string& new_name){
  149. return new ContainerTH1Many<V>(new_name, this->value, this->title, this->params);
  150. }
  151. };
  152. struct TH2Params{
  153. std::string label_x;
  154. int nbins_x;
  155. double low_x;
  156. double high_x;
  157. std::string label_y;
  158. int nbins_y;
  159. double low_y;
  160. double high_y;
  161. static TH2Params lookup(const std::string&& param_key){
  162. auto hist_params = fv::util::the_config->get("hist-params");
  163. if(!hist_params[param_key]){
  164. CRITICAL("Key \"" << param_key << "\" does not exist under hist-params in supplied config file. Add it!", true);
  165. }
  166. else{
  167. auto params = hist_params[param_key];
  168. return TH2Params({params["label_x"].as<std::string>(),
  169. params["nbins_x"].as<int>(),
  170. params["low_x"].as<double>(),
  171. params["high_x"].as<double>(),
  172. params["label_y"].as<std::string>(),
  173. params["nbins_y"].as<int>(),
  174. params["low_y"].as<double>(),
  175. params["high_y"].as<double>()
  176. });
  177. }
  178. }
  179. };
  180. template <typename V>
  181. class _ContainerTH2 : public Container<TH2,std::pair<V,V>>{
  182. private:
  183. void _fill(){
  184. if (this->container == nullptr){
  185. if (this->value == nullptr){
  186. CRITICAL("Container: \"" << this->get_name() << "\" has a null Value object. "
  187. << "Probably built with imcompatible type",-1);
  188. }
  189. this->container = new TH2D(this->get_name().c_str(), this->title.c_str(),
  190. params.nbins_x, params.low_x, params.high_x,
  191. params.nbins_y, params.low_y, params.high_y);
  192. this->container->SetXTitle(params.label_x.c_str());
  193. this->container->SetYTitle(params.label_y.c_str());
  194. }
  195. _do_fill(this->value->get_value());
  196. }
  197. protected:
  198. std::string title;
  199. TH2Params params;
  200. virtual void _do_fill(const std::pair<V,V>& val) = 0;
  201. public:
  202. explicit _ContainerTH2(const std::string& name, Value<std::pair<V, V>>* value,
  203. const std::string& title,
  204. TH2Params params)
  205. :Container<TH2,std::pair<V,V>>(name, value),
  206. title(title),
  207. params(params) { }
  208. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  209. util::save_as(this->get_container(), fname, option);
  210. }
  211. };
  212. template <typename V>
  213. class ContainerTH2 : public _ContainerTH2<V>{
  214. using _ContainerTH2<V>::_ContainerTH2;
  215. void _do_fill(const std::pair<V,V>& val){
  216. this->container->Fill(val.first,val.second);
  217. }
  218. public:
  219. GenContainer* clone_as(const std::string& new_name){
  220. return new ContainerTH2<V>(new_name, this->value, this->title, this->params);
  221. }
  222. };
  223. template <typename V>
  224. class ContainerTH2Many : public _ContainerTH2<std::vector<V>>{
  225. using _ContainerTH2<std::vector<V>>::_ContainerTH2;
  226. void _do_fill(const std::pair<std::vector<V>,std::vector<V>>& val){
  227. int min_size = std::min(val.first.size(), val.second.size());
  228. for(int i=0; i<min_size; i++)
  229. this->container->Fill(val.first[i],val.second[i]);
  230. }
  231. public:
  232. GenContainer* clone_as(const std::string& new_name){
  233. return new ContainerTH2Many<V>(new_name, this->value, this->title, this->params);
  234. }
  235. };
  236. template <typename V>
  237. class ContainerTGraph : public Container<TGraph,std::pair<V,V>>{
  238. private:
  239. std::vector<V> x_data;
  240. std::vector<V> y_data;
  241. std::string title;
  242. bool data_modified;
  243. void _fill(){
  244. auto val = this->value->get_value();
  245. x_data.push_back(val.first);
  246. y_data.push_back(val.second);
  247. data_modified = true;
  248. }
  249. public:
  250. ContainerTGraph(const std::string& name, const std::string& title, Value<std::pair<V, V>>* value)
  251. :Container<TGraph,std::pair<V,V>>(name, value),
  252. data_modified(false){
  253. this->container = new TGraph();
  254. }
  255. TGraph* get_container(){
  256. if (data_modified){
  257. delete this->container;
  258. this->container = new TGraph(x_data.size(), x_data.data(), y_data.data());
  259. this->container->SetName(this->get_name().c_str());
  260. this->container->SetTitle(title.c_str());
  261. data_modified = false;
  262. }
  263. return this->container;
  264. }
  265. GenContainer* clone_as(const std::string& new_name){
  266. return new ContainerTGraph<V>(new_name, this->title, this->value);
  267. }
  268. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  269. util::save_as(get_container(), fname, option);
  270. }
  271. };
  272. template <typename V>
  273. class Vector : public Container<std::vector<V>,V>{
  274. private:
  275. void _fill(){
  276. this->container->push_back(this->value->get_value());
  277. }
  278. public:
  279. Vector(const std::string& name, Value<V>* value)
  280. :Container<std::vector<V>,V>(name, value){
  281. this->container = new std::vector<V>;
  282. }
  283. GenContainer* clone_as(const std::string& new_name){
  284. return new Vector<V>(new_name, this->value);
  285. }
  286. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  287. std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(V))+">";
  288. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  289. }
  290. };
  291. template <typename V>
  292. class VectorMany : public Container<std::vector<V>,std::vector<V>>{
  293. private:
  294. void _fill(){
  295. for(const V& val: this->value->get_value())
  296. this->container->push_back(val);
  297. }
  298. public:
  299. VectorMany(const std::string& name, Value<std::vector<V>>* value)
  300. :Container<std::vector<V>,std::vector<V>>(name, value){
  301. this->container = new std::vector<V>;
  302. }
  303. GenContainer* clone_as(const std::string& new_name){
  304. return new VectorMany<V>(new_name, this->value);
  305. }
  306. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  307. std::string type_name = "std::vector<"+fv::util::get_type_name(typeid(V))+">";
  308. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  309. }
  310. };
  311. template <typename V, typename D>
  312. class _Counter : public Container<std::map<D,int>,V>{
  313. public:
  314. explicit _Counter(const std::string& name, Value<V>* value)
  315. :Container<std::map<D,int>,V>(name, value) {
  316. this->container = new std::map<D,int>;
  317. }
  318. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  319. std::string type_name = "std::map<"+fv::util::get_type_name(typeid(D))+",int>";
  320. util::save_as_stl(this->get_container(), type_name, this->get_name(), option);
  321. }
  322. };
  323. /**
  324. * A Counter that keeps a mapping of the number of occurances of each input
  325. * value.
  326. */
  327. template <typename V>
  328. class Counter : public _Counter<V,V>{
  329. using _Counter<V,V>::_Counter;
  330. void _fill(){
  331. (*this->container)[this->value->get_value()]++;
  332. }
  333. public:
  334. GenContainer* clone_as(const std::string& new_name){
  335. return new Counter<V>(new_name, this->value);
  336. }
  337. };
  338. /**
  339. * Same as Counter but accepts multiple values per fill.
  340. */
  341. template <typename V>
  342. class CounterMany : public _Counter<std::vector<V>,V>{
  343. using _Counter<std::vector<V>,V>::_Counter;
  344. void _fill(){
  345. for(V& val : this->value->get_value())
  346. (*this->container)[val]++;
  347. }
  348. public:
  349. GenContainer* clone_as(const std::string& new_name){
  350. return new CounterMany<V>(new_name, this->value);
  351. }
  352. };
  353. class PassCount : public Container<int,bool>{
  354. private:
  355. void _fill(){
  356. if(this->value->get_value()){
  357. (*this->container)++;
  358. }
  359. }
  360. public:
  361. PassCount(const std::string& name, Value<bool>* value)
  362. :Container<int,bool>(name, value){
  363. this->container = new int(0);
  364. }
  365. GenContainer* clone_as(const std::string& new_name){
  366. return new PassCount(new_name, this->value);
  367. }
  368. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  369. //ROOT(hilariously) cannot serialize basic data types, we wrap this
  370. //in a vector.
  371. std::vector<int> v({*this->get_container()});
  372. util::save_as_stl(&v, "std::vector<int>", this->get_name(), option);
  373. }
  374. };
  375. template <typename... ArgTypes>
  376. class MVA : public Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>{
  377. private:
  378. std::vector<std::pair<std::string,std::string>> methods;
  379. std::string cut;
  380. std::string opt;
  381. void _fill(){
  382. std::tuple<ArgTypes...> t;
  383. typename MVAData<ArgTypes...>::type& event = this->value->get_value();
  384. bool is_training, is_signal;
  385. double weight;
  386. std::tie(is_training, is_signal, weight, t) = event;
  387. std::vector<double> v = t2v<double>(t);
  388. if (is_signal){
  389. if (is_training){
  390. this->container->AddSignalTrainingEvent(v, weight);
  391. } else {
  392. this->container->AddSignalTestEvent(v, weight);
  393. }
  394. } else {
  395. if (is_training){
  396. this->container->AddBackgroundTrainingEvent(v, weight);
  397. } else {
  398. this->container->AddBackgroundTestEvent(v, weight);
  399. }
  400. }
  401. }
  402. public:
  403. MVA(const std::string& name, MVAData<ArgTypes...>* value, const std::string& cut = "", const std::string& opt = "")
  404. :Container<TMVA::DataLoader,typename MVAData<ArgTypes...>::type>(name, value),
  405. cut(cut), opt(opt) {
  406. this->container = new TMVA::DataLoader(name);
  407. for (std::pair<std::string,char>& p : value->get_label_types()){
  408. this->container->AddVariable(p.first, p.second);
  409. }
  410. }
  411. void add_method(const std::string& method_name, const std::string& method_params) {
  412. methods.push_back(std::make_pair(method_name, method_params));
  413. }
  414. GenContainer* clone_as(const std::string& new_name){
  415. auto mva = new MVA<ArgTypes...>(new_name, (MVAData<ArgTypes...>*)this->value, this->cut, this->opt);
  416. mva->methods = methods;
  417. return mva;
  418. }
  419. void save_as(const std::string& fname, const SaveOption& option = SaveOption::PNG) {
  420. TFile* outputFile = gDirectory->GetFile();
  421. this->container->PrepareTrainingAndTestTree(cut.c_str(), opt.c_str());
  422. TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
  423. "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
  424. TMVA::Types& types = TMVA::Types::Instance();
  425. for(auto& p : methods){
  426. std::string method_name, method_params;
  427. std::tie(method_name, method_params) = p;
  428. TMVA::Types::EMVA method_type = types.GetMethodType(method_name);
  429. factory->BookMethod(this->container, method_type, method_name, method_params);
  430. }
  431. // Train MVAs using the set of training events
  432. factory->TrainAllMethods();
  433. // Evaluate all MVAs using the set of test events
  434. factory->TestAllMethods();
  435. // Evaluate and compare performance of all configured MVAs
  436. factory->EvaluateAllMethods();
  437. delete factory;
  438. }
  439. };
  440. }
  441. #endif // root_container_hpp