diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs new file mode 100644 index 000000000000..5deba58ae62b --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -0,0 +1,597 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentRuntime.cs + +using System.Collections.Concurrent; +using System.Threading.Channels; +using Google.Protobuf; +using Grpc.Core; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public sealed class GrpcAgentRuntime( + AgentRpc.AgentRpcClient client, + IHostApplicationLifetime hostApplicationLifetime, + IServiceProvider serviceProvider, + ILogger logger + ) : IAgentRuntime, IDisposable +{ + private readonly object _channelLock = new(); + + // Request ID -> + private readonly ConcurrentDictionary> _pendingRequests = new(); + private Dictionary>> agentFactories = new(); + private Dictionary agentInstances = new(); + + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024) + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.Wait + }); + + private readonly AgentRpc.AgentRpcClient _client = client; + public readonly IServiceProvider ServiceProvider = serviceProvider; + + private readonly ILogger _logger = logger; + private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); + private AsyncDuplexStreamingCall? _channel; + private Task? _readTask; + private Task? _writeTask; + + private string _clientId = Guid.NewGuid().ToString(); + private CallOptions CallOptions + { + get + { + var metadata = new Metadata + { + { "client-id", this._clientId } + }; + return new CallOptions(headers: metadata); + } + } + + public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + + public void Dispose() + { + _outboundMessagesChannel.Writer.TryComplete(); + _channel?.Dispose(); + } + + private async Task RunReadPump() + { + var channel = GetChannel(); + while (!_shutdownCts.Token.IsCancellationRequested) + { + try + { + await foreach (var message in channel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) + { + // next if message is null + if (message == null) + { + continue; + } + switch (message.MessageCase) + { + case Message.MessageOneofCase.Request: + var request = message.Request ?? throw new InvalidOperationException("Request is null."); + await HandleRequest(request); + break; + case Message.MessageOneofCase.Response: + var response = message.Response ?? throw new InvalidOperationException("Response is null."); + await HandleResponse(response); + break; + case Message.MessageOneofCase.CloudEvent: + var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null."); + await HandlePublish(cloudEvent); + break; + default: + throw new InvalidOperationException($"Unexpected message '{message}'."); + } + } + } + catch (OperationCanceledException) + { + // Time to shut down. + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + _logger.LogError(ex, "Error reading from channel."); + channel = RecreateChannel(channel); + } + catch + { + // Shutdown requested. + break; + } + } + } + + private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.Target is null) + { + throw new InvalidOperationException("Target is null."); + } + if (request.Source is null) + { + throw new InvalidOperationException("Source is null."); + } + + var agentId = request.Target; + var agent = await EnsureAgentAsync(agentId.FromProtobuf()); + + // Convert payload back to object + var payload = request.Payload; + var message = PayloadToObject(payload); + + var messageContext = new MessageContext(request.RequestId, cancellationToken) + { + Sender = request.Source.FromProtobuf(), + Topic = null, + IsRpc = true + }; + + var result = await agent.OnMessageAsync(message, messageContext); + + if (result is not null) + { + var response = new RpcResponse + { + RequestId = request.RequestId, + Payload = ObjectToPayload(result) + }; + + var responseMessage = new Message + { + Response = response + }; + + await WriteChannelAsync(responseMessage, cancellationToken); + } + } + + private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.RequestId is null) + { + throw new InvalidOperationException("RequestId is null."); + } + + if (_pendingRequests.TryRemove(request.RequestId, out var resultSink)) + { + var payload = request.Payload; + var message = PayloadToObject(payload); + resultSink.SetResult(message); + } + } + + private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default) + { + if (evt is null) + { + throw new InvalidOperationException("CloudEvent is null."); + } + if (evt.ProtoData is null) + { + throw new InvalidOperationException("ProtoData is null."); + } + if (evt.Attributes is null) + { + throw new InvalidOperationException("Attributes is null."); + } + + var topic = new TopicId(evt.Type, evt.Source); + var sender = new Contracts.AgentId + { + Type = evt.Attributes["agagentsendertype"].CeString, + Key = evt.Attributes["agagentsenderkey"].CeString + }; + + var messageId = evt.Id; + var typeName = evt.Attributes["dataschema"].CeString; + var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception(); + var message = serializer.Deserialize(evt.ProtoData); + + var messageContext = new MessageContext(messageId, cancellationToken) + { + Sender = sender, + Topic = topic, + IsRpc = false + }; + var agent = await EnsureAgentAsync(sender); + await agent.OnMessageAsync(message, messageContext); + } + + private async Task RunWritePump() + { + var channel = GetChannel(); + var outboundMessages = _outboundMessagesChannel.Reader; + while (!_shutdownCts.IsCancellationRequested) + { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; + try + { + await outboundMessages.WaitToReadAsync().ConfigureAwait(false); + + // Read the next message if we don't already have an unsent message + // waiting to be sent. + if (!outboundMessages.TryRead(out item)) + { + break; + } + + while (!_shutdownCts.IsCancellationRequested) + { + await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); + break; + } + } + catch (OperationCanceledException) + { + // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // we could not connect to the endpoint - most likely we have the wrong port or failed ssl + // we need to let the user know what port we tried to connect to and then do backoff and retry + _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) + { + _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", channel.ToString()); + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + item.WriteCompletionSource?.TrySetException(ex); + _logger.LogError(ex, $"Error writing to channel.{ex}"); + channel = RecreateChannel(channel); + continue; + } + catch + { + // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } + } + + // private override async ValueTask SendMessageAsync(Payload message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) + // { + // var request = new RpcRequest + // { + // RequestId = Guid.NewGuid().ToString(), + // Source = agent, + // Target = agentId, + // Payload = message, + // }; + + // // Actually send it and wait for the response + // throw new NotImplementedException(); + // } + + // new is intentional + + // public new async ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default) + // { + // var requestId = Guid.NewGuid().ToString(); + // _pendingRequests[requestId] = ((Agent)agent, request.RequestId); + // request.RequestId = requestId; + // await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); + // } + + private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false); + } + private AsyncDuplexStreamingCall GetChannel() + { + if (_channel is { } channel) + { + return channel; + } + + lock (_channelLock) + { + if (_channel is not null) + { + return _channel; + } + + return RecreateChannel(null); + } + } + + private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? channel) + { + if (_channel is null || _channel == channel) + { + lock (_channelLock) + { + if (_channel is null || _channel == channel) + { + _channel?.Dispose(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); + } + } + } + + return _channel; + } + public async Task StartAsync(CancellationToken cancellationToken) + { + _channel = GetChannel(); + _logger.LogInformation("Starting " + GetType().Name + ",connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); + var didSuppress = false; + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump, cancellationToken); + _writeTask = Task.Run(RunWritePump, cancellationToken); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + _shutdownCts.Cancel(); + + _outboundMessagesChannel.Writer.TryComplete(); + + if (_readTask is { } readTask) + { + await readTask.ConfigureAwait(false); + } + + if (_writeTask is { } writeTask) + { + await writeTask.ConfigureAwait(false); + } + lock (_channelLock) + { + _channel?.Dispose(); + } + } + + private async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) + { + if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent)) + { + if (!this.agentFactories.TryGetValue(agentId.Type, out Func>? factoryFunc)) + { + throw new Exception($"Agent with name {agentId.Type} not found."); + } + + agent = await factoryFunc(agentId, this); + this.agentInstances.Add(agentId, agent); + } + + return this.agentInstances[agentId]; + } + + private Payload ObjectToPayload(object message) { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // Protobuf any to byte array + Payload payload = new() + { + DataType = typeName, + DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + Data = rpcMessage.ToByteString() + }; + + return payload; + } + + private object PayloadToObject(Payload payload) { + var typeName = payload.DataType; + var data = payload.Data; + var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); + var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); + var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + return serializer.Deserialize(any); + } + + public async ValueTask SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + + var payload = ObjectToPayload(message); + var request = new RpcRequest + { + RequestId = Guid.NewGuid().ToString(), + Source = (sender ?? new Contracts.AgentId() ).ToProtobuf(), + Target = recepient.ToProtobuf(), + Payload = payload, + }; + + Message msg = new() + { + Request = request + }; + // Create a future that will be completed when the response is received + var resultSink = new ResultSink(); + this._pendingRequests.TryAdd(request.RequestId, resultSink); + await WriteChannelAsync(msg, cancellationToken); + + return await resultSink.Future; + } + + private CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, Contracts.AgentId sender, string messageId) + { + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + return new CloudEvent + { + ProtoData = payload, + Type = topic.Type, + Source = topic.Source, + Id = messageId, + Attributes = { + { + "datacontenttype", new CloudEvent.Types.CloudEventAttributeValue { CeString = PAYLOAD_DATA_CONTENT_TYPE } + }, + { + "dataschema", new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType } + }, + { + "agagentsendertype", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Type } + }, + { + "agagentsenderkey", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Key } + }, + { + "agmsgkind", new CloudEvent.Types.CloudEventAttributeValue { CeString = "publish" } + } + } + }; + } + + public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + + var cloudEvent = CreateCloudEvent(protoAny, topic, typeName, sender ?? new Contracts.AgentId(), messageId ?? Guid.NewGuid().ToString()); + + Message msg = new() + { + CloudEvent = cloudEvent + }; + await WriteChannelAsync(msg, cancellationToken); + } + + public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentMetadataAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) + { + var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest{ + Subscription = subscription.ToProtobuf() + },this.CallOptions); + } + + public ValueTask RemoveSubscriptionAsync(string subscriptionId) + { + throw new NotImplementedException(); + } + + public ValueTask RegisterAgentFactoryAsync(AgentType type, Func> factoryFunc) + { + if (this.agentFactories.ContainsKey(type)) + { + throw new Exception($"Agent with type {type} already exists."); + } + this.agentFactories.Add(type, async (agentId, runtime) => await factoryFunc(agentId, runtime)); + + this._client.RegisterAgentAsync(new RegisterAgentTypeRequest + { + Type = type.Name, + + }, this.CallOptions); + return ValueTask.FromResult(type); + } + + public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public ValueTask> SaveStateAsync() + { + throw new NotImplementedException(); + } + + public ValueTask LoadStateAsync(IDictionary state) + { + throw new NotImplementedException(); + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs new file mode 100644 index 000000000000..0cc422d54d85 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; +/// +/// Interface for serializing and deserializing agent messages. +/// +public interface IAgentMessageSerializer +{ + /// + /// Serialize an agent message. + /// + /// The message to serialize. + /// The serialized message. + Google.Protobuf.WellKnownTypes.Any Serialize(object message); + + /// + /// Deserialize an agent message. + /// + /// The message to deserialize. + /// The deserialized message. + object Deserialize(Google.Protobuf.WellKnownTypes.Any message); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs new file mode 100644 index 000000000000..c820baa527c7 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentRuntimeExtensions.cs + +using System.Diagnostics; +using Google.Protobuf.Collections; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.DependencyInjection; +using static Microsoft.AutoGen.Contracts.CloudEvent.Types; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class IAgentRuntimeExtensions +{ + public static (string?, string?) GetTraceIdAndState(IAgentRuntime runtime, IDictionary metadata) + { + var dcp = runtime.RuntimeServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static (string?, string?) GetTraceIdAndState(IAgentRuntime worker, MapField metadata) + { + var dcp = worker.RuntimeServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static void Update(IAgentRuntime worker, RpcRequest request, Activity? activity = null) + { + var dcp = worker.RuntimeServiceProvider.GetRequiredService(); + dcp.Inject(activity, request.Metadata, static (carrier, key, value) => + { + var metadata = (IDictionary)carrier!; + if (metadata.TryGetValue(key, out _)) + { + metadata[key] = value; + } + else + { + metadata.Add(key, value); + } + }); + } + public static void Update(IAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null) + { + var dcp = worker.RuntimeServiceProvider.GetRequiredService(); + dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) => + { + var mapField = (MapField)carrier!; + if (mapField.TryGetValue(key, out var ceValue)) + { + mapField[key] = new CloudEventAttributeValue { CeString = value }; + } + else + { + mapField.Add(key, new CloudEventAttributeValue { CeString = value }); + } + }); + } + + public static IDictionary ExtractMetadata(IAgentRuntime worker, IDictionary metadata) + { + var dcp = worker.RuntimeServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }); + + return baggage as IDictionary ?? new Dictionary(); + } + public static IDictionary ExtractMetadata(IAgentRuntime worker, MapField metadata) + { + var dcp = worker.RuntimeServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }); + + return baggage as IDictionary ?? new Dictionary(); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs new file mode 100644 index 000000000000..ca690e508d2b --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IProtoMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtoMessageSerializer +{ + Google.Protobuf.WellKnownTypes.Any Serialize(object input); + object Deserialize(Google.Protobuf.WellKnownTypes.Any input); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs new file mode 100644 index 000000000000..190ed3ec239d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ISerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtoSerializationRegistry +{ + /// + /// Registers a serializer for the specified type. + /// + /// The type to register. + void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type)); + + void RegisterSerializer(System.Type type, IProtoMessageSerializer serializer); + + /// + /// Gets the serializer for the specified type. + /// + /// The type to get the serializer for. + /// The serializer for the specified type. + IProtoMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type)); + IProtoMessageSerializer? GetSerializer(string typeName); + + ITypeNameResolver TypeNameResolver { get; } + + bool Exists(System.Type type); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs new file mode 100644 index 000000000000..24de4cb8b449 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITypeNameResolver.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface ITypeNameResolver +{ + string ResolveTypeName(object input); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs new file mode 100644 index 000000000000..808116139ba6 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITypeNameResolver.cs + +using Google.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtoTypeNameResolver : ITypeNameResolver +{ + public string ResolveTypeName(object input) + { + if (input is IMessage protoMessage) + { + return protoMessage.Descriptor.FullName; + } + else + { + throw new ArgumentException("Input must be a protobuf message."); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs new file mode 100644 index 000000000000..4850b7825afe --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufConversionExtensions.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class ProtobufConversionExtensions +{ + // Convert an ISubscrptionDefinition to a Protobuf Subscription + public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition) + { + // Check if is a TypeSubscription + if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription) + { + return new Subscription + { + Id = typeSubscription.Id, + TypeSubscription = new Protobuf.TypeSubscription + { + TopicType = typeSubscription.TopicType, + AgentType = typeSubscription.AgentType + } + }; + } + + // Check if is a TypePrefixSubscription + if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription) + { + return new Subscription + { + Id = typePrefixSubscription.Id, + TypePrefixSubscription = new Protobuf.TypePrefixSubscription + { + TopicTypePrefix = typePrefixSubscription.TopicTypePrefix, + AgentType = typePrefixSubscription.AgentType + } + }; + } + + return null; + } + + // Convert AgentId from Protobuf to AgentId + public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId) + { + return new Contracts.AgentId(agentId.Type, agentId.Key); + } + + // Convert AgentId from AgentId to Protobuf + public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId) + { + return new Protobuf.AgentId + { + Type = agentId.Type, + Key = agentId.Key + }; + } + +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs new file mode 100644 index 000000000000..55c1aebfa47d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufMessageSerializer.cs + +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Microsoft.AutoGen.Core.Grpc; + +/// +/// Interface for serializing and deserializing agent messages. +/// +public class ProtobufMessageSerializer : IProtoMessageSerializer +{ + private System.Type _concreteType; + + public ProtobufMessageSerializer(System.Type concreteType) + { + _concreteType = concreteType; + } + + public object Deserialize(Any message) + { + // Check if the concrete type is a proto IMessage + if (typeof(IMessage).IsAssignableFrom(_concreteType)) + { + var nameOfMethod = nameof(Any.Unpack); + var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null); + return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message)); + } + + // Raise an exception if the concrete type is not a proto IMessage + throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType)); + } + + public Any Serialize(object message) + { + // Check if message is a proto IMessage + if (message is IMessage protoMessage) + { + return Any.Pack(protoMessage); + } + + // Raise an exception if the message is not a proto IMessage + throw new ArgumentException("Message must be a proto IMessage", nameof(message)); + } +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs new file mode 100644 index 000000000000..d7bf3a37325c --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtoSerializationRegistry : IProtoSerializationRegistry +{ + private readonly Dictionary _serializers + = new Dictionary(); + + public bool Exists(Type type) + { + return _serializers.ContainsKey(type); + } + + public IProtoMessageSerializer? GetSerializer(Type type) + { + _serializers.TryGetValue(type, out var serializer); + return serializer; + } + + public void RegisterSerializer(Type type, IProtoMessageSerializer serializer) + { + if (_serializers.ContainsKey(type)) + { + throw new InvalidOperationException($"Serializer already registered for {type.FullName}"); + } + _serializers[type] = serializer; + } +}