Skip to content

Commit

Permalink
new updates
Browse files Browse the repository at this point in the history
  • Loading branch information
geemi725 committed Feb 28, 2024
1 parent 400889c commit 2fd3a18
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
2 changes: 1 addition & 1 deletion xpertai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
- Then, provide a summary of everything you described previously to describe the relationship between these features and the {observation}. You must sound like a scientist.
Give scientific evidence for these answers and provide citations.
- Finally, provide the list of references used to answer. Use APA citation style for referencing. \n
- Finally, provide the list of references only used to answer. DO NOT make up references. Use APA style for referencing. \n
Eg: References: \n
1. reference 1 \n
2. reference 2 \n
Expand Down
7 changes: 6 additions & 1 deletion xpertai/tools/explain_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def get_modelsummary(arg_dict):
"model_type":<classifier or regressor>,
"top_k":<Number of features to explain>,
"XAI_tool": <SHAP, LIME or Both>
"persist_directory":<path to vectordb>
}.
Example:
Expand All @@ -22,6 +23,9 @@ def get_modelsummary(arg_dict):
'''
save_dir = './data'
global persist_directory ; persist_directory = None
global clean ; clean = True

if not os.path.exists(save_dir):
os.mkdir(save_dir)
# arg_dict = json.loads(json_request)
Expand Down Expand Up @@ -70,7 +74,8 @@ def get_modelsummary(arg_dict):
f.close()
metadata = {'Authors': 'XpertAI', 'Year': '2023', 'Title': 'XAI Summary'}

vector_db(lit_file=f'{save_dir}/XAI_summary.txt', clean=True,
vector_db(persist_directory=persist_directory,
lit_file=f'{save_dir}/XAI_summary.txt', clean=clean,
metadatas=metadata)

# Step 5: Generate summary of model explanation
Expand Down
23 changes: 12 additions & 11 deletions xpertai/tools/generate_nle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def gen_nle(arg_dict):
{"observation":<target property>,
"XAI_tool": <SHAP, LIME or Both>,
"top_k":<maximum number of features to explain>
"persist_directory":<path to vectordb>
}
Example:
Expand All @@ -20,6 +21,7 @@ def gen_nle(arg_dict):
'''

save_dir = './data'
global persist_directory ; persist_directory = "./data/chroma/"
# arg_dict = json.loads(json_request)
for k, val in arg_dict.items():
globals()[k] = val
Expand Down Expand Up @@ -61,7 +63,7 @@ def gen_nle(arg_dict):
# get human interpretable feature labels
# #initiate retriever, chain
llm = ChatOpenAI(
temperature=0.1,
temperature=0.0,
model_name="gpt-4",
request_timeout=1000)

Expand All @@ -78,30 +80,29 @@ def gen_nle(arg_dict):
# generate NLEs with citations

features = ','.join(new_labels)
db = Chroma(persist_directory=persist_directory,
embedding_function=embedding)
docs = []
# first collect docs for each feature
for feature in new_labels:
initial_question = f"""It has been identified by XAI analysis {feature} have an impact on the {observation}. \n
Your task is to explain how the {observation} is affected by the {feature}. \
How does each of these features impact the {observation}?
"""
initial_question = f"""How does the {feature} impact the {observation}?"""
# Get relevant docs

db = Chroma(persist_directory="./data/chroma/",
embedding_function=embedding)

docs.append(db.max_marginal_relevance_search(initial_question))
fetched = db.max_marginal_relevance_search(initial_question,k=4)
docs.append(fetched)

# flatten list of docs
docs = [item for sublist in docs for item in sublist]

# add citations from metadata
documents = ""
for i in range(len(docs)):
doc = docs[i].page_content
try:
authors = docs[i].metadata["authors"]
year = docs[i].metadata["year"]
documents += f"{doc} ({authors},{year}) \n\n"
title = docs[i].metadata["source"]
documents += f"{doc} REFERENCE:({authors},{year},{title}) \n\n"

except BaseException:
documents += f"{doc} \n\n"

Expand Down
20 changes: 11 additions & 9 deletions xpertai/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def explain_lime(df_init, model_path, model_type, top_k=2,

def load_split_docs(filename, meta_data=None):
r_splitter = RecursiveCharacterTextSplitter(
chunk_size=1500,
chunk_overlap=200,
chunk_size=500,
chunk_overlap=50,
length_function=len
)
docs = None
Expand All @@ -293,8 +293,8 @@ def load_split_docs(filename, meta_data=None):
return docs_split


def _create_vecdb(docs_split, persist_directory):
embedding = OpenAIEmbeddings()
def _create_vecdb(docs_split, persist_directory, embedding=None):
if embedding is None: embedding = OpenAIEmbeddings()

vectordb = Chroma.from_documents(
documents=docs_split,
Expand All @@ -304,8 +304,9 @@ def _create_vecdb(docs_split, persist_directory):
vectordb.persist()


def _update_vecdb(docs_split, persist_directory):
embedding = OpenAIEmbeddings()
def _update_vecdb(docs_split, persist_directory,embedding=None):
if embedding is None: embedding = OpenAIEmbeddings()

vectordb = Chroma(persist_directory=persist_directory,
embedding_function=embedding)

Expand Down Expand Up @@ -340,7 +341,8 @@ def _get_metadata(lit_file):
def vector_db(persist_directory=None,
lit_file=None, clean=False,
try_meta_data=False,
metadatas=None):
metadatas=None,
embedding=None):

if persist_directory is None:
persist_directory = "./data/chroma/"
Expand All @@ -355,7 +357,7 @@ def vector_db(persist_directory=None,
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
os.mkdir(persist_directory)
_create_vecdb(text_split, persist_directory)
_create_vecdb(text_split, persist_directory,embedding=embedding)

else:
_update_vecdb(text_split, persist_directory)
_update_vecdb(text_split, persist_directory,embedding=embedding)

0 comments on commit 2fd3a18

Please sign in to comment.