diff --git a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContext.java b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContext.java index 1b62261baf772..67f058db78db6 100644 --- a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContext.java +++ b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContext.java @@ -41,14 +41,18 @@ public final class TransactionConnectionContext implements AutoCloseable { @Setter private volatile String readWriteSplitReplicaRoute; + private volatile TransactionManager transactionManager; + /** * Begin transaction. * - * @param transactionType transaction type + * @param transactionType transaction type + * @param transactionManager transaction manager */ - public void beginTransaction(final String transactionType) { + public void beginTransaction(final String transactionType, final TransactionManager transactionManager) { this.transactionType = transactionType; inTransaction = true; + this.transactionManager = transactionManager; } /** @@ -78,6 +82,15 @@ public Optional getReadWriteSplitReplicaRoute() { return Optional.ofNullable(readWriteSplitReplicaRoute); } + /** + * Get transaction manager. + * + * @return transaction manager + */ + public Optional getTransactionManager() { + return Optional.ofNullable(transactionManager); + } + @Override public void close() { transactionType = null; @@ -85,5 +98,6 @@ public void close() { beginMills = 0L; exceptionOccur = false; readWriteSplitReplicaRoute = null; + transactionManager = null; } } diff --git a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionManager.java b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionManager.java new file mode 100644 index 0000000000000..b38dcbc195a00 --- /dev/null +++ b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionManager.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.infra.session.connection.transaction; + +/** + * Transaction manager. + */ +public interface TransactionManager { +} diff --git a/infra/session/src/test/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContextTest.java b/infra/session/src/test/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContextTest.java index eb4bc69cc2623..0d0dd0d0993e4 100644 --- a/infra/session/src/test/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContextTest.java +++ b/infra/session/src/test/java/org/apache/shardingsphere/infra/session/connection/transaction/TransactionConnectionContextTest.java @@ -25,6 +25,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; class TransactionConnectionContextTest { @@ -32,7 +33,7 @@ class TransactionConnectionContextTest { @Test void assertBeginTransaction() { - transactionConnectionContext.beginTransaction("XA"); + transactionConnectionContext.beginTransaction("XA", mock(TransactionManager.class)); assertThat(transactionConnectionContext.getTransactionType(), is(Optional.of("XA"))); assertTrue(transactionConnectionContext.isInTransaction()); } @@ -44,19 +45,19 @@ void assertIsNotInDistributedTransactionWhenNotBegin() { @Test void assertIsNotInDistributedTransactionWithLocal() { - transactionConnectionContext.beginTransaction("LOCAL"); + transactionConnectionContext.beginTransaction("LOCAL", mock(TransactionManager.class)); assertFalse(transactionConnectionContext.isInDistributedTransaction()); } @Test void assertIsInDistributedTransactionWithXA() { - transactionConnectionContext.beginTransaction("XA"); + transactionConnectionContext.beginTransaction("XA", mock(TransactionManager.class)); assertTrue(transactionConnectionContext.isInDistributedTransaction()); } @Test void assertIsInDistributedTransactionWithBASE() { - transactionConnectionContext.beginTransaction("BASE"); + transactionConnectionContext.beginTransaction("BASE", mock(TransactionManager.class)); assertTrue(transactionConnectionContext.isInDistributedTransaction()); } @@ -68,7 +69,7 @@ void assertGetReadWriteSplitReplicaRoute() { @Test void assertClose() { - transactionConnectionContext.beginTransaction("XA"); + transactionConnectionContext.beginTransaction("XA", mock(TransactionManager.class)); transactionConnectionContext.close(); assertFalse(transactionConnectionContext.getTransactionType().isPresent()); assertFalse(transactionConnectionContext.isInTransaction()); diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/DriverDatabaseConnectionManager.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/DriverDatabaseConnectionManager.java index bc4779a15039f..5218b69f5a800 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/DriverDatabaseConnectionManager.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/DriverDatabaseConnectionManager.java @@ -118,7 +118,7 @@ public void begin() throws SQLException { close(); connectionTransaction.begin(); } - connectionContext.getTransactionContext().beginTransaction(String.valueOf(connectionTransaction.getTransactionType())); + connectionContext.getTransactionContext().beginTransaction(String.valueOf(connectionTransaction.getTransactionType()), connectionTransaction.getDistributedTransactionManager()); } /** diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java index 7da8af0ac3531..c41917e774504 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java @@ -171,7 +171,8 @@ private void processLocalTransaction() throws SQLException { return; } if (!autoCommit && !transactionContext.isInTransaction()) { - transactionContext.beginTransaction(String.valueOf(contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class).getDefaultType())); + transactionContext.beginTransaction(contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class).getDefaultType().name(), + databaseConnectionManager.getConnectionTransaction().getDistributedTransactionManager()); } } diff --git a/kernel/transaction/api/src/main/java/org/apache/shardingsphere/transaction/spi/ShardingSphereDistributedTransactionManager.java b/kernel/transaction/api/src/main/java/org/apache/shardingsphere/transaction/spi/ShardingSphereDistributedTransactionManager.java index 8a61f60b89803..94c53b99ae178 100644 --- a/kernel/transaction/api/src/main/java/org/apache/shardingsphere/transaction/spi/ShardingSphereDistributedTransactionManager.java +++ b/kernel/transaction/api/src/main/java/org/apache/shardingsphere/transaction/spi/ShardingSphereDistributedTransactionManager.java @@ -18,6 +18,7 @@ package org.apache.shardingsphere.transaction.spi; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; +import org.apache.shardingsphere.infra.session.connection.transaction.TransactionManager; import org.apache.shardingsphere.infra.spi.type.typed.TypedSPI; import org.apache.shardingsphere.transaction.api.TransactionType; @@ -29,7 +30,7 @@ /** * ShardingSphere distributed transaction manager. */ -public interface ShardingSphereDistributedTransactionManager extends TypedSPI, AutoCloseable { +public interface ShardingSphereDistributedTransactionManager extends TypedSPI, AutoCloseable, TransactionManager { /** * Initialize distributed transaction manager. diff --git a/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/ConnectionTransaction.java b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/ConnectionTransaction.java index c0456ef3ab3c4..aa82ded550159 100644 --- a/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/ConnectionTransaction.java +++ b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/ConnectionTransaction.java @@ -35,6 +35,7 @@ public final class ConnectionTransaction { @Getter private final TransactionType transactionType; + @Getter private final ShardingSphereDistributedTransactionManager distributedTransactionManager; private final TransactionConnectionContext transactionContext; @@ -42,7 +43,11 @@ public final class ConnectionTransaction { public ConnectionTransaction(final TransactionRule rule, final TransactionConnectionContext transactionContext) { transactionType = transactionContext.getTransactionType().isPresent() ? TransactionType.valueOf(transactionContext.getTransactionType().get()) : rule.getDefaultType(); this.transactionContext = transactionContext; - distributedTransactionManager = TransactionType.LOCAL == transactionType ? null : rule.getResource().getTransactionManager(rule.getDefaultType()); + if (transactionContext.getTransactionManager().isPresent()) { + distributedTransactionManager = (ShardingSphereDistributedTransactionManager) transactionContext.getTransactionManager().get(); + } else { + distributedTransactionManager = TransactionType.LOCAL == this.transactionType ? null : rule.getResource().getTransactionManager(rule.getDefaultType()); + } } /** diff --git a/kernel/transaction/core/src/test/java/org/apache/shardingsphere/transaction/ConnectionTransactionTest.java b/kernel/transaction/core/src/test/java/org/apache/shardingsphere/transaction/ConnectionTransactionTest.java index 3fde297470451..23a832a7e370d 100644 --- a/kernel/transaction/core/src/test/java/org/apache/shardingsphere/transaction/ConnectionTransactionTest.java +++ b/kernel/transaction/core/src/test/java/org/apache/shardingsphere/transaction/ConnectionTransactionTest.java @@ -48,21 +48,23 @@ void assertIsNotInDistributedTransactionWhenTransactionIsNotBegin() { @Test void assertIsNotInDistributedTransactionWhenIsNotDistributedTransaction() { TransactionConnectionContext context = new TransactionConnectionContext(); - context.beginTransaction("LOCAL"); + context.beginTransaction("LOCAL", mock(ShardingSphereDistributedTransactionManager.class)); assertFalse(new ConnectionTransaction(mock(TransactionRule.class), context).isInDistributedTransaction(context)); } @Test void assertIsNotInDistributedTransactionWhenDistributedTransactionIsNotBegin() { TransactionConnectionContext context = new TransactionConnectionContext(); - context.beginTransaction("XA"); + context.beginTransaction("XA", mock(ShardingSphereDistributedTransactionManager.class)); assertFalse(new ConnectionTransaction(mock(TransactionRule.class, RETURNS_DEEP_STUBS), context).isInDistributedTransaction(context)); } @Test void assertIsInDistributedTransaction() { TransactionConnectionContext context = new TransactionConnectionContext(); - context.beginTransaction("XA"); + ShardingSphereDistributedTransactionManager distributedTransactionManager = mock(ShardingSphereDistributedTransactionManager.class); + when(distributedTransactionManager.isInTransaction()).thenReturn(true); + context.beginTransaction("XA", distributedTransactionManager); TransactionRule rule = mock(TransactionRule.class, RETURNS_DEEP_STUBS); when(rule.getResource().getTransactionManager(rule.getDefaultType()).isInTransaction()).thenReturn(true); assertTrue(new ConnectionTransaction(rule, context).isInDistributedTransaction(context)); @@ -95,7 +97,9 @@ void assertIsHoldTransactionWithXAAndAutoCommit() { when(rule.getDefaultType()).thenReturn(TransactionType.XA); when(rule.getResource().getTransactionManager(TransactionType.XA).isInTransaction()).thenReturn(true); TransactionConnectionContext context = new TransactionConnectionContext(); - context.beginTransaction("XA"); + ShardingSphereDistributedTransactionManager distributedTransactionManager = mock(ShardingSphereDistributedTransactionManager.class); + when(distributedTransactionManager.isInTransaction()).thenReturn(true); + context.beginTransaction("XA", distributedTransactionManager); assertTrue(new ConnectionTransaction(rule, context).isHoldTransaction(true)); } @@ -144,7 +148,10 @@ void assertGetConnectionWithoutInDistributeTransaction() throws SQLException { @Test void assertGetConnectionWithInDistributeTransaction() throws SQLException { TransactionConnectionContext context = new TransactionConnectionContext(); - context.beginTransaction("XA"); + ShardingSphereDistributedTransactionManager distributedTransactionManager = mock(ShardingSphereDistributedTransactionManager.class); + when(distributedTransactionManager.isInTransaction()).thenReturn(true); + when(distributedTransactionManager.getConnection("foo_db", "foo_ds")).thenReturn(mock(Connection.class)); + context.beginTransaction("XA", distributedTransactionManager); TransactionRule rule = mock(TransactionRule.class, RETURNS_DEEP_STUBS); when(rule.getResource().getTransactionManager(rule.getDefaultType()).isInTransaction()).thenReturn(true); when(rule.getResource().getTransactionManager(rule.getDefaultType()).getConnection("foo_db", "foo_ds")).thenReturn(mock(Connection.class)); diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/jdbc/transaction/BackendTransactionManager.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/jdbc/transaction/BackendTransactionManager.java index 384c0e3445062..53bad50aceca1 100644 --- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/jdbc/transaction/BackendTransactionManager.java +++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/jdbc/transaction/BackendTransactionManager.java @@ -70,7 +70,7 @@ public BackendTransactionManager(final ProxyDatabaseConnectionManager databaseCo public void begin() { if (!connection.getConnectionSession().getTransactionStatus().isInTransaction()) { connection.getConnectionSession().getTransactionStatus().setInTransaction(true); - getTransactionContext().beginTransaction(String.valueOf(transactionType)); + getTransactionContext().beginTransaction(transactionType.name(), distributedTransactionManager); connection.closeHandlers(true); connection.closeConnections(false); } diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionXAHandler.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionXAHandler.java index 2cce1c63bac7d..6b404c50e2219 100644 --- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionXAHandler.java +++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionXAHandler.java @@ -29,13 +29,13 @@ import org.apache.shardingsphere.proxy.backend.response.data.QueryResponseRow; import org.apache.shardingsphere.proxy.backend.response.header.ResponseHeader; import org.apache.shardingsphere.proxy.backend.session.ConnectionSession; -import org.apache.shardingsphere.proxy.backend.util.TransactionUtils; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.xa.XABeginStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.xa.XACommitStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.xa.XARecoveryStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.xa.XARollbackStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.xa.XAStatement; -import org.apache.shardingsphere.transaction.api.TransactionType; +import org.apache.shardingsphere.transaction.ShardingSphereTransactionManagerEngine; +import org.apache.shardingsphere.transaction.rule.TransactionRule; import org.apache.shardingsphere.transaction.xa.jta.exception.XATransactionNestedBeginException; import java.sql.SQLException; @@ -88,8 +88,9 @@ public ResponseHeader execute() throws SQLException { private ResponseHeader begin() throws SQLException { ShardingSpherePreconditions.checkState(!connectionSession.getTransactionStatus().isInTransaction(), XATransactionNestedBeginException::new); ResponseHeader result = backendHandler.execute(); - TransactionType transactionType = TransactionUtils.getTransactionType(connectionSession.getConnectionContext().getTransactionContext()); - connectionSession.getConnectionContext().getTransactionContext().beginTransaction(String.valueOf(transactionType)); + TransactionRule transactionRule = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class); + ShardingSphereTransactionManagerEngine engine = transactionRule.getResource(); + connectionSession.getConnectionContext().getTransactionContext().beginTransaction(transactionRule.getDefaultType().name(), engine.getTransactionManager(transactionRule.getDefaultType())); return result; } diff --git a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionBackendHandlerTest.java b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionBackendHandlerTest.java index c08e8350e919d..4ab0b3005b65f 100644 --- a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionBackendHandlerTest.java +++ b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/transaction/TransactionBackendHandlerTest.java @@ -26,6 +26,7 @@ import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.TCLStatement; import org.apache.shardingsphere.test.mock.AutoMockExtension; import org.apache.shardingsphere.test.mock.StaticMockSettings; +import org.apache.shardingsphere.transaction.api.TransactionType; import org.apache.shardingsphere.transaction.core.TransactionOperationType; import org.apache.shardingsphere.transaction.rule.TransactionRule; import org.junit.jupiter.api.Test; @@ -58,7 +59,9 @@ void assertExecute() throws SQLException { private ContextManager mockContextManager() { ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS); - when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(new RuleMetaData(Collections.singleton(mock(TransactionRule.class)))); + TransactionRule transactionRule = mock(TransactionRule.class); + when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL); + when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(new RuleMetaData(Collections.singleton(transactionRule))); return result; } } diff --git a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/session/ConnectionSessionTest.java b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/session/ConnectionSessionTest.java index 96da820ed6335..9ae58ecc1b721 100644 --- a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/session/ConnectionSessionTest.java +++ b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/session/ConnectionSessionTest.java @@ -27,6 +27,7 @@ import org.apache.shardingsphere.proxy.backend.context.ProxyContext; import org.apache.shardingsphere.test.mock.AutoMockExtension; import org.apache.shardingsphere.test.mock.StaticMockSettings; +import org.apache.shardingsphere.transaction.api.TransactionType; import org.apache.shardingsphere.transaction.rule.TransactionRule; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -82,7 +83,9 @@ void assertSwitchSchemaWhileBegin() { private ContextManager mockContextManager() { ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS); - when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(new RuleMetaData(Collections.singleton(mock(TransactionRule.class)))); + TransactionRule transactionRule = mock(TransactionRule.class); + when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL); + when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(new RuleMetaData(Collections.singleton(transactionRule))); return result; }