Skip to content

Commit 9da0e1d

Browse files
Include the total distance in the EMST result
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent 7810aac commit 9da0e1d

File tree

4 files changed

+62
-11
lines changed

4 files changed

+62
-11
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/DualTreeMSTAlgorithm.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.neo4j.gds.core.utils.paged.dss.HugeAtomicDisjointSetStruct;
3232
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3333

34-
public class DualTreeMSTAlgorithm extends Algorithm<HugeObjectArray<Edge>> {
34+
public class DualTreeMSTAlgorithm extends Algorithm<DualTreeMSTResult> {
3535

3636
private final NodePropertyValues nodePropertyValues;
3737
private final KdTree kdTree;
@@ -44,6 +44,7 @@ public class DualTreeMSTAlgorithm extends Algorithm<HugeObjectArray<Edge>> {
4444

4545
private final HugeObjectArray<Edge> edges;
4646
private long edgeCount = 0;
47+
private double totalEdgeSum = 0d;
4748

4849
public DualTreeMSTAlgorithm(
4950
NodePropertyValues nodePropertyValues,
@@ -69,15 +70,15 @@ public DualTreeMSTAlgorithm(
6970

7071

7172
@Override
72-
public HugeObjectArray<Edge> compute() {
73+
public DualTreeMSTResult compute() {
7374

7475
var kdRoot = kdTree.root();
7576
var rootId = kdRoot.id();
7677
while (!kdNodeSingleComponent.get(rootId)) {
7778
kdNodeBound.fill(Double.MAX_VALUE);
7879
performIteration();
7980
}
80-
return edges;
81+
return new DualTreeMSTResult(edges, totalEdgeSum);
8182
}
8283

8384
void baseCase(long p0, long p1, MutableDouble maxBound) {
@@ -252,11 +253,13 @@ void mergeComponents() {
252253
continue;
253254
}
254255

256+
var distance = Math.sqrt(closestDistanceTracker.componentClosestDistance(componentId));
255257
this.edges.set(
256258
edgeCount,
257-
new Edge(u, v, Math.sqrt(closestDistanceTracker.componentClosestDistance(componentId)))
259+
new Edge(u, v, distance)
258260
);
259261
this.edgeCount++;
262+
this.totalEdgeSum += distance;
260263

261264
unionFind.union(uComponent, vComponent);
262265
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.hdbscan;
21+
22+
import org.neo4j.gds.collections.ha.HugeObjectArray;
23+
24+
public record DualTreeMSTResult(HugeObjectArray<Edge> edges, double totalDistance) {
25+
}

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ HugeObjectArray<Edge> dualTreeMSTPhase(KdTree kdTree, HugeDoubleArray coreValues
7979
kdTree,coreValues,
8080
nodes.nodeCount()
8181
);
82-
return dualTreeMst.compute();
82+
return dualTreeMst.compute().edges();
8383
}
8484
KdTree buildKDTree(){
8585
var builder =new KdTreeBuilder(nodes,nodePropertyValues,concurrency.value(),leafSize);

algo/src/test/java/org/neo4j/gds/hdbscan/DualTreeMSTAlgorithmTest.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.hdbscan;
2121

22+
import org.assertj.core.data.Offset;
2223
import org.junit.jupiter.api.Nested;
2324
import org.junit.jupiter.api.Test;
2425
import org.neo4j.gds.collections.ha.HugeDoubleArray;
@@ -62,7 +63,7 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
6263
var cores = HugeDoubleArray.newArray(graph.nodeCount());
6364

6465
var dualTree = new DualTreeMSTAlgorithm(nodePropertyValues, kdTree, cores, graph.nodeCount());
65-
var edges = dualTree.compute();
66+
var result = dualTree.compute();
6667

6768
var expected = List.of(
6869
new Edge(graph.toMappedNodeId("g"), graph.toMappedNodeId("h"), 2.23606797749979),
@@ -75,9 +76,15 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
7576
new Edge(graph.toMappedNodeId("b"), graph.toMappedNodeId("d"), 3.1622776601683795)
7677
);
7778

78-
assertThat(edges.toArray())
79+
assertThat(result.edges().toArray())
7980
.usingElementComparator(new UndirectedEdgeComparator())
8081
.containsExactlyInAnyOrderElementsOf(expected);
82+
83+
assertThat(result.totalDistance())
84+
.isEqualTo(expected.stream()
85+
.mapToDouble(Edge::distance)
86+
.sum()
87+
);
8188
}
8289

8390
}
@@ -107,7 +114,7 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
107114
var cores = HugeDoubleArray.newArray(graph.nodeCount());
108115

109116
var dualTree = new DualTreeMSTAlgorithm(nodePropertyValues, kdTree, cores, graph.nodeCount());
110-
var edges = dualTree.compute();
117+
var result = dualTree.compute();
111118

112119
var expected = List.of(
113120
new Edge(graph.toMappedNodeId("a"), graph.toMappedNodeId("b"), 3.1622776601683795),
@@ -117,9 +124,17 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
117124
new Edge(graph.toMappedNodeId("e"), graph.toMappedNodeId("f"), 1.4142135623730951)
118125
);
119126

120-
assertThat(edges.toArray())
127+
assertThat(result.edges().toArray())
121128
.usingElementComparator(new UndirectedEdgeComparator())
122129
.containsExactlyInAnyOrderElementsOf(expected);
130+
131+
assertThat(result.totalDistance())
132+
.isCloseTo(
133+
expected.stream()
134+
.mapToDouble(Edge::distance)
135+
.sum(),
136+
Offset.offset(1e-12)
137+
);
123138
}
124139

125140
}
@@ -150,7 +165,7 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
150165
var cores = HugeDoubleArray.newArray(graph.nodeCount());
151166

152167
var dualTree = new DualTreeMSTAlgorithm(nodePropertyValues, kdTree, cores, graph.nodeCount());
153-
var edges = dualTree.compute();
168+
var result = dualTree.compute();
154169

155170
var expected = List.of(
156171
new Edge(graph.toMappedNodeId("a"), graph.toMappedNodeId("d"), 0.5),
@@ -160,9 +175,17 @@ void shouldReturnEuclideanMSTWithZeroCoreValues() {
160175
new Edge(graph.toMappedNodeId("c"), graph.toMappedNodeId("e"), 997.0045135304052)
161176
);
162177

163-
assertThat(edges.toArray())
178+
assertThat(result.edges().toArray())
164179
.usingElementComparator(new UndirectedEdgeComparator())
165180
.containsExactlyInAnyOrderElementsOf(expected);
181+
182+
assertThat(result.totalDistance())
183+
.isCloseTo(
184+
expected.stream()
185+
.mapToDouble(Edge::distance)
186+
.sum(),
187+
Offset.offset(1e-12)
188+
);
166189
}
167190

168191
}

0 commit comments

Comments
 (0)