diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/DocStoreQueryV1Test.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/DocStoreQueryV1Test.java index f684b118..5cdac2c4 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/DocStoreQueryV1Test.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/DocStoreQueryV1Test.java @@ -14,8 +14,11 @@ import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MAX; import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MIN; import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.SUM; +import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.DIVIDE; +import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.FLOOR; import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.LENGTH; import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.MULTIPLY; +import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.SUBTRACT; import static org.hypertrace.core.documentstore.expression.operators.LogicalOperator.AND; import static org.hypertrace.core.documentstore.expression.operators.LogicalOperator.OR; import static org.hypertrace.core.documentstore.expression.operators.RelationalOperator.CONTAINS; @@ -87,6 +90,7 @@ import org.hypertrace.core.documentstore.model.options.UpdateOptions; import org.hypertrace.core.documentstore.model.subdoc.SubDocumentUpdate; import org.hypertrace.core.documentstore.model.subdoc.SubDocumentValue; +import org.hypertrace.core.documentstore.query.Aggregation; import org.hypertrace.core.documentstore.query.Filter; import org.hypertrace.core.documentstore.query.Pagination; import org.hypertrace.core.documentstore.query.Query; @@ -3312,6 +3316,52 @@ public void testNotExistsOperatorWithFindUsingBooleanRhs(String dataStoreName) t testCountApi(dataStoreName, query, "query/not_exists_filter_response.json"); } + @ParameterizedTest + @ArgumentsSource(MongoProvider.class) + public void testMongoFunctionExpressionGroupBy(String dataStoreName) throws Exception { + Collection collection = getCollection(dataStoreName); + + FunctionExpression functionExpression = + FunctionExpression.builder() + .operator(FLOOR) + .operand( + FunctionExpression.builder() + .operator(DIVIDE) + .operand( + FunctionExpression.builder() + .operator(SUBTRACT) + .operand(IdentifierExpression.of("price")) + .operand(ConstantExpression.of(5)) + .build()) + .operand(ConstantExpression.of(5)) + .build()) + .build(); + List selectionSpecs = + List.of( + SelectionSpec.of(functionExpression, "function"), + SelectionSpec.of( + AggregateExpression.of(COUNT, IdentifierExpression.of("function")), + "functionCount")); + Selection selection = Selection.builder().selectionSpecs(selectionSpecs).build(); + + Query query = + Query.builder() + .setSelection(selection) + .setAggregation( + Aggregation.builder().expression(IdentifierExpression.of("function")).build()) + .setSort( + Sort.builder() + .sortingSpec(SortingSpec.of(IdentifierExpression.of("function"), ASC)) + .build()) + .build(); + + Iterator resultDocs = collection.aggregate(query); + assertDocsAndSizeEqualWithoutOrder( + dataStoreName, resultDocs, "query/function_expression_group_by_response.json", 3); + + testCountApi(dataStoreName, query, "query/function_expression_group_by_response.json"); + } + private static Collection getCollection(final String dataStoreName) { return getCollection(dataStoreName, COLLECTION_NAME); } diff --git a/document-store/src/integrationTest/resources/query/function_expression_group_by_response.json b/document-store/src/integrationTest/resources/query/function_expression_group_by_response.json new file mode 100644 index 00000000..5adbf5a6 --- /dev/null +++ b/document-store/src/integrationTest/resources/query/function_expression_group_by_response.json @@ -0,0 +1,14 @@ +[ + { + "function": 0.0, + "functionCount": 4 + }, + { + "function": 1.0, + "functionCount": 2 + }, + { + "function": 3.0, + "functionCount": 2 + } +] diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/MongoQueryExecutor.java b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/MongoQueryExecutor.java index c78fdf08..578039fb 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/MongoQueryExecutor.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/MongoQueryExecutor.java @@ -12,7 +12,6 @@ import static org.hypertrace.core.documentstore.mongo.query.MongoPaginationHelper.getSkipClause; import static org.hypertrace.core.documentstore.mongo.query.parser.MongoFilterTypeExpressionParser.getFilter; import static org.hypertrace.core.documentstore.mongo.query.parser.MongoFilterTypeExpressionParser.getFilterClause; -import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.getGroupClause; import static org.hypertrace.core.documentstore.mongo.query.parser.MongoNonProjectedSortTypeExpressionParser.getNonProjectedSortClause; import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSelectTypeExpressionParser.getProjectClause; import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSelectTypeExpressionParser.getSelections; @@ -41,6 +40,7 @@ import org.hypertrace.core.documentstore.model.config.ConnectionConfig; import org.hypertrace.core.documentstore.mongo.query.parser.AliasParser; import org.hypertrace.core.documentstore.mongo.query.parser.MongoFromTypeExpressionParser; +import org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser; import org.hypertrace.core.documentstore.mongo.query.transformer.MongoQueryTransformer; import org.hypertrace.core.documentstore.parser.AggregateExpressionChecker; import org.hypertrace.core.documentstore.parser.FunctionExpressionChecker; @@ -58,7 +58,7 @@ public class MongoQueryExecutor { List.of( query -> singleton(getFilterClause(query, Query::getFilter)), MongoFromTypeExpressionParser::getFromClauses, - query -> singleton(getGroupClause(query)), + MongoGroupTypeExpressionParser::getGroupClauses, query -> singleton(getProjectClause(query)), query -> singleton(getFilterClause(query, Query::getAggregationFilter)), query -> singleton(getSortClause(query)), diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoFunctionExpressionParser.java b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoFunctionExpressionParser.java index 6a807370..c4953886 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoFunctionExpressionParser.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoFunctionExpressionParser.java @@ -62,7 +62,8 @@ Map parse(final FunctionExpression expression) { SelectTypeExpressionVisitor parser = new MongoIdentifierPrefixingParser( - new MongoIdentifierExpressionParser(new MongoConstantExpressionParser())); + new MongoIdentifierExpressionParser( + new MongoFunctionExpressionParser(new MongoConstantExpressionParser()))); if (numArgs == 1) { Object value = expression.getOperands().get(0).accept(parser); diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoGroupTypeExpressionParser.java b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoGroupTypeExpressionParser.java index 801bf470..520520ce 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoGroupTypeExpressionParser.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoGroupTypeExpressionParser.java @@ -5,15 +5,20 @@ import static org.hypertrace.core.documentstore.mongo.MongoUtils.encodeKey; import com.mongodb.BasicDBObject; +import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; import org.hypertrace.core.documentstore.expression.impl.FunctionExpression; import org.hypertrace.core.documentstore.expression.impl.IdentifierExpression; import org.hypertrace.core.documentstore.expression.type.GroupTypeExpression; +import org.hypertrace.core.documentstore.parser.FunctionExpressionChecker; +import org.hypertrace.core.documentstore.parser.GroupByAliasGetter; import org.hypertrace.core.documentstore.parser.GroupTypeExpressionVisitor; import org.hypertrace.core.documentstore.parser.SelectTypeExpressionVisitor; import org.hypertrace.core.documentstore.query.Query; @@ -22,6 +27,10 @@ public final class MongoGroupTypeExpressionParser implements GroupTypeExpressionVisitor { private static final String GROUP_CLAUSE = "$group"; + private static final String ADD_FIELDS_CLAUSE = "$addFields"; + private static final FunctionExpressionChecker FUNCTION_EXPRESSION_CHECKER = + new FunctionExpressionChecker(); + private static final GroupByAliasGetter GROUP_BY_ALIAS_GETTER = new GroupByAliasGetter(); @SuppressWarnings("unchecked") @Override @@ -41,10 +50,32 @@ public Map visit(final IdentifierExpression expression) { return Map.of(key, PREFIX + identifier); } - public static BasicDBObject getGroupClause(final Query query) { + public static List getGroupClauses(final Query query) { final List selectionSpecs = query.getSelections(); final List expressions = query.getAggregations(); + final List basicDBObjects = new ArrayList<>(); + + final List functionExpressionSelectionWithGroupBys = + getFunctionExpressionSelectionWithGroupBys(selectionSpecs, expressions); + + if (!functionExpressionSelectionWithGroupBys.isEmpty()) { + MongoSelectTypeExpressionParser parser = + new MongoIdentifierPrefixingParser( + new MongoIdentifierExpressionParser(new MongoFunctionExpressionParser())); + Map addFields = + functionExpressionSelectionWithGroupBys.stream() + .map(spec -> MongoGroupTypeExpressionParser.parse(parser, spec)) + .reduce( + new LinkedHashMap<>(), + (first, second) -> { + first.putAll(second); + return first; + }); + + basicDBObjects.add(new BasicDBObject(ADD_FIELDS_CLAUSE, addFields)); + } + MongoGroupTypeExpressionParser parser = new MongoGroupTypeExpressionParser(); Map groupExp; @@ -82,11 +113,13 @@ public static BasicDBObject getGroupClause(final Query query) { }); if (MapUtils.isEmpty(definition) && CollectionUtils.isEmpty(expressions)) { - return new BasicDBObject(); + return basicDBObjects; } definition.putAll(groupExp); - return new BasicDBObject(GROUP_CLAUSE, definition); + + basicDBObjects.add(new BasicDBObject(GROUP_CLAUSE, definition)); + return basicDBObjects; } private static Map parse( @@ -99,4 +132,31 @@ private Map parse(final GroupTypeExpression expression) { MongoGroupTypeExpressionParser parser = new MongoGroupTypeExpressionParser(); return expression.accept(parser); } + + private static List getFunctionExpressionSelectionWithGroupBys( + final List selectionSpecs, final List expressions) { + List groupByAliases = getGroupByAliases(expressions); + + return selectionSpecs.stream() + .filter( + selectionSpec -> + isFunctionExpressionSelectionWithGroupBy(selectionSpec, groupByAliases)) + .collect(Collectors.toUnmodifiableList()); + } + + public static boolean isFunctionExpressionSelectionWithGroupBy( + final SelectionSpec selectionSpec, final List groupByAliases) { + return selectionSpec.getAlias() != null + && groupByAliases.contains(selectionSpec.getAlias()) + && (Boolean) selectionSpec.getExpression().accept(FUNCTION_EXPRESSION_CHECKER); + } + + @SuppressWarnings("unchecked") + public static List getGroupByAliases(final List expressions) { + return expressions.stream() + .map(expression -> (Optional) expression.accept(GROUP_BY_ALIAS_GETTER)) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toUnmodifiableList()); + } } diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoSelectTypeExpressionParser.java b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoSelectTypeExpressionParser.java index eedc1ed3..3a561fc0 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoSelectTypeExpressionParser.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/mongo/query/parser/MongoSelectTypeExpressionParser.java @@ -1,7 +1,11 @@ package org.hypertrace.core.documentstore.mongo.query.parser; import static java.util.stream.Collectors.toMap; +import static org.hypertrace.core.documentstore.mongo.MongoCollection.ID_KEY; +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.getGroupByAliases; +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.isFunctionExpressionSelectionWithGroupBy; +import com.google.common.base.Joiner; import com.mongodb.BasicDBObject; import java.util.List; import java.util.Map; @@ -20,6 +24,8 @@ public abstract class MongoSelectTypeExpressionParser implements SelectTypeExpre protected final MongoSelectTypeExpressionParser baseParser; + private static final Joiner DOT_JOINER = Joiner.on("."); + protected MongoSelectTypeExpressionParser() { this(MongoUnsupportedSelectTypeExpressionParser.INSTANCE); } @@ -59,8 +65,17 @@ public static BasicDBObject getSelections(final Query query) { new MongoIdentifierPrefixingParser( new MongoIdentifierExpressionParser(new MongoFunctionExpressionParser())); + List groupByAliases = getGroupByAliases(query.getAggregations()); + Map projectionMap = selectionSpecs.stream() + .map( + spec -> + isFunctionExpressionSelectionWithGroupBy(spec, groupByAliases) + ? SelectionSpec.of( + IdentifierExpression.of(DOT_JOINER.join(ID_KEY, spec.getAlias())), + spec.getAlias()) + : spec) .map(spec -> MongoSelectTypeExpressionParser.parse(parser, spec)) .flatMap(map -> map.entrySet().stream()) .collect( diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/parser/GroupByAliasGetter.java b/document-store/src/main/java/org/hypertrace/core/documentstore/parser/GroupByAliasGetter.java new file mode 100644 index 00000000..defb7e86 --- /dev/null +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/parser/GroupByAliasGetter.java @@ -0,0 +1,19 @@ +package org.hypertrace.core.documentstore.parser; + +import java.util.Optional; +import org.hypertrace.core.documentstore.expression.impl.FunctionExpression; +import org.hypertrace.core.documentstore.expression.impl.IdentifierExpression; + +@SuppressWarnings("unchecked") +public class GroupByAliasGetter implements GroupTypeExpressionVisitor { + + @Override + public Optional visit(FunctionExpression expression) { + return Optional.empty(); + } + + @Override + public Optional visit(IdentifierExpression expression) { + return Optional.of(expression.getName()); + } +}