examine_seeds.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from collections import defaultdict
  2. from uproot import open as root_open
  3. def main():
  4. f_old = root_open('trackingNtuple_old_default.root')['trackingNtuple/tree']
  5. f_new = root_open('trackingNtuple_new_default.root')['trackingNtuple/tree']
  6. keys = [b'see_sclIdx', b'see_trkIdx',
  7. b'scl_e', b'scl_px', b'scl_py', b'scl_pz', b'scl_hoe',
  8. b'trk_q']
  9. arrs_old = f_old.arrays(keys)
  10. arrs_new = f_new.arrays(keys)
  11. def dump_event(event, name):
  12. print('-'*20 + f'{name:10}' + '-'*20)
  13. # print(event[b'scl_hoe'] <= 0.15)
  14. def get_cols(*strs):
  15. en = enumerate(zip(*[event[s] for s in strs]))
  16. return en
  17. print('Seed Info')
  18. for idx, (sclIdx, trkIdx) in get_cols('see_sclIdx', 'see_trkIdx'):
  19. if sclIdx < 0: continue
  20. if event['scl_hoe'][sclIdx] > 0.15: continue
  21. trk_q = '-'
  22. if trkIdx>=0:
  23. trk_q = str(event["trk_q"][trkIdx])
  24. print(f'{idx:3d}) {sclIdx:10d} {trk_q:10s}')
  25. # print(event[b'see_sclIdx'])
  26. def dump_scl(event):
  27. def get_cols(*strs):
  28. en = enumerate(zip(*[event[s] for s in strs]))
  29. return en
  30. print('Supercluster Info')
  31. for idx, (e, px, py, pz, hoe) in get_cols('scl_e', 'scl_px', 'scl_py', 'scl_pz', 'scl_hoe'):
  32. print(f'{idx:3d}) {hoe:10.2f} {e:10.2f}')
  33. def seed_summary(event_old, event_new):
  34. def get_cols(event, *strs):
  35. en = enumerate(zip(*[event[s] for s in strs]))
  36. return en
  37. counts_old = defaultdict(int)
  38. counts_new = defaultdict(int)
  39. # print('Supercluster Info')
  40. # for idx, (e, px, py, pz, hoe) in get_cols('scl_e', 'scl_px', 'scl_py', 'scl_pz', 'scl_hoe'):
  41. # print(f'{idx:3d}) {hoe:10.2f} {e:10.2f}')
  42. print('Seed Info')
  43. for _, (sclIdx,) in get_cols(event_old, 'see_sclIdx'):
  44. if sclIdx >= 0:
  45. # if event_old['scl_hoe'][sclIdx] > 0.15: continue
  46. counts_old[sclIdx] += 1
  47. for _, (sclIdx,) in get_cols(event_new, 'see_sclIdx'):
  48. if sclIdx >= 0:
  49. # if event_new['scl_hoe'][sclIdx] > 0.15: continue
  50. counts_new[sclIdx] += 1
  51. for idx, (e, px, py, pz, hoe) in get_cols(event_old, 'scl_e', 'scl_px', 'scl_py', 'scl_pz', 'scl_hoe'):
  52. if hoe > 0.15: continue
  53. print(f'{idx:3d}) {hoe:10.2f} {e:10.2f} {counts_old[idx]:10d} {counts_new[idx]:10d}')
  54. nevt = len(arrs_old[keys[0]])
  55. nevt = 5
  56. for eIdx in range(nevt):
  57. print(f'NEW EVENT: {eIdx}')
  58. old = {key.decode(): arrs_old[key][eIdx] for key in keys}
  59. new = {key.decode(): arrs_new[key][eIdx] for key in keys}
  60. # dump_scl(old)
  61. # dump_event(old, 'OLD')
  62. # dump_event(new, 'NEW')
  63. seed_summary(old, new)
  64. # print(new[b'see_sclIdx'])
  65. if __name__ == '__main__':
  66. main()