diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index 671bc4ce40b..f1d20f9fa25 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -36,20 +36,32 @@ public static object Evaluate(this Expression expression) } } - public static (string CollectionName, IBsonSerializer DocumentSerializer) GetCollectionInfo(this Expression innerExpression, Expression containerExpression) + public static (IMongoQueryProviderInternal QueryProvider, bool IsRawCollectionExpression) FindMongoQueryProvider(this Expression innerExpression, Expression containerExpression) { - if (innerExpression is ConstantExpression constantExpression && - constantExpression.Value is IQueryable queryable && - queryable.Provider is IMongoQueryProviderInternal mongoQueryProvider && - mongoQueryProvider.CollectionNamespace != null) + var mongoQueryProvider = ExtractQueryProviderFromExpression(innerExpression); + if (mongoQueryProvider.QueryProvider is not null) { - return (mongoQueryProvider.CollectionNamespace.CollectionName, mongoQueryProvider.PipelineInputSerializer); + return mongoQueryProvider; } - var message = $"inner expression must be a MongoDB IQueryable against a collection"; + var message = "inner expression must be a MongoDB IQueryable against a collection"; throw new ExpressionNotSupportedException(innerExpression, containerExpression, because: message); } + private static (IMongoQueryProviderInternal QueryProvider, bool IsRawCollectionExpression) ExtractQueryProviderFromExpression(Expression expression, int depth = 0) + { + return expression switch + { + MethodCallExpression methodCallExpression => ExtractQueryProviderFromExpression(methodCallExpression.Arguments.FirstOrDefault(), depth + 1), + ConstantExpression constantExpression => constantExpression.Value switch + { + IQueryable { Provider: IMongoQueryProviderInternal queryProvider } => (queryProvider, depth == 0), + _ => default + }, + _ => default + }; + } + public static TValue GetConstantValue(this Expression expression, Expression containingExpression) { if (expression is ConstantExpression constantExpression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs index cc13c665965..0d82b308369 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupJoinMethodToPipelineTranslator.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Linq.Expressions; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; @@ -59,19 +60,36 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer); var innerExpression = arguments[1]; - var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression); + var (queryProvider, isRawCollectionExpression) = innerExpression.FindMongoQueryProvider(containerExpression: expression); var outerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[2]); var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer); var innerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[3]); - var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer); + var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, queryProvider.PipelineInputSerializer); - var lookupStage = AstStage.Lookup( - from: innerCollectionName, - localField, - foreignField, - @as: "_inner"); + AstStage lookupStage; + + if (isRawCollectionExpression) + { + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + @as: "_inner"); + } + else + { + var lookupPipeline = ExpressionToPipelineTranslator.Translate(context, innerExpression); + + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + Array.Empty(), + lookupPipeline, + @as: "_inner"); + } var resultSelectorLambda = ExpressionHelper.UnquoteLambda(arguments[4]); var root = AstExpression.Var("ROOT", isCurrent: true); @@ -80,7 +98,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var outerSymbol = context.CreateSymbol(outerParameter, outerField, outerSerializer); var innerParameter = resultSelectorLambda.Parameters[1]; var innerField = AstExpression.GetField(root, "_inner"); - var ienumerableInnerSerializer = IEnumerableSerializer.Create(innerSerializer); + var ienumerableInnerSerializer = IEnumerableSerializer.Create(queryProvider.PipelineInputSerializer); var innerSymbol = context.CreateSymbol(innerParameter, innerField, ienumerableInnerSerializer); var resultSelectorContext = context.WithSymbols(outerSymbol, innerSymbol); var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs index 3d61e3801b3..4e566a086fc 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/JoinMethodToPipelineTranslator.cs @@ -13,6 +13,8 @@ * limitations under the License. */ +using System.Linq; +using System; using System.Linq.Expressions; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; @@ -63,15 +65,31 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres AstProject.Exclude("_id")); var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer); - var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfo(containerExpression: expression); + var (queryProvider, isRawCollectionExpression) = innerExpression.FindMongoQueryProvider(containerExpression: expression); var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer); - var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer); + var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, queryProvider.PipelineInputSerializer); - var lookupStage = AstStage.Lookup( - from: innerCollectionName, - localField, - foreignField, - @as: "_inner"); + AstStage lookupStage; + + if (isRawCollectionExpression) + { + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + @as: "_inner"); + } + else + { + var lookupPipeline = ExpressionToPipelineTranslator.Translate(context, innerExpression); + lookupStage = AstStage.Lookup( + from: queryProvider.CollectionNamespace.CollectionName, + localField, + foreignField, + Array.Empty(), + lookupPipeline, + @as: "_inner"); + } var unwindStage = AstStage.Unwind("_inner"); @@ -80,7 +98,7 @@ public static AstPipeline Translate(TranslationContext context, MethodCallExpres var outerSymbol = context.CreateSymbol(outerParameter, outerField, outerSerializer); var innerParameter = resultSelectorLambda.Parameters[1]; var innerField = AstExpression.GetField(AstExpression.RootVar, "_inner"); - var innerSymbol = context.CreateSymbol(innerParameter, innerField, innerSerializer); + var innerSymbol = context.CreateSymbol(innerParameter, innerField, queryProvider.PipelineInputSerializer); var resultSelectorContext = context.WithSymbols(outerSymbol, innerSymbol); var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(resultSelectorTranslation);