Skip to content

Commit

Permalink
Fix local graph by properly dealing with bi-directional relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Jul 9, 2023
1 parent 21a74f8 commit e7a831c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
59 changes: 35 additions & 24 deletions pydatalab/pydatalab/routes/v0_1/graphs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Set

from flask import jsonify

Expand All @@ -13,7 +13,7 @@ def get_graph_cy_format(item_id: Optional[str] = None):
get_default_permissions(user_only=False),
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)
node_ids = {document["item_id"] for document in all_documents}
node_ids: Set[str] = {document["item_id"] for document in all_documents}
all_documents.rewind()

else:
Expand All @@ -27,24 +27,30 @@ def get_graph_cy_format(item_id: Optional[str] = None):
)
)

node_ids = {document["item_id"] for document in all_documents}
node_ids = {document["item_id"] for document in all_documents} | {
relationship["item_id"]
for document in all_documents
for relationship in document.get("relationships", [])
}
if len(node_ids) > 1:
query = [{"item_id": id} for id in node_ids if id != item_id]
# query.extend([{"relationships.item_id": id} for id in node_ids if id != item_id])
next_shell = flask_mongo.db.items.find(
{
"$or": [
*[{"item_id": id} for id in node_ids if id != item_id],
*[{"relationships.item_id": id} for id in node_ids if id != item_id],
],
"$or": query,
**get_default_permissions(user_only=False),
},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)

node_ids = node_ids | {document["item_id"] for document in next_shell}
all_documents.extend(next_shell)
node_ids = node_ids | {document["item_id"] for document in all_documents}

nodes = []
edges = []

# Collect the elements that have already been added to the graph, to avoid duplication
drawn_elements = set()
for document in all_documents:

node_collections = set()
Expand Down Expand Up @@ -92,28 +98,33 @@ def get_graph_cy_format(item_id: Optional[str] = None):
source = relationship["item_id"]
if source not in node_ids:
continue
edges.append(
edge_id = f"{source}->{target}"
if edge_id not in drawn_elements:
drawn_elements.add(edge_id)
edges.append(
{
"data": {
"id": edge_id,
"source": source,
"target": target,
"value": 1,
}
}
)

if document["item_id"] not in drawn_elements:
drawn_elements.add(document["item_id"])
nodes.append(
{
"data": {
"id": f"{source}->{target}",
"source": source,
"target": target,
"value": 1,
"id": document["item_id"],
"name": document["name"],
"type": document["type"],
"collections": list(node_collections),
}
}
)

nodes.append(
{
"data": {
"id": document["item_id"],
"name": document["name"],
"type": document["type"],
"collections": list(node_collections),
}
}
)

# We want to filter out all the starting materials that don't have relationships since there are so many of them:
whitelist = {edge["data"]["source"] for edge in edges}

Expand Down
10 changes: 6 additions & 4 deletions pydatalab/tests/routers/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_simple_graph(client):
assert len(graph["edges"]) == 3

graph = client.get("/item-graph/child_1").json
# These values are currently incorrect: really want to traverse the graph but need to
# resolve relationships first
assert len(graph["nodes"]) == 1
assert len(graph["edges"]) == 0
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

graph = client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 4
assert len(graph["edges"]) == 3

0 comments on commit e7a831c

Please sign in to comment.