Skip to content

Commit

Permalink
[log] Refactor ArrowCompressionRatioEstimator to be lock-free and est…
Browse files Browse the repository at this point in the history
…imate ratio for only a table (#330)
  • Loading branch information
wuchong committed Feb 1, 2025
1 parent 8e2bdbc commit f96c392
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.alibaba.fluss.utils.Preconditions;

import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
Expand Down Expand Up @@ -112,6 +113,7 @@ void testAppendOnly() throws Exception {

@ParameterizedTest
@ValueSource(booleans = {true, false})
@Disabled("TODO, fix me in #116")
void testAppendWithSmallBuffer(boolean indexedFormat) throws Exception {
TableDescriptor desc =
indexedFormat
Expand Down Expand Up @@ -991,7 +993,7 @@ void testArrowCompressionAndProject(String compression, String level) throws Exc
.property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_TYPE.key(), compression)
.property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_ZSTD_LEVEL.key(), level)
.build();
TablePath tablePath = TablePath.of("test_db_1", "test_arrow_compression_and_project");
TablePath tablePath = TablePath.of("test_db_1", "test_arrow_" + compression + level);
createTable(tablePath, tableDescriptor, false);

try (Connection conn = ConnectionFactory.createConnection(clientConf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.alibaba.fluss.client.write;

import com.alibaba.fluss.compression.ArrowCompressionInfo;
import com.alibaba.fluss.compression.ArrowCompressionRatioEstimator;
import com.alibaba.fluss.compression.ArrowCompressionType;
import com.alibaba.fluss.memory.MemorySegment;
import com.alibaba.fluss.memory.PreAllocatedPagedOutputView;
Expand All @@ -29,12 +28,14 @@
import com.alibaba.fluss.record.LogRecordReadContext;
import com.alibaba.fluss.record.MemoryLogRecords;
import com.alibaba.fluss.record.bytesview.BytesView;
import com.alibaba.fluss.row.arrow.ArrowWriter;
import com.alibaba.fluss.row.arrow.ArrowWriterPool;
import com.alibaba.fluss.row.indexed.IndexedRow;
import com.alibaba.fluss.shaded.arrow.org.apache.arrow.memory.BufferAllocator;
import com.alibaba.fluss.shaded.arrow.org.apache.arrow.memory.RootAllocator;
import com.alibaba.fluss.utils.CloseableIterator;

import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -188,37 +189,38 @@ void testArrowCompressionRatioEstimated() throws Exception {
// (COMPRESSION_RATIO_IMPROVING_STEP#COMPRESSION_RATIO_IMPROVING_STEP) each time. Therefore,
// the loop runs 100 times, and theoretically, the final number of input records will be
// much greater than at the beginning.
int round = 100;
int[] recordCounts = new int[round];
for (int i = 0; i < round; i++) {
float previousRatio = -1.0f;
float currentRatio = 1.0f;
int lastBytesInSize = 0;
// exit the loop until compression ratio is converged
while (previousRatio != currentRatio) {
ArrowWriter arrowWriter =
writerProvider.getOrCreateWriter(
tb.getTableId(),
DATA1_TABLE_INFO.getSchemaId(),
maxSizeInBytes,
DATA1_ROW_TYPE,
compressionInfo);

ArrowLogWriteBatch arrowLogWriteBatch =
new ArrowLogWriteBatch(
tb,
DATA1_PHYSICAL_TABLE_PATH,
DATA1_TABLE_INFO.getSchemaId(),
writerProvider.getOrCreateWriter(
tb.getTableId(),
DATA1_TABLE_INFO.getSchemaId(),
maxSizeInBytes,
DATA1_ROW_TYPE,
compressionInfo),
new PreAllocatedPagedOutputView(memorySegmentList));
arrowWriter,
new PreAllocatedPagedOutputView(memorySegmentList),
System.currentTimeMillis());

int recordCount = 0;
while (arrowLogWriteBatch.tryAppend(
createWriteRecord(
row(
DATA1_ROW_TYPE,
new Object[] {
recordCount,
"a a a"
+ recordCount
})),
new Object[] {recordCount, RandomStringUtils.random(100)})),
newWriteCallback())) {
recordCount++;
}

recordCounts[i] = recordCount;
// batch full.
boolean appendResult =
arrowLogWriteBatch.tryAppend(
Expand All @@ -228,16 +230,18 @@ void testArrowCompressionRatioEstimated() throws Exception {

// close this batch and recycle the writer.
arrowLogWriteBatch.close();
arrowLogWriteBatch.build();
BytesView built = arrowLogWriteBatch.build();
lastBytesInSize = built.getBytesLength();

ArrowCompressionRatioEstimator compressionRatioEstimator =
writerProvider.compressionRatioEstimator();
float currentRatio =
compressionRatioEstimator.estimation(tb.getTableId(), compressionInfo);
assertThat(currentRatio).isNotEqualTo(1.0f);
previousRatio = currentRatio;
currentRatio = arrowWriter.getCompressionRatioEstimator().estimation();
}

assertThat(recordCounts[round - 1]).isGreaterThan(recordCounts[0]);
// when the compression ratio is converged, the memory buffer should be fully used.
assertThat(lastBytesInSize)
.isGreaterThan((int) (maxSizeInBytes * ArrowWriter.BUFFER_USAGE_RATIO))
.isLessThan(maxSizeInBytes);
assertThat(currentRatio).isLessThan(1.0f);
}

private WriteRecord createWriteRecord(IndexedRow row) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.alibaba.fluss.record.MemoryLogRecords;
import com.alibaba.fluss.record.RowKind;
import com.alibaba.fluss.row.InternalRow;
import com.alibaba.fluss.row.arrow.ArrowWriter;
import com.alibaba.fluss.row.indexed.IndexedRow;
import com.alibaba.fluss.rpc.GatewayClientProxy;
import com.alibaba.fluss.rpc.RpcClient;
Expand All @@ -46,6 +47,7 @@
import com.alibaba.fluss.utils.CloseableIterator;
import com.alibaba.fluss.utils.clock.ManualClock;

import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -75,6 +77,22 @@

/** Test for {@link RecordAccumulator}. */
public class RecordAccumulatorTest {
private static final long ZSTD_TABLE_ID = 16001L;
private static final PhysicalTablePath ZSTD_PHYSICAL_TABLE_PATH =
PhysicalTablePath.of(TablePath.of("test_db_1", "test_zstd_table_1"));
private static final TableInfo ZSTD_TABLE_INFO =
new TableInfo(
ZSTD_PHYSICAL_TABLE_PATH.getTablePath(),
ZSTD_TABLE_ID,
TableDescriptor.builder()
.schema(DATA1_SCHEMA)
.distributedBy(3)
.property(ConfigOptions.TABLE_LOG_ARROW_COMPRESSION_TYPE.key(), "zstd")
.build(),
1,
System.currentTimeMillis(),
System.currentTimeMillis());

ServerNode node1 = new ServerNode(1, "localhost", 90, ServerType.TABLET_SERVER);
ServerNode node2 = new ServerNode(2, "localhost", 91, ServerType.TABLET_SERVER);
ServerNode node3 = new ServerNode(3, "localhost", 92, ServerType.TABLET_SERVER);
Expand All @@ -91,6 +109,7 @@ public class RecordAccumulatorTest {
new BucketLocation(DATA1_PHYSICAL_TABLE_PATH, DATA1_TABLE_ID, 2, node2, serverNodes);
private final BucketLocation bucket4 =
new BucketLocation(DATA1_PHYSICAL_TABLE_PATH, DATA1_TABLE_ID, 3, node2, serverNodes);

private final WriteCallback writeCallback =
exception -> {
if (exception != null) {
Expand Down Expand Up @@ -149,6 +168,63 @@ void testDrainBatches() throws Exception {
verifyTableBucketInBatches(batches3, tb1, tb3);
}

@Test
void testDrainCompressedBatches() throws Exception {
int batchSize = 10 * 1024;
int bucketNum = 10;
RecordAccumulator accum =
createTestRecordAccumulator(
Integer.MAX_VALUE, batchSize, batchSize, Integer.MAX_VALUE);
List<BucketLocation> bucketLocations = new ArrayList<>();
for (int b = 0; b < bucketNum; b++) {
bucketLocations.add(
new BucketLocation(
ZSTD_PHYSICAL_TABLE_PATH, ZSTD_TABLE_ID, b, node1, serverNodes));
}
// all buckets are located in node1
cluster = updateCluster(bucketLocations);

appendUntilCompressionRatioStable(accum, batchSize);

for (int i = 0; i < bucketNum; i++) {
appendUntilBatchFull(accum, i);
}

// all 3 buckets are located in node1
Map<Integer, List<WriteBatch>> batches =
accum.drain(cluster, Collections.singleton(node1), batchSize * bucketNum);
// the compression ratio is smaller than 1.0,
// so bucketNum * batch_size should contain all compressed batches for each bucket
assertThat(batches.containsKey(node1.id())).isTrue();
assertThat(batches.get(node1.id()).size()).isEqualTo(bucketNum);
}

private void appendUntilCompressionRatioStable(RecordAccumulator accum, int batchSize)
throws Exception {
while (true) {
appendUntilBatchFull(accum, 0);
Map<Integer, List<WriteBatch>> batches =
accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE);
WriteBatch batch = batches.get(node1.id()).get(0);
int actualSize = batch.build().getBytesLength();
if (actualSize > batchSize * ArrowWriter.BUFFER_USAGE_RATIO) {
return;
}
}
}

private void appendUntilBatchFull(RecordAccumulator accum, int bucketId) throws Exception {
while (true) {
InternalRow row = row(DATA1_ROW_TYPE, new Object[] {1, RandomStringUtils.random(10)});
PhysicalTablePath tablePath = PhysicalTablePath.of(ZSTD_TABLE_INFO.getTablePath());
WriteRecord record = new WriteRecord(tablePath, WriteKind.APPEND, row, null);
// append until the batch is full
if (accum.append(record, writeCallback, cluster, bucketId, false).batchIsFull) {
break;
}
}
}

@Test
void testFull() throws Exception {
// test case assumes that the records do not fill the batch completely
Expand Down Expand Up @@ -465,6 +541,7 @@ private Cluster updateCluster(List<BucketLocation> bucketLocations) {
System.currentTimeMillis());
Map<TablePath, TableInfo> tableInfoByPath = new HashMap<>();
tableInfoByPath.put(DATA1_TABLE_PATH, data1NonPkTableInfo);
tableInfoByPath.put(ZSTD_TABLE_INFO.getTablePath(), ZSTD_TABLE_INFO);

return new Cluster(
aliveTabletServersById,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,87 +20,49 @@

import javax.annotation.concurrent.ThreadSafe;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import static com.alibaba.fluss.utils.concurrent.LockUtils.inLock;

/**
* This class help estimate the compression ratio for each table and each arrow compression type
* combination.
*/
/** This class helps estimate the compression ratio for a table. */
@Internal
@ThreadSafe
public class ArrowCompressionRatioEstimator {
/**
* The constant speed to increase compression ratio when a batch compresses better than
* expected.
*/
public static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f;
private static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f;

/**
* The minimum speed to decrease compression ratio when a batch compresses worse than expected.
*/
public static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f;
private static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f;

/**
* The default compression ratio when initialize a new {@link ArrowCompressionRatioEstimator}.
*/
private static final float DEFAULT_COMPRESSION_RATIO = 1.0f;

private final Map<Long, Map<String, Float>> compressionRatio;
private final Map<Long, Lock> tableLocks;
/** The current compression ratio, use volatile for lock-free. */
private volatile float compressionRatio;

public ArrowCompressionRatioEstimator() {
compressionRatio = new ConcurrentHashMap<>();
tableLocks = new ConcurrentHashMap<>();
compressionRatio = DEFAULT_COMPRESSION_RATIO;
}

/**
* Update the compression ratio estimation for a table and related compression info with the
* observed compression ratio.
*/
public void updateEstimation(
long tableId, ArrowCompressionInfo compressionInfo, float observedRatio) {
Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock());
inLock(
lock,
() -> {
Map<String, Float> compressionRatioMap =
compressionRatio.computeIfAbsent(
tableId, k -> new ConcurrentHashMap<>());
String compressionKey = compressionInfo.toString();
float currentEstimation =
compressionRatioMap.getOrDefault(compressionKey, 1.0f);
if (observedRatio > currentEstimation) {
compressionRatioMap.put(
compressionKey,
Math.max(
currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP,
observedRatio));
} else if (observedRatio < currentEstimation) {
compressionRatioMap.put(
compressionKey,
Math.max(
currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP,
observedRatio));
}
});
/** Update the compression ratio estimation with the observed compression ratio. */
public void updateEstimation(float observedRatio) {
float currentEstimation = compressionRatio;
// it is possible it can't guarantee atomic and isolation update,
// but it's fine as it's just an estimation
if (observedRatio > currentEstimation) {
compressionRatio =
Math.max(currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP, observedRatio);
} else if (observedRatio < currentEstimation) {
compressionRatio =
Math.max(currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP, observedRatio);
}
}

/** Get the compression ratio estimation for a table and related compression info. */
public float estimation(long tableId, ArrowCompressionInfo compressionInfo) {
Lock lock = tableLocks.computeIfAbsent(tableId, k -> new ReentrantLock());
return inLock(
lock,
() -> {
Map<String, Float> compressionRatioMap =
compressionRatio.computeIfAbsent(
tableId, k -> new ConcurrentHashMap<>());
String compressionKey = compressionInfo.toString();

if (!compressionRatioMap.containsKey(compressionKey)) {
compressionRatioMap.put(compressionKey, 1.0f);
}

return compressionRatioMap.get(compressionKey);
});
/** Get current compression ratio estimation. */
public float estimation() {
return compressionRatio;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,16 @@ public void recycleArrowWriter() {

public int estimatedSizeInBytes() {
if (bytesView != null) {
// accurate total size in bytes
// accurate total size in bytes (compressed if compression is enabled)
return bytesView.getBytesLength();
}

if (reCalculateSizeInBytes) {
// make size in bytes up-to-date
estimatedSizeInBytes =
ARROW_ROWKIND_OFFSET + rowKindWriter.sizeInBytes() + arrowWriter.sizeInBytes();
ARROW_ROWKIND_OFFSET
+ rowKindWriter.sizeInBytes()
+ arrowWriter.estimatedSizeInBytes();
}

reCalculateSizeInBytes = false;
Expand Down
Loading

0 comments on commit f96c392

Please sign in to comment.