Skip to content

Commit d11571f

Browse files
authored
Refactor n2v walk (#81)
* Use id indexing in wv * Refactor _random_walks * Update deps, use numpy typing * Fix typing
1 parent 76bd6a1 commit d11571f

File tree

5 files changed

+74
-46
lines changed

5 files changed

+74
-46
lines changed

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ build-backend = "setuptools.build_meta"
55
[tool.mypy]
66
ignore_missing_imports = true
77
follow_imports = "skip"
8+
plugins = [
9+
"numpy.typing.mypy_plugin",
10+
]

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
gensim==4.1.2
22
numpy==1.21.5
3-
numba==0.55.0
3+
numba==0.55.1
44
numba-progress==0.0.2
5+
nptyping==1.4.4

src/pecanpy/graph.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,10 @@ def read_npz(self, path: str, weighted: bool):
425425
raw = np.load(path)
426426
self.set_node_ids(raw["IDs"].tolist())
427427
self.data = raw["data"]
428-
if not weighted: # overwrite edge weights with constant
429-
self.data[:] = 1.0
428+
if self.data is None:
429+
raise ValueError("Adjacency matrix data not found.")
430+
elif not weighted:
431+
self.data[:] = 1.0 # overwrite edge weights with constant
430432
self.indptr = raw["indptr"]
431433
self.indices = raw["indices"]
432434

@@ -523,7 +525,7 @@ def data(self) -> Optional[np.ndarray]:
523525
def data(self, data: np.ndarray):
524526
"""Set adjacency matrix and the corresponding nonzero matrix."""
525527
self._data = data.astype(float)
526-
self._nonzero = self._data != 0
528+
self._nonzero = np.array(self._data != 0, dtype=bool)
527529

528530
@property
529531
def nonzero(self) -> Optional[np.ndarray]:

src/pecanpy/pecanpy.py

+61-41
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Different strategies for generating node2vec walks."""
2+
from typing import Any
3+
from typing import Callable
24
from typing import List
35
from typing import Optional
46

57
import numpy as np
68
from gensim.models import Word2Vec
9+
from nptyping import NDArray
710
from numba import njit
811
from numba import prange
912
from numba.np.ufunc.parallel import _get_thread_id
@@ -14,6 +17,9 @@
1417
from .rw import SparseRWGraph
1518
from .wrappers import Timer
1619

20+
HasNbrs = Callable[[np.uint32], bool]
21+
MoveForward = Callable[..., np.uint32]
22+
1723

1824
class Base(BaseGraph):
1925
"""Base node2vec object.
@@ -137,51 +143,68 @@ def simulate_walks(
137143
has_nbrs = self.get_has_nbrs()
138144
verbose = self.verbose
139145

140-
@njit(parallel=True, nogil=True)
141-
def node2vec_walks(num_iter, progress_proxy):
142-
"""Simulate a random walk starting from start node."""
143-
# Seed the random number generator
144-
if random_state is not None:
145-
np.random.seed(random_state + _get_thread_id())
146-
147-
# use the last entry of each walk index array to keep track of the
148-
# effective walk length
149-
walk_idx_mat = np.zeros((num_iter, walk_length + 2), dtype=np.uint32)
150-
walk_idx_mat[:, 0] = start_node_idx_ary # initialize seeds
151-
walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default
152-
153-
for i in prange(num_iter):
154-
# initialize first step as normal random walk
155-
start_node_idx = walk_idx_mat[i, 0]
156-
if has_nbrs(start_node_idx):
157-
walk_idx_mat[i, 1] = move_forward(start_node_idx)
158-
else:
159-
walk_idx_mat[i, -1] = 1
160-
continue
161-
162-
# start bias random walk
163-
for j in range(2, walk_length + 1):
164-
cur_idx = walk_idx_mat[i, j - 1]
165-
if has_nbrs(cur_idx):
166-
prev_idx = walk_idx_mat[i, j - 2]
167-
walk_idx_mat[i, j] = move_forward(cur_idx, prev_idx)
168-
else:
169-
walk_idx_mat[i, -1] = j
170-
break
171-
172-
progress_proxy.update(1)
173-
174-
return walk_idx_mat
175-
176146
# Acquire numba progress proxy for displaying the progress bar
177147
with ProgressBar(total=tot_num_jobs, disable=not verbose) as progress:
178-
walk_idx_mat = node2vec_walks(tot_num_jobs, progress)
148+
walk_idx_mat = self._random_walks(
149+
tot_num_jobs,
150+
walk_length,
151+
random_state,
152+
start_node_idx_ary,
153+
has_nbrs,
154+
move_forward,
155+
progress,
156+
)
179157

180158
# Map node index back to node ID
181159
walks = [self._map_walk(walk_idx_ary) for walk_idx_ary in walk_idx_mat]
182160

183161
return walks
184162

163+
@staticmethod
164+
@njit(parallel=True, nogil=True)
165+
def _random_walks(
166+
tot_num_jobs: int,
167+
walk_length: int,
168+
random_state: Optional[int],
169+
start_node_idx_ary: NDArray[(Any,), np.uint32],
170+
has_nbrs: HasNbrs,
171+
move_forward: MoveForward,
172+
progress_proxy: ProgressBar,
173+
):
174+
"""Simulate a random walk starting from start node."""
175+
# Seed the random number generator
176+
if random_state is not None:
177+
np.random.seed(random_state + _get_thread_id())
178+
179+
# use the last entry of each walk index array to keep track of the
180+
# effective walk length
181+
walk_idx_mat = np.zeros((tot_num_jobs, walk_length + 2), dtype=np.uint32)
182+
walk_idx_mat[:, 0] = start_node_idx_ary # initialize seeds
183+
walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default
184+
185+
for i in prange(tot_num_jobs):
186+
# initialize first step as normal random walk
187+
start_node_idx = walk_idx_mat[i, 0]
188+
if has_nbrs(start_node_idx):
189+
walk_idx_mat[i, 1] = move_forward(start_node_idx)
190+
else:
191+
walk_idx_mat[i, -1] = 1
192+
continue
193+
194+
# start bias random walk
195+
for j in range(2, walk_length + 1):
196+
cur_idx = walk_idx_mat[i, j - 1]
197+
if has_nbrs(cur_idx):
198+
prev_idx = walk_idx_mat[i, j - 2]
199+
walk_idx_mat[i, j] = move_forward(cur_idx, prev_idx)
200+
else:
201+
walk_idx_mat[i, -1] = j
202+
break
203+
204+
progress_proxy.update(1)
205+
206+
return walk_idx_mat
207+
185208
def setup_get_normalized_probs(self):
186209
"""Transition probability computation setup.
187210
@@ -260,10 +283,7 @@ def embed(
260283
seed=self.random_state,
261284
)
262285

263-
# index mapping back to node IDs
264-
idx_list = [w2v.wv.get_index(i) for i in self.nodes]
265-
266-
return w2v.wv.vectors[idx_list]
286+
return w2v.wv[self.nodes]
267287

268288

269289
class FirstOrderUnweighted(Base, SparseRWGraph):

tox.ini

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ commands =
1919

2020
[testenv:mypy]
2121
skip_install = true
22-
deps = mypy
22+
deps =
23+
mypy
24+
numpy
2325
commands = mypy src/pecanpy
2426

2527
[testenv:flake8]

0 commit comments

Comments
 (0)