Skip to content

Commit

Permalink
Add!
Browse files Browse the repository at this point in the history
Siti's files added hehe
  • Loading branch information
siti34 committed Aug 7, 2024
1 parent 15f21a2 commit 4537359
Show file tree
Hide file tree
Showing 6 changed files with 552 additions and 0 deletions.
106 changes: 106 additions & 0 deletions Features Experimentation Files/add_derive_rs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from langchain_google_vertexai import ChatVertexAI
from langgraph.checkpoint import MemorySaver

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import(
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder
)
import GraphRScustom
from GraphRScustom import LLMGraphTransformer
from pyvis.network import Network

import os
import requests
import urllib


from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document

from sqlalchemy.orm import Session
from sqlalchemy import select
import re
import models
from database import stakeholder_engine, user_engine, media_engine
from qdrant_media import search_in_qdrant, vectorize_query


#Implementation with Qdrant filter is in qdrant_media.py
def derive_rs_from_media(db:Session, stakeholder_id: int= None, query: str=None):
# from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatVertexAI(model="gemini-1.5-flash")

# llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
llm = ChatVertexAI(model="gemini-1.5-flash")
llm_transformer = LLMGraphTransformer(llm=llm)

#get media ids from stakeholder ids
# Get all media ids from stakeholder ids
results = get_media_id_from_stakeholder(db, stakeholder_id=stakeholder_id)
#list of media ids
media_ids = [result.media_id for result in results]
# query_vector = vectorize_query(query)
# top_media = search_in_qdrant(query_vector)

# List of media ids that stakeholder is mentioned in
media_ids = [result.media_id for result in results]

# Pass list of media_ids through to metadata filter function

# Match list of media_ids and rank top 5 based on user query

# Join all articles
for id in media_ids:
#get content
results = get_content_from_media_id(db, media_id = id) #returns list of json dicts
text = [result.content for result in results] #list of article content

documents = [Document(page_content=' '.join(text))]
# print(documents) ## for all content in medias

# Derive relationships from filtered media ids
graph_documents = llm_transformer.convert_to_graph_documents(documents)

nodes = graph_documents[0].nodes
rs = graph_documents[0].relationships
# print(rs)

# Initiating dict, dictionary of ids, list of relationships and id
nodes_id = {}
media_rs = []
node_id_map = {}

def derive_rs_from_media(db:Session, stakeholder_id: int= None, query: str=None):
node_id_map[node.id] = node_counter
nodes_id[node_counter] = node.id

# Relationships with ids
# Format relationships with ids
for relation in rs:
source_id = node_id_map.get(relation.source.id)
target_id = node_id_map.get(relation.target.id)

if source_id is None:
print(f"KeyError: '{relation.source.id}' not found in node_id_map")
if target_id is None:
print(f"KeyError: '{relation.target.id}' not found in node_id_map")

if source_id is not None and target_id is not None:
media_rs.append([source_id, relation.type, target_id])
else:
# Handle the case where either source_id or target_id is None
# Either source_id or target_id is None
continue

# Output: {nodes: {id:name}, edges:[id,str,id]}
return {'nodes': nodes_id, 'edges': media_rs}
# return media_ids

if __name__ == "__main__":
stakeholder_id = 28235
99 changes: 99 additions & 0 deletions Features Experimentation Files/generate_bokeh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import sqlite3
import os
import networkx as nx
from bokeh.io import show, output_file
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, NodesAndLinkedEdges
from bokeh.plotting import from_networkx
from bokeh.palettes import Spectral4

base_url= 'https://python-server-ohgaalojiq-de.a.run.app'

## File paths
data_dir = os.path.join(os.path.dirname(__file__), 'data')
csv_dir = os.path.join(data_dir, 'csv')
db_file = os.path.join(data_dir, 'stakeholders.db')
media_db_file = os.path.join(data_dir, 'media.db')

file1_path = os.path.join(os.path.dirname(__file__),"build_database.py")
file2_path = os.path.join(os.path.dirname(__file__), "database_query_function.py")

