Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sql bind and encrypt logic #34373

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand Down Expand Up @@ -61,8 +60,6 @@ public final class EncryptConditionEngine {

private final EncryptRule rule;

private final ShardingSphereDatabase database;

static {
LOGICAL_OPERATOR.add("AND");
LOGICAL_OPERATOR.add("&&");
Expand Down Expand Up @@ -155,25 +152,36 @@ private Optional<EncryptCondition> createBinaryEncryptCondition(final BinaryOper
return Optional.empty();
}
ShardingSpherePreconditions.checkContains(SUPPORTED_COMPARE_OPERATOR, operator, () -> new UnsupportedEncryptSQLException(operator));
return createCompareEncryptCondition(tableName, expression, expression.getRight());
return createCompareEncryptCondition(tableName, expression);
}

private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
if (!(expression.getLeft() instanceof ColumnSegment) || compareRightValue instanceof SubqueryExpressionSegment) {
private Optional<EncryptCondition> createCompareEncryptCondition(final String tableName, final BinaryOperationExpression expression) {
if (isLeftRightAllNotColumnSegment(expression) || isLeftRightContainsSubquerySegment(expression)) {
return Optional.empty();
}
if (compareRightValue instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, compareRightValue));
ColumnSegment columnSegment = expression.getLeft() instanceof ColumnSegment ? (ColumnSegment) expression.getLeft() : (ColumnSegment) expression.getRight();
ExpressionSegment compareValueSegment = expression.getLeft() instanceof ColumnSegment ? expression.getRight() : expression.getLeft();
if (compareValueSegment instanceof SimpleExpressionSegment) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, compareValueSegment));
}
if (compareRightValue instanceof ListExpression) {
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, ((ListExpression) compareRightValue).getItems().get(0)));
if (compareValueSegment instanceof ListExpression) {
// TODO check this logic when items contain multiple values @duanzhengqiang
return Optional.of(createEncryptBinaryOperationCondition(tableName, expression, columnSegment, ((ListExpression) compareValueSegment).getItems().get(0)));
}
return Optional.empty();
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
ColumnSegment columnSegment = (ColumnSegment) expression.getLeft();
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue);
private boolean isLeftRightAllNotColumnSegment(final BinaryOperationExpression expression) {
return !(expression.getLeft() instanceof ColumnSegment) && !(expression.getRight() instanceof ColumnSegment);
}

private boolean isLeftRightContainsSubquerySegment(final BinaryOperationExpression expression) {
return expression.getLeft() instanceof SubqueryExpressionSegment || expression.getRight() instanceof SubqueryExpressionSegment;
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ColumnSegment columnSegment,
final ExpressionSegment compareValueSegment) {
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareValueSegment.getStartIndex(), compareValueSegment.getStopIndex(), compareValueSegment);
}

private static Optional<EncryptCondition> createInEncryptCondition(final String tableName, final InExpression inExpression, final ExpressionSegment inRightValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
Expand Down Expand Up @@ -55,7 +54,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props
if (!containsEncryptTable(rule, sqlStatementContext)) {
return;
}
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlRewriteContext);
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlStatementContext);
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext)
.build(new EncryptParameterRewritersRegistry(rule, sqlRewriteContext.getDatabase().getName(), encryptConditions));
Expand All @@ -77,23 +76,13 @@ private boolean containsEncryptTable(final EncryptRule rule, final SQLStatementC
return false;
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
&& !(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()).getWhereSegments().isEmpty()) {
return createEncryptConditions(rule, sqlRewriteContext, ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return createEncryptConditions(rule, sqlRewriteContext, sqlStatementContext);
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext,
final SQLStatementContext sqlStatementContext) {
private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
return new EncryptConditionEngine(rule).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -53,25 +51,7 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) {
return true;
}
if (sqlStatementContext instanceof SelectStatementContext) {
return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext);
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return false;
}

private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) {
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
if (isNeedRewrite(each)) {
return true;
}
}
return false;
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand All @@ -85,7 +65,7 @@ public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementConte

