Skip to content

Commit

Permalink
migrate FastRP write
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Jun 26, 2024
1 parent 0b955f4 commit 0bf49a6
Show file tree
Hide file tree
Showing 19 changed files with 295 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
package org.neo4j.gds.algorithms.embeddings;

import org.neo4j.gds.algorithms.estimation.AlgorithmEstimator;
import org.neo4j.gds.embeddings.fastrp.FastRPBaseConfig;
import org.neo4j.gds.embeddings.fastrp.FastRPMemoryEstimateDefinition;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
Expand All @@ -31,7 +30,6 @@
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecMemoryEstimateDefinition;
import org.neo4j.gds.modelcatalogservices.ModelCatalogService;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;

import java.util.Optional;

Expand Down Expand Up @@ -92,18 +90,6 @@ public <C extends GraphSageTrainConfig> MemoryEstimateResult graphSageTrain(
);
}

public <C extends FastRPBaseConfig> MemoryEstimateResult fastRP(
Object graphNameOrConfiguration,
C configuration
) {
return algorithmEstimator.estimate(
graphNameOrConfiguration,
configuration,
configuration.relationshipWeightProperty(),
new FastRPMemoryEstimateDefinition(configuration.toParameters())
);
}

public <C extends HashGNNConfig> MemoryEstimateResult hashGNN(
Object graphNameOrConfiguration,
C configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.algorithms.validation.AfterLoadValidation;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.fastrp.FastRPBaseConfig;
import org.neo4j.gds.embeddings.fastrp.FastRPFactory;
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageAlgorithmFactory;
Expand Down Expand Up @@ -110,18 +107,6 @@ AlgorithmComputationResult<Model<ModelData, GraphSageTrainConfig, GraphSageModel
);
}

AlgorithmComputationResult<FastRPResult> fastRP(
String graphName,
FastRPBaseConfig config
) {
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
new FastRPFactory<>()
);
}

AlgorithmComputationResult<HashGNNResult> hashGNN(
String graphName,
HashGNNConfig config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import org.neo4j.gds.algorithms.NodePropertyWriteResult;
import org.neo4j.gds.algorithms.embeddings.specificfields.Node2VecSpecificFields;
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.applications.algorithms.machinery.WriteNodePropertyService;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
import org.neo4j.gds.applications.algorithms.machinery.WriteNodePropertyService;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.ArrowConnectionInfo;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageWriteConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecWriteConfig;

Expand Down Expand Up @@ -102,31 +101,6 @@ public NodePropertyWriteResult<Long> graphSage(
);
}

public NodePropertyWriteResult<Long> fastRP(
String graphName,
FastRPWriteConfig configuration
) {
// 1. Run the algorithm and time the execution
var intermediateResult = AlgorithmRunner.runWithTiming(
() -> nodeEmbeddingsAlgorithmsFacade.fastRP(graphName, configuration)
);

return writeToDatabase(
intermediateResult.algorithmResult,
configuration,
(result) -> NodePropertyValuesAdapter.adapt(result.embeddings()),
(result) -> intermediateResult.algorithmResult.graph().nodeCount(),
intermediateResult.computeMilliseconds,
() -> 0L,
"FastRPWrite",
configuration.writeConcurrency(),
configuration.writeProperty(),
configuration.arrowConnectionInfo(),

configuration.resolveResultStore(intermediateResult.algorithmResult.resultStore())
);
}

