Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx committed Feb 7, 2025
1 parent 96758bc commit d2084dc
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ import java.net.URL
import java.time.ZoneId
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.sys.process._
import scala.util.Try

import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner, NvtxColor, NvtxRange}
import com.nvidia.spark.DFUDFPlugin
import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars
Expand All @@ -35,7 +33,6 @@ import com.nvidia.spark.rapids.io.async.TrafficController
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskContext, TaskFailedReason}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
Expand All @@ -47,6 +44,8 @@ import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.rapids.execution.TrampolineUtil

import scala.collection.mutable

class PluginException(msg: String) extends RuntimeException(msg)

case class CudfVersionMismatchException(errorMsg: String) extends PluginException(errorMsg)
Expand Down Expand Up @@ -353,20 +352,29 @@ object RapidsPluginUtils extends Logging {
val resourceName = "spark-rapids-extra-plugins"
val classLoader = RapidsPluginUtils.getClass.getClassLoader
val resourceUrls = classLoader.getResources(resourceName)
val resourceUrlArray = resourceUrls.asScala.toArray.distinct
// Somehow, it is possible that the definition of same Plugin occurs multiple times in the
// resourceUrls. Therefore, deduplication work is essential in case of loading some plugins
// repeatedly.
val distinctResources = mutable.HashSet.empty[URL]
while (resourceUrls.hasMoreElements) {
val url = resourceUrls.nextElement()
if (distinctResources.contains(url)) {
logWarning(s"Found duplicated definition of ExtraPlugin: $url! Discarded it.")
} else {
distinctResources.add(url)
}
}

if (resourceUrlArray.isEmpty) {
if (distinctResources.isEmpty) {
logDebug(s"Could not find file $resourceName in the classpath, not loading extra plugins")
Seq.empty
} else {
val plugins = scala.collection.mutable.ListBuffer[SparkPlugin]()
for (resourceUrl <- resourceUrlArray) {
distinctResources.iterator.flatMap { resourceUrl =>
val source = scala.io.Source.fromURL(resourceUrl)
val pluginClasses = source.getLines().toList
source.close()
plugins ++= loadExtensions(classOf[SparkPlugin], pluginClasses)
}
plugins.toSeq
loadExtensions(classOf[SparkPlugin], pluginClasses)
}.toList
}
}

Expand Down

0 comments on commit d2084dc

Please sign in to comment.