You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by ab...@apache.org on 2021/04/23 13:42:19 UTC
[kudu] branch master updated: [spark] KUDU-1884 Add custom SASL
protocol name
This is an automated email from the ASF dual-hosted git repository.
abukor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git
The following commit(s) were added to refs/heads/master by this push:
new dc5b5bd [spark] KUDU-1884 Add custom SASL protocol name
dc5b5bd is described below
commit dc5b5bd899755faa506363bd00d3bbbac8d594d3
Author: Attila Bukor <ab...@apache.org>
AuthorDate: Tue Apr 20 18:35:16 2021 +0200
[spark] KUDU-1884 Add custom SASL protocol name
Java client already supports setting custom SASL protocol names for a
KuduClient or AsyncKuduClient instance which is needed when using a
non-default service principal name. This patch exposes this setting in
KuduContext and DefaultSource.
Change-Id: Ifd0dba4f829f369c363cc89bb58650249035f356
Reviewed-on: http://gerrit.cloudera.org:8080/17328
Tested-by: Attila Bukor <ab...@apache.org>
Reviewed-by: Alexey Serbin <as...@cloudera.com>
Reviewed-by: Grant Henke <gr...@apache.org>
---
.../org/apache/kudu/spark/kudu/DefaultSource.scala | 15 ++++++++++++--
.../org/apache/kudu/spark/kudu/KuduContext.scala | 16 +++++++++++----
.../apache/kudu/spark/kudu/DefaultSourceTest.scala | 23 +++++++++++++++++++++-
.../org/apache/kudu/spark/kudu/KuduTestSuite.scala | 6 +++++-
4 files changed, 52 insertions(+), 8 deletions(-)
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index 2dcf335..b873736 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -72,6 +72,7 @@ class DefaultSource
val HANDLE_SCHEMA_DRIFT = "kudu.handleSchemaDrift"
val USE_DRIVER_METADATA = "kudu.useDriverMetadata"
val SNAPSHOT_TIMESTAMP_MS = "kudu.snapshotTimestampMs"
+ val SASL_PROTOCOL_NAME = "kudu.saslProtocolName"
/**
* A nice alias for the data source so that when specifying the format
@@ -109,6 +110,7 @@ class DefaultSource
val tableName = getTableName(parameters)
val kuduMaster = getMasterAddrs(parameters)
val operationType = getOperationType(parameters)
+ val saslProtocolName = getSaslProtocolName(parameters)
val schemaOption = Option(schema)
val readOptions = getReadOptions(parameters)
val writeOptions = getWriteOptions(parameters)
@@ -116,6 +118,7 @@ class DefaultSource
new KuduRelation(
tableName,
kuduMaster,
+ saslProtocolName,
operationType,
schemaOption,
readOptions,
@@ -157,12 +160,14 @@ class DefaultSource
val tableName = getTableName(parameters)
val masterAddrs = getMasterAddrs(parameters)
val operationType = getOperationType(parameters)
+ val saslProtocolName = getSaslProtocolName(parameters)
val readOptions = getReadOptions(parameters)
val writeOptions = getWriteOptions(parameters)
new KuduSink(
tableName,
masterAddrs,
+ saslProtocolName,
operationType,
readOptions,
writeOptions
@@ -227,6 +232,10 @@ class DefaultSource
parameters.getOrElse(KUDU_MASTER, InetAddress.getLocalHost.getCanonicalHostName)
}
+ private def getSaslProtocolName(parameters: Map[String, String]): String = {
+ parameters.getOrElse(SASL_PROTOCOL_NAME, "kudu")
+ }
+
private def getScanLocalityType(opParam: String): ReplicaSelection = {
opParam.toLowerCase(Locale.ENGLISH) match {
case "leader_only" => ReplicaSelection.LEADER_ONLY
@@ -274,6 +283,7 @@ class DefaultSource
class KuduRelation(
val tableName: String,
val masterAddrs: String,
+ val saslProtocolName: String,
val operationType: OperationType,
val userSchema: Option[StructType],
val readOptions: KuduReadOptions = new KuduReadOptions,
@@ -282,7 +292,7 @@ class KuduRelation(
val log: Logger = LoggerFactory.getLogger(getClass)
private val context: KuduContext =
- new KuduContext(masterAddrs, sqlContext.sparkContext)
+ new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))
private val table: KuduTable = context.syncClient.openTable(tableName)
@@ -498,13 +508,14 @@ private[spark] object KuduRelation {
class KuduSink(
val tableName: String,
val masterAddrs: String,
+ val saslProtocolName: String,
val operationType: OperationType,
val readOptions: KuduReadOptions = new KuduReadOptions,
val writeOptions: KuduWriteOptions)(val sqlContext: SQLContext)
extends Sink {
private val context: KuduContext =
- new KuduContext(masterAddrs, sqlContext.sparkContext)
+ new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))
override def addBatch(batchId: Long, data: DataFrame): Unit = {
context.writeRows(data, tableName, operationType, writeOptions)
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index 563533e..c364dc8 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -59,7 +59,11 @@ import org.apache.kudu.Type
@InterfaceAudience.Public
@InterfaceStability.Evolving
@SerialVersionUID(1L)
-class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeoutMs: Option[Long])
+class KuduContext(
+ val kuduMaster: String,
+ sc: SparkContext,
+ val socketReadTimeoutMs: Option[Long],
+ val saslProtocolName: Option[String] = None)
extends Serializable {
val log: Logger = LoggerFactory.getLogger(getClass)
@@ -149,7 +153,7 @@ class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeou
@transient lazy val syncClient: KuduClient = asyncClient.syncClient()
@transient lazy val asyncClient: AsyncKuduClient = {
- val c = KuduClientCache.getAsyncClient(kuduMaster)
+ val c = KuduClientCache.getAsyncClient(kuduMaster, saslProtocolName)
if (authnCredentials != null) {
c.importAuthenticationCredentials(authnCredentials)
}
@@ -607,10 +611,14 @@ private object KuduClientCache {
clientCache.clear()
}
- def getAsyncClient(kuduMaster: String): AsyncKuduClient = {
+ def getAsyncClient(kuduMaster: String, saslProtocolName: Option[String]): AsyncKuduClient = {
clientCache.synchronized {
if (!clientCache.contains(kuduMaster)) {
- val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build()
+ val builder = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster)
+ if (saslProtocolName.nonEmpty) {
+ builder.saslProtocolName(saslProtocolName.get)
+ }
+ val asyncClient = builder.build()
val hookHandle = new Runnable {
override def run(): Unit = asyncClient.close()
}
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 76f2af1..327ba4c 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -19,7 +19,6 @@ package org.apache.kudu.spark.kudu
import java.nio.charset.StandardCharsets
import java.util
-
import scala.collection.JavaConverters._
import scala.collection.immutable.IndexedSeq
import org.apache.spark.SparkException
@@ -35,6 +34,7 @@ import org.apache.kudu.client.CreateTableOptions
import org.apache.kudu.test.KuduTestHarness
import org.apache.kudu.test.RandomUtils
import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
+import org.apache.kudu.test.KuduTestHarness.EnableKerberos
import org.apache.kudu.test.KuduTestHarness.MasterServerConfig
import org.junit.Before
import org.junit.Test
@@ -876,4 +876,25 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
val kuduRelation = kuduRelationFromDataFrame(dataFrame)
assert(kuduRelation.sizeInBytes == 1024)
}
+
+ @Test
+ @EnableKerberos(principal = "oryx")
+ def testNonDefaultPrincipal(): Unit = {
+ KuduClientCache.clearCacheForTests()
+ val exception = intercept[Exception] {
+ val df = sqlContext.read.options(kuduOptions).format("kudu").load
+ df.count()
+ }
+ assertTrue(exception.getCause.getMessage.contains("this client is not authenticated"))
+
+ KuduClientCache.clearCacheForTests()
+ kuduOptions = Map(
+ "kudu.table" -> tableName,
+ "kudu.master" -> harness.getMasterAddressesAsString,
+ "kudu.saslProtocolName" -> "oryx"
+ )
+
+ val df = sqlContext.read.options(kuduOptions).format("kudu").load
+ assertEquals(rowCount, df.count())
+ }
}
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
index 88d440d..ebf41a4 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
@@ -139,7 +139,11 @@ trait KuduTestSuite {
@Before
def setUpBase(): Unit = {
ss = SparkSession.builder().config(conf).getOrCreate()
- kuduContext = new KuduContext(harness.getMasterAddressesAsString, ss.sparkContext)
+ kuduContext = new KuduContext(
+ harness.getMasterAddressesAsString,
+ ss.sparkContext,
+ None,
+ Some(harness.getPrincipal()))
// Spark tests should use the client from the kuduContext.
kuduClient = kuduContext.syncClient