diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java index 2c66ec5d512ec..4e3435d7642aa 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java @@ -25,7 +25,6 @@ import org.apache.shardingsphere.encrypt.rule.table.EncryptTable; import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions; -import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor; import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; @@ -61,8 +60,6 @@ public final class EncryptConditionEngine { private final EncryptRule rule; - private final ShardingSphereDatabase database; - static { LOGICAL_OPERATOR.add("AND"); LOGICAL_OPERATOR.add("&&"); @@ -155,25 +152,36 @@ private Optional createBinaryEncryptCondition(final BinaryOper return Optional.empty(); } ShardingSpherePreconditions.checkContains(SUPPORTED_COMPARE_OPERATOR, operator, () -> new UnsupportedEncryptSQLException(operator)); - return createCompareEncryptCondition(tableName, expression, expression.getRight()); + return createCompareEncryptCondition(tableName, expression); } - private Optional createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) { - if (!(expression.getLeft() instanceof ColumnSegment) || compareRightValue instanceof SubqueryExpressionSegment) { + private Optional createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression) { + if (isLeftRightAllNotColumnSegment(expression) || isLeftRightContainsSubquerySegment(expression)) { return Optional.empty(); } - if (compareRightValue instanceof SimpleExpressionSegment) { - return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, compareRightValue)); + ColumnSegment columnSegment = expression.getLeft() instanceof ColumnSegment ? (ColumnSegment) expression.getLeft() : (ColumnSegment) expression.getRight(); + ExpressionSegment compareValueSegment = expression.getLeft() instanceof ColumnSegment ? expression.getRight() : expression.getLeft(); + if (compareValueSegment instanceof SimpleExpressionSegment) { + return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, compareValueSegment)); } - if (compareRightValue instanceof ListExpression) { - return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, ((ListExpression) compareRightValue).getItems().get(0))); + if (compareValueSegment instanceof ListExpression) { + // TODO check this logic when items contain multiple values @duanzhengqiang + return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, ((ListExpression) compareValueSegment).getItems().get(0))); } return Optional.empty(); } - private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) { - ColumnSegment columnSegment = (ColumnSegment) expression.getLeft(); - return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue); + private boolean isLeftRightAllNotColumnSegment(final BinaryOperationExpression expression) { + return !(expression.getLeft() instanceof ColumnSegment) && !(expression.getRight() instanceof ColumnSegment); + } + + private boolean isLeftRightContainsSubquerySegment(final BinaryOperationExpression expression) { + return expression.getLeft() instanceof SubqueryExpressionSegment || expression.getRight() instanceof SubqueryExpressionSegment; + } + + private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ColumnSegment columnSegment, + final ExpressionSegment compareValueSegment) { + return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareValueSegment.getStartIndex(), compareValueSegment.getStopIndex(), compareValueSegment); } private static Optional createInEncryptCondition(final String tableName, final InExpression inExpression, final ExpressionSegment inRightValue) { diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java index de39b0c0a0dd2..bfcaa9c4890d3 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java @@ -26,7 +26,6 @@ import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; -import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; @@ -55,7 +54,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props if (!containsEncryptTable(rule, sqlStatementContext)) { return; } - Collection encryptConditions = createEncryptConditions(rule, sqlRewriteContext); + Collection encryptConditions = createEncryptConditions(rule, sqlStatementContext); if (!sqlRewriteContext.getParameters().isEmpty()) { Collection parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext) .build(new EncryptParameterRewritersRegistry(rule, sqlRewriteContext.getDatabase().getName(), encryptConditions)); @@ -77,23 +76,13 @@ private boolean containsEncryptTable(final EncryptRule rule, final SQLStatementC return false; } - private Collection createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext) { - SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext(); - if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext() - && !(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()).getWhereSegments().isEmpty()) { - return createEncryptConditions(rule, sqlRewriteContext, ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()); - } - return createEncryptConditions(rule, sqlRewriteContext, sqlStatementContext); - } - - private Collection createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext, - final SQLStatementContext sqlStatementContext) { + private Collection createEncryptConditions(final EncryptRule rule, final SQLStatementContext sqlStatementContext) { if (!(sqlStatementContext instanceof WhereAvailable)) { return Collections.emptyList(); } Collection allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext); Collection whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); - return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments); + return new EncryptConditionEngine(rule).createEncryptConditions(whereSegments); } private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection parameterRewriters) { diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java index 54d8ad7127a45..549b00ef0bc35 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java @@ -25,8 +25,6 @@ import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn; import org.apache.shardingsphere.encrypt.rule.table.EncryptTable; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; -import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; -import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry; @@ -53,25 +51,7 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite @Override public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) { - if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) { - return true; - } - if (sqlStatementContext instanceof SelectStatementContext) { - return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext); - } - if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) { - return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()); - } - return false; - } - - private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) { - for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) { - if (isNeedRewrite(each)) { - return true; - } - } - return false; + return sqlStatementContext instanceof WhereAvailable; } @Override @@ -85,7 +65,7 @@ public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementConte private List getEncryptedValues(final String schemaName, final EncryptCondition encryptCondition, final List originalValues) { String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue(); - String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue(); + String columnName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalColumn().getValue(); EncryptTable encryptTable = rule.getEncryptTable(tableName); EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName); if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equals(((EncryptBinaryCondition) encryptCondition).getOperator()) && encryptColumn.getLikeQuery().isPresent()) { diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilder.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilder.java index eaaaee9c05e51..d8aea0caf6e9b 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilder.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilder.java @@ -30,9 +30,9 @@ import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertOnUpdateTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertValuesTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateColumnTokenGenerator; -import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateRightValueTokenGenerator; +import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateValueTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateColumnTokenGenerator; -import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateRightValueTokenGenerator; +import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateValueTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptInsertSelectProjectionTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptSelectProjectionTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.select.EncryptGroupByItemTokenGenerator; @@ -70,8 +70,8 @@ public Collection getSQLTokenGenerators() { addSQLTokenGenerator(result, new EncryptUpdateAssignmentTokenGenerator(rule)); addSQLTokenGenerator(result, new EncryptPredicateColumnTokenGenerator(rule)); addSQLTokenGenerator(result, new EncryptInsertPredicateColumnTokenGenerator(rule)); - addSQLTokenGenerator(result, new EncryptPredicateRightValueTokenGenerator(rule)); - addSQLTokenGenerator(result, new EncryptInsertPredicateRightValueTokenGenerator(rule)); + addSQLTokenGenerator(result, new EncryptPredicateValueTokenGenerator(rule)); + addSQLTokenGenerator(result, new EncryptInsertPredicateValueTokenGenerator(rule)); addSQLTokenGenerator(result, new EncryptInsertValuesTokenGenerator(rule)); addSQLTokenGenerator(result, new EncryptInsertDefaultColumnsTokenGenerator(rule)); addSQLTokenGenerator(result, new EncryptInsertCipherNameTokenGenerator(rule)); diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateRightValueTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateValueTokenGenerator.java similarity index 89% rename from features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateRightValueTokenGenerator.java rename to features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateValueTokenGenerator.java index b067040abaa0b..28d809f7572b1 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateRightValueTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptInsertPredicateValueTokenGenerator.java @@ -35,12 +35,12 @@ import java.util.List; /** - * Insert predicate right value token generator for encrypt. + * Insert predicate value token generator for encrypt. */ @HighFrequencyInvocation @RequiredArgsConstructor @Setter -public final class EncryptInsertPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator, ParametersAware, EncryptConditionsAware, DatabaseAware { +public final class EncryptInsertPredicateValueTokenGenerator implements CollectionSQLTokenGenerator, ParametersAware, EncryptConditionsAware, DatabaseAware { private final EncryptRule rule; @@ -58,7 +58,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) @Override public Collection generateSQLTokens(final SQLStatementContext sqlStatementContext) { - EncryptPredicateRightValueTokenGenerator generator = new EncryptPredicateRightValueTokenGenerator(rule); + EncryptPredicateValueTokenGenerator generator = new EncryptPredicateValueTokenGenerator(rule); generator.setParameters(parameters); generator.setEncryptConditions(encryptConditions); generator.setDatabase(database); diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java index af41bb328b9e6..20bfe4ee544c2 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java @@ -61,7 +61,7 @@ public final class EncryptPredicateColumnTokenGenerator implements CollectionSQL @Override public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) { - return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty(); + return sqlStatementContext instanceof WhereAvailable; } @Override @@ -114,16 +114,15 @@ private boolean includesLike(final Collection whereSegments, final private boolean isLikeColumnSegment(final AndPredicate andPredicate, final ColumnSegment targetColumnSegment) { for (ExpressionSegment each : andPredicate.getPredicates()) { - if (each instanceof BinaryOperationExpression - && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isSameColumnSegment(((BinaryOperationExpression) each).getLeft(), targetColumnSegment)) { + if (each instanceof BinaryOperationExpression && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isContainsColumnSegment(each, targetColumnSegment)) { return true; } } return false; } - private boolean isSameColumnSegment(final ExpressionSegment columnSegment, final ColumnSegment targetColumnSegment) { - return columnSegment instanceof ColumnSegment && columnSegment.getStartIndex() == targetColumnSegment.getStartIndex() && columnSegment.getStopIndex() == targetColumnSegment.getStopIndex(); + private boolean isContainsColumnSegment(final ExpressionSegment expressionSegment, final ColumnSegment targetColumnSegment) { + return expressionSegment.getStartIndex() <= targetColumnSegment.getStartIndex() && expressionSegment.getStopIndex() >= targetColumnSegment.getStopIndex(); } private Collection createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter, final DatabaseType databaseType) { diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGenerator.java similarity index 91% rename from features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGenerator.java rename to features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGenerator.java index a10c04592493c..28ae729382649 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGenerator.java @@ -52,12 +52,12 @@ import java.util.Optional; /** - * Predicate right value token generator for encrypt. + * Predicate value token generator for encrypt. */ @HighFrequencyInvocation @RequiredArgsConstructor @Setter -public final class EncryptPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator, ParametersAware, EncryptConditionsAware, DatabaseAware { +public final class EncryptPredicateValueTokenGenerator implements CollectionSQLTokenGenerator, ParametersAware, EncryptConditionsAware, DatabaseAware { private final EncryptRule rule; @@ -69,7 +69,7 @@ public final class EncryptPredicateRightValueTokenGenerator implements Collectio @Override public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) { - return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty(); + return sqlStatementContext instanceof WhereAvailable; } @Override @@ -102,15 +102,15 @@ private SQLToken generateSQLToken(final String schemaName, final EncryptTable en private List getEncryptedValues(final String schemaName, final EncryptTable encryptTable, final EncryptCondition encryptCondition, final List originalValues) { String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue(); EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName); + String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue(); if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equalsIgnoreCase(((EncryptBinaryCondition) encryptCondition).getOperator())) { LikeQueryColumnItem likeQueryColumnItem = encryptColumn.getLikeQuery() .orElseThrow(() -> new MissingMatchedEncryptQueryAlgorithmException(encryptTable.getTable(), columnName, "LIKE")); return likeQueryColumnItem.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues); } return encryptColumn.getAssistedQuery() - .map(optional -> optional.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues)) - .orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, - originalValues)); + .map(optional -> optional.encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues)) + .orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues)); } private Map getPositionValues(final Collection valuePositions, final List encryptValues) { diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java index ea9f4b7e167fd..1e3456c6c626c 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java @@ -46,8 +46,12 @@ public EncryptPredicateEqualRightValueToken(final int startIndex, final int stop @Override public String toString() { if (paramMarkerIndexes.isEmpty()) { - return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString(); + return indexValues.isEmpty() ? "" : getIndexValue(indexValues); } return "?"; } + + private String getIndexValue(final Map indexValues) { + return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString(); + } } diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptMergedResultTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptMergedResultTest.java index eb827dd98deab..c01891e93d45a 100644 --- a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptMergedResultTest.java +++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/merge/dql/EncryptMergedResultTest.java @@ -57,6 +57,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -81,7 +82,8 @@ void assertNext() throws SQLException { void assertGetValueWithoutColumnProjection() throws SQLException { when(selectStatementContext.findColumnProjection(1)).thenReturn(Optional.empty()); when(mergedResult.getValue(1, String.class)).thenReturn("foo_value"); - assertThat(new EncryptMergedResult(mock(), mock(), selectStatementContext, mergedResult).getValue(1, String.class), is("foo_value")); + ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + assertThat(new EncryptMergedResult(database, mock(), selectStatementContext, mergedResult).getValue(1, String.class), is("foo_value")); } @Test diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilderTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilderTest.java index f5087709ffc1a..3b55e6741e810 100644 --- a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilderTest.java +++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/EncryptTokenGenerateBuilderTest.java @@ -17,6 +17,8 @@ package org.apache.shardingsphere.encrypt.rewrite.token; +import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateColumnTokenGenerator; +import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateValueTokenGenerator; import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptSelectProjectionTokenGenerator; import org.apache.shardingsphere.encrypt.rule.EncryptRule; import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByItem; @@ -56,9 +58,13 @@ void assertGetSQLTokenGenerators() { when(selectStatementContext.getWhereSegments()).thenReturn(Collections.emptyList()); EncryptTokenGenerateBuilder encryptTokenGenerateBuilder = new EncryptTokenGenerateBuilder(rule, selectStatementContext, Collections.emptyList(), mock(ShardingSphereDatabase.class)); Collection sqlTokenGenerators = encryptTokenGenerateBuilder.getSQLTokenGenerators(); - assertThat(sqlTokenGenerators.size(), is(1)); + assertThat(sqlTokenGenerators.size(), is(3)); Iterator iterator = sqlTokenGenerators.iterator(); SQLTokenGenerator item1 = iterator.next(); assertThat(item1, instanceOf(EncryptSelectProjectionTokenGenerator.class)); + SQLTokenGenerator item2 = iterator.next(); + assertThat(item2, instanceOf(EncryptPredicateColumnTokenGenerator.class)); + SQLTokenGenerator item3 = iterator.next(); + assertThat(item3, instanceOf(EncryptPredicateValueTokenGenerator.class)); } } diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGeneratorTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGeneratorTest.java similarity index 84% rename from features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGeneratorTest.java rename to features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGeneratorTest.java index 002e641a26ffc..29a88eae73239 100644 --- a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateRightValueTokenGeneratorTest.java +++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateValueTokenGeneratorTest.java @@ -22,7 +22,6 @@ import org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder; import org.apache.shardingsphere.infra.binder.context.statement.dml.UpdateStatementContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; -import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -35,13 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; -class EncryptPredicateRightValueTokenGeneratorTest { +class EncryptPredicateValueTokenGeneratorTest { - private EncryptPredicateRightValueTokenGenerator generator; + private EncryptPredicateValueTokenGenerator generator; @BeforeEach void setup() { - generator = new EncryptPredicateRightValueTokenGenerator(EncryptGeneratorFixtureBuilder.createEncryptRule()); + generator = new EncryptPredicateValueTokenGenerator(EncryptGeneratorFixtureBuilder.createEncryptRule()); } @Test @@ -61,8 +60,7 @@ void assertGenerateSQLTokenFromGenerateNewSQLToken() { } private Collection getEncryptConditions(final UpdateStatementContext updateStatementContext) { - ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db", mock(), mock(), mock(), Collections.singleton(new ShardingSphereSchema("foo_db"))); - return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), database) + return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule()) .createEncryptConditions(updateStatementContext.getWhereSegments()); } } diff --git a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ShardingProjectionsTokenGeneratorTest.java b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ShardingProjectionsTokenGeneratorTest.java index 203b0828d15ed..2c01d86970996 100644 --- a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ShardingProjectionsTokenGeneratorTest.java +++ b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ShardingProjectionsTokenGeneratorTest.java @@ -33,6 +33,7 @@ import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; import org.apache.shardingsphere.sql.parser.statement.core.enums.OrderDirection; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement; @@ -116,9 +117,9 @@ void assertGenerateSQLToken() { } private AggregationProjection createAggregationProjection() { - AggregationDistinctProjection derivedProjection = new AggregationDistinctProjection(0, 0, AggregationType.COUNT, "", + AggregationDistinctProjection derivedProjection = new AggregationDistinctProjection(0, 0, AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, ""), new IdentifierValue("foo_agg_alias"), "foo_agg_expr", databaseType); - AggregationProjection result = new AggregationDistinctProjection(0, 0, AggregationType.COUNT, "", null, "", databaseType); + AggregationProjection result = new AggregationDistinctProjection(0, 0, AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, ""), null, "", databaseType); result.getDerivedAggregationProjections().add(derivedProjection); return result; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/extractor/SQLStatementContextExtractor.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/extractor/SQLStatementContextExtractor.java index ed827c1acca56..5da332d1a441b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/extractor/SQLStatementContextExtractor.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/extractor/SQLStatementContextExtractor.java @@ -22,6 +22,8 @@ import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.ddl.AlterViewStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.ddl.CreateViewStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.IndexAvailable; @@ -83,13 +85,31 @@ public static Collection getAllSubqueryContexts(final SQ if (sqlStatementContext instanceof SelectStatementContext) { result.addAll(((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values()); ((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + return result; } if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) { InsertSelectContext insertSelectContext = ((InsertStatementContext) sqlStatementContext).getInsertSelectContext(); result.add(insertSelectContext.getSelectStatementContext()); result.addAll(insertSelectContext.getSelectStatementContext().getSubqueryContexts().values()); insertSelectContext.getSelectStatementContext().getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + return result; } + // SPEX ADDED: BEGIN + if (sqlStatementContext instanceof CreateViewStatementContext) { + CreateViewStatementContext createViewStatementContext = (CreateViewStatementContext) sqlStatementContext; + result.add(createViewStatementContext.getSelectStatementContext()); + result.addAll(createViewStatementContext.getSelectStatementContext().getSubqueryContexts().values()); + createViewStatementContext.getSelectStatementContext().getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + return result; + } + if (sqlStatementContext instanceof AlterViewStatementContext && ((AlterViewStatementContext) sqlStatementContext).getSelectStatementContext().isPresent()) { + AlterViewStatementContext alterViewStatementContext = (AlterViewStatementContext) sqlStatementContext; + result.add(alterViewStatementContext.getSelectStatementContext().get()); + result.addAll(alterViewStatementContext.getSelectStatementContext().get().getSubqueryContexts().values()); + alterViewStatementContext.getSelectStatementContext().get().getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + return result; + } + // SPEX ADDED: END return result; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java index 3fd824f99a485..b35e474d88b78 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java @@ -120,7 +120,7 @@ private AggregationDistinctProjection createProjection(final AggregationDistinct IdentifierValue alias = projectionSegment.getAlias().orElseGet(() -> new IdentifierValue(DerivedColumn.AGGREGATION_DISTINCT_DERIVED.getDerivedColumnAlias(aggregationDistinctDerivedColumnCount++))); AggregationDistinctProjection result = new AggregationDistinctProjection( - projectionSegment.getStartIndex(), projectionSegment.getStopIndex(), projectionSegment.getType(), projectionSegment.getExpression(), alias, + projectionSegment.getStartIndex(), projectionSegment.getStopIndex(), projectionSegment.getType(), projectionSegment, alias, projectionSegment.getDistinctInnerExpression(), databaseType, projectionSegment.getSeparator().orElse(null)); if (AggregationType.AVG == result.getType()) { appendAverageDistinctDerivedProjection(result); @@ -130,7 +130,7 @@ private AggregationDistinctProjection createProjection(final AggregationDistinct private AggregationProjection createProjection(final AggregationProjectionSegment projectionSegment) { AggregationProjection result = - new AggregationProjection(projectionSegment.getType(), projectionSegment.getExpression(), projectionSegment.getAlias().orElse(null), databaseType, + new AggregationProjection(projectionSegment.getType(), projectionSegment, projectionSegment.getAlias().orElse(null), databaseType, projectionSegment.getSeparator().orElse(null)); if (AggregationType.AVG == result.getType()) { appendAverageDerivedProjection(result); @@ -143,11 +143,13 @@ private void appendAverageDistinctDerivedProjection(final AggregationDistinctPro String distinctInnerExpression = averageDistinctProjection.getDistinctInnerExpression(); String countAlias = DerivedColumn.AVG_COUNT_ALIAS.getDerivedColumnAlias(aggregationAverageDerivedColumnCount); String innerExpression = averageDistinctProjection.getExpression().substring(averageDistinctProjection.getExpression().indexOf(Paren.PARENTHESES.getLeftParen())); + AggregationProjectionSegment countExpression = new AggregationProjectionSegment(0, 0, AggregationType.COUNT, AggregationType.COUNT.name() + innerExpression); AggregationDistinctProjection countDistinctProjection = - new AggregationDistinctProjection(0, 0, AggregationType.COUNT, AggregationType.COUNT.name() + innerExpression, new IdentifierValue(countAlias), distinctInnerExpression, databaseType); + new AggregationDistinctProjection(0, 0, AggregationType.COUNT, countExpression, new IdentifierValue(countAlias), distinctInnerExpression, databaseType); String sumAlias = DerivedColumn.AVG_SUM_ALIAS.getDerivedColumnAlias(aggregationAverageDerivedColumnCount); + AggregationProjectionSegment sumExpression = new AggregationProjectionSegment(0, 0, AggregationType.SUM, AggregationType.SUM.name() + innerExpression); AggregationDistinctProjection sumDistinctProjection = - new AggregationDistinctProjection(0, 0, AggregationType.SUM, AggregationType.SUM.name() + innerExpression, new IdentifierValue(sumAlias), distinctInnerExpression, databaseType); + new AggregationDistinctProjection(0, 0, AggregationType.SUM, sumExpression, new IdentifierValue(sumAlias), distinctInnerExpression, databaseType); averageDistinctProjection.getDerivedAggregationProjections().add(countDistinctProjection); averageDistinctProjection.getDerivedAggregationProjections().add(sumDistinctProjection); aggregationAverageDerivedColumnCount++; @@ -156,9 +158,11 @@ private void appendAverageDistinctDerivedProjection(final AggregationDistinctPro private void appendAverageDerivedProjection(final AggregationProjection averageProjection) { String countAlias = DerivedColumn.AVG_COUNT_ALIAS.getDerivedColumnAlias(aggregationAverageDerivedColumnCount); String innerExpression = averageProjection.getExpression().substring(averageProjection.getExpression().indexOf(Paren.PARENTHESES.getLeftParen())); - AggregationProjection countProjection = new AggregationProjection(AggregationType.COUNT, AggregationType.COUNT.name() + innerExpression, new IdentifierValue(countAlias), databaseType); + AggregationProjectionSegment countExpression = new AggregationProjectionSegment(0, 0, AggregationType.COUNT, AggregationType.COUNT.name() + innerExpression); + AggregationProjection countProjection = new AggregationProjection(AggregationType.COUNT, countExpression, new IdentifierValue(countAlias), databaseType); String sumAlias = DerivedColumn.AVG_SUM_ALIAS.getDerivedColumnAlias(aggregationAverageDerivedColumnCount); - AggregationProjection sumProjection = new AggregationProjection(AggregationType.SUM, AggregationType.SUM.name() + innerExpression, new IdentifierValue(sumAlias), databaseType); + AggregationProjectionSegment sumExpression = new AggregationProjectionSegment(0, 0, AggregationType.SUM, AggregationType.SUM.name() + innerExpression); + AggregationProjection sumProjection = new AggregationProjection(AggregationType.SUM, sumExpression, new IdentifierValue(sumAlias), databaseType); averageProjection.getDerivedAggregationProjections().add(countProjection); averageProjection.getDerivedAggregationProjections().add(sumProjection); aggregationAverageDerivedColumnCount++; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjection.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjection.java index e047645fa1e0a..72887593bfaf0 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjection.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjection.java @@ -21,6 +21,7 @@ import lombok.Getter; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; /** @@ -36,17 +37,17 @@ public final class AggregationDistinctProjection extends AggregationProjection { private final String distinctInnerExpression; - public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final String expression, + public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final AggregationProjectionSegment aggregationSegment, final IdentifierValue alias, final String distinctInnerExpression, final DatabaseType databaseType) { - super(type, expression, alias, databaseType); + super(type, aggregationSegment, alias, databaseType); this.startIndex = startIndex; this.stopIndex = stopIndex; this.distinctInnerExpression = distinctInnerExpression; } - public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final String expression, + public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final AggregationProjectionSegment aggregationSegment, final IdentifierValue alias, final String distinctInnerExpression, final DatabaseType databaseType, final String separator) { - super(type, expression, alias, databaseType, separator); + super(type, aggregationSegment, alias, databaseType, separator); this.startIndex = startIndex; this.stopIndex = stopIndex; this.distinctInnerExpression = distinctInnerExpression; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjection.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjection.java index 78ab5dd906e84..8a3fc641c812b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjection.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjection.java @@ -27,6 +27,7 @@ import org.apache.shardingsphere.infra.binder.context.segment.select.projection.extractor.ProjectionIdentifierExtractEngine; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; import java.util.ArrayList; @@ -44,7 +45,7 @@ public class AggregationProjection implements Projection { private final AggregationType type; - private final String expression; + private final AggregationProjectionSegment aggregationSegment; private final IdentifierValue alias; @@ -57,9 +58,9 @@ public class AggregationProjection implements Projection { @Setter private int index = -1; - public AggregationProjection(final AggregationType type, final String expression, final IdentifierValue alias, final DatabaseType databaseType) { + public AggregationProjection(final AggregationType type, final AggregationProjectionSegment aggregationSegment, final IdentifierValue alias, final DatabaseType databaseType) { this.type = type; - this.expression = expression; + this.aggregationSegment = aggregationSegment; this.alias = alias; this.databaseType = databaseType; separator = null; @@ -84,7 +85,12 @@ public String getColumnLabel() { ProjectionIdentifierExtractEngine extractEngine = new ProjectionIdentifierExtractEngine(databaseType); return getAlias().isPresent() && !DerivedColumn.isDerivedColumnName(getAlias().get().getValueWithQuoteCharacters()) ? extractEngine.getIdentifierValue(getAlias().get()) - : extractEngine.getColumnNameFromFunction(type.name(), expression); + : extractEngine.getColumnNameFromFunction(type.name(), aggregationSegment.getExpression()); + } + + @Override + public String getExpression() { + return aggregationSegment.getExpression(); } @Override diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContext.java index cef5709022630..28d5206c4639b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContext.java @@ -22,9 +22,13 @@ import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; +import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData; import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType; import org.apache.shardingsphere.sql.parser.statement.core.extractor.TableExtractor; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.AlterViewStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement; @@ -39,7 +43,7 @@ * Alter view statement context. */ @Getter -public final class AlterViewStatementContext extends CommonSQLStatementContext implements TableAvailable { +public final class AlterViewStatementContext extends CommonSQLStatementContext implements TableAvailable, WhereAvailable { private final TablesContext tablesContext; @@ -82,4 +86,19 @@ public Optional getSelectStatementContext() { public AlterViewStatement getSqlStatement() { return (AlterViewStatement) super.getSqlStatement(); } + + @Override + public Collection getWhereSegments() { + return getSelectStatementContext().isPresent() ? getSelectStatementContext().get().getWhereSegments() : Collections.emptyList(); + } + + @Override + public Collection getColumnSegments() { + return getSelectStatementContext().isPresent() ? getSelectStatementContext().get().getColumnSegments() : Collections.emptyList(); + } + + @Override + public Collection getJoinConditions() { + return getSelectStatementContext().isPresent() ? getSelectStatementContext().get().getJoinConditions() : Collections.emptyList(); + } } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/CreateViewStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/CreateViewStatementContext.java index c829761b25b8a..761d5b3a73806 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/CreateViewStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/CreateViewStatementContext.java @@ -22,11 +22,16 @@ import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; +import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData; import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType; import org.apache.shardingsphere.sql.parser.statement.core.extractor.TableExtractor; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.CreateViewStatement; +import java.util.Collection; import java.util.Collections; import java.util.List; @@ -34,7 +39,7 @@ * Create view statement context. */ @Getter -public final class CreateViewStatementContext extends CommonSQLStatementContext implements TableAvailable { +public final class CreateViewStatementContext extends CommonSQLStatementContext implements TableAvailable, WhereAvailable { private final TablesContext tablesContext; @@ -53,4 +58,19 @@ public CreateViewStatementContext(final ShardingSphereMetaData metaData, final L public CreateViewStatement getSqlStatement() { return (CreateViewStatement) super.getSqlStatement(); } + + @Override + public Collection getWhereSegments() { + return selectStatementContext.getWhereSegments(); + } + + @Override + public Collection getColumnSegments() { + return selectStatementContext.getColumnSegments(); + } + + @Override + public Collection getJoinConditions() { + return selectStatementContext.getJoinConditions(); + } } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/ExpressionSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/ExpressionSegmentBinder.java index 365f7fed040c5..d874e9cb85b86 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/ExpressionSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/ExpressionSegmentBinder.java @@ -23,6 +23,8 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.AggregationDistinctProjectionSegmentBinder; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.AggregationProjectionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.BetweenExpressionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.BinaryOperationExpressionBinder; import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.ColumnSegmentBinder; @@ -42,6 +44,8 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.InExpression; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; /** * Expression segment binder. @@ -89,6 +93,12 @@ public static ExpressionSegment bind(final ExpressionSegment segment, final Segm if (segment instanceof BetweenExpression) { return BetweenExpressionSegmentBinder.bind((BetweenExpression) segment, binderContext, tableBinderContexts, outerTableBinderContexts); } + if (segment instanceof AggregationDistinctProjectionSegment) { + return AggregationDistinctProjectionSegmentBinder.bind((AggregationDistinctProjectionSegment) segment, binderContext, tableBinderContexts, outerTableBinderContexts); + } + if (segment instanceof AggregationProjectionSegment) { + return AggregationProjectionSegmentBinder.bind((AggregationProjectionSegment) segment, binderContext, tableBinderContexts, outerTableBinderContexts); + } // TODO support more ExpressionSegment bound return segment; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationDistinctProjectionSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationDistinctProjectionSegmentBinder.java new file mode 100644 index 0000000000000..accf8e8d5b911 --- /dev/null +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationDistinctProjectionSegmentBinder.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type; + +import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString; +import com.google.common.collect.Multimap; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.ExpressionSegmentBinder; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.from.context.TableSegmentBinderContext; +import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment; + +/** + * Aggregation distinct projection segment binder. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class AggregationDistinctProjectionSegmentBinder { + + /** + * Bind aggregation distinct projection segment. + * + * @param segment aggregation distinct projection segment + * @param binderContext SQL statement binder context + * @param tableBinderContexts table binder contexts + * @param outerTableBinderContexts outer table binder contexts + * @return bound aggregation distinct projection segment + */ + public static AggregationDistinctProjectionSegment bind(final AggregationDistinctProjectionSegment segment, final SQLStatementBinderContext binderContext, + final Multimap tableBinderContexts, + final Multimap outerTableBinderContexts) { + AggregationDistinctProjectionSegment result = + new AggregationDistinctProjectionSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getType(), segment.getExpression(), segment.getDistinctInnerExpression()); + segment.getParameters().forEach(each -> result.getParameters().add(ExpressionSegmentBinder.bind(each, SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts))); + return result; + } +} diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationProjectionSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationProjectionSegmentBinder.java new file mode 100644 index 0000000000000..255134cffa2e8 --- /dev/null +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/dml/expression/type/AggregationProjectionSegmentBinder.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type; + +import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString; +import com.google.common.collect.Multimap; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.ExpressionSegmentBinder; +import org.apache.shardingsphere.infra.binder.engine.segment.dml.from.context.TableSegmentBinderContext; +import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; + +/** + * Aggregation projection segment binder. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class AggregationProjectionSegmentBinder { + + /** + * Bind aggregation projection segment. + * + * @param segment aggregation projection segment + * @param binderContext SQL statement binder context + * @param tableBinderContexts table binder contexts + * @param outerTableBinderContexts outer table binder contexts + * @return bound aggregation projection segment + */ + public static AggregationProjectionSegment bind(final AggregationProjectionSegment segment, final SQLStatementBinderContext binderContext, + final Multimap tableBinderContexts, + final Multimap outerTableBinderContexts) { + AggregationProjectionSegment result = new AggregationProjectionSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getType(), segment.getExpression()); + segment.getParameters().forEach(each -> result.getParameters().add(ExpressionSegmentBinder.bind(each, SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts))); + return result; + } +} diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/ProjectionsContextTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/ProjectionsContextTest.java index 26210e709fd2c..91f819685f1bc 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/ProjectionsContextTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/ProjectionsContextTest.java @@ -26,6 +26,7 @@ import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType; import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; import org.junit.jupiter.api.Test; @@ -128,12 +129,12 @@ private ColumnProjection getColumnProjectionWithAlias() { } private AggregationProjection getAggregationProjection() { - return new AggregationProjection(AggregationType.COUNT, "(column)", new IdentifierValue("c"), mock(DatabaseType.class)); + return new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "(column)"), new IdentifierValue("c"), mock(DatabaseType.class)); } private AggregationDistinctProjection getAggregationDistinctProjection() { return new AggregationDistinctProjection( - 0, 0, AggregationType.COUNT, "(DISTINCT column)", new IdentifierValue("c"), "column", mock(DatabaseType.class)); + 0, 0, AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "(DISTINCT column)"), new IdentifierValue("c"), "column", mock(DatabaseType.class)); } @Test diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjectionTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjectionTest.java index a7c0bbca17c3f..c49053b7d35bf 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjectionTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationDistinctProjectionTest.java @@ -19,6 +19,7 @@ import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; import org.junit.jupiter.api.Test; @@ -29,7 +30,7 @@ class AggregationDistinctProjectionTest { private final AggregationDistinctProjection aggregationDistinctProjection = new AggregationDistinctProjection( - 0, 0, AggregationType.COUNT, "(DISTINCT order_id)", new IdentifierValue("c"), "order_id", mock(DatabaseType.class)); + 0, 0, AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "(DISTINCT order_id)"), new IdentifierValue("c"), "order_id", mock(DatabaseType.class)); @Test void assertGetDistinctColumnName() { diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjectionTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjectionTest.java index 85be8d87a6294..089b4a4eabe84 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjectionTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/impl/AggregationProjectionTest.java @@ -18,10 +18,11 @@ package org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl; import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection; -import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter; +import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; import org.junit.jupiter.api.Test; @@ -36,59 +37,64 @@ class AggregationProjectionTest { @Test void assertGetColumnName() { - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnName(), is("COUNT( A.\"DIRECTION\" )")); - assertThat(new AggregationProjection(AggregationType.COUNT, "count( a.\"direction\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "count( a.\"direction\" )"), null, TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnName(), is("count( a.\"direction\" )")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")).getColumnName(), is("count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "openGauss")).getColumnName(), is("count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( a.\"direction\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( a.\"direction\" )"), null, TypedSPILoader.getService(DatabaseType.class, "Oracle")).getColumnName(), is("COUNT(A.\"DIRECTION\")")); } @Test void assertGetColumnLabelWithAliasNoQuote() { - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT"), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), new IdentifierValue("DIRECTION_COUNT"), TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnLabel(), is("DIRECTION_COUNT")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT"), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), new IdentifierValue("DIRECTION_COUNT"), TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")).getColumnLabel(), is("direction_count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT"), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), new IdentifierValue("DIRECTION_COUNT"), TypedSPILoader.getService(DatabaseType.class, "openGauss")).getColumnLabel(), is("direction_count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("direction_count"), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), new IdentifierValue("direction_count"), TypedSPILoader.getService(DatabaseType.class, "Oracle")).getColumnLabel(), is("DIRECTION_COUNT")); } @Test void assertGetColumnLabelWithAliasAndQuote() { - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.BACK_QUOTE), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), + new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.BACK_QUOTE), TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnLabel(), is("DIRECTION_COUNT")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.QUOTE), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), + new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.QUOTE), TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")).getColumnLabel(), is("DIRECTION_COUNT")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.QUOTE), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), + new IdentifierValue("DIRECTION_COUNT", QuoteCharacter.QUOTE), TypedSPILoader.getService(DatabaseType.class, "openGauss")).getColumnLabel(), is("DIRECTION_COUNT")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("direction_count", QuoteCharacter.QUOTE), + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), + new IdentifierValue("direction_count", QuoteCharacter.QUOTE), TypedSPILoader.getService(DatabaseType.class, "Oracle")).getColumnLabel(), is("direction_count")); } @Test void assertGetColumnLabelWithoutAlias() { - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnLabel(), is("COUNT( A.\"DIRECTION\" )")); - assertThat(new AggregationProjection(AggregationType.COUNT, "count( a.\"direction\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "count( a.\"direction\" )"), null, TypedSPILoader.getService(DatabaseType.class, "MySQL")).getColumnLabel(), is("count( a.\"direction\" )")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")).getColumnLabel(), is("count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), null, TypedSPILoader.getService(DatabaseType.class, "openGauss")).getColumnLabel(), is("count")); - assertThat(new AggregationProjection(AggregationType.COUNT, "COUNT( a.\"direction\" )", null, + assertThat(new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( a.\"direction\" )"), null, TypedSPILoader.getService(DatabaseType.class, "Oracle")).getColumnLabel(), is("COUNT(A.\"DIRECTION\")")); } @Test void assertGetAlias() { - Projection projection = new AggregationProjection(AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )", new IdentifierValue("AVG_DERIVED_COUNT_0"), mock(DatabaseType.class)); + Projection projection = new AggregationProjection(AggregationType.COUNT, new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT( A.\"DIRECTION\" )"), + new IdentifierValue("AVG_DERIVED_COUNT_0"), mock(DatabaseType.class)); Optional actual = projection.getAlias(); assertTrue(actual.isPresent()); assertThat(actual.get().getValue(), is("AVG_DERIVED_COUNT_0")); diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java index 7588c7db466e3..5852d8776e96c 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java @@ -66,32 +66,56 @@ public final class ColumnExtractor { public static Collection extract(final ExpressionSegment expression) { Collection result = new LinkedList<>(); if (expression instanceof BinaryOperationExpression) { - if (((BinaryOperationExpression) expression).getLeft() instanceof ColumnSegment) { - result.add((ColumnSegment) ((BinaryOperationExpression) expression).getLeft()); - } - if (((BinaryOperationExpression) expression).getRight() instanceof ColumnSegment) { - result.add((ColumnSegment) ((BinaryOperationExpression) expression).getRight()); - } - if (((BinaryOperationExpression) expression).getLeft() instanceof OuterJoinExpression) { - result.add(((OuterJoinExpression) ((BinaryOperationExpression) expression).getLeft()).getColumnName()); - } - if (((BinaryOperationExpression) expression).getRight() instanceof OuterJoinExpression) { - result.add(((OuterJoinExpression) ((BinaryOperationExpression) expression).getRight()).getColumnName()); - } + extractColumnsInBinaryOperationExpression((BinaryOperationExpression) expression, result); } if (expression instanceof InExpression && ((InExpression) expression).getLeft() instanceof ColumnSegment) { result.add((ColumnSegment) ((InExpression) expression).getLeft()); } if (expression instanceof InExpression && ((InExpression) expression).getLeft() instanceof RowExpression) { - extractColumnInRowExpression((InExpression) expression, result); + extractColumnsInRowExpression((InExpression) expression, result); + } + if (expression instanceof BetweenExpression) { + extractColumnsInBetweenExpression((BetweenExpression) expression, result); } - if (expression instanceof BetweenExpression && ((BetweenExpression) expression).getLeft() instanceof ColumnSegment) { - result.add((ColumnSegment) ((BetweenExpression) expression).getLeft()); + if (expression instanceof AggregationProjectionSegment) { + extractColumnsInAggregationProjectionSegment((AggregationProjectionSegment) expression, result); } return result; } - private static void extractColumnInRowExpression(final InExpression expression, final Collection result) { + private static void extractColumnsInBinaryOperationExpression(final BinaryOperationExpression expression, final Collection result) { + if (expression.getLeft() instanceof ColumnSegment) { + result.add((ColumnSegment) expression.getLeft()); + } + if (expression.getRight() instanceof ColumnSegment) { + result.add((ColumnSegment) expression.getRight()); + } + if (expression.getLeft() instanceof OuterJoinExpression) { + result.add(((OuterJoinExpression) expression.getLeft()).getColumnName()); + } + if (expression.getRight() instanceof OuterJoinExpression) { + result.add(((OuterJoinExpression) expression.getRight()).getColumnName()); + } + result.addAll(extract(expression.getLeft())); + result.addAll(extract(expression.getRight())); + } + + private static void extractColumnsInBetweenExpression(final BetweenExpression expression, final Collection result) { + if (expression.getLeft() instanceof ColumnSegment) { + result.add((ColumnSegment) expression.getLeft()); + } + if (expression.getBetweenExpr() instanceof ColumnSegment) { + result.add((ColumnSegment) expression.getBetweenExpr()); + } + if (expression.getAndExpr() instanceof ColumnSegment) { + result.add((ColumnSegment) expression.getAndExpr()); + } + result.addAll(extract(expression.getLeft())); + result.addAll(extract(expression.getBetweenExpr())); + result.addAll(extract(expression.getAndExpr())); + } + + private static void extractColumnsInRowExpression(final InExpression expression, final Collection result) { for (ExpressionSegment each : ((RowExpression) expression.getLeft()).getItems()) { if (each instanceof ColumnSegment) { result.add((ColumnSegment) each); @@ -99,6 +123,16 @@ private static void extractColumnInRowExpression(final InExpression expression, } } + private static void extractColumnsInAggregationProjectionSegment(final AggregationProjectionSegment expression, final Collection result) { + for (ExpressionSegment each : expression.getParameters()) { + if (each instanceof ColumnSegment) { + result.add((ColumnSegment) each); + } else { + result.addAll(extract(each)); + } + } + } + /** * Extract column segments. * diff --git a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dql/GeneralDQLE2EIT.java b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dql/GeneralDQLE2EIT.java index 9379586c32fe8..0cf4f04646f87 100644 --- a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dql/GeneralDQLE2EIT.java +++ b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dql/GeneralDQLE2EIT.java @@ -104,8 +104,8 @@ private void assertQueryForPreparedStatementWithXmlExpected(final E2ETestContext private void assertExecuteQueryWithExpectedDataSource(final AssertionTestParameter testParam, final E2ETestContext context) throws SQLException { try ( - Connection actualConnection = getEnvironmentEngine().getTargetDataSource().getConnection(); - Connection expectedConnection = getExpectedDataSource().getConnection()) { + Connection expectedConnection = getExpectedDataSource().getConnection(); + Connection actualConnection = getEnvironmentEngine().getTargetDataSource().getConnection()) { if (SQLExecuteType.LITERAL == context.getSqlExecuteType()) { assertExecuteQueryForStatement(context, actualConnection, expectedConnection, testParam); } else { @@ -117,10 +117,10 @@ private void assertExecuteQueryWithExpectedDataSource(final AssertionTestParamet private void assertExecuteQueryForStatement(final E2ETestContext context, final Connection actualConnection, final Connection expectedConnection, final AssertionTestParameter testParam) throws SQLException { try ( - Statement actualStatement = actualConnection.createStatement(); - ResultSet actualResultSet = actualStatement.executeQuery(context.getSQL()); Statement expectedStatement = expectedConnection.createStatement(); - ResultSet expectedResultSet = expectedStatement.executeQuery(context.getSQL())) { + ResultSet expectedResultSet = expectedStatement.executeQuery(context.getSQL()); + Statement actualStatement = actualConnection.createStatement(); + ResultSet actualResultSet = actualStatement.executeQuery(context.getSQL())) { assertResultSet(actualResultSet, expectedResultSet, testParam); } } @@ -128,15 +128,15 @@ private void assertExecuteQueryForStatement(final E2ETestContext context, final private void assertExecuteQueryForPreparedStatement(final E2ETestContext context, final Connection actualConnection, final Connection expectedConnection, final AssertionTestParameter testParam) throws SQLException { try ( - PreparedStatement actualPreparedStatement = actualConnection.prepareStatement(context.getSQL()); - PreparedStatement expectedPreparedStatement = expectedConnection.prepareStatement(context.getSQL())) { + PreparedStatement expectedPreparedStatement = expectedConnection.prepareStatement(context.getSQL()); + PreparedStatement actualPreparedStatement = actualConnection.prepareStatement(context.getSQL())) { for (SQLValue each : context.getAssertion().getSQLValues()) { actualPreparedStatement.setObject(each.getIndex(), each.getValue()); expectedPreparedStatement.setObject(each.getIndex(), each.getValue()); } try ( - ResultSet actualResultSet = actualPreparedStatement.executeQuery(); - ResultSet expectedResultSet = expectedPreparedStatement.executeQuery()) { + ResultSet expectedResultSet = expectedPreparedStatement.executeQuery(); + ResultSet actualResultSet = actualPreparedStatement.executeQuery()) { assertResultSet(actualResultSet, expectedResultSet, testParam); } } @@ -220,7 +220,7 @@ private void assertExecuteForStatement(final E2ETestContext context, final Conne try ( Statement actualStatement = actualConnection.createStatement(); Statement expectedStatement = expectedConnection.createStatement()) { - assertTrue(actualStatement.execute(context.getSQL()) && expectedStatement.execute(context.getSQL()), "Not a query statement."); + assertTrue(expectedStatement.execute(context.getSQL()) && actualStatement.execute(context.getSQL()), "Not a query statement."); try ( ResultSet actualResultSet = actualStatement.getResultSet(); ResultSet expectedResultSet = expectedStatement.getResultSet()) { @@ -238,7 +238,7 @@ private void assertExecuteForPreparedStatement(final E2ETestContext context, fin actualPreparedStatement.setObject(each.getIndex(), each.getValue()); expectedPreparedStatement.setObject(each.getIndex(), each.getValue()); } - assertTrue(actualPreparedStatement.execute() && expectedPreparedStatement.execute(), "Not a query statement."); + assertTrue(expectedPreparedStatement.execute() && actualPreparedStatement.execute(), "Not a query statement."); try ( ResultSet actualResultSet = actualPreparedStatement.getResultSet(); ResultSet expectedResultSet = expectedPreparedStatement.getResultSet()) {