-
Notifications
You must be signed in to change notification settings - Fork 497
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Embedded apps for chat tutorials (#1218)
* Add chat tutorial app sources * Update tutorial-chat-feedback.py * Update chat app sources
- Loading branch information
1 parent
f04fc75
commit 85cc4e6
Showing
3 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
lorem |
45 changes: 45 additions & 0 deletions
45
python/api-examples-source/tutorials/chat/tutorial-chat-feedback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import streamlit as st | ||
import time | ||
|
||
|
||
def chat_stream(prompt): | ||
response = f'You said, "{prompt}" ...interesting.' | ||
for char in response: | ||
yield char | ||
time.sleep(0.02) | ||
|
||
|
||
def save_feedback(index): | ||
st.session_state.history[index]["feedback"] = st.session_state[f"feedback_{index}"] | ||
|
||
|
||
if "history" not in st.session_state: | ||
st.session_state.history = [] | ||
|
||
for i, message in enumerate(st.session_state.history): | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
if message["role"] == "assistant": | ||
feedback = message.get("feedback", None) | ||
st.session_state[f"feedback_{i}"] = feedback | ||
st.feedback( | ||
"thumbs", | ||
key=f"feedback_{i}", | ||
disabled=feedback is not None, | ||
on_change=save_feedback, | ||
args=[i], | ||
) | ||
|
||
if prompt := st.chat_input("Say something"): | ||
with st.chat_message("user"): | ||
st.write(prompt) | ||
st.session_state.history.append({"role": "user", "content": prompt}) | ||
with st.chat_message("assistant"): | ||
response = st.write_stream(chat_stream(prompt)) | ||
st.feedback( | ||
"thumbs", | ||
key=f"feedback_{len(st.session_state.history)}", | ||
on_change=save_feedback, | ||
args=[len(st.session_state.history)], | ||
) | ||
st.session_state.history.append({"role": "assistant", "content": response}) |
147 changes: 147 additions & 0 deletions
147
python/api-examples-source/tutorials/chat/tutorial-chat-revision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import streamlit as st | ||
import lorem | ||
from random import randint | ||
import time | ||
|
||
if "stage" not in st.session_state: | ||
st.session_state.stage = "user" | ||
st.session_state.history = [] | ||
st.session_state.pending = None | ||
st.session_state.validation = {} | ||
|
||
|
||
def chat_stream(): | ||
for i in range(randint(3, 9)): | ||
yield lorem.sentence() + " " | ||
time.sleep(0.2) | ||
|
||
|
||
def validate(response): | ||
response_sentences = response.split(". ") | ||
response_sentences = [ | ||
sentence.strip(". ") + "." | ||
for sentence in response_sentences | ||
if sentence.strip(". ") != "" | ||
] | ||
validation_list = [ | ||
True if sentence.count(" ") > 4 else False for sentence in response_sentences | ||
] | ||
return response_sentences, validation_list | ||
|
||
|
||
def add_highlights(response_sentences, validation_list, bg="red", text="red"): | ||
return [ | ||
f":{text}[:{bg}-background[" + sentence + "]]" if not is_valid else sentence | ||
for sentence, is_valid in zip(response_sentences, validation_list) | ||
] | ||
|
||
|
||
for message in st.session_state.history: | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
if st.session_state.stage == "user": | ||
if user_input := st.chat_input("Enter a prompt"): | ||
st.session_state.history.append({"role": "user", "content": user_input}) | ||
with st.chat_message("user"): | ||
st.write(user_input) | ||
with st.chat_message("assistant"): | ||
response = st.write_stream(chat_stream()) | ||
st.session_state.pending = response | ||
st.session_state.stage = "validate" | ||
st.rerun() | ||
|
||
elif st.session_state.stage == "validate": | ||
st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) | ||
response_sentences, validation_list = validate(st.session_state.pending) | ||
highlighted_sentences = add_highlights(response_sentences, validation_list) | ||
with st.chat_message("assistant"): | ||
st.markdown(" ".join(highlighted_sentences)) | ||
st.divider() | ||
cols = st.columns(3) | ||
if cols[0].button( | ||
"Correct errors", type="primary", disabled=all(validation_list) | ||
): | ||
st.session_state.validation = { | ||
"sentences": response_sentences, | ||
"valid": validation_list, | ||
} | ||
st.session_state.stage = "correct" | ||
st.rerun() | ||
if cols[1].button("Accept"): | ||
st.session_state.history.append( | ||
{"role": "assistant", "content": st.session_state.pending} | ||
) | ||
st.session_state.pending = None | ||
st.session_state.validation = {} | ||
st.session_state.stage = "user" | ||
st.rerun() | ||
if cols[2].button("Rewrite answer", type="tertiary"): | ||
st.session_state.stage = "rewrite" | ||
st.rerun() | ||
|
||
elif st.session_state.stage == "correct": | ||
st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) | ||
response_sentences = st.session_state.validation["sentences"] | ||
validation_list = st.session_state.validation["valid"] | ||
highlighted_sentences = add_highlights( | ||
response_sentences, validation_list, "gray", "gray" | ||
) | ||
if not all(validation_list): | ||
focus = validation_list.index(False) | ||
highlighted_sentences[focus] = ":red[:red" + highlighted_sentences[focus][11:] | ||
else: | ||
focus = None | ||
with st.chat_message("assistant"): | ||
st.markdown(" ".join(highlighted_sentences)) | ||
st.divider() | ||
if focus is not None: | ||
new_sentence = st.text_input( | ||
"Replacement text:", value=response_sentences[focus] | ||
) | ||
cols = st.columns(2) | ||
if cols[0].button( | ||
"Update", type="primary", disabled=len(new_sentence.strip()) < 1 | ||
): | ||
st.session_state.validation["sentences"][focus] = ( | ||
new_sentence.strip(". ") + "." | ||
) | ||
st.session_state.validation["valid"][focus] = True | ||
st.session_state.pending = " ".join( | ||
st.session_state.validation["sentences"] | ||
) | ||
st.rerun() | ||
if cols[1].button("Remove"): | ||
st.session_state.validation["sentences"].pop(focus) | ||
st.session_state.validation["valid"].pop(focus) | ||
st.session_state.pending = " ".join( | ||
st.session_state.validation["sentences"] | ||
) | ||
st.rerun() | ||
else: | ||
cols = st.columns(2) | ||
if cols[0].button("Accept", type="primary"): | ||
st.session_state.history.append( | ||
{"role": "assistant", "content": st.session_state.pending} | ||
) | ||
st.session_state.pending = None | ||
st.session_state.validation = {} | ||
st.session_state.stage = "user" | ||
st.rerun() | ||
if cols[1].button("Re-validate"): | ||
st.session_state.validation = {} | ||
st.session_state.stage = "validate" | ||
st.rerun() | ||
|
||
elif st.session_state.stage == "rewrite": | ||
st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) | ||
with st.chat_message("assistant"): | ||
new = st.text_area("Rewrite the answer", value=st.session_state.pending) | ||
if st.button( | ||
"Update", type="primary", disabled=new is None or new.strip(". ") == "" | ||
): | ||
st.session_state.history.append({"role": "assistant", "content": new}) | ||
st.session_state.pending = None | ||
st.session_state.validation = {} | ||
st.session_state.stage = "user" | ||
st.rerun() |