You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/24 12:44:38 UTC

[spark] branch master updated: [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c0d301ea3c3 [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client
c0d301ea3c3 is described below

commit c0d301ea3c3f6e3d1b10373823e0aeeb997e8daf
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Fri Feb 24 08:44:23 2023 -0400

    [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client
    
    ### What changes were proposed in this pull request?
    Adding SSL encryption and access token support for Scala client
    
    ### Why are the changes needed?
    To support basic client side encryption to protect data sent over the network.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit tests. Manual tests.
    
    Closes #40133 from zhenlineo/ssl.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   7 +-
 .../sql/connect/client/SparkConnectClient.scala    | 193 +++++++++++++++++++--
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   2 +-
 .../connect/client/SparkConnectClientSuite.scala   |  76 ++++----
 4 files changed, 229 insertions(+), 49 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b086db09365..0e5aaace20d 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -224,7 +224,12 @@ object SparkSession extends Logging {
   class Builder() extends Logging {
     private var _client: SparkConnectClient = _
 
-    def client(client: SparkConnectClient): Builder = {
+    def remote(connectionString: String): Builder = {
+      client(SparkConnectClient.builder().connectionString(connectionString).build())
+      this
+    }
+
+    private[sql] def client(client: SparkConnectClient): Builder = {
       _client = client
       this
     }
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 3049a0a0a5d..12bb581880c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,21 +17,22 @@
 
 package org.apache.spark.sql.connect.client
 
-import scala.language.existentials
-
-import io.grpc.{ManagedChannel, ManagedChannelBuilder}
+import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, Metadata, MethodDescriptor, Status, TlsChannelCredentials}
 import java.net.URI
 import java.util.UUID
+import java.util.concurrent.Executor
 
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.UserContext
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 
 /**
  * Conceptually the remote spark session that communicates with the server.
  */
-class SparkConnectClient(
+private[sql] class SparkConnectClient(
     private val userContext: proto.UserContext,
-    private val channel: ManagedChannel) {
+    private val channel: ManagedChannel,
+    private[client] val userAgent: String) {
 
   private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
 
@@ -40,7 +41,7 @@ class SparkConnectClient(
    * @return
    *   User ID.
    */
-  def userId: String = userContext.getUserId()
+  private[client] def userId: String = userContext.getUserId()
 
   // Generate a unique session ID for this client. This UUID must be unique to allow
   // concurrent Spark sessions of the same user. If the channel is closed, creating
@@ -60,6 +61,8 @@ class SparkConnectClient(
       .newBuilder()
       .setPlan(plan)
       .setUserContext(userContext)
+      .setClientId(sessionId)
+      .setClientType(userAgent)
       .build()
     stub.executePlan(request)
   }
@@ -77,6 +80,7 @@ class SparkConnectClient(
       .setExplain(proto.Explain.newBuilder().setExplainMode(mode))
       .setUserContext(userContext)
       .setClientId(sessionId)
+      .setClientType(userAgent)
       .build()
     analyze(request)
   }
@@ -89,7 +93,21 @@ class SparkConnectClient(
   }
 }
 
-object SparkConnectClient {
+private[sql] object SparkConnectClient {
+
+  private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA"
+
+  private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
+    Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
+
+  private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String =
+    "Authentication token cannot be passed over insecure connections. " +
+      "Either remove 'token' or set 'use_ssl=true'"
+
+  // for internal tests
+  def apply(userContext: UserContext, channel: ManagedChannel): SparkConnectClient =
+    new SparkConnectClient(userContext, channel, DEFAULT_USER_AGENT)
+
   def builder(): Builder = new Builder()
 
   /**
@@ -98,14 +116,27 @@ object SparkConnectClient {
    */
   class Builder() {
     private val userContextBuilder = proto.UserContext.newBuilder()
+    private var userAgent: Option[String] = None
+
     private var host: String = "localhost"
     private var port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT
 
+    private var token: Option[String] = None
+    // If no value specified for isSslEnabled, default to false
+    private var isSslEnabled: Option[Boolean] = None
+
+    private var metadata: Map[String, String] = Map.empty
+
     def userId(id: String): Builder = {
       userContextBuilder.setUserId(id)
       this
     }
 
+    def userName(name: String): Builder = {
+      userContextBuilder.setUserName(name)
+      this
+    }
+
     def host(inputHost: String): Builder = {
       require(inputHost != null)
       host = inputHost
@@ -117,10 +148,58 @@ object SparkConnectClient {
       this
     }
 
+    /**
+     * Setting the token implicitly sets the use_ssl=true. All the following examples yield the
+     * same results:
+     *
+     * {{{
+     * sc://localhost/;token=aaa
+     * sc://localhost/;use_ssl=true;token=aaa
+     * sc://localhost/;token=aaa;use_ssl=true
+     * }}}
+     *
+     * Throws exception if the token is set but use_ssl=false.
+     *
+     * @param inputToken
+     *   the user token.
+     * @return
+     *   this builder.
+     */
+    def token(inputToken: String): Builder = {
+      require(inputToken != null && inputToken.nonEmpty)
+      token = Some(inputToken)
+      // Only set the isSSlEnabled if it is not yet set
+      isSslEnabled match {
+        case None => isSslEnabled = Some(true)
+        case Some(false) =>
+          throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
+        case Some(true) => // Good, the ssl is enabled
+      }
+      this
+    }
+
+    def enableSsl(): Builder = {
+      isSslEnabled = Some(true)
+      this
+    }
+
+    /**
+     * Disables the SSL. Throws exception if the token has been set.
+     *
+     * @return
+     *   this builder.
+     */
+    def disableSsl(): Builder = {
+      require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
+      isSslEnabled = Some(false)
+      this
+    }
+
     private object URIParams {
       val PARAM_USER_ID = "user_id"
       val PARAM_USE_SSL = "use_ssl"
       val PARAM_TOKEN = "token"
+      val PARAM_USER_AGENT = "user_agent"
     }
 
     private def verifyURI(uri: URI): Unit = {
@@ -146,6 +225,12 @@ object SparkConnectClient {
       }
     }
 
+    def userAgent(value: String): Builder = {
+      require(value != null)
+      userAgent = Some(value)
+      this
+    }
+
     private def parseURIParams(uri: URI): Unit = {
       val params = uri.getPath.split(';').drop(1).filter(_ != "")
       params.foreach { kv =>
@@ -158,13 +243,13 @@ object SparkConnectClient {
           }
           (arr(0), arr(1))
         }
-        if (key == URIParams.PARAM_USER_ID) {
-          userContextBuilder.setUserId(value)
-        } else {
-          // TODO(SPARK-41917): Support SSL and Auth tokens.
-          throw new UnsupportedOperationException(
-            "Parameters apart from user_id" +
-              " are currently unsupported.")
+        key match {
+          case URIParams.PARAM_USER_ID => userId(value)
+          case URIParams.PARAM_USER_AGENT => userAgent(value)
+          case URIParams.PARAM_TOKEN => token(value)
+          case URIParams.PARAM_USE_SSL =>
+            if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
+          case _ => this.metadata = this.metadata + (key -> value)
         }
       }
     }
@@ -176,7 +261,6 @@ object SparkConnectClient {
      * Note: The connection string, if used, will override any previous host/port settings.
      */
     def connectionString(connectionString: String): Builder = {
-      // TODO(SPARK-41917): Support SSL and Auth tokens.
       val uri = new URI(connectionString)
       verifyURI(uri)
       parseURIParams(uri)
@@ -189,9 +273,84 @@ object SparkConnectClient {
     }
 
     def build(): SparkConnectClient = {
-      val channelBuilder = ManagedChannelBuilder.forAddress(host, port).usePlaintext()
+      val creds = isSslEnabled match {
+        case Some(false) | None => InsecureChannelCredentials.create()
+        case Some(true) =>
+          token match {
+            case Some(t) =>
+              // With access token added in the http header.
+              CompositeChannelCredentials.create(
+                TlsChannelCredentials.create,
+                new AccessTokenCallCredentials(t))
+            case None =>
+              TlsChannelCredentials.create
+          }
+      }
+
+      val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)
+      if (metadata.nonEmpty) {
+        channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
+      }
       val channel: ManagedChannel = channelBuilder.build()
-      new SparkConnectClient(userContextBuilder.build(), channel)
+      new SparkConnectClient(
+        userContextBuilder.build(),
+        channel,
+        userAgent.getOrElse(DEFAULT_USER_AGENT))
+    }
+  }
+
+  /**
+   * A [[CallCredentials]] created from an access token.
+   *
+   * @param token
+   *   A string to place directly in the http request authorization header, for example
+   *   "authorization: Bearer <access_token>".
+   */
+  private[client] class AccessTokenCallCredentials(token: String) extends CallCredentials {
+    override def applyRequestMetadata(
+        requestInfo: CallCredentials.RequestInfo,
+        appExecutor: Executor,
+        applier: CallCredentials.MetadataApplier): Unit = {
+      appExecutor.execute(() => {
+        try {
+          val headers = new Metadata()
+          headers.put(AUTH_TOKEN_META_DATA_KEY, s"Bearer $token");
+          applier.apply(headers)
+        } catch {
+          case e: Throwable =>
+            applier.fail(Status.UNAUTHENTICATED.withCause(e));
+        }
+      })
+    }
+
+    override def thisUsesUnstableApi(): Unit = {
+      // Marks this API is not stable. Left empty on purpose.
+    }
+  }
+
+  /**
+   * A client interceptor to pass extra parameters in http request header.
+   *
+   * @param metadata
+   *   extra metadata placed in the http request header, for example "key: value".
+   */
+  private[client] class MetadataHeaderClientInterceptor(metadata: Map[String, String])
+      extends ClientInterceptor {
+    override def interceptCall[ReqT, RespT](
+        method: MethodDescriptor[ReqT, RespT],
+        callOptions: CallOptions,
+        next: Channel): ClientCall[ReqT, RespT] = {
+      new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
+        next.newCall(method, callOptions)) {
+        override def start(
+            responseListener: ClientCall.Listener[RespT],
+            headers: Metadata): Unit = {
+          metadata.foreach { case (key, value) =>
+            headers.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value)
+          }
+          super.start(responseListener, headers)
+        }
+      }
     }
   }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 6a54cc88aec..b759471e777 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -99,7 +99,7 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
 
   override protected def beforeAll(): Unit = {
     super.beforeAll()
-    val client = new SparkConnectClient(
+    val client = SparkConnectClient(
       proto.UserContext.newBuilder().build(),
       InProcessChannelBuilder.forName("/dev/null").build())
     val builder = SparkSession.builder().client(client)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 908eddbe7bf..98dacbcab89 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client
 
 import java.util.concurrent.TimeUnit
 
-import io.grpc.Server
+import io.grpc.{Server, StatusRuntimeException}
 import io.grpc.netty.NettyServerBuilder
 import io.grpc.stub.StreamObserver
 import org.scalatest.BeforeAndAfterEach
@@ -65,10 +65,11 @@ class SparkConnectClientSuite
     assert(client.userId == "abc123")
   }
 
-  private def testClientConnection(
-      client: SparkConnectClient,
-      serverPort: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT): Unit = {
+  // Use 0 to start the server at a random port
+  private def testClientConnection(serverPort: Int = 0)(
+      clientBuilder: Int => SparkConnectClient): Unit = {
     startDummyServer(serverPort)
+    client = clientBuilder(server.getPort)
     val request = AnalyzePlanRequest
       .newBuilder()
       .setClientId("abc123")
@@ -79,15 +80,28 @@ class SparkConnectClientSuite
   }
 
   test("Test connection") {
-    val testPort = 16001
-    client = SparkConnectClient.builder().port(testPort).build()
-    testClientConnection(client, testPort)
+    testClientConnection() { testPort => SparkConnectClient.builder().port(testPort).build() }
   }
 
   test("Test connection string") {
-    val testPort = 16000
-    client = SparkConnectClient.builder().connectionString("sc://localhost:16000").build()
-    testClientConnection(client, testPort)
+    testClientConnection() { testPort =>
+      SparkConnectClient.builder().connectionString(s"sc://localhost:$testPort").build()
+    }
+  }
+
+  test("Test encryption") {
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
+      .build()
+
+    val request = AnalyzePlanRequest.newBuilder().setClientId("abc123").build()
+
+    // Failed the ssl handshake as the dummy server does not have any server credentials installed.
+    assertThrows[StatusRuntimeException] {
+      client.analyze(request)
+    }
   }
 
   private case class TestPackURI(
@@ -97,17 +111,27 @@ class SparkConnectClientSuite
 
   private val URIs = Seq[TestPackURI](
     TestPackURI("sc://host", isCorrect = true),
-    TestPackURI("sc://localhost/", isCorrect = true, client => testClientConnection(client)),
+    TestPackURI(
+      "sc://localhost/",
+      isCorrect = true,
+      client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)),
     TestPackURI(
       "sc://localhost:1234/",
       isCorrect = true,
-      client => testClientConnection(client, 1234)),
-    TestPackURI("sc://localhost/;", isCorrect = true, client => testClientConnection(client)),
+      client => testClientConnection(1234)(_ => client)),
+    TestPackURI(
+      "sc://localhost/;",
+      isCorrect = true,
+      client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)),
     TestPackURI("sc://host:123", isCorrect = true),
     TestPackURI(
       "sc://host:123/;user_id=a94",
       isCorrect = true,
       client => assert(client.userId == "a94")),
+    TestPackURI(
+      "sc://host:123/;user_agent=a945",
+      isCorrect = true,
+      client => assert(client.userAgent == "a945")),
     TestPackURI("scc://host:12", isCorrect = false),
     TestPackURI("http://host", isCorrect = false),
     TestPackURI("sc:/host:1234/path", isCorrect = false),
@@ -116,7 +140,15 @@ class SparkConnectClientSuite
     TestPackURI("sc://host:123;user_id=a94", isCorrect = false),
     TestPackURI("sc:///user_id=123", isCorrect = false),
     TestPackURI("sc://host:-4", isCorrect = false),
-    TestPackURI("sc://:123/", isCorrect = false))
+    TestPackURI("sc://:123/", isCorrect = false),
+    TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true),
+    TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true),
+    TestPackURI("sc://host:123/;token=", isCorrect = false),
+    TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
+    TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
+    TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
+    TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false),
+    TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true))
 
   private def checkTestPack(testPack: TestPackURI): Unit = {
     val client = SparkConnectClient.builder().connectionString(testPack.connectionString).build()
@@ -132,22 +164,6 @@ class SparkConnectClientSuite
       }
     }
   }
-
-  // TODO(SPARK-41917): Remove test once SSL and Auth tokens are supported.
-  test("Non user-id parameters throw unsupported errors") {
-    assertThrows[UnsupportedOperationException] {
-      SparkConnectClient.builder().connectionString("sc://host/;use_ssl=true").build()
-    }
-
-    assertThrows[UnsupportedOperationException] {
-      SparkConnectClient.builder().connectionString("sc://host/;token=abc").build()
-    }
-
-    assertThrows[UnsupportedOperationException] {
-      SparkConnectClient.builder().connectionString("sc://host/;xyz=abc").build()
-
-    }
-  }
 }
 
 class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org