trade_utils.pyx 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # distutils: language=c++
  2. cimport cython
  3. import numpy as np
  4. cimport numpy as np
  5. from libcpp.queue cimport priority_queue
  6. from libcpp.pair cimport pair
  7. ctypedef pair[float, int] step
  8. ctypedef priority_queue[step] pp_t
  9. @cython.boundscheck(False)
  10. @cython.wraparound(False)
  11. cdef hex_neighbors(const int i, const int j, int[:,:] out):
  12. cdef int ii, jj
  13. out[0][0], out[0][1] = i, j+1 # UU
  14. out[1][0], out[1][1] = i, j-1 # DD
  15. if i % 2 == 0: # even rows
  16. out[2][0], out[2][1] = i+1, j # UR
  17. out[3][0], out[3][1] = i+1, j-1 # DR
  18. out[4][0], out[4][1] = i-1, j-1 # DL
  19. out[5][0], out[5][1] = i-1, j # UL
  20. else: # odd rows
  21. out[2][0], out[2][1] = i+1, j+1 # UR
  22. out[3][0], out[3][1] = i+1, j # DR
  23. out[4][0], out[4][1] = i-1, j # DL
  24. out[5][0], out[5][1] = i-1, j+1 # UL
  25. @cython.boundscheck(True)
  26. @cython.wraparound(False)
  27. @cython.cdivision(True)
  28. cdef _trade_distance(int init_i, int init_j, float trade_range, const float[:, :] trade_impedance,
  29. const float[:, :] trade_value, float[:, :] distance):
  30. cdef int width = trade_impedance.shape[0]
  31. cdef int height = trade_impedance.shape[1]
  32. cdef pp_t pp
  33. cdef step top
  34. cdef int i, j, ii, jj
  35. cdef int exp_i = -1, exp_j = -1
  36. cdef float exp_val = 0
  37. cdef float dist, dist_tmp
  38. cdef int[:, :] neighbors = np.zeros((6,2), dtype=np.int32)
  39. cdef char[:, :] visited = np.full((trade_impedance.shape[0], trade_impedance.shape[1]),
  40. False, dtype=np.int8)
  41. distance[init_i][init_j] = 0.0
  42. pp.push(step(0.0, init_j*width+init_i))
  43. while not pp.empty():
  44. top = pp.top()
  45. i = top.second % width
  46. j = top.second // width
  47. dist = -top.first
  48. pp.pop()
  49. if trade_value[i][j] > exp_val:
  50. exp_i, exp_j = i, j
  51. exp_val = trade_value[i][j]
  52. hex_neighbors(i, j, neighbors)
  53. for idx in range(6):
  54. ii = neighbors[idx][0]
  55. jj = neighbors[idx][1]
  56. if 0 <= ii < width and 0 <= jj < height and not visited[ii][jj]:
  57. visited[ii][jj] = True
  58. dist_tmp = dist + trade_impedance[ii][jj]
  59. if dist_tmp <= trade_range:
  60. distance[ii][jj] = dist_tmp
  61. pp.push(step(-dist_tmp, jj*width+ii))
  62. return exp_i, exp_j
  63. def trade_distance(int i, int j, float trade_range, trade_impedance: np.ndarray, trade_value: np.ndarray):
  64. distance = np.full_like(trade_impedance, -1)
  65. exp = _trade_distance(i, j, trade_range, trade_impedance, trade_value, distance)
  66. return exp, distance
  67. @cython.boundscheck(False)
  68. @cython.wraparound(False)
  69. @cython.cdivision(True)
  70. cpdef update_export_partner(trade_range: np.ndarray, trade_distance: np.ndarray, trade_value: np.ndarray, export_partner: np.ndarray):
  71. cdef int width = trade_range.shape[0]
  72. cdef int height = trade_range.shape[1]
  73. cdef int exp_i, exp_j
  74. print(trade_range)
  75. cdef float[:, :] distance
  76. for i in range(width):
  77. for j in range(height):
  78. distance = np.full_like(trade_range, -1)
  79. (exp_i, exp_j) = _trade_distance(i, j, trade_range[i, j], trade_distance, trade_value, distance)
  80. export_partner[i, j] = [exp_i, exp_j]
  81. @cython.boundscheck(True)
  82. @cython.wraparound(False)
  83. @cython.cdivision(True)
  84. cpdef share_food(food_produced: np.ndarray, food_consumed: np.ndarray, food_stored: np.ndarray,
  85. food_stored_capacity: np.ndarray, export_partner: np.ndarray,
  86. spread_factor_neighbor: float):
  87. """
  88. If a province produces more food that it consumes *and* it's food storage is full, a portion of it's excess
  89. can be transferred to neighboring hex's provided that they don't also have excess production. Any leftover excess
  90. will get shipped to the export partner.
  91. """
  92. cdef int width = food_produced.shape[0]
  93. cdef int height = food_produced.shape[1]
  94. cdef float[:, :] food_excess = food_produced - food_consumed
  95. cdef int ii, jj
  96. cdef int[:, :] neighbors = np.zeros((6,2), dtype=np.int32)
  97. food_stored_new = np.copy(food_stored)
  98. cdef float[:, :] food_stored_new_ = food_stored_new
  99. cdef float spread_amount
  100. cdef int spread_count
  101. for i in range(width):
  102. for j in range(height):
  103. if food_excess[i][j] <= 0 or food_stored[i][j] < food_stored_capacity[i][j]:
  104. continue
  105. hex_neighbors(i, j, neighbors)
  106. spread_amount_neighbor = spread_factor_neighbor * food_excess[i][j]
  107. spread_count = 0
  108. for idx in range(6):
  109. ii = neighbors[idx][0]
  110. jj = neighbors[idx][1]
  111. if not (0 <= ii < width and 0 <= jj < height):
  112. continue
  113. if food_excess[ii][jj] < food_excess[i][j] and food_stored[ii][jj] < food_stored_capacity[ii][jj]:
  114. food_stored_new_[ii][jj] += spread_amount_neighbor
  115. food_stored_new_[i][j] -= spread_amount_neighbor
  116. spread_count += 1
  117. spread_amount = food_excess[i][j] * (1 - spread_amount_neighbor*spread_count)
  118. try:
  119. food_stored_new_[export_partner[i][j][0], export_partner[i][j][1]] += spread_amount
  120. except IndexError:
  121. print(export_partner[i][j])
  122. food_stored_new_[i][j] -= spread_amount
  123. np.copyto(food_stored, food_stored_new)