@@ -41,7 +41,7 @@ class Base:
41
41
42
42
"""
43
43
44
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
44
+ def __init__ (self , p , q , workers , verbose = False , extend = False , gamma = 0 ):
45
45
"""Initializ node2vec base class.
46
46
47
47
Args:
@@ -53,7 +53,11 @@ def __init__(self, p, q, workers, verbose=False, extend=False):
53
53
workers (int): number of threads to be spawned for runing node2vec
54
54
including walk generation and word2vec embedding.
55
55
verbose (bool): show progress bar for walk generation.
56
- extend (bool): ``True`` if use node2vec+ extension, default is ``False``
56
+ extend (bool): use node2vec+ extension if set to :obj:`True`
57
+ (default: :obj:`False`).
58
+ gamma (float): Multiplication factor for the std term of edge
59
+ weights added to the average edge weights as the noisy edge
60
+ threashold, only used by node2vec+ (default: 0)
57
61
58
62
"""
59
63
super ().__init__ ()
@@ -62,6 +66,7 @@ def __init__(self, p, q, workers, verbose=False, extend=False):
62
66
self .workers = workers
63
67
self .verbose = verbose
64
68
self .extend = extend
69
+ self .gamma = gamma
65
70
66
71
def _map_walk (self , walk_idx_ary ):
67
72
"""Map walk from node index to node ID.
@@ -148,16 +153,16 @@ def setup_get_normalized_probs(self):
148
153
probability computation function ``get_extended_normalized_probs``,
149
154
if node2vec+ is used. Otherwise, return the normal transition function
150
155
``get_noramlized_probs`` with a trivial placeholder for average edge
151
- weights array ``avg_wts ``.
156
+ weights array ``noise_thresholds ``.
152
157
153
158
"""
154
159
if self .extend : # use n2v+
155
160
get_normalized_probs = self .get_extended_normalized_probs
156
- avg_wts = self .get_average_weights ()
161
+ noise_thresholds = self .get_noise_thresholds ()
157
162
else : # use normal n2v
158
163
get_normalized_probs = self .get_normalized_probs
159
- avg_wts = None
160
- return get_normalized_probs , avg_wts
164
+ noise_thresholds = None
165
+ return get_normalized_probs , noise_thresholds
161
166
162
167
def preprocess_transition_probs (self ):
163
168
"""Null default preprocess method."""
@@ -221,9 +226,9 @@ def embed(
221
226
class FirstOrderUnweighted (Base , SparseRWGraph ):
222
227
"""Directly sample edges for first order random walks."""
223
228
224
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
229
+ def __init__ (self , * args , ** kwargs ):
225
230
"""Initialize FirstOrderUnweighted mode."""
226
- Base .__init__ (self , p , q , workers , verbose , extend )
231
+ Base .__init__ (self , * args , ** kwargs )
227
232
228
233
def get_move_forward (self ):
229
234
"""Wrap ``move_forward``."""
@@ -241,9 +246,9 @@ def move_forward(cur_idx, prev_idx=None):
241
246
class PreCompFirstOrder (Base , SparseRWGraph ):
242
247
"""Precompute transition probabilities for first order random walks."""
243
248
244
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
249
+ def __init__ (self , * args , ** kwargs ):
245
250
"""Initialize PreCompFirstOrder mode."""
246
- Base .__init__ (self , p , q , workers , verbose , extend )
251
+ Base .__init__ (self , * args , ** kwargs )
247
252
self .alias_j = self .alias_q = None
248
253
249
254
def get_move_forward (self ):
@@ -304,9 +309,9 @@ class PreComp(Base, SparseRWGraph):
304
309
305
310
"""
306
311
307
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
312
+ def __init__ (self , * args , ** kwargs ):
308
313
"""Initialize PreComp mode node2vec."""
309
- Base .__init__ (self , p , q , workers , verbose , extend )
314
+ Base .__init__ (self , * args , ** kwargs )
310
315
self .alias_j = self .alias_q = self .alias_indptr = self .alias_dim = None
311
316
312
317
def get_move_forward (self ):
@@ -390,7 +395,7 @@ def preprocess_transition_probs(self):
390
395
q = self .q
391
396
392
397
# Retrieve transition probability computation callback function
393
- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
398
+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
394
399
395
400
# Determine the dimensionality of the 2nd order transition probs
396
401
n_nodes = self .indptr .size - 1 # number of nodes
@@ -423,7 +428,7 @@ def compute_all_transition_probs():
423
428
q ,
424
429
idx ,
425
430
nbr ,
426
- avg_wts ,
431
+ noise_thresholds ,
427
432
)
428
433
429
434
start = offset + dim * nbr_idx
@@ -444,9 +449,9 @@ class SparseOTF(Base, SparseRWGraph):
444
449
445
450
"""
446
451
447
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
452
+ def __init__ (self , * args , ** kwargs ):
448
453
"""Initialize PreComp mode node2vec."""
449
- Base .__init__ (self , p , q , workers , verbose , extend )
454
+ Base .__init__ (self , * args , ** kwargs )
450
455
451
456
def get_move_forward (self ):
452
457
"""Wrap ``move_forward``.
@@ -467,7 +472,7 @@ def get_move_forward(self):
467
472
p = self .p
468
473
q = self .q
469
474
470
- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
475
+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
471
476
472
477
@njit (nogil = True )
473
478
def move_forward (cur_idx , prev_idx = None ):
@@ -480,7 +485,7 @@ def move_forward(cur_idx, prev_idx=None):
480
485
q ,
481
486
cur_idx ,
482
487
prev_idx ,
483
- avg_wts ,
488
+ noise_thresholds ,
484
489
)
485
490
cdf = np .cumsum (normalized_probs )
486
491
choice = np .searchsorted (cdf , np .random .random ())
@@ -499,9 +504,9 @@ class DenseOTF(Base, DenseRWGraph):
499
504
500
505
"""
501
506
502
- def __init__ (self , p , q , workers , verbose = False , extend = False ):
507
+ def __init__ (self , * args , ** kwargs ):
503
508
"""Initialize DenseOTF mode node2vec."""
504
- Base .__init__ (self , p , q , workers , verbose , extend )
509
+ Base .__init__ (self , * args , ** kwargs )
505
510
506
511
def get_move_forward (self ):
507
512
"""Wrap ``move_forward``.
@@ -521,7 +526,7 @@ def get_move_forward(self):
521
526
p = self .p
522
527
q = self .q
523
528
524
- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
529
+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
525
530
526
531
@njit (nogil = True )
527
532
def move_forward (cur_idx , prev_idx = None ):
@@ -533,7 +538,7 @@ def move_forward(cur_idx, prev_idx=None):
533
538
q ,
534
539
cur_idx ,
535
540
prev_idx ,
536
- avg_wts ,
541
+ noise_thresholds ,
537
542
)
538
543
cdf = np .cumsum (normalized_probs )
539
544
choice = np .searchsorted (cdf , np .random .random ())
0 commit comments