You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by ad...@apache.org on 2019/10/18 03:30:13 UTC

[kudu] 01/03: [spark] Separate out DefaultSourceTests

This is an automated email from the ASF dual-hosted git repository.

adar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git

commit 930d54483a8ae850175ef1f1aff94a8f4342705f
Author: Grant Henke <gr...@apache.org>
AuthorDate: Thu Oct 17 14:13:57 2019 -0500

    [spark] Separate out DefaultSourceTests
    
    We have seen flaky test failures due to timeouts of DefaultSourceTest.
    This is primarily due to the sheer number of tests in that class.
    
    This patch break out the SQL based tests, ones using `sqlContext.sql(…)`
    into their own class. There is no change in test methods or coverage.
    
    The result is 22 DefaultSourceTests and 22 SparkSQLTests.
    
    Change-Id: I54aa0327ffb5254c03fcfe8a0a08dba230360a40
    Reviewed-on: http://gerrit.cloudera.org:8080/14491
    Reviewed-by: Adar Dembo <ad...@cloudera.com>
    Reviewed-by: Hao Hao <ha...@cloudera.com>
    Tested-by: Kudu Jenkins
---
 .../apache/kudu/spark/kudu/DefaultSourceTest.scala | 497 -------------------
 .../org/apache/kudu/spark/kudu/KuduTestSuite.scala |  14 +
 .../org/apache/kudu/spark/kudu/SparkSQLTest.scala  | 534 +++++++++++++++++++++
 3 files changed, 548 insertions(+), 497 deletions(-)

diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 7dc96d2..521c83a 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -18,8 +18,6 @@ package org.apache.kudu.spark.kudu
 
 import scala.collection.JavaConverters._
 import scala.collection.immutable.IndexedSeq
-import scala.util.control.NonFatal
-import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.functions._
@@ -28,16 +26,10 @@ import org.apache.spark.sql.types.StructField
 import org.apache.spark.sql.types.StructType
 import org.junit.Assert._
 import org.scalatest.Matchers
-import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder
 import org.apache.kudu.client.CreateTableOptions
-import org.apache.kudu.Schema
-import org.apache.kudu.Type
 import org.apache.kudu.test.RandomUtils
 import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
 import org.apache.kudu.test.KuduTestHarness.MasterServerConfig
-import org.apache.kudu.test.KuduTestHarness.TabletServerConfig
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.junit.Before
 import org.junit.Test
 
@@ -505,380 +497,6 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
   }
 
   @Test