<RESULT, CONFIG extends AlgoBaseConfig, ASF> NodePropertyWriteResult<ASF> writeToDatabase(
AlgorithmComputationResult<RESULT> algorithmResult,
CONFIG configuration,
Expand Down
1 change: 1 addition & 0 deletions applications/algorithms/node-embeddings/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies {
implementation project(":algorithms-machinery")
implementation project(":config-api")
implementation project(":core")
implementation project(":logging")
implementation project(":memory-usage")
implementation project(":ml-core")
implementation project(":progress-tracking")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.embeddings;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
import org.neo4j.gds.applications.algorithms.machinery.MutateOrWriteStep;
import org.neo4j.gds.applications.algorithms.machinery.WriteToDatabase;
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.WCC;

class FastRPWriteStep implements MutateOrWriteStep<FastRPResult, NodePropertiesWritten> {
private final WriteToDatabase writeToDatabase;
private final FastRPWriteConfig configuration;

FastRPWriteStep(WriteToDatabase writeToDatabase, FastRPWriteConfig configuration) {
this.writeToDatabase = writeToDatabase;
this.configuration = configuration;
}

@Override
public NodePropertiesWritten execute(
Graph graph,
GraphStore graphStore,
ResultStore resultStore,
FastRPResult result,
JobId jobId
) {
var nodePropertyValues = NodePropertyValuesAdapter.adapt(result.embeddings());

return writeToDatabase.perform(
graph,
graphStore,
resultStore,
configuration,
configuration,
WCC,
jobId,
nodePropertyValues
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.embeddings;

import org.neo4j.gds.api.GraphName;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplate;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.applications.algorithms.machinery.ResultBuilder;
import org.neo4j.gds.applications.algorithms.machinery.WriteNodePropertyService;
import org.neo4j.gds.applications.algorithms.machinery.WriteToDatabase;
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;
import org.neo4j.gds.logging.Log;

import java.util.Optional;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FastRP;

public final class NodeEmbeddingAlgorithmsWriteModeBusinessFacade {
private final NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade;
private final NodeEmbeddingAlgorithms algorithms;
private final AlgorithmProcessingTemplate algorithmProcessingTemplate;
private final WriteToDatabase writeToDatabase;

private NodeEmbeddingAlgorithmsWriteModeBusinessFacade(
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade,
NodeEmbeddingAlgorithms algorithms,
AlgorithmProcessingTemplate algorithmProcessingTemplate,
WriteToDatabase writeToDatabase
) {
this.estimationFacade = estimationFacade;
this.algorithms = algorithms;
this.algorithmProcessingTemplate = algorithmProcessingTemplate;
this.writeToDatabase = writeToDatabase;
}

public static NodeEmbeddingAlgorithmsWriteModeBusinessFacade create(
Log log,
RequestScopedDependencies requestScopedDependencies,
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade,
NodeEmbeddingAlgorithms algorithms,
AlgorithmProcessingTemplate algorithmProcessingTemplate
) {
var writeNodePropertyService = new WriteNodePropertyService(log, requestScopedDependencies);
var writeToDatabase = new WriteToDatabase(writeNodePropertyService);

return new NodeEmbeddingAlgorithmsWriteModeBusinessFacade(
estimationFacade,
algorithms,
algorithmProcessingTemplate,
writeToDatabase
);
}

public <RESULT> RESULT fastRP(
GraphName graphName,
FastRPWriteConfig configuration,
ResultBuilder<FastRPWriteConfig, FastRPResult, RESULT, NodePropertiesWritten> resultBuilder
) {
var writeStep = new FastRPWriteStep(writeToDatabase, configuration);

return algorithmProcessingTemplate.processAlgorithm(
graphName,
configuration,
FastRP,
() -> estimationFacade.fastRP(configuration),
graph -> algorithms.fastRP(graph, configuration),
Optional.of(writeStep),
resultBuilder
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public static ApplicationsFacade create(
);

var nodeEmbeddingApplications = NodeEmbeddingApplications.create(
log,
requestScopedDependencies,
algorithmEstimationTemplate,
algorithmProcessingTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,37 @@
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithmsMutateModeBusinessFacade;
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithmsStatsModeBusinessFacade;
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithmsStreamModeBusinessFacade;
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithmsWriteModeBusinessFacade;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmEstimationTemplate;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplate;
import org.neo4j.gds.applications.algorithms.machinery.MutateNodeProperty;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.logging.Log;

public final class NodeEmbeddingApplications {
private final NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationMode;
private final NodeEmbeddingAlgorithmsMutateModeBusinessFacade mutateMode;
private final NodeEmbeddingAlgorithmsStatsModeBusinessFacade statsMode;
private final NodeEmbeddingAlgorithmsStreamModeBusinessFacade streamMode;
private final NodeEmbeddingAlgorithmsWriteModeBusinessFacade writeMode;

private NodeEmbeddingApplications(
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationMode,
NodeEmbeddingAlgorithmsMutateModeBusinessFacade mutateMode,
NodeEmbeddingAlgorithmsStatsModeBusinessFacade statsMode,
NodeEmbeddingAlgorithmsStreamModeBusinessFacade streamMode
NodeEmbeddingAlgorithmsStreamModeBusinessFacade streamMode,
NodeEmbeddingAlgorithmsWriteModeBusinessFacade writeMode
) {
this.estimationMode = estimationMode;
this.mutateMode = mutateMode;
this.statsMode = statsMode;
this.streamMode = streamMode;
this.writeMode = writeMode;
}

static NodeEmbeddingApplications create(
Log log,
RequestScopedDependencies requestScopedDependencies,
AlgorithmEstimationTemplate algorithmEstimationTemplate,
AlgorithmProcessingTemplate algorithmProcessingTemplate,
Expand Down Expand Up @@ -77,8 +83,15 @@ static NodeEmbeddingApplications create(
algorithms,
algorithmProcessingTemplate
);
var writeMode = NodeEmbeddingAlgorithmsWriteModeBusinessFacade.create(
log,
requestScopedDependencies,
estimationMode,
algorithms,
algorithmProcessingTemplate
);

return new NodeEmbeddingApplications(estimationMode, mutateMode, statsMode, streamMode);
return new NodeEmbeddingApplications(estimationMode, mutateMode, statsMode, streamMode, writeMode);
}

public NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimate() {
Expand All @@ -96,4 +109,8 @@ public NodeEmbeddingAlgorithmsStatsModeBusinessFacade stats() {
public NodeEmbeddingAlgorithmsStreamModeBusinessFacade stream() {
return streamMode;
}

public NodeEmbeddingAlgorithmsWriteModeBusinessFacade write() {
return writeMode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.embeddings.results.DefaultNodeEmbeddingsWriteResult;
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsWriteResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
Expand All @@ -45,7 +45,7 @@ public Stream<DefaultNodeEmbeddingsWriteResult> write(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
return facade.oldNodeEmbeddings().fastRP().write(graphName, configuration);
return facade.algorithms().nodeEmbeddings().fastRPWrite(graphName, configuration);
}

@Procedure(value = "gds.fastRP.write.estimate", mode = READ)
Expand All @@ -54,6 +54,6 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
return facade.oldNodeEmbeddings().fastRP().writeEstimate(graphNameOrConfiguration, algoConfiguration);
return facade.algorithms().nodeEmbeddings().fastRPWriteEstimate(graphNameOrConfiguration, algoConfiguration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
import org.neo4j.gds.procedures.embeddings.results.DefaultNodeEmbeddingsWriteResult;
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsWriteResult;

import java.util.stream.Stream;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.gds.embeddings.graphsage;

import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.embeddings.results.DefaultNodeEmbeddingsWriteResult;
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsWriteResult;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
Expand Down
Loading

0 comments on commit 0bf49a6

Please sign in to comment.