generate_class.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #!/usr/bin/env python3
  2. def generate_collection_class(obj_name, obj_attrs):
  3. src = []
  4. src += f'''\
  5. struct {obj_name};
  6. class {obj_name}Collection {{
  7. public:
  8. class iter {{
  9. public:
  10. iter(const {obj_name}Collection* collection, size_t idx)
  11. :collection(collection), idx(idx) {{ }}
  12. iter operator++() {{ ++idx; return *this; }}
  13. bool operator!=(const iter & other) {{ return idx != other.idx; }}
  14. const {obj_name} operator*() const;
  15. private:
  16. const {obj_name}Collection* collection;
  17. size_t idx;
  18. }};
  19. '''.splitlines()
  20. for field in obj_attrs['fields']:
  21. name = field['name']
  22. type_ = field['type']
  23. src.append(f' Value<vector<{type_}>>* val_{name};')
  24. src.append(f'\n {obj_name}Collection() {{ }}\n')
  25. src.append(' void init(TrackingDataSet& tds){')
  26. for field in obj_attrs['fields']:
  27. name = field['name']
  28. type_ = field['type']
  29. prefix = obj_attrs['treename_prefix']+'_'
  30. src.append(f' val_{name} = tds.track_branch_obj<vector<{type_}>>("{prefix}{name}");')
  31. src.append(' }\n')
  32. first_obj_name = list(obj_attrs['fields'])[0]['name']
  33. src.append(f' size_t size() const {{ return val_{first_obj_name}->get_value().size();}}\n')
  34. src.append(f' const {obj_name} operator[](size_t) const;')
  35. src.append(' iter begin() const { return iter(this, 0); }')
  36. src.append(' iter end() const { return iter(this, size()); }')
  37. src.append('};')
  38. src += f'''
  39. struct {obj_name} {{
  40. const {obj_name}Collection* collection;
  41. const size_t idx;
  42. {obj_name}(const {obj_name}Collection* collection, const size_t idx)
  43. :collection(collection), idx(idx) {{ }}\n
  44. '''.splitlines()
  45. for field in obj_attrs['fields']:
  46. name = field['name']
  47. type_ = field['type']
  48. src.append(f' const {type_}& {name}() const {{return collection->val_{name}->get_value().at(idx);}}')
  49. src.append('};')
  50. src.append(f'''
  51. const {obj_name} {obj_name}Collection::iter::operator*() const {{
  52. return {{collection, idx}};
  53. }}
  54. const {obj_name} {obj_name}Collection::operator[](size_t idx) const {{
  55. return {{this, idx}};
  56. }}
  57. ''')
  58. return '\n'.join(src)
  59. def generate_header(input_filename, output_filename):
  60. from datetime import datetime
  61. return f'''\
  62. /** {output_filename} created on {datetime.now()} by generate_class.py
  63. * AVOID EDITING THIS FILE BY HAND!! Instead edit {input_filename} and re-run
  64. * generate_class.py
  65. */
  66. #include "filval/filval.hpp"
  67. #include "filval/root/filval.hpp"
  68. #include<cmath>
  69. #include "TrackingNtuple.h"
  70. using namespace std;
  71. using namespace fv;
  72. using namespace fv::root;
  73. typedef TreeDataSet<TrackingNtuple> TrackingDataSet;
  74. '''
  75. if __name__ == '__main__':
  76. import argparse
  77. import yaml
  78. parser = argparse.ArgumentParser()
  79. add = parser.add_argument
  80. add('input_file', help='An input YAML file defining the objects to generate')
  81. args = parser.parse_args()
  82. classes = []
  83. with open(args.input_file) as fi:
  84. for obj, attrs in yaml.load(fi).items():
  85. classes.append(generate_collection_class(obj, attrs))
  86. output_filename = args.input_file.replace('.yaml', '.hpp')
  87. with open(output_filename, 'w') as fo:
  88. fo.write(generate_header(args.input_file, output_filename))
  89. for class_ in classes:
  90. fo.write(class_)