Skip to content

Commit

Permalink
CSHARP-5459: Standardize on using AstExpression.RootVar and Context.C…
Browse files Browse the repository at this point in the history
…reateRootSymbol.
  • Loading branch information
rstam committed Jan 14, 2025
1 parent c8f0429 commit 8ac275a
Show file tree
Hide file tree
Showing 21 changed files with 47 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public static bool IsInt32Constant(this AstExpression expression, out int value)
public static bool IsMaxInt32(this AstExpression expression)
=> expression.IsInt32Constant(out var value) && value == int.MaxValue;

public static bool IsRootVar(this AstExpression expression)
=> expression is AstVarExpression varExpression && varExpression.Name == "ROOT" && varExpression.IsCurrent;

public static bool IsZero(this AstExpression expression)
=> expression is AstConstantExpression constantExpression && constantExpression.Value == 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ static bool ProjectsRoot(AstProjectStage projectStage)
return projectStage.Specifications.Any(
specification =>
specification is AstProjectStageSetFieldSpecification setFieldSpecification &&
setFieldSpecification.Value is AstVarExpression varExpression &&
varExpression.Name == "ROOT");
setFieldSpecification.Value.IsRootVar());
}
}
}
Expand Down Expand Up @@ -370,14 +369,12 @@ public override AstNode VisitMapExpression(AstMapExpression node)
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
mapInputconstantFieldExpression.Value.IsString &&
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
mapInputGetFieldVarExpression.Name == "ROOT")
mapInputGetFieldExpression.Input.IsRootVar())
{
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
var root = AstExpression.Var("ROOT", isCurrent: true);
return AstExpression.GetField(root, accumulatorFieldName);
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
}

return base.VisitMapExpression(node);
Expand All @@ -388,8 +385,7 @@ public override AstNode VisitPickExpression(AstPickExpression node)
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
if (node.Source is AstGetFieldExpression getFieldExpression &&
getFieldExpression.Input is AstVarExpression varExpression &&
varExpression.Name == "ROOT" &&
getFieldExpression.Input.IsRootVar() &&
getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
constantFieldNameExpression.Value.IsString &&
constantFieldNameExpression.Value.AsString == "_elements")
Expand All @@ -398,17 +394,14 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
var root = AstExpression.Var("ROOT", isCurrent: true);
return AstExpression.GetField(root, accumulatorFieldName);
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
}

return base.VisitPickExpression(node);
}

public override AstNode VisitUnaryExpression(AstUnaryExpression node)
{
var root = AstExpression.Var("ROOT", isCurrent: true);

if (TryOptimizeSizeOfElements(out var optimizedExpression))
{
return optimizedExpression;
Expand Down Expand Up @@ -438,7 +431,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres
{
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
return true;
}
}
Expand All @@ -455,12 +448,11 @@ node.Arg is AstGetFieldExpression getFieldExpression &&
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
getFieldConstantFieldNameExpression.Value.IsString &&
getFieldConstantFieldNameExpression.Value == "_elements" &&
getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
getFieldInputVarExpression.Name == "ROOT")
getFieldExpression.Input.IsRootVar())
{
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
return true;
}

