Skip to content

Commit

Permalink
fix : group by on functional expressions in mongo (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
RishabhB99 authored Jul 2, 2024
1 parent 43b12e8 commit c7ea09e
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MAX;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MIN;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.SUM;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.DIVIDE;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.FLOOR;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.LENGTH;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.MULTIPLY;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.SUBTRACT;
import static org.hypertrace.core.documentstore.expression.operators.LogicalOperator.AND;
import static org.hypertrace.core.documentstore.expression.operators.LogicalOperator.OR;
import static org.hypertrace.core.documentstore.expression.operators.RelationalOperator.CONTAINS;
Expand Down Expand Up @@ -87,6 +90,7 @@
import org.hypertrace.core.documentstore.model.options.UpdateOptions;
import org.hypertrace.core.documentstore.model.subdoc.SubDocumentUpdate;
import org.hypertrace.core.documentstore.model.subdoc.SubDocumentValue;
import org.hypertrace.core.documentstore.query.Aggregation;
import org.hypertrace.core.documentstore.query.Filter;
import org.hypertrace.core.documentstore.query.Pagination;
import org.hypertrace.core.documentstore.query.Query;
Expand Down Expand Up @@ -3312,6 +3316,52 @@ public void testNotExistsOperatorWithFindUsingBooleanRhs(String dataStoreName) t
testCountApi(dataStoreName, query, "query/not_exists_filter_response.json");
}

@ParameterizedTest
@ArgumentsSource(MongoProvider.class)
public void testMongoFunctionExpressionGroupBy(String dataStoreName) throws Exception {
Collection collection = getCollection(dataStoreName);

FunctionExpression functionExpression =
FunctionExpression.builder()
.operator(FLOOR)
.operand(
FunctionExpression.builder()
.operator(DIVIDE)
.operand(
FunctionExpression.builder()
.operator(SUBTRACT)
.operand(IdentifierExpression.of("price"))
.operand(ConstantExpression.of(5))
.build())
.operand(ConstantExpression.of(5))
.build())
.build();
List<SelectionSpec> selectionSpecs =
List.of(
SelectionSpec.of(functionExpression, "function"),
SelectionSpec.of(
AggregateExpression.of(COUNT, IdentifierExpression.of("function")),
"functionCount"));
Selection selection = Selection.builder().selectionSpecs(selectionSpecs).build();

Query query =
Query.builder()
.setSelection(selection)
.setAggregation(
Aggregation.builder().expression(IdentifierExpression.of("function")).build())
.setSort(
Sort.builder()
.sortingSpec(SortingSpec.of(IdentifierExpression.of("function"), ASC))
.build())
.build();

Iterator<Document> resultDocs = collection.aggregate(query);
assertDocsAndSizeEqualWithoutOrder(
dataStoreName, resultDocs, "query/function_expression_group_by_response.json", 3);

testCountApi(dataStoreName, query, "query/function_expression_group_by_response.json");
}

