diff --git a/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs b/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs index fc1c107aa2f..f6e642afe11 100644 --- a/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs @@ -912,7 +912,7 @@ public static IServiceCollection AddDbContextFactory serviceCollection.TryAdd( new ServiceDescriptor( typeof(TContext), - typeof(TContext), + sp => sp.GetRequiredService>().CreateDbContext(), lifetime == ServiceLifetime.Transient ? ServiceLifetime.Transient : ServiceLifetime.Scoped)); @@ -1024,6 +1024,7 @@ public static IServiceCollection AddPooledDbContextFactory serviceCollection.TryAddSingleton, DbContextPool>(); serviceCollection.TryAddSingleton>( sp => new PooledDbContextFactory(sp.GetRequiredService>())); + serviceCollection.TryAddScoped(sp => sp.GetRequiredService>().CreateDbContext()); return serviceCollection; } diff --git a/test/EFCore.Tests/DbContextFactoryTest.cs b/test/EFCore.Tests/DbContextFactoryTest.cs index e620b82d4b8..9b64b8e685f 100644 --- a/test/EFCore.Tests/DbContextFactoryTest.cs +++ b/test/EFCore.Tests/DbContextFactoryTest.cs @@ -80,6 +80,42 @@ private static void ContextFactoryTest(ServiceLifetime lifetime) Assert.Throws(() => context2.Model); } + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public void Registering_factory_also_registers_a_scoped_context(bool pooled) + { + using var serviceProvider = RegisterContextFactory(new ServiceCollection(), pooled).BuildServiceProvider(); + + WoolacombeContext context1; + + using (var scope1 = serviceProvider.CreateScope()) + { + context1 = scope1.ServiceProvider.GetRequiredService(); + var context2 = scope1.ServiceProvider.GetRequiredService(); + + Assert.Same(context1, context2); // Assert not transient + } + + Assert.Throws(() => context1.Model); // Assert not singleton + } + + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public async Task Implicitly_registered_context_gets_created_via_user_factory_type(bool pooled) + { + await using var serviceProvider = new ServiceCollection() + .AddDbContextFactory( + b => b.UseInMemoryDatabase(nameof(WoolacombeContext))) + .BuildServiceProvider(); + + var factory = (WoolacombeContextFactory)serviceProvider.GetRequiredService>(); + using var scope = serviceProvider.CreateScope(); + await using var context = scope.ServiceProvider.GetRequiredService(); + Assert.True(factory.WasCalled); + } + [ConditionalFact] public void Factory_can_use_pool() { @@ -153,6 +189,26 @@ public void Factory_can_use_shared_pool() Assert.NotSame(context2b, context3b); } + [ConditionalFact] + public void Implicitly_registered_context_gets_pooled() + { + var serviceProvider = RegisterContextFactory(new ServiceCollection(), pooled: true).BuildServiceProvider(); + + WoolacombeContext context1, context2; + + using (var scope = serviceProvider.CreateScope()) + { + context1 = scope.ServiceProvider.GetRequiredService(); + } + + using (var scope = serviceProvider.CreateScope()) + { + context2 = scope.ServiceProvider.GetRequiredService(); + } + + Assert.Same(context1, context2); + } + [ConditionalTheory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)] @@ -566,10 +622,13 @@ public void Application_can_register_factory_implementation_in_AddDbContextFacto private class WoolacombeContextFactory(DbContextOptions options) : IDbContextFactory { - private readonly DbContextOptions _options = options; + public bool WasCalled { get; private set; } public WoolacombeContext CreateDbContext() - => new(_options); + { + WasCalled = true; + return new WoolacombeContext(options); + } } [ConditionalTheory] @@ -851,6 +910,21 @@ private class CustomModelCustomizer(ModelCustomizerDependencies dependencies) : private class FactoryModelCustomizer(ModelCustomizerDependencies dependencies) : ModelCustomizer(dependencies); + private IServiceCollection RegisterContextFactory(IServiceCollection serviceCollection, bool pooled) + where TContext : DbContext + { + if (pooled) + { + serviceCollection.AddPooledDbContextFactory(b => b.UseInMemoryDatabase(nameof(TContext))); + } + else + { + serviceCollection.AddDbContextFactory(b => b.UseInMemoryDatabase(nameof(TContext))); + } + + return serviceCollection; + } + private static string GetStoreName(DbContext context1) => context1.GetService().FindExtension().StoreName; }