diff --git a/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt b/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt index 1707b9b..d8ae8c0 100644 --- a/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt +++ b/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt @@ -13,6 +13,7 @@ import no.liflig.documentstore.entity.EntityId import no.liflig.documentstore.entity.Version import no.liflig.documentstore.entity.Versioned import no.liflig.documentstore.entity.getEntityIdType +import no.liflig.documentstore.utils.BatchOperationException import no.liflig.documentstore.utils.currentTimeWithMicrosecondPrecision import no.liflig.documentstore.utils.executeBatchOperation import no.liflig.documentstore.utils.isEmpty @@ -206,31 +207,37 @@ open class RepositoryJdbi>( val createdAt = currentTimeWithMicrosecondPrecision() val version = Version.initial() - transactional { - useHandle(jdbi) { handle -> - executeBatchOperation( - handle, - entities, - statement = - """ - INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) - VALUES (:id, :data::jsonb, :version, :createdAt, :modifiedAt) - """ - .trimIndent(), - bindParameters = { batch, entity -> - batch - .bind("id", entity.id) - .bind("data", serializationAdapter.toJson(entity)) - .bind("version", version) - .bind("createdAt", createdAt) - .bind("modifiedAt", createdAt) - }, - ) + try { + transactional { + useHandle(jdbi) { handle -> + executeBatchOperation( + handle, + entities, + statement = + """ + INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) + VALUES (:id, :data::jsonb, :version, :createdAt, :modifiedAt) + """ + .trimIndent(), + bindParameters = { batch, entity -> + batch + .bind("id", entity.id) + .bind("data", serializationAdapter.toJson(entity)) + .bind("version", version) + .bind("createdAt", createdAt) + .bind("modifiedAt", createdAt) + }, + ) + } } + } catch (e: BatchOperationException) { + @Suppress("UNCHECKED_CAST") // We know the entity is EntityT on this repository + throw mapCreateOrUpdateException(e.cause, e.entity as EntityT) } // We wait until here to create the result list, which may be large, to avoid allocating it - // before calling the database. That would keep the list in memory while we are waiting for the + // before calling the database. That would keep the list in memory while we are waiting for + // the // database, needlessly reducing throughput. return entities.map { entity -> Versioned(entity, version, createdAt = createdAt, modifiedAt = createdAt) @@ -244,43 +251,48 @@ open class RepositoryJdbi>( val modifiedAt = currentTimeWithMicrosecondPrecision() - transactional { - useHandle(jdbi) { handle -> - executeBatchOperation( - handle, - entities, - statement = - """ - UPDATE "${tableName}" - SET - data = :data::jsonb, - version = :nextVersion, - modified_at = :modifiedAt - WHERE - id = :id AND - version = :previousVersion - """ - .trimIndent(), - bindParameters = { batch, entity -> - val nextVersion = entity.version.next() - - batch - .bind("data", serializationAdapter.toJson(entity.item)) - .bind("nextVersion", nextVersion) - .bind("modifiedAt", modifiedAt) - .bind("id", entity.item.id) - .bind("previousVersion", entity.version) - }, - handleModifiedRowCounts = { counts, batchStartIndex -> - handleModifiedRowCounts(counts, batchStartIndex, entities, operation = "update") - }, - ) + try { + transactional { + useHandle(jdbi) { handle -> + executeBatchOperation( + handle, + entities, + statement = + """ + UPDATE "${tableName}" + SET + data = :data::jsonb, + version = :nextVersion, + modified_at = :modifiedAt + WHERE + id = :id AND + version = :previousVersion + """ + .trimIndent(), + bindParameters = { batch, entity -> + val nextVersion = entity.version.next() + + batch + .bind("data", serializationAdapter.toJson(entity.item)) + .bind("nextVersion", nextVersion) + .bind("modifiedAt", modifiedAt) + .bind("id", entity.item.id) + .bind("previousVersion", entity.version) + }, + handleModifiedRowCounts = { counts, batch -> + handleModifiedRowCounts(counts, batch, operation = "update") + }, + ) + } } + } catch (e: BatchOperationException) { + @Suppress("UNCHECKED_CAST") // We know the entity is EntityT on this repository + throw mapCreateOrUpdateException(e.cause, e.entity as EntityT) } // We wait until here to create the result list, which may be large, to avoid allocating it - // before calling the database. That would keep the list in memory while we are waiting for the - // database, needlessly reducing throughput. + // before calling the database. That would keep the list in memory while we are waiting for + // the database, needlessly reducing throughput. return entities.map { entity -> entity.copy(modifiedAt = modifiedAt, version = entity.version.next()) } @@ -307,8 +319,8 @@ open class RepositoryJdbi>( bindParameters = { batch, entity -> batch.bind("id", entity.item.id).bind("previousVersion", entity.version) }, - handleModifiedRowCounts = { counts, batchStartIndex -> - handleModifiedRowCounts(counts, batchStartIndex, entities, operation = "delete") + handleModifiedRowCounts = { counts, batch -> + handleModifiedRowCounts(counts, batch, operation = "delete") }, ) } @@ -482,11 +494,11 @@ open class RepositoryJdbi>( } /** - * Method that you can override to map exceptions thrown in [create] or [update] to your own - * exception type. This is useful to handle e.g. unique constraint violations: instead of letting - * the database throw an opaque `PSQLException` that may be difficult to handle in layers above, - * you can instead check if the given exception is a unique index violation and map it to a more - * useful exception type here. + * Method that you can override to map exceptions thrown in [create] / [update] / [batchCreate] / + * [batchUpdate] to your own exception type. This is useful to handle e.g. unique constraint + * violations: instead of letting the database throw an opaque `PSQLException` that may be + * difficult to handle in layers above, you can instead check if the given exception is a unique + * index violation and map it to a more useful exception type here. * * If your implementation receives an exception here that it does not want to map, it should just * return it as-is. @@ -494,6 +506,9 @@ open class RepositoryJdbi>( * The entity that was attempted to be created or updated is also provided here, so you can add * extra context to the mapped exception. * + * This method is only called by [batchCreate] / [batchUpdate] if the batch operation failed + * because of a single entity (e.g. a unique constraint violation). + * * Example: * ``` * override fun mapCreateOrUpdateException(e: Exception, entity: ExampleEntity): Exception { @@ -519,24 +534,17 @@ open class RepositoryJdbi>( */ private fun handleModifiedRowCounts( modifiedRowCounts: IntArray, - batchStartIndex: Int, - entities: Iterable>, + batch: List>, operation: String, ) { for (count in modifiedRowCounts.withIndex()) { if (count.value == 0) { - var exceptionMessage = - "Entity was concurrently modified between being retrieved and trying to ${operation} it in batch ${operation} (rolling back batch ${operation})" - // We want to add the entity to the exception message for context, but we can only do this - // if the Iterable is indexable - if (entities is List) { - val entity = entities.getOrNull(batchStartIndex + count.index) - if (entity != null) { - exceptionMessage += " [Entity: ${entity}]" - } - } - - throw ConflictRepositoryException(exceptionMessage) + // Should never be null, but we don't want to suppress the ConflictRepositoryException here + // if it is + val conflictedEntity: EntityT? = batch.getOrNull(count.index)?.item + throw ConflictRepositoryException( + "Entity was concurrently modified between being retrieved and trying to ${operation} it in batch ${operation} (rolling back batch ${operation}) [Entity: ${conflictedEntity}]", + ) } } } diff --git a/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt b/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt index 3f3cdc5..a7747b1 100644 --- a/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt +++ b/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt @@ -1,77 +1,77 @@ package no.liflig.documentstore.utils +import java.sql.BatchUpdateException +import kotlin.math.min +import no.liflig.documentstore.entity.Entity +import no.liflig.documentstore.entity.Versioned import org.jdbi.v3.core.Handle import org.jdbi.v3.core.result.BatchResultBearing import org.jdbi.v3.core.statement.PreparedBatch +import org.jdbi.v3.core.statement.UnableToExecuteStatementException /** * Uses [Prepared Batches from JDBI](https://jdbi.org/releases/3.45.1/#_prepared_batches) to execute - * the given [statement] on the given [items]. For each item, [bindParameters] is called to bind - * parameters to the statement. + * the given [statement] on the given [entities]. For each entity, [bindParameters] is called to + * bind parameters to the statement. * - * The items are divided into multiple batches if the number of items exceeds [batchSize]. According - * to Oracle, + * The entities are divided into multiple batches if the number of entities exceeds [batchSize]. + * According to Oracle, * [the optimal size for batch operations in JDBC is 50-100](https://docs.oracle.com/cd/E11882_01/java.112/e16548/oraperf.htm#JJDBC28754). * We default to the conservative end of 50, since we send JSON which is rather memory inefficient. * - * [PreparedBatch.execute] returns an array of modified row counts (1 count for every batch item). - * If you want to handle this, use [handleModifiedRowCounts]. This function is called once for every - * executed batch, which may be more than 1 if the number of items exceeds [batchSize]. A second - * parameter is provided to [handleModifiedRowCounts] with the start index of the current batch, - * which can then be used to get the corresponding entity for diagnostics purposes. + * [PreparedBatch.execute] returns an array of modified row counts (1 count for every entity). If + * you want to handle this, use [handleModifiedRowCounts]. This function is called once for every + * executed batch, which may be more than 1 if the number of entities exceeds [batchSize]. The first + * parameter is an array of the modified row counts for the executed batch, and the second parameter + * is the batch of entities. * * If you need to return something from the query, pass columns names in [columnsToReturn]. This * will append `RETURNING` to the SQL statement with the given column names. You can then iterate * over the results with [handleReturnedColumns]. + * + * @throws BatchOperationException If the batch operation failed because of a single entity (e.g. a + * unique constraint violation). */ -internal fun executeBatchOperation( +internal fun executeBatchOperation( handle: Handle, - items: Iterable, + entities: Iterable, statement: String, - bindParameters: (PreparedBatch, BatchItemT) -> PreparedBatch, - handleModifiedRowCounts: ((IntArray, Int) -> Unit)? = null, + bindParameters: (PreparedBatch, EntityT) -> PreparedBatch, + handleModifiedRowCounts: ((IntArray, List) -> Unit)? = null, columnsToReturn: Array? = null, handleReturnedColumns: ((BatchResultBearing) -> Unit)? = null, batchSize: Int = 50, ) { runWithAutoCommitDisabled(handle) { - var currentBatch: PreparedBatch? = null - var elementCountInCurrentBatch = 0 - var startIndexOfCurrentBatch = 0 - - for ((index, element) in items.withIndex()) { - if (currentBatch == null) { - currentBatch = handle.prepareBatch(statement)!! // Should never return null - startIndexOfCurrentBatch = index - } + val batchProvider = BatchProvider.create(entities, batchSize) - currentBatch = bindParameters(currentBatch, element) - currentBatch.add() - elementCountInCurrentBatch++ + var batch = batchProvider.nextBatch() + while (batch != null) { + var batchStatement: PreparedBatch = handle.prepareBatch(statement) + + for (entity in batch) { + batchStatement = bindParameters(batchStatement, entity) + batchStatement.add() + } - if (elementCountInCurrentBatch >= batchSize) { + try { executeBatch( - currentBatch, - startIndexOfCurrentBatch, + batchStatement, + batch, handleModifiedRowCounts, columnsToReturn, handleReturnedColumns, ) - - currentBatch = null - elementCountInCurrentBatch = 0 + } catch (e: UnableToExecuteStatementException) { + val failingEntity = getFailingEntity(batch, e) + if (failingEntity != null) { + throw BatchOperationException(failingEntity, cause = e) + } else { + throw e + } } - } - // If currentBatch is non-null here, that means we still have remaining entities to update - if (currentBatch != null) { - executeBatch( - currentBatch, - startIndexOfCurrentBatch, - handleModifiedRowCounts, - columnsToReturn, - handleReturnedColumns, - ) + batch = batchProvider.nextBatch() } } } @@ -84,22 +84,22 @@ internal fun executeBatchOperation( * [PreparedBatch.executePreparedBatch]. That appends the given columns in a `RETURNING` clause on * the query, and gives us a result object which we can handle in [handleReturnedColumns]. */ -private fun executeBatch( - currentBatch: PreparedBatch, - startIndexOfCurrentBatch: Int, - handleModifiedRowCounts: ((IntArray, Int) -> Unit)?, +private fun executeBatch( + batchStatement: PreparedBatch, + batch: List, + handleModifiedRowCounts: ((IntArray, List) -> Unit)?, columnsToReturn: Array?, handleReturnedColumns: ((BatchResultBearing) -> Unit)?, ) { if (columnsToReturn.isNullOrEmpty()) { - val modifiedRowCounts = currentBatch.execute() + val modifiedRowCounts = batchStatement.execute() if (handleModifiedRowCounts != null) { - handleModifiedRowCounts(modifiedRowCounts, startIndexOfCurrentBatch) + handleModifiedRowCounts(modifiedRowCounts, batch) } } else { - val result = currentBatch.executePreparedBatch(*columnsToReturn) + val result = batchStatement.executePreparedBatch(*columnsToReturn) if (handleModifiedRowCounts != null) { - handleModifiedRowCounts(result.modifiedRowCounts(), startIndexOfCurrentBatch) + handleModifiedRowCounts(result.modifiedRowCounts(), batch) } if (handleReturnedColumns != null) { handleReturnedColumns(result) @@ -107,6 +107,134 @@ private fun executeBatch( } } +/** + * We want [executeBatchOperation] to take an [Iterable], so that it can be used both with large + * streams of entities (like we do in [no.liflig.documentstore.migration.migrateEntity]), and + * in-memory lists of entities. + * + * As we iterate over entities to add them to a batch, we also want to store a list of the current + * batch of entities. We need this to get the failing entity for [BatchOperationException], which is + * useful to know exactly which entity in the batch failed. + * + * If the `Iterable` represents a stream of entities that may not all be in memory, we must create + * this entity list ourselves and add to it as we consume the stream. This is what + * [StreamingBatchProvider] does. But if the `Iterable` is already an in-memory list, we want to + * avoid the overhead of allocating these additional lists, and instead use [List.subList] to create + * a view of the list without copying. This is what [InMemoryBatchProvider] does. + */ +private interface BatchProvider { + /** Returns null when empty. */ + fun nextBatch(): List? + + companion object { + fun create(entities: Iterable, batchSize: Int): BatchProvider { + return if (entities is List) { + InMemoryBatchProvider(entities, batchSize) + } else { + StreamingBatchProvider(entities.iterator(), batchSize) + } + } + } +} + +private class InMemoryBatchProvider( + private val entities: List, + private val batchSize: Int, +) : BatchProvider { + private var startIndexOfCurrentBatch = 0 + + override fun nextBatch(): List? { + if (startIndexOfCurrentBatch >= entities.size) { + return null + } + + val endIndex = min(startIndexOfCurrentBatch + batchSize, entities.size) + val batch = entities.subList(startIndexOfCurrentBatch, endIndex) + startIndexOfCurrentBatch = endIndex + return batch + } +} + +private class StreamingBatchProvider( + private val entities: Iterator, + private val batchSize: Int, +) : BatchProvider { + // Initialize with capacity equal to batchSize + private val currentBatch = ArrayList(batchSize) + + override fun nextBatch(): List? { + currentBatch.clear() + + while (entities.hasNext() && currentBatch.size < batchSize) { + currentBatch.add(entities.next()) + } + + return if (currentBatch.isEmpty()) null else currentBatch + } +} + +/** + * We throw this exception if a batch operation failed due to a single entity, e.g. a unique + * constraint violation. + */ +internal class BatchOperationException( + val entity: Entity<*>, + override val cause: Exception, +) : RuntimeException() { + override val message + get() = "Batch operation failed for entity: ${entity}" +} + +private const val BATCH_ENTRY_EXCEPTION_PREFIX = "Batch entry" + +/** + * When [PreparedBatch.execute] throws, we would like to know if the failure was caused by a + * specific entity in the batch, e.g. in case a constraint was violated. + * + * For some JDBC drivers, one can use [BatchUpdateException.updateCounts] for this. But the Postgres + * driver [does not provide the info we need from it](https://github.com/pgjdbc/pgjdbc/issues/670). + * + * However, the Postgres driver does include the batch index of the entity that failed in the + * [exception message when it throws BatchUpdateException](https://github.com/pgjdbc/pgjdbc/blob/cf3d8e5ed1bca96873735a8731ed4082132361a5/pgjdbc/src/main/java/org/postgresql/jdbc/BatchResultHandler.java#L162-L164). + * So we can try to parse out the index from the exception here, and use that to get the failing + * entity. This is a kind of hacky solution, but presumably the exception message format of the + * Postgres driver is pretty stable. And if it changes, it will be caught by our tests. + */ +private fun getFailingEntity( + batch: List, + exception: UnableToExecuteStatementException +): Entity<*>? { + try { + val cause = exception.cause + if (cause !is BatchUpdateException) { + return null + } + + val message = cause.message ?: return null + if (!message.startsWith(BATCH_ENTRY_EXCEPTION_PREFIX)) { + return null + } + + val start = BATCH_ENTRY_EXCEPTION_PREFIX.length + 1 + val end = message.indexOf(' ', startIndex = start) + val indexString = message.substring(start, end) + val index = indexString.toInt() + + /** + * executeBatchOperation is either called on a collection of Entity<*>, or Versioned>. + * We want to unpack to an Entity<*> here, so that [BatchOperationException] can use just the + * Entity. + */ + return when (val entity = batch.getOrNull(index)) { + is Entity<*> -> entity + is Versioned<*> -> entity.item + else -> null + } + } catch (e: Exception) { + return null + } +} + /** * When using batch operations, we typically want to send the batches as part of a transaction, so * that either the full operation is committed, or none at all. To do this, we have to make sure diff --git a/src/test/kotlin/no/liflig/documentstore/repository/MapExceptionTest.kt b/src/test/kotlin/no/liflig/documentstore/repository/MapExceptionTest.kt index ecaed52..cb12249 100644 --- a/src/test/kotlin/no/liflig/documentstore/repository/MapExceptionTest.kt +++ b/src/test/kotlin/no/liflig/documentstore/repository/MapExceptionTest.kt @@ -1,6 +1,8 @@ package no.liflig.documentstore.repository +import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import no.liflig.documentstore.entity.mapEntities import no.liflig.documentstore.testutils.ExampleEntity import no.liflig.documentstore.testutils.ExampleRepository import no.liflig.documentstore.testutils.UniqueFieldAlreadyExists @@ -11,20 +13,73 @@ class MapExceptionTest { @Test fun `mapCreateOrUpdateException catches and maps exceptions in create`() { /** We map to this custom exception in [ExampleRepository.mapCreateOrUpdateException]. */ - assertFailsWith { - exampleRepo.create(ExampleEntity(text = "A", uniqueField = 1)) - exampleRepo.create(ExampleEntity(text = "B", uniqueField = 1)) - } + val exception = + assertFailsWith { + exampleRepo.create(ExampleEntity(text = "A", uniqueField = 1)) + exampleRepo.create(ExampleEntity(text = "B", uniqueField = 1)) + } + assertEquals(1, exception.failingEntity.uniqueField) } @Test fun `mapCreateOrUpdateException catches and maps exceptions in update`() { /** We map to this custom exception in [ExampleRepository.mapCreateOrUpdateException]. */ - assertFailsWith { - val (entity1, version1) = exampleRepo.create(ExampleEntity(text = "A", uniqueField = 2)) - exampleRepo.create(ExampleEntity(text = "B", uniqueField = 3)) + val exception = + assertFailsWith { + val (entity1, version1) = exampleRepo.create(ExampleEntity(text = "A", uniqueField = 2)) + exampleRepo.create(ExampleEntity(text = "B", uniqueField = 3)) - exampleRepo.update(entity1.copy(uniqueField = 3), version1) - } + exampleRepo.update(entity1.copy(uniqueField = 3), version1) + } + assertEquals(3, exception.failingEntity.uniqueField) + } + + @Test + fun `mapCreateOrUpdateException catches and maps exceptions in batchCreate`() { + /** + * If this test starts failing, that may be because the Postgres JDBC driver has changed the + * format of their BatchUpdateException, which we use in + * [no.liflig.documentstore.utils.getFailingEntity] to get the entity that failed, so we can + * provide it to [RepositoryJdbi.mapCreateOrUpdateException]. + */ + val exception = + assertFailsWith { + exampleRepo.batchCreate( + // Create 100 entities, and an additional entity with the same unique field as the + // last one, which should cause UniqueFieldAlreadyExists exception from the + // overridden mapCreateOrUpdateException in the repository + (100..199).map { i -> ExampleEntity(text = "test", uniqueField = i) } + + listOf(ExampleEntity(text = "test", uniqueField = 199)), + ) + } + assertEquals(199, exception.failingEntity.uniqueField) + } + + @Test + fun `mapCreateOrUpdateException catches and maps exceptions in batchUpdate`() { + /** + * If this test starts failing, that may be because the Postgres JDBC driver has changed the + * format of their BatchUpdateException, which we use in + * [no.liflig.documentstore.utils.getFailingEntity] to get the entity that failed, so we can + * provide it to [RepositoryJdbi.mapCreateOrUpdateException]. + */ + val entities = + exampleRepo.batchCreate( + (200..299).map { i -> ExampleEntity(text = "test", uniqueField = i) }) + + val exception = + assertFailsWith { + exampleRepo.batchUpdate( + entities.mapEntities { entity -> + entity.copy( + text = "updated", + // Update the 100 entities we created, but set the uniqueField on the last one + // to the same as a previous entity + uniqueField = if (entity.uniqueField == 299) 298 else entity.uniqueField, + ) + }, + ) + } + assertEquals(298, exception.failingEntity.uniqueField) } } diff --git a/src/test/kotlin/no/liflig/documentstore/testutils/ExampleRepository.kt b/src/test/kotlin/no/liflig/documentstore/testutils/ExampleRepository.kt index cadd29b..2f95f27 100644 --- a/src/test/kotlin/no/liflig/documentstore/testutils/ExampleRepository.kt +++ b/src/test/kotlin/no/liflig/documentstore/testutils/ExampleRepository.kt @@ -94,9 +94,11 @@ class ExampleRepository( } } -class UniqueFieldAlreadyExists(entity: ExampleEntity, override val cause: Exception) : - RuntimeException() { - override val message = "Received entity with unique field that already exists: ${entity}" +class UniqueFieldAlreadyExists( + val failingEntity: ExampleEntity, + override val cause: Exception, +) : RuntimeException() { + override val message = "Received entity with unique field that already exists: ${failingEntity}" } class ExampleRepositoryWithStringEntityId(jdbi: Jdbi) :