diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java index a445b98516573..8d1ee71672982 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java @@ -69,16 +69,16 @@ public SQLRewriteEntry(final ShardingSphereDatabase database, final RuleMetaData * @return route unit and SQL rewrite result map */ public SQLRewriteResult rewrite(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) { - SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext, connectionContext); + SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(queryContext, routeContext); SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class); return routeContext.getRouteUnits().isEmpty() ? new GenericSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, queryContext) : new RouteSQLRewriteEngine(rule, database, globalRuleMetaData).rewrite(sqlRewriteContext, routeContext, queryContext); } - private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext, final ConnectionContext connectionContext) { + private SQLRewriteContext createSQLRewriteContext(final QueryContext queryContext, final RouteContext routeContext) { HintValueContext hintValueContext = queryContext.getHintValueContext(); - SQLRewriteContext result = new SQLRewriteContext(database, queryContext.getSqlStatementContext(), queryContext.getSql(), queryContext.getParameters(), connectionContext, hintValueContext); + SQLRewriteContext result = new SQLRewriteContext(database, queryContext); decorate(result, routeContext, hintValueContext); result.generateSQLTokens(); return result; diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java index 6df38915018aa..f7c7616422d19 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java @@ -21,7 +21,6 @@ import lombok.Getter; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; -import org.apache.shardingsphere.infra.hint.HintValueContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder; import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder; @@ -31,6 +30,7 @@ import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.DefaultTokenGeneratorBuilder; import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken; import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import java.util.Collection; import java.util.LinkedList; @@ -59,19 +59,18 @@ public final class SQLRewriteContext { private final ConnectionContext connectionContext; - public SQLRewriteContext(final ShardingSphereDatabase database, final SQLStatementContext sqlStatementContext, final String sql, final List params, - final ConnectionContext connectionContext, final HintValueContext hintValueContext) { + public SQLRewriteContext(final ShardingSphereDatabase database, final QueryContext queryContext) { this.database = database; - this.sqlStatementContext = sqlStatementContext; - this.sql = sql; - parameters = params; - this.connectionContext = connectionContext; - if (!hintValueContext.isSkipSQLRewrite()) { + sqlStatementContext = queryContext.getSqlStatementContext(); + sql = queryContext.getSql(); + parameters = queryContext.getParameters(); + connectionContext = queryContext.getConnectionContext(); + if (!queryContext.getHintValueContext().isSkipSQLRewrite()) { addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators()); } parameterBuilder = containsInsertValues(sqlStatementContext) ? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters()) - : new StandardParameterBuilder(params); + : new StandardParameterBuilder(parameters); } private boolean containsInsertValues(final SQLStatementContext sqlStatementContext) { diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java index 14e4d6eda3e30..2734ef2468c5e 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextTest.java @@ -30,7 +30,7 @@ import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator; import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator; import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken; -import org.apache.shardingsphere.infra.session.connection.ConnectionContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -84,8 +84,12 @@ void assertInsertStatementContext() { InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS); when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false); when(statementContext.getInsertSelectContext()).thenReturn(null); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext); + QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(queryContext.getSqlStatementContext()).thenReturn(statementContext); + when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)"); + when(queryContext.getParameters()).thenReturn(Collections.singletonList(1)); + when(queryContext.getHintValueContext()).thenReturn(hintValueContext); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(GroupedParameterBuilder.class)); } @@ -93,15 +97,23 @@ void assertInsertStatementContext() { void assertNotInsertStatementContext() { SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "SELECT * FROM tbl WHERE id = ?", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext); + QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(queryContext.getSqlStatementContext()).thenReturn(statementContext); + when(queryContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id = ?"); + when(queryContext.getParameters()).thenReturn(Collections.singletonList(1)); + when(queryContext.getHintValueContext()).thenReturn(hintValueContext); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); assertThat(sqlRewriteContext.getParameterBuilder(), instanceOf(StandardParameterBuilder.class)); } @Test void assertGenerateOptionalSQLToken() { - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext); + QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); + when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)"); + when(queryContext.getParameters()).thenReturn(Collections.singletonList(1)); + when(queryContext.getHintValueContext()).thenReturn(hintValueContext); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(optionalSQLTokenGenerator)); sqlRewriteContext.generateSQLTokens(); assertFalse(sqlRewriteContext.getSqlTokens().isEmpty()); @@ -110,8 +122,12 @@ void assertGenerateOptionalSQLToken() { @Test void assertGenerateCollectionSQLToken() { - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, sqlStatementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), hintValueContext); + QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); + when(queryContext.getSql()).thenReturn("INSERT INTO tbl VALUES (?)"); + when(queryContext.getParameters()).thenReturn(Collections.singletonList(1)); + when(queryContext.getHintValueContext()).thenReturn(hintValueContext); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); sqlRewriteContext.addSQLTokenGenerators(Collections.singleton(collectionSQLTokenGenerator)); sqlRewriteContext.generateSQLTokens(); assertFalse(sqlRewriteContext.getSqlTokens().isEmpty()); diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java index 4c3f2f437d0a5..af9d7feae5087 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java @@ -27,7 +27,6 @@ import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult; -import org.apache.shardingsphere.infra.session.connection.ConnectionContext; import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; import org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder; @@ -54,15 +53,21 @@ void assertRewrite() { when(database.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits); CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class); when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType); - QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); - when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); - GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) - .rewrite(new SQLRewriteContext(database, sqlStatementContext, "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), - new HintValueContext()), queryContext); + QueryContext queryContext = mockQueryContext(sqlStatementContext); + GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, queryContext), queryContext); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); } + private QueryContext mockQueryContext(final CommonSQLStatementContext sqlStatementContext) { + QueryContext result = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(result.getSqlStatementContext()).thenReturn(sqlStatementContext); + when(result.getSql()).thenReturn("SELECT 1"); + when(result.getParameters()).thenReturn(Collections.emptyList()); + when(result.getHintValueContext()).thenReturn(new HintValueContext()); + return result; + } + @Test void assertRewriteStorageTypeIsEmpty() { SQLTranslatorRule rule = new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()); @@ -73,10 +78,8 @@ void assertRewriteStorageTypeIsEmpty() { CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class); DatabaseType databaseType = mock(DatabaseType.class); when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType); - QueryContext queryContext = mock(QueryContext.class, RETURNS_DEEP_STUBS); - when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); - GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)) - .rewrite(new SQLRewriteContext(database, sqlStatementContext, "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new HintValueContext()), queryContext); + QueryContext queryContext = mockQueryContext(sqlStatementContext); + GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, database, mock(RuleMetaData.class)).rewrite(new SQLRewriteContext(database, queryContext), queryContext); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); } diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java index 39fed6e18bed1..07f0121ee2fd7 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java @@ -34,7 +34,6 @@ import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.route.context.RouteMapper; import org.apache.shardingsphere.infra.route.context.RouteUnit; -import org.apache.shardingsphere.infra.session.connection.ConnectionContext; import org.apache.shardingsphere.infra.session.query.QueryContext; import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule; import org.apache.shardingsphere.sqltranslator.rule.builder.DefaultSQLTranslatorRuleConfigurationBuilder; @@ -59,12 +58,11 @@ void assertRewriteWithStandardParameterBuilder() { ShardingSphereDatabase database = mockDatabase(databaseType); CommonSQLStatementContext sqlStatementContext = mock(CommonSQLStatementContext.class); when(sqlStatementContext.getDatabaseType()).thenReturn(databaseType); - SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, sqlStatementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(sqlStatementContext, "SELECT ?"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); @@ -72,6 +70,15 @@ void assertRewriteWithStandardParameterBuilder() { assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); } + private QueryContext mockQueryContext(final CommonSQLStatementContext sqlStatementContext, final String sql) { + QueryContext result = mock(QueryContext.class, RETURNS_DEEP_STUBS); + when(result.getSqlStatementContext()).thenReturn(sqlStatementContext); + when(result.getSql()).thenReturn(sql); + when(result.getParameters()).thenReturn(Collections.singletonList(1)); + when(result.getHintValueContext()).thenReturn(new HintValueContext()); + return result; + } + private ShardingSphereDatabase mockDatabase(final DatabaseType databaseType) { ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); when(result.getProtocolType()).thenReturn(databaseType); @@ -90,14 +97,13 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() { DatabaseType databaseType = mock(DatabaseType.class); when(statementContext.getDatabaseType()).thenReturn(databaseType); ShardingSphereDatabase database = mockDatabase(databaseType); - SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, statementContext, "SELECT ?", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(statementContext, "SELECT ?"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteContext routeContext = new RouteContext(); RouteUnit firstRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteUnit secondRouteUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_1"))); routeContext.getRouteUnits().add(firstRouteUnit); routeContext.getRouteUnits().add(secondRouteUnit); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(statementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); @@ -115,13 +121,11 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() { DatabaseType databaseType = mock(DatabaseType.class); when(statementContext.getDatabaseType()).thenReturn(databaseType); ShardingSphereDatabase database = mockDatabase(databaseType); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(statementContext, "INSERT INTO tbl VALUES (?)"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(statementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); @@ -139,15 +143,13 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { DatabaseType databaseType = mock(DatabaseType.class); when(statementContext.getDatabaseType()).thenReturn(databaseType); ShardingSphereDatabase database = mockDatabase(databaseType); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(statementContext, "INSERT INTO tbl VALUES (?)"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); // TODO check why data node is "ds.tbl_0", not "ds_0.tbl_0" routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0"))); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(statementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); @@ -165,14 +167,12 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() { DatabaseType databaseType = mock(DatabaseType.class); when(statementContext.getDatabaseType()).thenReturn(databaseType); ShardingSphereDatabase database = mockDatabase(databaseType); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(statementContext, "INSERT INTO tbl VALUES (?)"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.emptyList()); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(statementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); @@ -190,14 +190,12 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() { DatabaseType databaseType = mock(DatabaseType.class); when(statementContext.getDatabaseType()).thenReturn(databaseType); ShardingSphereDatabase database = mockDatabase(databaseType); - SQLRewriteContext sqlRewriteContext = - new SQLRewriteContext(database, statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); + QueryContext queryContext = mockQueryContext(statementContext, "INSERT INTO tbl VALUES (?)"); + SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(database, queryContext); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); RouteContext routeContext = new RouteContext(); routeContext.getRouteUnits().add(routeUnit); routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1"))); - QueryContext queryContext = mock(QueryContext.class); - when(queryContext.getSqlStatementContext()).thenReturn(statementContext); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( new SQLTranslatorRule(new DefaultSQLTranslatorRuleConfigurationBuilder().build()), database, mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext, queryContext); assertThat(actual.getSqlRewriteUnits().size(), is(1));