with open(file1_path, "r") as file1, open(file2_path, "r") as file2:
build_database_content = file1.read()
database_query_function_content = file2.read()

from database_query_function import get_relationships_with_names
from database_query_function import extract_after_last_slash

def get_data(query, params=None):
"""handle database queries and return results."""
try:
with sqlite3.connect(db_file) as conn:
cursor = conn.cursor()
cursor.execute(query, params or [])
results = cursor.fetchall()
except sqlite3.Error as e:
print(f"An error occurred: {e}")
return []
return results

def get_relationships_with_names_by_name(stakeholder_name: str):
"""Fetches relationships from the database and returns them with names for a given stakeholder."""
query = '''
WITH RECURSIVE
StakeholderConnections(level, subject, predicate, object) AS (
SELECT 1, s1.name, r.predicate, s2.name
FROM relationships r
JOIN stakeholders s1 ON s1.stakeholder_id = r.subject
JOIN stakeholders s2 ON s2.stakeholder_id = r.object
WHERE s1.name = ? OR s2.name = ?
UNION ALL
SELECT sc.level + 1, s1.name, r.predicate, s2.name
FROM StakeholderConnections sc
JOIN relationships r ON r.subject = (SELECT stakeholder_id FROM stakeholders WHERE name = sc.object)
JOIN stakeholders s1 ON s1.stakeholder_id = r.subject
JOIN stakeholders s2 ON s2.stakeholder_id = r.object
WHERE sc.level < 3
)
SELECT subject, predicate, object
FROM StakeholderConnections
'''
params = (stakeholder_name, stakeholder_name)
return get_data(query, params)

def generate_network_graph(stakeholder_name):
# Fetch connections up to a depth of 3 for a given stakeholder
relationships = get_relationships_with_names_by_name(stakeholder_name)

G = nx.MultiDiGraph() # Directed graph
for subject_name, predicate, object_name in relationships:
extracted_info = extract_after_last_slash(predicate)
G.add_edge(subject_name, object_name, label=extracted_info)

plot = Plot(width=800, height=800,
x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1))

graph_renderer = from_networkx(G, nx.spring_layout, scale=2, center=(0, 0))

graph_renderer.node_renderer.glyph = Circle(radius=0.1, fill_color=Spectral4[0])
graph_renderer.edge_renderer.glyph = MultiLine(line_color="black", line_alpha=0.8, line_width=1)

plot.renderers.append(graph_renderer)

node_hover_tool = HoverTool(tooltips=[("name", "@index")])
plot.add_tools(node_hover_tool, TapTool(), BoxSelectTool())

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()

output_dir = os.path.join(os.path.dirname(__file__), 'graphs')
if not os.path.exists(output_dir):
os.makedirs(output_dir)

output_file_path = os.path.join(output_dir, f"{stakeholder_name}_network.html")
output_file(output_file_path)
show(plot)


if __name__ == "__main__":
stakeholder_name = "Ben Carson" # Specify the stakeholder name here
generate_network_graph(stakeholder_name)
79 changes: 79 additions & 0 deletions Features Experimentation Files/generate_dash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dash import Dash, html
import dash_cytoscape as cyto
import requests

base_url= 'https://python-server-ohgaalojiq-de.a.run.app'

app = Dash(__name__)

def get_relationships_recr(subject, depth = 1):
response_ids = json.loads(requests.get(rf'https://python-server-ohgaalojiq-de.a.run.app/relationships/?subject={subject}').content)
response_names = json.loads(requests.get(rf'https://python-server-ohgaalojiq-de.a.run.app/relationships-with-names/?subject={subject}').content)

relationships = {}
for id_dict, name_ls in zip(response_ids, response_names):
rs = {
"subject": name_ls[0],
"predicate": name_ls[1],
"object": name_ls[2]
}
# print(rs)
relationships[tuple(id_dict.values())] = rs
if depth > 1:
relationships.update(get_relationships_recr(id_dict['object'], depth=depth-1))
# print(id_dict)
return relationships

def get_relationships(subject, depth):
res = get_relationships_recr(subject, depth=depth)
return list(res.values())

