|
1 | 1 | """Different strategies for generating node2vec walks."""
|
| 2 | +from typing import Any |
| 3 | +from typing import Callable |
2 | 4 | from typing import List
|
3 | 5 | from typing import Optional
|
4 | 6 |
|
5 | 7 | import numpy as np
|
6 | 8 | from gensim.models import Word2Vec
|
| 9 | +from nptyping import NDArray |
7 | 10 | from numba import njit
|
8 | 11 | from numba import prange
|
9 | 12 | from numba.np.ufunc.parallel import _get_thread_id
|
|
14 | 17 | from .rw import SparseRWGraph
|
15 | 18 | from .wrappers import Timer
|
16 | 19 |
|
| 20 | +HasNbrs = Callable[[np.uint32], bool] |
| 21 | +MoveForward = Callable[..., np.uint32] |
| 22 | + |
17 | 23 |
|
18 | 24 | class Base(BaseGraph):
|
19 | 25 | """Base node2vec object.
|
@@ -137,51 +143,68 @@ def simulate_walks(
|
137 | 143 | has_nbrs = self.get_has_nbrs()
|
138 | 144 | verbose = self.verbose
|
139 | 145 |
|
140 |
| - @njit(parallel=True, nogil=True) |
141 |
| - def node2vec_walks(num_iter, progress_proxy): |
142 |
| - """Simulate a random walk starting from start node.""" |
143 |
| - # Seed the random number generator |
144 |
| - if random_state is not None: |
145 |
| - np.random.seed(random_state + _get_thread_id()) |
146 |
| - |
147 |
| - # use the last entry of each walk index array to keep track of the |
148 |
| - # effective walk length |
149 |
| - walk_idx_mat = np.zeros((num_iter, walk_length + 2), dtype=np.uint32) |
150 |
| - walk_idx_mat[:, 0] = start_node_idx_ary # initialize seeds |
151 |
| - walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default |
152 |
| - |
153 |
| - for i in prange(num_iter): |
154 |
| - # initialize first step as normal random walk |
155 |
| - start_node_idx = walk_idx_mat[i, 0] |
156 |
| - if has_nbrs(start_node_idx): |
157 |
| - walk_idx_mat[i, 1] = move_forward(start_node_idx) |
158 |
| - else: |
159 |
| - walk_idx_mat[i, -1] = 1 |
160 |
| - continue |
161 |
| - |
162 |
| - # start bias random walk |
163 |
| - for j in range(2, walk_length + 1): |
164 |
| - cur_idx = walk_idx_mat[i, j - 1] |
165 |
| - if has_nbrs(cur_idx): |
166 |
| - prev_idx = walk_idx_mat[i, j - 2] |
167 |
| - walk_idx_mat[i, j] = move_forward(cur_idx, prev_idx) |
168 |
| - else: |
169 |
| - walk_idx_mat[i, -1] = j |
170 |
| - break |
171 |
| - |
172 |
| - progress_proxy.update(1) |
173 |
| - |
174 |
| - return walk_idx_mat |
175 |
| - |
176 | 146 | # Acquire numba progress proxy for displaying the progress bar
|
177 | 147 | with ProgressBar(total=tot_num_jobs, disable=not verbose) as progress:
|
178 |
| - walk_idx_mat = node2vec_walks(tot_num_jobs, progress) |
| 148 | + walk_idx_mat = self._random_walks( |
| 149 | + tot_num_jobs, |
| 150 | + walk_length, |
| 151 | + random_state, |
| 152 | + start_node_idx_ary, |
| 153 | + has_nbrs, |
| 154 | + move_forward, |
| 155 | + progress, |
| 156 | + ) |
179 | 157 |
|
180 | 158 | # Map node index back to node ID
|
181 | 159 | walks = [self._map_walk(walk_idx_ary) for walk_idx_ary in walk_idx_mat]
|
182 | 160 |
|
183 | 161 | return walks
|
184 | 162 |
|
| 163 | + @staticmethod |
| 164 | + @njit(parallel=True, nogil=True) |
| 165 | + def _random_walks( |
| 166 | + tot_num_jobs: int, |
| 167 | + walk_length: int, |
| 168 | + random_state: Optional[int], |
| 169 | + start_node_idx_ary: NDArray[(Any,), np.uint32], |
| 170 | + has_nbrs: HasNbrs, |
| 171 | + move_forward: MoveForward, |
| 172 | + progress_proxy: ProgressBar, |
| 173 | + ): |
| 174 | + """Simulate a random walk starting from start node.""" |
| 175 | + # Seed the random number generator |
| 176 | + if random_state is not None: |
| 177 | + np.random.seed(random_state + _get_thread_id()) |
| 178 | + |
| 179 | + # use the last entry of each walk index array to keep track of the |
| 180 | + # effective walk length |
| 181 | + walk_idx_mat = np.zeros((tot_num_jobs, walk_length + 2), dtype=np.uint32) |
| 182 | + walk_idx_mat[:, 0] = start_node_idx_ary # initialize seeds |
| 183 | + walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default |
| 184 | + |
| 185 | + for i in prange(tot_num_jobs): |
| 186 | + # initialize first step as normal random walk |
| 187 | + start_node_idx = walk_idx_mat[i, 0] |
| 188 | + if has_nbrs(start_node_idx): |
| 189 | + walk_idx_mat[i, 1] = move_forward(start_node_idx) |
| 190 | + else: |
| 191 | + walk_idx_mat[i, -1] = 1 |
| 192 | + continue |
| 193 | + |
| 194 | + # start bias random walk |
| 195 | + for j in range(2, walk_length + 1): |
| 196 | + cur_idx = walk_idx_mat[i, j - 1] |
| 197 | + if has_nbrs(cur_idx): |
| 198 | + prev_idx = walk_idx_mat[i, j - 2] |
| 199 | + walk_idx_mat[i, j] = move_forward(cur_idx, prev_idx) |
| 200 | + else: |
| 201 | + walk_idx_mat[i, -1] = j |
| 202 | + break |
| 203 | + |
| 204 | + progress_proxy.update(1) |
| 205 | + |
| 206 | + return walk_idx_mat |
| 207 | + |
185 | 208 | def setup_get_normalized_probs(self):
|
186 | 209 | """Transition probability computation setup.
|
187 | 210 |
|
@@ -260,10 +283,7 @@ def embed(
|
260 | 283 | seed=self.random_state,
|
261 | 284 | )
|
262 | 285 |
|
263 |
| - # index mapping back to node IDs |
264 |
| - idx_list = [w2v.wv.get_index(i) for i in self.nodes] |
265 |
| - |
266 |
| - return w2v.wv.vectors[idx_list] |
| 286 | + return w2v.wv[self.nodes] |
267 | 287 |
|
268 | 288 |
|
269 | 289 | class FirstOrderUnweighted(Base, SparseRWGraph):
|
|
0 commit comments