From fc287b57a6441962e8d1a601813de49dc57fbb36 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Fri, 31 Jan 2025 21:50:50 +0800 Subject: [PATCH] [log] Refactor ArrowCompressionRatioEstimator to be lock-free and estimate ratio for only a table (#330) --- .../fluss/client/table/FlussTableITCase.java | 4 +- .../client/write/ArrowLogWriteBatchTest.java | 52 ++++++----- .../client/write/RecordAccumulatorTest.java | 77 ++++++++++++++++ .../ArrowCompressionRatioEstimator.java | 90 ++++++------------- .../record/MemoryLogRecordsArrowBuilder.java | 6 +- .../alibaba/fluss/row/arrow/ArrowWriter.java | 69 +++++++------- .../fluss/row/arrow/ArrowWriterPool.java | 14 ++- .../ArrowCompressionRatioEstimatorTest.java | 9 +- .../row/arrow/ArrowReaderWriterTest.java | 2 +- 9 files changed, 179 insertions(+), 144 deletions(-) diff --git a/fluss-client/src/test/java/com/alibaba/fluss/client/table/FlussTableITCase.java b/fluss-client/src/test/java/com/alibaba/fluss/client/table/FlussTableITCase.java index e51de8159..5d6d47002 100644 --- a/fluss-client/src/test/java/com/alibaba/fluss/client/table/FlussTableITCase.java +++ b/fluss-client/src/test/java/com/alibaba/fluss/client/table/FlussTableITCase.java @@ -53,6 +53,7 @@ import com.alibaba.fluss.utils.Preconditions; import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -112,6 +113,7 @@ void testAppendOnly() throws Exception { @ParameterizedTest @ValueSource(booleans = {true, false}) + @Disabled("TODO, fix me in #116") void testAppendWithSmallBuffer(boolean indexedFormat) throws Exception { TableDescriptor desc = indexedFormat @@ -991,7 +993,7 @@ void testArrowCompressionAndProject(String compression, String level) throws Exc .property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_TYPE.key(), compression) .property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_ZSTD_LEVEL.key(), level) .build(); - TablePath tablePath = TablePath.of("test_db_1", "test_arrow_compression_and_project"); + TablePath tablePath = TablePath.of("test_db_1", "test_arrow_" + compression + level); createTable(tablePath, tableDescriptor, false); try (Connection conn = ConnectionFactory.createConnection(clientConf); diff --git a/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java b/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java index dfefc1770..aa53088bc 100644 --- a/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java +++ b/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java @@ -17,7 +17,6 @@ package com.alibaba.fluss.client.write; import com.alibaba.fluss.compression.ArrowCompressionInfo; -import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator; import com.alibaba.fluss.compression.ArrowCompressionType; import com.alibaba.fluss.memory.MemorySegment; import com.alibaba.fluss.memory.PreAllocatedPagedOutputView; @@ -29,12 +28,14 @@ import com.alibaba.fluss.record.LogRecordReadContext; import com.alibaba.fluss.record.MemoryLogRecords; import com.alibaba.fluss.record.bytesview.BytesView; +import com.alibaba.fluss.row.arrow.ArrowWriter; import com.alibaba.fluss.row.arrow.ArrowWriterPool; import com.alibaba.fluss.row.indexed.IndexedRow; import com.alibaba.fluss.shaded.arrow.org.apache.arrow.memory.BufferAllocator; import com.alibaba.fluss.shaded.arrow.org.apache.arrow.memory.RootAllocator; import com.alibaba.fluss.utils.CloseableIterator; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -188,37 +189,38 @@ void testArrowCompressionRatioEstimated() throws Exception { // (COMPRESSION_RATIO_IMPROVING_STEP#COMPRESSION_RATIO_IMPROVING_STEP) each time. Therefore, // the loop runs 100 times, and theoretically, the final number of input records will be // much greater than at the beginning. - int round = 100; - int[] recordCounts = new int[round]; - for (int i = 0; i < round; i++) { + float previousRatio = -1.0f; + float currentRatio = 1.0f; + int lastBytesInSize = 0; + // exit the loop until compression ratio is converged + while (previousRatio != currentRatio) { + ArrowWriter arrowWriter = + writerProvider.getOrCreateWriter( + tb.getTableId(), + DATA1_TABLE_INFO.getSchemaId(), + maxSizeInBytes, + DATA1_ROW_TYPE, + compressionInfo); + ArrowLogWriteBatch arrowLogWriteBatch = new ArrowLogWriteBatch( tb, DATA1_PHYSICAL_TABLE_PATH, DATA1_TABLE_INFO.getSchemaId(), - writerProvider.getOrCreateWriter( - tb.getTableId(), - DATA1_TABLE_INFO.getSchemaId(), - maxSizeInBytes, - DATA1_ROW_TYPE, - compressionInfo), - new PreAllocatedPagedOutputView(memorySegmentList)); + arrowWriter, + new PreAllocatedPagedOutputView(memorySegmentList), + System.currentTimeMillis()); int recordCount = 0; while (arrowLogWriteBatch.tryAppend( createWriteRecord( row( DATA1_ROW_TYPE, - new Object[] { - recordCount, - "a a a" - + recordCount - })), + new Object[] {recordCount, RandomStringUtils.random(100)})), newWriteCallback())) { recordCount++; } - recordCounts[i] = recordCount; // batch full. boolean appendResult = arrowLogWriteBatch.tryAppend( @@ -228,16 +230,18 @@ void testArrowCompressionRatioEstimated() throws Exception { // close this batch and recycle the writer. arrowLogWriteBatch.close(); - arrowLogWriteBatch.build(); + BytesView built = arrowLogWriteBatch.build(); + lastBytesInSize = built.getBytesLength(); - ArrowCompressionRatioEstimator compressionRatioEstimator = - writerProvider.compressionRatioEstimator(); - float currentRatio = - compressionRatioEstimator.estimation(tb.getTableId(), compressionInfo); - assertThat(currentRatio).isNotEqualTo(1.0f); + previousRatio = currentRatio; + currentRatio = arrowWriter.getCompressionRatioEstimator().estimation(); } - assertThat(recordCounts[round - 1]).isGreaterThan(recordCounts[0]); + // when the compression ratio is converged, the memory buffer should be fully used. + assertThat(lastBytesInSize) + .isGreaterThan((int) (maxSizeInBytes * ArrowWriter.BUFFER_USAGE_RATIO)) + .isLessThan(maxSizeInBytes); + assertThat(currentRatio).isLessThan(1.0f); } private WriteRecord createWriteRecord(IndexedRow row) { diff --git a/fluss-client/src/test/java/com/alibaba/fluss/client/write/RecordAccumulatorTest.java b/fluss-client/src/test/java/com/alibaba/fluss/client/write/RecordAccumulatorTest.java index b7d7ce3af..b1d9b59d0 100644 --- a/fluss-client/src/test/java/com/alibaba/fluss/client/write/RecordAccumulatorTest.java +++ b/fluss-client/src/test/java/com/alibaba/fluss/client/write/RecordAccumulatorTest.java @@ -38,6 +38,7 @@ import com.alibaba.fluss.record.MemoryLogRecords; import com.alibaba.fluss.record.RowKind; import com.alibaba.fluss.row.InternalRow; +import com.alibaba.fluss.row.arrow.ArrowWriter; import com.alibaba.fluss.row.indexed.IndexedRow; import com.alibaba.fluss.rpc.GatewayClientProxy; import com.alibaba.fluss.rpc.RpcClient; @@ -46,6 +47,7 @@ import com.alibaba.fluss.utils.CloseableIterator; import com.alibaba.fluss.utils.clock.ManualClock; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -75,6 +77,22 @@ /** Test for {@link RecordAccumulator}. */ public class RecordAccumulatorTest { + private static final long ZSTD_TABLE_ID = 16001L; + private static final PhysicalTablePath ZSTD_PHYSICAL_TABLE_PATH = + PhysicalTablePath.of(TablePath.of("test_db_1", "test_zstd_table_1")); + private static final TableInfo ZSTD_TABLE_INFO = + new TableInfo( + ZSTD_PHYSICAL_TABLE_PATH.getTablePath(), + ZSTD_TABLE_ID, + TableDescriptor.builder() + .schema(DATA1_SCHEMA) + .distributedBy(3) + .property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_TYPE.key(), "zstd") + .build(), + 1, + System.currentTimeMillis(), + System.currentTimeMillis()); + ServerNode node1 = new ServerNode(1, "localhost", 90, ServerType.TABLET_SERVER); ServerNode node2 = new ServerNode(2, "localhost", 91, ServerType.TABLET_SERVER); ServerNode node3 = new ServerNode(3, "localhost", 92, ServerType.TABLET_SERVER); @@ -91,6 +109,7 @@ public class RecordAccumulatorTest { new BucketLocation(DATA1_PHYSICAL_TABLE_PATH, DATA1_TABLE_ID, 2, node2, serverNodes); private final BucketLocation bucket4 = new BucketLocation(DATA1_PHYSICAL_TABLE_PATH, DATA1_TABLE_ID, 3, node2, serverNodes); + private final WriteCallback writeCallback = exception -> { if (exception != null) { @@ -149,6 +168,63 @@ void testDrainBatches() throws Exception { verifyTableBucketInBatches(batches3, tb1, tb3); } + @Test + void testDrainCompressedBatches() throws Exception { + int batchSize = 10 * 1024; + int bucketNum = 10; + RecordAccumulator accum = + createTestRecordAccumulator( + Integer.MAX_VALUE, batchSize, batchSize, Integer.MAX_VALUE); + List bucketLocations = new ArrayList<>(); + for (int b = 0; b < bucketNum; b++) { + bucketLocations.add( + new BucketLocation( + ZSTD_PHYSICAL_TABLE_PATH, ZSTD_TABLE_ID, b, node1, serverNodes)); + } + // all buckets are located in node1 + cluster = updateCluster(bucketLocations); + + appendUntilCompressionRatioStable(accum, batchSize); + + for (int i = 0; i < bucketNum; i++) { + appendUntilBatchFull(accum, i); + } + + // all 3 buckets are located in node1 + Map> batches = + accum.drain(cluster, Collections.singleton(node1), batchSize * bucketNum); + // the compression ratio is smaller than 1.0, + // so bucketNum * batch_size should contain all compressed batches for each bucket + assertThat(batches.containsKey(node1.id())).isTrue(); + assertThat(batches.get(node1.id()).size()).isEqualTo(bucketNum); + } + + private void appendUntilCompressionRatioStable(RecordAccumulator accum, int batchSize) + throws Exception { + while (true) { + appendUntilBatchFull(accum, 0); + Map> batches = + accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE); + WriteBatch batch = batches.get(node1.id()).get(0); + int actualSize = batch.build().getBytesLength(); + if (actualSize > batchSize * ArrowWriter.BUFFER_USAGE_RATIO) { + return; + } + } + } + + private void appendUntilBatchFull(RecordAccumulator accum, int bucketId) throws Exception { + while (true) { + InternalRow row = row(DATA1_ROW_TYPE, new Object[] {1, RandomStringUtils.random(10)}); + PhysicalTablePath tablePath = PhysicalTablePath.of(ZSTD_TABLE_INFO.getTablePath()); + WriteRecord record = new WriteRecord(tablePath, WriteKind.APPEND, row, null); + // append until the batch is full + if (accum.append(record, writeCallback, cluster, bucketId, false).batchIsFull) { + break; + } + } + } + @Test void testFull() throws Exception { // test case assumes that the records do not fill the batch completely @@ -465,6 +541,7 @@ private Cluster updateCluster(List bucketLocations) { System.currentTimeMillis()); Map tableInfoByPath = new HashMap<>(); tableInfoByPath.put(DATA1_TABLE_PATH, data1NonPkTableInfo); + tableInfoByPath.put(ZSTD_TABLE_INFO.getTablePath(), ZSTD_TABLE_INFO); return new Cluster( aliveTabletServersById, diff --git a/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java index 347a05ad6..757bfc6a7 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java @@ -20,17 +20,7 @@ import javax.annotation.concurrent.ThreadSafe; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; - -import static com.alibaba.fluss.utils.concurrent.LockUtils.inLock; - -/** - * This class help estimate the compression ratio for each table and each arrow compression type - * combination. - */ +/** This class helps estimate the compression ratio for a table. */ @Internal @ThreadSafe public class ArrowCompressionRatioEstimator { @@ -38,69 +28,41 @@ public class ArrowCompressionRatioEstimator { * The constant speed to increase compression ratio when a batch compresses better than * expected. */ - public static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f; + private static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f; /** * The minimum speed to decrease compression ratio when a batch compresses worse than expected. */ - public static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f; + private static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f; + + /** + * The default compression ratio when initialize a new {@link ArrowCompressionRatioEstimator}. + */ + private static final float DEFAULT_COMPRESSION_RATIO = 1.0f; - private final Map> compressionRatio; - private final Map tableLocks; + /** The current compression ratio, use volatile for lock-free. */ + private volatile float compressionRatio; public ArrowCompressionRatioEstimator() { - compressionRatio = new ConcurrentHashMap<>(); - tableLocks = new ConcurrentHashMap<>(); + compressionRatio = DEFAULT_COMPRESSION_RATIO; } - /** - * Update the compression ratio estimation for a table and related compression info with the - * observed compression ratio. - */ - public void updateEstimation( - long tableId, ArrowCompressionInfo compressionInfo, float observedRatio) { - Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock()); - inLock( - lock, - () -> { - Map compressionRatioMap = - compressionRatio.computeIfAbsent( - tableId, k -> new ConcurrentHashMap<>()); - String compressionKey = compressionInfo.toString(); - float currentEstimation = - compressionRatioMap.getOrDefault(compressionKey, 1.0f); - if (observedRatio > currentEstimation) { - compressionRatioMap.put( - compressionKey, - Math.max( - currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP, - observedRatio)); - } else if (observedRatio < currentEstimation) { - compressionRatioMap.put( - compressionKey, - Math.max( - currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP, - observedRatio)); - } - }); + /** Update the compression ratio estimation with the observed compression ratio. */ + public void updateEstimation(float observedRatio) { + float currentEstimation = compressionRatio; + // it is possible it can't guarantee atomic and isolation update, + // but it's fine as it's just an estimation + if (observedRatio > currentEstimation) { + compressionRatio = + Math.max(currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP, observedRatio); + } else if (observedRatio < currentEstimation) { + compressionRatio = + Math.max(currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP, observedRatio); + } } - /** Get the compression ratio estimation for a table and related compression info. */ - public float estimation(long tableId, ArrowCompressionInfo compressionInfo) { - Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock()); - return inLock( - lock, - () -> { - Map compressionRatioMap = - compressionRatio.computeIfAbsent( - tableId, k -> new ConcurrentHashMap<>()); - String compressionKey = compressionInfo.toString(); - - if (!compressionRatioMap.containsKey(compressionKey)) { - compressionRatioMap.put(compressionKey, 1.0f); - } - - return compressionRatioMap.get(compressionKey); - }); + /** Get current compression ratio estimation. */ + public float estimation() { + return compressionRatio; } } diff --git a/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java b/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java index 6ac525d93..9e8fb61b9 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java @@ -187,14 +187,16 @@ public void recycleArrowWriter() { public int estimatedSizeInBytes() { if (bytesView != null) { - // accurate total size in bytes + // accurate total size in bytes (compressed if compression is enabled) return bytesView.getBytesLength(); } if (reCalculateSizeInBytes) { // make size in bytes up-to-date estimatedSizeInBytes = - ARROW_ROWKIND_OFFSET + rowKindWriter.sizeInBytes() + arrowWriter.sizeInBytes(); + ARROW_ROWKIND_OFFSET + + rowKindWriter.sizeInBytes() + + arrowWriter.estimatedSizeInBytes(); } reCalculateSizeInBytes = false; diff --git a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java index ccac635cb..66d5633d0 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java @@ -17,9 +17,9 @@ package com.alibaba.fluss.row.arrow; import com.alibaba.fluss.annotation.Internal; +import com.alibaba.fluss.annotation.VisibleForTesting; import com.alibaba.fluss.compression.ArrowCompressionInfo; import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator; -import com.alibaba.fluss.compression.ArrowCompressionType; import com.alibaba.fluss.memory.AbstractPagedOutputView; import com.alibaba.fluss.row.InternalRow; import com.alibaba.fluss.row.arrow.writers.ArrowFieldWriter; @@ -61,13 +61,7 @@ public class ArrowWriter implements AutoCloseable { * The buffer usage ratio which is used to determine whether the writer is full. The writer is * full if the buffer usage ratio exceeds the threshold. */ - public static final double BUFFER_USAGE_RATIO = 0.96; - - /** - * The factor which is used to estimate the compression ratio of the serialized {@link - * ArrowRecordBatch}. - */ - private static final float COMPRESSION_RATE_ESTIMATION_FACTOR = 1.05f; + public static final float BUFFER_USAGE_RATIO = 0.95f; /** * The identifier of the writer which is used to identify the writer in the {@link @@ -78,9 +72,6 @@ public class ArrowWriter implements AutoCloseable { /** Container that holds a set of vectors for the rows. */ final VectorSchemaRoot root; - /** The table id of current writer. */ - private final long tableId; - /** * An array of writers which are responsible for the serialization of each column of the rows. */ @@ -97,7 +88,6 @@ public class ArrowWriter implements AutoCloseable { private final RowType schema; - private final ArrowCompressionInfo compressionInfo; private final CompressionCodec compressionCodec; private final ArrowCompressionRatioEstimator compressionRatioEstimator; @@ -113,7 +103,6 @@ public class ArrowWriter implements AutoCloseable { private long epoch; ArrowWriter( - long tableId, String writerKey, int bufferSizeInBytes, RowType schema, @@ -121,16 +110,13 @@ public class ArrowWriter implements AutoCloseable { ArrowWriterProvider provider, ArrowCompressionInfo compressionInfo, ArrowCompressionRatioEstimator compressionRatioEstimator) { - this.tableId = tableId; this.writerKey = writerKey; this.schema = schema; this.root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema(schema), allocator); this.provider = Preconditions.checkNotNull(provider); this.compressionCodec = compressionInfo.createCompressionCodec(); this.compressionRatioEstimator = compressionRatioEstimator; - this.compressionInfo = compressionInfo; - this.estimatedCompressionRatio = - compressionRatioEstimator.estimation(tableId, compressionInfo); + this.estimatedCompressionRatio = compressionRatioEstimator.estimation(); this.metadataLength = ArrowUtils.estimateArrowMetadataLength( @@ -160,9 +146,7 @@ public boolean isFull() { if (recordsCount > 0 && recordsCount >= estimatedMaxRecordsCount) { root.setRowCount(recordsCount); int metadataLength = getMetadataLength(); - int bodyLength = getBodyLength(); - - int estimatedBodyLength = estimatedBytesWritten(bodyLength); + int estimatedBodyLength = estimatedBytesWritten(getBodyLength()); int currentSize = metadataLength + estimatedBodyLength; if (currentSize >= writeLimitInBytes) { return true; @@ -183,7 +167,17 @@ public boolean isFull() { } public void reset(int bufferSizeInBytes) { - this.writeLimitInBytes = (int) (bufferSizeInBytes * BUFFER_USAGE_RATIO); + int newWriteLimit = (int) (bufferSizeInBytes * BUFFER_USAGE_RATIO); + if (newWriteLimit == writeLimitInBytes) { + // write limit is not changed, estimate from half records + // for better performance and accuracy. + estimatedMaxRecordsCount = recordsCount / 2; + } else { + // initial estimated count from -1 for new write limit, + // so estimate the count from the first row + estimatedMaxRecordsCount = -1; + } + writeLimitInBytes = newWriteLimit; for (int i = 0; i < fieldWriters.length; i++) { FieldVector fieldVector = root.getVector(i); initFieldVector(fieldVector); @@ -191,8 +185,8 @@ public void reset(int bufferSizeInBytes) { } root.setRowCount(0); recordsCount = 0; - // initial estimated count should < 0, so that we can estimate the count after the first row - estimatedMaxRecordsCount = -1; + // reset the compression ratio. + estimatedCompressionRatio = compressionRatioEstimator.estimation(); } /** Writes the specified row which is serialized into Arrow format. */ @@ -230,11 +224,13 @@ public int getBodyLength() { /** * Gets the total size in bytes of each serialized {@link ArrowRecordBatch} generated by this - * root. + * root. Return the estimated compressed size if compression is enabled. */ - public int sizeInBytes() { + public int estimatedSizeInBytes() { root.setRowCount(recordsCount); - return getMetadataLength() + getBodyLength(); + int metadataLength = getMetadataLength(); + int estimatedBodyLength = estimatedBytesWritten(getBodyLength()); + return metadataLength + estimatedBodyLength; } /** Serializes the current row batch to Arrow format and returns the written size in bytes. */ @@ -261,10 +257,9 @@ public int serializeToOutputView(AbstractPagedOutputView outputView, int positio checkState( uncompressedBodySizeInBytes > 0, "uncompressedRecordsSizeInBytes is 0 or negative"); - compressionRatioEstimator.updateEstimation( - tableId, - compressionInfo, - (float) block.getBodyLength() / uncompressedBodySizeInBytes); + float actualCompressionRatio = + (float) block.getBodyLength() / uncompressedBodySizeInBytes; + compressionRatioEstimator.updateEstimation(actualCompressionRatio); return (int) (block.getMetadataLength() + block.getBodyLength()); } @@ -290,10 +285,6 @@ public void recycle(long epoch) { if (this.epoch == epoch) { root.clear(); provider.recycleWriter(this); - - // reset the compression ratio. - this.estimatedCompressionRatio = - compressionRatioEstimator.estimation(tableId, compressionInfo); } } @@ -316,11 +307,15 @@ private void initFieldVector(FieldVector fieldVector) { } private int estimatedBytesWritten(int currentBytes) { - if (compressionInfo.getCompressionType() == ArrowCompressionType.NONE) { + if (compressionCodec.getCodecType() == CompressionUtil.CodecType.NO_COMPRESSION) { return currentBytes; } else { - return (int) - (currentBytes * estimatedCompressionRatio * COMPRESSION_RATE_ESTIMATION_FACTOR); + return (int) (currentBytes * estimatedCompressionRatio); } } + + @VisibleForTesting + public ArrowCompressionRatioEstimator getCompressionRatioEstimator() { + return compressionRatioEstimator; + } } diff --git a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java index 5641ba354..bd3d797a6 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java @@ -53,14 +53,14 @@ public class ArrowWriterPool implements ArrowWriterProvider { private boolean closed = false; @GuardedBy("lock") - private final ArrowCompressionRatioEstimator compressionRatioEstimator; + private final Map compressionRatioEstimators; private final ReentrantLock lock = new ReentrantLock(); public ArrowWriterPool(BufferAllocator allocator) { this.allocator = allocator; this.freeWriters = new HashMap<>(); - this.compressionRatioEstimator = new ArrowCompressionRatioEstimator(); + this.compressionRatioEstimators = new HashMap<>(); } @Override @@ -97,12 +97,14 @@ public ArrowWriter getOrCreateWriter( "Arrow VectorSchemaRoot pool closed while getting/creating root."); } Deque writers = freeWriters.get(writerKey); + ArrowCompressionRatioEstimator compressionRatioEstimator = + compressionRatioEstimators.computeIfAbsent( + writerKey, k -> new ArrowCompressionRatioEstimator()); if (writers != null && !writers.isEmpty()) { return initialize(writers.pollFirst(), bufferSizeInBytes); } else { return initialize( new ArrowWriter( - tableId, writerKey, bufferSizeInBytes, schema, @@ -130,6 +132,7 @@ public void close() { } } freeWriters.clear(); + compressionRatioEstimators.clear(); closed = true; } finally { lock.unlock(); @@ -140,9 +143,4 @@ public void close() { public Map> freeWriters() { return freeWriters; } - - @VisibleForTesting - public ArrowCompressionRatioEstimator compressionRatioEstimator() { - return compressionRatioEstimator; - } } diff --git a/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java b/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java index a59c469f9..ecba1b17c 100644 --- a/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java +++ b/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java @@ -56,14 +56,9 @@ class EstimationsObservedRatios { new EstimationsObservedRatios(0.6f, 0.7f), new EstimationsObservedRatios(0.6f, 0.4f), new EstimationsObservedRatios(0.004f, 0.001f)); - long tableId = 150001L; - ArrowCompressionInfo compressionInfo = - new ArrowCompressionInfo(ArrowCompressionType.ZSTD, 3); for (EstimationsObservedRatios estimationObservedRatio : estimationsObservedRatios) { - compressionRatioEstimator.updateEstimation( - tableId, compressionInfo, estimationObservedRatio.currentEstimation); - float updatedCompressionRatio = - compressionRatioEstimator.estimation(tableId, compressionInfo); + compressionRatioEstimator.updateEstimation(estimationObservedRatio.currentEstimation); + float updatedCompressionRatio = compressionRatioEstimator.estimation(); assertThat(updatedCompressionRatio) .isGreaterThanOrEqualTo(estimationObservedRatio.observedRatio); } diff --git a/fluss-common/src/test/java/com/alibaba/fluss/row/arrow/ArrowReaderWriterTest.java b/fluss-common/src/test/java/com/alibaba/fluss/row/arrow/ArrowReaderWriterTest.java index a1283b7ca..5f4f99f1b 100644 --- a/fluss-common/src/test/java/com/alibaba/fluss/row/arrow/ArrowReaderWriterTest.java +++ b/fluss-common/src/test/java/com/alibaba/fluss/row/arrow/ArrowReaderWriterTest.java @@ -150,7 +150,7 @@ void testReaderWriter() throws IOException { // skip arrow batch header. int size = writer.serializeToOutputView(pagedOutputView, ARROW_ROWKIND_OFFSET); - MemorySegment segment = MemorySegment.allocateHeapMemory(writer.sizeInBytes()); + MemorySegment segment = MemorySegment.allocateHeapMemory(writer.estimatedSizeInBytes()); assertThat(pagedOutputView.getWrittenSegments().size()).isEqualTo(1); MemorySegment firstSegment = pagedOutputView.getCurrentSegment();