-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Siti's files added hehe
- Loading branch information
Showing
6 changed files
with
552 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from langchain_core.tools import tool | ||
from langgraph.prebuilt import create_react_agent | ||
from langchain_google_vertexai import ChatVertexAI | ||
from langgraph.checkpoint import MemorySaver | ||
|
||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.prompts import( | ||
ChatPromptTemplate, | ||
SystemMessagePromptTemplate, | ||
HumanMessagePromptTemplate, | ||
MessagesPlaceholder | ||
) | ||
import GraphRScustom | ||
from GraphRScustom import LLMGraphTransformer | ||
from pyvis.network import Network | ||
|
||
import os | ||
import requests | ||
import urllib | ||
|
||
|
||
from langchain_experimental.graph_transformers import LLMGraphTransformer | ||
from langchain_openai import ChatOpenAI | ||
from langchain_core.documents import Document | ||
|
||
from sqlalchemy.orm import Session | ||
from sqlalchemy import select | ||
import re | ||
import models | ||
from database import stakeholder_engine, user_engine, media_engine | ||
from qdrant_media import search_in_qdrant, vectorize_query | ||
|
||
|
||
#Implementation with Qdrant filter is in qdrant_media.py | ||
def derive_rs_from_media(db:Session, stakeholder_id: int= None, query: str=None): | ||
# from langchain_google_genai import ChatGoogleGenerativeAI | ||
llm = ChatVertexAI(model="gemini-1.5-flash") | ||
|
||
# llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro") | ||
llm = ChatVertexAI(model="gemini-1.5-flash") | ||
llm_transformer = LLMGraphTransformer(llm=llm) | ||
|
||
#get media ids from stakeholder ids | ||
# Get all media ids from stakeholder ids | ||
results = get_media_id_from_stakeholder(db, stakeholder_id=stakeholder_id) | ||
#list of media ids | ||
media_ids = [result.media_id for result in results] | ||
# query_vector = vectorize_query(query) | ||
# top_media = search_in_qdrant(query_vector) | ||
|
||
# List of media ids that stakeholder is mentioned in | ||
media_ids = [result.media_id for result in results] | ||
|
||
# Pass list of media_ids through to metadata filter function | ||
|
||
# Match list of media_ids and rank top 5 based on user query | ||
|
||
# Join all articles | ||
for id in media_ids: | ||
#get content | ||
results = get_content_from_media_id(db, media_id = id) #returns list of json dicts | ||
text = [result.content for result in results] #list of article content | ||
|
||
documents = [Document(page_content=' '.join(text))] | ||
# print(documents) ## for all content in medias | ||
|
||
# Derive relationships from filtered media ids | ||
graph_documents = llm_transformer.convert_to_graph_documents(documents) | ||
|
||
nodes = graph_documents[0].nodes | ||
rs = graph_documents[0].relationships | ||
# print(rs) | ||
|
||
# Initiating dict, dictionary of ids, list of relationships and id | ||
nodes_id = {} | ||
media_rs = [] | ||
node_id_map = {} | ||
|
||
def derive_rs_from_media(db:Session, stakeholder_id: int= None, query: str=None): | ||
node_id_map[node.id] = node_counter | ||
nodes_id[node_counter] = node.id | ||
|
||
# Relationships with ids | ||
# Format relationships with ids | ||
for relation in rs: | ||
source_id = node_id_map.get(relation.source.id) | ||
target_id = node_id_map.get(relation.target.id) | ||
|
||
if source_id is None: | ||
print(f"KeyError: '{relation.source.id}' not found in node_id_map") | ||
if target_id is None: | ||
print(f"KeyError: '{relation.target.id}' not found in node_id_map") | ||
|
||
if source_id is not None and target_id is not None: | ||
media_rs.append([source_id, relation.type, target_id]) | ||
else: | ||
# Handle the case where either source_id or target_id is None | ||
# Either source_id or target_id is None | ||
continue | ||
|
||
# Output: {nodes: {id:name}, edges:[id,str,id]} | ||
return {'nodes': nodes_id, 'edges': media_rs} | ||
# return media_ids | ||
|
||
if __name__ == "__main__": | ||
stakeholder_id = 28235 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import sqlite3 | ||
import os | ||
import networkx as nx | ||
from bokeh.io import show, output_file | ||
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, NodesAndLinkedEdges | ||
from bokeh.plotting import from_networkx | ||
from bokeh.palettes import Spectral4 | ||
|
||
base_url= 'https://python-server-ohgaalojiq-de.a.run.app' | ||
|
||
## File paths | ||
data_dir = os.path.join(os.path.dirname(__file__), 'data') | ||
csv_dir = os.path.join(data_dir, 'csv') | ||
db_file = os.path.join(data_dir, 'stakeholders.db') | ||
media_db_file = os.path.join(data_dir, 'media.db') | ||
|
||
file1_path = os.path.join(os.path.dirname(__file__),"build_database.py") | ||
file2_path = os.path.join(os.path.dirname(__file__), "database_query_function.py") | ||
|
||
with open(file1_path, "r") as file1, open(file2_path, "r") as file2: | ||
build_database_content = file1.read() | ||
database_query_function_content = file2.read() | ||
|
||
from database_query_function import get_relationships_with_names | ||
from database_query_function import extract_after_last_slash | ||
|
||
def get_data(query, params=None): | ||
"""handle database queries and return results.""" | ||
try: | ||
with sqlite3.connect(db_file) as conn: | ||
cursor = conn.cursor() | ||
cursor.execute(query, params or []) | ||
results = cursor.fetchall() | ||
except sqlite3.Error as e: | ||
print(f"An error occurred: {e}") | ||
return [] | ||
return results | ||
|
||
def get_relationships_with_names_by_name(stakeholder_name: str): | ||
"""Fetches relationships from the database and returns them with names for a given stakeholder.""" | ||
query = ''' | ||
WITH RECURSIVE | ||
StakeholderConnections(level, subject, predicate, object) AS ( | ||
SELECT 1, s1.name, r.predicate, s2.name | ||
FROM relationships r | ||
JOIN stakeholders s1 ON s1.stakeholder_id = r.subject | ||
JOIN stakeholders s2 ON s2.stakeholder_id = r.object | ||
WHERE s1.name = ? OR s2.name = ? | ||
UNION ALL | ||
SELECT sc.level + 1, s1.name, r.predicate, s2.name | ||
FROM StakeholderConnections sc | ||
JOIN relationships r ON r.subject = (SELECT stakeholder_id FROM stakeholders WHERE name = sc.object) | ||
JOIN stakeholders s1 ON s1.stakeholder_id = r.subject | ||
JOIN stakeholders s2 ON s2.stakeholder_id = r.object | ||
WHERE sc.level < 3 | ||
) | ||
SELECT subject, predicate, object | ||
FROM StakeholderConnections | ||
''' | ||
params = (stakeholder_name, stakeholder_name) | ||
return get_data(query, params) | ||
|
||
def generate_network_graph(stakeholder_name): | ||
# Fetch connections up to a depth of 3 for a given stakeholder | ||
relationships = get_relationships_with_names_by_name(stakeholder_name) | ||
|
||
G = nx.MultiDiGraph() # Directed graph | ||
for subject_name, predicate, object_name in relationships: | ||
extracted_info = extract_after_last_slash(predicate) | ||
G.add_edge(subject_name, object_name, label=extracted_info) | ||
|
||
plot = Plot(width=800, height=800, | ||
x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1)) | ||
|
||
graph_renderer = from_networkx(G, nx.spring_layout, scale=2, center=(0, 0)) | ||
|
||
graph_renderer.node_renderer.glyph = Circle(radius=0.1, fill_color=Spectral4[0]) | ||
graph_renderer.edge_renderer.glyph = MultiLine(line_color="black", line_alpha=0.8, line_width=1) | ||
|
||
plot.renderers.append(graph_renderer) | ||
|
||
node_hover_tool = HoverTool(tooltips=[("name", "@index")]) | ||
plot.add_tools(node_hover_tool, TapTool(), BoxSelectTool()) | ||
|
||
graph_renderer.selection_policy = NodesAndLinkedEdges() | ||
graph_renderer.inspection_policy = NodesAndLinkedEdges() | ||
|
||
output_dir = os.path.join(os.path.dirname(__file__), 'graphs') | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
output_file_path = os.path.join(output_dir, f"{stakeholder_name}_network.html") | ||
output_file(output_file_path) | ||
show(plot) | ||
|
||
|
||
if __name__ == "__main__": | ||
stakeholder_name = "Ben Carson" # Specify the stakeholder name here | ||
generate_network_graph(stakeholder_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from dash import Dash, html | ||
import dash_cytoscape as cyto | ||
import requests | ||
|
||
base_url= 'https://python-server-ohgaalojiq-de.a.run.app' | ||
|
||
app = Dash(__name__) | ||
|
||
def get_relationships_recr(subject, depth = 1): | ||
response_ids = json.loads(requests.get(rf'https://python-server-ohgaalojiq-de.a.run.app/relationships/?subject={subject}').content) | ||
response_names = json.loads(requests.get(rf'https://python-server-ohgaalojiq-de.a.run.app/relationships-with-names/?subject={subject}').content) | ||
|
||
relationships = {} | ||
for id_dict, name_ls in zip(response_ids, response_names): | ||
rs = { | ||
"subject": name_ls[0], | ||
"predicate": name_ls[1], | ||
"object": name_ls[2] | ||
} | ||
# print(rs) | ||
relationships[tuple(id_dict.values())] = rs | ||
if depth > 1: | ||
relationships.update(get_relationships_recr(id_dict['object'], depth=depth-1)) | ||
# print(id_dict) | ||
return relationships | ||
|
||
def get_relationships(subject, depth): | ||
res = get_relationships_recr(subject, depth=depth) | ||
return list(res.values()) | ||
|
||
def extract_after_last_slash(predicate: str): | ||
"""Extracts the last part after the last slash.""" | ||
return predicate.split('/')[-1] | ||
|
||
# Define the base URL of the FastAPI | ||
base_url = 'https://stakeholder-api-hafh6z44mq-de.a.run.app' | ||
|
||
# Fetch stakeholders | ||
stakeholders_response = requests.get(f'{base_url}/stakeholders/') | ||
stakeholders = stakeholders_response.json() | ||
|
||
# Fetch relationships | ||
relationships_response = requests.get(f'{base_url}/relationships/') | ||
relationships = relationships_response.json() | ||
|
||
# Fetch relationships with names | ||
relationships_with_names_response = requests.get(f'{base_url}/relationships-with-names/?subject=1') | ||
relationships_with_names = relationships_with_names_response.json() | ||
|
||
# Create elements for Cytoscape | ||
elements = [] | ||
|
||
# Add stakeholders as nodes | ||
for stakeholder in stakeholders: | ||
elements.append({'data': {'id': stakeholder["id"], 'label': stakeholder["name"]}}) | ||
|
||
# Add relationships as edges | ||
for relationship in relationships: | ||
elements.append({'data': {'source': relationship["subject"], 'target': relationship["object"], 'label': relationship["predicate"]}}) | ||
|
||
# Add relationships with names as edges with additional information | ||
for relation in relationships_with_names: | ||
elements.append({'data': {'source': relation[0], 'target': relation[2], 'label': relation[1]}}) | ||
|
||
# Define the layout for the Dash app | ||
app.layout = html.Div([ | ||
html.P("Dash Cytoscape:"), | ||
cyto.Cytoscape( | ||
id='cytoscape', | ||
elements=elements, | ||
layout={'name': 'breadthfirst'}, | ||
style={'width': '800px', 'height': '600px'} | ||
) | ||
]) | ||
|
||
# Run the Dash app | ||
if __name__ == '__main__': | ||
app.run_server(debug=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import requests | ||
from pyvis.network import Network | ||
import json | ||
|
||
base_url = 'https://python-server-ohgaalojiq-de.a.run.app' | ||
|
||
def get_rs(subject): | ||
response_names = json.loads(requests.get(rf'{base_url}/relationships-with-names/?subject={subject}').content) | ||
return (response_names) | ||
##print(get_rs(1)) | ||
|
||
# def get_photo(name): | ||
# ref_pic = json.loads(requests.get(rf'{base_url}/stakeholders/?name={name}&summary=true&headline=true&photo=true').content) | ||
# pic_url = ref_pic[0].get("photo") | ||
# return pic_url | ||
|
||
# print(get_photo("Ben Carson")) | ||
|
||
def get_photo(name): | ||
response = requests.get(rf'{base_url}/stakeholders/?name={name}&summary=true&headline=true&photo=true').content | ||
ref_pic = json.loads(response) | ||
if ref_pic and isinstance(ref_pic, list) and len(ref_pic) > 0: | ||
photo_field = ref_pic[0].get("photo") | ||
if photo_field: | ||
pic_url = photo_field.split("||")[0].strip() # Split by "||" and take the first URL | ||
return pic_url | ||
else: | ||
print(f"No photo field for get_photo({name}): {response}") | ||
return None | ||
else: | ||
print(f"Unexpected response for get_photo({name}): {response}") | ||
return None | ||
|
||
def map_algs(g, alg="barnes"): | ||
if alg=="barnes": | ||
g.barnes_hut() | ||
if alg=="force": | ||
g.force_atlas_2based() | ||
if alg=="hr": | ||
g.hrepulsion() | ||
|
||
def map_data(relationships, subj_color="#77E4C8", obj_color="#3DC2EC", edge_color="#96C9F4",subj_shape="image",obj_shape="image", alg="hr", buttons=False): | ||
g = Network(height="1024px", width="100%",font_color="black") | ||
if buttons == True: | ||
g.width = "75%" | ||
g.show_buttons(filter_=["edges", "physics"]) | ||
for rs in relationships: | ||
subj = rs[0] | ||
pred = rs[1] | ||
obj = rs[2] | ||
s_pic = get_photo(subj) | ||
o_pic = get_photo(obj) | ||
g.add_node(subj, color=subj_color, shape=subj_shape, image=s_pic) | ||
g.add_node(obj, color=obj_color, shape=obj_shape, image=o_pic) | ||
g.add_edge(subj,obj,label=pred, color=edge_color, smooth=False) | ||
map_algs(g,alg=alg) | ||
g.toggle_physics(False) | ||
g.set_edge_smooth("dynamic") | ||
g.show("network1.html") | ||
|
||
if __name__ == '__main__': | ||
subject = 1 | ||
nw = get_rs(subject) | ||
map_data(relationships= nw, subj_shape="circularImage", alg="hr", buttons=False) |
Oops, something went wrong.