You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/24 12:44:38 UTC
[spark] branch master updated: [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c0d301ea3c3 [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client
c0d301ea3c3 is described below
commit c0d301ea3c3f6e3d1b10373823e0aeeb997e8daf
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Fri Feb 24 08:44:23 2023 -0400
[SPARK-42533][CONNECT][SCALA] Add ssl for Scala client
### What changes were proposed in this pull request?
Adding SSL encryption and access token support for Scala client
### Why are the changes needed?
To support basic client side encryption to protect data sent over the network.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests. Manual tests.
Closes #40133 from zhenlineo/ssl.
Authored-by: Zhen Li <zh...@users.noreply.github.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 7 +-
.../sql/connect/client/SparkConnectClient.scala | 193 +++++++++++++++++++--
.../apache/spark/sql/PlanGenerationTestSuite.scala | 2 +-
.../connect/client/SparkConnectClientSuite.scala | 76 ++++----
4 files changed, 229 insertions(+), 49 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b086db09365..0e5aaace20d 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -224,7 +224,12 @@ object SparkSession extends Logging {
class Builder() extends Logging {
private var _client: SparkConnectClient = _
- def client(client: SparkConnectClient): Builder = {
+ def remote(connectionString: String): Builder = {
+ client(SparkConnectClient.builder().connectionString(connectionString).build())
+ this
+ }
+
+ private[sql] def client(client: SparkConnectClient): Builder = {
_client = client
this
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 3049a0a0a5d..12bb581880c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,21 +17,22 @@
package org.apache.spark.sql.connect.client
-import scala.language.existentials
-
-import io.grpc.{ManagedChannel, ManagedChannelBuilder}
+import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, Metadata, MethodDescriptor, Status, TlsChannelCredentials}
import java.net.URI
import java.util.UUID
+import java.util.concurrent.Executor
import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.UserContext
import org.apache.spark.sql.connect.common.config.ConnectCommon
/**
* Conceptually the remote spark session that communicates with the server.
*/
-class SparkConnectClient(
+private[sql] class SparkConnectClient(
private val userContext: proto.UserContext,
- private val channel: ManagedChannel) {
+ private val channel: ManagedChannel,
+ private[client] val userAgent: String) {
private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
@@ -40,7 +41,7 @@ class SparkConnectClient(
* @return
* User ID.
*/
- def userId: String = userContext.getUserId()
+ private[client] def userId: String = userContext.getUserId()
// Generate a unique session ID for this client. This UUID must be unique to allow
// concurrent Spark sessions of the same user. If the channel is closed, creating
@@ -60,6 +61,8 @@ class SparkConnectClient(
.newBuilder()
.setPlan(plan)
.setUserContext(userContext)
+ .setClientId(sessionId)
+ .setClientType(userAgent)
.build()
stub.executePlan(request)
}
@@ -77,6 +80,7 @@ class SparkConnectClient(
.setExplain(proto.Explain.newBuilder().setExplainMode(mode))
.setUserContext(userContext)
.setClientId(sessionId)
+ .setClientType(userAgent)
.build()
analyze(request)
}
@@ -89,7 +93,21 @@ class SparkConnectClient(
}
}
-object SparkConnectClient {
+private[sql] object SparkConnectClient {
+
+ private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"
+
+ private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
+ Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
+
+ private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
+ "Authentication token cannot be passed over insecure connections. " +
+ "Either remove 'token' or set 'use_ssl=true'"
+
+ // for internal tests
+ def apply(userContext: UserContext, channel: ManagedChannel): SparkConnectClient =
+ new SparkConnectClient(userContext, channel, DEFAULT_USER_AGENT)
+
def builder(): Builder = new Builder()
/**
@@ -98,14 +116,27 @@ object SparkConnectClient {
*/
class Builder() {
private val userContextBuilder = proto.UserContext.newBuilder()
+ private var userAgent: Option[String] = None
+
private var host: String = "localhost"
private var port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT
+ private var token: Option[String] = None
+ // If no value specified for isSslEnabled, default to false
+ private var isSslEnabled: Option[Boolean] = None
+
+ private var metadata: Map[String, String] = Map.empty
+
def userId(id: String): Builder = {
userContextBuilder.setUserId(id)
this
}
+ def userName(name: String): Builder = {
+ userContextBuilder.setUserName(name)
+ this
+ }
+
def host(inputHost: String): Builder = {
require(inputHost != null)
host = inputHost
@@ -117,10 +148,58 @@ object SparkConnectClient {
this
}
+ /**
+ * Setting the token implicitly sets the use_ssl=true. All the following examples yield the
+ * same results:
+ *
+ * {{{
+ * sc://localhost/;token=aaa
+ * sc://localhost/;use_ssl=true;token=aaa
+ * sc://localhost/;token=aaa;use_ssl=true
+ * }}}
+ *
+ * Throws exception if the token is set but use_ssl=false.
+ *
+ * @param inputToken
+ * the user token.
+ * @return
+ * this builder.
+ */
+ def token(inputToken: String): Builder = {
+ require(inputToken != null && inputToken.nonEmpty)
+ token = Some(inputToken)
+ // Only set the isSSlEnabled if it is not yet set
+ isSslEnabled match {
+ case None => isSslEnabled = Some(true)
+ case Some(false) =>
+ throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
+ case Some(true) => // Good, the ssl is enabled
+ }
+ this
+ }
+
+ def enableSsl(): Builder = {
+ isSslEnabled = Some(true)
+ this
+ }
+
+ /**
+ * Disables the SSL. Throws exception if the token has been set.
+ *
+ * @return
+ * this builder.
+ */
+ def disableSsl(): Builder = {
+ require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
+ isSslEnabled = Some(false)
+ this
+ }
+
private object URIParams {
val PARAM_USER_ID = "user_id"
val PARAM_USE_SSL = "use_ssl"
val PARAM_TOKEN = "token"
+ val PARAM_USER_AGENT = "user_agent"
}
private def verifyURI(uri: URI): Unit = {
@@ -146,6 +225,12 @@ object SparkConnectClient {
}
}
+ def userAgent(value: String): Builder = {
+ require(value != null)
+ userAgent = Some(value)
+ this
+ }
+
private def parseURIParams(uri: URI): Unit = {
val params = uri.getPath.split(';').drop(1).filter(_ != "")
params.foreach { kv =>
@@ -158,13 +243,13 @@ object SparkConnectClient {
}
(arr(0), arr(1))
}
- if (key == URIParams.PARAM_USER_ID) {
- userContextBuilder.setUserId(value)
- } else {
- // TODO(SPARK-41917): Support SSL and Auth tokens.
- throw new UnsupportedOperationException(
- "Parameters apart from user_id" +
- " are currently unsupported.")
+ key match {
+ case URIParams.PARAM_USER_ID => userId(value)
+ case URIParams.PARAM_USER_AGENT => userAgent(value)
+ case URIParams.PARAM_TOKEN => token(value)
+ case URIParams.PARAM_USE_SSL =>
+ if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
+ case _ => this.metadata = this.metadata + (key -> value)
}
}
}
@@ -176,7 +261,6 @@ object SparkConnectClient {
* Note: The connection string, if used, will override any previous host/port settings.
*/
def connectionString(connectionString: String): Builder = {
- // TODO(SPARK-41917): Support SSL and Auth tokens.
val uri = new URI(connectionString)
verifyURI(uri)
parseURIParams(uri)
@@ -189,9 +273,84 @@ object SparkConnectClient {
}
def build(): SparkConnectClient = {
- val channelBuilder = ManagedChannelBuilder.forAddress(host, port).usePlaintext()
+ val creds = isSslEnabled match {
+ case Some(false) | None => InsecureChannelCredentials.create()
+ case Some(true) =>
+ token match {
+ case Some(t) =>
+ // With access token added in the http header.
+ CompositeChannelCredentials.create(
+ TlsChannelCredentials.create,
+ new AccessTokenCallCredentials(t))
+ case None =>
+ TlsChannelCredentials.create
+ }
+ }
+
+ val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)
+ if (metadata.nonEmpty) {
+ channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
+ }
val channel: ManagedChannel = channelBuilder.build()
- new SparkConnectClient(userContextBuilder.build(), channel)
+ new SparkConnectClient(
+ userContextBuilder.build(),
+ channel,
+ userAgent.getOrElse(DEFAULT_USER_AGENT))
+ }
+ }
+
+ /**
+ * A [[CallCredentials]] created from an access token.
+ *
+ * @param token
+ * A string to place directly in the http request authorization header, for example
+ * "authorization: Bearer <access_token>".
+ */
+ private[client] class AccessTokenCallCredentials(token: String) extends CallCredentials {
+ override def applyRequestMetadata(
+ requestInfo: CallCredentials.RequestInfo,
+ appExecutor: Executor,
+ applier: CallCredentials.MetadataApplier): Unit = {
+ appExecutor.execute(() => {
+ try {
+ val headers = new Metadata()
+ headers.put(AUTH_TOKEN_META_DATA_KEY, s"Bearer $token");
+ applier.apply(headers)
+ } catch {
+ case e: Throwable =>
+ applier.fail(Status.UNAUTHENTICATED.withCause(e));
+ }
+ })
+ }
+
+ override def thisUsesUnstableApi(): Unit = {
+ // Marks this API is not stable. Left empty on purpose.
+ }
+ }
+
+ /**
+ * A client interceptor to pass extra parameters in http request header.
+ *
+ * @param metadata
+ * extra metadata placed in the http request header, for example "key: value".
+ */
+ private[client] class MetadataHeaderClientInterceptor(metadata: Map[String, String])
+ extends ClientInterceptor {
+ override def interceptCall[ReqT, RespT](
+ method: MethodDescriptor[ReqT, RespT],
+ callOptions: CallOptions,
+ next: Channel): ClientCall[ReqT, RespT] = {
+ new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
+ next.newCall(method, callOptions)) {
+ override def start(
+ responseListener: ClientCall.Listener[RespT],
+ headers: Metadata): Unit = {
+ metadata.foreach { case (key, value) =>
+ headers.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value)
+ }
+ super.start(responseListener, headers)
+ }
+ }
}
}
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 6a54cc88aec..b759471e777 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -99,7 +99,7 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
override protected def beforeAll(): Unit = {
super.beforeAll()
- val client = new SparkConnectClient(
+ val client = SparkConnectClient(
proto.UserContext.newBuilder().build(),
InProcessChannelBuilder.forName("/dev/null").build())
val builder = SparkSession.builder().client(client)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 908eddbe7bf..98dacbcab89 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client
import java.util.concurrent.TimeUnit
-import io.grpc.Server
+import io.grpc.{Server, StatusRuntimeException}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
import org.scalatest.BeforeAndAfterEach
@@ -65,10 +65,11 @@ class SparkConnectClientSuite
assert(client.userId == "abc123")
}
- private def testClientConnection(
- client: SparkConnectClient,
- serverPort: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT): Unit = {
+ // Use 0 to start the server at a random port
+ private def testClientConnection(serverPort: Int = 0)(
+ clientBuilder: Int => SparkConnectClient): Unit = {
startDummyServer(serverPort)
+ client = clientBuilder(server.getPort)
val request = AnalyzePlanRequest
.newBuilder()
.setClientId("abc123")
@@ -79,15 +80,28 @@ class SparkConnectClientSuite
}
test("Test connection") {
- val testPort = 16001
- client = SparkConnectClient.builder().port(testPort).build()
- testClientConnection(client, testPort)
+ testClientConnection() { testPort => SparkConnectClient.builder().port(testPort).build() }
}
test("Test connection string") {
- val testPort = 16000
- client = SparkConnectClient.builder().connectionString("sc://localhost:16000").build()
- testClientConnection(client, testPort)
+ testClientConnection() { testPort =>
+ SparkConnectClient.builder().connectionString(s"sc://localhost:$testPort").build()
+ }
+ }
+
+ test("Test encryption") {
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
+ .build()
+
+ val request = AnalyzePlanRequest.newBuilder().setClientId("abc123").build()
+
+ // Failed the ssl handshake as the dummy server does not have any server credentials installed.
+ assertThrows[StatusRuntimeException] {
+ client.analyze(request)
+ }
}
private case class TestPackURI(
@@ -97,17 +111,27 @@ class SparkConnectClientSuite
private val URIs = Seq[TestPackURI](
TestPackURI("sc://host", isCorrect = true),
- TestPackURI("sc://localhost/", isCorrect = true, client => testClientConnection(client)),
+ TestPackURI(
+ "sc://localhost/",
+ isCorrect = true,
+ client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)),
TestPackURI(
"sc://localhost:1234/",
isCorrect = true,
- client => testClientConnection(client, 1234)),
- TestPackURI("sc://localhost/;", isCorrect = true, client => testClientConnection(client)),
+ client => testClientConnection(1234)(_ => client)),
+ TestPackURI(
+ "sc://localhost/;",
+ isCorrect = true,
+ client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)),
TestPackURI("sc://host:123", isCorrect = true),
TestPackURI(
"sc://host:123/;user_id=a94",
isCorrect = true,
client => assert(client.userId == "a94")),
+ TestPackURI(
+ "sc://host:123/;user_agent=a945",
+ isCorrect = true,
+ client => assert(client.userAgent == "a945")),
TestPackURI("scc://host:12", isCorrect = false),
TestPackURI("http://host", isCorrect = false),
TestPackURI("sc:/host:1234/path", isCorrect = false),
@@ -116,7 +140,15 @@ class SparkConnectClientSuite
TestPackURI("sc://host:123;user_id=a94", isCorrect = false),
TestPackURI("sc:///user_id=123", isCorrect = false),
TestPackURI("sc://host:-4", isCorrect = false),
- TestPackURI("sc://:123/", isCorrect = false))
+ TestPackURI("sc://:123/", isCorrect = false),
+ TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true),
+ TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true),
+ TestPackURI("sc://host:123/;token=", isCorrect = false),
+ TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
+ TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
+ TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
+ TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false),
+ TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true))
private def checkTestPack(testPack: TestPackURI): Unit = {
val client = SparkConnectClient.builder().connectionString(testPack.connectionString).build()
@@ -132,22 +164,6 @@ class SparkConnectClientSuite
}
}
}
-
- // TODO(SPARK-41917): Remove test once SSL and Auth tokens are supported.
- test("Non user-id parameters throw unsupported errors") {
- assertThrows[UnsupportedOperationException] {
- SparkConnectClient.builder().connectionString("sc://host/;use_ssl=true").build()
- }
-
- assertThrows[UnsupportedOperationException] {
- SparkConnectClient.builder().connectionString("sc://host/;token=abc").build()
- }
-
- assertThrows[UnsupportedOperationException] {
- SparkConnectClient.builder().connectionString("sc://host/;xyz=abc").build()
-
- }
- }
}
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org