You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/07/04 13:00:09 UTC

[spark] branch master updated: [SPARK-39649][PYTHON] Make listDatabases / getDatabase / listColumns / refreshTable in PySpark support 3-layer-namespace

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e7a571532e [SPARK-39649][PYTHON] Make listDatabases / getDatabase / listColumns / refreshTable in PySpark support 3-layer-namespace
6e7a571532e is described below

commit 6e7a571532e7c2e76725aa5651310b6b6a6de1be
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Jul 4 20:59:54 2022 +0800

    [SPARK-39649][PYTHON] Make listDatabases / getDatabase / listColumns / refreshTable in PySpark support 3-layer-namespace
    
    ### What changes were proposed in this pull request?
    
    make following commands in pyspark support 3-layer-namespace
    
    - ~setCurrentDatabase~ (per the comments https://github.com/apache/spark/pull/36969/files#diff-e6c98e62b4d35c54acd0481006733a84d6a12dec0a59b3d3024e103d708fae88R70-R71, skip `setCurrentDatabase`)
    - listDatabases
    - getDatabase
    - listColumns
    - refreshTable
    
    ### Why are the changes needed?
    to support 3-layer-namespace
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new api added
    
    ### How was this patch tested?
    updated UT
    
    Closes #37039 from zhengruifeng/py_3l_remaining.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 python/pyspark/sql/catalog.py            | 61 +++++++++++++++++++++++++++++---
 python/pyspark/sql/tests/test_catalog.py | 32 ++++++++++++++++-
 2 files changed, 88 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 6d38a37f6ab..624e3877db0 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -36,6 +36,7 @@ class CatalogMetadata(NamedTuple):
 
 class Database(NamedTuple):
     name: str
+    catalog: Optional[str]
     description: Optional[str]
     locationUri: str
 
@@ -139,11 +140,40 @@ class Catalog:
             jdb = iter.next()
             databases.append(
                 Database(
-                    name=jdb.name(), description=jdb.description(), locationUri=jdb.locationUri()
+                    name=jdb.name(),
+                    catalog=jdb.catalog(),
+                    description=jdb.description(),
+                    locationUri=jdb.locationUri(),
                 )
             )
         return databases
 
+    def getDatabase(self, dbName: str) -> Database:
+        """Get the database with the specified name.
+        This throws an AnalysisException when the database cannot be found.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        dbName : str
+             name of the database to check existence.
+
+        Examples
+        --------
+        >>> spark.catalog.getDatabase("default")
+        Database(name='default', catalog=None, description='default database', ...
+        >>> spark.catalog.getDatabase("spark_catalog.default")
+        Database(name='default', catalog='spark_catalog', description='default database', ...
+        """
+        jdb = self._jcatalog.getDatabase(dbName)
+        return Database(
+            name=jdb.name(),
+            catalog=jdb.catalog(),
+            description=jdb.description(),
+            locationUri=jdb.locationUri(),
+        )
+
     def databaseExists(self, dbName: str) -> bool:
         """Check if the database with the specified name exists.
 
@@ -309,14 +339,33 @@ class Catalog:
 
         .. versionadded:: 2.0.0
 
+        Parameters
+        ----------
+        tableName : str
+                    name of the table to check existence
+        dbName : str, optional
+                 name of the database to check table existence in.
+
+           .. deprecated:: 3.4.0
+
+        .. versionchanged:: 3.4
+           Allowed ``tableName`` to be qualified with catalog name when ``dbName`` is None.
+
          Notes
          -----
          the order of arguments here is different from that of its JVM counterpart
          because Python does not support method overloading.
         """
         if dbName is None:
-            dbName = self.currentDatabase()
-        iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
+            iter = self._jcatalog.listColumns(tableName).toLocalIterator()
+        else:
+            warnings.warn(
+                "`dbName` has been deprecated since Spark 3.4 and might be removed in "
+                "a future version. Use listColumns(`dbName.tableName`) instead.",
+                FutureWarning,
+            )
+            iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
+
         columns = []
         while iter.hasNext():
             jcolumn = iter.next()
@@ -590,7 +639,11 @@ class Catalog:
 
     @since(2.0)
     def refreshTable(self, tableName: str) -> None:
-        """Invalidates and refreshes all the cached data and metadata of the given table."""
+        """Invalidates and refreshes all the cached data and metadata of the given table.
+
+        .. versionchanged:: 3.4
+           Allowed ``tableName`` to be qualified with catalog name.
+        """
         self._jcatalog.refreshTable(tableName)
 
     @since("2.1.1")
diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py
index 53c2489015a..49d96a9b7aa 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -53,6 +53,14 @@ class CatalogTests(ReusedSQLTestCase):
             self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db"))
             self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2"))
 
+    def test_get_database(self):
+        spark = self.spark
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            db = spark.catalog.getDatabase("spark_catalog.some_db")
+            self.assertEqual(db.name, "some_db")
+            self.assertEqual(db.catalog, "spark_catalog")
+
     def test_list_tables(self):
         from pyspark.sql.catalog import Table
 
@@ -245,7 +253,9 @@ class CatalogTests(ReusedSQLTestCase):
                 spark.sql(
                     "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet"
                 )
-                columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
+                columns = sorted(
+                    spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name
+                )
                 columnsDefault = sorted(
                     spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name
                 )
@@ -352,6 +362,26 @@ class CatalogTests(ReusedSQLTestCase):
                 self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog")
                 self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1")
 
+    def test_refresh_table(self):
+        import os
+        import tempfile
+
+        spark = self.spark
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            with self.table("my_tab"):
+                spark.sql(
+                    "CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir)
+                )
+                spark.sql("INSERT INTO my_tab SELECT 'abc'")
+                spark.catalog.cacheTable("my_tab")
+                self.assertEqual(spark.table("my_tab").count(), 1)
+
+                os.system("rm -rf {}/*".format(tmp_dir))
+                self.assertEqual(spark.table("my_tab").count(), 1)
+
+                spark.catalog.refreshTable("spark_catalog.default.my_tab")
+                self.assertEqual(spark.table("my_tab").count(), 0)
+
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org