Skip to content

Commit 0a1f294

Browse files
Do for each iteration
1 parent ef1b195 commit 0a1f294

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ Node2VecResult train() {
128128

129129
var lossPerIteration = new ArrayList<Double>();
130130

131+
AtomicInteger taskIndex = new AtomicInteger(0);
132+
131133
for (int iteration = 0; iteration < iterations; iteration++) {
132134
progressTracker.beginSubTask();
133135
progressTracker.setVolume(walks.size());
@@ -137,7 +139,7 @@ Node2VecResult train() {
137139
initialLearningRate - iteration * learningRateAlpha
138140
);
139141

140-
var tasks = createTrainingTasks(learningRate);
142+
var tasks = createTrainingTasks(learningRate, taskIndex);
141143

142144
RunWithConcurrency.builder()
143145
.concurrency(concurrency)
@@ -288,8 +290,7 @@ void addAll(FloatConsumer other) {
288290
}
289291
}
290292

291-
List<TrainingTask> createTrainingTasks(float learningRate){
292-
AtomicInteger taskIndex = new AtomicInteger(0);
293+
List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskIndex){
293294
return PartitionUtils.degreePartitionWithBatchSize(
294295
walks.size(),
295296
walks::walkLength,

0 commit comments

Comments
 (0)