Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent the bot from replying multiple text messages #40

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
23 changes: 12 additions & 11 deletions cogs/messagehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,18 @@ async def handle_text_message(self, message, mode=""):
message,
)
await self.add_message_to_dict(message, message.clean_content)
async with message.channel.typing():
# If the response is more than 2000 characters, split it
chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)]
for chunk in chunks:
print(chunk)
response_obj = await message.channel.send(chunk)
await self.add_message_to_dict(
response_obj, response_obj.clean_content
)
# self.bot.sent_last_message[str(message.channel.id)] = True
# await log_message(response_obj)
if response:
async with message.channel.typing():
# If the response is more than 2000 characters, split it
chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)]
for chunk in chunks:
print(chunk)
response_obj = await message.channel.send(chunk)
await self.add_message_to_dict(
response_obj, response_obj.clean_content
)
# self.bot.sent_last_message[str(message.channel.id)] = True
# await log_message(response_obj)

async def set_listen_only_mode_timer(self, channel_id):
# Start the timer
Expand Down
47 changes: 41 additions & 6 deletions cogs/pygbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from discord import app_commands
from discord.ext import commands
import os
import asyncio


# load environment STOP_SEQUENCES variables and split them into a list by comma
Expand Down Expand Up @@ -127,7 +128,7 @@ async def get_memory_for_channel(self, channel_id):
name = message[0]
channel_ids = str(message[1])
message = message[2]
print(f"{name}: {message}")
#print(f"{name}: {message}")
await self.add_history(name, channel_ids, message)

# self.memory = self.histories[channel_id]
Expand Down Expand Up @@ -160,7 +161,7 @@ async def generate_response(self, message, message_content) -> None:
name = message.author.display_name
memory = await self.get_memory_for_channel(str(channel_id))
stop_sequence = await self.get_stop_sequence_for_channel(channel_id, name)
print(f"stop sequences: {stop_sequence}")
#print(f"stop sequences: {stop_sequence}")
formatted_message = f"{name}: {message_content}"
MAIN_TEMPLATE = f"""
{self.top_character_info}
Expand All @@ -180,7 +181,13 @@ async def generate_response(self, message, message_content) -> None:
memory=memory,
)
input_dict = {"input": formatted_message, "stop": stop_sequence}
response_text = conversation(input_dict)

# Run the conversation chain
if self.bot.koboldcpp_version >= 1.29:
response_text = await conversation.acall(input_dict,channel_id)
else:
response_text = await conversation.acall(input_dict)

response = await self.detect_and_replace_out(response_text["response"])
with open(self.convo_filename, "a", encoding="utf-8") as f:
f.write(f"{message.author.display_name}: {message_content}\n")
Expand All @@ -199,7 +206,7 @@ async def add_history(self, name, channel_id, message_content) -> None:
formatted_message = f"{name}: {message_content}"

# add the message to the memory
print(f"adding message to memory: {formatted_message}")
#print(f"adding message to memory: {formatted_message}")
memory.add_input_only(formatted_message)
return None

Expand All @@ -210,6 +217,10 @@ def __init__(self, bot):
self.chatlog_dir = bot.chatlog_dir
self.chatbot = Chatbot(bot)

# Store current task and last message here
self.current_tasks = {}
self.last_messages = {}

# create chatlog directory if it doesn't exist
if not os.path.exists(self.chatlog_dir):
os.makedirs(self.chatlog_dir)
Expand All @@ -233,8 +244,32 @@ async def chat_command(self, name, channel_id, message_content, message) -> None
and self.chatbot.convo_filename != chatlog_filename
):
await self.chatbot.set_convo_filename(chatlog_filename)
response = await self.chatbot.generate_response(message, message_content)
return response

# Check if the task is still running by channel ID
#print(f"The current task is: {self.current_tasks[channel_id]}") # for debugging purposes
if channel_id in self.current_tasks:
task = self.current_tasks[channel_id]

if task is not None and not task.done():
# Cancelling previous task, add last message to the history
await self.chatbot.add_history(name, str(channel_id), self.last_messages[channel_id])

# If the endpoint is koboldcpp, stop the generation by channel ID
if self.bot.koboldcpp_version >= 1.29:
await self.bot.llm._stop(channel_id)

self.current_task.cancel()

# Create a new task and last message bounded to the channel ID
self.last_messages[channel_id] = message_content
self.current_tasks[channel_id] = asyncio.create_task(self.chatbot.generate_response(message, message_content))

try:
response = await self.current_tasks[channel_id]
return response
except asyncio.CancelledError:
print(f"Cancelled {self.chatbot.char_name}'s current response, regenerate another reply...")
return None

# No Response Handler
@commands.command(name="chatnr")
Expand Down
10 changes: 9 additions & 1 deletion discordbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pathlib import Path
import base64
from helpers.textgen import TextGen
from langchain.llms import KoboldApiLLM, OpenAI
from helpers.koboldai import KoboldApiLLM
from langchain.llms import OpenAI
from discord import app_commands
from discord.ext import commands
from discord.ext.commands import Bot
Expand Down Expand Up @@ -237,6 +238,13 @@ async def on_ready():
"\n\n\n\nERROR: Unable to retrieve channel from .env \nPlease make sure you're using a valid channel ID, not a server ID."
)

# Check if the endpoint is connected to koboldcpp
if bot.llm._llm_type == "koboldai":
bot.koboldcpp_version = bot.llm.check_version()
print(f"KoboldCPP Version: {bot.koboldcpp_version}")
else:
bot.koboldcpp_version = 0.0


# COG LOADER
async def load_cogs() -> None:
Expand Down
Loading