You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by ab...@apache.org on 2021/04/23 13:42:19 UTC

[kudu] branch master updated: [spark] KUDU-1884 Add custom SASL protocol name

This is an automated email from the ASF dual-hosted git repository.

abukor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git


The following commit(s) were added to refs/heads/master by this push:
     new dc5b5bd  [spark] KUDU-1884 Add custom SASL protocol name
dc5b5bd is described below

commit dc5b5bd899755faa506363bd00d3bbbac8d594d3
Author: Attila Bukor <ab...@apache.org>
AuthorDate: Tue Apr 20 18:35:16 2021 +0200

    [spark] KUDU-1884 Add custom SASL protocol name
    
    Java client already supports setting custom SASL protocol names for a
    KuduClient or AsyncKuduClient instance which is needed when using a
    non-default service principal name. This patch exposes this setting in
    KuduContext and DefaultSource.
    
    Change-Id: Ifd0dba4f829f369c363cc89bb58650249035f356
    Reviewed-on: http://gerrit.cloudera.org:8080/17328
    Tested-by: Attila Bukor <ab...@apache.org>
    Reviewed-by: Alexey Serbin <as...@cloudera.com>
    Reviewed-by: Grant Henke <gr...@apache.org>
---
 .../org/apache/kudu/spark/kudu/DefaultSource.scala | 15 ++++++++++++--
 .../org/apache/kudu/spark/kudu/KuduContext.scala   | 16 +++++++++++----
 .../apache/kudu/spark/kudu/DefaultSourceTest.scala | 23 +++++++++++++++++++++-
 .../org/apache/kudu/spark/kudu/KuduTestSuite.scala |  6 +++++-
 4 files changed, 52 insertions(+), 8 deletions(-)

diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index 2dcf335..b873736 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -72,6 +72,7 @@ class DefaultSource
   val HANDLE_SCHEMA_DRIFT = "kudu.handleSchemaDrift"
   val USE_DRIVER_METADATA = "kudu.useDriverMetadata"
   val SNAPSHOT_TIMESTAMP_MS = "kudu.snapshotTimestampMs"
