Skip to content

Commit d67611a

Browse files
authored
Merge pull request #9575 from lassewesth/pipe2
migrate nc pipelines
2 parents 8852e89 + 1c207b0 commit d67611a

23 files changed

+470
-232
lines changed

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddTrainerMethodProcs.java

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.pipeline.node.classification;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.ConfigKeyValidation;
24-
import org.neo4j.gds.ml.api.TrainingMethod;
25-
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
26-
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
27-
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
28-
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
29-
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
3023
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
31-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
24+
import org.neo4j.procedure.Context;
3225
import org.neo4j.procedure.Description;
3326
import org.neo4j.procedure.Internal;
3427
import org.neo4j.procedure.Name;
@@ -39,25 +32,17 @@
3932

4033
import static org.neo4j.procedure.Mode.READ;
4134

42-
public class NodeClassificationPipelineAddTrainerMethodProcs extends BaseProc {
35+
public class NodeClassificationPipelineAddTrainerMethodProcs {
36+
@Context
37+
public GraphDataScienceProcedures facade;
4338

4439
@Procedure(name = "gds.beta.pipeline.nodeClassification.addLogisticRegression", mode = READ)
4540
@Description("Add a logistic regression configuration to the parameter space of the node classification train pipeline.")
4641
public Stream<NodePipelineInfoResult> addLogisticRegression(
4742
@Name("pipelineName") String pipelineName,
4843
@Name(value = "config", defaultValue = "{}") Map<String, Object> logisticRegressionClassifierConfig
4944
) {
50-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
51-
52-
var allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys();
53-
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, logisticRegressionClassifierConfig.keySet());
54-
55-
var tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, TrainingMethod.LogisticRegression);
56-
pipeline.addTrainerConfig(
57-
tunableTrainerConfig
58-
);
59-
60-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
45+
return facade.pipelines().addLogisticRegression(pipelineName, logisticRegressionClassifierConfig);
6146
}
6247

6348
@Procedure(name = "gds.beta.pipeline.nodeClassification.addRandomForest", mode = READ)
@@ -66,17 +51,7 @@ public Stream<NodePipelineInfoResult> addRandomForest(
6651
@Name("pipelineName") String pipelineName,
6752
@Name(value = "config") Map<String, Object> randomForestClassifierConfig
6853
) {
69-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
70-
71-
var allowedKeys = RandomForestClassifierTrainerConfig.DEFAULT.configKeys();
72-
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, randomForestClassifierConfig.keySet());
73-
74-
var tunableTrainerConfig = TunableTrainerConfig.of(randomForestClassifierConfig, TrainingMethod.RandomForestClassification);
75-
pipeline.addTrainerConfig(
76-
tunableTrainerConfig
77-
);
78-
79-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
54+
return facade.pipelines().addRandomForest(pipelineName, randomForestClassifierConfig);
8055
}
8156

8257
@Procedure(name = "gds.alpha.pipeline.nodeClassification.addRandomForest", mode = READ, deprecatedBy = "gds.beta.pipeline.nodeClassification.addRandomForest")
@@ -87,14 +62,9 @@ public Stream<NodePipelineInfoResult> addRandomForestAlpha(
8762
@Name("pipelineName") String pipelineName,
8863
@Name(value = "config") Map<String, Object> randomForestClassifierConfig
8964
) {
90-
executionContext()
91-
.metricsFacade()
92-
.deprecatedProcedures().called("gds.alpha.pipeline.nodeClassification.addRandomForest");
65+
facade.deprecatedProcedures().called("gds.alpha.pipeline.nodeClassification.addRandomForest");
66+
facade.log().warn("Procedure `gds.alpha.pipeline.nodeClassification.addRandomForest` has been deprecated, please use `gds.beta.pipeline.nodeClassification.addRandomForest`.");
9367

94-
executionContext()
95-
.log()
96-
.warn(
97-
"Procedure `gds.alpha.pipeline.nodeClassification.addRandomForest` has been deprecated, please use `gds.beta.pipeline.nodeClassification.addRandomForest`.");
9868
return addRandomForest(pipelineName, randomForestClassifierConfig);
9969
}
10070

@@ -104,13 +74,6 @@ public Stream<NodePipelineInfoResult> addMLP(
10474
@Name("pipelineName") String pipelineName,
10575
@Name(value = "config", defaultValue = "{}") Map<String, Object> mlpClassifierConfig
10676
) {
107-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
108-
109-
var allowedKeys = MLPClassifierTrainConfig.DEFAULT.configKeys();
110-
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, mlpClassifierConfig.keySet());
111-
112-
pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, TrainingMethod.MLPClassification));
113-
114-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
77+
return facade.pipelines().addMLP(pipelineName, mlpClassifierConfig);
11578
}
11679
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineConfigureAutoTuningProc.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.pipeline.node.classification;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
24-
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
2523
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
26-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
24+
import org.neo4j.procedure.Context;
2725
import org.neo4j.procedure.Description;
2826
import org.neo4j.procedure.Name;
2927
import org.neo4j.procedure.Procedure;
@@ -33,18 +31,13 @@
3331

3432
import static org.neo4j.procedure.Mode.READ;
3533

