Skip to content

Commit

Permalink
Refactor sql bind and encrypt logic
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Jan 16, 2025
1 parent 7ffbdca commit 2d5c680
Show file tree
Hide file tree
Showing 25 changed files with 355 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand Down Expand Up @@ -61,8 +60,6 @@ public final class EncryptConditionEngine {

private final EncryptRule rule;

private final ShardingSphereDatabase database;

static {
LOGICAL_OPERATOR.add("AND");
LOGICAL_OPERATOR.add("&&");
Expand Down Expand Up @@ -155,25 +152,36 @@ private Optional<EncryptCondition> createBinaryEncryptCondition(final BinaryOper
return Optional.empty();
}
ShardingSpherePreconditions.checkContains(SUPPORTED_COMPARE_OPERATOR, operator, () -> new UnsupportedEncryptSQLException(operator));
return createCompareEncryptCondition(tableName, expression, expression.getRight());
return createCompareEncryptCondition(tableName, expression);
}

private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
if (!(expression.getLeft() instanceof ColumnSegment) || compareRightValue instanceof SubqueryExpressionSegment) {
private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression) {
if (isLeftRightAllNotColumnSegment(expression) || isLeftRightContainsSubquerySegment(expression)) {
return Optional.empty();
}
if (compareRightValue instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, compareRightValue));
ColumnSegment columnSegment = expression.getLeft() instanceof ColumnSegment ? (ColumnSegment) expression.getLeft() : (ColumnSegment) expression.getRight();
ExpressionSegment compareValueSegment = expression.getLeft() instanceof ColumnSegment ? expression.getRight() : expression.getLeft();
if (compareValueSegment instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, compareValueSegment));
}
if (compareRightValue instanceof ListExpression) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, ((ListExpression) compareRightValue).getItems().get(0)));
if (compareValueSegment instanceof ListExpression) {
// TODO check this logic when items contain multiple values @duanzhengqiang
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, ((ListExpression) compareValueSegment).getItems().get(0)));
}
return Optional.empty();
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
ColumnSegment columnSegment = (ColumnSegment) expression.getLeft();
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue);
private boolean isLeftRightAllNotColumnSegment(final BinaryOperationExpression expression) {
return !(expression.getLeft() instanceof ColumnSegment) && !(expression.getRight() instanceof ColumnSegment);
}

private boolean isLeftRightContainsSubquerySegment(final BinaryOperationExpression expression) {
return expression.getLeft() instanceof SubqueryExpressionSegment || expression.getRight() instanceof SubqueryExpressionSegment;
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ColumnSegment columnSegment,
final ExpressionSegment compareValueSegment) {
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareValueSegment.getStartIndex(), compareValueSegment.getStopIndex(), compareValueSegment);
}

private static Optional<EncryptCondition> createInEncryptCondition(final String tableName, final InExpression inExpression, final ExpressionSegment inRightValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
Expand Down Expand Up @@ -55,7 +54,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props
if (!containsEncryptTable(rule, sqlStatementContext)) {
return;
}
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlRewriteContext);
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlStatementContext);
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext)
.build(new EncryptParameterRewritersRegistry(rule, sqlRewriteContext.getDatabase().getName(), encryptConditions));
Expand All @@ -77,23 +76,13 @@ private boolean containsEncryptTable(final EncryptRule rule, final SQLStatementC
return false;
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
&& !(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()).getWhereSegments().isEmpty()) {
return createEncryptConditions(rule, sqlRewriteContext, ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return createEncryptConditions(rule, sqlRewriteContext, sqlStatementContext);
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext,
final SQLStatementContext sqlStatementContext) {
private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
return new EncryptConditionEngine(rule).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -53,25 +51,7 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) {
return true;
}
if (sqlStatementContext instanceof SelectStatementContext) {
return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext);
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return false;
}

private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) {
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
if (isNeedRewrite(each)) {
return true;
}
}
return false;
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand All @@ -85,7 +65,7 @@ public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementConte

private List<Object> getEncryptedValues(final String schemaName, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
String columnName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalColumn().getValue();
EncryptTable encryptTable = rule.getEncryptTable(tableName);
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equals(((EncryptBinaryCondition) encryptCondition).getOperator()) && encryptColumn.getLikeQuery().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertOnUpdateTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertValuesTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptInsertSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.select.EncryptGroupByItemTokenGenerator;
Expand Down Expand Up @@ -70,8 +70,8 @@ public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
addSQLTokenGenerator(result, new EncryptUpdateAssignmentTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertValuesTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertDefaultColumnsTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertCipherNameTokenGenerator(rule));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
import java.util.List;

/**
* Insert predicate right value token generator for encrypt.
* Insert predicate value token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {
public final class EncryptInsertPredicateValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {

private final EncryptRule rule;

Expand All @@ -58,7 +58,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
EncryptPredicateRightValueTokenGenerator generator = new EncryptPredicateRightValueTokenGenerator(rule);
EncryptPredicateValueTokenGenerator generator = new EncryptPredicateValueTokenGenerator(rule);
generator.setParameters(parameters);
generator.setEncryptConditions(encryptConditions);
generator.setDatabase(database);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate;

import com.cedarsoftware.util.CaseInsensitiveSet;
import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
Expand Down Expand Up @@ -44,6 +45,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
Expand All @@ -57,11 +59,13 @@
@Setter
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext> {

private static final Collection<String> GREATER_LESS_THAN_EQUALS_OPERATORS = new CaseInsensitiveSet<>(Arrays.asList(">", "<", ">=", "<="));

private final EncryptRule rule;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand Down Expand Up @@ -114,16 +118,15 @@ private boolean includesLike(final Collection<WhereSegment> whereSegments, final

private boolean isLikeColumnSegment(final AndPredicate andPredicate, final ColumnSegment targetColumnSegment) {
for (ExpressionSegment each : andPredicate.getPredicates()) {
if (each instanceof BinaryOperationExpression
&& "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isSameColumnSegment(((BinaryOperationExpression) each).getLeft(), targetColumnSegment)) {
if (each instanceof BinaryOperationExpression && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isContainsColumnSegment(each, targetColumnSegment)) {
return true;
}
}
return false;
}

private boolean isSameColumnSegment(final ExpressionSegment columnSegment, final ColumnSegment targetColumnSegment) {
return columnSegment instanceof ColumnSegment && columnSegment.getStartIndex() == targetColumnSegment.getStartIndex() && columnSegment.getStopIndex() == targetColumnSegment.getStopIndex();
private boolean isContainsColumnSegment(final ExpressionSegment expressionSegment, final ColumnSegment targetColumnSegment) {
return expressionSegment.getStartIndex() <= targetColumnSegment.getStartIndex() && expressionSegment.getStopIndex() >= targetColumnSegment.getStopIndex();
}

private Collection<Projection> createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter, final DatabaseType databaseType) {
Expand Down
Loading

0 comments on commit 2d5c680

Please sign in to comment.