def extract_after_last_slash(predicate: str):
"""Extracts the last part after the last slash."""
return predicate.split('/')[-1]

# Define the base URL of the FastAPI
base_url = 'https://stakeholder-api-hafh6z44mq-de.a.run.app'

# Fetch stakeholders
stakeholders_response = requests.get(f'{base_url}/stakeholders/')
stakeholders = stakeholders_response.json()

# Fetch relationships
relationships_response = requests.get(f'{base_url}/relationships/')
relationships = relationships_response.json()

# Fetch relationships with names
relationships_with_names_response = requests.get(f'{base_url}/relationships-with-names/?subject=1')
relationships_with_names = relationships_with_names_response.json()

# Create elements for Cytoscape
elements = []

# Add stakeholders as nodes
for stakeholder in stakeholders:
elements.append({'data': {'id': stakeholder["id"], 'label': stakeholder["name"]}})

# Add relationships as edges
for relationship in relationships:
elements.append({'data': {'source': relationship["subject"], 'target': relationship["object"], 'label': relationship["predicate"]}})

# Add relationships with names as edges with additional information
for relation in relationships_with_names:
elements.append({'data': {'source': relation[0], 'target': relation[2], 'label': relation[1]}})

# Define the layout for the Dash app
app.layout = html.Div([
html.P("Dash Cytoscape:"),
cyto.Cytoscape(
id='cytoscape',
elements=elements,
layout={'name': 'breadthfirst'},
style={'width': '800px', 'height': '600px'}
)
])

# Run the Dash app
if __name__ == '__main__':
app.run_server(debug=True)

64 changes: 64 additions & 0 deletions Features Experimentation Files/generate_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import requests
from pyvis.network import Network
import json

base_url = 'https://python-server-ohgaalojiq-de.a.run.app'

def get_rs(subject):
response_names = json.loads(requests.get(rf'{base_url}/relationships-with-names/?subject={subject}').content)
return (response_names)
##print(get_rs(1))

# def get_photo(name):
# ref_pic = json.loads(requests.get(rf'{base_url}/stakeholders/?name={name}&summary=true&headline=true&photo=true').content)
# pic_url = ref_pic[0].get("photo")
# return pic_url

# print(get_photo("Ben Carson"))

def get_photo(name):
response = requests.get(rf'{base_url}/stakeholders/?name={name}&summary=true&headline=true&photo=true').content
ref_pic = json.loads(response)
if ref_pic and isinstance(ref_pic, list) and len(ref_pic) > 0:
photo_field = ref_pic[0].get("photo")
if photo_field:
pic_url = photo_field.split("||")[0].strip() # Split by "||" and take the first URL
return pic_url
else:
print(f"No photo field for get_photo({name}): {response}")
return None
else:
print(f"Unexpected response for get_photo({name}): {response}")
return None

def map_algs(g, alg="barnes"):
if alg=="barnes":
g.barnes_hut()
if alg=="force":
g.force_atlas_2based()
if alg=="hr":
g.hrepulsion()

def map_data(relationships, subj_color="#77E4C8", obj_color="#3DC2EC", edge_color="#96C9F4",subj_shape="image",obj_shape="image", alg="hr", buttons=False):
g = Network(height="1024px", width="100%",font_color="black")
if buttons == True:
g.width = "75%"
g.show_buttons(filter_=["edges", "physics"])
for rs in relationships:
subj = rs[0]
pred = rs[1]
obj = rs[2]
s_pic = get_photo(subj)
o_pic = get_photo(obj)
g.add_node(subj, color=subj_color, shape=subj_shape, image=s_pic)
g.add_node(obj, color=obj_color, shape=obj_shape, image=o_pic)
g.add_edge(subj,obj,label=pred, color=edge_color, smooth=False)
map_algs(g,alg=alg)
g.toggle_physics(False)
g.set_edge_smooth("dynamic")
g.show("network1.html")

if __name__ == '__main__':
subject = 1
nw = get_rs(subject)
map_data(relationships= nw, subj_shape="circularImage", alg="hr", buttons=False)
Loading

0 comments on commit 4537359

Please sign in to comment.