Skip to content

Commit

Permalink
CSHARP-5321: Change snippet deserialization strategy to remove limita…
Browse files Browse the repository at this point in the history
…tion on number of snippets.
  • Loading branch information
rstam committed Jan 14, 2025
1 parent fa39baf commit d175be3
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 170 deletions.
127 changes: 26 additions & 101 deletions src/MongoDB.Driver/ClientSideProjectionSnippetsDeserializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,134 +24,59 @@ namespace MongoDB.Driver
{
internal static class ClientSideProjectionSnippetsDeserializer
{
private static readonly Type[] __deserializerGenericTypeDefinitions =
[
null,
typeof(ClientSideProjectionSnippetsDeserializer<,>),
typeof(ClientSideProjectionSnippetsDeserializer<,,>),
typeof(ClientSideProjectionSnippetsDeserializer<,,,>),
typeof(ClientSideProjectionSnippetsDeserializer<,,,,>)
];

public const int MaxNumberOfSnippets = 4; // could be expanded up to 16

public static IBsonSerializer Create(
Type projectionType,
IBsonSerializer[] snippetDeserializers,
Delegate projector)
{
var snippetTypes = snippetDeserializers.Select(s => s.ValueType).ToArray();
var deserializerGenericTypeDefinition = __deserializerGenericTypeDefinitions[snippetTypes.Length];
var deserializerGenericTypeArguments = snippetTypes.Append(projectionType).ToArray();
var deserializerType = deserializerGenericTypeDefinition.MakeGenericType(deserializerGenericTypeArguments);
var deserializerType = typeof(ClientSideProjectionSnippetsDeserializer<>).MakeGenericType(projectionType);
return (IBsonSerializer)Activator.CreateInstance(deserializerType, [snippetDeserializers, projector]);
}
}

internal abstract class ClientSideProjectionSnippetsDeserializer<TProjection> : SerializerBase<TProjection>, IClientSideProjectionDeserializer
internal sealed class ClientSideProjectionSnippetsDeserializer<TProjection> : SerializerBase<TProjection>, IClientSideProjectionDeserializer
{
private readonly IBsonSerializer[] _snippetDeserializers;
private readonly Func<object[], TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers)
public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<object[], TProjection> projector)
{
_snippetDeserializers = snippetDeserializers;
_projector = projector;
}

protected object[] DeserializeSnippets(BsonDeserializationContext context)
public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector(snippets);
}

private object[] DeserializeSnippets(BsonDeserializationContext context)
{
var reader = context.Reader;
var snippets = new object[_snippetDeserializers.Length];

reader.ReadStartDocument();
reader.ReadName("_snippets");
reader.ReadStartArray();
var snippets = new object[_snippetDeserializers.Length];
var i = 0;
while (reader.ReadBsonType() != BsonType.EndOfDocument)
{
var name = reader.ReadName();
var i = ParseIndex(name);
if (i >= _snippetDeserializers.Length)
{
throw new BsonSerializationException($"Expected {_snippetDeserializers.Length} snippets but found more than that.");
}
snippets[i] = _snippetDeserializers[i].Deserialize(context);
i++;
}
reader.ReadEndDocument();

return snippets;

int ParseIndex(string name)
if (i != _snippetDeserializers.Length)
{
if (name.StartsWith("_") &&
int.TryParse(name.Substring(1), out var index) &&
index >= 0 && index < _snippetDeserializers.Length)
{
return index;
}

throw new FormatException("Invalid snippet name: " + name);
throw new BsonSerializationException($"Expected {_snippetDeserializers.Length} snippets but found {i}.");
}
}
}

internal class ClientSideProjectionSnippetsDeserializer<T1, TProjection> : ClientSideProjectionSnippetsDeserializer<TProjection>
{
private readonly Func<T1, TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<T1, TProjection> projector)
: base(snippetDeserializers)
{
_projector = projector;
}

public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector((T1)snippets[0]);
}
}

internal class ClientSideProjectionSnippetsDeserializer<T1, T2, TProjection> : ClientSideProjectionSnippetsDeserializer<TProjection>
{
private readonly Func<T1, T2, TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<T1, T2, TProjection> projector)
: base(snippetDeserializers)
{
_projector = projector;
}

public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector((T1)snippets[0], (T2)snippets[1]);
}
}

internal class ClientSideProjectionSnippetsDeserializer<T1, T2, T3, TProjection> : ClientSideProjectionSnippetsDeserializer<TProjection>
{
private readonly Func<T1, T2, T3, TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<T1, T2, T3, TProjection> projector)
: base(snippetDeserializers)
{
_projector = projector;
}

public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector((T1)snippets[0], (T2)snippets[1], (T3)snippets[2]);
}
}

internal class ClientSideProjectionSnippetsDeserializer<T1, T2, T3, T4, TProjection> : ClientSideProjectionSnippetsDeserializer<TProjection>
{
private readonly Func<T1, T2, T3, T4, TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<T1, T2, T3, T4, TProjection> projector)
: base(snippetDeserializers)
{
_projector = projector;
}
reader.ReadEndArray();
reader.ReadEndDocument();

