diff --git a/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java b/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java index ae74d53a..d79a05e5 100644 --- a/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java +++ b/fluss-client/src/test/java/com/alibaba/fluss/client/write/ArrowLogWriteBatchTest.java @@ -17,6 +17,8 @@ package com.alibaba.fluss.client.write; import com.alibaba.fluss.compression.ArrowCompressionInfo; +import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator; +import com.alibaba.fluss.compression.ArrowCompressionType; import com.alibaba.fluss.memory.MemorySegment; import com.alibaba.fluss.memory.PreAllocatedPagedOutputView; import com.alibaba.fluss.memory.TestingMemorySegmentPool; @@ -166,6 +168,77 @@ void testAppendWithPreAllocatedMemorySegments() throws Exception { } } + @Test + void testArrowCompressionRatioEstimated() throws Exception { + int bucketId = 0; + int maxSizeInBytes = 1024 * 10; + int pageSize = 512; + TestingMemorySegmentPool memoryPool = new TestingMemorySegmentPool(pageSize); + List memorySegmentList = new ArrayList<>(); + for (int i = 0; i < maxSizeInBytes / pageSize; i++) { + memorySegmentList.add(memoryPool.nextSegment()); + } + + TableBucket tb = new TableBucket(DATA1_TABLE_ID, bucketId); + ArrowCompressionInfo compressionInfo = + new ArrowCompressionInfo(ArrowCompressionType.ZSTD, 3); + + // The compression rate increases slowly, with an increment of only 0.005 + // (COMPRESSION_RATIO_IMPROVING_STEP#COMPRESSION_RATIO_IMPROVING_STEP) each time. Therefore, + // the loop runs 100 times, and theoretically, the final number of input records will be + // much greater than at the beginning. + int round = 100; + int[] recordCounts = new int[round]; + for (int i = 0; i < round; i++) { + ArrowLogWriteBatch arrowLogWriteBatch = + new ArrowLogWriteBatch( + tb, + DATA1_PHYSICAL_TABLE_PATH, + DATA1_TABLE_INFO.getSchemaId(), + writerProvider.getOrCreateWriter( + tb.getTableId(), + DATA1_TABLE_INFO.getSchemaId(), + maxSizeInBytes, + DATA1_ROW_TYPE, + compressionInfo), + new PreAllocatedPagedOutputView(memorySegmentList)); + + int recordCount = 0; + while (arrowLogWriteBatch.tryAppend( + createWriteRecord( + row( + DATA1_ROW_TYPE, + new Object[] { + recordCount, + "a a a" + + recordCount + })), + newWriteCallback())) { + recordCount++; + } + + recordCounts[i] = recordCount; + // batch full. + boolean appendResult = + arrowLogWriteBatch.tryAppend( + createWriteRecord(row(DATA1_ROW_TYPE, new Object[] {1, "a"})), + newWriteCallback()); + assertThat(appendResult).isFalse(); + + // close this batch and recycle the writer. + arrowLogWriteBatch.close(); + arrowLogWriteBatch.build(); + + ArrowCompressionRatioEstimator compressionRatioEstimator = + writerProvider.compressionRatioEstimator(); + float currentRatio = + compressionRatioEstimator.estimation(tb.getTableId(), compressionInfo); + assertThat(currentRatio).isNotEqualTo(1.0f); + } + + assertThat(recordCounts[round - 1]).isGreaterThan(recordCounts[0]); + } + private WriteRecord createWriteRecord(IndexedRow row) { return new WriteRecord(DATA1_PHYSICAL_TABLE_PATH, WriteKind.APPEND, row, null); } diff --git a/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionInfo.java b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionInfo.java index 7f31a6e6..358a9893 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionInfo.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionInfo.java @@ -48,7 +48,7 @@ public int getCompressionLevel() { return compressionLevel; } - /** Create a Arrow compression codec based on the compression type and level. */ + /** Create an Arrow compression codec based on the compression type and level. */ public CompressionCodec createCompressionCodec() { return ArrowCompressionFactory.INSTANCE.createCodec( ArrowCompressionFactory.toArrowCompressionCodecType(compressionType), diff --git a/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java new file mode 100644 index 00000000..347a05ad --- /dev/null +++ b/fluss-common/src/main/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimator.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.fluss.compression; + +import com.alibaba.fluss.annotation.Internal; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import static com.alibaba.fluss.utils.concurrent.LockUtils.inLock; + +/** + * This class help estimate the compression ratio for each table and each arrow compression type + * combination. + */ +@Internal +@ThreadSafe +public class ArrowCompressionRatioEstimator { + /** + * The constant speed to increase compression ratio when a batch compresses better than + * expected. + */ + public static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f; + + /** + * The minimum speed to decrease compression ratio when a batch compresses worse than expected. + */ + public static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f; + + private final Map> compressionRatio; + private final Map tableLocks; + + public ArrowCompressionRatioEstimator() { + compressionRatio = new ConcurrentHashMap<>(); + tableLocks = new ConcurrentHashMap<>(); + } + + /** + * Update the compression ratio estimation for a table and related compression info with the + * observed compression ratio. + */ + public void updateEstimation( + long tableId, ArrowCompressionInfo compressionInfo, float observedRatio) { + Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock()); + inLock( + lock, + () -> { + Map compressionRatioMap = + compressionRatio.computeIfAbsent( + tableId, k -> new ConcurrentHashMap<>()); + String compressionKey = compressionInfo.toString(); + float currentEstimation = + compressionRatioMap.getOrDefault(compressionKey, 1.0f); + if (observedRatio > currentEstimation) { + compressionRatioMap.put( + compressionKey, + Math.max( + currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP, + observedRatio)); + } else if (observedRatio < currentEstimation) { + compressionRatioMap.put( + compressionKey, + Math.max( + currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP, + observedRatio)); + } + }); + } + + /** Get the compression ratio estimation for a table and related compression info. */ + public float estimation(long tableId, ArrowCompressionInfo compressionInfo) { + Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock()); + return inLock( + lock, + () -> { + Map compressionRatioMap = + compressionRatio.computeIfAbsent( + tableId, k -> new ConcurrentHashMap<>()); + String compressionKey = compressionInfo.toString(); + + if (!compressionRatioMap.containsKey(compressionKey)) { + compressionRatioMap.put(compressionKey, 1.0f); + } + + return compressionRatioMap.get(compressionKey); + }); + } +} diff --git a/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java b/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java index 4db3e96f..6ac525d9 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/record/MemoryLogRecordsArrowBuilder.java @@ -193,7 +193,6 @@ public int estimatedSizeInBytes() { if (reCalculateSizeInBytes) { // make size in bytes up-to-date - // TODO: consider the compression ratio estimatedSizeInBytes = ARROW_ROWKIND_OFFSET + rowKindWriter.sizeInBytes() + arrowWriter.sizeInBytes(); } diff --git a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java index ddc495cc..bcd64dc0 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriter.java @@ -18,6 +18,8 @@ import com.alibaba.fluss.annotation.Internal; import com.alibaba.fluss.compression.ArrowCompressionInfo; +import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator; +import com.alibaba.fluss.compression.ArrowCompressionType; import com.alibaba.fluss.memory.AbstractPagedOutputView; import com.alibaba.fluss.row.InternalRow; import com.alibaba.fluss.row.arrow.writers.ArrowFieldWriter; @@ -61,6 +63,12 @@ public class ArrowWriter implements AutoCloseable { */ public static final double BUFFER_USAGE_RATIO = 0.96; + /** + * The factor which is used to estimate the compression ratio of the serialized {@link + * ArrowRecordBatch}. + */ + private static final float COMPRESSION_RATE_ESTIMATION_FACTOR = 1.05f; + /** * The identifier of the writer which is used to identify the writer in the {@link * ArrowWriterPool}. @@ -70,6 +78,9 @@ public class ArrowWriter implements AutoCloseable { /** Container that holds a set of vectors for the rows. */ final VectorSchemaRoot root; + /** The table id of current writer. */ + private final long tableId; + /** * An array of writers which are responsible for the serialization of each column of the rows. */ @@ -86,28 +97,46 @@ public class ArrowWriter implements AutoCloseable { private final RowType schema; + private final ArrowCompressionInfo compressionInfo; private final CompressionCodec compressionCodec; + private final ArrowCompressionRatioEstimator compressionRatioEstimator; private int writeLimitInBytes; private int estimatedMaxRecordsCount; private int recordsCount; + /** + * Number of bytes (excluding the batch header, writeRowKind and arrow metadata) written before + * compression. + */ + private int uncompressedBodySizeInBytes = 0; + + /** The latest estimated compression ratio for this ArrowWriter. */ + private float estimatedCompressionRatio; + /** identify the number of used times of the writer, used for idempotent recycle() invoking. */ private long epoch; ArrowWriter( + long tableId, String writerKey, int bufferSizeInBytes, RowType schema, BufferAllocator allocator, ArrowWriterProvider provider, - ArrowCompressionInfo compressionInfo) { + ArrowCompressionInfo compressionInfo, + ArrowCompressionRatioEstimator compressionRatioEstimator) { + this.tableId = tableId; this.writerKey = writerKey; this.schema = schema; this.root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema(schema), allocator); this.provider = Preconditions.checkNotNull(provider); this.compressionCodec = compressionInfo.createCompressionCodec(); + this.compressionRatioEstimator = compressionRatioEstimator; + this.compressionInfo = compressionInfo; + this.estimatedCompressionRatio = + compressionRatioEstimator.estimation(tableId, compressionInfo); this.metadataLength = ArrowUtils.estimateArrowMetadataLength( @@ -138,7 +167,9 @@ public boolean isFull() { root.setRowCount(recordsCount); int metadataLength = getMetadataLength(); int bodyLength = getBodyLength(); - int currentSize = metadataLength + bodyLength; + + int estimatedBodyLength = estimatedBytesWritten(bodyLength); + int currentSize = metadataLength + estimatedBodyLength; if (currentSize >= writeLimitInBytes) { return true; } else { @@ -147,7 +178,7 @@ public boolean isFull() { (int) Math.ceil( (writeLimitInBytes - metadataLength) - / (bodyLength / (recordsCount * 1.0))); + / (estimatedBodyLength / (recordsCount * 1.0))); return false; } } else { @@ -224,11 +255,23 @@ public int serializeToOutputView(AbstractPagedOutputView outputView, int positio // update row count only when we try to write records to the output. root.setRowCount(recordsCount); + + // update the uncompressed body size. + uncompressedBodySizeInBytes = getBodyLength(); try (ArrowRecordBatch arrowBatch = new VectorUnloader(root, true, compressionCodec, true).getRecordBatch()) { PagedMemorySegmentWritableChannel channel = new PagedMemorySegmentWritableChannel(outputView); ArrowBlock block = MessageSerializer.serialize(new WriteChannel(channel), arrowBatch); + + checkState( + uncompressedBodySizeInBytes > 0, + "uncompressedRecordsSizeInBytes is 0 or negative"); + compressionRatioEstimator.updateEstimation( + tableId, + compressionInfo, + (float) block.getBodyLength() / uncompressedBodySizeInBytes); + return (int) (block.getMetadataLength() + block.getBodyLength()); } } @@ -253,6 +296,11 @@ public void recycle(long epoch) { if (this.epoch == epoch) { root.clear(); provider.recycleWriter(this); + + // reset the uncompressedBodySizeInBytes and compression ratio. + this.uncompressedBodySizeInBytes = 0; + this.estimatedCompressionRatio = + compressionRatioEstimator.estimation(tableId, compressionInfo); } } @@ -273,4 +321,13 @@ private void initFieldVector(FieldVector fieldVector) { fieldVector.allocateNew(); } } + + private int estimatedBytesWritten(int currentBytes) { + if (compressionInfo.getCompressionType() == ArrowCompressionType.NONE) { + return currentBytes; + } else { + return (int) + (currentBytes * estimatedCompressionRatio * COMPRESSION_RATE_ESTIMATION_FACTOR); + } + } } diff --git a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java index a280a308..5641ba35 100644 --- a/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java +++ b/fluss-common/src/main/java/com/alibaba/fluss/row/arrow/ArrowWriterPool.java @@ -19,6 +19,7 @@ import com.alibaba.fluss.annotation.Internal; import com.alibaba.fluss.annotation.VisibleForTesting; import com.alibaba.fluss.compression.ArrowCompressionInfo; +import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator; import com.alibaba.fluss.exception.FlussRuntimeException; import com.alibaba.fluss.shaded.arrow.org.apache.arrow.memory.BufferAllocator; import com.alibaba.fluss.shaded.arrow.org.apache.arrow.vector.VectorSchemaRoot; @@ -51,11 +52,15 @@ public class ArrowWriterPool implements ArrowWriterProvider { @GuardedBy("lock") private boolean closed = false; + @GuardedBy("lock") + private final ArrowCompressionRatioEstimator compressionRatioEstimator; + private final ReentrantLock lock = new ReentrantLock(); public ArrowWriterPool(BufferAllocator allocator) { this.allocator = allocator; this.freeWriters = new HashMap<>(); + this.compressionRatioEstimator = new ArrowCompressionRatioEstimator(); } @Override @@ -97,12 +102,14 @@ public ArrowWriter getOrCreateWriter( } else { return initialize( new ArrowWriter( + tableId, writerKey, bufferSizeInBytes, schema, allocator, this, - compressionInfo), + compressionInfo, + compressionRatioEstimator), bufferSizeInBytes); } }); @@ -133,4 +140,9 @@ public void close() { public Map> freeWriters() { return freeWriters; } + + @VisibleForTesting + public ArrowCompressionRatioEstimator compressionRatioEstimator() { + return compressionRatioEstimator; + } } diff --git a/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java b/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java new file mode 100644 index 00000000..a59c469f --- /dev/null +++ b/fluss-common/src/test/java/com/alibaba/fluss/compression/ArrowCompressionRatioEstimatorTest.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.fluss.compression; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link ArrowCompressionRatioEstimator}. */ +public class ArrowCompressionRatioEstimatorTest { + + private ArrowCompressionRatioEstimator compressionRatioEstimator; + + @BeforeEach + public void setup() { + compressionRatioEstimator = new ArrowCompressionRatioEstimator(); + } + + @Test + void testUpdateEstimation() { + class EstimationsObservedRatios { + final float currentEstimation; + final float observedRatio; + + EstimationsObservedRatios(float currentEstimation, float observedRatio) { + this.currentEstimation = currentEstimation; + this.observedRatio = observedRatio; + } + } + + // If currentEstimation is smaller than observedRatio, the updatedCompressionRatio is + // currentEstimation plus COMPRESSION_RATIO_DETERIORATE_STEP(0.05), otherwise + // currentEstimation minus COMPRESSION_RATIO_IMPROVING_STEP(0.005). There are four cases, + // and updatedCompressionRatio shouldn't smaller than observedRatio in all cases. + List estimationsObservedRatios = + Arrays.asList( + new EstimationsObservedRatios(0.8f, 0.84f), + new EstimationsObservedRatios(0.6f, 0.7f), + new EstimationsObservedRatios(0.6f, 0.4f), + new EstimationsObservedRatios(0.004f, 0.001f)); + long tableId = 150001L; + ArrowCompressionInfo compressionInfo = + new ArrowCompressionInfo(ArrowCompressionType.ZSTD, 3); + for (EstimationsObservedRatios estimationObservedRatio : estimationsObservedRatios) { + compressionRatioEstimator.updateEstimation( + tableId, compressionInfo, estimationObservedRatio.currentEstimation); + float updatedCompressionRatio = + compressionRatioEstimator.estimation(tableId, compressionInfo); + assertThat(updatedCompressionRatio) + .isGreaterThanOrEqualTo(estimationObservedRatio.observedRatio); + } + } +}