Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Allow specifying chat history for LMs #1435

Open
ShaojieJiang opened this issue Aug 28, 2024 · 8 comments
Open

Feature request: Allow specifying chat history for LMs #1435

ShaojieJiang opened this issue Aug 28, 2024 · 8 comments

Comments

@ShaojieJiang
Copy link

Hi DSPy developers,

First of all, thanks a lot for this great work!

Recently I've been trying to integrate DSPy into my work, but I stumbled upon the chat history specification. My task is to design a professional interviewer chatbot, therefore, the optimisation target for me is to optimise the conversation flow in diverse cases (such as with both cooperative and uncooperative interviewees). This is different from the existing DSPy modules which mostly focus on answer generation.

Although my task is also roughly answer generation, being able to specify the chat history is important. So I'm wondering is it already planned to add support for history specification? If not, do you think this aligns well with your development agenda?

For the moment I'll add the support in my inherited LM class. Looking forward to your response!

Best regards,
Shaojie Jiang

@ShaojieJiang
Copy link
Author

Update: I quickly realised that it's more complicated than I thought. After some investigation, I managed to make it work for my case. Please see the MWE:
chatbot.py.zip

Not sure I did it in an ideal way, but the solution looked quite hacky to me. Any suggestions on a better solution are much appreciated!

@okhat
Copy link
Collaborator

okhat commented Sep 1, 2024

Hey @ShaojieJiang ! What does specifying chat history mean? Do you mean passing multiple user/assistant turns to chat LMs?

@ktzsh
Copy link

ktzsh commented Nov 5, 2024

Hi @okhat, I have a similar scenario. What would be the best way to emulate these multiple user/assistant turns within input so that the end user can ask follow-up questions in a conversational flow?

One obvious way is to convert question->answer into question,history->answer but that would lose the nuances of multi-turn. An ideal way, I think should be so that history can be parsed into llm chat template similar to what @ShaojieJiang has done

@ShaojieJiang
Copy link
Author

Hi @okhat , apologies for the late response. I've found a solution for TextGrad and didn't pay much attention to the GitHub notifications.

Basically, I want to be able to specify the chat history in their original format ({"role": "user", "content": "..."}...) instead of the stringified version of the history (Here is the chat history: User: ... Assistant: ...). This is to optimise the LLM behaviour under given contexts (good or bad) so as to train them to handle different situations.

Here you can find a MWE of my solution for TextGrad. I tried the same for DSPy, but it's much more hacky as you might have seen from my attachments above.

@svenhimmelvarg
Copy link

I am in a similar boat where I want to be able to do something like this:

messages=[]
resp=dspy.Predict(MyQA)("What is the capital of France...")
messages.append(resp)
resp=dspy.Predict(MyQA)("What is the native language name for the country",  messages)
messages.append(resp)
resp=dspy.Predict(MyQA)("What is the main district of the city",  messages)
....  

Being able to simulate a conversation would be useful.
Also being able to evaluate this would be useful to.

Is this currently possible with the current dspy implementation?

@itay1542
Copy link
Contributor

itay1542 commented Dec 26, 2024

You can achieve properly passing the chat history by overriding the default dspy.ChatAdatper and its method format.
Then you can configure it using dspy.settings.configure(adapter=your_adapter_instance)

Here is an example:

import dspy

class ChatModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predict = dspy.Predict("question -> answer")
        self.conversation_history = []
    
    def forward(self, question: str):
        prediction = self.predict(question=question, history_messages=self.conversation_history)
        self.conversation_history.append(
            {
                "inputs": dict(question=question),
                "outputs": prediction
            }
        )
        return prediction

class MultiTurnChatAdapter(dspy.ChatAdapter):    
    def format(self, signature, demos, inputs) -> list[dict[str, Any]]:
        history_messages = inputs.pop("history_messages", [])
        inputs_ = super().format(signature, demos, inputs) # system prompt, demos, user prompt
        formatted_history = []
        for turn in history_messages:
            formatted_history.append(format_turn(signature, turn["inputs"], role="user"))
            formatted_history.append(format_turn(signature, turn["outputs"], role="assistant"))
        
        return inputs_[:-1] + formatted_history + [inputs_[-1]] # concat system and demos with past turns and finally the user prompt

usage:

# assuming you already have a LM configured
dspy.settings.configure(adapter=MultiTurnChatAdapter())
chat_module = ChatModule()
chat_module("My name is Itay")
chat_module("What is my name?")
Prediction(
    answer='Your name is Itay.'
)

@vanakema
Copy link

vanakema commented Jan 31, 2025

Is there no plan for implementing multi-turn conversation natively in DSPy? For any kind of chatbot, being able to pass chat history is necessary. Of course you can pass it in as an input field, but ultimately it's not representing that chat history as the model's native syntax, so it seems like it almost certainly wouldn't be as performant simply embedding it into a single user message like DSPy currently does. I could be wrong about my assumptions and maybe LLMs handle this equally well, but chat models are trained to handle chat history this way, so by doing it the way it's currently implemented (sys message, entire chat history embedded in a single user message) is taking it more out of distribution.

If my assumptions are true (def worth testing first), I think it makes a lot of sense for multi-turn conversation to be a first-class citizen of DSPy.

@hung-phan
Copy link
Contributor

hung-phan commented Jan 31, 2025

I think you can do sth like this

from itertools import chain
from typing import Any, Callable, TypedDict

