You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by mp...@apache.org on 2018/07/16 19:30:04 UTC

[2/2] kudu git commit: spark: Expose socketReadTimeoutMs to Spark connector

spark: Expose socketReadTimeoutMs to Spark connector

This patch exposes socketReadTimeoutMs in the KuduContext and the
DefaultSource.

This patch also performs a bit of cleanup by renaming
the KuduConnection object to KuduClientCache, which seems like a more
appropriate name.

Because socketReadTimeout is a KuduClient configuration parameter
related to connection handling, socketReadTimeout was incorporated into
the client cache key.

Manually tested in spark-shell using spark-on-yarn.

Added a basic test to ensure that the parameter is properly parsed by
the DefaultSource and configured in the KuduRelation instance.

Change-Id: I0ab0ff0b242790caffb7e2848958148ffe547c4d
Reviewed-on: http://gerrit.cloudera.org:8080/10839
Tested-by: Kudu Jenkins
Reviewed-by: Dan Burkert <da...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/eee82d90
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/eee82d90
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/eee82d90

Branch: refs/heads/master
Commit: eee82d90a54108f2d7e18e84ec0bbd391fcc129a
Parents: eaed285
Author: Mike Percy <mp...@apache.org>
Authored: Thu Jun 28 00:35:55 2018 -0700
Committer: Mike Percy <mp...@apache.org>
Committed: Mon Jul 16 19:26:47 2018 +0000

----------------------------------------------------------------------
 .../apache/kudu/spark/kudu/DefaultSource.scala  | 21 ++++++---
 .../apache/kudu/spark/kudu/KuduContext.scala    | 47 ++++++++++++--------
 .../org/apache/kudu/spark/kudu/KuduRDD.scala    |  3 +-
 .../kudu/spark/kudu/DefaultSourceTest.scala     | 15 +++++++
 .../kudu/spark/kudu/KuduContextTest.scala       |  2 +-
 5 files changed, 61 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
----------------------------------------------------------------------
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index dd5b824..090e5fb 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -52,6 +52,7 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider
   val IGNORE_NULL = "kudu.ignoreNull"
   val IGNORE_DUPLICATE_ROW_ERRORS = "kudu.ignoreDuplicateRowErrors"
   val SCAN_REQUEST_TIMEOUT_MS = "kudu.scanRequestTimeoutMs"
+  val SOCKET_READ_TIMEOUT_MS = "kudu.socketReadTimeoutMs"
 
   def defaultMasterAddrs: String = InetAddress.getLocalHost.getCanonicalHostName
 
@@ -59,6 +60,10 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider
     parameters.get(SCAN_REQUEST_TIMEOUT_MS).map(_.toLong)
   }
 