Expand All @@ -478,13 +470,12 @@ mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
mapInputconstantFieldExpression.Value.IsString &&
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
mapInputGetFieldVarExpression.Name == "ROOT")
mapInputGetFieldExpression.Input.IsRootVar())
{
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ protected override AstStage RenderGroupingStage(
var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
var serializedBoundaries = SerializationHelper.SerializeValues(valueSerializer, _boundaries);
var serializedDefault = _options != null && _options.DefaultBucket.HasValue ? SerializationHelper.SerializeValue(valueSerializer, _options.DefaultBucket.Value) : null;
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
groupingOutputSerializer = IGroupingSerializer.Create(valueSerializer, inputSerializer);

return AstStage.Bucket(
Expand Down Expand Up @@ -156,7 +156,7 @@ protected override AstStage RenderGroupingStage(
var valueSerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
var keySerializer = AggregateBucketAutoResultIdSerializer.Create(valueSerializer);
var serializedGranularity = _options != null && _options.Granularity.HasValue ? _options.Granularity.Value.Value : null;
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
groupingOutputSerializer = IGroupingSerializer.Create(keySerializer, inputSerializer);

return AstStage.BucketAuto(
Expand Down Expand Up @@ -190,7 +190,7 @@ protected override AstStage RenderGroupingStage(
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true));
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
var groupBySerializer = (IBsonSerializer<TValue>)groupByTranslation.Serializer;
groupingOutputSerializer = IGroupingSerializer.Create(groupBySerializer, inputSerializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static string TranslateToDottedFieldName(this LambdaExpression fieldSelec
{
throw new ArgumentException($"ValueType '{parameterSerializer.ValueType.FullName}' of parameterSerializer does not match parameter type '{parameterExpression.Type.FullName}'.", nameof(parameterSerializer));
}
var parameterSymbol = context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true);
var parameterSymbol = context.CreateRootSymbol(parameterExpression, parameterSerializer);
var lambdaContext = context.WithSymbol(parameterSymbol);
var lambdaBody = ConvertHelper.RemoveConvertToObject(fieldSelectorLambda.Body);
var fieldSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(lambdaContext, lambdaBody);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private static (string, IBsonSerializer) CreateGetFieldChainWithSafeFieldNamesPr
var wrappedValueSerializer = WrappedValueSerializer.Create(fieldName, serializer);
var input = getFieldExpression.Input;

if (input is AstVarExpression varExpression && varExpression.Name == "ROOT")
if (input.IsRootVar())
{
return (fieldName, wrappedValueSerializer);
}
Expand Down Expand Up @@ -132,7 +132,7 @@ private static bool IsGetFieldChainWithSafeFieldNames(AstGetFieldExpression getF
return
getFieldExpression.HasSafeFieldName(out _) &&
(
(getFieldExpression.Input is AstVarExpression varExpression && varExpression.Name == "ROOT") ||
(getFieldExpression.Input.IsRootVar()) ||
(getFieldExpression.Input is AstGetFieldExpression nestedGetFieldExpression && IsGetFieldChainWithSafeFieldNames(nestedGetFieldExpression))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public static AggregationExpression TranslateLambdaBody(
}
var parameterSymbol =
asRoot ?
context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true) :
context.CreateRootSymbol(parameterExpression, parameterSerializer) :
context.CreateSymbol(parameterExpression, parameterSerializer, isCurrent: false);

return TranslateLambdaBody(context, lambdaExpression, parameterSymbol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,16 @@ private static bool IsGroupingSource(AstExpression source)
{
return
source is AstGetFieldExpression getFieldExpression &&
getFieldExpression.Input is AstVarExpression inputVarExpression &&
inputVarExpression.Name == "ROOT" &&
getFieldExpression.Input.IsRootVar() &&
getFieldExpression.FieldName is AstConstantExpression fieldNameConstantExpression &&
fieldNameConstantExpression.Value == "_elements";
}

private static bool IsValidKey(AggregationExpression keyTranslation)
{
if (keyTranslation.Ast is AstGetFieldExpression getFieldExpression &&
getFieldExpression.Input is AstVarExpression inputVarExpression &&
getFieldExpression.Input.IsRootVar() &&
getFieldExpression.FieldName is AstConstantExpression constantFieldName &&
inputVarExpression.Name == "ROOT" &&
constantFieldName.Value == "_id")
{
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
else
{
Ensure.That(sourceSerializer is IWrappedValueSerializer, "Expected sourceSerializer to be an IWrappedValueSerializer.", nameof(sourceSerializer));
var root = AstExpression.Var("ROOT", isCurrent: true);
valueExpression = AstExpression.GetField(root, "_v");
valueExpression = AstExpression.GetField(AstExpression.RootVar, "_v");
}

IBsonSerializer outputValueSerializer = expression.GetResultType() switch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static ExecutableQuery<TDocument, bool> Translate<TDocument>(MongoQueryPr
wrappedValueSerializer,
AstStage.Project(
AstProject.ExcludeId(),
AstProject.Set("_v", AstExpression.Var("ROOT"))));
AstProject.Set("_v", AstExpression.RootVar)));
}

var itemExpression = arguments[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
pipeline.OutputSerializer,
AstStage.Group(
id: BsonNull.Value,
fields: AstExpression.AccumulatorField("_last", AstUnaryAccumulatorOperator.Last, AstExpression.Var("ROOT"))));
fields: AstExpression.AccumulatorField("_last", AstUnaryAccumulatorOperator.Last, AstExpression.RootVar)));

var finalizer = method.Name == "LastOrDefault" ? __singleOrDefaultFinalizer : __singleFinalizer;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);

var sourceSerializer = pipeline.OutputSerializer;
var root = AstExpression.Var("ROOT", isCurrent: true);
AstExpression valueAst;
IBsonSerializer valueSerializer;
if (method.IsOneOf(__maxWithSelectorMethods))
Expand All @@ -86,7 +85,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
}
else
{
valueAst = root;
valueAst = AstExpression.RootVar;
valueSerializer = pipeline.OutputSerializer;
}

Expand All @@ -95,7 +94,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
AstStage.Group(
id: BsonNull.Value,
fields: AstExpression.AccumulatorField("_max", AstUnaryAccumulatorOperator.Max, valueAst)),
AstStage.ReplaceRoot(AstExpression.GetField(root, "_max")));
AstStage.ReplaceRoot(AstExpression.GetField(AstExpression.RootVar, "_max")));

return ExecutableQuery.Create(
provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);

var sourceSerializer = pipeline.OutputSerializer;
var root = AstExpression.Var("ROOT", isCurrent: true);
AstExpression valueAst;
IBsonSerializer valueSerializer;
if (method.IsOneOf(__minWithSelectorMethods))
Expand All @@ -86,7 +85,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
}
else
{
valueAst = root;
valueAst = AstExpression.RootVar;
valueSerializer = pipeline.OutputSerializer;
}

Expand All @@ -95,7 +94,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
AstStage.Group(
id: BsonNull.Value,
fields: AstExpression.AccumulatorField("_min", AstUnaryAccumulatorOperator.Min, valueAst)),
AstStage.ReplaceRoot(AstExpression.GetField(root, "_min")));
AstStage.ReplaceRoot(AstExpression.GetField(AstExpression.RootVar, "_min")));

return ExecutableQuery.Create(
provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
}
else
{
var root = AstExpression.Var("ROOT", isCurrent: true);
valueAst = AstExpression.GetField(root, "_v");
valueAst = AstExpression.GetField(AstExpression.RootVar, "_v");
}
var outputValueType = expression.GetResultType();
var outputValueSerializer = BsonSerializer.LookupSerializer(outputValueType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
var arguments = expression.Arguments;

if (method.IsOneOf(__sumMethods))
{
{
var sourceExpression = arguments[0];
var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression);
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);
Expand All @@ -131,8 +131,7 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
else
{
Ensure.That(sourceSerializer is IWrappedValueSerializer, "Expected sourceSerializer to be an IWrappedValueSerializer.", nameof(sourceSerializer));
var rootVar = AstExpression.Var("ROOT", isCurrent: true);
valueAst = AstExpression.GetField(rootVar, "_v");
valueAst = AstExpression.GetField(AstExpression.RootVar, "_v");
}

var outputValueType = expression.GetResultType();
Expand Down
Loading

0 comments on commit 8ac275a

Please sign in to comment.