generate_class.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python3
  2. def generate_collection_class(obj_name, obj_attrs):
  3. src = []
  4. src += f'''\
  5. struct {obj_name};
  6. class {obj_name}Collection {{
  7. public:
  8. class iter {{
  9. public:
  10. iter({obj_name}Collection* collection, size_t idx)
  11. :collection(collection), idx(idx) {{ }}
  12. iter operator++() {{ ++idx; return *this; }}
  13. bool operator!=(const iter & other) {{ return idx != other.idx; }}
  14. {obj_name} operator*();
  15. private:
  16. {obj_name}Collection* collection;
  17. size_t idx;
  18. }};
  19. TrackingDataSet* tds;
  20. '''.splitlines()
  21. for field in obj_attrs['fields']:
  22. name = field['name']
  23. type_ = field['type']
  24. src.append(f' Value<vector<{type_}>>* val_{name};')
  25. src.append(f' bool {name}_loaded;')
  26. src.append(f'\n {obj_name}Collection() {{ }}\n')
  27. src.append(' void init(TrackingDataSet* tds){')
  28. src.append(' this->tds = tds;')
  29. src.append(' }\n')
  30. first_obj_name = list(obj_attrs['fields'])[0]['name']
  31. first_obj_type = list(obj_attrs['fields'])[0]['type']
  32. prefix = obj_attrs['treename_prefix']
  33. src.append(f'''\
  34. size_t size() {{
  35. if (!this->{first_obj_name}_loaded) {{
  36. this->val_{first_obj_name} = this->tds->track_branch_obj<vector<{first_obj_type}>>("{prefix}_{first_obj_name}");
  37. this->{first_obj_name}_loaded = true;
  38. }}
  39. return (*this->val_{first_obj_name})().size();
  40. }}
  41. \n''')
  42. src.append(f' {obj_name} operator[](size_t);')
  43. src.append(' iter begin() { return iter(this, 0); }')
  44. src.append(' iter end() { return iter(this, size()); }')
  45. src.append('};')
  46. src += f'''
  47. struct {obj_name} {{
  48. {obj_name}Collection* collection;
  49. size_t idx;
  50. {obj_name}({obj_name}Collection* collection, const size_t idx)
  51. :collection(collection), idx(idx) {{ }}\n
  52. '''.splitlines()
  53. for field in obj_attrs['fields']:
  54. name = field['name']
  55. type_ = field['type']
  56. # Because vector<bool> is packed, a temporary object is created so we can't return a reference.
  57. ret_type = f'{type_}&' if type_ != 'bool' else type_
  58. src.append(f'''\
  59. const {ret_type} {name}() const {{
  60. if (!collection->{name}_loaded) {{
  61. collection->val_{name} = collection->tds->track_branch_obj<vector<{type_}>>("{prefix}_{name}");
  62. collection->{name}_loaded = true;
  63. }}
  64. return (*collection->val_{name})().at(idx);
  65. }}
  66. ''')
  67. src.append('};')
  68. src.append(f'''
  69. bool operator==(const {obj_name}& obj1, const {obj_name}& obj2) {{
  70. return obj1.idx == obj2.idx;
  71. }}
  72. {obj_name} {obj_name}Collection::iter::operator*() {{
  73. return {{collection, idx}};
  74. }}
  75. {obj_name} {obj_name}Collection::operator[](size_t idx) {{
  76. return {{this, idx}};
  77. }}
  78. ''')
  79. return '\n'.join(src)
  80. def generate_header(input_filename, output_filename):
  81. from datetime import datetime
  82. return f'''\
  83. /** {output_filename} created on {datetime.now()} by generate_class.py
  84. * AVOID EDITING THIS FILE BY HAND!! Instead edit {input_filename} and re-run
  85. * generate_class.py
  86. */
  87. #include "filval.hpp"
  88. #include "root_filval.hpp"
  89. #include<cmath>
  90. #include "TrackingNtuple.h"
  91. using namespace std;
  92. using namespace fv;
  93. using namespace fv_root;
  94. typedef TreeDataSet<TrackingNtuple> TrackingDataSet;
  95. '''
  96. if __name__ == '__main__':
  97. import argparse
  98. import yaml
  99. parser = argparse.ArgumentParser()
  100. add = parser.add_argument
  101. add('input_file', help='An input YAML file defining the objects to generate')
  102. args = parser.parse_args()
  103. classes = []
  104. with open(args.input_file) as fi:
  105. for obj, attrs in yaml.load(fi).items():
  106. classes.append(generate_collection_class(obj, attrs))
  107. output_filename = args.input_file.replace('.yaml', '.hpp')
  108. with open(output_filename, 'w') as fo:
  109. fo.write(generate_header(args.input_file, output_filename))
  110. for class_ in classes:
  111. fo.write(class_)