Skip to content

Commit

Permalink
CSHARP-5321: Optimize client-side projections to perform as much as p…
Browse files Browse the repository at this point in the history
…ossible of the projection on the server.
  • Loading branch information
rstam committed Jan 14, 2025
1 parent 8ac275a commit 34ca6d3
Show file tree
Hide file tree
Showing 19 changed files with 611 additions and 90 deletions.
82 changes: 82 additions & 0 deletions src/MongoDB.Driver/ClientSideProjectionSnippetsDeserializer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Linq;
using MongoDB.Bson;
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;

namespace MongoDB.Driver
{
internal static class ClientSideProjectionSnippetsDeserializer
{
public static IBsonSerializer Create(
Type projectionType,
IBsonSerializer[] snippetDeserializers,
Delegate projector)
{
var deserializerType = typeof(ClientSideProjectionSnippetsDeserializer<>).MakeGenericType(projectionType);
return (IBsonSerializer)Activator.CreateInstance(deserializerType, [snippetDeserializers, projector]);
}
}

internal sealed class ClientSideProjectionSnippetsDeserializer<TProjection> : SerializerBase<TProjection>, IClientSideProjectionDeserializer
{
private readonly IBsonSerializer[] _snippetDeserializers;
private readonly Func<object[], TProjection> _projector;

public ClientSideProjectionSnippetsDeserializer(IBsonSerializer[] snippetDeserializers, Func<object[], TProjection> projector)
{
_snippetDeserializers = snippetDeserializers;
_projector = projector;
}

public override TProjection Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
{
var snippets = DeserializeSnippets(context);
return _projector(snippets);
}

private object[] DeserializeSnippets(BsonDeserializationContext context)
{
var reader = context.Reader;

reader.ReadStartDocument();
reader.ReadName("_snippets");
reader.ReadStartArray();
var snippets = new object[_snippetDeserializers.Length];
var i = 0;
while (reader.ReadBsonType() != BsonType.EndOfDocument)
{
if (i >= _snippetDeserializers.Length)
{
throw new BsonSerializationException($"Expected {_snippetDeserializers.Length} snippets but found more than that.");
}
snippets[i] = _snippetDeserializers[i].Deserialize(context);
i++;
}
if (i != _snippetDeserializers.Length)
{
throw new BsonSerializationException($"Expected {_snippetDeserializers.Length} snippets but found {i}.");
}
reader.ReadEndArray();
reader.ReadEndDocument();

return snippets;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
{
internal class ExpressionIsReferencedVisitor : ExpressionVisitor
{
#region static
public static bool IsReferenced(Expression node, Expression expression)
{
var visitor = new ExpressionIsReferencedVisitor(expression);
visitor.Visit(node);
return visitor.ExpressionIsReferenced;
}
#endregion

private readonly Expression _expression;
private bool _expressionIsReferenced;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ internal static class LambdaExpressionExtensions
{
public static bool LambdaBodyReferencesParameter(this LambdaExpression lambda, ParameterExpression parameter)
{
var visitor = new ExpressionIsReferencedVisitor(parameter);
visitor.Visit(lambda.Body);
return visitor.ExpressionIsReferenced;
return ExpressionIsReferencedVisitor.IsReferenced(lambda.Body, parameter);
}

public static string TranslateToDottedFieldName(this LambdaExpression fieldSelectorLambda, TranslationContext context, IBsonSerializer parameterSerializer)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq;
using System.Linq.Expressions;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Core.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
internal static class ClientSideProjectionExpressionRewriter
{
public static (AstProjectStage, IBsonSerializer) CreateClientSideProjection(
TranslationContext context,
LambdaExpression projectionLambda,
IBsonSerializer sourceSerializer)
{
var (snippetsExpression, snippetsProjectionDeserializer) = ClientSideProjectionExpressionRewriter.TranslateLambdaBodyUsingSnippets(context, sourceSerializer, projectionLambda);
if (snippetsExpression == null)
{
return (null, snippetsProjectionDeserializer);
}
else
{
var snippetsTranslation = new AggregationExpression(projectionLambda, snippetsExpression, snippetsProjectionDeserializer);
return ProjectionHelper.CreateProjectStage(snippetsTranslation);
}
}

private static (AstComputedDocumentExpression, IBsonSerializer) TranslateLambdaBodyUsingSnippets(
TranslationContext context,
IBsonSerializer sourceSerializer,
LambdaExpression projectionLambda)
{
var wireVersion = context.TranslationOptions.CompatibilityLevel.ToWireVersion();
if (!Feature.FindProjectionExpressions.IsSupported(wireVersion))
{
var clientSideProjectionDeserializer = ClientSideProjectionDeserializer.Create(sourceSerializer, projectionLambda);
return (null, clientSideProjectionDeserializer); // project directly off $$ROOT with no snippets
}

var snippets = ClientSideProjectionSnippetsTranslator.TranslateSnippets(context, projectionLambda, sourceSerializer);

if (snippets.Length == 0 || snippets.Any(IsRoot))
{
var clientSideProjectionDeserializer = ClientSideProjectionDeserializer.Create(sourceSerializer, projectionLambda);
return (null, clientSideProjectionDeserializer); // project directly off $$ROOT with no snippets
}
else
{
var snippetsComputedDocument = CreateSnippetsComputedDocument(snippets);
var snippetDeserializers = snippets.Select(s => s.Serializer).ToArray();
var rewrittenSelectorLamdba = RewriteSelector(projectionLambda, snippets);
var rewrittenSelectorDelegate = rewrittenSelectorLamdba.Compile();
var clientSideProjectionSnippetsDeserializer = ClientSideProjectionSnippetsDeserializer.Create(projectionLambda.ReturnType, snippetDeserializers, rewrittenSelectorDelegate);
return (snippetsComputedDocument, clientSideProjectionSnippetsDeserializer);
}

static bool IsRoot(AggregationExpression snippet) => snippet.Ast.IsRootVar();
}

private static AstComputedDocumentExpression CreateSnippetsComputedDocument(AggregationExpression[] snippets)
{
var snippetsArray = AstExpression.ComputedArray(snippets.Select(s => s.Ast));
var snippetsdField = AstExpression.ComputedField("_snippets", snippetsArray);
return (AstComputedDocumentExpression)AstExpression.ComputedDocument([snippetsdField]);

}

private static LambdaExpression RewriteSelector(LambdaExpression selectorLambda, AggregationExpression[] snippets)
{
var rewrittenBody = selectorLambda.Body;
var snippetsParameter = Expression.Parameter(typeof(object[]), "snippets");

for (var i = 0; i < snippets.Length; i++)
{
var snippet = snippets[i];
var snippetReference = // (T)_snippets[i]
Expression.Convert(
Expression.ArrayIndex(snippetsParameter, Expression.Constant(i)),
snippet.Expression.Type);
rewrittenBody = ExpressionReplacer.Replace(rewrittenBody, snippet.Expression, snippetReference);
}

return Expression.Lambda(rewrittenBody, snippetsParameter);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
internal class ClientSideProjectionSnippetsTranslator : ExpressionVisitor
{
#region static

private readonly static MethodInfo[] __orderByMethods =
[
EnumerableMethod.OrderBy,
EnumerableMethod.OrderByDescending,
QueryableMethod.OrderBy,
QueryableMethod.OrderByDescending
];

private readonly static MethodInfo[] __thenByMethods =
[
EnumerableMethod.ThenBy,
EnumerableMethod.ThenByDescending,
QueryableMethod.ThenBy,
QueryableMethod.ThenByDescending
];

public static AggregationExpression[] TranslateSnippets(TranslationContext context, LambdaExpression selectorLambda, IBsonSerializer sourceSerializer)
{
var rootParameter = selectorLambda.Parameters.Single();
var rootSymbol = context.CreateRootSymbol(rootParameter, sourceSerializer);
context = context.WithSymbol(rootSymbol);

var snippetTranslator = new ClientSideProjectionSnippetsTranslator(context, rootParameter);
snippetTranslator.Visit(selectorLambda.Body);

return snippetTranslator.Snippets.ToArray();
}

#endregion

private readonly TranslationContext _context;
private readonly ParameterExpression _rootParameter;
private readonly List<AggregationExpression> _snippets = new();

private ClientSideProjectionSnippetsTranslator(TranslationContext context, ParameterExpression rootParameter)
{
_context = context;
_rootParameter = rootParameter;
}

private List<AggregationExpression> Snippets => _snippets;

public override Expression Visit(Expression node)
{
if (ExpressionIsReferencedVisitor.IsReferenced(node, _rootParameter))
{
try
{
var snippet = ExpressionToAggregationExpressionTranslator.Translate(_context, node);
_snippets.Add(snippet);
return node;
}
catch
{
// don't split OrderBy/ThenBy between client and server
if (node is MethodCallExpression methodCallExpression &&
methodCallExpression.Method.IsOneOf(__thenByMethods))
{
var orderBySource = FindOrderBySource(node);
Visit(orderBySource); // resume visiting at orderBySource
return node; // suppress any further visiting below this node
}

// ignore exceptions and fall through
}
}

return base.Visit(node);

static Expression FindOrderBySource(Expression node)
{
if (node is MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsOneOf(__thenByMethods))
{
return FindOrderBySource(methodCallExpression.Arguments[0]);
}

if (methodCallExpression.Method.IsOneOf(__orderByMethods))
{
return methodCallExpression.Arguments[0];
}
}

throw new ArgumentException($"Node type {node.NodeType} is not a MethodCallExpression.");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

using System;
using System.Linq.Expressions;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
Expand Down Expand Up @@ -46,19 +47,21 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);

var sourceSerializer = pipeline.OutputSerializer;
AstProjectStage projectStage;
IBsonSerializer projectionSerializer;
try
{
var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true);
var (projectStage, projectionSerializer) = ProjectionHelper.CreateProjectStage(selectorTranslation);
pipeline = pipeline.AddStages(projectionSerializer, projectStage);
(projectStage, projectionSerializer) = ProjectionHelper.CreateProjectStage(selectorTranslation);
}
catch (ExpressionNotSupportedException) when (context.TranslationOptions?.EnableClientSideProjections ?? false)
{
var clientSideProjectionDeserializer = ClientSideProjectionDeserializer.Create(sourceSerializer, selectorLambda);
pipeline = pipeline.AddStages(clientSideProjectionDeserializer, Array.Empty<AstStage>());
(projectStage, projectionSerializer) = ClientSideProjectionExpressionRewriter.CreateClientSideProjection(context, selectorLambda, sourceSerializer);
}

return pipeline;
return projectStage == null ?
new TranslatedPipeline(pipeline.Ast, projectionSerializer) : // just switch the output serializer
pipeline.AddStages(projectionSerializer, projectStage);
}

throw new ExpressionNotSupportedException(expression);
Expand Down
6 changes: 3 additions & 3 deletions src/MongoDB.Driver/Linq/LinqProviderAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ private static RenderedProjectionDefinition<TOutput> TranslateExpressionToProjec
}
catch (ExpressionNotSupportedException) when (translationOptions?.EnableClientSideProjections ?? false)
{
var projectorDelegate = expression.Compile();
var clientSideProjectionDeserializer = new ClientSideProjectionDeserializer<TInput, TOutput>(inputSerializer, projectorDelegate);
return new RenderedProjectionDefinition<TOutput>(document: null, clientSideProjectionDeserializer);
var (projectStage, projectionSerializer) = ClientSideProjectionExpressionRewriter.CreateClientSideProjection(context, expression, inputSerializer);
var projectionDocument = projectStage == null ? null : AstSimplifier.SimplifyAndConvert(projectStage).Render()["$project"].AsBsonDocument;
return new RenderedProjectionDefinition<TOutput>(projectionDocument, (IBsonSerializer<TOutput>)projectionSerializer);
}
}

Expand Down
Loading

0 comments on commit 34ca6d3

Please sign in to comment.