Skip to content

Commit

Permalink
search knowledge
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiesun committed Dec 3, 2023
1 parent 6f3bc28 commit de6cf9a
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 31 deletions.
Empty file added components/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions components/streamlit_tesseract_scanner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import pytesseract
from pytesseract import Output

import streamlit as st
import streamlit.components.v1 as components

# Tell streamlit that there is a component called camera_input_live,
# and that the code to display that component is in the "frontend" folder
frontend_dir = (Path(__file__).parent / "frontend").absolute()
_component_func = components.declare_component(
"tesseract_scanner", path=str(frontend_dir)
)


def tesseract_scanner(showimg: bool =False,
lang: str = 'eng',
blacklist: str = None,
whitelist: str = None,
psm: str = '3',
hrate: float=0.2,
key: Optional[str] = None
) -> Optional[BytesIO]:
"""
Add a descriptive docstring
"""
b64_data: Optional[str] = _component_func(hrate=hrate, key=key)

if b64_data is None:
return None

raw_data = b64_data.split(",")[1] # Strip the data: type prefix

component_value = BytesIO(base64.b64decode(raw_data))

# return component_value
# image = cv2.imdecode(np.frombuffer(component_value, np.uint8), cv2.IMREAD_COLOR)

image = base64.b64decode(raw_data)
image = np.fromstring(image, dtype=np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]

if showimg:
st.image(image)

# blacklist = '@*|©_Ⓡ®¢§š'
if blacklist:
custom_config = f'''--oem 3 --psm 11'''
else:
custom_config = f'''--oem 3 --psm 3'''

text = pytesseract.image_to_string(image, lang=lang, config=custom_config)
# text = text.split('\n')
# while("" in text): text.remove("")
# while(" " in text): text.remove(" ")
# text.remove("\x0c")

return text


def main():
st.write("## Example")

blacklist='@*|©_Ⓡ®¢§š'
data = tesseract_scanner(showimg=False, lang='vie+eng',
blacklist=blacklist, psm=3)

if data is not None:
st.write(data)

if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions components/streamlit_tesseract_scanner/frontend/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>streamlit-camera-input-live</title>
<script src="./streamlit-component-lib.js"></script>
<script src="./main.js"></script>
<!--link rel="stylesheet" href="./style.css" / -->
</head>
<body>
<div id="container">
<input id="videoheight" type="range" min="1" max="100" value="20" style="width:100%">
<video id="video" autoplay="true"></video>
<canvas id="canvas"></canvas>
</div>
</body>
</html>
84 changes: 84 additions & 0 deletions components/streamlit_tesseract_scanner/frontend/main.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

// Borrowed minimalistic Streamlit API from Thiago
// https://discuss.streamlit.io/t/code-snippet-create-components-without-any-frontend-tooling-no-react-babel-webpack-etc/13064
function sendMessageToStreamlitClient(type, data) {
console.log(type, data)
const outData = Object.assign({
isStreamlitMessage: true,
type: type,
}, data);
window.parent.postMessage(outData, "*");
}

const Streamlit = {
setComponentReady: function() {
sendMessageToStreamlitClient("streamlit:componentReady", {apiVersion: 1});
},
setFrameHeight: function(height) {
sendMessageToStreamlitClient("streamlit:setFrameHeight", {height: height});
},
setComponentValue: function(value) {
sendMessageToStreamlitClient("streamlit:setComponentValue", {value: value});
},
RENDER_EVENT: "streamlit:render",
events: {
addEventListener: function(type, callback) {
window.addEventListener("message", function(event) {
if (event.data.type === type) {
event.detail = event.data
callback(event);
}
});
}
}
}

1 change: 1 addition & 0 deletions components/streamlit_tesseract_scanner/frontend/test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

28 changes: 28 additions & 0 deletions libs/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import requests
import os


def search_knowledge(collection, query):
gpt_address = os.getenv("GPT_SERVICE_ADDRESS")
api_token = os.getenv("GPT_SERVICE_TOKEN")
url = f"{gpt_address}/knowledge/search"
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {api_token}"
}
payload = {
"collection": collection,
"query": query
}

response = requests.post(url, headers=headers, json=payload)
if response.status_code != 200:
return f"Error searching knowledge: {response.text}"
data = response.json()

def fmt(v):
return f'**Score**: {v["score"]}\n\n{v["content"]}\n\n---\n\n'

return "\n\n".join([fmt(v) for v in data["result"]["data"]])

24 changes: 24 additions & 0 deletions libs/msal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from msal_streamlit_authentication import msal_authentication
import os


def msal_auth():
tenant_id = os.getenv("MSAL_TENANTID")
app_id = os.getenv("MSAL_APPID")
return msal_authentication(
auth={
"clientId": app_id,
"authority": f"https://login.microsoftonline.com/{tenant_id}",
"redirectUri": "/",
"postLogoutRedirectUri": "/"
},
cache={
"cacheLocation": "sessionStorage",
"storeAuthStateInCookie": False
},
login_button_text="Microsoft Account Login",
login_request={
"scopes": [f"{app_id}/.default"]
},
key="msal_token"
)
53 changes: 48 additions & 5 deletions pages/02_Knowledge_Search.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,56 @@
import streamlit as st
import os
import sys
from dotenv import load_dotenv

# 在其他页面
if 'authenticated' not in st.session_state or not st.session_state['authenticated']:
st.error("请先登录。")
st.stop() # 阻止未认证的用户访问页面内容
sys.path.append(os.path.abspath('..'))
load_dotenv()
from libs.http import search_knowledge
from libs.msal import msal_auth

if os.getenv("DEV_MODE") not in ["true", "1", "on"]:
value = msal_auth()
if value is None:
st.stop()

knowledges = {
"青少年编程": "codeboy",
"对数课堂": "logbot",
}

st.sidebar.markdown("# 知识库搜索")

st.title("知识库搜索")
st.subheader("搜索知识库内容")
st.divider()

if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "欢迎使用知识库检索, 请输入主题"}]

collection = st.selectbox("选择知识库", knowledges.keys())
collection_value = knowledges[collection]

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])


def clear_chat_history():
st.session_state.messages = [{"role": "assistant", "content": "欢迎使用知识库检索,请输入主题"}]


st.sidebar.button('清除历史', on_click=clear_chat_history)

if prompt := st.chat_input("输入检索主题"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)

if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = search_knowledge(collection_value, prompt)
if response is None:
response = "没有找到相关知识"
st.markdown(response)
message = {"role": "assistant", "content": response}
st.session_state.messages.append(message)
Loading

0 comments on commit de6cf9a

Please sign in to comment.