Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make IChatClient/IEmbeddingGenerator.GetService non-generic #5608

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ namespace Microsoft.Extensions.AI;
/// <summary>Provides a collection of static methods for extending <see cref="IChatClient"/> instances.</summary>
public static class ChatClientExtensions
{
/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="client">The client.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TService>(this IChatClient client, object? serviceKey = null)
{
_ = Throw.IfNull(client);

return (TService?)client.GetService(typeof(TService), serviceKey);
}

/// <summary>Sends a user chat text message to the model and returns the response messages.</summary>
/// <param name="client">The chat client.</param>
/// <param name="chatMessage">The text content for the chat message to send.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ public virtual IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreaming
}

/// <inheritdoc />
public virtual TService? GetService<TService>(object? key = null)
where TService : class
public virtual object? GetService(Type serviceType, object? serviceKey = null)
{
#pragma warning disable S3060 // "is" should not be used with "this"
// If the key is non-null, we don't know what it means so pass through to the inner service
return key is null && this is TService service ? service : InnerClient.GetService<TService>(key);
#pragma warning restore S3060
_ = Throw.IfNull(serviceType);

// If the key is non-null, we don't know what it means so pass through to the inner service.
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
InnerClient.GetService(serviceType, serviceKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
/// <summary>Gets metadata that describes the <see cref="IChatClient"/>.</summary>
ChatClientMetadata Metadata { get; }

/// <summary>Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="key">An optional key that may be used to help identify the target service.</param>
/// <summary>Asks the <see cref="IChatClient"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
TService? GetService<TService>(object? key = null)
where TService : class;
object? GetService(Type serviceType, object? serviceKey = null);
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ public virtual Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(IEnumerable<T
InnerGenerator.GenerateAsync(values, options, cancellationToken);

/// <inheritdoc />
public virtual TService? GetService<TService>(object? key = null)
where TService : class
public virtual object? GetService(Type serviceType, object? serviceKey = null)
{
#pragma warning disable S3060 // "is" should not be used with "this"
// If the key is non-null, we don't know what it means so pass through to the inner service
return key is null && this is TService service ? service : InnerGenerator.GetService<TService>(key);
#pragma warning restore S3060
_ = Throw.IfNull(serviceType);

// If the key is non-null, we don't know what it means so pass through to the inner service.
return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
InnerGenerator.GetService(serviceType, serviceKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
public static class EmbeddingGeneratorExtensions
{
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TInput, TEmbedding, TService>(this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);

return (TService?)generator.GetService(typeof(TService), serviceKey);
}

// The following overload exists purely to work around the lack of partial generic type inference.
// Given an IEmbeddingGenerator<TInput, TEmbedding> generator, to call GetService with TService, you still need
// to re-specify both TInput and TEmbedding, e.g. generator.GetService<string, Embedding<float>, TService>.
// The case of string/Embedding<float> is by far the most common case today, so this overload exists as an
// accelerator to allow it to be written simply as generator.GetService<TService>.

/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
GetService<string, Embedding<float>, TService>(generator, serviceKey);

/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
/// <summary>Gets metadata that describes the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</summary>
EmbeddingGeneratorMetadata Metadata { get; }

/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of type <typeparamref name="TService"/>.</summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="key">An optional key that may be used to help identify the target service.</param>
/// <summary>Asks the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that may be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>,
/// including itself or any services it might be wrapping.
/// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
TService? GetService<TService>(object? key = null)
where TService : class;
object? GetService(Type serviceType, object? serviceKey = null);
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,16 @@ public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, s
public ChatClientMetadata Metadata { get; }

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(ChatCompletionsClient) ? (TService?)(object?)_chatCompletionsClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is not null ? null :
serviceType == typeof(ChatCompletionsClient) ? _chatCompletionsClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,16 @@ public AzureAIInferenceEmbeddingGenerator(
public EmbeddingGeneratorMetadata Metadata { get; }

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(EmbeddingsClient) ? (TService)(object)_embeddingsClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is not null ? null :
serviceType == typeof(EmbeddingsClient) ? _embeddingsClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,14 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
}

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=> key is null ? this as TService : null;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public void Dispose()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient
public EmbeddingGeneratorMetadata Metadata { get; }

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=> key is null ? this as TService : null;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public void Dispose()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using OpenAI;
using OpenAI.Chat;

#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S1135 // Track uses of "TODO" tags
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1204 // Static elements should appear before instance elements
Expand Down Expand Up @@ -85,11 +86,17 @@ public OpenAIChatClient(ChatClient chatClient)
public ChatClientMetadata Metadata { get; }

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class =>
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is not null ? null :
serviceType == typeof(OpenAIClient) ? _openAIClient :
serviceType == typeof(ChatClient) ? _chatClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using OpenAI;
using OpenAI.Embeddings;

#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields

namespace Microsoft.Extensions.AI;
Expand Down Expand Up @@ -95,12 +96,17 @@ private static EmbeddingGeneratorMetadata CreateMetadata(string providerName, st
public EmbeddingGeneratorMetadata Metadata { get; }

/// <inheritdoc />
public TService? GetService<TService>(object? key = null)
where TService : class
=>
typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient :
typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient :
this as TService;
public object? GetService(Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(serviceType);

return
serviceKey is not null ? null :
serviceType == typeof(OpenAIClient) ? _openAIClient :
serviceType == typeof(EmbeddingClient) ? _embeddingClient :
serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace Microsoft.Extensions.AI;

public class ChatClientExtensionsTests
{
[Fact]
public void GetService_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetService<object>(null!));
}

[Fact]
public void CompleteAsync_InvalidArgs_Throws()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ public async Task ChatStreamingAsyncDefaultsToInnerClientAsync()
Assert.False(await enumerator.MoveNextAsync());
}

[Fact]
public void GetServiceThrowsForNullType()
{
using var inner = new TestChatClient();
using var delegating = new NoOpDelegatingChatClient(inner);
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
}

[Fact]
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ public async Task GenerateEmbeddingsDefaultsToInnerServiceAsync()
Assert.Same(expectedEmbedding, await resultTask);
}

[Fact]
public void GetServiceThrowsForNullType()
{
using var inner = new TestEmbeddingGenerator();
using var delegating = new NoOpDelegatingEmbeddingGenerator(inner);
Assert.Throws<ArgumentNullException>("serviceType", () => delegating.GetService(null!));
}

[Fact]
public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ namespace Microsoft.Extensions.AI;

public class EmbeddingGeneratorExtensionsTests
{
[Fact]
public void GetService_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<object>(null!));
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<string, Embedding<double>, object>(null!));
}

[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ public Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatO
public IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken);

public TService? GetService<TService>(object? key = null)
where TService : class
=> (TService?)GetServiceCallback!(typeof(TService), key);
public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);

void IDisposable.Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator<string, Embeddi
public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
=> GenerateAsyncCallback!.Invoke(values, options, cancellationToken);

public TService? GetService<TService>(object? key = null)
where TService : class
=> (TService?)GetServiceCallback!(typeof(TService), key);
public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);

void IDisposable.Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ public QuantizationEmbeddingGenerator(IEmbeddingGenerator<string, Embedding<floa

void IDisposable.Dispose() => _floatService.Dispose();

public TService? GetService<TService>(object? key = null)
where TService : class =>
key is null && this is TService ? (TService?)(object)this :
_floatService.GetService<TService>(key);
public object? GetService(Type serviceType, object? serviceKey = null) =>
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
_floatService.GetService(serviceType, serviceKey);

async Task<GeneratedEmbeddings<BinaryEmbedding>> IEmbeddingGenerator<string, BinaryEmbedding>.GenerateAsync(
IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
Expand Down
Loading