123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # distutils: language=c++
- cimport cython
- import numpy as np
- cimport numpy as np
- from libcpp.queue cimport priority_queue
- from libcpp.pair cimport pair
- ctypedef pair[float, int] step
- ctypedef priority_queue[step] pp_t
- @cython.boundscheck(False)
- @cython.wraparound(False)
- cdef hex_neighbors(const int i, const int j, int[:,:] out):
- cdef int ii, jj
- out[0][0], out[0][1] = i, j+1 # UU
- out[1][0], out[1][1] = i, j-1 # DD
- if i % 2 == 0: # even rows
- out[2][0], out[2][1] = i+1, j # UR
- out[3][0], out[3][1] = i+1, j-1 # DR
- out[4][0], out[4][1] = i-1, j-1 # DL
- out[5][0], out[5][1] = i-1, j # UL
- else: # odd rows
- out[2][0], out[2][1] = i+1, j+1 # UR
- out[3][0], out[3][1] = i+1, j # DR
- out[4][0], out[4][1] = i-1, j # DL
- out[5][0], out[5][1] = i-1, j+1 # UL
- @cython.boundscheck(True)
- @cython.wraparound(False)
- @cython.cdivision(True)
- cdef _trade_distance(int init_i, int init_j, float trade_range, const float[:, :] trade_traversal_cost,
- const float[:, :] trade_value, float[:, :] distance):
- """
- Function needs following properties:
- 1. Returns an array for the map where
- a) Shows trade distance where distance < trade_range
- b) Otherwise, filled with -1
- 2. Returns the province that is the export partner (if one exists)
- """
- cdef int width = trade_traversal_cost.shape[0]
- cdef int height = trade_traversal_cost.shape[1]
- cdef pp_t pp
- cdef step top
- cdef int i, j, ii, jj
- cdef int exp_i = -1, exp_j = -1
- cdef float exp_val = 0
- cdef float dist, dist_tmp
- cdef int[:, :] neighbors = np.zeros((6,2), dtype=np.int32)
- cdef char[:, :] visited = np.full((trade_traversal_cost.shape[0], trade_traversal_cost.shape[1]),
- False, dtype=np.int8)
- distance[init_i][init_j] = 0.0
- pp.push(step(0.0, init_j*width+init_i))
- while not pp.empty():
- top = pp.top()
- i = top.second % width
- j = top.second // width
- dist = -top.first
- pp.pop()
- if trade_value[i][j] > exp_val:
- exp_i, exp_j = i, j
- exp_val = trade_value[i][j]
- hex_neighbors(i, j, neighbors)
- for idx in range(6):
- ii = neighbors[idx][0]
- jj = neighbors[idx][1]
- if 0 <= ii < width and 0 <= jj < height and not visited[ii][jj]:
- visited[ii][jj] = True
- dist_tmp = dist + trade_traversal_cost[ii][jj]
- if dist_tmp <= trade_range:
- distance[ii][jj] = dist_tmp
- pp.push(step(-dist_tmp, jj*width+ii))
- return exp_i, exp_j
- def trade_distance(int i, int j, float trade_range, trade_traversal_cost: np.ndarray, trade_value: np.ndarray):
- distance = np.full_like(trade_traversal_cost, -1)
- exp = _trade_distance(i, j, trade_range, trade_traversal_cost, trade_value, distance)
- return exp, distance
- cpdef update_export_partner(trade_range: np.ndarray, trade_distance: np.ndarray, trade_value: np.ndarray, export_partner: np.ndarray):
- cdef int width = trade_range.shape[0]
- cdef int height = trade_range.shape[1]
- cdef float[:, :] distance
- for i in range(width):
- for j in range(height):
- distance = np.full_like(trade_range, -1)
- (exp_i, exp_j) = _trade_distance(i, j, trade_range[i, j], trade_distance, trade_value, distance)
- export_partner[i, j] = [exp_i, exp_j]
|