Skip to content

Commit

Permalink
Refactor SQLRewriteEntry to remove too much parameters (#33462)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Oct 30, 2024
1 parent 8b676c0 commit 2ba71b1
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ public SQLRewriteEntry(final ShardingSphereDatabase database, final RuleMetaData
* @return route unit and SQL rewrite result map
*/
public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext);
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext)
: new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext);
}

private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) {
private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext) {
HintValueContext hintValueContext = queryContext.getHintValueContext();
SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext);
SQLRewriteContext result = new SQLRewriteContext(database, queryContext);
decorate(result, routeContext, hintValueContext);
result.generateSQLTokens();
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import lombok.Getter;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
Expand All @@ -31,6 +30,7 @@
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.DefaultTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;

import java.util.Collection;
import java.util.LinkedList;
Expand Down Expand Up @@ -59,19 +59,18 @@ public final class SQLRewriteContext {

private final ConnectionContext connectionContext;

public SQLRewriteContext(final ShardingSphereDatabase database, final SQLStatementContext sqlStatementContext, final String sql, final List<Object> params,
final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
public SQLRewriteContext(final ShardingSphereDatabase database, final QueryContext queryContext) {
this.database = database;
this.sqlStatementContext = sqlStatementContext;
this.sql = sql;
parameters = params;
this.connectionContext = connectionContext;
if (!hintValueContext.isSkipSQLRewrite()) {
sqlStatementContext = queryContext.getSqlStatementContext();
sql = queryContext.getSql();
parameters = queryContext.getParameters();
connectionContext = queryContext.getConnectionContext();
if (!queryContext.getHintValueContext().isSkipSQLRewrite()) {
addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
}
parameterBuilder = containsInsertValues(sqlStatementContext)
? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
: new StandardParameterBuilder(params);
: new StandardParameterBuilder(parameters);
}

private boolean containsInsertValues(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -84,24 +84,36 @@ void assertInsertStatementContext() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext);
assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(GroupedParameterBuilder.class));
}

@Test
void assertNotInsertStatementContext() {
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(database, statementContext, "SELECT * FROM tbl WHERE id = ?", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(statementContext);
when(queryContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id = ?");
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext);
assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(StandardParameterBuilder.class));
}

@Test
void assertGenerateOptionalSQLToken() {
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(database, sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext);
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(optionalSQLTokenGenerator));
sqlRewriteContext.generateSQLTokens();
assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
Expand All @@ -110,8 +122,12 @@ void assertGenerateOptionalSQLToken() {

@Test
void assertGenerateCollectionSQLToken() {
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(database, sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)");
when(queryContext.getParameters()).thenReturn(Collections.singletonList(1));
when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext);
sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(collectionSQLTokenGenerator));
sqlRewriteContext.generateSQLTokens();
assertFalse(sqlRewriteContext.getSqlTokens().isEmpty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder;
Expand All @@ -54,15 +53,21 @@ void assertRewrite() {
when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class);
when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, sqlStatementContext, "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()), queryContext);
QueryContext queryContext = mockQueryContext(sqlStatementContext);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, queryContext), queryContext);
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}

private QueryContext mockQueryContext(final CommonSQLStatementContext sqlStatementContext) {
QueryContext result = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
when(result.getSql()).thenReturn("SELECT 1");
when(result.getParameters()).thenReturn(Collections.emptyList());
when(result.getHintValueContext()).thenReturn(new HintValueContext());
return result;
}

@Test
void assertRewriteStorageTypeIsEmpty() {
SQLTranslatorRule rule = new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build());
Expand All @@ -73,10 +78,8 @@ void assertRewriteStorageTypeIsEmpty() {
CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class);
DatabaseType databaseType = mock(DatabaseType.class);
when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType);
QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS);
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(database, sqlStatementContext, "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new HintValueContext()), queryContext);
QueryContext queryContext = mockQueryContext(sqlStatementContext);
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, queryContext), queryContext);
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}
Expand Down
Loading

0 comments on commit 2ba71b1

Please sign in to comment.