+  val SASL_PROTOCOL_NAME = "kudu.saslProtocolName"
 
   /**
    * A nice alias for the data source so that when specifying the format
@@ -109,6 +110,7 @@ class DefaultSource
     val tableName = getTableName(parameters)
     val kuduMaster = getMasterAddrs(parameters)
     val operationType = getOperationType(parameters)
+    val saslProtocolName = getSaslProtocolName(parameters)
     val schemaOption = Option(schema)
     val readOptions = getReadOptions(parameters)
     val writeOptions = getWriteOptions(parameters)
@@ -116,6 +118,7 @@ class DefaultSource
     new KuduRelation(
       tableName,
       kuduMaster,
+      saslProtocolName,
       operationType,
       schemaOption,
       readOptions,
@@ -157,12 +160,14 @@ class DefaultSource
     val tableName = getTableName(parameters)
     val masterAddrs = getMasterAddrs(parameters)
     val operationType = getOperationType(parameters)
+    val saslProtocolName = getSaslProtocolName(parameters)
     val readOptions = getReadOptions(parameters)
     val writeOptions = getWriteOptions(parameters)
 
     new KuduSink(
       tableName,
       masterAddrs,
+      saslProtocolName,
       operationType,
       readOptions,
       writeOptions
@@ -227,6 +232,10 @@ class DefaultSource
     parameters.getOrElse(KUDU_MASTER, InetAddress.getLocalHost.getCanonicalHostName)
   }
 
+  private def getSaslProtocolName(parameters: Map[String, String]): String = {
+    parameters.getOrElse(SASL_PROTOCOL_NAME, "kudu")
+  }
+
   private def getScanLocalityType(opParam: String): ReplicaSelection = {
     opParam.toLowerCase(Locale.ENGLISH) match {
       case "leader_only" => ReplicaSelection.LEADER_ONLY
@@ -274,6 +283,7 @@ class DefaultSource
 class KuduRelation(
     val tableName: String,
     val masterAddrs: String,
+    val saslProtocolName: String,
     val operationType: OperationType,
     val userSchema: Option[StructType],
     val readOptions: KuduReadOptions = new KuduReadOptions,
@@ -282,7 +292,7 @@ class KuduRelation(
   val log: Logger = LoggerFactory.getLogger(getClass)
 
   private val context: KuduContext =
-    new KuduContext(masterAddrs, sqlContext.sparkContext)
+    new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))
 
   private val table: KuduTable = context.syncClient.openTable(tableName)
 
@@ -498,13 +508,14 @@ private[spark] object KuduRelation {
 class KuduSink(
     val tableName: String,
     val masterAddrs: String,
+    val saslProtocolName: String,
     val operationType: OperationType,
     val readOptions: KuduReadOptions = new KuduReadOptions,
     val writeOptions: KuduWriteOptions)(val sqlContext: SQLContext)
     extends Sink {
 
   private val context: KuduContext =
-    new KuduContext(masterAddrs, sqlContext.sparkContext)
+    new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     context.writeRows(data, tableName, operationType, writeOptions)
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index 563533e..c364dc8 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -59,7 +59,11 @@ import org.apache.kudu.Type
 @InterfaceAudience.Public
 @InterfaceStability.Evolving
 @SerialVersionUID(1L)
-class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeoutMs: Option[Long])
+class KuduContext(
+    val kuduMaster: String,
+    sc: SparkContext,
+    val socketReadTimeoutMs: Option[Long],
+    val saslProtocolName: Option[String] = None)
     extends Serializable {
   val log: Logger = LoggerFactory.getLogger(getClass)
 
@@ -149,7 +153,7 @@ class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeou
   @transient lazy val syncClient: KuduClient = asyncClient.syncClient()
 
   @transient lazy val asyncClient: AsyncKuduClient = {
-    val c = KuduClientCache.getAsyncClient(kuduMaster)
+    val c = KuduClientCache.getAsyncClient(kuduMaster, saslProtocolName)
     if (authnCredentials != null) {
       c.importAuthenticationCredentials(authnCredentials)
     }
@@ -607,10 +611,14 @@ private object KuduClientCache {
     clientCache.clear()
   }
 
-  def getAsyncClient(kuduMaster: String): AsyncKuduClient = {
+  def getAsyncClient(kuduMaster: String, saslProtocolName: Option[String]): AsyncKuduClient = {
     clientCache.synchronized {
       if (!clientCache.contains(kuduMaster)) {
-        val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build()
+        val builder = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster)
+        if (saslProtocolName.nonEmpty) {
+          builder.saslProtocolName(saslProtocolName.get)
+        }
+        val asyncClient = builder.build()
         val hookHandle = new Runnable {
           override def run(): Unit = asyncClient.close()
         }
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 76f2af1..327ba4c 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -19,7 +19,6 @@ package org.apache.kudu.spark.kudu
 
 import java.nio.charset.StandardCharsets
 import java.util
-
 import scala.collection.JavaConverters._
 import scala.collection.immutable.IndexedSeq
 import org.apache.spark.SparkException
@@ -35,6 +34,7 @@ import org.apache.kudu.client.CreateTableOptions
 import org.apache.kudu.test.KuduTestHarness
 import org.apache.kudu.test.RandomUtils
 import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
+import org.apache.kudu.test.KuduTestHarness.EnableKerberos
 import org.apache.kudu.test.KuduTestHarness.MasterServerConfig
 import org.junit.Before
 import org.junit.Test
@@ -876,4 +876,25 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
     val kuduRelation = kuduRelationFromDataFrame(dataFrame)
     assert(kuduRelation.sizeInBytes == 1024)
   }
+
+  @Test
+  @EnableKerberos(principal = "oryx")
+  def testNonDefaultPrincipal(): Unit = {
+    KuduClientCache.clearCacheForTests()
+    val exception = intercept[Exception] {
+      val df = sqlContext.read.options(kuduOptions).format("kudu").load
+      df.count()
+    }
+    assertTrue(exception.getCause.getMessage.contains("this client is not authenticated"))
+
+    KuduClientCache.clearCacheForTests()
+    kuduOptions = Map(
+      "kudu.table" -> tableName,
+      "kudu.master" -> harness.getMasterAddressesAsString,
+      "kudu.saslProtocolName" -> "oryx"
+    )
+
+    val df = sqlContext.read.options(kuduOptions).format("kudu").load
+    assertEquals(rowCount, df.count())
+  }
 }
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
index 88d440d..ebf41a4 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
@@ -139,7 +139,11 @@ trait KuduTestSuite {
   @Before
   def setUpBase(): Unit = {
     ss = SparkSession.builder().config(conf).getOrCreate()
-    kuduContext = new KuduContext(harness.getMasterAddressesAsString, ss.sparkContext)
+    kuduContext = new KuduContext(
+      harness.getMasterAddressesAsString,
+      ss.sparkContext,
+      None,
+      Some(harness.getPrincipal()))
 
     // Spark tests should use the client from the kuduContext.
     kuduClient = kuduContext.syncClient