private static Collection getCollection(final String dataStoreName) {
return getCollection(dataStoreName, COLLECTION_NAME);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"function": 0.0,
"functionCount": 4
},
{
"function": 1.0,
"functionCount": 2
},
{
"function": 3.0,
"functionCount": 2
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import static org.hypertrace.core.documentstore.mongo.query.MongoPaginationHelper.getSkipClause;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoFilterTypeExpressionParser.getFilter;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoFilterTypeExpressionParser.getFilterClause;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.getGroupClause;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoNonProjectedSortTypeExpressionParser.getNonProjectedSortClause;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSelectTypeExpressionParser.getProjectClause;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSelectTypeExpressionParser.getSelections;
Expand Down Expand Up @@ -41,6 +40,7 @@
import org.hypertrace.core.documentstore.model.config.ConnectionConfig;
import org.hypertrace.core.documentstore.mongo.query.parser.AliasParser;
import org.hypertrace.core.documentstore.mongo.query.parser.MongoFromTypeExpressionParser;
import org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser;
import org.hypertrace.core.documentstore.mongo.query.transformer.MongoQueryTransformer;
import org.hypertrace.core.documentstore.parser.AggregateExpressionChecker;
import org.hypertrace.core.documentstore.parser.FunctionExpressionChecker;
Expand All @@ -58,7 +58,7 @@ public class MongoQueryExecutor {
List.of(
query -> singleton(getFilterClause(query, Query::getFilter)),
MongoFromTypeExpressionParser::getFromClauses,
query -> singleton(getGroupClause(query)),
MongoGroupTypeExpressionParser::getGroupClauses,
query -> singleton(getProjectClause(query)),
query -> singleton(getFilterClause(query, Query::getAggregationFilter)),
query -> singleton(getSortClause(query)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Map<String, Object> parse(final FunctionExpression expression) {

SelectTypeExpressionVisitor parser =
new MongoIdentifierPrefixingParser(
new MongoIdentifierExpressionParser(new MongoConstantExpressionParser()));
new MongoIdentifierExpressionParser(
new MongoFunctionExpressionParser(new MongoConstantExpressionParser())));

if (numArgs == 1) {
Object value = expression.getOperands().get(0).accept(parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
import static org.hypertrace.core.documentstore.mongo.MongoUtils.encodeKey;

import com.mongodb.BasicDBObject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.hypertrace.core.documentstore.expression.impl.FunctionExpression;
import org.hypertrace.core.documentstore.expression.impl.IdentifierExpression;
import org.hypertrace.core.documentstore.expression.type.GroupTypeExpression;
import org.hypertrace.core.documentstore.parser.FunctionExpressionChecker;
import org.hypertrace.core.documentstore.parser.GroupByAliasGetter;
import org.hypertrace.core.documentstore.parser.GroupTypeExpressionVisitor;
import org.hypertrace.core.documentstore.parser.SelectTypeExpressionVisitor;
import org.hypertrace.core.documentstore.query.Query;
Expand All @@ -22,6 +27,10 @@
public final class MongoGroupTypeExpressionParser implements GroupTypeExpressionVisitor {

private static final String GROUP_CLAUSE = "$group";
private static final String ADD_FIELDS_CLAUSE = "$addFields";
private static final FunctionExpressionChecker FUNCTION_EXPRESSION_CHECKER =
new FunctionExpressionChecker();
private static final GroupByAliasGetter GROUP_BY_ALIAS_GETTER = new GroupByAliasGetter();

@SuppressWarnings("unchecked")
@Override
Expand All @@ -41,10 +50,32 @@ public Map<String, Object> visit(final IdentifierExpression expression) {
return Map.of(key, PREFIX + identifier);
}

public static BasicDBObject getGroupClause(final Query query) {
public static List<BasicDBObject> getGroupClauses(final Query query) {
final List<SelectionSpec> selectionSpecs = query.getSelections();
final List<GroupTypeExpression> expressions = query.getAggregations();

final List<BasicDBObject> basicDBObjects = new ArrayList<>();

final List<SelectionSpec> functionExpressionSelectionWithGroupBys =
getFunctionExpressionSelectionWithGroupBys(selectionSpecs, expressions);

if (!functionExpressionSelectionWithGroupBys.isEmpty()) {
MongoSelectTypeExpressionParser parser =
new MongoIdentifierPrefixingParser(
new MongoIdentifierExpressionParser(new MongoFunctionExpressionParser()));
Map<String, Object> addFields =
functionExpressionSelectionWithGroupBys.stream()
.map(spec -> MongoGroupTypeExpressionParser.parse(parser, spec))
.reduce(
new LinkedHashMap<>(),
(first, second) -> {
first.putAll(second);
return first;
});

basicDBObjects.add(new BasicDBObject(ADD_FIELDS_CLAUSE, addFields));
}

MongoGroupTypeExpressionParser parser = new MongoGroupTypeExpressionParser();
Map<String, Object> groupExp;

Expand Down Expand Up @@ -82,11 +113,13 @@ public static BasicDBObject getGroupClause(final Query query) {
});

if (MapUtils.isEmpty(definition) && CollectionUtils.isEmpty(expressions)) {
return new BasicDBObject();
return basicDBObjects;
}

definition.putAll(groupExp);
return new BasicDBObject(GROUP_CLAUSE, definition);

basicDBObjects.add(new BasicDBObject(GROUP_CLAUSE, definition));
return basicDBObjects;
}

private static Map<String, Object> parse(
Expand All @@ -99,4 +132,31 @@ private Map<String, Object> parse(final GroupTypeExpression expression) {
MongoGroupTypeExpressionParser parser = new MongoGroupTypeExpressionParser();
return expression.accept(parser);
}

private static List<SelectionSpec> getFunctionExpressionSelectionWithGroupBys(
final List<SelectionSpec> selectionSpecs, final List<GroupTypeExpression> expressions) {
List<String> groupByAliases = getGroupByAliases(expressions);

return selectionSpecs.stream()
.filter(
selectionSpec ->
isFunctionExpressionSelectionWithGroupBy(selectionSpec, groupByAliases))
.collect(Collectors.toUnmodifiableList());
}

public static boolean isFunctionExpressionSelectionWithGroupBy(
final SelectionSpec selectionSpec, final List<String> groupByAliases) {
return selectionSpec.getAlias() != null
&& groupByAliases.contains(selectionSpec.getAlias())
&& (Boolean) selectionSpec.getExpression().accept(FUNCTION_EXPRESSION_CHECKER);
}

@SuppressWarnings("unchecked")
public static List<String> getGroupByAliases(final List<GroupTypeExpression> expressions) {
return expressions.stream()
.map(expression -> (Optional<String>) expression.accept(GROUP_BY_ALIAS_GETTER))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toUnmodifiableList());
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package org.hypertrace.core.documentstore.mongo.query.parser;

import static java.util.stream.Collectors.toMap;
import static org.hypertrace.core.documentstore.mongo.MongoCollection.ID_KEY;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.getGroupByAliases;
import static org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser.isFunctionExpressionSelectionWithGroupBy;

import com.google.common.base.Joiner;
import com.mongodb.BasicDBObject;
import java.util.List;
import java.util.Map;
Expand All @@ -20,6 +24,8 @@ public abstract class MongoSelectTypeExpressionParser implements SelectTypeExpre

protected final MongoSelectTypeExpressionParser baseParser;

private static final Joiner DOT_JOINER = Joiner.on(".");

protected MongoSelectTypeExpressionParser() {
this(MongoUnsupportedSelectTypeExpressionParser.INSTANCE);
}
Expand Down Expand Up @@ -59,8 +65,17 @@ public static BasicDBObject getSelections(final Query query) {
new MongoIdentifierPrefixingParser(
new MongoIdentifierExpressionParser(new MongoFunctionExpressionParser()));

List<String> groupByAliases = getGroupByAliases(query.getAggregations());

Map<String, Object> projectionMap =
selectionSpecs.stream()
.map(
spec ->
isFunctionExpressionSelectionWithGroupBy(spec, groupByAliases)
? SelectionSpec.of(
IdentifierExpression.of(DOT_JOINER.join(ID_KEY, spec.getAlias())),
spec.getAlias())
: spec)
.map(spec -> MongoSelectTypeExpressionParser.parse(parser, spec))
.flatMap(map -> map.entrySet().stream())
.collect(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.hypertrace.core.documentstore.parser;

import java.util.Optional;
import org.hypertrace.core.documentstore.expression.impl.FunctionExpression;
import org.hypertrace.core.documentstore.expression.impl.IdentifierExpression;

@SuppressWarnings("unchecked")
public class GroupByAliasGetter implements GroupTypeExpressionVisitor {

@Override
public Optional<String> visit(FunctionExpression expression) {
return Optional.empty();
}

@Override
public Optional<String> visit(IdentifierExpression expression) {
return Optional.of(expression.getName());
}
}

0 comments on commit c7ea09e

Please sign in to comment.