sim_track_viz.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #!/usr/bin/env python
  2. import numpy as np
  3. from mpl_toolkits.mplot3d import Axes3D
  4. import matplotlib.pyplot as plt
  5. from uproot import open as root_open
  6. # plt.interactive(True)
  7. tree = None
  8. bsp = None
  9. def pdg_color(pdgId):
  10. return {
  11. 11: 'b',
  12. -11: 'b',
  13. 22: 'g',
  14. }.get(pdgId, 'k')
  15. def get_roots(sourceSimIdxs):
  16. roots = []
  17. for idx, sourceSimIdx in enumerate(sourceSimIdxs):
  18. if len(sourceSimIdx) == 0:
  19. roots.append(idx)
  20. return np.array(roots[:len(roots)//2])
  21. def p(px, py, pz):
  22. return np.sqrt(px**2 + py**2 + pz**2)
  23. def plot_all_roots(sim_vtxs):
  24. # ax = plt.gca(projection='3d')
  25. ax = plt.gca()
  26. xs = sim_vtxs[b'simvtx_x'] - bsp[b'bsp_x']
  27. ys = sim_vtxs[b'simvtx_y'] - bsp[b'bsp_y']
  28. zs = sim_vtxs[b'simvtx_z'] - bsp[b'bsp_z']
  29. rs = np.sqrt(xs*xs + ys*ys)
  30. roots = get_roots(sim_vtxs[b'simvtx_sourceSimIdx'])
  31. bound_z = bsp[b'bsp_sigmaz']*5
  32. # bound_z = 1
  33. bound_r = np.sqrt(bsp[b'bsp_sigmax']**2 + bsp[b'bsp_sigmay']**2)
  34. lumi_roots = []
  35. for root in roots:
  36. if abs(zs[root]) < bound_z and np.sqrt(xs[root]**2 + ys[root]**2) < bound_r:
  37. lumi_roots.append(root)
  38. # nvtx = len(xs)//2
  39. # ax.plot(xs[roots], ys[roots], zs[roots], '.')
  40. # ax.plot(zs[roots], rs[roots], '.')
  41. ax.plot(zs[lumi_roots], rs[lumi_roots], '.')
  42. # ax.plot([-bound_z, -bound_z, bound_z, bound_z], [0, bound_r, bound_r, 0], 'r')
  43. ax.set_xlabel('z')
  44. ax.set_ylabel('r')
  45. ax.set_aspect('equal')
  46. # ax.set_zlabel('z')
  47. def plot_all_vtx(sim_vtxs):
  48. ax = plt.gca(projection='3d')
  49. xs = sim_vtxs[b'simvtx_x']
  50. ys = sim_vtxs[b'simvtx_y']
  51. zs = sim_vtxs[b'simvtx_z']
  52. # nvtx = len(xs)//2
  53. ax.plot(xs, ys, zs, '.')
  54. ax.set_xlabel('x')
  55. ax.set_ylabel('y')
  56. ax.set_zlabel('z')
  57. def plot_event_tree(sim_tracks, sim_vtxs, pv_idx):
  58. print('='*80)
  59. ax = plt.gca(projection='3d')
  60. pdgIds = sim_tracks[b'sim_pdgId']
  61. decayVtxIdxs = sim_tracks[b'sim_decayVtxIdx']
  62. px = sim_tracks[b'sim_px']
  63. py = sim_tracks[b'sim_py']
  64. pz = sim_tracks[b'sim_pz']
  65. xs = sim_vtxs[b'simvtx_x']
  66. ys = sim_vtxs[b'simvtx_y']
  67. zs = sim_vtxs[b'simvtx_z']
  68. daughterSimIdxs = sim_vtxs[b'simvtx_daughterSimIdx']
  69. vtx_idxs = [(0, pv_idx)]
  70. while vtx_idxs:
  71. depth, vtx_idx = vtx_idxs.pop()
  72. print(' '*depth + str(vtx_idx), end=' -> ', flush=True)
  73. ax.plot([xs[vtx_idx]], [ys[vtx_idx]], [zs[vtx_idx]], 'r.')
  74. start_x = xs[vtx_idx]
  75. start_y = ys[vtx_idx]
  76. start_z = zs[vtx_idx]
  77. for sim_idx in daughterSimIdxs[vtx_idx]:
  78. pdgId = pdgIds[sim_idx]
  79. # print(pdgId)
  80. # if abs(pdgId) != 11:
  81. # continue
  82. # if pdgId == 22:
  83. # continue
  84. # if pdgId == 22 and p(px[sim_idx], py[sim_idx], pz[sim_idx]) < 1.0:
  85. # continue
  86. for decay_vtx_idx in decayVtxIdxs[sim_idx]:
  87. end_x = xs[decay_vtx_idx]
  88. end_y = ys[decay_vtx_idx]
  89. end_z = zs[decay_vtx_idx]
  90. # if abs(pdgId) == 11:
  91. if True:
  92. ax.plot([start_x, end_x], [start_y, end_y], [start_z, end_z],
  93. pdg_color(pdgId))
  94. vtx_idxs.append((depth+1, decay_vtx_idx))
  95. print(str(decay_vtx_idx) + ', ', end='')
  96. print()
  97. ax.plot([xs[pv_idx]], [ys[pv_idx]], [zs[pv_idx]], 'k.')
  98. # ax.set_xlim((-5, 5))
  99. # ax.set_ylim((-5, 5))
  100. # ax.set_zlim((-5, 5))
  101. ax.set_xlabel('x')
  102. ax.set_ylabel('y')
  103. ax.set_zlabel('z')
  104. def print_sim_vtxs(sim_vtxs, sim_tracks, sim_pvs, gens, gsf_tracks):
  105. daughterSimIdxs = sim_vtxs[b'simvtx_daughterSimIdx']
  106. sourceSimIdxs = sim_vtxs[b'simvtx_sourceSimIdx']
  107. processTypes = sim_vtxs[b'simvtx_processType']
  108. xs = sim_vtxs[b'simvtx_x'] - bsp[b'bsp_x']
  109. ys = sim_vtxs[b'simvtx_y'] - bsp[b'bsp_y']
  110. zs = sim_vtxs[b'simvtx_z'] - bsp[b'bsp_z']
  111. parentVtxIdxs = sim_tracks[b'sim_parentVtxIdx']
  112. decayVtxIdxs = sim_tracks[b'sim_decayVtxIdx']
  113. sim_pdgIds = sim_tracks[b'sim_pdgId']
  114. sim_pxs = sim_tracks[b'sim_px']
  115. sim_pys = sim_tracks[b'sim_py']
  116. sim_pzs = sim_tracks[b'sim_pz']
  117. gen_pdgIds = gens[b'gen_pdgId']
  118. vxs = gens[b'gen_vx'] - bsp[b'bsp_x']
  119. vys = gens[b'gen_vy'] - bsp[b'bsp_y']
  120. vzs = gens[b'gen_vz'] - bsp[b'bsp_z']
  121. gen_pxs = gens[b'gen_px']
  122. gen_pys = gens[b'gen_py']
  123. gen_pzs = gens[b'gen_pz']
  124. gsf_pdgIds = -11 * gsf_tracks[b'trk_q']
  125. gsf_vxs = gsf_tracks[b'trk_vtxx'] - bsp[b'bsp_x']
  126. gsf_vys = gsf_tracks[b'trk_vtxy'] - bsp[b'bsp_y']
  127. gsf_vzs = gsf_tracks[b'trk_vtxz'] - bsp[b'bsp_z']
  128. gsf_pxs = gsf_tracks[b'trk_px']
  129. gsf_pys = gsf_tracks[b'trk_py']
  130. gsf_pzs = gsf_tracks[b'trk_pz']
  131. # print('VTX')
  132. # for (idx, (sourceSimIdx, daughterSimIdx, processType)) in enumerate(zip(sourceSimIdxs, daughterSimIdxs, processTypes)):
  133. # if idx in sim_pvs:
  134. # print(f'*{idx}|{processType}', sourceSimIdx, daughterSimIdx, sep=" - ")
  135. # else:
  136. # print(f'{idx}|{processType}', sourceSimIdx, daughterSimIdx, sep=" - ")
  137. print('GEN')
  138. for (idx, (px, py, pz, vx, vy, vz, pdgId)) in enumerate(zip(gen_pxs, gen_pys, gen_pzs, vxs, vys, vzs, gen_pdgIds)):
  139. if abs(pdgId) != 11: continue
  140. p = np.sqrt(px**2 + py**2 + pz**2)
  141. theta = np.arctan2(np.hypot(px, py), pz)
  142. phi = np.arctan2(px, py)
  143. print(f'{idx: 4d}|{pdgId: 3d} - ({vx:8.2f},{vy:8.2f},{vz:8.2f}) ({theta:5.2f},{phi:5.2f}) {p:.2f}GeV')
  144. print('SIM')
  145. for (idx, (px, py, pz, parentVtxIdx, decayVtxIdx, pdgId)) in enumerate(zip(gen_pxs, gen_pys, gen_pzs, parentVtxIdxs, decayVtxIdxs, sim_pdgIds)):
  146. if abs(pdgId) != 11: continue
  147. if len(sourceSimIdxs[parentVtxIdx]) > 0: continue
  148. vx = xs[parentVtxIdx]
  149. vy = ys[parentVtxIdx]
  150. vz = zs[parentVtxIdx]
  151. p = np.sqrt(px**2 + py**2 + pz**2)
  152. theta = np.arctan2(np.hypot(px, py), pz)
  153. phi = np.arctan2(px, py)
  154. print(f'{idx: 4d}|{pdgId: 3d} - ({vx:8.2f},{vy:8.2f},{vz:8.2f}) ({theta:5.2f},{phi:5.2f}) {p:.2f}GeV')
  155. print('RECO')
  156. for (idx, (px, py, pz, vx, vy, vz, pdgId)) in enumerate(zip(gsf_pxs, gsf_pys, gsf_pzs, gsf_vxs, gsf_vys, gsf_vzs, gsf_pdgIds)):
  157. p = np.sqrt(px**2 + py**2 + pz**2)
  158. theta = np.arctan2(np.hypot(px, py), pz)
  159. phi = np.arctan2(px, py)
  160. print(f'{idx: 4d}|{pdgId: 3d} - ({vx:8.2f},{vy:8.2f},{vz:8.2f}) ({theta:5.2f},{phi:5.2f}) {p:.2f}GeV')
  161. input()
  162. def plot_event(sim_tracks, sim_vtxs, sim_pvs, gens, gsf_tracks, event_idx):
  163. print(f"Processing event {event_idx}")
  164. sim_tracks = {k: v[event_idx] for k, v in sim_tracks.items()}
  165. sim_vtxs = {k: v[event_idx] for k, v in sim_vtxs.items()}
  166. gens = {k: v[event_idx] for k, v in gens.items()}
  167. gsf_tracks = {k: v[event_idx] for k, v in gsf_tracks.items()}
  168. sim_pvs = sim_pvs[event_idx]
  169. print_sim_vtxs(sim_vtxs, sim_tracks, sim_pvs, gens, gsf_tracks)
  170. # print(len(sim_pvs[event_idx]))
  171. # plt.clf()
  172. # plot_all_roots(sim_vtxs)
  173. # plt.show()
  174. # for pv_idx in get_roots(sim_vtxs[b'simvtx_sourceSimIdx']):
  175. # plt.clf()
  176. # plot_event_tree(sim_tracks, sim_vtxs, pv_idx)
  177. # # plot_all_vtx(sim_vtxs)
  178. # plt.show()
  179. def main():
  180. global tree
  181. global bsp
  182. f = root_open('../data/zee.root')
  183. tree = f['trackingNtuple/tree']
  184. gsf_tracks = tree.arrays([
  185. 'trk_q',
  186. 'trk_vtxx',
  187. 'trk_vtxy',
  188. 'trk_vtxz',
  189. 'trk_px',
  190. 'trk_py',
  191. 'trk_pz',
  192. ])
  193. sim_tracks = tree.arrays([
  194. 'sim_pdgId',
  195. 'sim_parentVtxIdx',
  196. 'sim_decayVtxIdx',
  197. 'sim_pt',
  198. 'sim_eta',
  199. 'sim_px',
  200. 'sim_py',
  201. 'sim_pz',
  202. ])
  203. sim_vtxs = tree.arrays([
  204. 'simvtx_x',
  205. 'simvtx_y',
  206. 'simvtx_z',
  207. 'simvtx_sourceSimIdx',
  208. 'simvtx_daughterSimIdx',
  209. 'simvtx_processType',
  210. ])
  211. bsp = tree.arrays([
  212. 'bsp_x',
  213. 'bsp_y',
  214. 'bsp_z',
  215. 'bsp_sigmax',
  216. 'bsp_sigmay',
  217. 'bsp_sigmaz',
  218. ])
  219. gens = tree.arrays([
  220. 'gen_vx',
  221. 'gen_vy',
  222. 'gen_vz',
  223. 'gen_px',
  224. 'gen_py',
  225. 'gen_pz',
  226. 'gen_pdgId',
  227. ])
  228. bsp = {k: v[0] for k, v in bsp.items()}
  229. sim_pvs = tree.array('simpv_idx')
  230. for event_idx in range(tree.fEntries):
  231. plot_event(sim_tracks, sim_vtxs, sim_pvs, gens, gsf_tracks, event_idx)
  232. # break
  233. if __name__ == '__main__':
  234. main()