From 0f1793b3ce312325d2e602abd79fd10d64f8c2f5 Mon Sep 17 00:00:00 2001 From: "kerem.acer" Date: Mon, 4 Nov 2024 23:30:17 +0300 Subject: [PATCH 1/2] CSHARP-5396: Fix LINQ join error when add .Where to foreign collection queryable --- .../ExtensionMethods/ExpressionExtensions.cs | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index 671bc4ce40b..86232a3065d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -38,18 +38,30 @@ public static object Evaluate(this Expression expression) public static (string CollectionName, IBsonSerializer DocumentSerializer) GetCollectionInfo(this Expression innerExpression, Expression containerExpression) { - if (innerExpression is ConstantExpression constantExpression && - constantExpression.Value is IQueryable queryable && - queryable.Provider is IMongoQueryProviderInternal mongoQueryProvider && - mongoQueryProvider.CollectionNamespace != null) + var mongoQueryProvider = ExtractQueryProviderFromExpression(innerExpression); + if (mongoQueryProvider is not null) { return (mongoQueryProvider.CollectionNamespace.CollectionName, mongoQueryProvider.PipelineInputSerializer); } - var message = $"inner expression must be a MongoDB IQueryable against a collection"; + var message = "inner expression must be a MongoDB IQueryable against a collection"; throw new ExpressionNotSupportedException(innerExpression, containerExpression, because: message); } + private static IMongoQueryProviderInternal ExtractQueryProviderFromExpression(Expression expression) + { + return expression switch + { + MethodCallExpression methodCallExpression => ExtractQueryProviderFromExpression(methodCallExpression.Arguments.FirstOrDefault()), + ConstantExpression constantExpression => constantExpression.Value switch + { + IQueryable queryable => queryable.Provider as IMongoQueryProviderInternal, + _ => null + }, + _ => null + }; + } + public static TValue GetConstantValue(this Expression expression, Expression containingExpression) { if (expression is ConstantExpression constantExpression) From 19d088be923d240e410e1586eb85d5d8bb650f5a Mon Sep 17 00:00:00 2001 From: "kerem.acer" Date: Mon, 11 Nov 2024 08:09:32 +0300 Subject: [PATCH 2/2] Add support for arbitrary inner expression arguments --- .../ExtensionMethods/ExpressionExtensions.cs | 16 ++++----- .../GroupJoinMethodToPipelineTranslator.cs | 34 ++++++++++++++----- .../JoinMethodToPipelineTranslator.cs | 34 ++++++++++++++----- 3 files changed, 60 insertions(+), 24 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index 86232a3065d..f1d20f9fa25 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -36,29 +36,29 @@ public static object Evaluate(this Expression expression) } } - public static (string CollectionName, IBsonSerializer DocumentSerializer) GetCollectionInfo(this Expression innerExpression, Expression containerExpression) + public static (IMongoQueryProviderInternal QueryProvider, bool IsRawCollectionExpression) FindMongoQueryProvider(this Expression innerExpression, Expression containerExpression) { var mongoQueryProvider = ExtractQueryProviderFromExpression(innerExpression); - if (mongoQueryProvider is not null) + if (mongoQueryProvider.QueryProvider is not null) { - return (mongoQueryProvider.CollectionNamespace.CollectionName, mongoQueryProvider.PipelineInputSerializer); + return mongoQueryProvider; } var message = "inner expression must be a MongoDB IQueryable against a collection"; throw new ExpressionNotSupportedException(innerExpression, containerExpression, because: message); } - private static IMongoQueryProviderInternal ExtractQueryProviderFromExpression(Expression expression) + private static (IMongoQueryProviderInternal QueryProvider, bool IsRawCollectionExpression) ExtractQueryProviderFromExpression(Expression expression, int depth = 0) { return expression switch { - MethodCallExpression methodCallExpression => ExtractQueryProviderFromExpression(methodCallExpression.Arguments.FirstOrDefault()), + MethodCallExpression methodCallExpression => ExtractQueryProviderFromExpression(methodCallExpression.Arguments.FirstOrDefault(), depth + 1), ConstantExpression constantExpression => constantExpression.Value switch { - IQueryable queryable => queryable.Provider as IMongoQueryProviderInternal, - _ => null + IQueryable { Provider: IMongoQueryProviderInternal queryProvider } => (queryProvider, depth == 0), + _ => default }, - _ => null + _ => default }; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs index cc13c665965..0d82b308369 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Linq.Expressions; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; @@ -59,19 +60,36 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer); var innerExpression = arguments[1]; - var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression); + var (queryProvider, isRawCollectionExpression) = innerExpression.FindMongoQueryProvider(containerExpression: expression); var outerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[2]); var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer); var innerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[3]); - var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer); + var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, queryProvider.PipelineInputSerializer); - var lookupStage = AstStage.Lookup( - from: innerCollectionName, - localField, - foreignField, - @as: "_inner"); + AstStage lookupStage; + + if (isRawCollectionExpression) + { + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + @as: "_inner"); + } + else + { + var lookupPipeline = ExpressionToPipelineTranslator.Translate(context, innerExpression); + + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + Array.Empty(), + lookupPipeline, + @as: "_inner"); + } var resultSelectorLambda = ExpressionHelper.UnquoteLambda(arguments[4]); var root = AstExpression.Var("ROOT", isCurrent: true); @@ -80,7 +98,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var outerSymbol = context.CreateSymbol(outerParameter, outerField, outerSerializer); var innerParameter = resultSelectorLambda.Parameters[1]; var innerField = AstExpression.GetField(root, "_inner"); - var ienumerableInnerSerializer = IEnumerableSerializer.Create(innerSerializer); + var ienumerableInnerSerializer = IEnumerableSerializer.Create(queryProvider.PipelineInputSerializer); var innerSymbol = context.CreateSymbol(innerParameter, innerField, ienumerableInnerSerializer); var resultSelectorContext = context.WithSymbols(outerSymbol, innerSymbol); var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs index 3d61e3801b3..4e566a086fc 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs @@ -13,6 +13,8 @@ * limitations under the License. */ +using System.Linq; +using System; using System.Linq.Expressions; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; @@ -63,15 +65,31 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres AstProject.Exclude("_id")); var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer); - var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression); + var (queryProvider, isRawCollectionExpression) = innerExpression.FindMongoQueryProvider(containerExpression: expression); var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer); - var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer); + var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, queryProvider.PipelineInputSerializer); - var lookupStage = AstStage.Lookup( - from: innerCollectionName, - localField, - foreignField, - @as: "_inner"); + AstStage lookupStage; + + if (isRawCollectionExpression) + { + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + @as: "_inner"); + } + else + { + var lookupPipeline = ExpressionToPipelineTranslator.Translate(context, innerExpression); + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + Array.Empty(), + lookupPipeline, + @as: "_inner"); + } var unwindStage = AstStage.Unwind("_inner"); @@ -80,7 +98,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var outerSymbol = context.CreateSymbol(outerParameter, outerField, outerSerializer); var innerParameter = resultSelectorLambda.Parameters[1]; var innerField = AstExpression.GetField(AstExpression.RootVar, "_inner"); - var innerSymbol = context.CreateSymbol(innerParameter, innerField, innerSerializer); + var innerSymbol = context.CreateSymbol(innerParameter, innerField, queryProvider.PipelineInputSerializer); var resultSelectorContext = context.WithSymbols(outerSymbol, innerSymbol); var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(resultSelectorTranslation);