Skip to content

Commit

Permalink
Merge pull request #7444 from FlorentinD/fix-userinputnegativesampler-23
Browse files Browse the repository at this point in the history
LP: Fix negative sampling bugs related to idmaps
  • Loading branch information
FlorentinD authored Apr 26, 2023
2 parents 068ec3d + 1f17224 commit 043e0b3
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;

import java.util.Collection;
import java.util.List;
import java.util.Optional;

public interface NegativeSampler {
Expand All @@ -36,6 +37,7 @@ public interface NegativeSampler {
static NegativeSampler of(
GraphStore graphStore,
Graph graph,
Collection<NodeLabel> sourceAndTargetNodeLabels,
Optional<String> negativeRelationshipType,
double negativeSamplingRatio,
long testPositiveCount,
Expand All @@ -47,7 +49,11 @@ static NegativeSampler of(
Optional<Long> randomSeed
) {
if (negativeRelationshipType.isPresent()) {
Graph negativeExampleGraph = graphStore.getGraph(RelationshipType.of(negativeRelationshipType.orElseThrow()));
Graph negativeExampleGraph = graphStore.getGraph(
sourceAndTargetNodeLabels,
List.of(RelationshipType.of(negativeRelationshipType.orElseThrow())),
Optional.empty()
);
double testTrainFraction = testPositiveCount / (double) (testPositiveCount + trainPositiveCount);

return new UserInputNegativeSampler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ public void produceNegativeSamples(

negativeExampleGraph.forEachNode(nodeId -> {
negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> {
// as we are adding the relationships to the GraphStore we need to operate over the rootNodeIds
long rootS = negativeExampleGraph.toRootNodeId(s);
long rootT = negativeExampleGraph.toRootNodeId(t);
if (s < t) {
if (sample(testRelationshipsToAdd.doubleValue()/(testRelationshipsToAdd.doubleValue() + trainRelationshipsToAdd.doubleValue()))) {
testRelationshipsToAdd.decrement();
testSetBuilder.add(s, t, NEGATIVE);
testSetBuilder.addFromInternal(rootS, rootT, NEGATIVE);
} else {
trainRelationshipsToAdd.decrement();
trainSetBuilder.add(s, t, NEGATIVE);
trainSetBuilder.addFromInternal(rootS, rootT, NEGATIVE);
}
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,22 @@ public class DirectedEdgeSplitter extends EdgeSplitter {

public DirectedEdgeSplitter(
Optional<Long> maybeSeed,
IdMap rootNodes,
IdMap sourceLabels,
IdMap targetLabels,
RelationshipType selectedRelationshipType,
RelationshipType remainingRelationshipType,
int concurrency
) {
super(maybeSeed, sourceLabels, targetLabels, selectedRelationshipType, remainingRelationshipType, concurrency);
super(
maybeSeed,
rootNodes,
sourceLabels,
targetLabels,
selectedRelationshipType,
remainingRelationshipType,
concurrency
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ public abstract class EdgeSplitter {

protected final IdMap sourceNodes;
protected final IdMap targetNodes;
protected final IdMap rootNodes;

protected int concurrency;

EdgeSplitter(
Optional<Long> maybeSeed,
IdMap rootNodes,
IdMap sourceNodes,
IdMap targetNodes,
RelationshipType selectedRelationshipType,
RelationshipType remainingRelationshipType,
int concurrency
) {
this.rootNodes = rootNodes;
this.selectedRelationshipType = selectedRelationshipType;
this.remainingRelationshipType = remainingRelationshipType;
this.rng = new Random();
Expand All @@ -78,7 +81,7 @@ public SplitResult splitPositiveExamples(
LongLongPredicate isValidNodePair = (s, t) -> isValidSourceNode.apply(s) && isValidTargetNode.apply(t);

RelationshipsBuilder selectedRelsBuilder = newRelationshipsBuilder(
graph,
rootNodes,
selectedRelationshipType,
Direction.DIRECTED,
Optional.of(EdgeSplitter.RELATIONSHIP_PROPERTY)
Expand All @@ -89,7 +92,7 @@ public SplitResult splitPositiveExamples(
RelationshipsBuilder remainingRelsBuilder;
RelationshipWithPropertyConsumer remainingRelsConsumer;

remainingRelsBuilder = newRelationshipsBuilder(graph, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey);
remainingRelsBuilder = newRelationshipsBuilder(rootNodes, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey);
remainingRelsConsumer = (s, t, w) -> {
remainingRelsBuilder.addFromInternal(graph.toRootNodeId(s), graph.toRootNodeId(t), w);
return true;
Expand Down Expand Up @@ -153,15 +156,15 @@ protected long samplesPerNode(long maxSamples, double remainingSamples, long rem
}

private static RelationshipsBuilder newRelationshipsBuilder(
Graph graph,
IdMap rootNodes,
RelationshipType relationshipType,
Direction direction,
Optional<String> propertyKey
) {
return GraphFactory.initRelationshipsBuilder()
.relationshipType(relationshipType)
.aggregation(Aggregation.SINGLE)
.nodes(graph)
.nodes(rootNodes)
.orientation(direction.toOrientation())
.addAllPropertyConfigs(propertyKey
.map(key -> List.of(GraphFactory.PropertyConfig.of(key, Aggregation.SINGLE, DefaultValue.forDouble())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ public final class SplitRelationships extends Algorithm<EdgeSplitter.SplitResult

private final SplitRelationshipsBaseConfig config;

private final IdMap rootNodes;

private final IdMap sourceNodes;

private final IdMap targetNodes;

private SplitRelationships(Graph graph, Graph masterGraph, IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsBaseConfig config) {
private SplitRelationships(Graph graph, Graph masterGraph,
IdMap rootNodes,
IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsBaseConfig config) {
super(ProgressTracker.NULL_TRACKER);
this.graph = graph;
this.masterGraph = masterGraph;
this.rootNodes = rootNodes;
this.config = config;
this.sourceNodes = sourceNodes;
this.targetNodes = targetNodes;
Expand All @@ -66,7 +71,7 @@ public static SplitRelationships of(GraphStore graphStore, SplitRelationshipsBas
IdMap sourceNodes = graphStore.getGraph(sourceLabels);
IdMap targetNodes = graphStore.getGraph(targetLabels);

return new SplitRelationships(graph, masterGraph, sourceNodes, targetNodes, config);
return new SplitRelationships(graph, masterGraph, graphStore.nodes(), sourceNodes, targetNodes, config);
}

public static MemoryEstimation estimate(SplitRelationshipsBaseConfig configuration) {
Expand Down Expand Up @@ -98,6 +103,7 @@ public EdgeSplitter.SplitResult compute() {
var splitter = isUndirected
? new UndirectedEdgeSplitter(
config.randomSeed(),
rootNodes,
sourceNodes,
targetNodes,
config.holdoutRelationshipType(),
Expand All @@ -106,6 +112,7 @@ public EdgeSplitter.SplitResult compute() {
)
: new DirectedEdgeSplitter(
config.randomSeed(),
rootNodes,
sourceNodes,
targetNodes,
config.holdoutRelationshipType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ public class UndirectedEdgeSplitter extends EdgeSplitter {

public UndirectedEdgeSplitter(
Optional<Long> maybeSeed,
IdMap rootNodes,
IdMap sourceNodes,
IdMap targetNodes,
RelationshipType selectedRelationshipType,
RelationshipType remainingRelationshipType,
int concurrency
) {
super(maybeSeed, sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency);
super(maybeSeed,
rootNodes,
sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void splitSkewedGraph() {
Optional.of(-1L),
skewedGraphStore.nodes(),
skewedGraphStore.nodes(),
skewedGraphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -150,6 +151,7 @@ void splitMultiGraph() {
Optional.of(-1L),
multiGraphStore.nodes(),
multiGraphStore.nodes(),
multiGraphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -170,6 +172,7 @@ void split() {
Optional.of(-1L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -207,6 +210,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() {
.generate();

var splitter = new DirectedEdgeSplitter(Optional.of(42L),
huuuuugeDenseGraph,
huuuuugeDenseGraph,
huuuuugeDenseGraph,
RelationshipType.of("SELECTED"),
Expand Down Expand Up @@ -241,6 +245,7 @@ void negativeEdgeSampling() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -261,6 +266,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() {
Collection<NodeLabel> targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D"));
var splitter = new DirectedEdgeSplitter(
Optional.of(1337L),
multiLabelGraphStore.nodes(),
multiLabelGraphStore.getGraph(sourceNodeLabels),
multiLabelGraphStore.getGraph(targetNodeLabels),
RelationshipType.of("SELECTED"),
Expand Down Expand Up @@ -295,6 +301,7 @@ void samplesWithinBounds() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -310,6 +317,7 @@ void shouldPreserveRelationshipWeights() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void split() {
Optional.of(1337L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -127,6 +128,7 @@ void splitMultiGraph() {
Optional.of(-1L),
multiGraphStore.nodes(),
multiGraphStore.nodes(),
multiGraphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -157,6 +159,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() {
Optional.of(42L),
huuuuugeDenseGraph,
huuuuugeDenseGraph,
huuuuugeDenseGraph,
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -198,29 +201,31 @@ void shouldProduceDeterministicResult() {

var splitResult1 = new UndirectedEdgeSplitter(
Optional.of(12L),
graphStore.nodes(),
graphStore.nodes(),
graph.idMap(),
graph.idMap(),
graph.idMap(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
).splitPositiveExamples(graph, 0.5, Optional.empty());
var splitResult2 = new UndirectedEdgeSplitter(
Optional.of(12L),
graphStore.nodes(),
graphStore.nodes(),
graph.idMap(),
graph.idMap(),
graph.idMap(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
).splitPositiveExamples(graph, 0.5, Optional.empty());
var remainingAreEqual = relationshipsAreEqual(
graph,
graph.idMap(),
splitResult1.remainingRels().build(),
splitResult2.remainingRels().build()
);
assertTrue(remainingAreEqual);

var holdoutAreEqual = relationshipsAreEqual(
graph,
graph.idMap(),
splitResult1.selectedRels().build(),
splitResult2.selectedRels().build()
);
Expand All @@ -244,6 +249,7 @@ void shouldProduceNonDeterministicResult() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -252,6 +258,7 @@ void shouldProduceNonDeterministicResult() {
Optional.of(117L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -277,6 +284,7 @@ void negativeEdgeSampling() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -299,6 +307,7 @@ void splitWithFilteringWithSameSourceTargetLabels() {
Optional.of(1337L),
graphStore.getGraph(NodeLabel.of("A")),
graphStore.getGraph(NodeLabel.of("A")),
graphStore.getGraph(NodeLabel.of("A")),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -332,6 +341,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() {
Collection<NodeLabel> targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D"));
var splitter = new UndirectedEdgeSplitter(
Optional.of(1337L),
multiLabelGraphStore.nodes(),
multiLabelGraphStore.getGraph(sourceNodeLabels),
multiLabelGraphStore.getGraph(targetNodeLabels),
RelationshipType.of("SELECTED"),
Expand Down Expand Up @@ -367,6 +377,7 @@ void samplesWithinBounds() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand All @@ -382,6 +393,7 @@ void shouldPreserveRelationshipWeights() {
Optional.of(42L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down Expand Up @@ -411,6 +423,7 @@ void zeroNegativeSamples() {
Optional.of(1337L),
graphStore.nodes(),
graphStore.nodes(),
graphStore.nodes(),
RelationshipType.of("SELECTED"),
RelationshipType.of("REMAINING"),
4
Expand Down
Loading

0 comments on commit 043e0b3

Please sign in to comment.