diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java index 0f222872440..40a0f92f6cd 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java @@ -27,6 +27,7 @@ import org.neo4j.gds.core.loading.construction.RelationshipsBuilder; import java.util.Collection; +import java.util.List; import java.util.Optional; public interface NegativeSampler { @@ -36,6 +37,7 @@ public interface NegativeSampler { static NegativeSampler of( GraphStore graphStore, Graph graph, + Collection sourceAndTargetNodeLabels, Optional negativeRelationshipType, double negativeSamplingRatio, long testPositiveCount, @@ -47,7 +49,11 @@ static NegativeSampler of( Optional randomSeed ) { if (negativeRelationshipType.isPresent()) { - Graph negativeExampleGraph = graphStore.getGraph(RelationshipType.of(negativeRelationshipType.orElseThrow())); + Graph negativeExampleGraph = graphStore.getGraph( + sourceAndTargetNodeLabels, + List.of(RelationshipType.of(negativeRelationshipType.orElseThrow())), + Optional.empty() + ); double testTrainFraction = testPositiveCount / (double) (testPositiveCount + trainPositiveCount); return new UserInputNegativeSampler( diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java index b19af6fc804..7a3db98a57f 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java @@ -64,13 +64,16 @@ public void produceNegativeSamples( negativeExampleGraph.forEachNode(nodeId -> { negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> { + // as we are adding the relationships to the GraphStore we need to operate over the rootNodeIds + long rootS = negativeExampleGraph.toRootNodeId(s); + long rootT = negativeExampleGraph.toRootNodeId(t); if (s < t) { if (sample(testRelationshipsToAdd.doubleValue()/(testRelationshipsToAdd.doubleValue() + trainRelationshipsToAdd.doubleValue()))) { testRelationshipsToAdd.decrement(); - testSetBuilder.add(s, t, NEGATIVE); + testSetBuilder.addFromInternal(rootS, rootT, NEGATIVE); } else { trainRelationshipsToAdd.decrement(); - trainSetBuilder.add(s, t, NEGATIVE); + trainSetBuilder.addFromInternal(rootS, rootT, NEGATIVE); } } return true; diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java index 44b46d421ec..3c86dff2ba4 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java @@ -36,13 +36,22 @@ public class DirectedEdgeSplitter extends EdgeSplitter { public DirectedEdgeSplitter( Optional maybeSeed, + IdMap rootNodes, IdMap sourceLabels, IdMap targetLabels, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { - super(maybeSeed, sourceLabels, targetLabels, selectedRelationshipType, remainingRelationshipType, concurrency); + super( + maybeSeed, + rootNodes, + sourceLabels, + targetLabels, + selectedRelationshipType, + remainingRelationshipType, + concurrency + ); } @Override diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java index 8e6d34ece1e..db52570ae92 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java @@ -47,17 +47,20 @@ public abstract class EdgeSplitter { protected final IdMap sourceNodes; protected final IdMap targetNodes; + protected final IdMap rootNodes; protected int concurrency; EdgeSplitter( Optional maybeSeed, + IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { + this.rootNodes = rootNodes; this.selectedRelationshipType = selectedRelationshipType; this.remainingRelationshipType = remainingRelationshipType; this.rng = new Random(); @@ -78,7 +81,7 @@ public SplitResult splitPositiveExamples( LongLongPredicate isValidNodePair = (s, t) -> isValidSourceNode.apply(s) && isValidTargetNode.apply(t); RelationshipsBuilder selectedRelsBuilder = newRelationshipsBuilder( - graph, + rootNodes, selectedRelationshipType, Direction.DIRECTED, Optional.of(EdgeSplitter.RELATIONSHIP_PROPERTY) @@ -89,7 +92,7 @@ public SplitResult splitPositiveExamples( RelationshipsBuilder remainingRelsBuilder; RelationshipWithPropertyConsumer remainingRelsConsumer; - remainingRelsBuilder = newRelationshipsBuilder(graph, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey); + remainingRelsBuilder = newRelationshipsBuilder(rootNodes, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey); remainingRelsConsumer = (s, t, w) -> { remainingRelsBuilder.addFromInternal(graph.toRootNodeId(s), graph.toRootNodeId(t), w); return true; @@ -153,7 +156,7 @@ protected long samplesPerNode(long maxSamples, double remainingSamples, long rem } private static RelationshipsBuilder newRelationshipsBuilder( - Graph graph, + IdMap rootNodes, RelationshipType relationshipType, Direction direction, Optional propertyKey @@ -161,7 +164,7 @@ private static RelationshipsBuilder newRelationshipsBuilder( return GraphFactory.initRelationshipsBuilder() .relationshipType(relationshipType) .aggregation(Aggregation.SINGLE) - .nodes(graph) + .nodes(rootNodes) .orientation(direction.toOrientation()) .addAllPropertyConfigs(propertyKey .map(key -> List.of(GraphFactory.PropertyConfig.of(key, Aggregation.SINGLE, DefaultValue.forDouble()))) diff --git a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java index f09e8e3e5df..86563c863c3 100644 --- a/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java +++ b/ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java @@ -40,14 +40,19 @@ public final class SplitRelationships extends Algorithm maybeSeed, + IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency ) { - super(maybeSeed, sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency); + super(maybeSeed, + rootNodes, + sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency); } @Override diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java index 025b48c97ed..79ec73b7124 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java @@ -131,6 +131,7 @@ void splitSkewedGraph() { Optional.of(-1L), skewedGraphStore.nodes(), skewedGraphStore.nodes(), + skewedGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -150,6 +151,7 @@ void splitMultiGraph() { Optional.of(-1L), multiGraphStore.nodes(), multiGraphStore.nodes(), + multiGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -170,6 +172,7 @@ void split() { Optional.of(-1L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -207,6 +210,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() { .generate(); var splitter = new DirectedEdgeSplitter(Optional.of(42L), + huuuuugeDenseGraph, huuuuugeDenseGraph, huuuuugeDenseGraph, RelationshipType.of("SELECTED"), @@ -241,6 +245,7 @@ void negativeEdgeSampling() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -261,6 +266,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { Collection targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D")); var splitter = new DirectedEdgeSplitter( Optional.of(1337L), + multiLabelGraphStore.nodes(), multiLabelGraphStore.getGraph(sourceNodeLabels), multiLabelGraphStore.getGraph(targetNodeLabels), RelationshipType.of("SELECTED"), @@ -295,6 +301,7 @@ void samplesWithinBounds() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -310,6 +317,7 @@ void shouldPreserveRelationshipWeights() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java index 690731bcad0..2953286b5d6 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java @@ -97,6 +97,7 @@ void split() { Optional.of(1337L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -127,6 +128,7 @@ void splitMultiGraph() { Optional.of(-1L), multiGraphStore.nodes(), multiGraphStore.nodes(), + multiGraphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -157,6 +159,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() { Optional.of(42L), huuuuugeDenseGraph, huuuuugeDenseGraph, + huuuuugeDenseGraph, RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -198,29 +201,31 @@ void shouldProduceDeterministicResult() { var splitResult1 = new UndirectedEdgeSplitter( Optional.of(12L), - graphStore.nodes(), - graphStore.nodes(), + graph.idMap(), + graph.idMap(), + graph.idMap(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 ).splitPositiveExamples(graph, 0.5, Optional.empty()); var splitResult2 = new UndirectedEdgeSplitter( Optional.of(12L), - graphStore.nodes(), - graphStore.nodes(), + graph.idMap(), + graph.idMap(), + graph.idMap(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 ).splitPositiveExamples(graph, 0.5, Optional.empty()); var remainingAreEqual = relationshipsAreEqual( - graph, + graph.idMap(), splitResult1.remainingRels().build(), splitResult2.remainingRels().build() ); assertTrue(remainingAreEqual); var holdoutAreEqual = relationshipsAreEqual( - graph, + graph.idMap(), splitResult1.selectedRels().build(), splitResult2.selectedRels().build() ); @@ -244,6 +249,7 @@ void shouldProduceNonDeterministicResult() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -252,6 +258,7 @@ void shouldProduceNonDeterministicResult() { Optional.of(117L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -277,6 +284,7 @@ void negativeEdgeSampling() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -299,6 +307,7 @@ void splitWithFilteringWithSameSourceTargetLabels() { Optional.of(1337L), graphStore.getGraph(NodeLabel.of("A")), graphStore.getGraph(NodeLabel.of("A")), + graphStore.getGraph(NodeLabel.of("A")), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -332,6 +341,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() { Collection targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D")); var splitter = new UndirectedEdgeSplitter( Optional.of(1337L), + multiLabelGraphStore.nodes(), multiLabelGraphStore.getGraph(sourceNodeLabels), multiLabelGraphStore.getGraph(targetNodeLabels), RelationshipType.of("SELECTED"), @@ -367,6 +377,7 @@ void samplesWithinBounds() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -382,6 +393,7 @@ void shouldPreserveRelationshipWeights() { Optional.of(42L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 @@ -411,6 +423,7 @@ void zeroNegativeSamples() { Optional.of(1337L), graphStore.nodes(), graphStore.nodes(), + graphStore.nodes(), RelationshipType.of("SELECTED"), RelationshipType.of("REMAINING"), 4 diff --git a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java index 41edc76800b..de0e7e62550 100644 --- a/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java +++ b/ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UserInputNegativeSamplerTest.java @@ -31,6 +31,7 @@ import org.neo4j.gds.core.loading.construction.RelationshipsBuilderBuilder; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.ml.negativeSampling.UserInputNegativeSampler; @@ -40,11 +41,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.neo4j.gds.ml.negativeSampling.NegativeSampler.NEGATIVE; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; @GdlExtension class UserInputNegativeSamplerTest { - @GdlGraph(orientation = Orientation.UNDIRECTED) + @GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 5000) static String gdl = "(a1:A), " + "(a2:A), " + @@ -61,6 +63,9 @@ class UserInputNegativeSamplerTest { @Inject Graph graph; + @Inject + IdFunction idFunction; + @Test void generateNegativeSamples() { @@ -120,9 +125,11 @@ void shouldValidateNegativeExamplesRespectNodeLabels() { List.of(NodeLabel.of("A")) )) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("There is a relationship of negativeRelationshipType between nodes 0 and 1. " + + .hasMessageContaining(formatWithLocale("There is a relationship of negativeRelationshipType between nodes %d and %d. " + "The nodes have types [NodeLabel{name='A'}] and [NodeLabel{name='A'}]. " + - "However, they need to be between [NodeLabel{name='B'}] and [NodeLabel{name='A'}]." + "However, they need to be between [NodeLabel{name='B'}] and [NodeLabel{name='A'}].", + idFunction.of("a1"), idFunction.of("a2") + ) ); } diff --git a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java index c51c61169b7..fcb42a5c8fd 100644 --- a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java +++ b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.java @@ -21,6 +21,7 @@ import org.jetbrains.annotations.NotNull; import org.neo4j.gds.ElementProjection; +import org.neo4j.gds.NodeLabel; import org.neo4j.gds.RelationshipType; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; @@ -41,6 +42,7 @@ import org.neo4j.gds.ml.splitting.EdgeSplitter; import org.neo4j.gds.ml.splitting.UndirectedEdgeSplitter; +import java.util.Collection; import java.util.List; import java.util.Optional; @@ -102,8 +104,9 @@ public void splitAndSampleRelationships( var targetLabels = ElementTypeValidator.resolve(graphStore, List.of(trainConfig.targetNodeLabel())); IdMap sourceNodes = graphStore.getGraph(sourceLabels); IdMap targetNodes = graphStore.getGraph(targetLabels); + Collection sourceAndTargetNodeLabels = trainConfig.nodeLabelIdentifiers(graphStore); var graph = graphStore.getGraph( - trainConfig.nodeLabelIdentifiers(graphStore), + sourceAndTargetNodeLabels, trainConfig.internalRelationshipTypes(graphStore), relationshipWeightProperty); @@ -121,7 +124,7 @@ public void splitAndSampleRelationships( ); // 2. Split test-complement into (labeled) train and feature-input. var testComplementGraph = graphStore.getGraph( - trainConfig.nodeLabelIdentifiers(graphStore), + sourceAndTargetNodeLabels, List.of(splitConfig.testComplementRelationshipType()), relationshipWeightProperty ); @@ -141,6 +144,7 @@ public void splitAndSampleRelationships( NegativeSampler negativeSampler = NegativeSampler.of( graphStore, graph, + sourceAndTargetNodeLabels, splitConfig.negativeRelationshipType(), splitConfig.negativeSamplingRatio(), testSplitResult.selectedRelCount(), @@ -180,6 +184,7 @@ private EdgeSplitter.SplitResult split( } var splitter = new UndirectedEdgeSplitter( trainConfig.randomSeed(), + graphStore.nodes(), sourceNodes, targetNodes, selectedRelType, diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java index 6bca527a1da..51ed227b268 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java @@ -36,9 +36,11 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfigImpl; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -54,7 +56,7 @@ @GdlExtension class LinkPredictionRelationshipSamplerTest { - @GdlGraph(orientation = Orientation.UNDIRECTED) + @GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 1337) private static final String GRAPH = "CREATE " + "(a:N {scalar: 0, array: [-1.0, -2.0, 1.0, 1.0, 3.0]}), " + @@ -64,6 +66,8 @@ class LinkPredictionRelationshipSamplerTest { "(e:N {scalar: 1, array: [-2.0, 1.0, 2.0, 1.0, -1.0]}), " + "(f:N {scalar: 0, array: [-1.0, -3.0, 1.0, 2.0, 2.0]}), " + "(g:N {scalar: 1, array: [3.0, 1.0, -3.0, 3.0, 1.0]}), " + + // leaving some id gap between nodes + "(:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), ".repeat(20) + "(h:N {scalar: 3, array: [-1.0, 3.0, 2.0, 1.0, -3.0]}), " + "(i:N {scalar: 3, array: [4.0, 1.0, 1.0, 2.0, 1.0]}), " + "(j:N {scalar: 4, array: [1.0, -4.0, 2.0, -2.0, 2.0]}), " + @@ -72,7 +76,7 @@ class LinkPredictionRelationshipSamplerTest { "(m:N {scalar: 0, array: [4.0, 4.0, 1.0, 1.0, 1.0]}), " + "(n:N {scalar: 3, array: [1.0, -2.0, 3.0, 2.0, 3.0]}), " + "(o:N {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " + - "" + + "(a)-[:REL {weight: 2.0}]->(b), " + "(a)-[:REL {weight: 1.0}]->(c), " + "(b)-[:REL {weight: 3.0}]->(c), " + @@ -96,6 +100,9 @@ class LinkPredictionRelationshipSamplerTest { @Inject GraphStore graphStore; + @Inject + IdFunction idFunction; + @GdlGraph(graphNamePrefix = "multi", orientation = Orientation.UNDIRECTED) private static final String MULTI_GRAPH = "CREATE " + @@ -363,7 +370,7 @@ void splitWithSpecifiedNegativeRelationships() { .negativeRelationshipType("NEGATIVE") // 3 total .build(); - var trainConfig = createTrainConfig("REL", "*", "N", -1337L); + var trainConfig = createTrainConfig("REL", "N", "N", -1337L); var relationshipSplitter = new LinkPredictionRelationshipSampler( graphStore, @@ -386,6 +393,23 @@ void splitWithSpecifiedNegativeRelationships() { //8 * 0.5 = 4 positive, 1 negative assertThat(trainGraphSize).isEqualTo(5); assertThat(featureInputGraphSize).isEqualTo(8); - + var outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label")); + + var negativeRelSpace = graphStore.getGraph(RelationshipType.of("NEGATIVE")); + var positiveRelSpace = graphStore.getGraph(RelationshipType.of("REL")); + + outGraph.forEachNode(nodeId -> { + outGraph.forEachRelationship(nodeId, Double.NaN, (s,t, w) -> { + if (w == 1.0) { + assertThat(positiveRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue(); + } + if (w == 0.0) { + assertThat(negativeRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue(); + } + return true; + }); + return true; + } + ); } }