Skip to content

Commit

Permalink
[CELEBORN-894] Add config to disable checksum
Browse files Browse the repository at this point in the history
  • Loading branch information
jiang13021 committed Dec 18, 2024
1 parent e75d84f commit f719b49
Show file tree
Hide file tree
Showing 20 changed files with 432 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {
private ConcurrentHashMap<String, TransportClient> currentClient =
JavaUtils.newConcurrentHashMap();
private long driverTimestamp;
private final int BATCH_HEADER_SIZE = 4 * 4;

private final TransportContext context;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.apache.celeborn.common.rpc.RpcAddress;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.rpc.RpcEnv;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.*;
import org.apache.celeborn.common.write.DataBatches;
import org.apache.celeborn.common.write.PushState;
Expand Down Expand Up @@ -93,8 +92,6 @@ public class ShuffleClientImpl extends ShuffleClient {
private TransportContext transportContext;
protected TransportClientFactory dataClientFactory;

protected final int BATCH_HEADER_SIZE = 4 * 4;

protected byte[] extension;

// key: appShuffleIdentifier, value: shuffleId
Expand All @@ -121,6 +118,7 @@ public class ShuffleClientImpl extends ShuffleClient {
JavaUtils.newConcurrentHashMap();
private boolean pushReplicateEnabled;
private boolean fetchExcludeWorkerOnFailureEnabled;
private boolean clientShuffleChecksumEnabled;

private final ExecutorService pushDataRetryPool;

Expand Down Expand Up @@ -186,6 +184,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
pushReplicateEnabled = conf.clientPushReplicateEnabled();
fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
clientShuffleChecksumEnabled = conf.clientShuffleChecksumEnabled();
if (conf.clientPushReplicateEnabled()) {
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
} else {
Expand Down Expand Up @@ -1000,12 +999,16 @@ public int pushOrMergeData(
length = compressor.getCompressedTotalSize();
}

final byte[] body = new byte[BATCH_HEADER_SIZE + length];
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length);
System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length);
final int headerSize;
if (clientShuffleChecksumEnabled) {
headerSize = PushDataHeaderUtils.BATCH_HEADER_SIZE;
} else {
headerSize = PushDataHeaderUtils.BATCH_HEADER_SIZE_WITHOUT_CHECKSUM;
}
final byte[] body = new byte[headerSize + length];
PushDataHeaderUtils.buildDataHeader(
body, mapId, attemptId, nextBatchId, length, clientShuffleChecksumEnabled);
System.arraycopy(data, offset, body, headerSize, length);

if (doPush) {
// check limit
Expand Down Expand Up @@ -1985,6 +1988,9 @@ private StatusCode getPushDataFailCause(String message) {
cause = StatusCode.PUSH_DATA_REPLICA_WORKER_EXCLUDED;
} else if (message.startsWith(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND.name())) {
cause = StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND;
} else if (message.startsWith(StatusCode.PUSH_DATA_CHECKSUM_FAIL.name())) {
// TODO: prefer to retry instead of revive
cause = StatusCode.PUSH_DATA_CHECKSUM_FAIL;
} else if (ExceptionUtils.connectFail(message)) {
// Throw when push to primary worker connection causeException.
cause = StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.PushDataHeaderUtils;
import org.apache.celeborn.common.util.Utils;

public abstract class CelebornInputStream extends InputStream {
Expand Down Expand Up @@ -151,8 +151,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private MetricsCallback callback;

// mapId, attemptId, batchId, size
private final int BATCH_HEADER_SIZE = 4 * 4;
private final byte[] sizeBuf = new byte[BATCH_HEADER_SIZE];
private final byte[] sizeBuf;
private LongAdder skipCount = new LongAdder();
private final boolean rangeReadFilter;
private final boolean enabledReadLocalShuffle;
Expand All @@ -169,6 +168,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private int partitionId;
private ExceptionMaker exceptionMaker;
private boolean closed = false;
private boolean checksumEnabled;

CelebornInputStreamImpl(
CelebornConf conf,
Expand Down Expand Up @@ -226,6 +226,12 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
init();
firstChunk = false;
}
this.checksumEnabled = conf.clientShuffleChecksumEnabled();
if (checksumEnabled) {
sizeBuf = new byte[PushDataHeaderUtils.BATCH_HEADER_SIZE];
} else {
sizeBuf = new byte[PushDataHeaderUtils.BATCH_HEADER_SIZE_WITHOUT_CHECKSUM];
}
}

private boolean skipLocation(int startMapIndex, int endMapIndex, PartitionLocation location) {
Expand Down Expand Up @@ -586,10 +592,13 @@ private boolean fillBuffer() throws IOException {
boolean hasData = false;
while (currentChunk.isReadable() || moveToNextChunk()) {
currentChunk.readBytes(sizeBuf);
int mapId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET);
int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4);
int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8);
int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12);
if (checksumEnabled && !PushDataHeaderUtils.checkHeaderChecksum32(sizeBuf)) {
throw new CelebornIOException("Data Corrupted: checksum not match");
}
int mapId = PushDataHeaderUtils.getMapId(sizeBuf);
int attemptId = PushDataHeaderUtils.getAttemptId(sizeBuf);
int batchId = PushDataHeaderUtils.getBatchId(sizeBuf);
int size = PushDataHeaderUtils.getDataLength(sizeBuf);

if (shuffleCompressionEnabled) {
if (size > compressedBuf.length) {
Expand All @@ -614,7 +623,7 @@ private boolean fillBuffer() throws IOException {
Set<Integer> batchSet = batchesRead.get(mapId);
if (!batchSet.contains(batchId)) {
batchSet.add(batchId);
callback.incBytesRead(BATCH_HEADER_SIZE + size);
callback.incBytesRead(PushDataHeaderUtils.BATCH_HEADER_SIZE + size);
if (shuffleCompressionEnabled) {
// decompress data
int originalLength = decompressor.getOriginalLen(compressedBuf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public abstract class ShuffleClientBaseSuiteJ {
REPLICA_REPLICATE_PORT,
PartitionLocation.Mode.REPLICA);

protected final int BATCH_HEADER_SIZE = 4 * 4;
protected ChannelFuture mockedFuture = mock(ChannelFuture.class);

protected CelebornConf setupEnv(CompressionCodec codec) throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.celeborn.common.protocol.message.ControlMessages.*;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.PushDataHeaderUtils;

public class ShuffleClientSuiteJ {

Expand Down Expand Up @@ -89,7 +90,6 @@ public class ShuffleClientSuiteJ {
PartitionLocation.Mode.REPLICA);

private static final byte[] TEST_BUF1 = "hello world".getBytes(StandardCharsets.UTF_8);
private final int BATCH_HEADER_SIZE = 4 * 4;

@Test
public void testPushData() throws IOException, InterruptedException {
Expand All @@ -109,12 +109,12 @@ public void testPushData() throws IOException, InterruptedException {
1);

if (codec.equals(CompressionCodec.NONE)) {
assertEquals(TEST_BUF1.length + BATCH_HEADER_SIZE, pushDataLen);
assertEquals(TEST_BUF1.length + PushDataHeaderUtils.BATCH_HEADER_SIZE, pushDataLen);
} else {
Compressor compressor = Compressor.getCompressor(conf);
compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
final int compressedTotalSize = compressor.getCompressedTotalSize();
assertEquals(compressedTotalSize + BATCH_HEADER_SIZE, pushDataLen);
assertEquals(compressedTotalSize + PushDataHeaderUtils.BATCH_HEADER_SIZE, pushDataLen);
}
}
}
Expand Down Expand Up @@ -162,12 +162,12 @@ public void testMergeData() throws IOException, InterruptedException {
1);

if (codec.equals(CompressionCodec.NONE)) {
assertEquals(TEST_BUF1.length + BATCH_HEADER_SIZE, mergeSize);
assertEquals(TEST_BUF1.length + PushDataHeaderUtils.BATCH_HEADER_SIZE, mergeSize);
} else {
Compressor compressor = Compressor.getCompressor(conf);
compressor.compress(TEST_BUF1, 0, TEST_BUF1.length);
final int compressedTotalSize = compressor.getCompressedTotalSize();
assertEquals(compressedTotalSize + BATCH_HEADER_SIZE, mergeSize);
assertEquals(compressedTotalSize + PushDataHeaderUtils.BATCH_HEADER_SIZE, mergeSize);
}

byte[] buf1k = RandomStringUtils.random(4000).getBytes(StandardCharsets.UTF_8);
Expand All @@ -184,12 +184,12 @@ public void testMergeData() throws IOException, InterruptedException {
1);

if (codec.equals(CompressionCodec.NONE)) {
assertEquals(buf1k.length + BATCH_HEADER_SIZE, largeMergeSize);
assertEquals(buf1k.length + PushDataHeaderUtils.BATCH_HEADER_SIZE, largeMergeSize);
} else {
Compressor compressor = Compressor.getCompressor(conf);
compressor.compress(buf1k, 0, buf1k.length);
final int compressedTotalSize = compressor.getCompressedTotalSize();
assertEquals(compressedTotalSize + BATCH_HEADER_SIZE, largeMergeSize);
assertEquals(compressedTotalSize + PushDataHeaderUtils.BATCH_HEADER_SIZE, largeMergeSize);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public enum StatusCode {
OPEN_STREAM_FAILED(51),
SEGMENT_START_FAIL_REPLICA(52),
SEGMENT_START_FAIL_PRIMARY(53),
NO_SPLIT(54);
NO_SPLIT(54),
PUSH_DATA_CHECKSUM_FAIL(55);

private final byte value;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.util;

import java.util.zip.CRC32;

import org.apache.celeborn.common.unsafe.Platform;

public class PushDataHeaderUtils {
// Data Header Layout:
// | mapId (4 bytes) |
// | attemptId (4 bytes) |
// | batchId with checksum flag (4 bytes) |
// | length with checksum length (4 bytes) |
// | checksum (4 bytes) |
//
// Fields description:
// - mapId: Unique identifier for the map (4 bytes)
// - attemptId: Identifier for the attempt (4 bytes)
// - batchId with checksum flag:
// -- checksum flag: 1 bit (indicates if batchId has a checksum)
// -- batchId: 31 bits (always positive when represented as an integer)
// - length with checksum length: total length of the data + 4 bytes for checksum
// - checksum: Always positive integer (4 bytes)
public static final int BATCH_HEADER_SIZE = 5 * 4;
public static final int BATCH_HEADER_SIZE_WITHOUT_CHECKSUM = BATCH_HEADER_SIZE - 4;
public static final int MAP_ID_OFFSET = Platform.BYTE_ARRAY_OFFSET;
public static final int ATTEMPT_ID_OFFSET = Platform.BYTE_ARRAY_OFFSET + 4;
public static final int BATCH_ID_OFFSET = Platform.BYTE_ARRAY_OFFSET + 8;
public static final int LENGTH_OFFSET = Platform.BYTE_ARRAY_OFFSET + 12;
public static final int CHECKSUM_OFFSET = Platform.BYTE_ARRAY_OFFSET + 16;
public static final int POSITIVE_MASK = 0x7FFFFFFF;
public static final int HIGHEST_1_BIT_FLAG_MASK = 0x80000000;

public static void buildDataHeader(
byte[] data, int mapId, int attemptId, int batchId, int length, boolean enableChecksum) {
if (enableChecksum) {
assert data.length >= BATCH_HEADER_SIZE;
int batchIdWithChecksumFlag = batchIdWithChecksumFlag(batchId);
int lengthWithChecksum = length + 4;
Platform.putInt(data, MAP_ID_OFFSET, mapId);
Platform.putInt(data, ATTEMPT_ID_OFFSET, attemptId);
Platform.putInt(data, BATCH_ID_OFFSET, batchIdWithChecksumFlag);
Platform.putInt(data, LENGTH_OFFSET, lengthWithChecksum);
Platform.putInt(data, CHECKSUM_OFFSET, computeHeaderChecksum32(data));
} else {
assert data.length >= BATCH_HEADER_SIZE_WITHOUT_CHECKSUM;
Platform.putInt(data, MAP_ID_OFFSET, mapId);
Platform.putInt(data, ATTEMPT_ID_OFFSET, attemptId);
Platform.putInt(data, BATCH_ID_OFFSET, batchId);
Platform.putInt(data, LENGTH_OFFSET, length);
}
}

public static int batchIdWithChecksumFlag(int batchId) {
return batchId | HIGHEST_1_BIT_FLAG_MASK;
}

public static int batchIdWithoutChecksumFlag(int batchId) {
return batchId & POSITIVE_MASK;
}

public static boolean hasChecksumFlag(byte[] data) {
int batchId = Platform.getInt(data, BATCH_ID_OFFSET);
return (batchId & HIGHEST_1_BIT_FLAG_MASK) != 0;
}

public static int getMapId(byte[] data) {
return Platform.getInt(data, MAP_ID_OFFSET);
}

public static int getAttemptId(byte[] data) {
return Platform.getInt(data, ATTEMPT_ID_OFFSET);
}

public static int getBatchId(byte[] data) {
return batchIdWithoutChecksumFlag(Platform.getInt(data, BATCH_ID_OFFSET));
}

public static int getChecksumLength(byte[] data) {
if (hasChecksumFlag(data)) {
return 4;
} else {
return 0;
}
}

public static int getDataLength(byte[] data) {
return Platform.getInt(data, LENGTH_OFFSET) - getChecksumLength(data);
}

// lengthWithChecksumLength = Platform.getInt(data, LENGTH_OFFSET)
public static int getTotalLengthWithHeader(byte[] data) {
return BATCH_HEADER_SIZE_WITHOUT_CHECKSUM + Platform.getInt(data, LENGTH_OFFSET);
}

public static int computeHeaderChecksum32(byte[] data) {
assert data.length >= BATCH_HEADER_SIZE_WITHOUT_CHECKSUM;
CRC32 crc32 = new CRC32();
crc32.update(data, 0, BATCH_HEADER_SIZE_WITHOUT_CHECKSUM);
return ((int) crc32.getValue()) & POSITIVE_MASK;
}

public static boolean checkHeaderChecksum32(byte[] data) {
assert data.length >= BATCH_HEADER_SIZE;
int expectedChecksum = Platform.getInt(data, CHECKSUM_OFFSET);
int currentChecksum = computeHeaderChecksum32(data);
return currentChecksum == expectedChecksum;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def registerShuffleFilterExcludedWorkerEnabled: Boolean =
get(REGISTER_SHUFFLE_FILTER_EXCLUDED_WORKER_ENABLED)
def reviseLostShufflesEnabled: Boolean = get(REVISE_LOST_SHUFFLES_ENABLED)
def clientShuffleChecksumEnabled: Boolean = get(CLIENT_SHUFFLE_CHECKSUM_ENABLED)

// //////////////////////////////////////////////////////
// Worker //
Expand Down Expand Up @@ -1207,6 +1208,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
}).reduce(_ ++ _)
}.orElse(Some(Map("MEMORY" -> List("SSD", "HDD", "HDFS", "OSS"))))

def workerChecksumVerifyEnabled: Boolean = get(WORKER_CHECKSUM_VERIFY_ENABLED)

// //////////////////////////////////////////////////////
// Decommission //
// //////////////////////////////////////////////////////
Expand Down Expand Up @@ -4116,6 +4119,14 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefault(0)

val CLIENT_SHUFFLE_CHECKSUM_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.shuffle.checksum.enabled")
.categories("client")
.doc("Whether to enable checksum for shuffle data.")
.version("0.6.0")
.booleanConf
.createWithDefault(false)

val WORKER_PUSH_HEARTBEAT_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.push.heartbeat.enabled")
.categories("worker")
Expand Down Expand Up @@ -5843,6 +5854,14 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(0)

val WORKER_CHECKSUM_VERIFY_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.checksum.verify.enabled")
.categories("worker")
.version("0.6.0")
.doc("Whether to verify checksum when handling pushed data.")
.booleanConf
.createWithDefault(false)

val MASTER_SEND_APPLICATION_META_THREADS: ConfigEntry[Int] =
buildConf("celeborn.master.send.applicationMeta.threads")
.categories("master")
Expand Down
Loading

0 comments on commit f719b49

Please sign in to comment.