Skip to content

Commit

Permalink
.Net: AzureOpenAI Connector - Bugfix AsyncFilter Null Reference Excep…
Browse files Browse the repository at this point in the history
…tion when using Function Calling. (#8654)

### Motivation and Context

Closes #8629 

Workaround BugFix for issue in OpenAI SDK:

- openai/openai-dotnet#198
  • Loading branch information
RogerBarreto authored Sep 10, 2024
1 parent d03d294 commit 07eb921
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,98 @@ public async Task GetStreamingChatMessageContentsWithFunctionCallAsync()
Assert.Equal(2, functionCallCount);
}

[Fact]
public async Task GetStreamingChatMessageContentsWithFunctionCallAsyncFilterAsync()
{
// Arrange
int functionCallCount = 0;

var kernel = Kernel.CreateBuilder().Build();
var function1 = KernelFunctionFactory.CreateFromMethod((string location) =>
{
functionCallCount++;
return "Some weather";
}, "GetCurrentWeather");

var function2 = KernelFunctionFactory.CreateFromMethod((string argument) =>
{
functionCallCount++;
throw new ArgumentException("Some exception");
}, "FunctionWithException");

kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]));

var service = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient, this._mockLoggerFactory.Object);
var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };

using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = AzureOpenAITestHelper.GetTestResponseAsStream("chat_completion_streaming_multiple_function_calls_test_async_filter_response.txt") };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = AzureOpenAITestHelper.GetTestResponseAsStream("chat_completion_streaming_test_response.txt") };

this._messageHandlerStub.ResponsesToReturn = [response1, response2];

// Act & Assert
var enumerator = service.GetStreamingChatMessageContentsAsync([], settings, kernel).GetAsyncEnumerator();
await enumerator.MoveNextAsync();
var message = enumerator.Current;

Assert.IsType<StreamingChatCompletionUpdate>(message.InnerContent);
var update = (StreamingChatCompletionUpdate)message.InnerContent;
#pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
var promptResults = update.GetContentFilterResultForPrompt();
Assert.Equal(ContentFilterSeverity.Safe, promptResults.Hate.Severity);
Assert.Equal(ContentFilterSeverity.Safe, promptResults.Sexual.Severity);
Assert.Equal(ContentFilterSeverity.Safe, promptResults.Violence.Severity);
Assert.Equal(ContentFilterSeverity.Safe, promptResults.SelfHarm.Severity);
Assert.False(promptResults.Jailbreak.Detected);

await enumerator.MoveNextAsync();
message = enumerator.Current;
Assert.Equal("Test chat streaming response", message.Content);
Assert.Equal("ToolCalls", message.Metadata?["FinishReason"]);

await enumerator.MoveNextAsync();
message = enumerator.Current;
Assert.Equal("ToolCalls", message.Metadata?["FinishReason"]);

await enumerator.MoveNextAsync();
message = enumerator.Current;
Assert.Equal("ToolCalls", message.Metadata?["FinishReason"]);

await enumerator.MoveNextAsync();
message = enumerator.Current;
Assert.Equal("ToolCalls", message.Metadata?["FinishReason"]);

// Async Filter Final Chunks
await enumerator.MoveNextAsync();
message = enumerator.Current;

Assert.IsType<StreamingChatCompletionUpdate>(message.InnerContent);
update = (StreamingChatCompletionUpdate)message.InnerContent;

var filterResults = update.GetContentFilterResultForResponse();
Assert.Equal(ContentFilterSeverity.Safe, filterResults.Hate.Severity);
Assert.Equal(ContentFilterSeverity.Safe, filterResults.Sexual.Severity);
Assert.Equal(ContentFilterSeverity.Safe, filterResults.SelfHarm.Severity);
Assert.Equal(ContentFilterSeverity.Safe, filterResults.Violence.Severity);

await enumerator.MoveNextAsync();
message = enumerator.Current;

Assert.IsType<StreamingChatCompletionUpdate>(message.InnerContent);
update = (StreamingChatCompletionUpdate)message.InnerContent;
filterResults = update.GetContentFilterResultForResponse();
Assert.False(filterResults.ProtectedMaterialCode.Detected);
Assert.False(filterResults.ProtectedMaterialText.Detected);

// Keep looping until the end of stream
while (await enumerator.MoveNextAsync())
{
}

Assert.Equal(2, functionCallCount);
#pragma warning restore AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
}

[Fact]
public async Task GetStreamingChatMessageContentsWithFunctionCallMaximumAutoInvokeAttemptsAsync()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data: {"choices":[],"created":0,"id":"","model":"","object":"","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"jailbreak":{"filtered":false,"detected":false},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin-GetCurrentWeather","arguments":"{\n\"location\": \"Boston, MA\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin-FunctionWithException","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":2,"id":"3","type":"function","function":{"name":"MyPlugin-NonExistentFunction","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":3,"id":"4","type":"function","function":{"name":"MyPlugin-InvalidArguments","arguments":"invalid_arguments_format"}}]},"finish_reason":"tool_calls"}]}

data: {"choices":[{"content_filter_offsets":{"check_offset":1576,"start_offset":1576,"end_offset":2318},"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":null,"index":0}],"created":0,"id":"","model":"","object":""}

data: {"choices":[{"content_filter_offsets":{"check_offset":1576,"start_offset":1576,"end_offset":2318},"content_filter_results":{"protected_material_code":{"filtered":false,"detected":false},"protected_material_text":{"filtered":false,"detected":false}},"finish_reason":null,"index":0}],"created":0,"id":"","model":"","object":""}

data: [DONE]
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,22 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// If we're intending to invoke function calls, we need to consume that function call information.
if (toolCallingConfig.AutoInvoke)
{
foreach (var contentPart in chatCompletionUpdate.ContentUpdate)
try
{
if (contentPart.Kind == ChatMessageContentPartKind.Text)
foreach (var contentPart in chatCompletionUpdate.ContentUpdate)
{
(contentBuilder ??= new()).Append(contentPart.Text);
if (contentPart.Kind == ChatMessageContentPartKind.Text)
{
(contentBuilder ??= new()).Append(contentPart.Text);
}
}
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatCompletionUpdate.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}
catch (NullReferenceException)
{
// Temporary workaround for OpenAI SDK Bug here: https://github.com/openai/openai-dotnet/issues/198
// TODO: Remove this try-catch block once the bug is fixed.
}

OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatCompletionUpdate.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}

var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(chatCompletionUpdate, 0, targetModel, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ internal OpenAIStreamingChatMessageContent(
}
catch (NullReferenceException)
{
// Temporary bugfix for: https://github.com/openai/openai-dotnet/issues/198
// Temporary workaround for OpenAI SDK Bug here: https://github.com/openai/openai-dotnet/issues/198
// TODO: Remove this try-catch block once the bug is fixed.
}
}
Expand Down

0 comments on commit 07eb921

Please sign in to comment.