Skip to content

Commit

Permalink
Refactor sql bind and encrypt logic (#34373)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Jan 16, 2025
1 parent 7ffbdca commit cc46a48
Show file tree
Hide file tree
Showing 26 changed files with 358 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand Down Expand Up @@ -61,8 +60,6 @@ public final class EncryptConditionEngine {

private final EncryptRule rule;

private final ShardingSphereDatabase database;

static {
LOGICAL_OPERATOR.add("AND");
LOGICAL_OPERATOR.add("&&");
Expand Down Expand Up @@ -155,25 +152,36 @@ private Optional<EncryptCondition> createBinaryEncryptCondition(final BinaryOper
return Optional.empty();
}
ShardingSpherePreconditions.checkContains(SUPPORTED_COMPARE_OPERATOR, operator, () -> new UnsupportedEncryptSQLException(operator));
return createCompareEncryptCondition(tableName, expression, expression.getRight());
return createCompareEncryptCondition(tableName, expression);
}

private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
if (!(expression.getLeft() instanceof ColumnSegment) || compareRightValue instanceof SubqueryExpressionSegment) {
private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression) {
if (isLeftRightAllNotColumnSegment(expression) || isLeftRightContainsSubquerySegment(expression)) {
return Optional.empty();
}
if (compareRightValue instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, compareRightValue));
ColumnSegment columnSegment = expression.getLeft() instanceof ColumnSegment ? (ColumnSegment) expression.getLeft() : (ColumnSegment) expression.getRight();
ExpressionSegment compareValueSegment = expression.getLeft() instanceof ColumnSegment ? expression.getRight() : expression.getLeft();
if (compareValueSegment instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, compareValueSegment));
}
if (compareRightValue instanceof ListExpression) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, ((ListExpression) compareRightValue).getItems().get(0)));
if (compareValueSegment instanceof ListExpression) {
// TODO check this logic when items contain multiple values @duanzhengqiang
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, ((ListExpression) compareValueSegment).getItems().get(0)));
}
return Optional.empty();
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
ColumnSegment columnSegment = (ColumnSegment) expression.getLeft();
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue);
private boolean isLeftRightAllNotColumnSegment(final BinaryOperationExpression expression) {
return !(expression.getLeft() instanceof ColumnSegment) && !(expression.getRight() instanceof ColumnSegment);
}

private boolean isLeftRightContainsSubquerySegment(final BinaryOperationExpression expression) {
return expression.getLeft() instanceof SubqueryExpressionSegment || expression.getRight() instanceof SubqueryExpressionSegment;
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ColumnSegment columnSegment,
final ExpressionSegment compareValueSegment) {
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareValueSegment.getStartIndex(), compareValueSegment.getStopIndex(), compareValueSegment);
}

private static Optional<EncryptCondition> createInEncryptCondition(final String tableName, final InExpression inExpression, final ExpressionSegment inRightValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
Expand Down Expand Up @@ -55,7 +54,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props
if (!containsEncryptTable(rule, sqlStatementContext)) {
return;
}
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlRewriteContext);
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlStatementContext);
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext)
.build(new EncryptParameterRewritersRegistry(rule, sqlRewriteContext.getDatabase().getName(), encryptConditions));
Expand All @@ -77,23 +76,13 @@ private boolean containsEncryptTable(final EncryptRule rule, final SQLStatementC
return false;
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
&& !(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()).getWhereSegments().isEmpty()) {
return createEncryptConditions(rule, sqlRewriteContext, ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return createEncryptConditions(rule, sqlRewriteContext, sqlStatementContext);
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext,
final SQLStatementContext sqlStatementContext) {
private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
return new EncryptConditionEngine(rule).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -53,25 +51,7 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) {
return true;
}
if (sqlStatementContext instanceof SelectStatementContext) {
return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext);
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return false;
}

private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) {
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
if (isNeedRewrite(each)) {
return true;
}
}
return false;
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand All @@ -85,7 +65,7 @@ public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementConte

