Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add spring mutli container support #1781

Merged
merged 6 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,59 @@ class SpringTransactionManager(
databaseConfig = databaseConfig
) { this }

@Volatile override var defaultIsolationLevel: Int = -1
@Volatile
override var defaultIsolationLevel: Int = -1
get() {
if (field == -1) {
field = Database.getDefaultIsolationLevel(db)
}
return field
}

private val springTxKey = "SPRING_TX_KEY"
private val transactionStackKey = "SPRING_TRANSACTION_STACK_KEY"

private fun getTransactionStack(): List<TransactionManager> {
return TransactionSynchronizationManager.getResource(transactionStackKey)
?.let { it as List<TransactionManager> }
?: listOf()
}

private fun setTransactionStack(list: List<TransactionManager>) {
TransactionSynchronizationManager.unbindResourceIfPossible(transactionStackKey)
TransactionSynchronizationManager.bindResource(transactionStackKey, list)
}

private fun pushTransactionStack(transaction: TransactionManager) {
val transactionList = getTransactionStack()
setTransactionStack(transactionList + transaction)
}

private fun popTransactionStack() = setTransactionStack(getTransactionStack().dropLast(1))

private fun getLastTransactionStack() = getTransactionStack().lastOrNull()

override fun doBegin(transaction: Any, definition: TransactionDefinition) {
super.doBegin(transaction, definition)

if (TransactionSynchronizationManager.hasResource(obtainDataSource())) {
currentOrNull() ?: initTransaction()
}
if (!TransactionSynchronizationManager.hasResource(springTxKey)) {
TransactionSynchronizationManager.bindResource(springTxKey, transaction)
currentOrNull() ?: initTransaction(transaction)
}

pushTransactionStack(this@SpringTransactionManager)
}

override fun doCleanupAfterCompletion(transaction: Any) {
super.doCleanupAfterCompletion(transaction)
if (!TransactionSynchronizationManager.hasResource(obtainDataSource())) {
TransactionSynchronizationManager.unbindResourceIfPossible(this)
TransactionSynchronizationManager.unbindResource(springTxKey)
}

popTransactionStack()
TransactionManager.resetCurrent(getLastTransactionStack())

if (TransactionSynchronizationManager.isSynchronizationActive() && TransactionSynchronizationManager.getSynchronizations().isEmpty()) {
TransactionSynchronizationManager.clearSynchronization()
}
TransactionManager.resetCurrent(null)
}

override fun doSuspend(transaction: Any): Any {
Expand Down Expand Up @@ -100,21 +122,21 @@ class SpringTransactionManager(
isolationLevel = isolation
}

getTransaction(tDefinition)

return currentOrNull() ?: initTransaction()
val transactionStatus = (getTransaction(tDefinition) as DefaultTransactionStatus)
return currentOrNull() ?: initTransaction(transactionStatus.transaction)
}

private fun initTransaction(): Transaction {
private fun initTransaction(transaction: Any): Transaction {
val connection = (TransactionSynchronizationManager.getResource(obtainDataSource()) as ConnectionHolder).connection

val transactionImpl = try {
SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull())
SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull(), transaction)
} catch (e: Exception) {
exposedLogger.error("Failed to start transaction. Connection will be closed.", e)
connection.close()
throw e
}

