From 1bb991da96d53e8aeaf126a33a20caa9a13f3928 Mon Sep 17 00:00:00 2001 From: Jett Wang Date: Tue, 5 Dec 2023 20:52:02 +0800 Subject: [PATCH] ocr vision --- GPTStudio.py | 2 + libs/llms.py | 32 ++++++ pages/02_Knowledge_Search.py | 4 +- pages/{03_Ta365.py => 03_Ta365_Chatbot.py} | 14 +-- .../{04_Speech.py => 04_Speech_Transcribe.py} | 0 pages/06_OCR.py | 10 -- pages/06_OCR_Vision.py | 99 +++++++++++++++++++ pages/Elements.py | 7 ++ 8 files changed, 149 insertions(+), 19 deletions(-) rename pages/{03_Ta365.py => 03_Ta365_Chatbot.py} (92%) rename pages/{04_Speech.py => 04_Speech_Transcribe.py} (100%) delete mode 100644 pages/06_OCR.py create mode 100644 pages/06_OCR_Vision.py create mode 100644 pages/Elements.py diff --git a/GPTStudio.py b/GPTStudio.py index 353166b..d40d9cc 100644 --- a/GPTStudio.py +++ b/GPTStudio.py @@ -1,5 +1,7 @@ import streamlit as st +from libs.msal import msal_auth +msal_auth() def sidebar(): st.sidebar.markdown(""" diff --git a/libs/llms.py b/libs/llms.py index 8208053..3e0b4c8 100644 --- a/libs/llms.py +++ b/libs/llms.py @@ -1,3 +1,5 @@ +import base64 + from openai import OpenAI import os @@ -18,3 +20,33 @@ def openai_streaming(sysmsg, historys: list): ) for chunk in completion: yield chunk.choices[0].delta + + +# 定义函数来调用 OpenAI GPT-4 Vision API +def openai_analyze_image(prompt_str, imagefs): + client = OpenAI() + # 将图像转换为 Base64 编码,这里需要一些额外的处理 + # 假设已经将图像转换为 base64_string + base64_string = base64.b64encode(imagefs.getvalue()).decode('utf-8') + + response = client.chat.completions.create( + model="gpt-4-vision-preview", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt_str or "分析图片内容"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64," + base64_string, + "detail": "high" + }, + }, + ], + } + ], + max_tokens=300, + ) + + return response.choices[0].message.content diff --git a/pages/02_Knowledge_Search.py b/pages/02_Knowledge_Search.py index 335a2fd..8cfcd1c 100644 --- a/pages/02_Knowledge_Search.py +++ b/pages/02_Knowledge_Search.py @@ -23,7 +23,7 @@ if "knowledge_messages" not in st.session_state.keys(): st.session_state.knowledge_messages = [{"role": "assistant", "content": "欢迎使用知识库检索, 请输入主题"}] -collection = st.selectbox("选择知识库", knowledge_dictionary.keys()) +collection = st.sidebar.selectbox("选择知识库", knowledge_dictionary.keys()) collection_value = knowledge_dictionary[collection] for knowledge_messages in st.session_state.knowledge_messages: @@ -38,7 +38,7 @@ def clear_chat_history(): st.sidebar.button('清除历史', on_click=clear_chat_history) if collection_value == "": - st.error("请选择知识库") + st.warning("请选择知识库") st.stop() if prompt := st.chat_input("输入检索主题"): diff --git a/pages/03_Ta365.py b/pages/03_Ta365_Chatbot.py similarity index 92% rename from pages/03_Ta365.py rename to pages/03_Ta365_Chatbot.py index cd26b3c..557b185 100644 --- a/pages/03_Ta365.py +++ b/pages/03_Ta365_Chatbot.py @@ -18,8 +18,7 @@ st.sidebar.markdown("# 💡Ta365 AI 助手") st.title("💡Ta365 AI 助手") -st.markdown("> 一个通用型人工智能助手,可以帮助你解决各种问题。") -st.divider() +st.markdown("> 一个通用型人工智能助手,可以帮助你解决各种问题, 左侧栏可以选择知识库。") if "ta365_messages" not in st.session_state.keys(): st.session_state.ta365_messages = [{"role": "assistant", "content": "我是 Ta365 AI 助手,欢迎提问"}] @@ -37,7 +36,7 @@ def stop_streaming(): st.session_state.ta365_last_user_msg_processed = True -collection = st.selectbox("选择知识库", knowledge_dictionary.keys()) +collection = st.sidebar.selectbox("选择知识库", knowledge_dictionary.keys()) collection_value = knowledge_dictionary[collection] for ta365_messages in st.session_state.ta365_messages: @@ -49,8 +48,7 @@ def clear_chat_history(): st.session_state.ta365_messages = [{"role": "assistant", "content": "我是 Ta365 AI 助手,欢迎提问"}] -st.sidebar.button('清除历史', on_click=clear_chat_history) - +st.sidebar.button('清除对话历史', on_click=clear_chat_history) # 用户输入 if prompt := st.chat_input("输入你的问题"): @@ -66,6 +64,7 @@ def clear_chat_history(): if not st.session_state.ta365_streaming_end: stop_action.button('停止输出', on_click=stop_streaming, help="点击此按钮停止流式输出") + # 用户输入响应,如果上一条消息不是助手的消息,且上一条用户消息还没有处理完毕 if (st.session_state.ta365_messages[-1]["role"] != "assistant" and not st.session_state.ta365_last_user_msg_processed): @@ -75,6 +74,8 @@ def clear_chat_history(): kmsg = "" if collection_value not in "": kmsg = search_knowledge(collection_value, prompt) + if kmsg != "": + st.expander("📚 知识库检索结果", expanded=False).markdown(kmsg) sysmsg = get_ta365_sysmsg(kmsg) response = openai_streaming(sysmsg, st.session_state.ta365_messages[-10:]) # 流式输出 @@ -89,8 +90,7 @@ def clear_chat_history(): full_response += text placeholder.markdown(full_response) placeholder.markdown(full_response) - if kmsg != "": - st.expander("知识库检索结果", expanded=False).markdown(kmsg) + stop_action.empty() # 用于标记流式输出已经结束 diff --git a/pages/04_Speech.py b/pages/04_Speech_Transcribe.py similarity index 100% rename from pages/04_Speech.py rename to pages/04_Speech_Transcribe.py diff --git a/pages/06_OCR.py b/pages/06_OCR.py deleted file mode 100644 index 3f00287..0000000 --- a/pages/06_OCR.py +++ /dev/null @@ -1,10 +0,0 @@ -import streamlit as st -from components.streamlit_tesseract_scanner import tesseract_scanner - -img_file_buffer = st.camera_input("Take a picture") - -blacklist='@*|©_Ⓡ®¢§š' -data = tesseract_scanner(showimg=True, lang='chi_sim+eng', psm=11) - -if data is not None: - st.write(data) diff --git a/pages/06_OCR_Vision.py b/pages/06_OCR_Vision.py new file mode 100644 index 0000000..29e634a --- /dev/null +++ b/pages/06_OCR_Vision.py @@ -0,0 +1,99 @@ +import streamlit as st +from libs.llms import openai_analyze_image, openai_streaming +from libs.msal import msal_auth + +with st.sidebar: + value = msal_auth() + if value is None: + st.stop() + +if "ocr_vision_messages" not in st.session_state.keys(): + st.session_state.ocr_vision_messages = [] + +if "ocr_vision_last_user_msg_processed" not in st.session_state: + st.session_state.ocr_vision_last_user_msg_processed = True + +if "ocr_vision_analysis_result" not in st.session_state: + st.session_state.ocr_vision_analysis_result = "" + +st.sidebar.markdown("# 🔬视觉分析") + +st.title("🔬视觉分析") + + +def clear_result(): + st.session_state.ocr_vision_analysis_result = "" + st.session_state.ocr_vision_last_user_msg_processed = True + st.session_state.ocr_vision_messages = [] + + +def save_result(): + st.session_state.ocr_vision_analysis_result = st.session_state.ocr_vision_analysis_result_temp + + +# Streamlit 应用的主要部分 +col1, col2, = st.columns([3, 6]) + +# 摄像头输入获取图片 +image = col1.camera_input("点击按钮截图", on_change=clear_result) + +# 图像分析提示输入 +prompt = col2.text_input("图像分析提示", "识别分析图片内容") + +# 重新获取图像时触发图像分析 +if image is not None and not st.session_state.ocr_vision_analysis_result: + with col2: + with st.spinner("分析中..."): + st.session_state.ocr_vision_analysis_result = openai_analyze_image(prompt, image) + +# 使用文本区域组件显示分析结果, 支持手工修改 +if st.session_state.ocr_vision_analysis_result: + with col2: + st.text_area("识别结果(请手工修正识别错误)", + value=st.session_state.ocr_vision_analysis_result, + key="ocr_vision_analysis_result_temp", + on_change=save_result, + height=170) + + +for ocr_vision_messages in st.session_state.ocr_vision_messages: + with st.chat_message(ocr_vision_messages["role"]): + st.write(ocr_vision_messages["content"]) + +if uprompt := st.chat_input("输入你的问题"): + # 用于标记用户消息还没有处理 + st.session_state.ocr_vision_last_user_msg_processed = False + st.session_state.ocr_vision_messages.append({"role": "user", "content": uprompt}) + with st.chat_message("user"): + st.write(uprompt) + +# 用户输入响应,如果上一条消息不是助手的消息,且上一条用户消息还没有处理完毕 +if ((st.session_state.ocr_vision_messages and + st.session_state.ocr_vision_messages[-1]["role"] != "assistant" and + not st.session_state.ocr_vision_last_user_msg_processed) and + st.session_state.ocr_vision_analysis_result not in [""]): + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + sysmsg = f"""" + 以下是来自一图片识别获取的内容结果: + ''' + {st.session_state.ocr_vision_analysis_result} + ''' + 我们将围绕这个内容进行深入讨论。 + """ + response = openai_streaming(sysmsg, st.session_state.ocr_vision_messages[-10:]) + # 流式输出 + placeholder = st.empty() + full_response = '' + for item in response: + text = item.content + if text is not None: + full_response += text + placeholder.markdown(full_response) + placeholder.markdown(full_response) + + # 用于标记上一条用户消息已经处理完毕 + st.session_state.ocr_vision_last_user_msg_processed = True + # 追加对话记录 + message = {"role": "assistant", "content": full_response} + st.session_state.ocr_vision_messages.append(message) diff --git a/pages/Elements.py b/pages/Elements.py new file mode 100644 index 0000000..3afad68 --- /dev/null +++ b/pages/Elements.py @@ -0,0 +1,7 @@ +from streamlit_player import st_player + +# Embed a youtube video +st_player("https://youtu.be/CmSKVW1v0xM") + +# Embed a music from SoundCloud +st_player("https://soundcloud.com/imaginedragons/demons")