private List<Object> getEncryptedValues(final String schemaName, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
String columnName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalColumn().getValue();
EncryptTable encryptTable = rule.getEncryptTable(tableName);
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equals(((EncryptBinaryCondition) encryptCondition).getOperator()) && encryptColumn.getLikeQuery().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertOnUpdateTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertValuesTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptInsertSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.select.EncryptGroupByItemTokenGenerator;
Expand Down Expand Up @@ -70,8 +70,8 @@ public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
addSQLTokenGenerator(result, new EncryptUpdateAssignmentTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertValuesTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertDefaultColumnsTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertCipherNameTokenGenerator(rule));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
import java.util.List;

/**
* Insert predicate right value token generator for encrypt.
* Insert predicate value token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {
public final class EncryptInsertPredicateValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {

private final EncryptRule rule;

Expand All @@ -58,7 +58,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
EncryptPredicateRightValueTokenGenerator generator = new EncryptPredicateRightValueTokenGenerator(rule);
EncryptPredicateValueTokenGenerator generator = new EncryptPredicateValueTokenGenerator(rule);
generator.setParameters(parameters);
generator.setEncryptConditions(encryptConditions);
generator.setDatabase(database);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public final class EncryptPredicateColumnTokenGenerator implements CollectionSQL

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand Down Expand Up @@ -114,16 +114,15 @@ private boolean includesLike(final Collection<WhereSegment> whereSegments, final

private boolean isLikeColumnSegment(final AndPredicate andPredicate, final ColumnSegment targetColumnSegment) {
for (ExpressionSegment each : andPredicate.getPredicates()) {
if (each instanceof BinaryOperationExpression
&& "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isSameColumnSegment(((BinaryOperationExpression) each).getLeft(), targetColumnSegment)) {
if (each instanceof BinaryOperationExpression && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isContainsColumnSegment(each, targetColumnSegment)) {
return true;
}
}
return false;
}

private boolean isSameColumnSegment(final ExpressionSegment columnSegment, final ColumnSegment targetColumnSegment) {
return columnSegment instanceof ColumnSegment && columnSegment.getStartIndex() == targetColumnSegment.getStartIndex() && columnSegment.getStopIndex() == targetColumnSegment.getStopIndex();
private boolean isContainsColumnSegment(final ExpressionSegment expressionSegment, final ColumnSegment targetColumnSegment) {
return expressionSegment.getStartIndex() <= targetColumnSegment.getStartIndex() && expressionSegment.getStopIndex() >= targetColumnSegment.getStopIndex();
}

private Collection<Projection> createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter, final DatabaseType databaseType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@
import java.util.Optional;

/**
* Predicate right value token generator for encrypt.
* Predicate value token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {
public final class EncryptPredicateValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {

private final EncryptRule rule;

Expand All @@ -69,7 +69,7 @@ public final class EncryptPredicateRightValueTokenGenerator implements Collectio

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand Down Expand Up @@ -102,15 +102,15 @@ private SQLToken generateSQLToken(final String schemaName, final EncryptTable en
private List<Object> getEncryptedValues(final String schemaName, final EncryptTable encryptTable, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equalsIgnoreCase(((EncryptBinaryCondition) encryptCondition).getOperator())) {
LikeQueryColumnItem likeQueryColumnItem = encryptColumn.getLikeQuery()
.orElseThrow(() -> new MissingMatchedEncryptQueryAlgorithmException(encryptTable.getTable(), columnName, "LIKE"));
return likeQueryColumnItem.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues);
}
return encryptColumn.getAssistedQuery()
.map(optional -> optional.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName,
originalValues));
.map(optional -> optional.encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues));
}

private Map<Integer, Object> getPositionValues(final Collection<Integer> valuePositions, final List<Object> encryptValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ public EncryptPredicateEqualRightValueToken(final int startIndex, final int stop
@Override
public String toString() {
if (paramMarkerIndexes.isEmpty()) {
return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString();
return indexValues.isEmpty() ? "" : getIndexValue(indexValues);
}
return "?";
}

private String getIndexValue(final Map<Integer, Object> indexValues) {
return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString();
}
}
Loading

0 comments on commit cc46a48

Please sign in to comment.