Skip to content

Commit

Permalink
Support copy statement sql bind and add bind test case
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Dec 26, 2024
1 parent 2482ee4 commit d9037d3
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ public final class ShardingCopySupportedChecker implements SupportedSQLChecker<C

@Override
public boolean isCheck(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof CopyStatementContext;
return sqlStatementContext instanceof CopyStatementContext && ((CopyStatementContext) sqlStatementContext).getSqlStatement().getTable().isPresent();
}

@Override
public void check(final ShardingRule rule, final ShardingSphereDatabase database, final ShardingSphereSchema currentSchema, final CopyStatementContext sqlStatementContext) {
String tableName = sqlStatementContext.getSqlStatement().getTableSegment().getTableName().getIdentifier().getValue();
String tableName = sqlStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
ShardingSpherePreconditions.checkState(!rule.isShardingTable(tableName), () -> new UnsupportedShardingOperationException("COPY", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ class ShardingCopySupportedCheckerTest {
@Test
void assertCheckWhenTableSegmentForPostgreSQL() {
PostgreSQLCopyStatement sqlStatement = new PostgreSQLCopyStatement();
sqlStatement.setTableSegment(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
assertDoesNotThrow(() -> new ShardingCopySupportedChecker().check(rule, mock(), mock(), new CopyStatementContext(sqlStatement)));
}

@Test
void assertCheckWhenTableSegmentForOpenGauss() {
OpenGaussCopyStatement sqlStatement = new OpenGaussCopyStatement();
sqlStatement.setTableSegment(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
assertDoesNotThrow(() -> new ShardingCopySupportedChecker().check(rule, mock(), mock(), new CopyStatementContext(sqlStatement)));
}

Expand All @@ -67,7 +67,7 @@ void assertCheckCopyWithShardingTableForOpenGauss() {
}

private void assertCheckCopyTable(final CopyStatement sqlStatement) {
sqlStatement.setTableSegment(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
CopyStatementContext sqlStatementContext = new CopyStatementContext(sqlStatement);
String tableName = "t_order";
when(rule.isShardingTable(tableName)).thenReturn(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.CopyStatement;

import java.util.Collection;
import java.util.Collections;

/**
* Copy statement context.
*/
Expand All @@ -33,7 +37,8 @@ public final class CopyStatementContext extends CommonSQLStatementContext implem

public CopyStatementContext(final CopyStatement sqlStatement) {
super(sqlStatement);
tablesContext = new TablesContext(sqlStatement.getTableSegment());
Collection<SimpleTableSegment> tables = sqlStatement.getTable().isPresent() ? Collections.singleton(sqlStatement.getTable().get()) : Collections.emptyList();
tablesContext = new TablesContext(tables);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
*/
public enum SegmentType {

PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, INSERT_COLUMNS, DEFINITION_COLUMNS
PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, COPY, INSERT_COLUMNS, DEFINITION_COLUMNS
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.dml.prepare;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.DeleteStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.InsertStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.SelectStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.UpdateStatementBinder;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.prepare.PrepareStatementQuerySegment;

/**
* Prepare statement query segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class PrepareStatementQuerySegmentBinder {

/**
* Bind prepare statement query segment.
*
* @param segment prepare statement query segment
* @param binderContext SQL statement binder context
* @return bound prepare statement query segment
*/
public static PrepareStatementQuerySegment bind(final PrepareStatementQuerySegment segment, final SQLStatementBinderContext binderContext) {
PrepareStatementQuerySegment result = new PrepareStatementQuerySegment(segment.getStartIndex(), segment.getStopIndex());
segment.getSelect().ifPresent(optional -> result.setSelect(new SelectStatementBinder().bind(optional, binderContext)));
segment.getInsert().ifPresent(optional -> result.setInsert(new InsertStatementBinder().bind(optional, binderContext)));
segment.getUpdate().ifPresent(optional -> result.setUpdate(new UpdateStatementBinder().bind(optional, binderContext)));
segment.getDelete().ifPresent(optional -> result.setDelete(new DeleteStatementBinder().bind(optional, binderContext)));
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.statement.dml;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.SneakyThrows;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.dml.expression.type.ColumnSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.dml.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.dml.from.type.SimpleTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.dml.prepare.PrepareStatementQuerySegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.CopyStatement;

/**
* Copy statement binder.
*/
public final class CopyStatementBinder implements SQLStatementBinder<CopyStatement> {

@Override
public CopyStatement bind(final CopyStatement sqlStatement, final SQLStatementBinderContext binderContext) {
CopyStatement result = copy(sqlStatement);
Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts = LinkedHashMultimap.create();
sqlStatement.getTable().ifPresent(optional -> result.setTable(SimpleTableSegmentBinder.bind(optional, binderContext, tableBinderContexts)));
sqlStatement.getPrepareStatementQuery().ifPresent(optional -> result.setPrepareStatementQuery(PrepareStatementQuerySegmentBinder.bind(optional, binderContext)));
sqlStatement.getColumns().forEach(each -> result.getColumns().add(ColumnSegmentBinder.bind(each, SegmentType.COPY, binderContext, tableBinderContexts, LinkedHashMultimap.create())));
return result;
}

@SneakyThrows(ReflectiveOperationException.class)
private CopyStatement copy(final CopyStatement sqlStatement) {
CopyStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
result.getVariableNames().addAll(sqlStatement.getVariableNames());
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.CopyStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.DeleteStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.InsertStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.SelectStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.dml.UpdateStatementBinder;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.CopyStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.DMLStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.DeleteStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement;
Expand Down Expand Up @@ -63,6 +65,9 @@ public DMLStatement bind(final DMLStatement statement) {
if (statement instanceof DeleteStatement) {
return new DeleteStatementBinder().bind((DeleteStatement) statement, binderContext);
}
if (statement instanceof CopyStatement) {
return new CopyStatementBinder().bind((CopyStatement) statement, binderContext);
}
return statement;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public ASTNode visitDoStatement(final DoStatementContext ctx) {
public ASTNode visitCopy(final CopyContext ctx) {
OpenGaussCopyStatement result = new OpenGaussCopyStatement();
if (null != ctx.qualifiedName()) {
result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
result.setTable((SimpleTableSegment) visit(ctx.qualifiedName()));
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public ASTNode visitCopy(final CopyContext ctx) {
public ASTNode visitCopyWithTableOrQuery(final CopyWithTableOrQueryContext ctx) {
PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
if (null != ctx.qualifiedName()) {
result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
result.setTable((SimpleTableSegment) visit(ctx.qualifiedName()));
if (null != ctx.columnNames()) {
result.getColumns().addAll(((CollectionValue<ColumnSegment>) visit(ctx.columnNames())).getValue());
}
Expand Down Expand Up @@ -127,7 +127,7 @@ private PrepareStatementQuerySegment extractPrepareStatementQuerySegmentFromPrep
public ASTNode visitCopyWithTableOrQueryBinaryCsv(final CopyWithTableOrQueryBinaryCsvContext ctx) {
PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
if (null != ctx.qualifiedName()) {
result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
result.setTable((SimpleTableSegment) visit(ctx.qualifiedName()));
if (null != ctx.columnNames()) {
result.getColumns().addAll(((CollectionValue<ColumnSegment>) visit(ctx.columnNames())).getValue());
}
Expand All @@ -142,7 +142,7 @@ public ASTNode visitCopyWithTableOrQueryBinaryCsv(final CopyWithTableOrQueryBina
public ASTNode visitCopyWithTableBinary(final CopyWithTableBinaryContext ctx) {
PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
if (null != ctx.qualifiedName()) {
result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
result.setTable((SimpleTableSegment) visit(ctx.qualifiedName()));
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,24 @@
@Setter
public abstract class CopyStatement extends AbstractSQLStatement implements DMLStatement {

private SimpleTableSegment tableSegment;
private SimpleTableSegment table;

/**
* Get table.
*
* @return table
*/
public Optional<SimpleTableSegment> getTable() {
return Optional.ofNullable(table);
}

/**
* Set prepare statement query segment.
*
* @param prepareStatementQuery prepare statement query segment
*/
public void setPrepareStatementQuery(final PrepareStatementQuerySegment prepareStatementQuery) {
}

/**
* Get prepare statement query segment.
Expand Down
Loading

0 comments on commit d9037d3

Please sign in to comment.