trade_utils.pyx 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # distutils: language=c++
  2. cimport cython
  3. import numpy as np
  4. cimport numpy as np
  5. from libcpp.queue cimport priority_queue
  6. from libcpp.pair cimport pair
  7. ctypedef pair[float, int] step
  8. ctypedef priority_queue[step] pp_t
  9. @cython.boundscheck(False)
  10. @cython.wraparound(False)
  11. cdef hex_neighbors(const int i, const int j, int[:,:] out):
  12. cdef int ii, jj
  13. out[0][0], out[0][1] = i, j+1 # UU
  14. out[1][0], out[1][1] = i, j-1 # DD
  15. if i % 2 == 0: # even rows
  16. out[2][0], out[2][1] = i+1, j # UR
  17. out[3][0], out[3][1] = i+1, j-1 # DR
  18. out[4][0], out[4][1] = i-1, j-1 # DL
  19. out[5][0], out[5][1] = i-1, j # UL
  20. else: # odd rows
  21. out[2][0], out[2][1] = i+1, j+1 # UR
  22. out[3][0], out[3][1] = i+1, j # DR
  23. out[4][0], out[4][1] = i-1, j # DL
  24. out[5][0], out[5][1] = i-1, j+1 # UL
  25. @cython.boundscheck(True)
  26. @cython.wraparound(False)
  27. @cython.cdivision(True)
  28. cdef _trade_distance(int init_i, int init_j, float trade_range, const float[:, :] trade_traversal_cost,
  29. const float[:, :] trade_value, float[:, :] distance):
  30. """
  31. Function needs following properties:
  32. 1. Returns an array for the map where
  33. a) Shows trade distance where distance < trade_range
  34. b) Otherwise, filled with -1
  35. 2. Returns the province that is the export partner (if one exists)
  36. """
  37. cdef int width = trade_traversal_cost.shape[0]
  38. cdef int height = trade_traversal_cost.shape[1]
  39. cdef pp_t pp
  40. cdef step top
  41. cdef int i, j, ii, jj
  42. cdef int exp_i = -1, exp_j = -1
  43. cdef float exp_val = 0
  44. cdef float dist, dist_tmp
  45. cdef int[:, :] neighbors = np.zeros((6,2), dtype=np.int32)
  46. cdef char[:, :] visited = np.full((trade_traversal_cost.shape[0], trade_traversal_cost.shape[1]),
  47. False, dtype=np.int8)
  48. distance[init_i][init_j] = 0.0
  49. pp.push(step(0.0, init_j*width+init_i))
  50. while not pp.empty():
  51. top = pp.top()
  52. i = top.second % width
  53. j = top.second // width
  54. dist = -top.first
  55. pp.pop()
  56. if trade_value[i][j] > exp_val:
  57. exp_i, exp_j = i, j
  58. exp_val = trade_value[i][j]
  59. hex_neighbors(i, j, neighbors)
  60. for idx in range(6):
  61. ii = neighbors[idx][0]
  62. jj = neighbors[idx][1]
  63. if 0 <= ii < width and 0 <= jj < height and not visited[ii][jj]:
  64. visited[ii][jj] = True
  65. dist_tmp = dist + trade_traversal_cost[ii][jj]
  66. if dist_tmp <= trade_range:
  67. distance[ii][jj] = dist_tmp
  68. pp.push(step(-dist_tmp, jj*width+ii))
  69. return exp_i, exp_j
  70. def trade_distance(int i, int j, float trade_range, trade_traversal_cost: np.ndarray, trade_value: np.ndarray):
  71. distance = np.full_like(trade_traversal_cost, -1)
  72. exp = _trade_distance(i, j, trade_range, trade_traversal_cost, trade_value, distance)
  73. return exp, distance
  74. cpdef update_export_partner(trade_range: np.ndarray, trade_distance: np.ndarray, trade_value: np.ndarray, export_partner: np.ndarray):
  75. cdef int width = trade_range.shape[0]
  76. cdef int height = trade_range.shape[1]
  77. cdef float[:, :] distance
  78. for i in range(width):
  79. for j in range(height):
  80. distance = np.full_like(trade_range, -1)
  81. (exp_i, exp_j) = _trade_distance(i, j, trade_range[i, j], trade_distance, trade_value, distance)
  82. export_partner[i, j] = [exp_i, exp_j]