You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/27 01:42:32 UTC
[spark] branch master updated: [SPARK-44146][CONNECT] Isolate Spark Connect Session jars and classfiles
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b02ea4cd370 [SPARK-44146][CONNECT] Isolate Spark Connect Session jars and classfiles
b02ea4cd370 is described below
commit b02ea4cd370ce6a066561dfde9d517ea70805a2b
Author: vicennial <ve...@databricks.com>
AuthorDate: Mon Jun 26 21:42:19 2023 -0400
[SPARK-44146][CONNECT] Isolate Spark Connect Session jars and classfiles
### What changes were proposed in this pull request?
This PR follows up on https://github.com/apache/spark/pull/41625 to utilize the classloader/resource isolation in Spark to support multi-user Spark Connect sessions which are isolated from each other (currently, classfiles and jars) and thus, enables multi-user REPLs and UDFs.
- Instead of a single instance of `SparkArtifactManager` handling all the artifact movement, each instance is now responsible for a single `sessionHolder` (i.e a Spark Connect session) which it requires in it's constructor.
- Previously, all artifacts were stored under a common directory `sparkConnectArtifactDirectory` which was initialised in `SparkContext`. Moving forward, all artifacts are instead separated based on the underlying `SparkSession` (using it's `sessionUUID`) they belong to in the format of `ROOT_ARTIFACT_DIR/<sessionUUID>/jars/...`.
- The `SparkConnectArtifactManager` also builds a `JobArtifactSet` [here](https://github.com/apache/spark/pull/41701/files#diff-f833145e80f2b42f54f446a0f173e60e3f5ad657a6ad1f2135bc5c20bcddc90cR157-R168) which is eventually propagated to the executors where the classloader isolation mechanism uses the `uuid` parameter.
- Currently, classfile and jars are isolated but files and archives aren't.
### Why are the changes needed?
To enable support for multi-user sessions coexisting on a singular Spark cluster. For example, multi-user Scala REPLs/UDFs will be supported with this PR.
### Does this PR introduce _any_ user-facing change?
Yes, multiple Spark Connect REPLs may use a single Spark cluster at once and execute their own UDFs without intefering with each other.
### How was this patch tested?
New unit tests in `ArtifactManagerSuite` + existing tests.
Closes #41701 from vicennial/SPARK-44146.
Authored-by: vicennial <ve...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../artifact/SparkConnectArtifactManager.scala | 205 +++++++++++++--------
.../sql/connect/planner/SparkConnectPlanner.scala | 19 +-
.../spark/sql/connect/service/SessionHolder.scala | 103 +++++++++++
.../service/SparkConnectAddArtifactsHandler.scala | 7 +-
.../service/SparkConnectAnalyzeHandler.scala | 21 ++-
.../sql/connect/service/SparkConnectService.scala | 4 +-
.../service/SparkConnectStreamHandler.scala | 15 +-
.../connect/artifact/ArtifactManagerSuite.scala | 167 ++++++++++++++---
.../scala/org/apache/spark/JobArtifactSet.scala | 19 +-
.../main/scala/org/apache/spark/SparkContext.scala | 23 +--
.../main/scala/org/apache/spark/rpc/RpcEnv.scala | 11 ++
.../spark/rpc/netty/NettyStreamManager.scala | 5 +
12 files changed, 441 insertions(+), 158 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 05c0a597722..0a91c6b9550 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -26,9 +26,12 @@ import javax.ws.rs.core.UriBuilder
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
+import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}
-import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
import org.apache.spark.sql.connect.service.SessionHolder
@@ -39,45 +42,29 @@ import org.apache.spark.util.Utils
* The Artifact Manager for the [[SparkConnectService]].
*
* This class handles the storage of artifacts as well as preparing the artifacts for use.
- * Currently, jars and classfile artifacts undergo additional processing:
- * - Jars and pyfiles are automatically added to the underlying [[SparkContext]] and are
- * accessible by all users of the cluster.
- * - Class files are moved into a common directory that is shared among all users of the
- * cluster. Note: Under a multi-user setup, class file conflicts may occur between user
- * classes as the class file directory is shared.
+ *
+ * Artifacts belonging to different [[SparkSession]]s are segregated and isolated from each other
+ * with the help of the `sessionUUID`.
+ *
+ * Jars and classfile artifacts are stored under "jars" and "classes" sub-directories respectively
+ * while other types of artifacts are stored under the root directory for that particular
+ * [[SparkSession]].
+ *
+ * @param sessionHolder
+ * The object used to hold the Spark Connect session state.
*/
-class SparkConnectArtifactManager private[connect] {
+class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging {
+ import SparkConnectArtifactManager._
- // The base directory where all artifacts are stored.
- // Note: If a REPL is attached to the cluster, class file artifacts are stored in the
- // REPL's output directory.
- private[connect] lazy val artifactRootPath = SparkContext.getActive match {
- case Some(sc) =>
- sc.sparkConnectArtifactDirectory.toPath
- case None =>
- throw new RuntimeException("SparkContext is uninitialized!")
- }
- private[connect] lazy val artifactRootURI = {
- val fileServer = SparkEnv.get.rpcEnv.fileServer
- fileServer.addDirectory("artifacts", artifactRootPath.toFile)
- }
-
- // The base directory where all class files are stored.
- // Note: If a REPL is attached to the cluster, we piggyback on the existing REPL output
- // directory to store class file artifacts.
- private[connect] lazy val classArtifactDir = SparkEnv.get.conf
- .getOption("spark.repl.class.outputDir")
- .map(p => Paths.get(p))
- .getOrElse(ArtifactUtils.concatenatePaths(artifactRootPath, "classes"))
-
- private[connect] lazy val classArtifactUri: String =
- SparkEnv.get.conf.getOption("spark.repl.class.uri") match {
- case Some(uri) => uri
- case None =>
- throw new RuntimeException("Class artifact URI had not been initialised in SparkContext!")
- }
+ private val sessionUUID = sessionHolder.session.sessionUUID
+ // The base directory/URI where all artifacts are stored for this `sessionUUID`.
+ val (artifactPath, artifactURI): (Path, String) =
+ getArtifactDirectoryAndUriForSession(sessionHolder)
+ // The base directory/URI where all class file artifacts are stored for this `sessionUUID`.
+ val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder)
private val jarsList = new CopyOnWriteArrayList[Path]
+ private val jarsURI = new CopyOnWriteArrayList[String]
private val pythonIncludeList = new CopyOnWriteArrayList[String]
/**
@@ -98,13 +85,11 @@ class SparkConnectArtifactManager private[connect] {
* Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes
* over the wire) for use.
*
- * @param session
* @param remoteRelativePath
* @param serverLocalStagingPath
* @param fragment
*/
private[connect] def addArtifact(
- sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path,
fragment: Option[String]): Unit = {
@@ -127,27 +112,28 @@ class SparkConnectArtifactManager private[connect] {
updater.save()
}(catchBlock = { tmpFile.delete() })
} else if (remoteRelativePath.startsWith(s"classes${File.separator}")) {
- // Move class files to common location (shared among all users)
+ // Move class files to the right directory.
val target = ArtifactUtils.concatenatePaths(
- classArtifactDir,
+ classDir,
remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
Files.createDirectories(target.getParent)
// Allow overwriting class files to capture updates to classes.
+ // This is required because the client currently sends all the class files in each class file
+ // transfer.
Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING)
} else {
- val target = ArtifactUtils.concatenatePaths(artifactRootPath, remoteRelativePath)
+ val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath)
Files.createDirectories(target.getParent)
- // Disallow overwriting jars because spark doesn't support removing jars that were
- // previously added,
+ // Disallow overwriting non-classfile artifacts
if (Files.exists(target)) {
throw new RuntimeException(
- s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.")
+ s"Duplicate Artifact: $remoteRelativePath. " +
+ "Artifacts cannot be overwritten.")
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
- // Adding Jars to the underlying spark context (visible to all users)
- sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
+ jarsURI.add(artifactURI + "/" + target.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
val stringRemotePath = remoteRelativePath.toString
@@ -165,8 +151,47 @@ class SparkConnectArtifactManager private[connect] {
}
}
+ /**
+ * Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files.
+ */
+ def jobArtifactSet: JobArtifactSet = {
+ val builder = Map.newBuilder[String, Long]
+ jarsURI.forEach { jar =>
+ builder += jar -> 0
+ }
+
+ new JobArtifactSet(
+ uuid = Option(sessionUUID),
+ replClassDirUri = Option(classURI),
+ jars = builder.result(),
+ files = Map.empty,
+ archives = Map.empty)
+ }
+
+ /**
+ * Returns a [[ClassLoader]] for session-specific jar/class file resources.
+ */
+ def classloader: ClassLoader = {
+ val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL
+ new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ }
+
+ /**
+ * Cleans up all resources specific to this `sessionHolder`.
+ */
+ private[connect] def cleanUpResources(): Unit = {
+ logDebug(
+ s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " +
+ s"sessionId: ${sessionHolder.sessionId}")
+ // Clean up cached relations
+ val blockManager = sessionHolder.session.sparkContext.env.blockManager
+ blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId)
+
+ // Clean up artifacts folder
+ FileUtils.deleteDirectory(artifactRootPath.toFile)
+ }
+
private[connect] def uploadArtifactToFs(
- sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
val hadoopConf = sessionHolder.session.sparkContext.hadoopConfiguration
@@ -200,48 +225,80 @@ class SparkConnectArtifactManager private[connect] {
}
}
-object SparkConnectArtifactManager {
+object SparkConnectArtifactManager extends Logging {
val forwardToFSPrefix = "forward_to_fs"
- private var _activeArtifactManager: SparkConnectArtifactManager = _
+ private var currentArtifactRootUri: String = _
+ private var lastKnownSparkContextInstance: SparkContext = _
- /**
- * Obtain the active artifact manager or create a new artifact manager.
- *
- * @return
- */
- def getOrCreateArtifactManager: SparkConnectArtifactManager = {
- if (_activeArtifactManager == null) {
- _activeArtifactManager = new SparkConnectArtifactManager
- }
- _activeArtifactManager
+ private val ARTIFACT_DIRECTORY_PREFIX = "artifacts"
+
+ // The base directory where all artifacts are stored.
+ private[spark] lazy val artifactRootPath = {
+ Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath
}
- private lazy val artifactManager = getOrCreateArtifactManager
+ private[spark] def getArtifactDirectoryAndUriForSession(session: SparkSession): (Path, String) =
+ (
+ ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID),
+ s"$artifactRootURI/${session.sessionUUID}")
+
+ private[spark] def getArtifactDirectoryAndUriForSession(
+ sessionHolder: SessionHolder): (Path, String) =
+ getArtifactDirectoryAndUriForSession(sessionHolder.session)
+
+ private[spark] def getClassfileDirectoryAndUriForSession(
+ session: SparkSession): (Path, String) = {
+ val (artDir, artUri) = getArtifactDirectoryAndUriForSession(session)
+ (ArtifactUtils.concatenatePaths(artDir, "classes"), s"$artUri/classes/")
+ }
+
+ private[spark] def getClassfileDirectoryAndUriForSession(
+ sessionHolder: SessionHolder): (Path, String) =
+ getClassfileDirectoryAndUriForSession(sessionHolder.session)
/**
- * Obtain a classloader that contains jar and classfile artifacts on the classpath.
+ * Updates the URI for the artifact directory.
*
- * @return
+ * This is required if the SparkContext is restarted.
+ *
+ * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted
+ * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is
+ * not expected to be restarted at any point in time.
*/
- def classLoaderWithArtifacts: ClassLoader = {
- val urls = artifactManager.getSparkConnectAddedJars :+
- artifactManager.classArtifactDir.toUri.toURL
- new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ private def refreshArtifactUri(sc: SparkContext): Unit = synchronized {
+ // If a competing thread had updated the URI, we do not need to refresh the URI again.
+ if (sc eq lastKnownSparkContextInstance) {
+ return
+ }
+ val oldArtifactUri = currentArtifactRootUri
+ currentArtifactRootUri = SparkEnv.get.rpcEnv.fileServer
+ .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile)
+ lastKnownSparkContextInstance = sc
+ logDebug(s"Artifact URI updated from $oldArtifactUri to $currentArtifactRootUri")
}
/**
- * Run a segment of code utilising a classloader that contains jar and classfile artifacts on
- * the classpath.
+ * Checks if the URI for the artifact directory needs to be updated. This is required in cases
+ * where SparkContext is restarted as the old URI would no longer be valid.
*
- * @param thunk
- * @tparam T
- * @return
+ * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted
+ * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is
+ * not expected to be restarted at any point in time.
*/
- def withArtifactClassLoader[T](thunk: => T): T = {
- Utils.withContextClassLoader(classLoaderWithArtifacts) {
- thunk
+ private def updateUriIfRequired(): Unit = {
+ SparkContext.getActive.foreach { sc =>
+ if (lastKnownSparkContextInstance == null || (sc ne lastKnownSparkContextInstance)) {
+ logDebug("Refreshing artifact URI due to SparkContext (re)initialisation!")
+ refreshArtifactUri(sc)
+ }
}
}
+
+ private[connect] def artifactRootURI: String = {
+ updateUriIfRequired()
+ require(currentArtifactRootUri != null)
+ currentArtifactRootUri
+ }
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 907d25e1ee1..c19fc5fe90e 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -25,6 +25,8 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.exception.ExceptionUtils
+import org.json4s._
+import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
@@ -50,7 +52,6 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
-import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
@@ -88,6 +89,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
+ // SparkConnectPlanner is used per request.
+ private lazy val pythonIncludes = {
+ implicit val formats = DefaultFormats
+ parse(session.conf.get("spark.connect.pythonUDF.includes", "[]"))
+ .extract[Array[String]]
+ .toList
+ .asJava
+ }
+
// The root of the query plan is a relation and we apply the transformations to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
val plan = rel.getRelTypeCase match {
@@ -1421,13 +1431,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
Utils.deserialize[UdfPacket](
fun.getScalarScalaUdf.getPayload.toByteArray,
- SparkConnectArtifactManager.classLoaderWithArtifacts)
+ Utils.getContextOrSparkClassLoader)
}
private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = {
Utils.deserialize[ForeachWriterPacket](
fun.getPayload.toByteArray,
- SparkConnectArtifactManager.classLoaderWithArtifacts)
+ Utils.getContextOrSparkClassLoader)
}
/**
@@ -1484,8 +1494,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
command = fun.getCommand.toByteArray,
// Empty environment variables
envVars = Maps.newHashMap(),
- pythonIncludes =
- SparkConnectArtifactManager.getOrCreateArtifactManager.getSparkConnectPythonIncludes.asJava,
+ pythonIncludes = pythonIncludes,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index cc2327abb5c..00432209779 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -17,15 +17,22 @@
package org.apache.spark.sql.connect.service
+import java.nio.file.Path
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.JobArtifactSet
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
+import org.apache.spark.util.Utils
/**
* Object used to hold the Spark Connect session state.
@@ -60,6 +67,102 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
}
}
}
+
+ private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this)
+
+ /**
+ * Add an artifact to this SparkConnect session.
+ *
+ * @param remoteRelativePath
+ * @param serverLocalStagingPath
+ * @param fragment
+ */
+ private[connect] def addArtifact(
+ remoteRelativePath: Path,
+ serverLocalStagingPath: Path,
+ fragment: Option[String]): Unit = {
+ artifactManager.addArtifact(remoteRelativePath, serverLocalStagingPath, fragment)
+ }
+
+ /**
+ * A [[JobArtifactSet]] for this SparkConnect session.
+ */
+ def connectJobArtifactSet: JobArtifactSet = artifactManager.jobArtifactSet
+
+ /**
+ * A [[ClassLoader]] for jar/class file resources specific to this SparkConnect session.
+ */
+ def classloader: ClassLoader = artifactManager.classloader
+
+ /**
+ * Expire this session and trigger state cleanup mechanisms.
+ */
+ private[connect] def expireSession(): Unit = {
+ logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId")
+ artifactManager.cleanUpResources()
+ }
+
+ /**
+ * Execute a block of code using this session's classloader.
+ * @param f
+ * @tparam T
+ */
+ def withContext[T](f: => T): T = {
+ // Needed for deserializing and evaluating the UDF on the driver
+ Utils.withContextClassLoader(classloader) {
+ // Needed for propagating the dependencies to the executors.
+ JobArtifactSet.withActive(connectJobArtifactSet) {
+ f
+ }
+ }
+ }
+
+ /**
+ * Set the session-based Python paths to include in Python UDF.
+ * @param f
+ * @tparam T
+ */
+ def withSessionBasedPythonPaths[T](f: => T): T = {
+ try {
+ session.conf.set(
+ "spark.connect.pythonUDF.includes",
+ compact(render(artifactManager.getSparkConnectPythonIncludes)))
+ f
+ } finally {
+ session.conf.unset("spark.connect.pythonUDF.includes")
+ }
+ }
+
+ /**
+ * Execute a block of code with this session as the active SparkConnect session.
+ * @param f
+ * @tparam T
+ */
+ def withSession[T](f: SparkSession => T): T = {
+ withSessionBasedPythonPaths {
+ withContext {
+ session.withActive {
+ f(session)
+ }
+ }
+ }
+ }
+
+ /**
+ * Execute a block of code using the session from this [[SessionHolder]] as the active
+ * SparkConnect session.
+ * @param f
+ * @tparam T
+ */
+ def withSessionHolder[T](f: SessionHolder => T): T = {
+ withSessionBasedPythonPaths {
+ withContext {
+ session.withActive {
+ f(this)
+ }
+ }
+ }
+ }
}
object SessionHolder {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index 179ff1b3ec9..e424331e761 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -49,8 +49,6 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
// several [[AddArtifactsRequest]]s.
private var chunkedArtifact: StagedChunkedArtifact = _
private var holder: SessionHolder = _
- private def artifactManager: SparkConnectArtifactManager =
- SparkConnectArtifactManager.getOrCreateArtifactManager
override def onNext(req: AddArtifactsRequest): Unit = {
if (this.holder == null) {
@@ -87,7 +85,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
}
protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = {
- artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, artifact.fragment)
+ require(holder != null)
+ holder.addArtifact(artifact.path, artifact.stagedPath, artifact.fragment)
}
/**
@@ -103,7 +102,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
if (artifact.getCrcStatus.contains(true)) {
if (artifact.path.startsWith(
SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) {
- artifactManager.uploadArtifactToFs(holder, artifact.path, artifact.stagedPath)
+ holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 947f6ebbebe..5c069bfaf5d 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -24,7 +24,6 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
@@ -33,16 +32,18 @@ private[connect] class SparkConnectAnalyzeHandler(
responseObserver: StreamObserver[proto.AnalyzePlanResponse])
extends Logging {
- def handle(request: proto.AnalyzePlanRequest): Unit =
- SparkConnectArtifactManager.withArtifactClassLoader {
- val sessionHolder = SparkConnectService
- .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
- sessionHolder.session.withActive {
- val response = process(request, sessionHolder)
- responseObserver.onNext(response)
- responseObserver.onCompleted()
- }
+ def handle(request: proto.AnalyzePlanRequest): Unit = {
+ val sessionHolder = SparkConnectService.getOrCreateIsolatedSession(
+ request.getUserContext.getUserId,
+ request.getSessionId)
+ // `withSession` ensures that session-specific artifacts (such as JARs and class files) are
+ // available during processing (such as deserialization).
+ sessionHolder.withSessionHolder { sessionHolder =>
+ val response = process(request, sessionHolder)
+ responseObserver.onNext(response)
+ responseObserver.onCompleted()
}
+ }
def process(
request: proto.AnalyzePlanRequest,
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index c1647fd85a0..0f90bccaac8 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -301,9 +301,7 @@ object SparkConnectService {
private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] {
override def onRemoval(
notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = {
- val SessionHolder(userId, sessionId, session) = notification.getValue
- val blockManager = session.sparkContext.env.blockManager
- blockManager.removeCache(userId, sessionId)
+ notification.getValue.expireSession()
}
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index b7fdb07d788..d809833d012 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -30,7 +30,6 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils}
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -45,12 +44,14 @@ import org.apache.spark.util.{ThreadUtils, Utils}
class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
extends Logging {
- def handle(v: ExecutePlanRequest): Unit = SparkConnectArtifactManager.withArtifactClassLoader {
- val sessionHolder = SparkConnectService
- .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
- val session = sessionHolder.session
-
- session.withActive {
+ def handle(v: ExecutePlanRequest): Unit = {
+ val sessionHolder =
+ SparkConnectService
+ .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+ // `withSession` ensures that session-specific artifacts (such as JARs and class files) are
+ // available during processing.
+ sessionHolder.withSession { session =>
+ // Add debug information to the query execution so that the jobs are traceable.
val debugString =
try {
Utils.redact(
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index b6da38fc572..42ab8ca18f6 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths}
import org.apache.commons.io.FileUtils
-import org.apache.spark.SparkConf
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.sql.connect.ResourceHelper
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.functions.col
@@ -39,21 +39,30 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
}
private val artifactPath = commonResourcePath.resolve("artifact-tests")
- private lazy val artifactManager = SparkConnectArtifactManager.getOrCreateArtifactManager
-
private def sessionHolder(): SessionHolder = {
SessionHolder("test", spark.sessionUUID, spark)
}
+ private lazy val artifactManager = new SparkConnectArtifactManager(sessionHolder())
+
+ private def sessionUUID: String = spark.sessionUUID
+
+ override def afterEach(): Unit = {
+ artifactManager.cleanUpResources()
+ super.afterEach()
+ }
test("Jar artifacts are added to spark session") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("smallJar.jar")
val remotePath = Paths.get("jars/smallJar.jar")
- artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
- val jarList = spark.sparkContext.listJars()
- assert(jarList.exists(_.contains(remotePath.toString)))
+ val expectedPath = SparkConnectArtifactManager.artifactRootPath
+ .resolve(s"$sessionUUID/jars/smallJar.jar")
+ assert(expectedPath.toFile.exists())
+ val jars = artifactManager.jobArtifactSet.jars
+ assert(jars.exists(_._1.contains(remotePath.toString)))
}
test("Class artifacts are added to the correct directory.") {
@@ -62,10 +71,11 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("smallClassFile.class")
val remotePath = Paths.get("classes/smallClassFile.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
- val classFileDirectory = artifactManager.classArtifactDir
- val movedClassFile = classFileDirectory.resolve("smallClassFile.class").toFile
+ val movedClassFile = SparkConnectArtifactManager.artifactRootPath
+ .resolve(s"$sessionUUID/classes/smallClassFile.class")
+ .toFile
assert(movedClassFile.exists())
}
@@ -75,13 +85,14 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
- val classFileDirectory = artifactManager.classArtifactDir
- val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
+ val movedClassFile = SparkConnectArtifactManager.artifactRootPath
+ .resolve(s"$sessionUUID/classes/Hello.class")
+ .toFile
assert(movedClassFile.exists())
- val classLoader = SparkConnectArtifactManager.classLoaderWithArtifacts
+ val classLoader = artifactManager.classloader
val instance = classLoader
.loadClass("Hello")
@@ -98,22 +109,26 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
- val classFileDirectory = artifactManager.classArtifactDir
- val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
- assert(movedClassFile.exists())
+ val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session")
+ sessionHolder.addArtifact(remotePath, stagingPath, None)
- val classLoader = SparkConnectArtifactManager.classLoaderWithArtifacts
+ val movedClassFile = SparkConnectArtifactManager.artifactRootPath
+ .resolve(s"${sessionHolder.session.sessionUUID}/classes/Hello.class")
+ .toFile
+ assert(movedClassFile.exists())
+ val classLoader = sessionHolder.classloader
val instance = classLoader
.loadClass("Hello")
.getDeclaredConstructor(classOf[String])
.newInstance("Talon")
.asInstanceOf[String => String]
val udf = org.apache.spark.sql.functions.udf(instance)
- val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session
- session.range(10).select(udf(col("id").cast("string"))).collect()
+
+ sessionHolder.withSession { session =>
+ session.range(10).select(udf(col("id").cast("string"))).collect()
+ }
}
test("add a cache artifact to the Block Manager") {
@@ -125,7 +140,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val blockManager = spark.sparkContext.env.blockManager
val blockId = CacheId(session.userId, session.sessionId, "abc")
try {
- artifactManager.addArtifact(session, remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
val bytes = blockManager.getLocalBytes(blockId)
assert(bytes.isDefined)
val readback = new String(bytes.get.toByteBuffer().array(), StandardCharsets.UTF_8)
@@ -141,9 +156,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
withTempPath { path =>
val stagingPath = path.toPath
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
- val session = sessionHolder()
val remotePath = Paths.get("pyfiles/abc.zip")
- artifactManager.addArtifact(session, remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip"))
}
}
@@ -155,10 +169,113 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("smallClassFile.class")
val remotePath = Paths.get("forward_to_fs", destFSDir.toString, "smallClassFileCopied.class")
assert(stagingPath.toFile.exists())
- artifactManager.uploadArtifactToFs(sessionHolder, remotePath, stagingPath)
- artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
+ artifactManager.uploadArtifactToFs(remotePath, stagingPath)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
val copiedClassFile = Paths.get(destFSDir.toString, "smallClassFileCopied.class").toFile
assert(copiedClassFile.exists())
}
+
+ test("Removal of resources") {
+ withTempPath { path =>
+ // Setup cache
+ val stagingPath = path.toPath
+ Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
+ val remotePath = Paths.get("cache/abc")
+ val session = sessionHolder()
+ val blockManager = spark.sparkContext.env.blockManager
+ val blockId = CacheId(session.userId, session.sessionId, "abc")
+ // Setup artifact dir
+ val copyDir = Utils.createTempDir().toPath
+ FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
+ try {
+ artifactManager.addArtifact(remotePath, stagingPath, None)
+ val stagingPathFile = copyDir.resolve("smallClassFile.class")
+ val remotePathFile = Paths.get("classes/smallClassFile.class")
+ artifactManager.addArtifact(remotePathFile, stagingPathFile, None)
+
+ // Verify resources exist
+ val bytes = blockManager.getLocalBytes(blockId)
+ assert(bytes.isDefined)
+ blockManager.releaseLock(blockId)
+ val expectedPath = SparkConnectArtifactManager.artifactRootPath
+ .resolve(s"$sessionUUID/classes/smallClassFile.class")
+ assert(expectedPath.toFile.exists())
+
+ // Remove resources
+ artifactManager.cleanUpResources()
+
+ assert(!blockManager.getLocalBytes(blockId).isDefined)
+ assert(!expectedPath.toFile.exists())
+ } finally {
+ try {
+ blockManager.releaseLock(blockId)
+ } catch {
+ case _: SparkException =>
+ case throwable: Throwable => throw throwable
+ } finally {
+ FileUtils.deleteDirectory(copyDir.toFile)
+ blockManager.removeCache(session.userId, session.sessionId)
+ }
+ }
+ }
+ }
+
+ test("Classloaders for spark sessions are isolated") {
+ val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1")
+ val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2")
+
+ def addHelloClass(holder: SessionHolder): Unit = {
+ val copyDir = Utils.createTempDir().toPath
+ FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
+ val stagingPath = copyDir.resolve("Hello.class")
+ val remotePath = Paths.get("classes/Hello.class")
+ assert(stagingPath.toFile.exists())
+ holder.addArtifact(remotePath, stagingPath, None)
+ }
+
+ // Add the classfile only for the first user
+ addHelloClass(holder1)
+
+ val classLoader1 = holder1.classloader
+ val instance1 = classLoader1
+ .loadClass("Hello")
+ .getDeclaredConstructor(classOf[String])
+ .newInstance("Talon")
+ .asInstanceOf[String => String]
+ val udf1 = org.apache.spark.sql.functions.udf(instance1)
+
+ holder1.withSession { session =>
+ session.range(10).select(udf1(col("id").cast("string"))).collect()
+ }
+
+ assertThrows[ClassNotFoundException] {
+ val classLoader2 = holder2.classloader
+ val instance2 = classLoader2
+ .loadClass("Hello")
+ .getDeclaredConstructor(classOf[String])
+ .newInstance("Talon")
+ .asInstanceOf[String => String]
+ }
+ }
+}
+
+class ArtifactUriSuite extends SparkFunSuite with LocalSparkContext {
+
+ private def createSparkContext(): Unit = {
+ resetSparkContext()
+ sc = new SparkContext("local[4]", "test", new SparkConf())
+
+ }
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ createSparkContext()
+ }
+
+ test("Artifact URI is reset when SparkContext is restarted") {
+ val oldUri = SparkConnectArtifactManager.artifactRootURI
+ createSparkContext()
+ val newUri = SparkConnectArtifactManager.artifactRootURI
+ assert(newUri != oldUri)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
index d87c25c0b7c..3e402b3b330 100644
--- a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
+++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.Serializable
+import java.util.Objects
/**
* Artifact set for a job.
@@ -41,7 +42,7 @@ class JobArtifactSet(
def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f)
override def hashCode(): Int = {
- Seq(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq).hashCode()
+ Objects.hash(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq)
}
override def equals(obj: Any): Boolean = {
@@ -76,17 +77,17 @@ object JobArtifactSet {
archives = sc.addedArchives.toMap)
}
+ private lazy val emptyJobArtifactSet = new JobArtifactSet(
+ None,
+ None,
+ Map.empty,
+ Map.empty,
+ Map.empty)
+
/**
* Empty artifact set for use in tests.
*/
- private[spark] def apply(): JobArtifactSet = {
- new JobArtifactSet(
- None,
- None,
- Map.empty,
- Map.empty,
- Map.empty)
- }
+ private[spark] def apply(): JobArtifactSet = emptyJobArtifactSet
/**
* Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 961f02c640e..fe3fe1be429 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -19,7 +19,6 @@ package org.apache.spark
import java.io._
import java.net.URI
-import java.nio.file.Files
import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference}
@@ -42,7 +41,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.logging.log4j.Level
-import org.apache.spark.annotation.{DeveloperApi, Experimental, Private}
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.executor.{Executor, ExecutorMetrics, ExecutorMetricsSource}
@@ -390,13 +389,6 @@ class SparkContext(config: SparkConf) extends Logging {
Utils.setLogLevel(Level.toLevel(upperCased))
}
- /**
- * :: Private ::
- * Returns the directory that stores artifacts transferred through Spark Connect.
- */
- @Private
- private[spark] lazy val sparkConnectArtifactDirectory: File = Utils.createTempDir("artifacts")
-
try {
_conf = config.clone()
_conf.get(SPARK_LOG_LEVEL).foreach { level =>
@@ -482,18 +474,7 @@ class SparkContext(config: SparkConf) extends Logging {
SparkEnv.set(_env)
// If running the REPL, register the repl's output dir with the file server.
- _conf.getOption("spark.repl.class.outputDir").orElse {
- if (_conf.get(PLUGINS).contains("org.apache.spark.sql.connect.SparkConnectPlugin")) {
- // For Spark Connect, we piggyback on the existing REPL integration to load class
- // files on the executors.
- // This is a temporary intermediate step due to unavailable classloader isolation.
- val classDirectory = sparkConnectArtifactDirectory.toPath.resolve("classes")
- Files.createDirectories(classDirectory)
- Some(classDirectory.toString)
- } else {
- None
- }
- }.foreach { path =>
+ _conf.getOption("spark.repl.class.outputDir").foreach { path =>
val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path))
_conf.set("spark.repl.class.uri", replUri)
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 82d3a28894b..2fce2889c09 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -186,6 +186,17 @@ private[spark] trait RpcEnvFileServer {
*/
def addDirectory(baseUri: String, path: File): String
+ /**
+ * Adds a local directory to be served via this file server.
+ * If the directory is already registered with the file server, it will result in a no-op.
+ *
+ * @param baseUri Leading URI path (files can be retrieved by appending their relative
+ * path to this base URI). This cannot be "files" nor "jars".
+ * @param path Path to the local directory.
+ * @return URI for the root of the directory in the file server.
+ */
+ def addDirectoryIfAbsent(baseUri: String, path: File): String
+
/** Validates and normalizes the base URI for directories. */
protected def validateDirectoryUri(baseUri: String): String = {
val baseCanonicalUri = new URI(baseUri).normalize().getPath
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
index 73eb9a34669..57243133aba 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -90,4 +90,9 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
s"${rpcEnv.address.toSparkURL}$fixedBaseUri"
}
+ override def addDirectoryIfAbsent(baseUri: String, path: File): String = {
+ val fixedBaseUri = validateDirectoryUri(baseUri)
+ dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path.getCanonicalFile)
+ s"${rpcEnv.address.toSparkURL}$fixedBaseUri"
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org