Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate google search with Superbooga #3021

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions extensions/superbooga/script.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json
import re
import textwrap

import gradio as gr
import requests
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

from modules import chat, shared
from modules.logging_colors import logger
Expand All @@ -17,6 +21,8 @@
'chunk_length': 700,
'chunk_separator': '',
'strong_cleanup': False,
'semantic_cleanup': True,
'semantic_weight': 0.5,
'threads': 4,
}

Expand Down Expand Up @@ -84,6 +90,77 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
yield i


def calculate_semantic_similarity(query_embedding, target_embedding):
# Calculate cosine similarity between the query embedding and the target embedding
similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1))
return similarity[0][0]


def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
# Load parameters from the config file
with open('custom_search_engine_keys.json') as key_file:
key = json.load(key_file)

model = SentenceTransformer('all-MiniLM-L6-v2')
query_embedding = model.encode([query])[0]

# Set up API endpoint and parameters
url = "https://www.googleapis.com/customsearch/v1"

# Retrieve the values from the config dictionary
params = {
"key": key.get("key", "default_key_value"),
"cx": key.get("cx", "default_custom_engine_value"),
"q": str(query),
}

if "default_key_value" in str(params):
print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query

if "default_custom_engine_value" in str(params):
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query

# Send API request
response = requests.get(url, params=params)

# Parse JSON response
data = response.json()

# get the result items
search_items = data.get("items")

# iterate over 10 results found
urls = ""
for i, search_item in enumerate(search_items, start=1):
if semantic_cleanup:
# get titles and descriptions and use that to semantically weight the search result
# get the page title
title = search_item.get("title")
# page snippet
snippet = search_item.get("snippet")

target_sentence = str(title) + " " + str(snippet)
target_embedding = model.encode([target_sentence])[0]

similarity_score = calculate_semantic_similarity(query_embedding, target_embedding)

if similarity_score < semantic_requirement:
continue

# extract the page url and add it to the urls to download
link = search_item.get("link")
urls += link + "\n"

# Call the original feed_url_into_collector function instead of duplicating the code
result_generator = feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)

# Consume the yielded values
for result in result_generator:
yield result


def apply_settings(chunk_count, chunk_count_initial, time_weight):
global params
params['chunk_count'] = int(chunk_count)
Expand All @@ -102,6 +179,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
user_input += additional_context
logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n')
else:

def make_single_exchange(id_):
Expand Down Expand Up @@ -240,6 +318,44 @@ def ui():
file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Load data')

with gr.Tab("Search input"):
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')

semantic_cleanup = gr.Checkbox(value=params['semantic_cleanup'], label='Require semantic similarity', info='Only download pages with similar titles/snippets to the search based on a semantic search')
semantic_requirement = gr.Slider(0, 1, value=params['semantic_weight'], label='Semantic similarity requirement', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.')

search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
update_search = gr.Button('Load data')

with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent("""
### Instructions for Installation/Setup:

To set up a custom search engine with Google, please follow the instructions provided in this guide:
https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python

Create a new file called "custom_search_engine_keys.json".

Open the "custom_search_engine_keys.json" file and paste the following text into it:

json

{
"key": "Custom search engine key",
"cx": "Custom search engine cx number"
}

Replace the placeholders "Custom search engine key" and "Custom search engine cx number" with the respective values you obtained from the previous step.

### Usage:

Enter your desired search query in the search box above.

Press the "Load Data" button. This will add the retrieved data to the local chromaDB, which will be read into the context during runtime.

"""))

with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
Expand All @@ -256,4 +372,5 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)