Skip to content

Commit

Permalink
update spark patch to abort stage when rerun skew join stage
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshengjie123 committed Nov 27, 2024
1 parent a11b57c commit 09522d6
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 70 deletions.
33 changes: 19 additions & 14 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ index e469c9989f2..245d9b3b9de 100644

/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b950c07f3d8..e9e10bb647f 100644
index b950c07f3d8..d081b4642c9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1369,7 +1369,10 @@ private[spark] class DAGScheduler(
Expand All @@ -50,21 +50,20 @@ index b950c07f3d8..e9e10bb647f 100644
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
@@ -1780,7 +1783,8 @@ private[spark] class DAGScheduler(
@@ -1780,7 +1783,7 @@ private[spark] class DAGScheduler(
failedStage.failedAttemptIds.add(task.stageAttemptId)
val shouldAbortStage =
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
- disallowStageRetryForTest
+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage]
+ && mapOutputTracker.skewShuffleIds.contains(shuffleId))
+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId)

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6bc8ba4eebb..2e7d87c96eb 100644
index 6bc8ba4eebb..44db30dbaec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3431,6 +3431,12 @@ object SQLConf {
@@ -3431,6 +3431,19 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

Expand All @@ -73,16 +72,25 @@ index 6bc8ba4eebb..2e7d87c96eb 100644
+ .version("3.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CELEBORN_STAGE_RERUN_ENABLED =
+ buildConf("spark.celeborn.client.spark.stageRerun.enabled")
+ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
+ .version("3.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -4154,6 +4160,9 @@ class SQLConf extends Serializable with Logging {
@@ -4154,6 +4167,11 @@ class SQLConf extends Serializable with Logging {

def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG)

+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean =
+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)
+
+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED)
+
/** ********************** SQLConf functionality methods ************ */

Expand Down Expand Up @@ -189,7 +197,7 @@ index 88abe68197b..150699a84a3 100644
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
index 3609548f374..f7c6d5dda90 100644
index 3609548f374..59c80198f19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
Expand All @@ -200,7 +208,7 @@ index 3609548f374..f7c6d5dda90 100644

object ShufflePartitionsUtil extends Logging {
final val SMALL_PARTITION_FACTOR = 0.2
@@ -376,11 +377,25 @@ object ShufflePartitionsUtil extends Logging {
@@ -376,11 +377,22 @@ object ShufflePartitionsUtil extends Logging {
def createSkewPartitionSpecs(
shuffleId: Int,
reducerId: Int,
Expand All @@ -215,10 +223,7 @@ index 3609548f374..f7c6d5dda90 100644
+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
+
+ val throwsFetchFailure = SparkEnv.get
+ .conf
+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false")
+ .toBoolean
+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled
+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
+ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand All @@ -227,7 +232,7 @@ index 3609548f374..f7c6d5dda90 100644
Some(mapStartIndices.indices.map { i =>
val startMapIndex = mapStartIndices(i)
val endMapIndex = if (i == mapStartIndices.length - 1) {
@@ -388,8 +403,20 @@ object ShufflePartitionsUtil extends Logging {
@@ -388,8 +400,20 @@ object ShufflePartitionsUtil extends Logging {
} else {
mapStartIndices(i + 1)
}
Expand Down
46 changes: 30 additions & 16 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ index b1974948430..0dc92ec44a8 100644

/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index bd2823bcac1..5d81b9de5b6 100644
index bd2823bcac1..4f40becadc7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler(
Expand All @@ -50,21 +50,20 @@ index bd2823bcac1..5d81b9de5b6 100644
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
@@ -1851,7 +1854,8 @@ private[spark] class DAGScheduler(
@@ -1851,7 +1854,7 @@ private[spark] class DAGScheduler(
failedStage.failedAttemptIds.add(task.stageAttemptId)
val shouldAbortStage =
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
- disallowStageRetryForTest
+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage]
+ && mapOutputTracker.skewShuffleIds.contains(shuffleId))
+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId)

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index af03ad9a4cb..7a3ee9ebfaf 100644
index af03ad9a4cb..6c36fb96d58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3784,6 +3784,12 @@ object SQLConf {
@@ -3784,6 +3784,19 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

Expand All @@ -73,16 +72,25 @@ index af03ad9a4cb..7a3ee9ebfaf 100644
+ .version("3.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CELEBORN_STAGE_RERUN_ENABLED =
+ buildConf("spark.celeborn.client.spark.stageRerun.enabled")
+ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
+ .version("3.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -4549,6 +4555,9 @@ class SQLConf extends Serializable with Logging {
@@ -4549,6 +4562,11 @@ class SQLConf extends Serializable with Logging {
def histogramNumericPropagateInputType: Boolean =
getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE)

+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean =
+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)
+
+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED)
+
/** ********************** SQLConf functionality methods ************ */

Expand Down Expand Up @@ -190,17 +198,25 @@ index d4a173bb9cc..21ef335e064 100644
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
index af689db3379..9d9f9c994b9 100644
index af689db3379..529097549ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
@@ -380,13 +380,27 @@ object ShufflePartitionsUtil extends Logging {
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec}
+import org.apache.spark.sql.internal.SQLConf

object ShufflePartitionsUtil extends Logging {
final val SMALL_PARTITION_FACTOR = 0.2
@@ -380,13 +381,23 @@ object ShufflePartitionsUtil extends Logging {
shuffleId: Int,
reducerId: Int,
targetSize: Long,
- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR)
- : Option[Seq[PartialReducerPartitionSpec]] = {
+ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR,
+ isCelebornShuffle: Boolean = false)
: Option[Seq[PartialReducerPartitionSpec]] = {
+ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = {
val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
if (mapPartitionSizes.exists(_ < 0)) return None
val mapStartIndices = splitSizeListByTargetSize(
Expand All @@ -210,10 +226,7 @@ index af689db3379..9d9f9c994b9 100644
+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
+
+ val throwsFetchFailure = SparkEnv.get
+ .conf
+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false")
+ .toBoolean
+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled
+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
+ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand All @@ -222,11 +235,12 @@ index af689db3379..9d9f9c994b9 100644
Some(mapStartIndices.indices.map { i =>
val startMapIndex = mapStartIndices(i)
val endMapIndex = if (i == mapStartIndices.length - 1) {
@@ -400,7 +414,14 @@ object ShufflePartitionsUtil extends Logging {
@@ -400,7 +411,15 @@ object ShufflePartitionsUtil extends Logging {
dataSize += mapPartitionSizes(mapIndex)
mapIndex += 1
}
- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
+
+ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
+ // These `dataSize` variables may not be accurate as they only represent the sum of
+ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled.
Expand Down
Loading

0 comments on commit 09522d6

Please sign in to comment.