You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ch...@apache.org on 2022/06/01 10:16:17 UTC

[incubator-kyuubi] branch master updated: [KYUUBI #2788] Add excludeDatabases for TPC-H catalogs

This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 05ee19643 [KYUUBI #2788] Add excludeDatabases for TPC-H catalogs
05ee19643 is described below

commit 05ee19643fe688a21e177f5e07847624b6c730d7
Author: jiaoqingbo <11...@qq.com>
AuthorDate: Wed Jun 1 18:16:07 2022 +0800

    [KYUUBI #2788] Add excludeDatabases for TPC-H catalogs
    
    ### _Why are the changes needed?_
    
    fix #2788
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #2797 from jiaoqingbo/kyuubi-2788.
    
    Closes #2788
    
    94cc1a5a [jiaoqingbo] [KYUUBI #2788] Add excludeDatabases for TPC-H catalogs
    
    Authored-by: jiaoqingbo <11...@qq.com>
    Signed-off-by: Cheng Pan <ch...@apache.org>
---
 .../kyuubi/spark/connector/tpch/TPCHCatalog.scala  |  18 ++-
 .../spark/connector/tpch/TPCHCatalogSuite.scala    | 121 +++++++++++++++------
 2 files changed, 101 insertions(+), 38 deletions(-)

diff --git a/extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalog.scala b/extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalog.scala
index 909eaf39f..b347fe1d3 100644
--- a/extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalog.scala
+++ b/extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalog.scala
@@ -21,15 +21,16 @@ import java.util
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException}
 import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table => SparkTable, TableCatalog, TableChange}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
-class TPCHCatalog extends TableCatalog with SupportsNamespaces {
+class TPCHCatalog extends TableCatalog with SupportsNamespaces with Logging {
 
-  val databases: Array[String] = TPCHSchemaUtils.DATABASES
+  var databases: Array[String] = _
 
   val tables: Array[String] = TPCHSchemaUtils.BASE_TABLES.map(_.getTableName)
 
@@ -40,8 +41,19 @@ class TPCHCatalog extends TableCatalog with SupportsNamespaces {
   override def name: String = _name
 
   override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
-    this.options = options
     this._name = name
+    this.options = options
+    val uncheckedExcludeDatabases = options.getOrDefault("excludeDatabases", "")
+      .split(",").map(_.toLowerCase.trim).filter(_.nonEmpty)
+    val invalidExcludeDatabases = uncheckedExcludeDatabases diff TPCHSchemaUtils.DATABASES
+    if (invalidExcludeDatabases.nonEmpty) {
+      logWarning(
+        s"""Ignore unknown databases ${invalidExcludeDatabases.mkString(", ")} in excluding
+           |list. All known databases are ${TPCHSchemaUtils.BASE_TABLES.mkString(", ")}
+           |""".stripMargin)
+    }
+    val excludeDatabase = uncheckedExcludeDatabases diff invalidExcludeDatabases
+    this.databases = TPCHSchemaUtils.DATABASES diff excludeDatabase
   }
 
   override def listTables(namespace: Array[String]): Array[Identifier] = namespace match {
diff --git a/extensions/spark/kyuubi-spark-connector-tpch/src/test/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalogSuite.scala b/extensions/spark/kyuubi-spark-connector-tpch/src/test/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalogSuite.scala
index 77c49ae93..abb3581cb 100644
--- a/extensions/spark/kyuubi-spark-connector-tpch/src/test/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalogSuite.scala
+++ b/extensions/spark/kyuubi-spark-connector-tpch/src/test/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalogSuite.scala
@@ -17,10 +17,12 @@
 
 package org.apache.kyuubi.spark.connector.tpch
 
+import org.apache.spark.SparkConf
 import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 import org.apache.kyuubi.KyuubiFunSuite
+import org.apache.kyuubi.spark.connector.common.LocalSparkSession.withSparkSession
 import org.apache.kyuubi.spark.connector.common.SparkUtils
 
 class TPCHCatalogSuite extends KyuubiFunSuite {
@@ -44,52 +46,101 @@ class TPCHCatalogSuite extends KyuubiFunSuite {
   }
 
   test("supports namespaces") {
-    spark.sql("use tpch")
-    assert(spark.sql(s"SHOW DATABASES").collect().length == 12)
-    assert(spark.sql(s"SHOW NAMESPACES IN tpch.sf1").collect().length == 0)
+    val sparkConf = new SparkConf()
+      .setMaster("local[*]")
+      .set("spark.ui.enabled", "false")
+      .set("spark.sql.catalogImplementation", "in-memory")
+      .set("spark.sql.catalog.tpch", classOf[TPCHCatalog].getName)
+      .set("spark.sql.cbo.enabled", "true")
+      .set("spark.sql.cbo.planStats.enabled", "true")
+    withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+      spark.sql("USE tpch")
+      assert(spark.sql(s"SHOW DATABASES").collect().length == 12)
+      assert(spark.sql(s"SHOW NAMESPACES IN tpch.sf1").collect().length == 0)
+    }
+  }
+
+  test("exclude databases") {
+    Seq(
+      "TINY,sf10" -> Seq("tiny", "sf10"),
+      "sf1 , " -> Seq("sf1"),
+      "none" -> Seq.empty[String]).foreach { case (confValue, expectedExcludeDatabases) =>
+      val sparkConf = new SparkConf().setMaster("local[*]")
+        .set("spark.ui.enabled", "false")
+        .set("spark.sql.catalogImplementation", "in-memory")
+        .set("spark.sql.catalog.tpch", classOf[TPCHCatalog].getName)
+        .set("spark.sql.catalog.tpch.excludeDatabases", confValue)
+      withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+        spark.sql("USE tpch")
+        assert(
+          spark.sql(s"SHOW DATABASES").collect.map(_.getString(0)).sorted ===
+            (TPCHSchemaUtils.DATABASES diff expectedExcludeDatabases).sorted)
+      }
+    }
   }
 
   test("tpch.tiny count") {
-    assert(spark.table("tpch.tiny.customer").count === 1500)
-    assert(spark.table("tpch.tiny.orders").count === 15000)
-    assert(spark.table("tpch.tiny.lineitem").count === 60175)
-    assert(spark.table("tpch.tiny.part").count === 2000)
-    assert(spark.table("tpch.tiny.partsupp").count === 8000)
-    assert(spark.table("tpch.tiny.supplier").count === 100)
-    assert(spark.table("tpch.tiny.nation").count === 25)
-    assert(spark.table("tpch.tiny.region").count === 5)
+    val sparkConf = new SparkConf()
+      .setMaster("local[*]")
+      .set("spark.ui.enabled", "false")
+      .set("spark.sql.catalogImplementation", "in-memory")
+      .set("spark.sql.catalog.tpch", classOf[TPCHCatalog].getName)
+      .set("spark.sql.cbo.enabled", "true")
+      .set("spark.sql.cbo.planStats.enabled", "true")
+    withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+      assert(spark.table("tpch.tiny.customer").count === 1500)
+      assert(spark.table("tpch.tiny.orders").count === 15000)
+      assert(spark.table("tpch.tiny.lineitem").count === 60175)
+      assert(spark.table("tpch.tiny.part").count === 2000)
+      assert(spark.table("tpch.tiny.partsupp").count === 8000)
+      assert(spark.table("tpch.tiny.supplier").count === 100)
+      assert(spark.table("tpch.tiny.nation").count === 25)
+      assert(spark.table("tpch.tiny.region").count === 5)
+    }
   }
 
   test("tpch.sf1 stats") {
-    def assertStats(tableName: String, sizeInBytes: BigInt, rowCount: BigInt): Unit = {
-      val stats = spark.table(tableName).queryExecution.analyzed.stats
-      assert(stats.sizeInBytes == sizeInBytes)
-      // stats.rowCount only has value after SPARK-33954
-      if (SparkUtils.isSparkVersionAtLeast("3.2")) {
-        assert(stats.rowCount.contains(rowCount), tableName)
+    val sparkConf = new SparkConf()
+      .setMaster("local[*]")
+      .set("spark.ui.enabled", "false")
+      .set("spark.sql.catalogImplementation", "in-memory")
+      .set("spark.sql.catalog.tpch", classOf[TPCHCatalog].getName)
+      .set("spark.sql.cbo.enabled", "true")
+      .set("spark.sql.cbo.planStats.enabled", "true")
+    withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+      def assertStats(tableName: String, sizeInBytes: BigInt, rowCount: BigInt): Unit = {
+        val stats = spark.table(tableName).queryExecution.analyzed.stats
+        assert(stats.sizeInBytes == sizeInBytes)
+        // stats.rowCount only has value after SPARK-33954
+        if (SparkUtils.isSparkVersionAtLeast("3.2")) {
+          assert(stats.rowCount.contains(rowCount), tableName)
+        }
       }
-    }
-
-    assertStats("tpch.sf1.customer", 26850000, 150000)
-    assertStats("tpch.sf1.orders", 156000000, 1500000)
-    assertStats("tpch.sf1.lineitem", 672136080, 6001215)
-    assertStats("tpch.sf1.part", 31000000, 200000)
-    assertStats("tpch.sf1.partsupp", 115200000, 800000)
-    assertStats("tpch.sf1.supplier", 1590000, 10000)
-    assertStats("tpch.sf1.nation", 3200, 25)
-    assertStats("tpch.sf1.region", 620, 5)
+      assertStats("tpch.sf1.customer", 26850000, 150000)
+      assertStats("tpch.sf1.orders", 156000000, 1500000)
+      assertStats("tpch.sf1.lineitem", 672136080, 6001215)
+      assertStats("tpch.sf1.part", 31000000, 200000)
+      assertStats("tpch.sf1.partsupp", 115200000, 800000)
+      assertStats("tpch.sf1.supplier", 1590000, 10000)
+      assertStats("tpch.sf1.nation", 3200, 25)
+      assertStats("tpch.sf1.region", 620, 5)
 
+    }
   }
 
   test("nonexistent table") {
-    val exception = intercept[AnalysisException] {
-      spark.table("tpch.sf1.nonexistent_table")
+    val sparkConf = new SparkConf()
+      .setMaster("local[*]")
+      .set("spark.ui.enabled", "false")
+      .set("spark.sql.catalogImplementation", "in-memory")
+      .set("spark.sql.catalog.tpch", classOf[TPCHCatalog].getName)
+      .set("spark.sql.cbo.enabled", "true")
+      .set("spark.sql.cbo.planStats.enabled", "true")
+    withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+      val exception = intercept[AnalysisException] {
+        spark.table("tpch.sf1.nonexistent_table")
+      }
+      assert(exception.message === "Table or view not found: tpch.sf1.nonexistent_table")
     }
-    assert(exception.message === "Table or view not found: tpch.sf1.nonexistent_table")
-  }
-
-  override def afterAll(): Unit = {
-    super.afterAll()
-    spark.stop()
   }
 }