36-
public class NodeClassificationPipelineConfigureAutoTuningProc extends BaseProc {
34+
public class NodeClassificationPipelineConfigureAutoTuningProc {
35+
@Context
36+
public GraphDataScienceProcedures facade;
3737

3838
@Procedure(name = "gds.alpha.pipeline.nodeClassification.configureAutoTuning", mode = READ)
3939
@Description("Configures the auto-tuning of the node classification pipeline.")
4040
public Stream<NodePipelineInfoResult> configureAutoTuning(@Name("pipelineName") String pipelineName, @Name("configuration") Map<String, Object> configMap) {
41-
PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
42-
return PipelineCompanion.configureAutoTuning(
43-
username(),
44-
pipelineName,
45-
configMap,
46-
pipeline -> new NodePipelineInfoResult(pipelineName, (NodeClassificationTrainingPipeline) pipeline)
47-
);
41+
return facade.pipelines().configureAutoTuning(pipelineName, configMap);
4842
}
49-
5043
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineConfigureSplitProc.java

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.pipeline.node.classification;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.CypherMapWrapper;
24-
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
2523
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
26-
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
27-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
24+
import org.neo4j.procedure.Context;
2825
import org.neo4j.procedure.Description;
2926
import org.neo4j.procedure.Name;
3027
import org.neo4j.procedure.Procedure;
@@ -34,20 +31,13 @@
3431

3532
import static org.neo4j.procedure.Mode.READ;
3633

37-
public class NodeClassificationPipelineConfigureSplitProc extends BaseProc {
34+
public class NodeClassificationPipelineConfigureSplitProc {
35+
@Context
36+
public GraphDataScienceProcedures facade;
3837

3938
@Procedure(name = "gds.beta.pipeline.nodeClassification.configureSplit", mode = READ)
4039
@Description("Configures the split of the node classification training pipeline.")
4140
public Stream<NodePipelineInfoResult> configureSplit(@Name("pipelineName") String pipelineName, @Name("configuration") Map<String, Object> configMap) {
42-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
43-
44-
var cypherConfig = CypherMapWrapper.create(configMap);
45-
var config = NodePropertyPredictionSplitConfig.of(cypherConfig);
46-
47-
cypherConfig.requireOnlyKeysFrom(config.configKeys());
48-
49-
pipeline.setSplitConfig(config);
50-
51-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
41+
return facade.pipelines().configureSplit(pipelineName, configMap);
5242
}
5343
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineCreateProc.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.pipeline.node.classification;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.StringIdentifierValidations;
24-
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
2523
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
26-
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
24+
import org.neo4j.procedure.Context;
2725
import org.neo4j.procedure.Description;
2826
import org.neo4j.procedure.Name;
2927
import org.neo4j.procedure.Procedure;
@@ -32,22 +30,13 @@
3230

3331
import static org.neo4j.procedure.Mode.READ;
3432

35-
@SuppressWarnings("immutables:subtype")
36-
public class NodeClassificationPipelineCreateProc extends BaseProc {
37-
38-
public static NodePipelineInfoResult create(String username, String pipelineName) {
39-
StringIdentifierValidations.validateNoWhiteCharacter(pipelineName, "pipelineName");
40-
41-
var pipeline = new NodeClassificationTrainingPipeline();
42-
43-
PipelineCatalog.set(username, pipelineName, pipeline);
44-
45-
return new NodePipelineInfoResult(pipelineName, pipeline);
46-
}
33+
public class NodeClassificationPipelineCreateProc {
34+
@Context
35+
public GraphDataScienceProcedures facade;
4736

4837
@Procedure(name = "gds.beta.pipeline.nodeClassification.create", mode = READ)
4938
@Description("Creates a node classification training pipeline in the pipeline catalog.")
5039
public Stream<NodePipelineInfoResult> create(@Name("pipelineName") String pipelineName) {
51-
return Stream.of(create(username(), pipelineName));
40+
return facade.pipelines().createPipeline(pipelineName);
5241
}
5342
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/configure/NodeRegressionPipelineAddStepProcs.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public Stream<NodePipelineInfoResult> addNodeProperty(
4848

4949
pipeline.addNodePropertyStep(createNodePropertyStep(taskName, procedureConfig));
5050

51-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
51+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
5252
}
5353

5454
@Procedure(name = "gds.alpha.pipeline.nodeRegression.selectFeatures", mode = READ)
@@ -74,6 +74,6 @@ public Stream<NodePipelineInfoResult> selectFeatures(
7474
throw new IllegalArgumentException("The value of `featureProperties` is required to be a list of strings.");
7575
}
7676

77-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
77+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
7878
}
7979
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/configure/NodeRegressionPipelineAddTrainerMethodProcs.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public Stream<NodePipelineInfoResult> addLogisticRegression(
5252

5353
pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, TrainingMethod.LinearRegression));
5454

55-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
55+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
5656
}
5757

5858
@Procedure(name = "gds.alpha.pipeline.nodeRegression.addRandomForest", mode = READ)
@@ -68,6 +68,6 @@ public Stream<NodePipelineInfoResult> addRandomForest(
6868

6969
pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, TrainingMethod.RandomForestRegression));
7070

71-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
71+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
7272
}
7373
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/configure/NodeRegressionPipelineConfigureAutoTuningProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public Stream<NodePipelineInfoResult> configureAutoTuning(@Name("pipelineName")
4343
username(),
4444
pipelineName,
4545
configMap,
46-
pipeline -> new NodePipelineInfoResult(pipelineName, (NodeRegressionTrainingPipeline) pipeline)
46+
pipeline -> NodePipelineInfoResult.create(pipelineName, (NodeRegressionTrainingPipeline) pipeline)
4747
);
4848
}
4949

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/configure/NodeRegressionPipelineConfigureSplitProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public Stream<NodePipelineInfoResult> configureSplit(
5050

5151
pipeline.setSplitConfig(config);
5252

53-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
53+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
5454
}
5555
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/configure/NodeRegressionPipelineCreateProc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ public Stream<NodePipelineInfoResult> create(@Name("pipelineName") String pipeli
4343

4444
PipelineCatalog.set(username(), pipelineName, pipeline);
4545

46-
return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
46+
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
4747
}
4848
}

0 commit comments

Comments
 (0)