Skip to content

Commit

Permalink
Refactor EncryptOrderByItemTokenGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Aug 3, 2024
1 parent def40c0 commit 33225e5
Showing 1 changed file with 23 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByItem;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
Expand All @@ -48,7 +47,7 @@
*/
@HighFrequencyInvocation
@Setter
public final class EncryptOrderByItemTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, EncryptRuleAware, SchemaMetaDataAware {
public final class EncryptOrderByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext>, EncryptRuleAware, SchemaMetaDataAware {

private EncryptRule encryptRule;

Expand All @@ -58,24 +57,36 @@ public final class EncryptOrderByItemTokenGenerator implements CollectionSQLToke

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext && containsOrderByItem(sqlStatementContext);
return sqlStatementContext instanceof SelectStatementContext && containsOrderByItem((SelectStatementContext) sqlStatementContext);
}

private boolean containsOrderByItem(final SelectStatementContext sqlStatementContext) {
if (!sqlStatementContext.getOrderByContext().getItems().isEmpty() && !sqlStatementContext.getOrderByContext().isGenerated()) {
return true;
}
for (SelectStatementContext each : sqlStatementContext.getSubqueryContexts().values()) {
if (containsOrderByItem(each)) {
return true;
}
}
return false;
}

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
public Collection<SQLToken> generateSQLTokens(final SelectStatementContext sqlStatementContext) {
Collection<SQLToken> result = new LinkedHashSet<>();
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
for (OrderByItem each : getOrderByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
ColumnSegment columnSegment = ((ColumnOrderByItemSegment) each.getSegment()).getColumn();
Map<String, String> columnTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
result.addAll(generateSQLTokensWithColumnSegments(Collections.singleton(columnSegment), columnTableNames));
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
result.addAll(generateSQLTokens(Collections.singleton(columnSegment), columnTableNames));
}
}
return result;
}

private Collection<SubstitutableColumnNameToken> generateSQLTokensWithColumnSegments(final Collection<ColumnSegment> columnSegments, final Map<String, String> columnTableNames) {
private Collection<SubstitutableColumnNameToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Map<String, String> columnTableNames) {
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
for (ColumnSegment each : columnSegments) {
String tableName = columnTableNames.getOrDefault(each.getExpression(), "");
Expand All @@ -86,34 +97,14 @@ private Collection<SubstitutableColumnNameToken> generateSQLTokensWithColumnSegm
return result;
}

private Collection<OrderByItem> getOrderByItems(final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof SelectStatementContext)) {
return Collections.emptyList();
}
private Collection<OrderByItem> getOrderByItems(final SelectStatementContext sqlStatementContext) {
Collection<OrderByItem> result = new LinkedList<>();
SelectStatementContext statementContext = (SelectStatementContext) sqlStatementContext;
if (!statementContext.getOrderByContext().isGenerated()) {
result.addAll(statementContext.getOrderByContext().getItems());
if (!sqlStatementContext.getOrderByContext().isGenerated()) {
result.addAll(sqlStatementContext.getOrderByContext().getItems());
}
for (SelectStatementContext each : statementContext.getSubqueryContexts().values()) {
for (SelectStatementContext each : sqlStatementContext.getSubqueryContexts().values()) {
result.addAll(getOrderByItems(each));
}
return result;
}

private boolean containsOrderByItem(final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof SelectStatementContext)) {
return false;
}
SelectStatementContext statementContext = (SelectStatementContext) sqlStatementContext;
if (!statementContext.getOrderByContext().getItems().isEmpty() && !statementContext.getOrderByContext().isGenerated()) {
return true;
}
for (SelectStatementContext each : statementContext.getSubqueryContexts().values()) {
if (containsOrderByItem(each)) {
return true;
}
}
return false;
}
}

0 comments on commit 33225e5

Please sign in to comment.