From b42a4715e8b9070b074756118763970a49718988 Mon Sep 17 00:00:00 2001 From: "JinHyouk York (Ivan)" Date: Thu, 20 Jul 2023 21:09:53 +0900 Subject: [PATCH] feat: Add spring mutli container support (#1781) * feat: fix to support spring multi container * test: add test for spring multi container --- .../spring/SpringTransactionManager.kt | 55 +++-- .../SpringMultiContainerTransactionTest.kt | 225 ++++++++++++++++++ .../spring/SpringTransactionEntityTest.kt | 41 +++- 3 files changed, 294 insertions(+), 27 deletions(-) create mode 100644 spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringMultiContainerTransactionTest.kt diff --git a/spring-transaction/src/main/kotlin/org/jetbrains/exposed/spring/SpringTransactionManager.kt b/spring-transaction/src/main/kotlin/org/jetbrains/exposed/spring/SpringTransactionManager.kt index 4a7dc585df..0d623da0ad 100644 --- a/spring-transaction/src/main/kotlin/org/jetbrains/exposed/spring/SpringTransactionManager.kt +++ b/spring-transaction/src/main/kotlin/org/jetbrains/exposed/spring/SpringTransactionManager.kt @@ -38,7 +38,8 @@ class SpringTransactionManager( databaseConfig = databaseConfig ) { this } - @Volatile override var defaultIsolationLevel: Int = -1 + @Volatile + override var defaultIsolationLevel: Int = -1 get() { if (field == -1) { field = Database.getDefaultIsolationLevel(db) @@ -46,29 +47,50 @@ class SpringTransactionManager( return field } - private val springTxKey = "SPRING_TX_KEY" + private val transactionStackKey = "SPRING_TRANSACTION_STACK_KEY" + + private fun getTransactionStack(): List { + return TransactionSynchronizationManager.getResource(transactionStackKey) + ?.let { it as List } + ?: listOf() + } + + private fun setTransactionStack(list: List) { + TransactionSynchronizationManager.unbindResourceIfPossible(transactionStackKey) + TransactionSynchronizationManager.bindResource(transactionStackKey, list) + } + + private fun pushTransactionStack(transaction: TransactionManager) { + val transactionList = getTransactionStack() + setTransactionStack(transactionList + transaction) + } + + private fun popTransactionStack() = setTransactionStack(getTransactionStack().dropLast(1)) + + private fun getLastTransactionStack() = getTransactionStack().lastOrNull() override fun doBegin(transaction: Any, definition: TransactionDefinition) { super.doBegin(transaction, definition) if (TransactionSynchronizationManager.hasResource(obtainDataSource())) { - currentOrNull() ?: initTransaction() - } - if (!TransactionSynchronizationManager.hasResource(springTxKey)) { - TransactionSynchronizationManager.bindResource(springTxKey, transaction) + currentOrNull() ?: initTransaction(transaction) } + + pushTransactionStack(this@SpringTransactionManager) } override fun doCleanupAfterCompletion(transaction: Any) { super.doCleanupAfterCompletion(transaction) if (!TransactionSynchronizationManager.hasResource(obtainDataSource())) { TransactionSynchronizationManager.unbindResourceIfPossible(this) - TransactionSynchronizationManager.unbindResource(springTxKey) } + + popTransactionStack() + TransactionManager.resetCurrent(getLastTransactionStack()) + if (TransactionSynchronizationManager.isSynchronizationActive() && TransactionSynchronizationManager.getSynchronizations().isEmpty()) { TransactionSynchronizationManager.clearSynchronization() } - TransactionManager.resetCurrent(null) } override fun doSuspend(transaction: Any): Any { @@ -100,22 +122,22 @@ class SpringTransactionManager( isolationLevel = isolation } - getTransaction(tDefinition) - - return currentOrNull() ?: initTransaction() + val transactionStatus = (getTransaction(tDefinition) as DefaultTransactionStatus) + return currentOrNull() ?: initTransaction(transactionStatus.transaction) } - private fun initTransaction(): Transaction { + private fun initTransaction(transaction: Any): Transaction { val connection = (TransactionSynchronizationManager.getResource(obtainDataSource()) as ConnectionHolder).connection @Suppress("TooGenericExceptionCaught") val transactionImpl = try { - SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull()) + SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull(), transaction) } catch (e: Exception) { exposedLogger.error("Failed to start transaction. Connection will be closed.", e) connection.close() throw e } + TransactionManager.resetCurrent(this) return Transaction(transactionImpl).apply { TransactionSynchronizationManager.bindResource(this@SpringTransactionManager, this) @@ -144,7 +166,8 @@ class SpringTransactionManager( override val db: Database, override val transactionIsolation: Int, override val readOnly: Boolean, - override val outerTransaction: Transaction? + override val outerTransaction: Transaction?, + private val currentTransaction: Any, ) : TransactionInterface { override fun commit() { @@ -157,9 +180,7 @@ class SpringTransactionManager( override fun close() { if (TransactionSynchronizationManager.isActualTransactionActive()) { - TransactionSynchronizationManager.getResource(springTxKey)?.let { springTx -> - this@SpringTransactionManager.doCleanupAfterCompletion(springTx) - } + this@SpringTransactionManager.doCleanupAfterCompletion(currentTransaction) } } } diff --git a/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringMultiContainerTransactionTest.kt b/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringMultiContainerTransactionTest.kt new file mode 100644 index 0000000000..16980c13d9 --- /dev/null +++ b/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringMultiContainerTransactionTest.kt @@ -0,0 +1,225 @@ +package org.jetbrains.exposed.spring + +import org.jetbrains.exposed.dao.id.LongIdTable +import org.jetbrains.exposed.sql.SchemaUtils +import org.jetbrains.exposed.sql.deleteAll +import org.jetbrains.exposed.sql.insertAndGetId +import org.jetbrains.exposed.sql.selectAll +import org.jetbrains.exposed.sql.transactions.transaction +import org.junit.Assert +import org.junit.Test +import org.springframework.context.annotation.AnnotationConfigApplicationContext +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType +import org.springframework.transaction.annotation.EnableTransactionManagement +import org.springframework.transaction.annotation.Transactional +import javax.sql.DataSource +import kotlin.test.BeforeTest + +open class SpringMultiContainerTransactionTest { + + val orderContainer = AnnotationConfigApplicationContext(OrderConfig::class.java) + val paymentContainer = AnnotationConfigApplicationContext(PaymentConfig::class.java) + + val orders: Orders = orderContainer.getBean(Orders::class.java) + val payments: Payments = paymentContainer.getBean(Payments::class.java) + + @BeforeTest + open fun beforeTest() { + orders.init() + payments.init() + } + + @Test + open fun test1() { + Assert.assertEquals(0, orders.findAll().size) + Assert.assertEquals(0, payments.findAll().size) + } + + @Test + open fun test2() { + orders.create() + Assert.assertEquals(1, orders.findAll().size) + payments.create() + Assert.assertEquals(1, payments.findAll().size) + } + + @Test + open fun test3() { + orders.transaction { + payments.create() + orders.create() + payments.create() + } + Assert.assertEquals(1, orders.findAll().size) + Assert.assertEquals(2, payments.findAll().size) + } + + @Test + open fun test4() { + kotlin.runCatching { + orders.transaction { + orders.create() + payments.create() + throw SpringTransactionTestException() + } + } + Assert.assertEquals(0, orders.findAll().size) + Assert.assertEquals(1, payments.findAll().size) + } + + @Test + open fun test5() { + kotlin.runCatching { + orders.transaction { + orders.create() + payments.databaseTemplate { + payments.create() + throw SpringTransactionTestException() + } + } + } + Assert.assertEquals(0, orders.findAll().size) + Assert.assertEquals(0, payments.findAll().size) + } + + @Test + open fun test6() { + Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size) + Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size) + } + + @Test + open fun test7() { + orders.createWithExposedTrxBlock() + Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size) + payments.createWithExposedTrxBlock() + Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size) + } + + @Test + open fun test8() { + orders.transaction { + payments.createWithExposedTrxBlock() + orders.createWithExposedTrxBlock() + payments.createWithExposedTrxBlock() + } + Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size) + Assert.assertEquals(2, payments.findAllWithExposedTrxBlock().size) + } + + @Test + open fun test9() { + kotlin.runCatching { + orders.transaction { + orders.createWithExposedTrxBlock() + payments.createWithExposedTrxBlock() + throw SpringTransactionTestException() + } + } + Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size) + Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size) + } + + @Test + open fun test10() { + kotlin.runCatching { + orders.transaction { + orders.createWithExposedTrxBlock() + payments.databaseTemplate { + payments.createWithExposedTrxBlock() + throw SpringTransactionTestException() + } + } + } + Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size) + Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size) + } +} + +@Configuration +@EnableTransactionManagement(proxyTargetClass = true) +open class OrderConfig { + + @Bean + open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest1").setType(EmbeddedDatabaseType.H2).build() + + @Bean + open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource) + + @Bean + open fun orders() = Orders() +} + +@Transactional +open class Orders { + + open fun findAll() = Order.selectAll().map { it } + + open fun findAllWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { findAll() } + + open fun create() = Order.insertAndGetId { + it[buyer] = 123 + }.value + + open fun createWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { create() } + + open fun init() { + SchemaUtils.create(Order) + Order.deleteAll() + } + + open fun transaction(block: () -> Unit) { + block() + } +} + +object Order : LongIdTable("orders") { + val buyer = long("buyer_id") +} + +@Configuration +@EnableTransactionManagement(proxyTargetClass = true) +open class PaymentConfig { + + @Bean + open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest2").setType(EmbeddedDatabaseType.H2).build() + + @Bean + open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource) + + @Bean + open fun payments() = Payments() +} + +@Transactional +open class Payments { + + open fun findAll() = Payment.selectAll().map { it } + + open fun findAllWithExposedTrxBlock() = transaction { findAll() } + + open fun create() = Payment.insertAndGetId { + it[state] = "state" + }.value + + open fun createWithExposedTrxBlock() = transaction { create() } + + open fun init() { + SchemaUtils.create(Payment) + Payment.deleteAll() + } + + open fun databaseTemplate(block: () -> Unit) { + block() + } +} + +object Payment : LongIdTable("payments") { + val state = varchar("state", 50) +} + +private class SpringTransactionTestException : Error() diff --git a/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringTransactionEntityTest.kt b/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringTransactionEntityTest.kt index 7718696539..959971aa87 100644 --- a/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringTransactionEntityTest.kt +++ b/spring-transaction/src/test/kotlin/org/jetbrains/exposed/spring/SpringTransactionEntityTest.kt @@ -5,12 +5,13 @@ import org.jetbrains.exposed.dao.UUIDEntityClass import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.UUIDTable import org.jetbrains.exposed.sql.SchemaUtils -import org.jetbrains.exposed.sql.transactions.transaction import org.junit.Test import org.springframework.beans.factory.annotation.Autowired import org.springframework.test.annotation.Commit import org.springframework.transaction.annotation.Transactional import java.util.* +import kotlin.test.AfterTest +import kotlin.test.BeforeTest import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -39,6 +40,11 @@ class OrderDAO(id: EntityID) : UUIDEntity(id) { @org.springframework.stereotype.Service @Transactional open class Service { + + open fun init() { + SchemaUtils.create(CustomerTable, OrderTable) + } + open fun createCustomer(name: String): CustomerDAO { return CustomerDAO.new { this.name = name @@ -59,6 +65,14 @@ open class Service { open fun findOrderByProduct(product: String): OrderDAO? { return OrderDAO.find { OrderTable.product eq product }.singleOrNull() } + + open fun transaction(block: () -> Unit) { + block() + } + + open fun cleanUp() { + SchemaUtils.drop(CustomerTable, OrderTable) + } } open class SpringTransactionEntityTest : SpringTransactionTestBase() { @@ -66,29 +80,36 @@ open class SpringTransactionEntityTest : SpringTransactionTestBase() { @Autowired lateinit var service: Service - @Test @Commit - open fun test01() { - transaction { - SchemaUtils.create(CustomerTable, OrderTable) - } + @BeforeTest + open fun beforeTest() { + service.init() + } + @Test + @Commit + open fun test01() { val customer = service.createCustomer("Alice1") service.createOrder(customer, "Product1") val order = service.findOrderByProduct("Product1") assertNotNull(order) - transaction { + service.transaction { assertEquals("Alice1", order.customer.name) } } - @Test @Commit + @Test + @Commit fun test02() { service.doBoth("Bob", "Product2") val order = service.findOrderByProduct("Product2") assertNotNull(order) - transaction { + service.transaction { assertEquals("Bob", order.customer.name) - SchemaUtils.drop(CustomerTable, OrderTable) } } + + @AfterTest + fun afterTest() { + service.cleanUp() + } }