Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Avoid generating empty tool messages when there's no content for it #138

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions src/MicrosoftAi/AbstractionMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,38 @@ public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessag

if (options?.AdditionalProperties?.Any() ?? false)
{
TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop, v => request.Options.Stop = v);
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = v);
TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = (int?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = (int?)v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop,
v => request.Options.Stop = (v as IEnumerable<string>)?.ToArray());
awaescher marked this conversation as resolved.
Show resolved Hide resolved
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = (bool?)v);
}

return request;
Expand All @@ -113,10 +114,10 @@ public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessag
/// <param name="microsoftChatOptions">The chat options from the Microsoft abstraction</param>
/// <param name="option">The Ollama setting to add</param>
/// <param name="optionSetter">The setter to set the Ollama option if available in the chat options</param>
private static void TryAddOllamaOption<T>(ChatOptions microsoftChatOptions, OllamaOption option, Action<T> optionSetter)
private static void TryAddOllamaOption<T>(ChatOptions microsoftChatOptions, OllamaOption option, Action<object?> optionSetter)
{
if ((microsoftChatOptions?.AdditionalProperties?.TryGetValue(option.Name, out var value) ?? false) && value is not null)
optionSetter((T)value);
optionSetter(value);
}

/// <summary>
Expand Down Expand Up @@ -200,13 +201,17 @@ private static IEnumerable<Message> ToOllamaSharpMessages(IList<ChatMessage> cha
var images = cm.Contents.OfType<ImageContent>().Select(ToOllamaImage).Where(s => !string.IsNullOrEmpty(s)).ToArray();
var toolCalls = cm.Contents.OfType<FunctionCallContent>().Select(ToOllamaSharpToolCall).ToArray();

yield return new Message
// Only generates a message if there is text/content, images or tool calls
if (cm.Text is not null || images.Length > 0 || toolCalls.Length > 0)
{
Content = cm.Text,
Images = images.Length > 0 ? images : null,
Role = ToOllamaSharpRole(cm.Role),
ToolCalls = toolCalls.Length > 0 ? toolCalls : null,
};
yield return new Message
{
Content = cm.Text,
Images = images.Length > 0 ? images : null,
Role = ToOllamaSharpRole(cm.Role),
ToolCalls = toolCalls.Length > 0 ? toolCalls : null,
};
}

// If the message contains a function result, add it as a separate tool message
foreach (var frc in cm.Contents.OfType<FunctionResultContent>())
Expand Down
88 changes: 88 additions & 0 deletions test/AbstractionMapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,53 @@ public void Maps_Messages_With_Tools()
tool.Type.Should().Be("function");
}

[TestCaseSource(nameof(StopSequencesTestData))]
public void Maps_Messages_With_IEnumerable_StopSequences(object? enumerable)
{
var chatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.User,
Text = "What's the weather in Honululu?"
}
};

var options = new ChatOptions()
{
AdditionalProperties = new AdditionalPropertiesDictionary() { ["stop"] = enumerable }
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, JsonSerializerOptions.Default);

var stopSequences = chatRequest.Options.Stop;
var typedEnumerable = (IEnumerable<string>?)enumerable;

if (typedEnumerable == null)
{
stopSequences.Should().BeNull();
return;
}
stopSequences.Should().HaveCount(typedEnumerable?.Count() ?? 0);
}

public static IEnumerable<TestCaseData> StopSequencesTestData
{
get
{
yield return new TestCaseData((object?)(JsonSerializer.Deserialize<JsonElement>("[\"stop1\", \"stop2\"]")).EnumerateArray().Select(e => e.GetString()));
yield return new TestCaseData((object?)(IEnumerable<string>?)null);
yield return new TestCaseData((object?)new List<string> { "stop1", "stop2", "stop3", "stop4" });
yield return new TestCaseData((object?)new string[] { "stop1", "stop2", "stop3" });
yield return new TestCaseData((object?)new HashSet<string> { "stop1", "stop2", });
yield return new TestCaseData((object?)new Stack<string>(new[] { "stop1" }));
yield return new TestCaseData((object?)new Queue<string>(new[] { "stop1" }));
}
}

[Test]
public void Maps_Messages_With_ToolResponse()
{
Expand Down Expand Up @@ -316,6 +363,47 @@ public void Maps_Messages_With_MultipleToolResponse()
user.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.User);
}

[Test]
public void Maps_Messages_WithoutContent_MultipleToolResponse()
{
var aiChatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.User,
Contents = [
new FunctionResultContent(
callId: "123",
name: "Function1",
result: new { Temperature = 40 }),

new FunctionResultContent(
callId: "456",
name: "Function2",
result: new { Summary = "This is a tool result test" }
),
]
}
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(aiChatMessages, new(), stream: true, JsonSerializerOptions.Default);
var chatMessages = chatRequest.Messages?.ToList();

chatMessages.Should().HaveCount(2);

var tool1 = chatMessages[0];
var tool2 = chatMessages[1];
tool1.Content.Should().Contain("\"Temperature\":40");
tool1.Content.Should().Contain("\"CallId\":\"123\"");
tool1.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
tool2.Content.Should().Contain("\"Summary\":\"This is a tool result test\"");
tool2.Content.Should().Contain("\"CallId\":\"456\"");
tool2.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
}

[Test]
public void Maps_Options()
{
Expand Down
Loading