Skip to content
This repository has been archived by the owner on Apr 20, 2024. It is now read-only.

Fix/update child entities meta #82

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,85 @@ public void Setup()
_repository = new TestEntityRepository(_context);
}

[Test]
public async Task SubEntities1()
{
var entity = new TestEntity
{
Property = "whatever",
TestSubEntities = new List<TestSubEntity>
{
new TestSubEntity
{
Property = "sub 1"
},
new TestSubEntity
{
Property = "sub 2"
}
},
TestSubEntity = new SingleTestSubEntity
{
Property = "sub 999"
}
};

var updatedEntity = await _repository.AddAsync(entity);
await _unitOfWork.CommitAsync();

Assert.AreNotEqual(default(Guid), updatedEntity.TestSubEntities.ElementAt(0).Id);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(0).Created);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(0).Updated);
Assert.AreEqual(updatedEntity.TestSubEntities.ElementAt(0).Created, updatedEntity.TestSubEntities.ElementAt(0).Updated);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(1).Created);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(1).Updated);
Assert.AreEqual(updatedEntity.TestSubEntities.ElementAt(1).Created, updatedEntity.TestSubEntities.ElementAt(1).Updated);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntity.Created);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntity.Updated);
Assert.AreEqual(updatedEntity.TestSubEntity.Created, updatedEntity.TestSubEntity.Updated);
}

[Test]
public async Task SubEntities2()
{
var (oldUpdated, oldCreated) = (_entity.Updated, _entity.Created);
var entity = new TestEntity
{
Id = _entity.Id,
TestSubEntities = new List<TestSubEntity>
{
new TestSubEntity
{
Property = "sub 1"
},
new TestSubEntity
{
Property = "sub 2"
},
new TestSubEntity
{
Property = "sub 3"
},
new TestSubEntity
{
Property = "sub 4"
}
}
};

var updatedEntity = await _repository.UpdateAsync(entity);
await _unitOfWork.CommitAsync();

updatedEntity.TestSubEntities.First().Property = "salam";
updatedEntity = await _repository.UpdateAsync(entity);
await _unitOfWork.CommitAsync();

Assert.AreNotEqual(default(Guid), updatedEntity.TestSubEntities.ElementAt(0).Id);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(0).Created);
Assert.AreNotEqual(default(DateTime), updatedEntity.TestSubEntities.ElementAt(0).Updated);
}


#region Add
[Test]
public async Task AddAddsEntityAndSetsAttributes()
Expand Down
14 changes: 13 additions & 1 deletion Monstarlab.EntityFramework.Extension.Tests/Mocks/TestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ public class TestContext : DbContext
public TestContext(DbContextOptions options) : base(options) { }

public DbSet<TestEntity> Table { get; set; }

public DbSet<TestSubEntity> SubEntityTable { get; set; }
public DbSet<SingleTestSubEntity> SingleSubEntityTable { get; set; }

public DbSet<TestSoftDeleteEntity> SoftDeleteTable { get; set; }
}

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);

modelBuilder.ApplyConfigurationsFromAssembly(GetType().Assembly);
}


}
16 changes: 16 additions & 0 deletions Monstarlab.EntityFramework.Extension.Tests/Mocks/TestEntity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,20 @@ public class TestEntity : EntityBase<Guid>

[ReadOnly(true)]
public string ReadOnlyProperty { get; set; }

public virtual IEnumerable<TestSubEntity> TestSubEntities { get; set; }

public virtual SingleTestSubEntity TestSubEntity { get; set; }
}

public class TestSubEntity : EntityBase<Guid>
{
public string Property { get; set; }
public virtual TestEntity TestEntity { get; set; }
}