-  def testTableNonFaultTolerantScan() {
-    val results = sqlContext.sql(s"SELECT * FROM $tableName").collectAsList()
-    assert(results.size() == rowCount)
-
-    assert(!results.get(0).isNullAt(2))
-    assert(results.get(1).isNullAt(2))
-  }
-
-  @Test
-  def testTableFaultTolerantScan() {
-    kuduOptions = Map(
-      "kudu.table" -> tableName,
-      "kudu.master" -> harness.getMasterAddressesAsString,
-      "kudu.faultTolerantScan" -> "true")
-
-    val table = "faultTolerantScanTest"
-    sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(table)
-    val results = sqlContext.sql(s"SELECT * FROM $table").collectAsList()
-    assert(results.size() == rowCount)
-
-    assert(!results.get(0).isNullAt(2))
-    assert(results.get(1).isNullAt(2))
-  }
-
-  @Test
-  def testTableScanWithProjection() {
-    assertEquals(10, sqlContext.sql(s"""SELECT key FROM $tableName""").count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateDouble() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c3_double FROM $tableName where c3_double > "5.0"""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateLong() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c4_long FROM $tableName where c4_long > "5"""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateBool() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i % 2 == 0 },
-      sqlContext
-        .sql(s"""SELECT key, c5_bool FROM $tableName where c5_bool = true""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateShort() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c6_short FROM $tableName where c6_short > 5""")
-        .count())
-
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateFloat() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c7_float FROM $tableName where c7_float > 5""")
-        .count())
-
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateDecimal32() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c11_decimal32 FROM $tableName where c11_decimal32 > 5""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateDecimal64() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c12_decimal64 FROM $tableName where c12_decimal64 > 5""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicateDecimal128() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => i > 5 },
-      sqlContext
-        .sql(s"""SELECT key, c13_decimal128 FROM $tableName where c13_decimal128 > 5""")
-        .count())
-  }
-
-  @Test
-  def testTableScanWithProjectionAndPredicate() {
-    assertEquals(
-      rows.count { case (key, i, s, ts) => s != null && s > "5" },
-      sqlContext
-        .sql(s"""SELECT key FROM $tableName where c2_s > "5"""")
-        .count())
-
-    assertEquals(
-      rows.count { case (key, i, s, ts) => s != null },
-      sqlContext
-        .sql(s"""SELECT key, c2_s FROM $tableName where c2_s IS NOT NULL""")
-        .count())
-  }
-
-  @Test
-  def testBasicSparkSQL() {
-    val results = sqlContext.sql("SELECT * FROM " + tableName).collectAsList()
-    assert(results.size() == rowCount)
-
-    assert(results.get(1).isNullAt(2))
-    assert(!results.get(0).isNullAt(2))
-  }
-
-  @Test
-  def testBasicSparkSQLWithProjection() {
-    val results = sqlContext.sql("SELECT key FROM " + tableName).collectAsList()
-    assert(results.size() == rowCount)
-    assert(results.get(0).size.equals(1))
-    assert(results.get(0).getInt(0).equals(0))
-  }
-
-  @Test
-  def testBasicSparkSQLWithPredicate() {
-    val results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where key=1")
-      .collectAsList()
-    assert(results.size() == 1)
-    assert(results.get(0).size.equals(1))
-    assert(results.get(0).getInt(0).equals(1))
-
-  }
-
-  @Test
-  def testBasicSparkSQLWithTwoPredicates() {
-    val results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where key=2 and c2_s='2'")
-      .collectAsList()
-    assert(results.size() == 1)
-    assert(results.get(0).size.equals(1))
-    assert(results.get(0).getInt(0).equals(2))
-  }
-
-  @Test
-  def testBasicSparkSQLWithInListPredicate() {
-    val keys = Array(1, 5, 7)
-    val results = sqlContext
-      .sql(s"SELECT key FROM $tableName where key in (${keys.mkString(", ")})")
-      .collectAsList()
-    assert(results.size() == keys.length)
-    keys.zipWithIndex.foreach {
-      case (v, i) =>
-        assert(results.get(i).size.equals(1))
-        assert(results.get(i).getInt(0).equals(v))
-    }
-  }
-
-  @Test
-  def testBasicSparkSQLWithInListPredicateOnString() {
-    val keys = Array(1, 4, 6)
-    val results = sqlContext
-      .sql(s"SELECT key FROM $tableName where c2_s in (${keys.mkString("'", "', '", "'")})")
-      .collectAsList()
-    assert(results.size() == keys.count(_ % 2 == 0))
-    keys.filter(_ % 2 == 0).zipWithIndex.foreach {
-      case (v, i) =>
-        assert(results.get(i).size.equals(1))
-        assert(results.get(i).getInt(0).equals(v))
-    }
-  }
-
-  @Test
-  def testBasicSparkSQLWithInListAndComparisonPredicate() {
-    val keys = Array(1, 5, 7)
-    val results = sqlContext
-      .sql(s"SELECT key FROM $tableName where key>2 and key in (${keys.mkString(", ")})")
-      .collectAsList()
-    assert(results.size() == keys.count(_ > 2))
-    keys.filter(_ > 2).zipWithIndex.foreach {
-      case (v, i) =>
-        assert(results.get(i).size.equals(1))
-        assert(results.get(i).getInt(0).equals(v))
-    }
-  }
-
-  @Test
-  def testBasicSparkSQLWithTwoPredicatesNegative() {
-    val results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where key=1 and c2_s='2'")
-      .collectAsList()
-    assert(results.size() == 0)
-  }
-
-  @Test
-  def testBasicSparkSQLWithTwoPredicatesIncludingString() {
-    val results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where c2_s='2'")
-      .collectAsList()
-    assert(results.size() == 1)
-    assert(results.get(0).size.equals(1))
-    assert(results.get(0).getInt(0).equals(2))
-  }
-
-  @Test
-  def testBasicSparkSQLWithTwoPredicatesAndProjection() {
-    val results = sqlContext
-      .sql("SELECT key, c2_s FROM " + tableName + " where c2_s='2'")
-      .collectAsList()
-    assert(results.size() == 1)
-    assert(results.get(0).size.equals(2))
-    assert(results.get(0).getInt(0).equals(2))
-    assert(results.get(0).getString(1).equals("2"))
-  }
-
-  @Test
-  def testBasicSparkSQLWithTwoPredicatesGreaterThan() {
-    val results = sqlContext
-      .sql("SELECT key, c2_s FROM " + tableName + " where c2_s>='2'")
-      .collectAsList()
-    assert(results.size() == 4)
-    assert(results.get(0).size.equals(2))
-    assert(results.get(0).getInt(0).equals(2))
-    assert(results.get(0).getString(1).equals("2"))
-  }
-
-  @Test
-  def testSparkSQLStringStartsWithFilters() {
-    // This test requires a special table.
-    val testTableName = "startswith"
-    val schema = new Schema(
-      List(new ColumnSchemaBuilder("key", Type.STRING).key(true).build()).asJava)
-    val tableOptions = new CreateTableOptions()
-      .setRangePartitionColumns(List("key").asJava)
-      .setNumReplicas(1)
-    val testTable = kuduClient.createTable(testTableName, schema, tableOptions)
-
-    val kuduSession = kuduClient.newSession()
-    val chars = List('a', 'b', '乕', Char.MaxValue, '\u0000')
-    val keys = for {
-      x <- chars
-      y <- chars
-      z <- chars
-      w <- chars
-    } yield Array(x, y, z, w).mkString
-    keys.foreach { key =>
-      val insert = testTable.newInsert
-      val row = insert.getRow
-      val r = Array(1, 2, 3)
-      row.addString(0, key)
-      kuduSession.apply(insert)
-    }
-    val options: Map[String, String] =
-      Map("kudu.table" -> testTableName, "kudu.master" -> harness.getMasterAddressesAsString)
-    sqlContext.read.options(options).format("kudu").load.createOrReplaceTempView(testTableName)
-
-    val checkPrefixCount = { prefix: String =>
-      val results = sqlContext.sql(s"SELECT key FROM $testTableName WHERE key LIKE '$prefix%'")
-      assertEquals(keys.count(k => k.startsWith(prefix)), results.count())
-    }
-    // empty string
-    checkPrefixCount("")
-    // one character
-    for (x <- chars) {
-      checkPrefixCount(Array(x).mkString)
-    }
-    // all two character combos
-    for {
-      x <- chars
-      y <- chars
-    } {
-      checkPrefixCount(Array(x, y).mkString)
-    }
-  }
-
-  @Test
-  def testSparkSQLIsNullPredicate() {
-    var results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where c2_s IS NULL")
-      .collectAsList()
-    assert(results.size() == 5)
-
-    results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where key IS NULL")
-      .collectAsList()
-    assert(results.isEmpty())
-  }
-
-  @Test
-  def testSparkSQLIsNotNullPredicate() {
-    var results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where c2_s IS NOT NULL")
-      .collectAsList()
-    assert(results.size() == 5)
-
-    results = sqlContext
-      .sql("SELECT key FROM " + tableName + " where key IS NOT NULL")
-      .collectAsList()
-    assert(results.size() == 10)
-  }
-
-  @Test
-  def testSQLInsertInto() {
-    val insertTable = "insertintotest"
-
-    // read 0 rows just to get the schema
-    val df = sqlContext.sql(s"SELECT * FROM $tableName LIMIT 0")
-    kuduContext.createTable(
-      insertTable,
-      df.schema,
-      Seq("key"),
-      new CreateTableOptions()
-        .setRangePartitionColumns(List("key").asJava)
-        .setNumReplicas(1))
-
-    val newOptions: Map[String, String] =
-      Map("kudu.table" -> insertTable, "kudu.master" -> harness.getMasterAddressesAsString)
-    sqlContext.read
-      .options(newOptions)
-      .format("kudu")
-      .load
-      .createOrReplaceTempView(insertTable)
-
-    sqlContext.sql(s"INSERT INTO TABLE $insertTable SELECT * FROM $tableName")
-    val results =
-      sqlContext.sql(s"SELECT key FROM $insertTable").collectAsList()
-    assertEquals(10, results.size())
-  }
-
-  @Test
-  def testSQLInsertOverwriteUnsupported() {
-    val insertTable = "insertoverwritetest"
-
-    // read 0 rows just to get the schema
-    val df = sqlContext.sql(s"SELECT * FROM $tableName LIMIT 0")
-    kuduContext.createTable(
-      insertTable,
-      df.schema,
-      Seq("key"),
-      new CreateTableOptions()
-        .setRangePartitionColumns(List("key").asJava)
-        .setNumReplicas(1))
-
-    val newOptions: Map[String, String] =
-      Map("kudu.table" -> insertTable, "kudu.master" -> harness.getMasterAddressesAsString)
-    sqlContext.read
-      .options(newOptions)
-      .format("kudu")
-      .load
-      .createOrReplaceTempView(insertTable)
-
-    try {
-      sqlContext.sql(s"INSERT OVERWRITE TABLE $insertTable SELECT * FROM $tableName")
-      fail("insert overwrite should throw UnsupportedOperationException")
-    } catch {
-      case _: UnsupportedOperationException => // good
-      case NonFatal(_) =>
-        fail("insert overwrite should throw UnsupportedOperationException")
-    }
-  }
-
-  @Test
   def testWriteUsingDefaultSource() {
     val df = sqlContext.read.options(kuduOptions).format("kudu").load
 
@@ -935,22 +553,6 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
     }.getMessage should include("Unknown column: foo")
   }
 
-  @Test
-  def testScanLocality() {
-    kuduOptions = Map(
-      "kudu.table" -> tableName,
-      "kudu.master" -> harness.getMasterAddressesAsString,
-      "kudu.scanLocality" -> "closest_replica")
-
-    val table = "scanLocalityTest"
-    sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(table)
-    val results = sqlContext.sql(s"SELECT * FROM $table").collectAsList()
-    assert(results.size() == rowCount)
-
-    assert(!results.get(0).isNullAt(2))
-    assert(results.get(1).isNullAt(2))
-  }
-
   // Verify that the propagated timestamp is properly updated inside
   // the same client.
   @Test
@@ -1000,18 +602,6 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
   }
 
   /**
-   * Assuming that the only part of the logical plan is a Kudu scan, this
-   * function extracts the KuduRelation from the passed DataFrame for
-   * testing purposes.
-   */
-  def kuduRelationFromDataFrame(dataFrame: DataFrame) = {
-    val logicalPlan = dataFrame.queryExecution.logical
-    val logicalRelation = logicalPlan.asInstanceOf[LogicalRelation]
-    val baseRelation = logicalRelation.relation
-    baseRelation.asInstanceOf[KuduRelation]
-  }
-
-  /**
    * Verify that the kudu.scanRequestTimeoutMs parameter is parsed by the
    * DefaultSource and makes it into the KuduRelation as a configuration
    * parameter.
@@ -1028,35 +618,6 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
   }
 
   @Test
-  @TabletServerConfig(
-    flags = Array(
-      "--flush_threshold_mb=1",
-      "--flush_threshold_secs=1",
-      // Disable rowset compact to prevent DRSs being merged because they are too small.
-      "--enable_rowset_compaction=false"
-    ))
-  def testScanWithKeyRange() {
-    upsertRowsWithRowDataSize(table, rowCount * 100, 32 * 1024)
-
-    // Wait for mrs flushed
-    Thread.sleep(5 * 1000)
-
-    kuduOptions = Map(
-      "kudu.table" -> tableName,
-      "kudu.master" -> harness.getMasterAddressesAsString,
-      "kudu.splitSizeBytes" -> "1024")
-
-    // count the number of tasks that end.
-    val actualNumTasks = withJobTaskCounter(ss.sparkContext) { () =>
-      val t = "scanWithKeyRangeTest"
-      sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(t)
-      val results = sqlContext.sql(s"SELECT * FROM $t").collectAsList()
-      assertEquals(rowCount * 100, results.size())
-    }
-    assert(actualNumTasks > 2)
-  }
-
-  @Test
   @MasterServerConfig(
     flags = Array(
       "--mock_table_metrics_for_testing=true",
@@ -1068,62 +629,4 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
     val kuduRelation = kuduRelationFromDataFrame(dataFrame)
     assert(kuduRelation.sizeInBytes == 1024)
   }
-
-  @Test
-  @MasterServerConfig(
-    flags = Array(
-      "--mock_table_metrics_for_testing=true",
-      "--on_disk_size_for_testing=1024",
-      "--live_row_count_for_testing=100"
-    ))
-  def testJoinWithTableStatistics(): Unit = {
-    val df = sqlContext.read.options(kuduOptions).format("kudu").load
-
-    // 1. Create two tables.
-    val table1 = "table1"
-    kuduContext.createTable(
-      table1,
-      df.schema,
-      Seq("key"),
-      new CreateTableOptions()
-        .setRangePartitionColumns(List("key").asJava)
-        .setNumReplicas(1))
-    var options1: Map[String, String] =
-      Map("kudu.table" -> table1, "kudu.master" -> harness.getMasterAddressesAsString)
-    df.write.options(options1).mode("append").format("kudu").save
-    val df1 = sqlContext.read.options(options1).format("kudu").load
-    df1.createOrReplaceTempView(table1)
-
-    val table2 = "table2"
-    kuduContext.createTable(
-      table2,
-      df.schema,
-      Seq("key"),
-      new CreateTableOptions()
-        .setRangePartitionColumns(List("key").asJava)
-        .setNumReplicas(1))
-    var options2: Map[String, String] =
-      Map("kudu.table" -> table2, "kudu.master" -> harness.getMasterAddressesAsString)
-    df.write.options(options2).mode("append").format("kudu").save
-    val df2 = sqlContext.read.options(options2).format("kudu").load
-    df2.createOrReplaceTempView(table2)
-
-    // 2. Get the table statistics of each table and verify.
-    val relation1 = kuduRelationFromDataFrame(df1)
-    val relation2 = kuduRelationFromDataFrame(df2)
-    assert(relation1.sizeInBytes == relation2.sizeInBytes)
-    assert(relation1.sizeInBytes == 1024)
-
-    // 3. Test join with table size should be able to broadcast.
-    val sqlStr = s"SELECT * FROM $table1 JOIN $table2 ON $table1.key = $table2.key"
-    var physical = sqlContext.sql(sqlStr).queryExecution.sparkPlan
-    var operators = physical.collect {
-      case j: BroadcastHashJoinExec => j
-    }
-    assert(operators.size == 1)
-
-    // Verify result.
-    var results = sqlContext.sql(sqlStr).collectAsList()
-    assert(results.size() == rowCount)
-  }
 }
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
index 54f0ffa..61e1069 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
@@ -31,6 +31,8 @@ import org.apache.kudu.Schema
 import org.apache.kudu.Type
 import org.apache.kudu.test.KuduTestHarness
 import org.apache.kudu.util.DecimalUtil
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
 import org.junit.After
 import org.junit.Before
@@ -229,4 +231,16 @@ trait KuduTestSuite extends JUnitSuite {
     }
     rows
   }
+
+  /**
+   * Assuming that the only part of the logical plan is a Kudu scan, this
+   * function extracts the KuduRelation from the passed DataFrame for
+   * testing purposes.
+   */
+  def kuduRelationFromDataFrame(dataFrame: DataFrame) = {
+    val logicalPlan = dataFrame.queryExecution.logical
+    val logicalRelation = logicalPlan.asInstanceOf[LogicalRelation]
+    val baseRelation = logicalRelation.relation
+    baseRelation.asInstanceOf[KuduRelation]
+  }
 }
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
new file mode 100644
index 0000000..a6aaf41
--- /dev/null
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
@@ -0,0 +1,534 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kudu.spark.kudu
+
+import scala.collection.JavaConverters._
+import scala.collection.immutable.IndexedSeq
+import scala.util.control.NonFatal
+import org.apache.spark.sql.SQLContext
+import org.junit.Assert._
+import org.scalatest.Matchers
+import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder
+import org.apache.kudu.client.CreateTableOptions
+import org.apache.kudu.Schema
+import org.apache.kudu.Type
+import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
+import org.apache.kudu.test.KuduTestHarness.MasterServerConfig
+import org.apache.kudu.test.KuduTestHarness.TabletServerConfig
+import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
+import org.junit.Before
+import org.junit.Test
+
+class SparkSQLTest extends KuduTestSuite with Matchers {
+  val rowCount = 10
+  var sqlContext: SQLContext = _
+  var rows: IndexedSeq[(Int, Int, String, Long)] = _
+  var kuduOptions: Map[String, String] = _
+
+  @Before
+  def setUp(): Unit = {
+    rows = insertRows(table, rowCount)
+
+    sqlContext = ss.sqlContext
+
+    kuduOptions =
+      Map("kudu.table" -> tableName, "kudu.master" -> harness.getMasterAddressesAsString)
+
+    sqlContext.read
+      .options(kuduOptions)
+      .format("kudu")
+      .load()
+      .createOrReplaceTempView(tableName)
+  }
+
+  @Test
+  def testBasicSparkSQL() {
+    val results = sqlContext.sql("SELECT * FROM " + tableName).collectAsList()
+    assert(results.size() == rowCount)
+
+    assert(results.get(1).isNullAt(2))
+    assert(!results.get(0).isNullAt(2))
+  }
+
+  @Test
+  def testBasicSparkSQLWithProjection() {
+    val results = sqlContext.sql("SELECT key FROM " + tableName).collectAsList()
+    assert(results.size() == rowCount)
+    assert(results.get(0).size.equals(1))
+    assert(results.get(0).getInt(0).equals(0))
+  }
+
+  @Test
+  def testBasicSparkSQLWithPredicate() {
+    val results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where key=1")
+      .collectAsList()
+    assert(results.size() == 1)
+    assert(results.get(0).size.equals(1))
+    assert(results.get(0).getInt(0).equals(1))
+
+  }
+
+  @Test
+  def testBasicSparkSQLWithTwoPredicates() {
+    val results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where key=2 and c2_s='2'")
+      .collectAsList()
+    assert(results.size() == 1)
+    assert(results.get(0).size.equals(1))
+    assert(results.get(0).getInt(0).equals(2))
+  }
+
+  @Test
+  def testBasicSparkSQLWithInListPredicate() {
+    val keys = Array(1, 5, 7)
+    val results = sqlContext
+      .sql(s"SELECT key FROM $tableName where key in (${keys.mkString(", ")})")
+      .collectAsList()
+    assert(results.size() == keys.length)
+    keys.zipWithIndex.foreach {
+      case (v, i) =>
+        assert(results.get(i).size.equals(1))
+        assert(results.get(i).getInt(0).equals(v))
+    }
+  }
+
+  @Test
+  def testBasicSparkSQLWithInListPredicateOnString() {
+    val keys = Array(1, 4, 6)
+    val results = sqlContext
+      .sql(s"SELECT key FROM $tableName where c2_s in (${keys.mkString("'", "', '", "'")})")
+      .collectAsList()
+    assert(results.size() == keys.count(_ % 2 == 0))
+    keys.filter(_ % 2 == 0).zipWithIndex.foreach {
+      case (v, i) =>
+        assert(results.get(i).size.equals(1))
+        assert(results.get(i).getInt(0).equals(v))
+    }
+  }
+
+  @Test
+  def testBasicSparkSQLWithInListAndComparisonPredicate() {
+    val keys = Array(1, 5, 7)
+    val results = sqlContext
+      .sql(s"SELECT key FROM $tableName where key>2 and key in (${keys.mkString(", ")})")
+      .collectAsList()
+    assert(results.size() == keys.count(_ > 2))
+    keys.filter(_ > 2).zipWithIndex.foreach {
+      case (v, i) =>
+        assert(results.get(i).size.equals(1))
+        assert(results.get(i).getInt(0).equals(v))
+    }
+  }
+
+  @Test
+  def testBasicSparkSQLWithTwoPredicatesNegative() {
+    val results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where key=1 and c2_s='2'")
+      .collectAsList()
+    assert(results.size() == 0)
+  }
+
+  @Test
+  def testBasicSparkSQLWithTwoPredicatesIncludingString() {
+    val results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where c2_s='2'")
+      .collectAsList()
+    assert(results.size() == 1)
+    assert(results.get(0).size.equals(1))
+    assert(results.get(0).getInt(0).equals(2))
+  }
+
+  @Test
+  def testBasicSparkSQLWithTwoPredicatesAndProjection() {
+    val results = sqlContext
+      .sql("SELECT key, c2_s FROM " + tableName + " where c2_s='2'")
+      .collectAsList()
+    assert(results.size() == 1)
+    assert(results.get(0).size.equals(2))
+    assert(results.get(0).getInt(0).equals(2))
+    assert(results.get(0).getString(1).equals("2"))
+  }
+
+  @Test
+  def testBasicSparkSQLWithTwoPredicatesGreaterThan() {
+    val results = sqlContext
+      .sql("SELECT key, c2_s FROM " + tableName + " where c2_s>='2'")
+      .collectAsList()
+    assert(results.size() == 4)
+    assert(results.get(0).size.equals(2))
+    assert(results.get(0).getInt(0).equals(2))
+    assert(results.get(0).getString(1).equals("2"))
+  }
+
+  @Test
+  def testSparkSQLStringStartsWithFilters() {
+    // This test requires a special table.
+    val testTableName = "startswith"
+    val schema = new Schema(
+      List(new ColumnSchemaBuilder("key", Type.STRING).key(true).build()).asJava)
+    val tableOptions = new CreateTableOptions()
+      .setRangePartitionColumns(List("key").asJava)
+      .setNumReplicas(1)
+    val testTable = kuduClient.createTable(testTableName, schema, tableOptions)
+
+    val kuduSession = kuduClient.newSession()
+    val chars = List('a', 'b', '乕', Char.MaxValue, '\u0000')
+    val keys = for {
+      x <- chars
+      y <- chars
+      z <- chars
+      w <- chars
+    } yield Array(x, y, z, w).mkString
+    keys.foreach { key =>
+      val insert = testTable.newInsert
+      val row = insert.getRow
+      val r = Array(1, 2, 3)
+      row.addString(0, key)
+      kuduSession.apply(insert)
+    }
+    val options: Map[String, String] =
+      Map("kudu.table" -> testTableName, "kudu.master" -> harness.getMasterAddressesAsString)
+    sqlContext.read.options(options).format("kudu").load.createOrReplaceTempView(testTableName)
+
+    val checkPrefixCount = { prefix: String =>
+      val results = sqlContext.sql(s"SELECT key FROM $testTableName WHERE key LIKE '$prefix%'")
+      assertEquals(keys.count(k => k.startsWith(prefix)), results.count())
+    }
+    // empty string
+    checkPrefixCount("")
+    // one character
+    for (x <- chars) {
+      checkPrefixCount(Array(x).mkString)
+    }
+    // all two character combos
+    for {
+      x <- chars
+      y <- chars
+    } {
+      checkPrefixCount(Array(x, y).mkString)
+    }
+  }
+
+  @Test
+  def testSparkSQLIsNullPredicate() {
+    var results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where c2_s IS NULL")
+      .collectAsList()
+    assert(results.size() == 5)
+
+    results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where key IS NULL")
+      .collectAsList()
+    assert(results.isEmpty())
+  }
+
+  @Test
+  def testSparkSQLIsNotNullPredicate() {
+    var results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where c2_s IS NOT NULL")
+      .collectAsList()
+    assert(results.size() == 5)
+
+    results = sqlContext
+      .sql("SELECT key FROM " + tableName + " where key IS NOT NULL")
+      .collectAsList()
+    assert(results.size() == 10)
+  }
+
+  @Test
+  def testSQLInsertInto() {
+    val insertTable = "insertintotest"
+
+    // read 0 rows just to get the schema
+    val df = sqlContext.sql(s"SELECT * FROM $tableName LIMIT 0")
+    kuduContext.createTable(
+      insertTable,
+      df.schema,
+      Seq("key"),
+      new CreateTableOptions()
+        .setRangePartitionColumns(List("key").asJava)
+        .setNumReplicas(1))
+
+    val newOptions: Map[String, String] =
+      Map("kudu.table" -> insertTable, "kudu.master" -> harness.getMasterAddressesAsString)
+    sqlContext.read
+      .options(newOptions)
+      .format("kudu")
+      .load
+      .createOrReplaceTempView(insertTable)
+
+    sqlContext.sql(s"INSERT INTO TABLE $insertTable SELECT * FROM $tableName")
+    val results =
+      sqlContext.sql(s"SELECT key FROM $insertTable").collectAsList()
+    assertEquals(10, results.size())
+  }
+
+  @Test
+  def testSQLInsertOverwriteUnsupported() {
+    val insertTable = "insertoverwritetest"
+
+    // read 0 rows just to get the schema
+    val df = sqlContext.sql(s"SELECT * FROM $tableName LIMIT 0")
+    kuduContext.createTable(
+      insertTable,
+      df.schema,
+      Seq("key"),
+      new CreateTableOptions()
+        .setRangePartitionColumns(List("key").asJava)
+        .setNumReplicas(1))
+
+    val newOptions: Map[String, String] =
+      Map("kudu.table" -> insertTable, "kudu.master" -> harness.getMasterAddressesAsString)
+    sqlContext.read
+      .options(newOptions)
+      .format("kudu")
+      .load
+      .createOrReplaceTempView(insertTable)
+
+    try {
+      sqlContext.sql(s"INSERT OVERWRITE TABLE $insertTable SELECT * FROM $tableName")
+      fail("insert overwrite should throw UnsupportedOperationException")
+    } catch {
+      case _: UnsupportedOperationException => // good
+      case NonFatal(_) =>
+        fail("insert overwrite should throw UnsupportedOperationException")
+    }
+  }
+
+  @Test
+  def testTableScanWithProjection() {
+    assertEquals(10, sqlContext.sql(s"""SELECT key FROM $tableName""").count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateDouble() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c3_double FROM $tableName where c3_double > "5.0"""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateLong() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c4_long FROM $tableName where c4_long > "5"""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateBool() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i % 2 == 0 },
+      sqlContext
+        .sql(s"""SELECT key, c5_bool FROM $tableName where c5_bool = true""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateShort() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c6_short FROM $tableName where c6_short > 5""")
+        .count())
+
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateFloat() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c7_float FROM $tableName where c7_float > 5""")
+        .count())
+
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateDecimal32() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c11_decimal32 FROM $tableName where c11_decimal32 > 5""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateDecimal64() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c12_decimal64 FROM $tableName where c12_decimal64 > 5""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicateDecimal128() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => i > 5 },
+      sqlContext
+        .sql(s"""SELECT key, c13_decimal128 FROM $tableName where c13_decimal128 > 5""")
+        .count())
+  }
+
+  @Test
+  def testTableScanWithProjectionAndPredicate() {
+    assertEquals(
+      rows.count { case (key, i, s, ts) => s != null && s > "5" },
+      sqlContext
+        .sql(s"""SELECT key FROM $tableName where c2_s > "5"""")
+        .count())
+
+    assertEquals(
+      rows.count { case (key, i, s, ts) => s != null },
+      sqlContext
+        .sql(s"""SELECT key, c2_s FROM $tableName where c2_s IS NOT NULL""")
+        .count())
+  }
+
+  @Test
+  def testScanLocality() {
+    kuduOptions = Map(
+      "kudu.table" -> tableName,
+      "kudu.master" -> harness.getMasterAddressesAsString,
+      "kudu.scanLocality" -> "closest_replica")
+
+    val table = "scanLocalityTest"
+    sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(table)
+    val results = sqlContext.sql(s"SELECT * FROM $table").collectAsList()
+    assert(results.size() == rowCount)
+
+    assert(!results.get(0).isNullAt(2))
+    assert(results.get(1).isNullAt(2))
+  }
+
+  @Test
+  def testTableNonFaultTolerantScan() {
+    val results = sqlContext.sql(s"SELECT * FROM $tableName").collectAsList()
+    assert(results.size() == rowCount)
+
+    assert(!results.get(0).isNullAt(2))
+    assert(results.get(1).isNullAt(2))
+  }
+
+  @Test
+  def testTableFaultTolerantScan() {
+    kuduOptions = Map(
+      "kudu.table" -> tableName,
+      "kudu.master" -> harness.getMasterAddressesAsString,
+      "kudu.faultTolerantScan" -> "true")
+
+    val table = "faultTolerantScanTest"
+    sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(table)
+    val results = sqlContext.sql(s"SELECT * FROM $table").collectAsList()
+    assert(results.size() == rowCount)
+
+    assert(!results.get(0).isNullAt(2))
+    assert(results.get(1).isNullAt(2))
+  }
+
+  @Test
+  @TabletServerConfig(
+    flags = Array(
+      "--flush_threshold_mb=1",
+      "--flush_threshold_secs=1",
+      // Disable rowset compact to prevent DRSs being merged because they are too small.
+      "--enable_rowset_compaction=false"
+    ))
+  def testScanWithKeyRange() {
+    upsertRowsWithRowDataSize(table, rowCount * 100, 32 * 1024)
+
+    // Wait for mrs flushed
+    Thread.sleep(5 * 1000)
+
+    kuduOptions = Map(
+      "kudu.table" -> tableName,
+      "kudu.master" -> harness.getMasterAddressesAsString,
+      "kudu.splitSizeBytes" -> "1024")
+
+    // count the number of tasks that end.
+    val actualNumTasks = withJobTaskCounter(ss.sparkContext) { () =>
+      val t = "scanWithKeyRangeTest"
+      sqlContext.read.options(kuduOptions).format("kudu").load.createOrReplaceTempView(t)
+      val results = sqlContext.sql(s"SELECT * FROM $t").collectAsList()
+      assertEquals(rowCount * 100, results.size())
+    }
+    assert(actualNumTasks > 2)
+  }
+
+  @Test
+  @MasterServerConfig(
+    flags = Array(
+      "--mock_table_metrics_for_testing=true",
+      "--on_disk_size_for_testing=1024",
+      "--live_row_count_for_testing=100"
+    ))
+  def testJoinWithTableStatistics(): Unit = {
+    val df = sqlContext.read.options(kuduOptions).format("kudu").load
+
+    // 1. Create two tables.
+    val table1 = "table1"
+    kuduContext.createTable(
+      table1,
+      df.schema,
+      Seq("key"),
+      new CreateTableOptions()
+        .setRangePartitionColumns(List("key").asJava)
+        .setNumReplicas(1))
+    var options1: Map[String, String] =
+      Map("kudu.table" -> table1, "kudu.master" -> harness.getMasterAddressesAsString)
+    df.write.options(options1).mode("append").format("kudu").save
+    val df1 = sqlContext.read.options(options1).format("kudu").load
+    df1.createOrReplaceTempView(table1)
+
+    val table2 = "table2"
+    kuduContext.createTable(
+      table2,
+      df.schema,
+      Seq("key"),
+      new CreateTableOptions()
+        .setRangePartitionColumns(List("key").asJava)
+        .setNumReplicas(1))
+    var options2: Map[String, String] =
+      Map("kudu.table" -> table2, "kudu.master" -> harness.getMasterAddressesAsString)
+    df.write.options(options2).mode("append").format("kudu").save
+    val df2 = sqlContext.read.options(options2).format("kudu").load
+    df2.createOrReplaceTempView(table2)
+
+    // 2. Get the table statistics of each table and verify.
+    val relation1 = kuduRelationFromDataFrame(df1)
+    val relation2 = kuduRelationFromDataFrame(df2)
+    assert(relation1.sizeInBytes == relation2.sizeInBytes)
+    assert(relation1.sizeInBytes == 1024)
+
+    // 3. Test join with table size should be able to broadcast.
+    val sqlStr = s"SELECT * FROM $table1 JOIN $table2 ON $table1.key = $table2.key"
+    var physical = sqlContext.sql(sqlStr).queryExecution.sparkPlan
+    var operators = physical.collect {
+      case j: BroadcastHashJoinExec => j
+    }
+    assert(operators.size == 1)
+
+    // Verify result.
+    var results = sqlContext.sql(sqlStr).collectAsList()
+    assert(results.size() == rowCount)
+  }
+}