You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/02/13 03:08:11 UTC

spark git commit: [SPARK-3299][SQL]Public API in SQLContext to list tables

Repository: spark
Updated Branches:
  refs/heads/master c025a4688 -> 1d0596a16


[SPARK-3299][SQL]Public API in SQLContext to list tables

https://issues.apache.org/jira/browse/SPARK-3299

Author: Yin Huai <yh...@databricks.com>

Closes #4547 from yhuai/tables and squashes the following commits:

6c8f92e [Yin Huai] Add tableNames.
acbb281 [Yin Huai] Update Python test.
7793dcb [Yin Huai] Fix scala test.
572870d [Yin Huai] Address comments.
aba2e88 [Yin Huai] Format.
12c86df [Yin Huai] Add tables() to SQLContext to return a DataFrame containing existing tables.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d0596a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d0596a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d0596a1

Branch: refs/heads/master
Commit: 1d0596a16e1d3add2631f5d8169aeec2876a1362
Parents: c025a46
Author: Yin Huai <yh...@databricks.com>
Authored: Thu Feb 12 18:08:01 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu Feb 12 18:08:01 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/context.py                   | 34 +++++++++
 .../spark/sql/catalyst/analysis/Catalog.scala   | 37 ++++++++++
 .../scala/org/apache/spark/sql/SQLContext.scala | 36 +++++++++
 .../org/apache/spark/sql/ListTablesSuite.scala  | 76 +++++++++++++++++++
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |  5 ++
 .../apache/spark/sql/hive/ListTablesSuite.scala | 77 ++++++++++++++++++++
 6 files changed, 265 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index db4bcbe..082f1b6 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -621,6 +621,40 @@ class SQLContext(object):
         """
         return DataFrame(self._ssql_ctx.table(tableName), self)
 
+    def tables(self, dbName=None):
+        """Returns a DataFrame containing names of tables in the given database.
+
+        If `dbName` is not specified, the current database will be used.
+
+        The returned DataFrame has two columns, tableName and isTemporary
+        (a column with BooleanType indicating if a table is a temporary one or not).
+
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.tables()
+        >>> df2.filter("tableName = 'table1'").first()
+        Row(tableName=u'table1', isTemporary=True)
+        """
+        if dbName is None:
+            return DataFrame(self._ssql_ctx.tables(), self)
+        else:
+            return DataFrame(self._ssql_ctx.tables(dbName), self)
+
+    def tableNames(self, dbName=None):
+        """Returns a list of names of tables in the database `dbName`.
+
+        If `dbName` is not specified, the current database will be used.
+
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> "table1" in sqlCtx.tableNames()
+        True
+        >>> "table1" in sqlCtx.tableNames("db")
+        True
+        """
+        if dbName is None:
+            return [name for name in self._ssql_ctx.tableNames()]
+        else:
+            return [name for name in self._ssql_ctx.tableNames(dbName)]
+
     def cacheTable(self, tableName):
         """Caches the specified table in-memory."""
         self._ssql_ctx.cacheTable(tableName)

http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index df8d03b..f57eab2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -34,6 +34,12 @@ trait Catalog {
       tableIdentifier: Seq[String],
       alias: Option[String] = None): LogicalPlan
 
+  /**
+   * Returns tuples of (tableName, isTemporary) for all tables in the given database.
+   * isTemporary is a Boolean value indicates if a table is a temporary or not.
+   */
+  def getTables(databaseName: Option[String]): Seq[(String, Boolean)]
+
   def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
 
   def unregisterTable(tableIdentifier: Seq[String]): Unit
@@ -101,6 +107,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
     // properly qualified with this alias.
     alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
   }
+
+  override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+    tables.map {
+      case (name, _) => (name, true)
+    }.toSeq
+  }
 }
 
 /**
@@ -137,6 +149,27 @@ trait OverrideCatalog extends Catalog {
     withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
   }
 
+  abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+    val dbName = if (!caseSensitive) {
+      if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
+    } else {
+      databaseName
+    }
+
+    val temporaryTables = overrides.filter {
+      // If a temporary table does not have an associated database, we should return its name.
+      case ((None, _), _) => true
+      // If a temporary table does have an associated database, we should return it if the database
+      // matches the given database name.
+      case ((db: Some[String], _), _) if db == dbName => true
+      case _ => false
+    }.map {
+      case ((_, tableName), _) => (tableName, true)
+    }.toSeq
+
+    temporaryTables ++ super.getTables(databaseName)
+  }
+
   override def registerTable(
       tableIdentifier: Seq[String],
       plan: LogicalPlan): Unit = {
@@ -172,6 +205,10 @@ object EmptyCatalog extends Catalog {
     throw new UnsupportedOperationException
   }
 
+  override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+    throw new UnsupportedOperationException
+  }
+
   def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
     throw new UnsupportedOperationException
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8aae222..0f8af75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -774,6 +774,42 @@ class SQLContext(@transient val sparkContext: SparkContext)
   def table(tableName: String): DataFrame =
     DataFrame(this, catalog.lookupRelation(Seq(tableName)))
 
+  /**
+   * Returns a [[DataFrame]] containing names of existing tables in the given database.
+   * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
+   * indicating if a table is a temporary one or not).
+   */
+  def tables(): DataFrame = {
+    createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
+  }
+
+  /**
+   * Returns a [[DataFrame]] containing names of existing tables in the current database.
+   * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
+   * indicating if a table is a temporary one or not).
+   */
+  def tables(databaseName: String): DataFrame = {
+    createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
+  }
+
+  /**
+   * Returns an array of names of tables in the current database.
+   */
+  def tableNames(): Array[String] = {
+    catalog.getTables(None).map {
+      case (tableName, _) => tableName
+    }.toArray
+  }
+
+  /**
+   * Returns an array of names of tables in the given database.
+   */
+  def tableNames(databaseName: String): Array[String] = {
+    catalog.getTables(Some(databaseName)).map {
+      case (tableName, _) => tableName
+    }.toArray
+  }
+
   protected[sql] class SparkPlanner extends SparkStrategies {
     val sparkContext: SparkContext = self.sparkContext
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
new file mode 100644
index 0000000..5fc3534
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -0,0 +1,76 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
+
+class ListTablesSuite extends QueryTest with BeforeAndAfter {
+
+  import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+  val df =
+    sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+
+  before {
+    df.registerTempTable("ListTablesSuiteTable")
+  }
+
+  after {
+    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+  }
+
+  test("get all tables") {
+    checkAnswer(
+      tables().filter("tableName = 'ListTablesSuiteTable'"),
+      Row("ListTablesSuiteTable", true))
+
+    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+  }
+
+  test("getting all Tables with a database name has no impact on returned table names") {
+    checkAnswer(
+      tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+      Row("ListTablesSuiteTable", true))
+
+    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+  }
+
+  test("query the returned DataFrame of tables") {
+    val tableDF = tables()
+    val schema = StructType(
+      StructField("tableName", StringType, true) ::
+      StructField("isTemporary", BooleanType, false) :: Nil)
+    assert(schema === tableDF.schema)
+
+    tableDF.registerTempTable("tables")
+    checkAnswer(
+      sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
+      Row(true, "ListTablesSuiteTable")
+    )
+    checkAnswer(
+      tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+      Row("tables", true))
+    dropTempTable("tables")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c78369d..eb1ee54 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
     }
   }
 
+  override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+    val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase)
+    client.getAllTables(dbName).map(tableName => (tableName, false))
+  }
+
   /**
    * Create table with specified database, table name, table description and schema
    * @param databaseName Database Name

http://git-wip-us.apache.org/repos/asf/spark/blob/1d0596a1/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
new file mode 100644
index 0000000..068aa03
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -0,0 +1,77 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.Row
+
+class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
+
+  import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+  val df =
+    sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+
+  override def beforeAll(): Unit = {
+    // The catalog in HiveContext is a case insensitive one.
+    catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
+    catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
+    sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
+    sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
+    sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
+  }
+
+  override def afterAll(): Unit = {
+    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+    catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
+    sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
+    sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
+    sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
+  }
+
+  test("get all tables of current database") {
+    val allTables = tables()
+    // We are using default DB.
+    checkAnswer(
+      allTables.filter("tableName = 'listtablessuitetable'"),
+      Row("listtablessuitetable", true))
+    assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
+    checkAnswer(
+      allTables.filter("tableName = 'hivelisttablessuitetable'"),
+      Row("hivelisttablessuitetable", false))
+    assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
+  }
+
+  test("getting all tables with a database name") {
+    val allTables = tables("ListTablesSuiteDB")
+    checkAnswer(
+      allTables.filter("tableName = 'listtablessuitetable'"),
+      Row("listtablessuitetable", true))
+    checkAnswer(
+      allTables.filter("tableName = 'indblisttablessuitetable'"),
+      Row("indblisttablessuitetable", true))
+    assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
+    checkAnswer(
+      allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
+      Row("hiveindblisttablessuitetable", false))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org