public class SingleTestSubEntity : EntityBase<Guid>
{
public string Property { get; set; }
public virtual TestEntity TestEntity { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Microsoft.EntityFrameworkCore.Metadata.Builders;

namespace Monstarlab.EntityFramework.Extension.Tests.Mocks;

public class TestEntityConfiguration : IEntityTypeConfiguration<TestEntity>
{
public void Configure(EntityTypeBuilder<TestEntity> builder)
{
builder.HasOne(e => e.TestSubEntity)
.WithOne(e => e.TestEntity).HasForeignKey(nameof(SingleTestSubEntity));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ public class TestEntityRepository : EntityRepository<TestContext, TestEntity, Gu
public TestEntityRepository(TestContext context) : base(context)
{
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
namespace Monstarlab.EntityFramework.Extension.Tests.Mocks;

public class TestSubEntityRepository : EntityRepository<TestContext, TestSubEntity, Guid>
{
public TestSubEntityRepository(TestContext context) : base(context)
{
}
}

public class SingleTestSubEntityRepository : EntityRepository<TestContext, SingleTestSubEntity, Guid>
{
public SingleTestSubEntityRepository(TestContext context) : base(context)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="AutoFixture.NUnit3" Version="4.17.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="6.0.6" />
<PackageReference Include="AutoFixture.NUnit3" Version="4.18.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="6.0.15" />
<PackageReference Include="nunit" Version="3.13.3" />
<PackageReference Include="NUnit3TestAdapter" Version="4.2.1">
<PackageReference Include="NUnit3TestAdapter" Version="4.4.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.2.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.5.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<PackageId>Monstarlab.EntityFramework.Extension</PackageId>
<Version>3.0.2</Version>
<Version>4.0.0</Version>
<Authors>Monstarlab</Authors>
<Company>Monstarlab</Company>
<Product>Entity Framework Extension</Product>
Expand All @@ -15,10 +15,11 @@
<RepositoryType>Github</RepositoryType>
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.6" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.15" />
<PackageReference Include="System.ComponentModel.Annotations" Version="5.0.0" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
namespace Monstarlab.EntityFramework.Extension.Repositories;
using System.Collections;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.ChangeTracking.Internal;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Monstarlab.EntityFramework.Extension.Utils;

namespace Monstarlab.EntityFramework.Extension.Repositories;

public abstract class BaseEntityRepository<TContext, TEntity, TId> : IBaseEntityRepository<TEntity, TId>
where TEntity : EntityBase<TId>
Expand All @@ -18,13 +24,10 @@ public virtual async Task<TEntity> AddAsync(TEntity entity)
if (entity == null)
throw new ArgumentNullException(nameof(entity));

var now = DateTime.UtcNow;

entity.Created = now;
entity.Updated = now;

var addedEntity = await Context.Set<TEntity>().AddAsync(entity);
UpdateSubEntities(Context.Entry(entity), new HashSet<object>(), true);

var addedEntity = await Context.Set<TEntity>().AddAsync(entity);

return addedEntity.Entity;
}

Expand All @@ -47,17 +50,59 @@ public virtual async Task<TEntity> UpdateAsync(TEntity entity)
var initialValue = prop.GetValue(originalEntity);
var potentialNewValue = prop.GetValue(entity);

if (potentialNewValue != null && potentialNewValue != initialValue && !PropertyIsReadOnly(prop))
if (potentialNewValue != null && potentialNewValue != initialValue && !prop.IsReadOnly())
prop.SetValue(originalEntity, potentialNewValue);
}
}

var entry = Context.Entry(originalEntity);
UpdateSubEntities(entry, new HashSet<object>());

var updatedEntity = Context.Set<TEntity>().Update(originalEntity);

return await GetAsync(updatedEntity.Entity.Id);
}

private bool PropertyIsReadOnly(PropertyInfo prop) => (prop.GetCustomAttribute(typeof(ReadOnlyAttribute), true) as ReadOnlyAttribute)?.IsReadOnly ?? false;
private void UpdateSubEntities(EntityEntry entry, ISet<object> visited, bool updateSelf = false)
{
if (visited.Contains(entry.Entity))
return;

visited.Add(entry.Entity);

if (updateSelf && entry.Entity.GetType().IsAssignableToGenericType(typeof(EntityBase<>)))
{
entry.DetectChanges();
if (entry.State is EntityState.Detached or EntityState.Added)
{
entry.State = EntityState.Added;
var now = DateTime.UtcNow;

entry.Property(nameof(EntityBase<object>.Created)).CurrentValue = now;
entry.Property(nameof(EntityBase<object>.Updated)).CurrentValue = now;
}
else if (entry.State == EntityState.Modified)
{
entry.Property(nameof(EntityBase<object>.Updated)).CurrentValue = DateTime.UtcNow;
}
}

foreach (var subEntry in entry.Navigations)
{
if (subEntry is CollectionEntry {CurrentValue: { }} entryItems)
{
foreach (var subEntryItem in entryItems.CurrentValue)
{
var subEntryItemEntry = Context.Entry(subEntryItem);
UpdateSubEntities(subEntryItemEntry, visited, true);
}
}
else if (subEntry is ReferenceEntry {TargetEntry: {}} entryItem)
{
UpdateSubEntities(entryItem.TargetEntry, visited, true);
}
}
}

public virtual Task<bool> DeleteAsync(TEntity entity)
{
Expand Down Expand Up @@ -86,6 +131,7 @@ protected IQueryable<T> Paginate<T>(IQueryable<T> query, [Range(1, int.MaxValue)
{
if (page < 1)
throw new ArgumentException($"{nameof(page)} was below 1. Received: {page}", nameof(page));

if (pageSize < 1)
throw new ArgumentException($"{nameof(pageSize)} was below 1. Received: {pageSize}", nameof(pageSize));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ public class EntityRepository<TContext, TEntity, TId> : BaseEntityRepository<TCo
{
public EntityRepository(TContext context) : base(context) { }

public Task<TEntity?> GetAsync(Expression<Func<TEntity, bool>> where)
=> BaseIncludes().FirstOrDefaultAsync(where);

public virtual Task<ListWrapper<TEntity>> GetListAsync(
[Range(1, int.MaxValue)] int page,
[Range(1, int.MaxValue)] int pageSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ public interface IEntityRepository<TEntity, TId> : IBaseEntityRepository<TEntity
/// <param name="id">The ID of the entity to fetch.</param>
Task<TEntity> GetAsync(TId id);

/// <summary>
/// Get the entity, filtered by <paramref name="where"/>
/// </summary>
/// <param name="where">The filter expression</param>
Task<TEntity?> GetAsync(Expression<Func<TEntity, bool>> where);

/// <summary>
/// Get multiple entities paginated.
/// </summary>
Expand Down
58 changes: 58 additions & 0 deletions Monstarlab.EntityFramework.Extension/Utils/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using System.Collections;
using System.Collections.Concurrent;

namespace Monstarlab.EntityFramework.Extension.Utils;

internal static class TypeExtensions
{
private static readonly ConcurrentDictionary<Type, object?> TypeDefaults = new();

internal static object? GetDefaultValue(this Type type)
{
return type.GetTypeInfo().IsValueType
? TypeDefaults.GetOrAdd(type, Activator.CreateInstance)
: null;
}

internal static bool IsAssignableToGenericType(this Type givenType, Type genericType)
{
var interfaceTypes = givenType.GetInterfaces();

if (interfaceTypes.Any(it => it.IsGenericType && it.GetGenericTypeDefinition() == genericType))
{
return true;
}

if (givenType.IsGenericType && givenType.GetGenericTypeDefinition() == genericType)
{
return true;
}

if (givenType.BaseType is null)
{
return false;
}

return IsAssignableToGenericType(givenType.BaseType, genericType);
}
}

internal static class ReflectionExtensions
{
internal static object? GetPropertyValue(this object obj, string name) =>
obj?.GetType().GetProperty(name)?.GetValue(obj);
}

internal static class PropertyInfoExtensions
{
internal static bool IsNavigationProperty(this PropertyInfo prop) =>
prop.PropertyType.IsAssignableToGenericType(typeof(EntityBase<>));

internal static bool IsCollectionNavigationProperty(this PropertyInfo prop) =>
prop.PropertyType.IsAssignableTo(typeof(IEnumerable)) &&
prop.PropertyType != typeof(string) &&
prop.PropertyType.GenericTypeArguments[0].IsAssignableToGenericType(typeof(EntityBase<>));

internal static bool IsReadOnly(this PropertyInfo prop) =>
(prop.GetCustomAttribute(typeof(ReadOnlyAttribute), true) as ReadOnlyAttribute)?.IsReadOnly ?? false;
}