Skip to content

Commit

Permalink
Minor refactor for encrypt merge decorate logic
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Oct 31, 2024
1 parent 649c338 commit 351878c
Show file tree
Hide file tree
Showing 23 changed files with 86 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecorator;
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecoratorEngine;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dal.DALStatement;

import java.util.Optional;
Expand All @@ -38,13 +38,13 @@
public final class EncryptResultDecoratorEngine implements ResultDecoratorEngine<EncryptRule> {

@Override
public Optional<ResultDecorator<EncryptRule>> newInstance(final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database,
public Optional<ResultDecorator<EncryptRule>> newInstance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
final EncryptRule encryptRule, final ConfigurationProperties props, final SQLStatementContext sqlStatementContext) {
if (sqlStatementContext instanceof SelectStatementContext) {
return Optional.of(new EncryptDQLResultDecorator(database, encryptRule, (SelectStatementContext) sqlStatementContext));
return Optional.of(new EncryptDQLResultDecorator(database, metaData, (SelectStatementContext) sqlStatementContext));
}
if (sqlStatementContext.getSqlStatement() instanceof DALStatement) {
return Optional.of(new EncryptDALResultDecorator(globalRuleMetaData));
return Optional.of(new EncryptDALResultDecorator(metaData.getGlobalRuleMetaData()));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecorator;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.merge.result.impl.transparent.TransparentMergedResult;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;

/**
Expand All @@ -35,17 +36,17 @@ public final class EncryptDQLResultDecorator implements ResultDecorator<EncryptR

private final ShardingSphereDatabase database;

private final EncryptRule encryptRule;
private final ShardingSphereMetaData metaData;

private final SelectStatementContext selectStatementContext;

@Override
public MergedResult decorate(final QueryResult queryResult, final SQLStatementContext sqlStatementContext, final EncryptRule rule) {
return new EncryptMergedResult(database, encryptRule, selectStatementContext, new TransparentMergedResult(queryResult));
return new EncryptMergedResult(database, metaData, selectStatementContext, new TransparentMergedResult(queryResult));
}

@Override
public MergedResult decorate(final MergedResult mergedResult, final SQLStatementContext sqlStatementContext, final EncryptRule rule) {
return new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult);
return new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.external.sql.identifier.SQLExceptionIdentifier;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;

import java.io.InputStream;
Expand All @@ -42,7 +43,7 @@ public final class EncryptMergedResult implements MergedResult {

private final ShardingSphereDatabase database;

private final EncryptRule encryptRule;
private final ShardingSphereMetaData metaData;

private final SelectStatementContext selectStatementContext;

Expand All @@ -61,6 +62,10 @@ public Object getValue(final int columnIndex, final Class<?> type) throws SQLExc
}
String originalTableName = columnProjection.get().getOriginalTable().getValue();
String originalColumnName = columnProjection.get().getOriginalColumn().getValue();
ShardingSphereDatabase database = metaData.containsDatabase(columnProjection.get().getColumnBoundInfo().getOriginalDatabase().getValue())
? metaData.getDatabase(columnProjection.get().getColumnBoundInfo().getOriginalDatabase().getValue())
: this.database;
EncryptRule encryptRule = database.getRuleMetaData().getSingleRule(EncryptRule.class);
if (!encryptRule.findEncryptTable(originalTableName).map(optional -> optional.isEncryptColumn(originalColumnName)).orElse(false)) {
return mergedResult.getValue(columnIndex, type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ private ColumnProjection buildColumnProjection(final ColumnProjectionSegment seg
IdentifierValue owner = segment.getColumn().getOwner().map(OwnerSegment::getIdentifier).orElse(null);
ColumnProjection result = new ColumnProjection(owner, segment.getColumn().getIdentifier(), segment.getAliasName().isPresent() ? segment.getAlias().orElse(null) : null, databaseType,
segment.getColumn().getLeftParentheses().orElse(null), segment.getColumn().getRightParentheses().orElse(null));
result.setOriginalColumn(segment.getColumn().getColumnBoundInfo().getOriginalColumn());
result.setOriginalTable(segment.getColumn().getColumnBoundInfo().getOriginalTable());
result.setColumnBoundInfo(segment.getColumn().getColumnBoundInfo());
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -58,19 +59,15 @@ void assertCheckWhenCombineStatementContainsEncryptColumn() {
CombineSegment combineSegment = mock(CombineSegment.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getSqlStatement().getCombine().get()).thenReturn(combineSegment);
ColumnProjection orderIdColumn = new ColumnProjection("o", "order_id", null, new MySQLDatabaseType());
orderIdColumn.setOriginalTable(new IdentifierValue("t_order"));
orderIdColumn.setOriginalColumn(new IdentifierValue("order_id"));
orderIdColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("t_order"), new IdentifierValue("order_id")));
ColumnProjection userIdColumn = new ColumnProjection("o", "user_id", null, new MySQLDatabaseType());
userIdColumn.setOriginalTable(new IdentifierValue("t_order"));
userIdColumn.setOriginalColumn(new IdentifierValue("user_id"));
userIdColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("t_order"), new IdentifierValue("user_id")));
SelectStatementContext leftSelectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(leftSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(orderIdColumn, userIdColumn));
ColumnProjection merchantIdColumn = new ColumnProjection("m", "merchant_id", null, new MySQLDatabaseType());
merchantIdColumn.setOriginalTable(new IdentifierValue("t_merchant"));
merchantIdColumn.setOriginalColumn(new IdentifierValue("merchant_id"));
merchantIdColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("t_merchant"), new IdentifierValue("merchant_id")));
ColumnProjection merchantNameColumn = new ColumnProjection("m", "merchant_name", null, new MySQLDatabaseType());
merchantNameColumn.setOriginalTable(new IdentifierValue("t_merchant"));
merchantNameColumn.setOriginalColumn(new IdentifierValue("merchant_name"));
merchantNameColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("t_merchant"), new IdentifierValue("merchant_name")));
SelectStatementContext rightSelectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(rightSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(merchantIdColumn, merchantNameColumn));
Map<Integer, SelectStatementContext> subqueryContexts = new LinkedHashMap<>(2, 1F);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.merge.engine.ResultProcessEngine;
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecorator;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.mysql.dal.MySQLExplainStatement;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -60,7 +60,7 @@ class EncryptResultDecoratorEngineTest {
void assertNewInstanceWithSelectStatement() {
EncryptResultDecoratorEngine engine = (EncryptResultDecoratorEngine) OrderedSPILoader.getServices(ResultProcessEngine.class, Collections.singleton(rule)).get(rule);
Optional<ResultDecorator<EncryptRule>> actual =
engine.newInstance(mock(RuleMetaData.class), database, rule, mock(ConfigurationProperties.class), mock(SelectStatementContext.class, RETURNS_DEEP_STUBS));
engine.newInstance(mock(ShardingSphereMetaData.class), database, rule, mock(ConfigurationProperties.class), mock(SelectStatementContext.class, RETURNS_DEEP_STUBS));
assertTrue(actual.isPresent());
assertThat(actual.get(), instanceOf(EncryptDQLResultDecorator.class));
}
Expand All @@ -70,14 +70,14 @@ void assertNewInstanceWithDALStatement() {
SQLStatementContext sqlStatementContext = mock(ExplainStatementContext.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(MySQLExplainStatement.class));
EncryptResultDecoratorEngine engine = (EncryptResultDecoratorEngine) OrderedSPILoader.getServices(ResultProcessEngine.class, Collections.singleton(rule)).get(rule);
Optional<ResultDecorator<EncryptRule>> actual = engine.newInstance(mock(RuleMetaData.class), database, rule, mock(ConfigurationProperties.class), sqlStatementContext);
Optional<ResultDecorator<EncryptRule>> actual = engine.newInstance(mock(ShardingSphereMetaData.class), database, rule, mock(ConfigurationProperties.class), sqlStatementContext);
assertTrue(actual.isPresent());
assertThat(actual.get(), instanceOf(EncryptDALResultDecorator.class));
}

@Test
void assertNewInstanceWithOtherStatement() {
EncryptResultDecoratorEngine engine = (EncryptResultDecoratorEngine) OrderedSPILoader.getServices(ResultProcessEngine.class, Collections.singleton(rule)).get(rule);
assertFalse(engine.newInstance(mock(RuleMetaData.class), database, rule, mock(ConfigurationProperties.class), mock(InsertStatementContext.class)).isPresent());
assertFalse(engine.newInstance(mock(ShardingSphereMetaData.class), database, rule, mock(ConfigurationProperties.class), mock(InsertStatementContext.class)).isPresent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.junit.jupiter.api.Test;

Expand All @@ -37,7 +38,7 @@ class EncryptDQLResultDecoratorTest {
void assertDecorateQueryResult() throws SQLException {
QueryResult queryResult = mock(QueryResult.class);
when(queryResult.next()).thenReturn(true);
EncryptDQLResultDecorator decorator = new EncryptDQLResultDecorator(mock(ShardingSphereDatabase.class), mock(EncryptRule.class), mock(SelectStatementContext.class));
EncryptDQLResultDecorator decorator = new EncryptDQLResultDecorator(mock(ShardingSphereDatabase.class), mock(ShardingSphereMetaData.class), mock(SelectStatementContext.class));
MergedResult actual = decorator.decorate(queryResult, mock(SQLStatementContext.class), mock(EncryptRule.class));
assertTrue(actual.next());
}
Expand All @@ -46,7 +47,7 @@ void assertDecorateQueryResult() throws SQLException {
void assertDecorateMergedResult() throws SQLException {
MergedResult mergedResult = mock(MergedResult.class);
when(mergedResult.next()).thenReturn(true);
EncryptDQLResultDecorator decorator = new EncryptDQLResultDecorator(mock(ShardingSphereDatabase.class), mock(EncryptRule.class), mock(SelectStatementContext.class));
EncryptDQLResultDecorator decorator = new EncryptDQLResultDecorator(mock(ShardingSphereDatabase.class), mock(ShardingSphereMetaData.class), mock(SelectStatementContext.class));
MergedResult actual = decorator.decorate(mergedResult, mock(SQLStatementContext.class), mock(EncryptRule.class));
assertTrue(actual.next());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.shardingsphere.encrypt.merge.dql;

import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -45,7 +45,7 @@ class EncryptMergedResultTest {
private ShardingSphereDatabase database;

@Mock
private EncryptRule encryptRule;
private ShardingSphereMetaData metaData;

@Mock
private SelectStatementContext selectStatementContext;
Expand All @@ -55,32 +55,32 @@ class EncryptMergedResultTest {

@Test
void assertNext() throws SQLException {
assertFalse(new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult).next());
assertFalse(new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult).next());
}

@Test
void assertGetCalendarValue() throws SQLException {
Calendar calendar = Calendar.getInstance();
when(mergedResult.getCalendarValue(1, Date.class, calendar)).thenReturn(new Date(0L));
assertThat(new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult).getCalendarValue(1, Date.class, calendar), is(new Date(0L)));
assertThat(new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult).getCalendarValue(1, Date.class, calendar), is(new Date(0L)));
}

@Test
void assertGetInputStream() throws SQLException {
InputStream inputStream = mock(InputStream.class);
when(mergedResult.getInputStream(1, "asc")).thenReturn(inputStream);
assertThat(new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult).getInputStream(1, "asc"), is(inputStream));
assertThat(new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult).getInputStream(1, "asc"), is(inputStream));
}

@Test
void assertGetCharacterStream() throws SQLException {
Reader reader = mock(Reader.class);
when(mergedResult.getCharacterStream(1)).thenReturn(reader);
assertThat(new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult).getCharacterStream(1), is(reader));
assertThat(new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult).getCharacterStream(1), is(reader));
}

@Test
void assertWasNull() throws SQLException {
assertFalse(new EncryptMergedResult(database, encryptRule, selectStatementContext, mergedResult).wasNull());
assertFalse(new EncryptMergedResult(database, metaData, selectStatementContext, mergedResult).wasNull());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecorator;
import org.apache.shardingsphere.infra.merge.engine.decorator.ResultDecoratorEngine;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.mask.constant.MaskOrder;
import org.apache.shardingsphere.mask.merge.dql.MaskDQLResultDecorator;
import org.apache.shardingsphere.mask.rule.MaskRule;
Expand All @@ -36,7 +36,7 @@
public final class MaskResultDecoratorEngine implements ResultDecoratorEngine<MaskRule> {

@Override
public Optional<ResultDecorator<MaskRule>> newInstance(final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database,
public Optional<ResultDecorator<MaskRule>> newInstance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
final MaskRule maskRule, final ConfigurationProperties props, final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext ? Optional.of(new MaskDQLResultDecorator(maskRule, (SelectStatementContext) sqlStatementContext)) : Optional.empty();
}
Expand Down
Loading

0 comments on commit 351878c

Please sign in to comment.