Skip to content

Commit

Permalink
add default subscriptions for the agent type - Implicitly created sub… (
Browse files Browse the repository at this point in the history
#4324)

* add default subscriptions for the agent type - Implicitly created subscription for agent RPC #4321

* add default sub for agenttype+id

* fix subscription implementation for in memory runtime
---------

Co-authored-by: XiaoYun Zhang <[email protected]>
  • Loading branch information
rysweet and LittleLittleCloud authored Nov 26, 2024
1 parent 3a1625f commit b94abb2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 20 deletions.
43 changes: 28 additions & 15 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,40 @@ namespace Microsoft.AutoGen.Agents;
public abstract class AgentBase : IAgentBase, IHandle
{
public static readonly ActivitySource s_source = new("AutoGen.Agent");
public AgentId AgentId => _context.AgentId;
public AgentId AgentId => _runtime.AgentId;
private readonly object _lock = new();
private readonly Dictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];

private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentRuntime _context;
private readonly IAgentRuntime _runtime;
public string Route { get; set; } = "base";

protected internal ILogger<AgentBase> _logger;
public IAgentRuntime Context => _context;
public IAgentRuntime Context => _runtime;
protected readonly EventTypes EventTypes;

protected AgentBase(
IAgentRuntime context,
IAgentRuntime runtime,
EventTypes eventTypes,
ILogger<AgentBase>? logger = null)
{
_context = context;
context.AgentInstance = this;
_runtime = runtime;
runtime.AgentInstance = this;
this.EventTypes = eventTypes;
_logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger<AgentBase>();
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = this.AgentId.Type,
TopicType = this.AgentId.Type + "/" + this.AgentId.Key
}
}
};
_runtime.SendMessageAsync(new Message { AddSubscriptionRequest = subscriptionRequest }).AsTask().Wait();
Completion = Start();
}
internal Task Completion { get; }
Expand Down Expand Up @@ -131,19 +144,19 @@ public List<string> Subscribe(string topic)
}
}
};
_context.SendMessageAsync(message).AsTask().Wait();
_runtime.SendMessageAsync(message).AsTask().Wait();

return new List<string> { topic };
}
public async Task StoreAsync(AgentState state, CancellationToken cancellationToken = default)
{
await _context.StoreAsync(state, cancellationToken).ConfigureAwait(false);
await _runtime.StoreAsync(state, cancellationToken).ConfigureAwait(false);
return;
}
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
{
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentState.FromAgentState<T>();
var agentstate = await _runtime.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentstate.FromAgentState<T>();
}
private void OnResponseCore(RpcResponse response)
{
Expand Down Expand Up @@ -171,7 +184,7 @@ private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken canc
{
response = new RpcResponse { Error = ex.Message };
}
await _context.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
await _runtime.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
}

protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Dictionary<string, string> parameters)
Expand All @@ -195,7 +208,7 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
activity?.SetTag("peer.service", target.ToString());

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
_context.Update(request, activity);
_runtime.Update(request, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
{
Expand All @@ -206,7 +219,7 @@ static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResp
self._pendingRequests[request.RequestId] = completion;
}

await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false);
await state.Agent._runtime.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);

await completion.Task.ConfigureAwait(false);
},
Expand All @@ -231,11 +244,11 @@ public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken canc
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");

// TODO: fix activity
_context.Update(item, activity);
_runtime.Update(item, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =>
{
await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false);
await state.Agent._runtime.PublishEventAsync(state.Event).ConfigureAwait(false);
},
(this, item),
activity,
Expand Down
32 changes: 27 additions & 5 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class AgentWorker :
private readonly CancellationTokenSource _shutdownCts;
private readonly IServiceProvider _serviceProvider;
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes;
private readonly ConcurrentDictionary<string, Subscription> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();
private readonly DistributedContextPropagator _distributedContextPropagator;
private readonly CancellationTokenSource _shutdownCancellationToken = new();
private Task? _mailboxTask;
Expand Down Expand Up @@ -96,11 +98,7 @@ public async Task RunMessagePump()
if (message == null) { continue; }
switch (message)
{
case Message.MessageOneofCase.AddSubscriptionResponse:
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
break;
case Message msg:
case Message msg when msg.CloudEvent != null:

var item = msg.CloudEvent;

Expand All @@ -110,6 +108,13 @@ public async Task RunMessagePump()
agentToInvoke.ReceiveMessage(msg);
}
break;
case Message msg when msg.AddSubscriptionRequest != null:
await AddSubscriptionRequestAsync(msg.AddSubscriptionRequest).ConfigureAwait(true);
break;
case Message msg when msg.AddSubscriptionResponse != null:
break;
case Message msg when msg.RegisterAgentTypeResponse != null:
break;
default:
throw new InvalidOperationException($"Unexpected message '{message}'.");
}
Expand All @@ -123,6 +128,23 @@ public async Task RunMessagePump()
}
}
}
private async ValueTask AddSubscriptionRequestAsync(AddSubscriptionRequest subscription)
{
var topic = subscription.Subscription.TypeSubscription.TopicType;
var agentType = subscription.Subscription.TypeSubscription.AgentType;
_subscriptionsByAgentType[agentType] = subscription.Subscription;
_subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType);
Message response = new()
{
AddSubscriptionResponse = new()
{
RequestId = subscription.RequestId,
Error = "",
Success = true
}
};
await _mailbox.Writer.WriteAsync(response).ConfigureAwait(false);
}

public async Task StartAsync(CancellationToken cancellationToken)
{
Expand Down
16 changes: 16 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
Success = true
}
};
// add a default subscription for the agent type
//TODO: we should consider having constraints on the namespace or at least migrate all our examples to use well typed namesspaces like com.microsoft.autogen/hello/HelloAgents etc
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = msg.Type,
TopicType = msg.Type
}
}
};
await AddSubscriptionAsync(connection, subscriptionRequest).ConfigureAwait(true);

await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false);
}
private async ValueTask DispatchEventAsync(CloudEvent evt)
Expand Down
4 changes: 4 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public class AgentBaseTests(InMemoryAgentRuntimeFixture fixture)
public async Task ItInvokeRightHandlerTestAsync()
{
var mockContext = new Mock<IAgentRuntime>();
mockContext.SetupGet(x => x.AgentId).Returns(new AgentId("test", "test"));
// mock SendMessageAsync
mockContext.Setup(x => x.SendMessageAsync(It.IsAny<Message>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask());
var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []), new Logger<AgentBase>(new LoggerFactory()));

await agent.HandleObject("hello world");
Expand Down

0 comments on commit b94abb2

Please sign in to comment.