public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector((T1)snippets[0], (T2)snippets[1], (T3)snippets[2], (T4)snippets[3]);
return snippets;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
using System.Linq.Expressions;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Core.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
Expand All @@ -43,7 +42,7 @@ public static (AstProjectStage, IBsonSerializer) CreateClientSideProjection(
}
}

private static (AstExpression, IBsonSerializer) TranslateLambdaBodyUsingSnippets(
private static (AstComputedDocumentExpression, IBsonSerializer) TranslateLambdaBodyUsingSnippets(
TranslationContext context,
IBsonSerializer sourceSerializer,
LambdaExpression projectionLambda)
Expand All @@ -57,7 +56,7 @@ private static (AstExpression, IBsonSerializer) TranslateLambdaBodyUsingSnippets

var snippets = ClientSideProjectionSnippetsTranslator.TranslateSnippets(context, projectionLambda, sourceSerializer);

if (snippets.Length == 0 || snippets.Length > ClientSideProjectionSnippetsDeserializer.MaxNumberOfSnippets || snippets.Any(IsRoot))
if (snippets.Length == 0 || snippets.Any(IsRoot))
{
var clientSideProjectionDeserializer = ClientSideProjectionDeserializer.Create(sourceSerializer, projectionLambda);
return (null, clientSideProjectionDeserializer); // project directly off $$ROOT with no snippets
Expand All @@ -77,34 +76,28 @@ private static (AstExpression, IBsonSerializer) TranslateLambdaBodyUsingSnippets

private static AstComputedDocumentExpression CreateSnippetsComputedDocument(AggregationExpression[] snippets)
{
var numberOfSnippets = snippets.Length;
var computedFields = new AstComputedField[numberOfSnippets];
var snippetsArray = AstExpression.ComputedArray(snippets.Select(s => s.Ast));
var snippetsdField = AstExpression.ComputedField("_snippets", snippetsArray);
return (AstComputedDocumentExpression)AstExpression.ComputedDocument([snippetsdField]);

for (var i = 0; i < numberOfSnippets; i++)
{
var name = $"_{i}";
var snippet = snippets[i];
computedFields[i] = AstExpression.ComputedField(name, snippet.Ast);
}

return (AstComputedDocumentExpression)AstExpression.ComputedDocument(computedFields);
}

private static LambdaExpression RewriteSelector(LambdaExpression selectorLambda, AggregationExpression[] snippets)
{
var numberOfSnippets = snippets.Length;
var snippetParameters = new ParameterExpression[numberOfSnippets];
var rewrittenBody = selectorLambda.Body;
var snippetsParameter = Expression.Parameter(typeof(object[]), "snippets");

for (var i = 0; i < numberOfSnippets; i++)
for (var i = 0; i < snippets.Length; i++)
{
var snippet = snippets[i];
var snippetParameter = Expression.Parameter(snippet.Expression.Type, $"_{i}");
rewrittenBody = ExpressionReplacer.Replace(rewrittenBody, snippet.Expression, snippetParameter);
snippetParameters[i] = snippetParameter;
var snippetReference = // (T)_snippets[i]
Expression.Convert(
Expression.ArrayIndex(snippetsParameter, Expression.Constant(i)),
snippet.Expression.Type);
rewrittenBody = ExpressionReplacer.Replace(rewrittenBody, snippet.Expression, snippetReference);
}

return Expression.Lambda(rewrittenBody, snippetParameters);
return Expression.Lambda(rewrittenBody, snippetsParameter);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public void Bottom_without_GroupBy_should_have_helpful_error_message(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$A', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$A'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var exception = Record.Exception(() => queryable.ToList());
Expand Down Expand Up @@ -398,7 +398,7 @@ public void BottomN_without_GroupBy_should_have_helpful_error_message(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$A', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$A'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var exception = Record.Exception(() => queryable.ToList());
Expand Down Expand Up @@ -1660,7 +1660,7 @@ public void Top_without_GroupBy_should_have_helpful_error_message(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$A', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$A'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var exception = Record.Exception(() => queryable.ToList());
Expand Down Expand Up @@ -1892,7 +1892,7 @@ public void TopN_without_GroupBy_should_have_helpful_error_message(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$A', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$A'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var exception = Record.Exception(() => queryable.ToList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public void OrderBy_on_entire_object_followed_by_ThenBy_should_throw(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$Team', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$Team'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var result = queryable.Single();
Expand Down Expand Up @@ -187,7 +187,7 @@ public void OrderByDescending_on_entire_object_followed_by_ThenBy_should_throw(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$Team', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$Team'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var result = queryable.Single();
Expand Down Expand Up @@ -218,7 +218,7 @@ public void ThenBy_on_entire_object_should_throw(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$Team', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$Team'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var result = queryable.Single();
Expand Down Expand Up @@ -248,7 +248,7 @@ public void ThenByDescending_on_entire_object_should_throw(
if (enableClientSideProjections)
{
var stages = Translate(collection, queryable, out var outputSerializer);
AssertStages(stages, "{ $project : { _0 : '$Team', _id : 0 } }");
AssertStages(stages, "{ $project : { _snippets : ['$Team'], _id : 0 } }");
outputSerializer.Should().BeAssignableTo<IClientSideProjectionDeserializer>();

var result = queryable.Single();
Expand Down
Loading

0 comments on commit d175be3

Please sign in to comment.