-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneo4j_client.py
183 lines (150 loc) · 6.08 KB
/
neo4j_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
from dotenv import load_dotenv
from neo4j import GraphDatabase
load_dotenv()
def clear_db() -> None:
"""Clears the Neo4j database"""
with Neo4jClient() as neo4j_client:
neo4j_client.clear_graph()
def clear_db_artists() -> None:
"""Clears the artist nodes from the Neo4j database"""
with Neo4jClient() as neo4j_client:
neo4j_client.clear_artists()
def clear_db_tracks() -> None:
"""Clears the track nodes from the Neo4j database"""
with Neo4jClient() as neo4j_client:
neo4j_client.clear_tracks()
class Neo4jClient:
"""Neo4j class to handle neo4j database requests"""
def __init__(self: "Neo4jClient") -> None:
self._driver = None
def __enter__(self):
uri = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
if uri is None or username is None or password is None:
raise ValueError("Missing Neo4j environment variables")
self._driver = GraphDatabase.driver(uri, auth=(username, password))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._driver is not None:
self._driver.close()
def clear_graph(self: "Neo4jClient") -> None:
"""Clears the Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
"""
if self._driver is not None:
with self._driver.session() as session:
delete_query = "MATCH (n) DETACH DELETE n"
session.run(delete_query)
def clear_artists(self: "Neo4jClient") -> None:
"""Clears the artist nodes from the Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
"""
if self._driver is not None:
with self._driver.session() as session:
delete_query = "MATCH (n: Artist) DETACH DELETE n"
session.run(delete_query)
def clear_tracks(self: "Neo4jClient") -> None:
"""Clears the track nodes from the Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
"""
if self._driver is not None:
with self._driver.session() as session:
delete_query = "MATCH (n: Track) DETACH DELETE n"
session.run(delete_query)
def verify_conn(self: "Neo4jClient") -> None:
"""Verifies connection to Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
"""
if self._driver is not None:
self._driver.verify_connectivity()
def create_artist_node(self: "Neo4jClient", artist: dict) -> None:
"""Creates an artist node in the Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
artist (dict): The artist information
"""
if self._driver is not None:
with self._driver.session() as session:
constraint = (
"CREATE CONSTRAINT unique_artist_id IF NOT EXISTS "
"FOR (n: Artist) REQUIRE n.id IS UNIQUE"
)
session.run(constraint)
node_query = "MERGE (n:Artist {name: $name, id: $id})"
session.run(
node_query,
name=artist["name"],
id=artist["id"],
)
def create_track_node(self: "Neo4jClient", track: dict) -> None:
"""Creates a track node in the Neo4j database
Args:
self (Neo4jClient): Instance of Neo4jClient
track (dict): The track information
"""
if self._driver is not None:
with self._driver.session() as session:
unique_track_constraint = (
"CREATE CONSTRAINT unique_track_id IF NOT EXISTS "
"FOR (n: Track) REQUIRE n.id IS UNIQUE"
)
session.run(unique_track_constraint)
node_query = (
"MERGE (n:Track {name: $name, id: $id, artists: $artists})"
)
session.run(
node_query,
name=track["name"],
id=track["id"],
artists=track["artists"],
)
def create_relationships(self: "Neo4jClient") -> None:
"""Creates relationships between artists and tracks in the Neo4j
database
Args:
self (Neo4jClient): Instance of Neo4jClient
"""
if self._driver is not None:
with self._driver.session() as session:
relationship_query = (
"MATCH (a: Artist), (t: Track) "
"WHERE a.id IN t.artists "
"MERGE (a)-[:APPEARS_ON]->(t)"
)
session.run(relationship_query)
def shortest_path(self: "Neo4jClient", start_id: str, end_id: str) -> list:
"""Finds the shortest path between two artists, if it exists
Args:
self (Neo4jClient): Instance of Neo4jClient
start_id (str): id of the starting artist
end_id (str): id of the ending artist
Returns:
list: The shortest path between the two artists, if any
"""
if self._driver is not None:
with self._driver.session() as session:
path_query = (
"MATCH (start:Artist {id: $start_id}), (end:Artist {id: $end_id}), "
"p = shortestPath((start)-[:APPEARS_ON*]-(end)) "
"UNWIND nodes(p) AS node "
"RETURN node.id, node.name"
)
result = session.run(
path_query, start_id=start_id, end_id=end_id
)
path = []
if result.peek() is None:
print("No path found")
return path
print("Path found")
for record in result:
node_id = record["node.id"]
path.append(node_id)
return path
return []