from dspy import Module, Signature
from dspy.adapters.chat_adapter import (
    BuiltInCompletedOutputFieldInfo,
    ChatAdapter,
    FieldInfoWithName,
    format_fields,
)
from dspy.adapters.json_adapter import get_annotation_name


class PersistedMessage(TypedDict):
    inputs: dict[str, Any]
    outputs: dict[str, Any]


class PersistedChatHistory:
    def __init__(
        self,
        module: Module,
        *,
        selected_inputs: list[str] | None = None,
        selected_outputs: list[str] | None = None,
    ):
        self.module: Module = module
        self.selected_inputs: list[str] | None = selected_inputs
        self.selected_outputs: list[str] | None = selected_outputs
        self.historical_messages: list[PersistedMessage] = []

    def __call__(self, **kwargs):
        result = self.module(**kwargs, historical_messages=self.historical_messages)

        self.historical_messages.append(
            PersistedMessage(
                inputs=(
                    {key: kwargs[key] for key in self.selected_inputs}
                    if self.selected_inputs
                    else kwargs
                ),
                outputs=(
                    {key: result[key] for key in self.selected_outputs}
                    if self.selected_outputs
                    else dict(result)
                ),
            )
        )

        return result

    def inject_history(self, messages: list[PersistedMessage]) -> None:
        self.historical_messages = messages

    def reset_history(self) -> None:
        self.historical_messages = []


class PersistedChatAdapter(ChatAdapter):
    def format(
        self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]
    ) -> list[dict[str, Any]]:
        # system prompt, demos, user prompt
        formatted_inputs = super().format(signature, demos, inputs)

        # get historical messages
        historical_messages = inputs.pop("historical_messages", [])

        formatted_history = [
            format_historical_turn(signature, turn[key], role=role)
            for turn in historical_messages
            for key, role in [("inputs", "user"), ("outputs", "assistant")]
        ]

        # concat system and demos with past turns and finally the user prompt
        return formatted_inputs[:-1] + formatted_history + [formatted_inputs[-1]]


def format_historical_turn(signature, values, role):
    fields_to_collapse = []

    if role == "user":
        fields = signature.input_fields
        fields_to_collapse.append(
            {
                "type": "text",
                "text": "This is a past message of the task, though some input or output fields are not supplied.",
            }
        )
    else:
        fields = signature.output_fields
        # Add the built-in field indicating that the chat turn has been completed
        fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
        values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}

    fields_to_collapse.extend(
        format_fields(
            fields_with_values={
                FieldInfoWithName(name=field_name, info=field_info): values.get(
                    field_name, "Not supplied for this particular message."
                )
                for field_name, field_info in fields.items()
            },
            assume_text=False,
        )
    )

    if role == "user":
        output_fields = list(signature.output_fields.keys())

        def type_info(v):
            return (
                f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
                if v.annotation is not str
                else ""
            )

        if output_fields:
            fields_to_collapse.append(
                {
                    "type": "text",
                    "text": "Respond with the corresponding output fields, starting with the field "
                    + ", then ".join(
                        f"`[[ ## {f} ## ]]`{type_info(v)}"
                        for f, v in signature.output_fields.items()
                    )
                    + ", and then ending with the marker for `[[ ## completed ## ]]`.",
                }
            )

    # flatmap the list if any items are lists otherwise keep the item
    flattened_list = list(
        chain.from_iterable(
            item if isinstance(item, list) else [item] for item in fields_to_collapse
        )
    )

    if all(message.get("type", None) == "text" for message in flattened_list):
        content = "\n\n".join(message.get("text", "") for message in flattened_list)
        return {"role": role, "content": content}

    # Collapse all consecutive text messages into a single message.
    collapsed_messages: list[dict] = []
    for item in flattened_list:
        # First item is always added
        if not collapsed_messages:
            collapsed_messages.append(item)
            continue

        # If the current item is image, add to collapsed_messages
        if item.get("type") == "image_url":
            if collapsed_messages[-1].get("type") == "text":
                collapsed_messages[-1]["text"] += "\n"
            collapsed_messages.append(item)
        # If the previous item is text and current item is text, append to the previous item
        elif collapsed_messages[-1].get("type") == "text":
            collapsed_messages[-1]["text"] += "\n\n" + item["text"]
        # If the previous item is not text(aka image), add the current item as a new item
        else:
            item["text"] = "\n\n" + item["text"]
            collapsed_messages.append(item)

    return {"role": role, "content": collapsed_messages}

Some usages of it

lm, adapter = dspy.LM(**config), PersistedChatAdapter()

module = PersistedChatHistory(dspy.ChainOfThought("question -> answer"))

with dspy.context(lm=lm, adapter=adapter):
    print(module(question="What is the capital of France?"))
    print(module(question="How about Vietnam?"))

lm.inspect_history()
lm, adapter = dspy.LM(**config), PersistedChatAdapter()

module = PersistedChatHistory(dspy.ChainOfThought("question -> answer"))

module.inject_history([
    PersistedMessage(
        inputs={"question": "Who is Pikachu?"},
        outputs={"answer": "Pikachu is a pokemon"}
    )
])

with dspy.context(lm=lm, adapter=adapter):
    print(module(question="What color is it?"))
lm, adapter = dspy.LM(**config), PersistedChatAdapter()

module = PersistedChatHistory(dspy.ChainOfThought("question -> answer"))

with dspy.context(lm=lm, adapter=adapter):
    print(module(question="Who is Pikachu?"))
    module.reset_history()
    print(module(question="What color is it?"))

selected_inputs and selected_outputs allow you more control on picking the field out of input and output object

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants