diff --git a/algo/src/main/java/org/neo4j/gds/paths/yens/MutablePathResult.java b/algo/src/main/java/org/neo4j/gds/paths/yens/MutablePathResult.java index f973d792308..8d5e43eb081 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/yens/MutablePathResult.java +++ b/algo/src/main/java/org/neo4j/gds/paths/yens/MutablePathResult.java @@ -31,6 +31,7 @@ */ final class MutablePathResult { + private final long[] EMPTY_ARRAY = new long[0]; private long index; private final long sourceNode; @@ -156,8 +157,7 @@ boolean matchesExactly(MutablePathResult path, int index) { * The cost value associated with the last value in this path, is added to * the costs for each node in the second path. */ - - + private void append(MutablePathResult path, long[] relationships) { // spur node is end of first and beginning of second path assert nodeIds[nodeIds.length - 1] == path.nodeIds[0]; @@ -213,7 +213,7 @@ void append(MutablePathResult path) { */ void appendWithoutRelationshipIds(MutablePathResult path) { // spur node is end of first and beginning of second path - append(path, new long[0]); + append(path, EMPTY_ARRAY); } diff --git a/algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java b/algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java index a2052a57f55..048878666fa 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java +++ b/algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java @@ -127,10 +127,9 @@ private Yens(Graph graph, Dijkstra dijkstra, ShortestPathYensBaseConfig config, } private boolean shouldAvoidRelationship(long source, long target, long relationshipId) { - long forbidden = target; - if (config.trackRelationships()) { - forbidden = relationshipId; - } + long forbidden = config.trackRelationships() + ? relationshipId + : target; return relationshipAvoidList.getOrDefault(source, EMPTY_SET).contains(forbidden); } diff --git a/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java b/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java index 4dc677e6e4f..6275662143e 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java @@ -166,9 +166,7 @@ static Stream> pathInput() { @ParameterizedTest @MethodSource("pathInput") void compute(Collection expectedPaths) { - assertResult(graph, idFunction, expectedPaths, false); - } @Test diff --git a/proc/path-finding/src/test/java/org/neo4j/gds/paths/sourcetarget/YensTestWithDifferentProjections.java b/proc/path-finding/src/test/java/org/neo4j/gds/paths/sourcetarget/YensTestWithDifferentProjections.java index 51f2e2b6045..a176c6aa5de 100644 --- a/proc/path-finding/src/test/java/org/neo4j/gds/paths/sourcetarget/YensTestWithDifferentProjections.java +++ b/proc/path-finding/src/test/java/org/neo4j/gds/paths/sourcetarget/YensTestWithDifferentProjections.java @@ -24,6 +24,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.neo4j.gds.BaseProcTest; import org.neo4j.gds.catalog.GraphProjectProc; +import org.neo4j.gds.extension.IdFunction; +import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.Neo4jGraph; import java.util.Collection; @@ -34,71 +36,76 @@ class YensTestWithDifferentProjections extends BaseProcTest { - @Neo4jGraph - private static final String DB_CYPHER = - "CREATE (a:CITY), " + - "(b:CITY), " + - "(c:CITY), " + - "(d:CITY), " + - "(e:CITY), " + - "(f:CITY), " + - "(a)-[:ROAD]->(b), " + - "(a)-[:ROAD]->(b), " + - "(b)-[:ROAD]->(c), " + - "(b)-[:ROAD]->(d), " + - "(c)-[:ROAD]->(f), " + - "(d)-[:ROAD]->(e), " + - "(e)-[:ROAD]->(c), " + - "(e)-[:ROAD]->(f), " + - "(a)-[:PATH]->(b), " + - "(d)-[:PATH]->(e), " + - "(d)-[:PATH]->(e)"; - - @BeforeEach - void setup() throws Exception { - registerProcedures( - ShortestPathYensStreamProc.class, - GraphProjectProc.class - ); - } - - - @ParameterizedTest - @ValueSource(strings = { - "CALL gds.graph.project('g', '*', {TYPE: {type: '*', aggregation: 'SINGLE'}})", - "CALL gds.graph.project.cypher('g', 'MATCH (n) RETURN id(n) AS id', 'MATCH (n)-[r]->(m) RETURN DISTINCT id(n) AS source, id(m) AS target')" - }) - void shouldWorkWithDifferentProjections(String projectionQuery) { - - runQuery(projectionQuery); - String yensQuery = "MATCH (source), (target) " + - "WHERE id(source)=0 AND id(target)=5 " + - "CALL gds.shortestPath.yens.stream(" + - " 'g', " + - " {sourceNode:source, targetNode:target, k:3} " + - ") " + - "YIELD nodeIds RETURN nodeIds "; - - Collection encounteredPaths = new HashSet<>(); - runQuery(yensQuery, result -> { - assertThat(result.columns()).containsExactlyInAnyOrder("nodeIds"); - - while (result.hasNext()) { - var next = result.next(); - var currentPath = (List) next.get("nodeIds"); - long[] pathToArray = currentPath.stream().mapToLong(l -> l).toArray(); - encounteredPaths.add(pathToArray); - } - - return true; - }); - - assertThat(encounteredPaths).containsExactlyInAnyOrder( - new long[]{0l, 1l, 3l, 4l, 2l, 5l}, - new long[]{0l, 1l, 3l, 4l, 5l}, - new long[]{0l, 1l, 2l, 5l} - ); - } + @Neo4jGraph + private static final String DB_CYPHER = + "CREATE (a:CITY {cityid:0}), " + + "(b:CITY {cityid:1}), " + + "(c:CITY {cityid:2}), " + + "(d:CITY {cityid:3}), " + + "(e:CITY {cityid:4}), " + + "(f:CITY {cityid:5}), " + + "(a)-[:ROAD]->(b), " + + "(a)-[:ROAD]->(b), " + + "(b)-[:ROAD]->(c), " + + "(b)-[:ROAD]->(d), " + + "(c)-[:ROAD]->(f), " + + "(d)-[:ROAD]->(e), " + + "(e)-[:ROAD]->(c), " + + "(e)-[:ROAD]->(f), " + + "(a)-[:PATH]->(b), " + + "(d)-[:PATH]->(e), " + + "(d)-[:PATH]->(e)"; + + @Inject + IdFunction idFunction; + + @BeforeEach + void setup() throws Exception { + registerProcedures( + ShortestPathYensStreamProc.class, + GraphProjectProc.class + ); + } + + + @ParameterizedTest + @ValueSource(strings = { + "CALL gds.graph.project('g', '*', {TYPE: {type: '*', aggregation: 'SINGLE'}})", + "CALL gds.graph.project.cypher('g', 'MATCH (n) RETURN id(n) AS id', 'MATCH (n)-[r]->(m) RETURN DISTINCT id(n) AS source, id(m) AS target')" + }) + void shouldWorkWithDifferentProjections(String projectionQuery) { + + runQuery(projectionQuery); + String yensQuery = "MATCH (source), (target) " + + "WHERE source.cityid=0 AND target.cityid=5 " + + "CALL gds.shortestPath.yens.stream(" + + " 'g', " + + " {sourceNode:source, targetNode:target, k:3} " + + ") " + + "YIELD nodeIds RETURN nodeIds "; + + Collection encounteredPaths = new HashSet<>(); + runQuery(yensQuery, result -> { + assertThat(result.columns()).containsExactlyInAnyOrder("nodeIds"); + + while (result.hasNext()) { + var next = result.next(); + var currentPath = (List) next.get("nodeIds"); + long[] pathToArray = currentPath.stream().mapToLong(l -> l).toArray(); + encounteredPaths.add(pathToArray); + } + + return true; + }); + + long[] nodes = new long[]{idFunction.of("a"), idFunction.of("b"), idFunction.of("c"), idFunction.of("d"), idFunction.of( + "e"), idFunction.of("f")}; + assertThat(encounteredPaths).containsExactlyInAnyOrder( + new long[]{nodes[0], nodes[1], nodes[3], nodes[4], nodes[2], nodes[5]}, + new long[]{nodes[0], nodes[1], nodes[3], nodes[4], nodes[5]}, + new long[]{nodes[0], nodes[1], nodes[2], nodes[5]} + ); + } }