+  def getSocketReadTimeoutMs(parameters: Map[String, String]): Option[Long] = {
+    parameters.get(SOCKET_READ_TIMEOUT_MS).map(_.toLong)
+  }
+
   /**
     * Construct a BaseRelation using the provided context and parameters.
     *
@@ -82,8 +87,8 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider
     val writeOptions = new KuduWriteOptions(ignoreDuplicateRowErrors, ignoreNull)
 
     new KuduRelation(tableName, kuduMaster, faultTolerantScanner,
-      scanLocality, getScanRequestTimeoutMs(parameters), operationType, None,
-      writeOptions)(sqlContext)
+      scanLocality, getScanRequestTimeoutMs(parameters), getSocketReadTimeoutMs(parameters),
+      operationType, None, writeOptions)(sqlContext)
   }
 
   /**
@@ -119,7 +124,8 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider
     val scanLocality = getScanLocalityType(parameters.getOrElse(SCAN_LOCALITY, "closest_replica"))
 
     new KuduRelation(tableName, kuduMaster, faultTolerantScanner,
-      scanLocality, getScanRequestTimeoutMs(parameters), operationType, Some(schema))(sqlContext)
+      scanLocality, getScanRequestTimeoutMs(parameters), getSocketReadTimeoutMs(parameters),
+      operationType, Some(schema))(sqlContext)
   }
 
   private def getOperationType(opParam: String): OperationType = {
@@ -163,6 +169,7 @@ class KuduRelation(private val tableName: String,
                    private val faultTolerantScanner: Boolean,
                    private val scanLocality: ReplicaSelection,
                    private[kudu] val scanRequestTimeoutMs: Option[Long],
+                   private[kudu] val socketReadTimeoutMs: Option[Long],
                    private val operationType: OperationType,
                    private val userSchema: Option[StructType],
                    private val writeOptions: KuduWriteOptions = new KuduWriteOptions)(
@@ -171,13 +178,13 @@ class KuduRelation(private val tableName: String,
     with PrunedFilteredScan
     with InsertableRelation {
 
-  import KuduRelation._
+  private val context: KuduContext = new KuduContext(masterAddrs, sqlContext.sparkContext,
+                                                     socketReadTimeoutMs)
 
-  private val context: KuduContext = new KuduContext(masterAddrs, sqlContext.sparkContext)
   private val table: KuduTable = context.syncClient.openTable(tableName)
 
   override def unhandledFilters(filters: Array[Filter]): Array[Filter] =
-    filters.filterNot(supportsFilter)
+    filters.filterNot(KuduRelation.supportsFilter)
 
   /**
     * Generates a SparkSQL schema object so SparkSQL knows what is being
@@ -200,7 +207,7 @@ class KuduRelation(private val tableName: String,
     val predicates = filters.flatMap(filterToPredicate)
     new KuduRDD(context, 1024 * 1024 * 20, requiredColumns, predicates,
                 table, faultTolerantScanner, scanLocality, scanRequestTimeoutMs,
-                sqlContext.sparkContext)
+                socketReadTimeoutMs, sqlContext.sparkContext)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
----------------------------------------------------------------------
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index a27e395..5f7d1f3 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -18,24 +18,22 @@
 package org.apache.kudu.spark.kudu
 
 import java.security.{AccessController, PrivilegedAction}
+
 import javax.security.auth.Subject
 import javax.security.auth.login.{AppConfigurationEntry, Configuration, LoginContext}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-
 import org.apache.hadoop.util.ShutdownHookManager
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.util.AccumulatorV2
-import org.apache.yetus.audience.InterfaceStability
+import org.apache.yetus.audience.{InterfaceAudience, InterfaceStability}
 import org.slf4j.{Logger, LoggerFactory}
-
 import org.apache.kudu.client.SessionConfiguration.FlushMode
 import org.apache.kudu.client._
-import org.apache.kudu.spark.kudu
 import org.apache.kudu.spark.kudu.SparkUtil._
 import org.apache.kudu.{Schema, Type}
 
@@ -48,7 +46,10 @@ import org.apache.kudu.{Schema, Type}
   */
 @InterfaceStability.Unstable
 class KuduContext(val kuduMaster: String,
-                  sc: SparkContext) extends Serializable {
+                  sc: SparkContext,
+                  val socketReadTimeoutMs: Option[Long]) extends Serializable {
+
+  def this(kuduMaster: String, sc: SparkContext) = this(kuduMaster, sc, None)
 
   /**
     * TimestampAccumulator accumulates the maximum value of client's
@@ -88,8 +89,6 @@ class KuduContext(val kuduMaster: String,
   val timestampAccumulator = new TimestampAccumulator()
   sc.register(timestampAccumulator)
 
-  import kudu.KuduContext._
-
   @Deprecated()
   def this(kuduMaster: String) {
     this(kuduMaster, new SparkContext())
@@ -98,7 +97,7 @@ class KuduContext(val kuduMaster: String,
   @transient lazy val syncClient: KuduClient = asyncClient.syncClient()
 
   @transient lazy val asyncClient: AsyncKuduClient = {
-    val c = KuduConnection.getAsyncClient(kuduMaster)
+    val c = KuduClientCache.getAsyncClient(kuduMaster, socketReadTimeoutMs)
     if (authnCredentials != null) {
       c.importAuthenticationCredentials(authnCredentials)
     }
@@ -107,7 +106,7 @@ class KuduContext(val kuduMaster: String,
 
   // Visible for testing.
   private[kudu] val authnCredentials : Array[Byte] = {
-    Subject.doAs(getSubject(sc), new PrivilegedAction[Array[Byte]] {
+    Subject.doAs(KuduContext.getSubject(sc), new PrivilegedAction[Array[Byte]] {
       override def run(): Array[Byte] = syncClient.exportAuthenticationCredentials()
     })
   }
@@ -128,7 +127,7 @@ class KuduContext(val kuduMaster: String,
     // TODO: localityScan, etc) to KuduRDD
     new KuduRDD(this, 1024*1024*20, columnProjection.toArray, Array(),
                 syncClient.openTable(tableName), false, ReplicaSelection.LEADER_ONLY,
-                None, sc)
+                None, None, sc)
   }
 
   /**
@@ -391,8 +390,8 @@ private object KuduContext {
   }
 }
 
-private object KuduConnection {
-  private[kudu] val asyncCache = new mutable.HashMap[String, AsyncKuduClient]()
+private object KuduClientCache {
+  private case class CacheKey(kuduMaster: String, socketReadTimeoutMs: Option[Long])
 
   /**
     * Set to
@@ -403,17 +402,29 @@ private object KuduConnection {
     */
   private val ShutdownHookPriority = 100
 
-  def getAsyncClient(kuduMaster: String): AsyncKuduClient = {
-    asyncCache.synchronized {
-      if (!asyncCache.contains(kuduMaster)) {
-        val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build()
+  private val clientCache = new mutable.HashMap[CacheKey, AsyncKuduClient]()
+
+  // Visible for testing.
+  private[kudu] def clearCacheForTests() = clientCache.clear()
+
+  def getAsyncClient(kuduMaster: String, socketReadTimeoutMs: Option[Long]): AsyncKuduClient = {
+    val cacheKey = CacheKey(kuduMaster, socketReadTimeoutMs)
+    clientCache.synchronized {
+      if (!clientCache.contains(cacheKey)) {
+        val builder = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster)
+        socketReadTimeoutMs match {
+          case Some(timeout) => builder.defaultSocketReadTimeoutMs(timeout)
+          case None =>
+        }
+
+        val asyncClient = builder.build()
         ShutdownHookManager.get().addShutdownHook(
           new Runnable {
             override def run(): Unit = asyncClient.close()
           }, ShutdownHookPriority)
-        asyncCache.put(kuduMaster, asyncClient)
+        clientCache.put(cacheKey, asyncClient)
       }
-      return asyncCache(kuduMaster)
+      return clientCache(cacheKey)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
----------------------------------------------------------------------
diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
index 7117983..4817da6 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
@@ -28,7 +28,7 @@ import org.apache.kudu.{Type, client}
 /**
   * A Resilient Distributed Dataset backed by a Kudu table.
   *
-  * To construct a KuduRDD, use {@link KuduContext#kuduRdd} or a Kudu DataSource.
+  * To construct a KuduRDD, use [[KuduContext#kuduRDD]] or a Kudu DataSource.
   */
 class KuduRDD private[kudu] (val kuduContext: KuduContext,
                              @transient val batchSize: Integer,
@@ -38,6 +38,7 @@ class KuduRDD private[kudu] (val kuduContext: KuduContext,
                              @transient val isFaultTolerant: Boolean,
                              @transient val scanLocality: ReplicaSelection,
                              @transient val scanRequestTimeoutMs: Option[Long],
+                             @transient val socketReadTimeoutMs: Option[Long],
                              @transient val sc: SparkContext) extends RDD[Row](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {

http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
----------------------------------------------------------------------
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 701e398..f01a211 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -734,4 +734,19 @@ class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfterEac
     val kuduRelation = kuduRelationFromDataFrame(dataFrame)
     assert(kuduRelation.scanRequestTimeoutMs == Some(1))
   }
+
+  /**
+    * Verify that the kudu.socketReadTimeoutMs parameter is parsed by the
+    * DefaultSource and makes it into the KuduRelation as a configuration
+    * parameter.
+    */
+  test("socket read timeout propagation") {
+    kuduOptions = Map(
+      "kudu.table" -> tableName,
+      "kudu.master" -> miniCluster.getMasterAddresses,
+      "kudu.socketReadTimeoutMs" -> "1")
+    val dataFrame = sqlContext.read.options(kuduOptions).kudu
+    val kuduRelation = kuduRelationFromDataFrame(dataFrame)
+    assert(kuduRelation.socketReadTimeoutMs == Some(1))
+  }
 }

http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
----------------------------------------------------------------------
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
index 4915002..47d4519 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
@@ -52,7 +52,7 @@ class KuduContextTest extends FunSuite with TestContext with Matchers {
 
   test("Test KuduContext serialization") {
     val serialized = serialize(kuduContext)
-    KuduConnection.asyncCache.clear()
+    KuduClientCache.clearCacheForTests()
     val deserialized = deserialize(serialized).asInstanceOf[KuduContext]
     assert(deserialized.authnCredentials != null)
     // Make a nonsense call just to make sure the re-hydrated client works.