private List<Object> getEncryptedValues(final String schemaName, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
String columnName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalColumn().getValue();
EncryptTable encryptTable = rule.getEncryptTable(tableName);
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equals(((EncryptBinaryCondition) encryptCondition).getOperator()) && encryptColumn.getLikeQuery().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertOnUpdateTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.insert.EncryptInsertValuesTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptInsertPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateColumnTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateRightValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.predicate.EncryptPredicateValueTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptInsertSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.projection.EncryptSelectProjectionTokenGenerator;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.select.EncryptGroupByItemTokenGenerator;
Expand Down Expand Up @@ -70,8 +70,8 @@ public Collection<SQLTokenGenerator> getSQLTokenGenerators() {
addSQLTokenGenerator(result, new EncryptUpdateAssignmentTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateColumnTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateRightValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertPredicateValueTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertValuesTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertDefaultColumnsTokenGenerator(rule));
addSQLTokenGenerator(result, new EncryptInsertCipherNameTokenGenerator(rule));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
import java.util.List;

/**
* Insert predicate right value token generator for encrypt.
* Insert predicate value token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {
public final class EncryptInsertPredicateValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {

private final EncryptRule rule;

Expand All @@ -58,7 +58,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
EncryptPredicateRightValueTokenGenerator generator = new EncryptPredicateRightValueTokenGenerator(rule);
EncryptPredicateValueTokenGenerator generator = new EncryptPredicateValueTokenGenerator(rule);
generator.setParameters(parameters);
generator.setEncryptConditions(encryptConditions);
generator.setDatabase(database);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public final class EncryptPredicateColumnTokenGenerator implements CollectionSQL

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand Down Expand Up @@ -114,16 +114,15 @@ private boolean includesLike(final Collection<WhereSegment> whereSegments, final

private boolean isLikeColumnSegment(final AndPredicate andPredicate, final ColumnSegment targetColumnSegment) {
for (ExpressionSegment each : andPredicate.getPredicates()) {
if (each instanceof BinaryOperationExpression
&& "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isSameColumnSegment(((BinaryOperationExpression) each).getLeft(), targetColumnSegment)) {
if (each instanceof BinaryOperationExpression && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isContainsColumnSegment(each, targetColumnSegment)) {
return true;
}
}
return false;
}

private boolean isSameColumnSegment(final ExpressionSegment columnSegment, final ColumnSegment targetColumnSegment) {
return columnSegment instanceof ColumnSegment && columnSegment.getStartIndex() == targetColumnSegment.getStartIndex() && columnSegment.getStopIndex() == targetColumnSegment.getStopIndex();
private boolean isContainsColumnSegment(final ExpressionSegment expressionSegment, final ColumnSegment targetColumnSegment) {
return expressionSegment.getStartIndex() <= targetColumnSegment.getStartIndex() && expressionSegment.getStopIndex() >= targetColumnSegment.getStopIndex();
}

private Collection<Projection> createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter, final DatabaseType databaseType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@
import java.util.Optional;

/**
* Predicate right value token generator for encrypt.
* Predicate value token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptPredicateRightValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {
public final class EncryptPredicateValueTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, ParametersAware, EncryptConditionsAware, DatabaseAware {

private final EncryptRule rule;

Expand All @@ -69,7 +69,7 @@ public final class EncryptPredicateRightValueTokenGenerator implements Collectio

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
return sqlStatementContext instanceof WhereAvailable;
}

@Override
Expand Down Expand Up @@ -102,15 +102,15 @@ private SQLToken generateSQLToken(final String schemaName, final EncryptTable en
private List<Object> getEncryptedValues(final String schemaName, final EncryptTable encryptTable, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equalsIgnoreCase(((EncryptBinaryCondition) encryptCondition).getOperator())) {
LikeQueryColumnItem likeQueryColumnItem = encryptColumn.getLikeQuery()
.orElseThrow(() -> new MissingMatchedEncryptQueryAlgorithmException(encryptTable.getTable(), columnName, "LIKE"));
return likeQueryColumnItem.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues);
}
return encryptColumn.getAssistedQuery()
.map(optional -> optional.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName, originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, encryptCondition.getTableName(), columnName,
originalValues));
.map(optional -> optional.encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, tableName, encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues));
}

private Map<Integer, Object> getPositionValues(final Collection<Integer> valuePositions, final List<Object> encryptValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ public EncryptPredicateEqualRightValueToken(final int startIndex, final int stop
@Override
public String toString() {
if (paramMarkerIndexes.isEmpty()) {
return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString();
return indexValues.isEmpty() ? "" : getIndexValue(indexValues);
}
return "?";
}

private String getIndexValue(final Map<Integer, Object> indexValues) {
return indexValues.get(0) instanceof String ? "'" + indexValues.get(0) + "'" : indexValues.get(0).toString();
}
}
Loading
Loading