From 8aaf93ceb9be352a5aaa06caffa93774c4de21e7 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Wed, 26 Apr 2023 12:56:39 +0200 Subject: [PATCH 1/3] Fix sampling bugs related to idmaps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Florentin Dörre Co-Authored-By: Adam Schill Collberg --- .../ml/negativeSampling/NegativeSampler.java | 8 +- .../UserInputNegativeSampler.java | 7 +- .../ml/splitting/DirectedEdgeSplitter.java | 11 ++- .../neo4j/gds/ml/splitting/EdgeSplitter.java | 11 ++- .../gds/ml/splitting/SplitRelationships.java | 11 ++- .../ml/splitting/UndirectedEdgeSplitter.java | 5 +- .../splitting/DirectedEdgeSplitterTest.java | 10 ++- .../splitting/UndirectedEdgeSplitterTest.java | 25 ++++-- .../UserInputNegativeSamplerTest.java | 13 ++- .../LinkPredictionRelationshipSampler.java | 9 ++- ...LinkPredictionRelationshipSamplerTest.java | 79 ++++++++++++++++++- 11 files changed, 163 insertions(+), 26 deletions(-) diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java index 0f222872440..40a0f92f6cd 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java @@ -27,6 +27,7 @@ import org.neo4j.gds.core.loading.construction.RelationshipsBuilder; import java.util.Collection; +import java.util.List; import java.util.Optional; public interface NegativeSampler { @@ -36,6 +37,7 @@ public interface NegativeSampler { static NegativeSampler of( GraphStore graphStore, Graph graph, + Collection sourceAndTargetNodeLabels, Optional negativeRelationshipType, double negativeSamplingRatio, long testPositiveCount, @@ -47,7 +49,11 @@ static NegativeSampler of( Optional randomSeed ) { if (negativeRelationshipType.isPresent()) { - Graph negativeExampleGraph = graphStore.getGraph(RelationshipType.of(negativeRelationshipType.orElseThrow())); + Graph negativeExampleGraph = graphStore.getGraph( + sourceAndTargetNodeLabels, + List.of(RelationshipType.of(negativeRelationshipType.orElseThrow())), + Optional.empty() + ); double testTrainFraction = testPositiveCount / (double) (testPositiveCount + trainPositiveCount); return new UserInputNegativeSampler( diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java index b19af6fc804..7a3db98a57f 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java @@ -64,13 +64,16 @@ public void produceNegativeSamples( negativeExampleGraph.forEachNode(nodeId -> { negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> { + // as we are adding the relationships to the GraphStore we need to operate over the rootNodeIds + long rootS = negativeExampleGraph.toRootNodeId(s); + long rootT = negativeExampleGraph.toRootNodeId(t); if (s < t) { if (sample(testRelationshipsToAdd.doubleValue()/(testRelationshipsToAdd.doubleValue() + trainRelationshipsToAdd.doubleValue()))) { testRelationshipsToAdd.decrement(); - testSetBuilder.add(s, t, NEGATIVE); + testSetBuilder.addFromInternal(rootS, rootT, NEGATIVE); } else { trainRelationshipsToAdd.decrement(); - trainSetBuilder.add(s, t, NEGATIVE); + trainSetBuilder.addFromInternal(rootS, rootT, NEGATIVE); } } return true; diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java index 44b46d421ec..3c86dff2ba4 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java @@ -36,13 +36,22 @@ public class DirectedEdgeSplitter extends EdgeSplitter { public DirectedEdgeSplitter( Optional maybeSeed, + IdMap rootNodes, IdMap sourceLabels, IdMap targetLabels, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { - super(maybeSeed, sourceLabels, targetLabels, selectedRelationshipType, remainingRelationshipType, concurrency); + super( + maybeSeed, + rootNodes, + sourceLabels, + targetLabels, + selectedRelationshipType, + remainingRelationshipType, + concurrency + ); } @Override diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java index 8e6d34ece1e..db52570ae92 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java @@ -47,17 +47,20 @@ public abstract class EdgeSplitter { protected final IdMap sourceNodes; protected final IdMap targetNodes; + protected final IdMap rootNodes; protected int concurrency; EdgeSplitter( Optional maybeSeed, + IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { + this.rootNodes = rootNodes; this.selectedRelationshipType = selectedRelationshipType; this.remainingRelationshipType = remainingRelationshipType; this.rng = new Random(); @@ -78,7 +81,7 @@ public SplitResult splitPositiveExamples( LongLongPredicate isValidNodePair = (s, t) -> isValidSourceNode.apply(s) && isValidTargetNode.apply(t); RelationshipsBuilder selectedRelsBuilder = newRelationshipsBuilder( - graph, + rootNodes, selectedRelationshipType, Direction.DIRECTED, Optional.of(EdgeSplitter.RELATIONSHIP_PROPERTY) @@ -89,7 +92,7 @@ public SplitResult splitPositiveExamples( RelationshipsBuilder remainingRelsBuilder; RelationshipWithPropertyConsumer remainingRelsConsumer; - remainingRelsBuilder = newRelationshipsBuilder(graph, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey); + remainingRelsBuilder = newRelationshipsBuilder(rootNodes, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey); remainingRelsConsumer = (s, t, w) -> { remainingRelsBuilder.addFromInternal(graph.toRootNodeId(s), graph.toRootNodeId(t), w); return true; @@ -153,7 +156,7 @@ protected long samplesPerNode(long maxSamples, double remainingSamples, long rem } private static RelationshipsBuilder newRelationshipsBuilder( - Graph graph, + IdMap rootNodes, RelationshipType relationshipType, Direction direction, Optional propertyKey @@ -161,7 +164,7 @@ private static RelationshipsBuilder newRelationshipsBuilder( return GraphFactory.initRelationshipsBuilder() .relationshipType(relationshipType) .aggregation(Aggregation.SINGLE) - .nodes(graph) + .nodes(rootNodes) .orientation(direction.toOrientation()) .addAllPropertyConfigs(propertyKey .map(key -> List.of(GraphFactory.PropertyConfig.of(key, Aggregation.SINGLE, DefaultValue.forDouble()))) diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java index f09e8e3e5df..86563c863c3 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java @@ -40,14 +40,19 @@ public final class SplitRelationships extends Algorithm maybeSeed, + IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { - super(maybeSeed, sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency); + super(maybeSeed, + rootNodes, + sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency); } @Override diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java index 025b48c97ed..2a4340c5f56 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java @@ -131,6 +131,7 @@ void splitSkewedGraph() { Optional.of(-1L), skewedGraphStore.nodes(), skewedGraphStore.nodes(), + skewedGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -150,6 +151,7 @@ void splitMultiGraph() { Optional.of(-1L), multiGraphStore.nodes(), multiGraphStore.nodes(), + multiGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -170,6 +172,7 @@ void split() { Optional.of(-1L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -207,6 +210,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() { .generate(); var splitter = new DirectedEdgeSplitter(Optional.of(42L), + huuuuugeDenseGraph, huuuuugeDenseGraph, huuuuugeDenseGraph, RelationshipType.of("SELECTED"), @@ -241,6 +245,7 @@ void negativeEdgeSampling() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -261,6 +266,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { Collection targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D")); var splitter = new DirectedEdgeSplitter( Optional.of(1337L), + multiLabelGraphStore.nodes(), multiLabelGraphStore.getGraph(sourceNodeLabels), multiLabelGraphStore.getGraph(targetNodeLabels), RelationshipType.of("SELECTED"), @@ -280,7 +286,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { var selectedRelationships = result.selectedRels().build(); assertThat(selectedRelationships.topology()).satisfies(topology -> { - assertRelSamplingProperties(selectedRelationships, multiLabelGraph); + assertRelSamplingProperties(selectedRelationships, multiLabelGraphStore); assertThat(topology.elementCount()).isEqualTo(1); assertFalse(topology.isMultiGraph()); }); @@ -295,6 +301,7 @@ void samplesWithinBounds() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -310,6 +317,7 @@ void shouldPreserveRelationshipWeights() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java index 690731bcad0..2953286b5d6 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java @@ -97,6 +97,7 @@ void split() { Optional.of(1337L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -127,6 +128,7 @@ void splitMultiGraph() { Optional.of(-1L), multiGraphStore.nodes(), multiGraphStore.nodes(), + multiGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -157,6 +159,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() { Optional.of(42L), huuuuugeDenseGraph, huuuuugeDenseGraph, + huuuuugeDenseGraph, RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -198,29 +201,31 @@ void shouldProduceDeterministicResult() { var splitResult1 = new UndirectedEdgeSplitter( Optional.of(12L), - graphStore.nodes(), - graphStore.nodes(), + graph.idMap(), + graph.idMap(), + graph.idMap(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 ).splitPositiveExamples(graph, 0.5, Optional.empty()); var splitResult2 = new UndirectedEdgeSplitter( Optional.of(12L), - graphStore.nodes(), - graphStore.nodes(), + graph.idMap(), + graph.idMap(), + graph.idMap(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 ).splitPositiveExamples(graph, 0.5, Optional.empty()); var remainingAreEqual = relationshipsAreEqual( - graph, + graph.idMap(), splitResult1.remainingRels().build(), splitResult2.remainingRels().build() ); assertTrue(remainingAreEqual); var holdoutAreEqual = relationshipsAreEqual( - graph, + graph.idMap(), splitResult1.selectedRels().build(), splitResult2.selectedRels().build() ); @@ -244,6 +249,7 @@ void shouldProduceNonDeterministicResult() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -252,6 +258,7 @@ void shouldProduceNonDeterministicResult() { Optional.of(117L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -277,6 +284,7 @@ void negativeEdgeSampling() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -299,6 +307,7 @@ void splitWithFilteringWithSameSourceTargetLabels() { Optional.of(1337L), graphStore.getGraph(NodeLabel.of("A")), graphStore.getGraph(NodeLabel.of("A")), + graphStore.getGraph(NodeLabel.of("A")), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -332,6 +341,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { Collection targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D")); var splitter = new UndirectedEdgeSplitter( Optional.of(1337L), + multiLabelGraphStore.nodes(), multiLabelGraphStore.getGraph(sourceNodeLabels), multiLabelGraphStore.getGraph(targetNodeLabels), RelationshipType.of("SELECTED"), @@ -367,6 +377,7 @@ void samplesWithinBounds() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -382,6 +393,7 @@ void shouldPreserveRelationshipWeights() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -411,6 +423,7 @@ void zeroNegativeSamples() { Optional.of(1337L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java index 41edc76800b..de0e7e62550 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java @@ -31,6 +31,7 @@ import org.neo4j.gds.core.loading.construction.RelationshipsBuilderBuilder; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.ml.negativeSampling.UserInputNegativeSampler; @@ -40,11 +41,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.neo4j.gds.ml.negativeSampling.NegativeSampler.NEGATIVE; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; @GdlExtension class UserInputNegativeSamplerTest { - @GdlGraph(orientation = Orientation.UNDIRECTED) + @GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 5000) static String gdl = "(a1:A), " + "(a2:A), " + @@ -61,6 +63,9 @@ class UserInputNegativeSamplerTest { @Inject Graph graph; + @Inject + IdFunction idFunction; + @Test void generateNegativeSamples() { @@ -120,9 +125,11 @@ void shouldValidateNegativeExamplesRespectNodeLabels() { List.of(NodeLabel.of("A")) )) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("There is a relationship of negativeRelationshipType between nodes 0 and 1. " + + .hasMessageContaining(formatWithLocale("There is a relationship of negativeRelationshipType between nodes %d and %d. " + "The nodes have types [NodeLabel{name='A'}] and [NodeLabel{name='A'}]. " + - "However, they need to be between [NodeLabel{name='B'}] and [NodeLabel{name='A'}]." + "However, they need to be between [NodeLabel{name='B'}] and [NodeLabel{name='A'}].", + idFunction.of("a1"), idFunction.of("a2") + ) ); } diff --git a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java index c51c61169b7..fcb42a5c8fd 100644 --- a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java +++ b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java @@ -21,6 +21,7 @@ import org.jetbrains.annotations.NotNull; import org.neo4j.gds.ElementProjection; +import org.neo4j.gds.NodeLabel; import org.neo4j.gds.RelationshipType; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; @@ -41,6 +42,7 @@ import org.neo4j.gds.ml.splitting.EdgeSplitter; import org.neo4j.gds.ml.splitting.UndirectedEdgeSplitter; +import java.util.Collection; import java.util.List; import java.util.Optional; @@ -102,8 +104,9 @@ public void splitAndSampleRelationships( var targetLabels = ElementTypeValidator.resolve(graphStore, List.of(trainConfig.targetNodeLabel())); IdMap sourceNodes = graphStore.getGraph(sourceLabels); IdMap targetNodes = graphStore.getGraph(targetLabels); + Collection sourceAndTargetNodeLabels = trainConfig.nodeLabelIdentifiers(graphStore); var graph = graphStore.getGraph( - trainConfig.nodeLabelIdentifiers(graphStore), + sourceAndTargetNodeLabels, trainConfig.internalRelationshipTypes(graphStore), relationshipWeightProperty); @@ -121,7 +124,7 @@ public void splitAndSampleRelationships( ); // 2. Split test-complement into (labeled) train and feature-input. var testComplementGraph = graphStore.getGraph( - trainConfig.nodeLabelIdentifiers(graphStore), + sourceAndTargetNodeLabels, List.of(splitConfig.testComplementRelationshipType()), relationshipWeightProperty ); @@ -141,6 +144,7 @@ public void splitAndSampleRelationships( NegativeSampler negativeSampler = NegativeSampler.of( graphStore, graph, + sourceAndTargetNodeLabels, splitConfig.negativeRelationshipType(), splitConfig.negativeSamplingRatio(), testSplitResult.selectedRelCount(), @@ -180,6 +184,7 @@ private EdgeSplitter.SplitResult split( } var splitter = new UndirectedEdgeSplitter( trainConfig.randomSeed(), + graphStore.nodes(), sourceNodes, targetNodes, selectedRelType, diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java index 6bca527a1da..0dd392bc1c8 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java @@ -24,6 +24,7 @@ import org.neo4j.gds.InspectableTestProgressTracker; import org.neo4j.gds.Orientation; import org.neo4j.gds.RelationshipType; +import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.api.schema.ElementSchemaEntry; import org.neo4j.gds.assertj.Extractors; @@ -36,13 +37,17 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfigImpl; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -54,16 +59,43 @@ @GdlExtension class LinkPredictionRelationshipSamplerTest { - @GdlGraph(orientation = Orientation.UNDIRECTED) + @GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 1337) private static final String GRAPH = "CREATE " + + "(x1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(x9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(y9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + "(z9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + "(a:N {scalar: 0, array: [-1.0, -2.0, 1.0, 1.0, 3.0]}), " + + "(z4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + "(b:N {scalar: 4, array: [2.0, 1.0, -2.0, 2.0, 1.0]}), " + "(c:N {scalar: 0, array: [-3.0, 4.0, 3.0, 3.0, 2.0]}), " + "(d:N {scalar: 3, array: [1.0, 3.0, 1.0, -1.0, -1.0]}), " + "(e:N {scalar: 1, array: [-2.0, 1.0, 2.0, 1.0, -1.0]}), " + "(f:N {scalar: 0, array: [-1.0, -3.0, 1.0, 2.0, 2.0]}), " + "(g:N {scalar: 1, array: [3.0, 1.0, -3.0, 3.0, 1.0]}), " + + "(z2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + "(h:N {scalar: 3, array: [-1.0, 3.0, 2.0, 1.0, -3.0]}), " + "(i:N {scalar: 3, array: [4.0, 1.0, 1.0, 2.0, 1.0]}), " + "(j:N {scalar: 4, array: [1.0, -4.0, 2.0, -2.0, 2.0]}), " + @@ -96,6 +128,9 @@ class LinkPredictionRelationshipSamplerTest { @Inject GraphStore graphStore; + @Inject + IdFunction idFunction; + @GdlGraph(graphNamePrefix = "multi", orientation = Orientation.UNDIRECTED) private static final String MULTI_GRAPH = "CREATE " + @@ -363,7 +398,7 @@ void splitWithSpecifiedNegativeRelationships() { .negativeRelationshipType("NEGATIVE") // 3 total .build(); - var trainConfig = createTrainConfig("REL", "*", "N", -1337L); + var trainConfig = createTrainConfig("REL", "N", "N", -1337L); var relationshipSplitter = new LinkPredictionRelationshipSampler( graphStore, @@ -386,6 +421,44 @@ void splitWithSpecifiedNegativeRelationships() { //8 * 0.5 = 4 positive, 1 negative assertThat(trainGraphSize).isEqualTo(5); assertThat(featureInputGraphSize).isEqualTo(8); - + Graph outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label")); + var positiveEdgesList = new ArrayList(); + var negativeEdgesList = new ArrayList(); + var idsToNames = IntStream + .range('a', 'o' + 1) + .mapToObj(i -> (char) i) + .collect(Collectors.toMap(c -> idFunction.of(String.valueOf(c)), String::valueOf)); + outGraph.forEachNode(nodeId -> { + outGraph.forEachRelationship(nodeId, -2, (s,t, w) -> { + var relationshipString = "(" + idsToNames.get(outGraph.toOriginalNodeId(s)) + "," + idsToNames.get(outGraph.toOriginalNodeId(t)) + ")"; + if (w == 1.0) { + positiveEdgesList.add(relationshipString); + } + if (w == 0.0) { + negativeEdgesList.add(relationshipString); + } + return true; + }); + return true; + } + ); + assertThat(String.join("\n", positiveEdgesList)) + .isEqualTo( + "(a,b)\n" + + "(a,c)\n" + + "(c,d)\n" + + "(e,g)\n" + + "(f,g)\n" + + "(h,i)\n" + + "(j,k)\n" + + "(j,l)\n" + + "(k,l)\n" + + "(m,n)\n" + + "(m,o)\n" + + "(n,o)"); + assertThat(String.join("\n", negativeEdgesList)) + .isEqualTo("(a,k)\n" + + "(b,k)\n" + + "(c,k)"); } } From e1ad623263835d99ee11ed447240e7ac71e03cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Wed, 26 Apr 2023 14:10:00 +0200 Subject: [PATCH 2/3] Cleanup test Avoid string comparisons which are fixed for a specific seed --- ...LinkPredictionRelationshipSamplerTest.java | 71 +++---------------- 1 file changed, 11 insertions(+), 60 deletions(-) diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java index 0dd392bc1c8..51ed227b268 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java @@ -24,7 +24,6 @@ import org.neo4j.gds.InspectableTestProgressTracker; import org.neo4j.gds.Orientation; import org.neo4j.gds.RelationshipType; -import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.api.schema.ElementSchemaEntry; import org.neo4j.gds.assertj.Extractors; @@ -41,13 +40,11 @@ import org.neo4j.gds.extension.Inject; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfigImpl; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -62,40 +59,15 @@ class LinkPredictionRelationshipSamplerTest { @GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 1337) private static final String GRAPH = "CREATE " + - "(x1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(x9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(y9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "(z9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + "(a:N {scalar: 0, array: [-1.0, -2.0, 1.0, 1.0, 3.0]}), " + - "(z4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + "(b:N {scalar: 4, array: [2.0, 1.0, -2.0, 2.0, 1.0]}), " + "(c:N {scalar: 0, array: [-3.0, 4.0, 3.0, 3.0, 2.0]}), " + "(d:N {scalar: 3, array: [1.0, 3.0, 1.0, -1.0, -1.0]}), " + "(e:N {scalar: 1, array: [-2.0, 1.0, 2.0, 1.0, -1.0]}), " + "(f:N {scalar: 0, array: [-1.0, -3.0, 1.0, 2.0, 2.0]}), " + "(g:N {scalar: 1, array: [3.0, 1.0, -3.0, 3.0, 1.0]}), " + - "(z2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + + // leaving some id gap between nodes + "(:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), ".repeat(20) + "(h:N {scalar: 3, array: [-1.0, 3.0, 2.0, 1.0, -3.0]}), " + "(i:N {scalar: 3, array: [4.0, 1.0, 1.0, 2.0, 1.0]}), " + "(j:N {scalar: 4, array: [1.0, -4.0, 2.0, -2.0, 2.0]}), " + @@ -104,7 +76,7 @@ class LinkPredictionRelationshipSamplerTest { "(m:N {scalar: 0, array: [4.0, 4.0, 1.0, 1.0, 1.0]}), " + "(n:N {scalar: 3, array: [1.0, -2.0, 3.0, 2.0, 3.0]}), " + "(o:N {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "" + + "(a)-[:REL {weight: 2.0}]->(b), " + "(a)-[:REL {weight: 1.0}]->(c), " + "(b)-[:REL {weight: 3.0}]->(c), " + @@ -421,44 +393,23 @@ void splitWithSpecifiedNegativeRelationships() { //8 * 0.5 = 4 positive, 1 negative assertThat(trainGraphSize).isEqualTo(5); assertThat(featureInputGraphSize).isEqualTo(8); - Graph outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label")); - var positiveEdgesList = new ArrayList(); - var negativeEdgesList = new ArrayList(); - var idsToNames = IntStream - .range('a', 'o' + 1) - .mapToObj(i -> (char) i) - .collect(Collectors.toMap(c -> idFunction.of(String.valueOf(c)), String::valueOf)); + var outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label")); + + var negativeRelSpace = graphStore.getGraph(RelationshipType.of("NEGATIVE")); + var positiveRelSpace = graphStore.getGraph(RelationshipType.of("REL")); + outGraph.forEachNode(nodeId -> { - outGraph.forEachRelationship(nodeId, -2, (s,t, w) -> { - var relationshipString = "(" + idsToNames.get(outGraph.toOriginalNodeId(s)) + "," + idsToNames.get(outGraph.toOriginalNodeId(t)) + ")"; + outGraph.forEachRelationship(nodeId, Double.NaN, (s,t, w) -> { if (w == 1.0) { - positiveEdgesList.add(relationshipString); + assertThat(positiveRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue(); } if (w == 0.0) { - negativeEdgesList.add(relationshipString); + assertThat(negativeRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue(); } return true; }); return true; } ); - assertThat(String.join("\n", positiveEdgesList)) - .isEqualTo( - "(a,b)\n" + - "(a,c)\n" + - "(c,d)\n" + - "(e,g)\n" + - "(f,g)\n" + - "(h,i)\n" + - "(j,k)\n" + - "(j,l)\n" + - "(k,l)\n" + - "(m,n)\n" + - "(m,o)\n" + - "(n,o)"); - assertThat(String.join("\n", negativeEdgesList)) - .isEqualTo("(a,k)\n" + - "(b,k)\n" + - "(c,k)"); } } From 1f17224351463a64697e93cc0e094579f6c3fe95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Wed, 26 Apr 2023 14:20:11 +0200 Subject: [PATCH 3/3] Fix cherry pick to 2.3 --- .../org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java index 2a4340c5f56..79ec73b7124 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java @@ -286,7 +286,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { var selectedRelationships = result.selectedRels().build(); assertThat(selectedRelationships.topology()).satisfies(topology -> { - assertRelSamplingProperties(selectedRelationships, multiLabelGraphStore); + assertRelSamplingProperties(selectedRelationships, multiLabelGraph); assertThat(topology.elementCount()).isEqualTo(1); assertFalse(topology.isMultiGraph()); });