TransactionManager.resetCurrent(this)
return Transaction(transactionImpl).apply {
TransactionSynchronizationManager.bindResource(this@SpringTransactionManager, this)
Expand Down Expand Up @@ -143,7 +165,8 @@ class SpringTransactionManager(
override val db: Database,
override val transactionIsolation: Int,
override val readOnly: Boolean,
override val outerTransaction: Transaction?
override val outerTransaction: Transaction?,
private val currentTransaction: Any,
) : TransactionInterface {

override fun commit() {
Expand All @@ -156,9 +179,7 @@ class SpringTransactionManager(

override fun close() {
if (TransactionSynchronizationManager.isActualTransactionActive()) {
TransactionSynchronizationManager.getResource(springTxKey)?.let { springTx ->
[email protected](springTx)
}
[email protected](currentTransaction)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package org.jetbrains.exposed.spring

import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.deleteAll
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.selectAll
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.Assert
import org.junit.Test
import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType
import org.springframework.transaction.annotation.EnableTransactionManagement
import org.springframework.transaction.annotation.Transactional
import javax.sql.DataSource
import kotlin.test.BeforeTest

/**
* @author [email protected]
*/
open class SpringMultiContainerTransactionTest {

val orderContainer = AnnotationConfigApplicationContext(OrderConfig::class.java)
val paymentContainer = AnnotationConfigApplicationContext(PaymentConfig::class.java)

val orders: Orders = orderContainer.getBean(Orders::class.java)
val payments: Payments = paymentContainer.getBean(Payments::class.java)

@BeforeTest
open fun beforeTest() {
orders.init()
payments.init()
}

@Test
open fun test1() {
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(0, payments.findAll().size)
}

@Test
open fun test2() {
orders.create()
Assert.assertEquals(1, orders.findAll().size)
payments.create()
Assert.assertEquals(1, payments.findAll().size)
}

@Test
open fun test3() {
orders.transaction {
payments.create()
orders.create()
payments.create()
}
Assert.assertEquals(1, orders.findAll().size)
Assert.assertEquals(2, payments.findAll().size)
}

@Test
open fun test4() {
kotlin.runCatching {
orders.transaction {
orders.create()
payments.create()
throw Error()
}
}
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(1, payments.findAll().size)
}

@Test
open fun test5() {
kotlin.runCatching {
orders.transaction {
orders.create()
payments.databaseTemplate {
payments.create()
throw Error()
}
}
}
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(0, payments.findAll().size)
}

@Test
open fun test6() {
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test7() {
orders.createWithExposedTrxBlock()
Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size)
payments.createWithExposedTrxBlock()
Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test8() {
orders.transaction {
payments.createWithExposedTrxBlock()
orders.createWithExposedTrxBlock()
payments.createWithExposedTrxBlock()
}
Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(2, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test9() {
kotlin.runCatching {
orders.transaction {
orders.createWithExposedTrxBlock()
payments.createWithExposedTrxBlock()
throw Error()
}
}
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test10() {
kotlin.runCatching {
orders.transaction {
orders.createWithExposedTrxBlock()
payments.databaseTemplate {
payments.createWithExposedTrxBlock()
throw Error()
}
}
}
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size)
}
}

@Configuration
@EnableTransactionManagement(proxyTargetClass = true)
open class OrderConfig {

@Bean
open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest1").setType(EmbeddedDatabaseType.H2).build()

@Bean
open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource)

@Bean
open fun orders() = Orders()
}

@Transactional
open class Orders {

open fun findAll() = Order.selectAll().map { it }

open fun findAllWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { findAll() }

open fun create() = Order.insertAndGetId {
it[buyer] = 123
}.value

open fun createWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { create() }

open fun init() {
SchemaUtils.create(Order)
Order.deleteAll()
}

open fun transaction(block: () -> Unit) {
block()
}
}

object Order : LongIdTable("orders") {
val buyer = long("buyer_id")
}

@Configuration
@EnableTransactionManagement(proxyTargetClass = true)
open class PaymentConfig {

@Bean
open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest2").setType(EmbeddedDatabaseType.H2).build()

@Bean
open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource)

@Bean
open fun payments() = Payments()
}

@Transactional
open class Payments {

open fun findAll() = Payment.selectAll().map { it }

open fun findAllWithExposedTrxBlock() = transaction { findAll() }

open fun create() = Payment.insertAndGetId {
it[state] = "state"
}.value

open fun createWithExposedTrxBlock() = transaction { create() }

open fun init() {
SchemaUtils.create(Payment)
Payment.deleteAll()
}

open fun databaseTemplate(block: () -> Unit) {
block()
}
}

object Payment : LongIdTable("payments") {
val state = varchar("state", 50)
}
Loading