From becc6b065d4498e83ac95def832cd6243b2af85b Mon Sep 17 00:00:00 2001 From: binjie yang Date: Mon, 9 Dec 2024 20:35:23 +0800 Subject: [PATCH 1/3] [CELEBORN-1768][WRITER] Refactoring Shuffle Writer to extract common methods --- .../shuffle/celeborn/BasedShuffleWriter.java | 223 ++++++++++++++++ .../celeborn/HashBasedShuffleWriter.java | 209 ++------------- .../celeborn/SortBasedShuffleWriter.java | 246 ++---------------- .../shuffle/celeborn/SparkShuffleManager.java | 3 +- .../SortBasedShuffleWriterSuiteJ.java | 17 +- .../ColumnarHashBasedShuffleWriter.java | 2 +- 6 files changed, 279 insertions(+), 421 deletions(-) create mode 100644 client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java new file mode 100644 index 00000000000..f732a68ad43 --- /dev/null +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java @@ -0,0 +1,223 @@ +package org.apache.spark.shuffle.celeborn; + +import java.io.IOException; +import java.util.concurrent.atomic.LongAdder; + +import scala.Option; +import scala.Product2; +import scala.collection.Iterator; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.sql.execution.UnsafeRowSerializer; +import org.apache.spark.storage.BlockManagerId; + +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.common.CelebornConf; + +public abstract class BasedShuffleWriter extends ShuffleWriter { + + protected static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + protected static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; + + protected final int PUSH_BUFFER_INIT_SIZE; + protected final int PUSH_BUFFER_MAX_SIZE; + protected final ShuffleDependency dep; + protected final Partitioner partitioner; + protected final ShuffleWriteMetricsReporter writeMetrics; + protected final int shuffleId; + protected final int mapId; + protected final int encodedAttemptId; + protected final TaskContext taskContext; + protected final ShuffleClient shuffleClient; + protected final int numMappers; + protected final int numPartitions; + protected final OpenByteArrayOutputStream serBuffer; + protected final SerializationStream serOutputStream; + private final boolean unsafeRowFastWrite; + + protected final LongAdder[] mapStatusLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true and + * then call stop() with success = false if they get an exception, we want to make sure we don't + * try deleting files, etc. twice. + */ + private volatile boolean stopping = false; + + protected long peakMemoryUsedBytes = 0; + protected long tmpRecordsWritten = 0; + + public BasedShuffleWriter( + int shuffleId, + CelebornShuffleHandle handle, + TaskContext taskContext, + CelebornConf conf, + ShuffleClient client, + ShuffleWriteMetricsReporter metrics) { + PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize(); + PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize(); + this.dep = handle.dependency(); + this.partitioner = dep.partitioner(); + this.writeMetrics = metrics; + this.shuffleId = shuffleId; + this.mapId = taskContext.partitionId(); + // [CELEBORN-1496] using the encoded attempt number instead of task attempt number + this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext); + this.taskContext = taskContext; + this.shuffleClient = client; + this.numMappers = handle.numMappers(); + this.numPartitions = dep.partitioner().numPartitions(); + SerializerInstance serializer = dep.serializer().newInstance(); + serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serOutputStream = serializer.serializeStream(serBuffer); + unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite(); + + mapStatusLengths = new LongAdder[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + mapStatusLengths[i] = new LongAdder(); + } + } + + protected void doWrite(scala.collection.Iterator> records) + throws IOException, InterruptedException { + if (canUseFastWrite()) { + fastWrite0(records); + } else if (dep.mapSideCombine()) { + if (dep.aggregator().isEmpty()) { + throw new UnsupportedOperationException( + "When using map side combine, an aggregator must be specified."); + } + write0(dep.aggregator().get().combineValuesByKey(records, taskContext)); + } else { + write0(records); + } + } + + @Override + public void write(Iterator> records) throws IOException { + boolean needCleanupPusher = true; + try { + doWrite(records); + close(); + needCleanupPusher = false; + } catch (InterruptedException e) { + TaskInterruptedHelper.throwTaskKillException(); + } finally { + if (needCleanupPusher) { + cleanupPusher(); + } + } + } + + @Override + public Option stop(boolean success) { + try { + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); + + if (stopping) { + return Option.empty(); + } else { + stopping = true; + if (success) { + BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId(); + MapStatus mapStatus = + SparkUtils.createMapStatus( + bmId, SparkUtils.unwrap(mapStatusLengths), taskContext.taskAttemptId()); + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + return Option.empty(); + } + } + } finally { + shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId); + } + } + + // Added in SPARK-32917, for Spark 3.2 and above + @SuppressWarnings("MissingOverride") + public long[] getPartitionLengths() { + throw new UnsupportedOperationException( + "Celeborn is not compatible with Spark push mode, please set spark.shuffle.push.enabled to false"); + } + + abstract void fastWrite0(scala.collection.Iterator iterator) + throws IOException, InterruptedException; + + abstract void write0(scala.collection.Iterator iterator) throws IOException, InterruptedException; + + abstract void updatePeakMemoryUsed(); + + abstract void cleanupPusher() throws IOException; + + abstract void closeWrite() throws IOException, InterruptedException; + + @VisibleForTesting + boolean canUseFastWrite() { + boolean keyIsPartitionId = false; + if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) { + // SPARK-39391 renames PartitionIdPassthrough's package + String partitionerClassName = partitioner.getClass().getSimpleName(); + keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName); + } + return keyIsPartitionId; + } + + /** Return the peak memory used so far, in bytes. */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { + int bytesWritten = + shuffleClient.pushData( + shuffleId, + mapId, + encodedAttemptId, + partitionId, + buffer, + 0, + numBytes, + numMappers, + numPartitions); + mapStatusLengths[partitionId].add(bytesWritten); + writeMetrics.incBytesWritten(bytesWritten); + } + + /** + * This method will push the remaining data and close these pushers. + * It's important, will send Mapper End RPC to LifecycleManager to update + * the attempt of the corresponding task. + * We should only call this method when the task is successfully completed. + */ + protected void close() throws IOException, InterruptedException { + long pushMergedDataTime = System.nanoTime(); + closeWrite(); + shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId); + writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime); + updateRecordsWrittenMetrics(); + + long waitStartTime = System.nanoTime(); + shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); + writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); + } + + protected void updateRecordsWrittenMetrics() { + writeMetrics.incRecordsWritten(tmpRecordsWritten); + tmpRecordsWritten = 0; + } +} diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 4c5e6739b3b..6a0dc24b16e 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -19,30 +19,15 @@ import java.io.IOException; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.LongAdder; -import javax.annotation.Nullable; - -import scala.Option; import scala.Product2; -import scala.reflect.ClassTag; -import scala.reflect.ClassTag$; -import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.annotation.Private; -import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.execution.UnsafeRowSerializer; import org.apache.spark.sql.execution.metric.SQLMetric; -import org.apache.spark.storage.BlockManagerId; import org.apache.spark.unsafe.Platform; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,50 +39,14 @@ import org.apache.celeborn.common.util.Utils; @Private -public class HashBasedShuffleWriter extends ShuffleWriter { +public class HashBasedShuffleWriter extends BasedShuffleWriter { private static final Logger logger = LoggerFactory.getLogger(HashBasedShuffleWriter.class); - private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); - private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; - - private final int PUSH_BUFFER_INIT_SIZE; - private final int PUSH_BUFFER_MAX_SIZE; - private final ShuffleDependency dep; - private final Partitioner partitioner; - private final ShuffleWriteMetricsReporter writeMetrics; - private final int shuffleId; - private final int mapId; - private final int encodedAttemptId; - private final TaskContext taskContext; - private final ShuffleClient shuffleClient; - private final int numMappers; - private final int numPartitions; - - @Nullable private MapStatus mapStatus; - private long peakMemoryUsedBytes = 0; - - private final OpenByteArrayOutputStream serBuffer; - private final SerializationStream serOutputStream; - private byte[][] sendBuffers; private int[] sendOffsets; - - private final LongAdder[] mapStatusLengths; - protected long tmpRecordsWritten = 0; - - private final SendBufferPool sendBufferPool; - - /** - * Are we in the process of stopping? Because map tasks can call stop() with success = true and - * then call stop() with success = false if they get an exception, we want to make sure we don't - * try deleting files, etc. twice. - */ - private volatile boolean stopping = false; - private DataPusher dataPusher; - - private final boolean unsafeRowFastWrite; + private final SendBufferPool sendBufferPool; // In order to facilitate the writing of unit test code, ShuffleClient needs to be passed in as // parameters. By the way, simplify the passed parameters. @@ -110,31 +59,9 @@ public HashBasedShuffleWriter( ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool) throws IOException { - this.mapId = taskContext.partitionId(); - this.dep = handle.dependency(); - this.shuffleId = shuffleId; - this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext); - SerializerInstance serializer = dep.serializer().newInstance(); - this.partitioner = dep.partitioner(); - this.writeMetrics = metrics; - this.taskContext = taskContext; - this.numMappers = handle.numMappers(); - this.numPartitions = dep.partitioner().numPartitions(); - this.shuffleClient = client; - - unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite(); - serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); - serOutputStream = serializer.serializeStream(serBuffer); - - mapStatusLengths = new LongAdder[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - mapStatusLengths[i] = new LongAdder(); - } - - PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize(); - PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize(); - + super(shuffleId, handle, taskContext, conf, client, metrics); this.sendBufferPool = sendBufferPool; + sendBuffers = sendBufferPool.acquireBuffer(numPartitions); sendOffsets = new int[numPartitions]; @@ -159,42 +86,6 @@ public HashBasedShuffleWriter( } @Override - public void write(scala.collection.Iterator> records) throws IOException { - boolean needCleanupPusher = true; - try { - if (canUseFastWrite()) { - fastWrite0(records); - } else if (dep.mapSideCombine()) { - if (dep.aggregator().isEmpty()) { - throw new UnsupportedOperationException( - "When using map side combine, an aggregator must be specified."); - } - write0(dep.aggregator().get().combineValuesByKey(records, taskContext)); - } else { - write0(records); - } - close(); - needCleanupPusher = false; - } catch (InterruptedException e) { - TaskInterruptedHelper.throwTaskKillException(); - } finally { - if (needCleanupPusher) { - cleanupPusher(); - } - } - } - - @VisibleForTesting - boolean canUseFastWrite() { - boolean keyIsPartitionId = false; - if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) { - // SPARK-39391 renames PartitionIdPassthrough's package - String partitionerClassName = partitioner.getClass().getSimpleName(); - keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName); - } - return keyIsPartitionId; - } - protected void fastWrite0(scala.collection.Iterator iterator) throws IOException, InterruptedException { final scala.collection.Iterator> records = iterator; @@ -238,7 +129,9 @@ protected void fastWrite0(scala.collection.Iterator iterator) } } - private void write0(scala.collection.Iterator iterator) throws IOException, InterruptedException { + @Override + protected void write0(scala.collection.Iterator iterator) + throws IOException, InterruptedException { final scala.collection.Iterator> records = iterator; while (records.hasNext()) { @@ -265,6 +158,11 @@ private void write0(scala.collection.Iterator iterator) throws IOException, Inte } } + @Override + void updatePeakMemoryUsed() { + // do nothing, hash shuffle writer always update this used peak memory + } + private byte[] getOrCreateBuffer(int partitionId) { byte[] buffer = sendBuffers[partitionId]; if (buffer == null) { @@ -275,23 +173,6 @@ private byte[] getOrCreateBuffer(int partitionId) { return buffer; } - protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { - logger.debug("Push giant record, size {}.", numBytes); - int bytesWritten = - shuffleClient.pushData( - shuffleId, - mapId, - encodedAttemptId, - partitionId, - buffer, - 0, - numBytes, - numMappers, - numPartitions); - mapStatusLengths[partitionId].add(bytesWritten); - writeMetrics.incBytesWritten(bytesWritten); - } - private int getOrUpdateOffset(int partitionId, int serializedRecordSize) throws IOException, InterruptedException { int offset = sendOffsets[partitionId]; @@ -322,7 +203,12 @@ private void flushSendBuffer(int partitionId, byte[] buffer, int size) writeMetrics.incWriteTime(System.nanoTime() - start); } - protected void closeWrite() throws IOException { + @Override + protected void closeWrite() throws IOException, InterruptedException { + // here we wait for all the in-flight batches to return which sent by dataPusher thread + dataPusher.waitOnTermination(); + sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue()); + shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId); // merge and push residual data to reduce network traffic // NB: since dataPusher thread have no in-flight data at this point, // we now push merged data by task thread will not introduce any contention @@ -356,7 +242,8 @@ protected void mergeData(int partitionId, byte[] buffer, int offset, int length) writeMetrics.incBytesWritten(bytesWritten); } - private void cleanupPusher() throws IOException { + @Override + protected void cleanupPusher() throws IOException { try { dataPusher.waitOnTermination(); sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue()); @@ -364,60 +251,4 @@ private void cleanupPusher() throws IOException { TaskInterruptedHelper.throwTaskKillException(); } } - - private void close() throws IOException, InterruptedException { - // here we wait for all the in-flight batches to return which sent by dataPusher thread - long pushMergedDataTime = System.nanoTime(); - dataPusher.waitOnTermination(); - sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue()); - shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId); - closeWrite(); - shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId); - writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime); - updateRecordsWrittenMetrics(); - - long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); - writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); - - BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId(); - mapStatus = - SparkUtils.createMapStatus( - bmId, SparkUtils.unwrap(mapStatusLengths), taskContext.taskAttemptId()); - } - - private void updateRecordsWrittenMetrics() { - writeMetrics.incRecordsWritten(tmpRecordsWritten); - tmpRecordsWritten = 0; - } - - @Override - public Option stop(boolean success) { - try { - taskContext.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes); - - if (stopping) { - return Option.empty(); - } else { - stopping = true; - if (success) { - if (mapStatus == null) { - throw new IllegalStateException("Cannot call stop(true) without having called write()"); - } - return Option.apply(mapStatus); - } else { - return Option.empty(); - } - } - } finally { - shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId); - } - } - - // Added in SPARK-32917, for Spark 3.2 and above - @SuppressWarnings("MissingOverride") - public long[] getPartitionLengths() { - throw new UnsupportedOperationException( - "Celeborn is not compatible with Spark push mode, please set spark.shuffle.push.enabled to false"); - } } diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 5717910eea9..b6fd2f40770 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -18,28 +18,15 @@ package org.apache.spark.shuffle.celeborn; import java.io.IOException; -import java.util.concurrent.atomic.LongAdder; -import scala.Option; import scala.Product2; -import scala.reflect.ClassTag; -import scala.reflect.ClassTag$; -import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.annotation.Private; -import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.execution.UnsafeRowSerializer; import org.apache.spark.sql.execution.metric.SQLMetric; -import org.apache.spark.storage.BlockManagerId; import org.apache.spark.unsafe.Platform; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,96 +34,40 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; -import org.apache.celeborn.common.util.Utils; @Private -public class SortBasedShuffleWriter extends ShuffleWriter { +public class SortBasedShuffleWriter extends BasedShuffleWriter { private static final Logger logger = LoggerFactory.getLogger(SortBasedShuffleWriter.class); - - private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); - private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; - - private final ShuffleDependency dep; - private final Partitioner partitioner; - private final ShuffleWriteMetricsReporter writeMetrics; - private final int shuffleId; - private final int mapId; - private final int encodedAttemptId; - private final TaskContext taskContext; - private final ShuffleClient shuffleClient; - private final int numMappers; - private final int numPartitions; - - private final long pushBufferMaxSize; + private final SendBufferPool sendBufferPool; private final SortBasedPusher pusher; - private long peakMemoryUsedBytes = 0; - - private final OpenByteArrayOutputStream serBuffer; - private final SerializationStream serOutputStream; - - private final LongAdder[] mapStatusLengths; - private long tmpRecordsWritten = 0; - - /** - * Are we in the process of stopping? Because map tasks can call stop() with success = true and - * then call stop() with success = false if they get an exception, we want to make sure we don't - * try deleting files, etc. twice. - */ - private volatile boolean stopping = false; - - private final boolean unsafeRowFastWrite; public SortBasedShuffleWriter( int shuffleId, - ShuffleDependency dep, - int numMappers, + CelebornShuffleHandle handle, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool) throws IOException { - this(shuffleId, dep, numMappers, taskContext, conf, client, metrics, sendBufferPool, null); + this(shuffleId, handle, taskContext, conf, client, metrics, sendBufferPool, null); } // In order to facilitate the writing of unit test code, ShuffleClient needs to be passed in as // parameters. By the way, simplify the passed parameters. public SortBasedShuffleWriter( int shuffleId, - ShuffleDependency dep, - int numMappers, + CelebornShuffleHandle handle, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool, - SortBasedPusher pusher) - throws IOException { - this.mapId = taskContext.partitionId(); - this.dep = dep; - this.shuffleId = shuffleId; - this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext); - SerializerInstance serializer = dep.serializer().newInstance(); - this.partitioner = dep.partitioner(); - this.writeMetrics = metrics; - this.taskContext = taskContext; - this.numMappers = numMappers; - this.numPartitions = dep.partitioner().numPartitions(); - this.shuffleClient = client; - unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite(); - - serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); - serOutputStream = serializer.serializeStream(serBuffer); - - this.mapStatusLengths = new LongAdder[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - this.mapStatusLengths[i] = new LongAdder(); - } - - pushBufferMaxSize = conf.clientPushBufferMaxSize(); - + SortBasedPusher pusher) { + super(shuffleId, handle, taskContext, conf, client, metrics); + this.sendBufferPool = sendBufferPool; if (pusher == null) { this.pusher = new SortBasedPusher( @@ -159,99 +90,16 @@ public SortBasedShuffleWriter( } } - public SortBasedShuffleWriter( - CelebornShuffleHandle handle, - TaskContext taskContext, - CelebornConf conf, - ShuffleClient client, - ShuffleWriteMetricsReporter metrics, - SendBufferPool sendBufferPool) - throws IOException { - this( - SparkUtils.celebornShuffleId(client, handle, taskContext, true), - handle.dependency(), - handle.numMappers(), - taskContext, - conf, - client, - metrics, - sendBufferPool); - } - - public SortBasedShuffleWriter( - CelebornShuffleHandle handle, - TaskContext taskContext, - CelebornConf conf, - ShuffleClient client, - ShuffleWriteMetricsReporter metrics, - SendBufferPool sendBufferPool, - SortBasedPusher pusher) - throws IOException { - this( - SparkUtils.celebornShuffleId(client, handle, taskContext, true), - handle.dependency(), - handle.numMappers(), - taskContext, - conf, - client, - metrics, - sendBufferPool, - pusher); - } - - private void updatePeakMemoryUsed() { + @Override + protected void updatePeakMemoryUsed() { long mem = pusher.getPeakMemoryUsedBytes(); if (mem > peakMemoryUsedBytes) { peakMemoryUsedBytes = mem; } } - /** Return the peak memory used so far, in bytes. */ - public long getPeakMemoryUsedBytes() { - updatePeakMemoryUsed(); - return peakMemoryUsedBytes; - } - - void doWrite(scala.collection.Iterator> records) throws IOException { - if (canUseFastWrite()) { - fastWrite0(records); - } else if (dep.mapSideCombine()) { - if (dep.aggregator().isEmpty()) { - throw new UnsupportedOperationException( - "When using map side combine, an aggregator must be specified."); - } - write0(dep.aggregator().get().combineValuesByKey(records, taskContext)); - } else { - write0(records); - } - } - @Override - public void write(scala.collection.Iterator> records) throws IOException { - boolean needCleanupPusher = true; - try { - doWrite(records); - close(); - needCleanupPusher = false; - } finally { - if (needCleanupPusher) { - cleanupPusher(); - } - } - } - - @VisibleForTesting - boolean canUseFastWrite() { - boolean keyIsPartitionId = false; - if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer) { - // SPARK-39391 renames PartitionIdPassthrough's package - String partitionerClassName = partitioner.getClass().getSimpleName(); - keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName); - } - return keyIsPartitionId; - } - - private void fastWrite0(scala.collection.Iterator iterator) throws IOException { + protected void fastWrite0(scala.collection.Iterator iterator) throws IOException { final scala.collection.Iterator> records = iterator; SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) dep.serializer()); @@ -267,7 +115,7 @@ private void fastWrite0(scala.collection.Iterator iterator) throws IOException { dataSize.add(serializedRecordSize); } - if (serializedRecordSize > pushBufferMaxSize) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { byte[] giantBuffer = new byte[serializedRecordSize]; Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize)); Platform.copyMemory( @@ -301,7 +149,8 @@ private void doPush() throws IOException { writeMetrics.incWriteTime(System.nanoTime() - start); } - private void write0(scala.collection.Iterator iterator) throws IOException { + @Override + protected void write0(scala.collection.Iterator iterator) throws IOException { final scala.collection.Iterator> records = iterator; while (records.hasNext()) { @@ -316,7 +165,7 @@ private void write0(scala.collection.Iterator iterator) throws IOException { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); - if (serializedRecordSize > pushBufferMaxSize) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize); } else { boolean success = @@ -344,75 +193,18 @@ private void write0(scala.collection.Iterator iterator) throws IOException { } } - private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { - logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes)); - int bytesWritten = - shuffleClient.pushData( - shuffleId, - mapId, - encodedAttemptId, - partitionId, - buffer, - 0, - numBytes, - numMappers, - numPartitions); - mapStatusLengths[partitionId].add(bytesWritten); - writeMetrics.incBytesWritten(bytesWritten); - } - - private void cleanupPusher() throws IOException { + @Override + protected void cleanupPusher() throws IOException { if (pusher != null) { pusher.close(false); } } - private void close() throws IOException { - logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed())); + @Override + protected void closeWrite() throws IOException, InterruptedException { long pushStartTime = System.nanoTime(); pusher.pushData(false); pusher.close(true); - - shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId); writeMetrics.incWriteTime(System.nanoTime() - pushStartTime); - writeMetrics.incRecordsWritten(tmpRecordsWritten); - - long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); - writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); - } - - @Override - public Option stop(boolean success) { - try { - taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); - - if (stopping) { - return Option.empty(); - } else { - stopping = true; - if (success) { - BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId(); - MapStatus mapStatus = - SparkUtils.createMapStatus( - bmId, SparkUtils.unwrap(mapStatusLengths), taskContext.taskAttemptId()); - if (mapStatus == null) { - throw new IllegalStateException("Cannot call stop(true) without having called write()"); - } - return Option.apply(mapStatus); - } else { - return Option.empty(); - } - } - } finally { - shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId); - } - } - - // Added in SPARK-32917, for Spark 3.2 and above - @SuppressWarnings("MissingOverride") - public long[] getPartitionLengths() { - throw new UnsupportedOperationException( - "Celeborn is not compatible with push-based shuffle, please set spark.shuffle.push.enabled to false"); } } diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index af3c400ec7c..8e419033512 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -277,8 +277,7 @@ public ShuffleWriter getWriter( if (ShuffleMode.SORT.equals(shuffleMode)) { return new SortBasedShuffleWriter<>( shuffleId, - h.dependency(), - h.numMappers(), + h, context, celebornConf, shuffleClient, diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java index 0963737c0a5..c0d44007c60 100644 --- a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java @@ -64,7 +64,13 @@ protected ShuffleWriter createShuffleWriter( ShuffleWriteMetricsReporter metrics) throws IOException { return new SortBasedShuffleWriter( - handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60)); + SparkUtils.celebornShuffleId(client, handle, taskContext, true), + handle, + context, + conf, + client, + metrics, + SendBufferPool.get(4, 30, 60)); } private SortBasedShuffleWriter createShuffleWriterWithPusher( @@ -76,7 +82,14 @@ private SortBasedShuffleWriter createShuffleWriterWithP SortBasedPusher pusher) throws Exception { return new SortBasedShuffleWriter( - handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60), pusher); + SparkUtils.celebornShuffleId(client, handle, taskContext, true), + handle, + context, + conf, + client, + metrics, + SendBufferPool.get(4, 30, 60), + pusher); } private SortBasedPusher createSortBasedPusher( diff --git a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java index b09b1306c90..d2867391125 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java +++ b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java @@ -132,7 +132,7 @@ private void fastColumnarWrite0(scala.collection.Iterator iterator) throws IOExc } @Override - protected void closeWrite() throws IOException { + protected void closeWrite() throws IOException, InterruptedException { if (canUseFastWrite() && isColumnarShuffle) { closeColumnarWrite(); } else { From 1671c86c2d7c3405246a758b30c6d91bb7a8b9c4 Mon Sep 17 00:00:00 2001 From: binjie yang Date: Mon, 9 Dec 2024 21:02:51 +0800 Subject: [PATCH 2/3] fix doc style --- .../apache/spark/shuffle/celeborn/BasedShuffleWriter.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java index f732a68ad43..e3bbc2381de 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java @@ -199,10 +199,9 @@ protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) thr } /** - * This method will push the remaining data and close these pushers. - * It's important, will send Mapper End RPC to LifecycleManager to update - * the attempt of the corresponding task. - * We should only call this method when the task is successfully completed. + * This method will push the remaining data and close these pushers. It's important, will send + * Mapper End RPC to LifecycleManager to update the attempt of the corresponding task. We should + * only call this method when the task is successfully completed. */ protected void close() throws IOException, InterruptedException { long pushMergedDataTime = System.nanoTime(); From 52da4bd35a1c595a6b27d0d6870096735bb73f5e Mon Sep 17 00:00:00 2001 From: binjie yang Date: Mon, 9 Dec 2024 21:54:55 +0800 Subject: [PATCH 3/3] add apache license --- .../shuffle/celeborn/BasedShuffleWriter.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java index e3bbc2381de..0b83c7b70a1 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.shuffle.celeborn; import java.io.IOException;