Skip to content

Commit

Permalink
ocr vision
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiesun committed Dec 5, 2023
1 parent 75980e1 commit 1bb991d
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 19 deletions.
2 changes: 2 additions & 0 deletions GPTStudio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import streamlit as st
from libs.msal import msal_auth

msal_auth()

def sidebar():
st.sidebar.markdown("""
Expand Down
32 changes: 32 additions & 0 deletions libs/llms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64

from openai import OpenAI
import os

Expand All @@ -18,3 +20,33 @@ def openai_streaming(sysmsg, historys: list):
)
for chunk in completion:
yield chunk.choices[0].delta


# 定义函数来调用 OpenAI GPT-4 Vision API
def openai_analyze_image(prompt_str, imagefs):
client = OpenAI()
# 将图像转换为 Base64 编码,这里需要一些额外的处理
# 假设已经将图像转换为 base64_string
base64_string = base64.b64encode(imagefs.getvalue()).decode('utf-8')

response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt_str or "分析图片内容"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_string,
"detail": "high"
},
},
],
}
],
max_tokens=300,
)

return response.choices[0].message.content
4 changes: 2 additions & 2 deletions pages/02_Knowledge_Search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if "knowledge_messages" not in st.session_state.keys():
st.session_state.knowledge_messages = [{"role": "assistant", "content": "欢迎使用知识库检索, 请输入主题"}]

collection = st.selectbox("选择知识库", knowledge_dictionary.keys())
collection = st.sidebar.selectbox("选择知识库", knowledge_dictionary.keys())
collection_value = knowledge_dictionary[collection]

for knowledge_messages in st.session_state.knowledge_messages:
Expand All @@ -38,7 +38,7 @@ def clear_chat_history():
st.sidebar.button('清除历史', on_click=clear_chat_history)

if collection_value == "":
st.error("请选择知识库")
st.warning("请选择知识库")
st.stop()

if prompt := st.chat_input("输入检索主题"):
Expand Down
14 changes: 7 additions & 7 deletions pages/03_Ta365.py → pages/03_Ta365_Chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
st.sidebar.markdown("# 💡Ta365 AI 助手")

st.title("💡Ta365 AI 助手")
st.markdown("> 一个通用型人工智能助手,可以帮助你解决各种问题。")
st.divider()
st.markdown("> 一个通用型人工智能助手,可以帮助你解决各种问题, 左侧栏可以选择知识库。")

if "ta365_messages" not in st.session_state.keys():
st.session_state.ta365_messages = [{"role": "assistant", "content": "我是 Ta365 AI 助手,欢迎提问"}]
Expand All @@ -37,7 +36,7 @@ def stop_streaming():
st.session_state.ta365_last_user_msg_processed = True


collection = st.selectbox("选择知识库", knowledge_dictionary.keys())
collection = st.sidebar.selectbox("选择知识库", knowledge_dictionary.keys())
collection_value = knowledge_dictionary[collection]

for ta365_messages in st.session_state.ta365_messages:
Expand All @@ -49,8 +48,7 @@ def clear_chat_history():
st.session_state.ta365_messages = [{"role": "assistant", "content": "我是 Ta365 AI 助手,欢迎提问"}]


st.sidebar.button('清除历史', on_click=clear_chat_history)

st.sidebar.button('清除对话历史', on_click=clear_chat_history)

# 用户输入
if prompt := st.chat_input("输入你的问题"):
Expand All @@ -66,6 +64,7 @@ def clear_chat_history():
if not st.session_state.ta365_streaming_end:
stop_action.button('停止输出', on_click=stop_streaming, help="点击此按钮停止流式输出")


# 用户输入响应,如果上一条消息不是助手的消息,且上一条用户消息还没有处理完毕
if (st.session_state.ta365_messages[-1]["role"] != "assistant"
and not st.session_state.ta365_last_user_msg_processed):
Expand All @@ -75,6 +74,8 @@ def clear_chat_history():
kmsg = ""
if collection_value not in "":
kmsg = search_knowledge(collection_value, prompt)
if kmsg != "":
st.expander("📚 知识库检索结果", expanded=False).markdown(kmsg)
sysmsg = get_ta365_sysmsg(kmsg)
response = openai_streaming(sysmsg, st.session_state.ta365_messages[-10:])
# 流式输出
Expand All @@ -89,8 +90,7 @@ def clear_chat_history():
full_response += text
placeholder.markdown(full_response)
placeholder.markdown(full_response)
if kmsg != "":
st.expander("知识库检索结果", expanded=False).markdown(kmsg)


stop_action.empty()
# 用于标记流式输出已经结束
Expand Down
File renamed without changes.
10 changes: 0 additions & 10 deletions pages/06_OCR.py

This file was deleted.

99 changes: 99 additions & 0 deletions pages/06_OCR_Vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import streamlit as st
from libs.llms import openai_analyze_image, openai_streaming
from libs.msal import msal_auth

with st.sidebar:
value = msal_auth()
if value is None:
st.stop()

if "ocr_vision_messages" not in st.session_state.keys():
st.session_state.ocr_vision_messages = []

if "ocr_vision_last_user_msg_processed" not in st.session_state:
st.session_state.ocr_vision_last_user_msg_processed = True

if "ocr_vision_analysis_result" not in st.session_state:
st.session_state.ocr_vision_analysis_result = ""

st.sidebar.markdown("# 🔬视觉分析")

st.title("🔬视觉分析")


def clear_result():
st.session_state.ocr_vision_analysis_result = ""
st.session_state.ocr_vision_last_user_msg_processed = True
st.session_state.ocr_vision_messages = []


def save_result():
st.session_state.ocr_vision_analysis_result = st.session_state.ocr_vision_analysis_result_temp


# Streamlit 应用的主要部分
col1, col2, = st.columns([3, 6])

# 摄像头输入获取图片
image = col1.camera_input("点击按钮截图", on_change=clear_result)

# 图像分析提示输入
prompt = col2.text_input("图像分析提示", "识别分析图片内容")

# 重新获取图像时触发图像分析
if image is not None and not st.session_state.ocr_vision_analysis_result:
with col2:
with st.spinner("分析中..."):
st.session_state.ocr_vision_analysis_result = openai_analyze_image(prompt, image)

# 使用文本区域组件显示分析结果, 支持手工修改
if st.session_state.ocr_vision_analysis_result:
with col2:
st.text_area("识别结果(请手工修正识别错误)",
value=st.session_state.ocr_vision_analysis_result,
key="ocr_vision_analysis_result_temp",
on_change=save_result,
height=170)


for ocr_vision_messages in st.session_state.ocr_vision_messages:
with st.chat_message(ocr_vision_messages["role"]):
st.write(ocr_vision_messages["content"])

if uprompt := st.chat_input("输入你的问题"):
# 用于标记用户消息还没有处理
st.session_state.ocr_vision_last_user_msg_processed = False
st.session_state.ocr_vision_messages.append({"role": "user", "content": uprompt})
with st.chat_message("user"):
st.write(uprompt)

# 用户输入响应,如果上一条消息不是助手的消息,且上一条用户消息还没有处理完毕
if ((st.session_state.ocr_vision_messages and
st.session_state.ocr_vision_messages[-1]["role"] != "assistant" and
not st.session_state.ocr_vision_last_user_msg_processed) and
st.session_state.ocr_vision_analysis_result not in [""]):
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
sysmsg = f""""
以下是来自一图片识别获取的内容结果:
'''
{st.session_state.ocr_vision_analysis_result}
'''
我们将围绕这个内容进行深入讨论。
"""
response = openai_streaming(sysmsg, st.session_state.ocr_vision_messages[-10:])
# 流式输出
placeholder = st.empty()
full_response = ''
for item in response:
text = item.content
if text is not None:
full_response += text
placeholder.markdown(full_response)
placeholder.markdown(full_response)

# 用于标记上一条用户消息已经处理完毕
st.session_state.ocr_vision_last_user_msg_processed = True
# 追加对话记录
message = {"role": "assistant", "content": full_response}
st.session_state.ocr_vision_messages.append(message)
7 changes: 7 additions & 0 deletions pages/Elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from streamlit_player import st_player

# Embed a youtube video
st_player("https://youtu.be/CmSKVW1v0xM")

# Embed a music from SoundCloud
st_player("https://soundcloud.com/imaginedragons/demons")

0 comments on commit 1bb991d

Please sign in to comment.