You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/06/07 09:37:35 UTC
[spark] branch master updated: [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 64855fa5582 [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
64855fa5582 is described below
commit 64855fa55821332c4913ad06b2cf902d4c565a94
Author: Jiaan Geng <be...@163.com>
AuthorDate: Wed Jun 7 17:37:07 2023 +0800
[SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
### What changes were proposed in this pull request?
Currently, the syntax `SHOW TABLES LIKE pattern` supports an optional pattern, so as filtered out the expected tables.
But the `Catalog.listTables` missing the function both in Catalog API and Connect Catalog API.
In fact, the optional pattern is very useful.
This PR also extracts the common `wrapNamespace` to clean up the duplicated code.
### Why are the changes needed?
This PR want add the optional pattern for `Catalog.listTables`.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New test cases.
Closes #41461 from beliefer/SPARK-43961.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../org/apache/spark/sql/catalog/Catalog.scala | 9 ++
.../apache/spark/sql/internal/CatalogImpl.scala | 13 +++
.../scala/org/apache/spark/sql/CatalogSuite.scala | 8 ++
.../src/main/protobuf/spark/connect/catalog.proto | 2 +
.../sql/connect/planner/SparkConnectPlanner.scala | 9 +-
project/MimaExcludes.scala | 4 +-
python/pyspark/sql/catalog.py | 24 ++++-
python/pyspark/sql/connect/catalog.py | 6 +-
python/pyspark/sql/connect/plan.py | 5 +-
python/pyspark/sql/connect/proto/catalog_pb2.py | 100 ++++++++++-----------
python/pyspark/sql/connect/proto/catalog_pb2.pyi | 33 ++++++-
python/pyspark/sql/tests/test_catalog.py | 48 ++++++++++
.../org/apache/spark/sql/catalog/Catalog.scala | 10 +++
.../apache/spark/sql/internal/CatalogImpl.scala | 65 ++++++++------
.../apache/spark/sql/internal/CatalogSuite.scala | 13 ++-
15 files changed, 259 insertions(+), 90 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 363f895db20..0ac704e68e6 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -76,6 +76,15 @@ abstract class Catalog {
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String): Dataset[Table]
+ /**
+ * Returns a list of tables/views in the specified database (namespace) which name match the
+ * specify pattern (the name can be qualified with catalog). This includes all temporary views.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listTables(dbName: String, pattern: String): Dataset[Table]
+
/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index c2ed7f4e19e..95a3332cfc2 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -101,6 +101,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}
+ /**
+ * Returns a list of tables/views in the specified database (namespace) which name match the
+ * specify pattern (the name can be qualified with catalog). This includes all temporary views.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listTables(dbName: String, pattern: String): Dataset[Table] = {
+ sparkSession.newDataset(CatalogImpl.tableEncoder) { builder =>
+ builder.getCatalogBuilder.getListTablesBuilder.setDbName(dbName).setPattern(pattern)
+ }
+ }
+
/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
index 396f7214c04..671f6ac4051 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
@@ -126,6 +126,14 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper {
parquetTableName,
orcTableName,
jsonTableName))
+ assert(
+ spark.catalog
+ .listTables(spark.catalog.currentDatabase, "par*")
+ .collect()
+ .map(_.name)
+ .toSet == Set(parquetTableName))
+ assert(
+ spark.catalog.listTables(spark.catalog.currentDatabase, "txt*").collect().isEmpty)
}
assert(spark.catalog.tableExists(parquetTableName))
assert(!spark.catalog.tableExists(orcTableName))
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto b/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto
index 57d75ee4a42..97b905da7c3 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto
@@ -77,6 +77,8 @@ message ListDatabases {
message ListTables {
// (Optional)
optional string db_name = 1;
+ // (Optional) The pattern that the table name needs to match
+ optional string pattern = 2;
}
// See `spark.catalog.listFunctions`
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f09a4a4895b..7e642b0bdf6 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2706,7 +2706,14 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformListTables(getListTables: proto.ListTables): LogicalPlan = {
if (getListTables.hasDbName) {
- session.catalog.listTables(getListTables.getDbName).logicalPlan
+ if (getListTables.hasPattern) {
+ session.catalog.listTables(getListTables.getDbName, getListTables.getPattern).logicalPlan
+ } else {
+ session.catalog.listTables(getListTables.getDbName).logicalPlan
+ }
+ } else if (getListTables.hasPattern) {
+ val currentDatabase = session.catalog.currentDatabase
+ session.catalog.listTables(currentDatabase, getListTables.getPattern).logicalPlan
} else {
session.catalog.listTables().logicalPlan
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d09bc87998c..f22994ed75e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,7 +52,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"),
// [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases"),
+ // [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables")
)
// Defulat exclude rules
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index c0df6f38dbf..9650affc68a 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -202,7 +202,7 @@ class Catalog:
The pattern that the database name needs to match.
.. versionchanged: 3.5.0
- Added ``pattern`` argument.
+ Adds ``pattern`` argument.
Returns
-------
@@ -307,7 +307,9 @@ class Catalog:
"""
return self._jcatalog.databaseExists(dbName)
- def listTables(self, dbName: Optional[str] = None) -> List[Table]:
+ def listTables(
+ self, dbName: Optional[str] = None, pattern: Optional[str] = None
+ ) -> List[Table]:
"""Returns a list of tables/views in the specified database.
.. versionadded:: 2.0.0
@@ -320,6 +322,12 @@ class Catalog:
.. versionchanged:: 3.4.0
Allow ``dbName`` to be qualified with catalog name.
+ pattern : str
+ The pattern that the database name needs to match.
+
+ .. versionchanged: 3.5.0
+ Adds ``pattern`` argument.
+
Returns
-------
list
@@ -336,13 +344,23 @@ class Catalog:
>>> spark.catalog.listTables()
[Table(name='test_view', catalog=None, namespace=[], description=None, ...
+ >>> spark.catalog.listTables(pattern="test*")
+ [Table(name='test_view', catalog=None, namespace=[], description=None, ...
+
+ >>> spark.catalog.listTables(pattern="table*")
+ []
+
>>> _ = spark.catalog.dropTempView("test_view")
>>> spark.catalog.listTables()
[]
"""
if dbName is None:
dbName = self.currentDatabase()
- iter = self._jcatalog.listTables(dbName).toLocalIterator()
+
+ if pattern is None:
+ iter = self._jcatalog.listTables(dbName).toLocalIterator()
+ else:
+ iter = self._jcatalog.listTables(dbName, pattern).toLocalIterator()
tables = []
while iter.hasNext():
jtable = iter.next()
diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py
index 790b194c3f8..6766060a7b9 100644
--- a/python/pyspark/sql/connect/catalog.py
+++ b/python/pyspark/sql/connect/catalog.py
@@ -116,8 +116,10 @@ class Catalog:
databaseExists.__doc__ = PySparkCatalog.databaseExists.__doc__
- def listTables(self, dbName: Optional[str] = None) -> List[Table]:
- pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName))
+ def listTables(
+ self, dbName: Optional[str] = None, pattern: Optional[str] = None
+ ) -> List[Table]:
+ pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName, pattern=pattern))
return [
Table(
name=row.iloc[0],
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 79c070101b6..95d7af90f65 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1648,14 +1648,17 @@ class ListDatabases(LogicalPlan):
class ListTables(LogicalPlan):
- def __init__(self, db_name: Optional[str] = None) -> None:
+ def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._db_name = db_name
+ self._pattern = pattern
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables()))
if self._db_name is not None:
plan.catalog.list_tables.db_name = self._db_name
+ if self._pattern is not None:
+ plan.catalog.list_tables.pattern = self._pattern
return plan
diff --git a/python/pyspark/sql/connect/proto/catalog_pb2.py b/python/pyspark/sql/connect/proto/catalog_pb2.py
index f82b360a4a7..920ffa32444 100644
--- a/python/pyspark/sql/connect/proto/catalog_pb2.py
+++ b/python/pyspark/sql/connect/proto/catalog_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1bspark/connect/catalog.proto\x12\rspark.connect\x1a\x1aspark/connect/common.proto\x1a\x19spark/connect/types.proto"\xc6\x0e\n\x07\x43\x61talog\x12K\n\x10\x63urrent_database\x18\x01 \x01(\x0b\x32\x1e.spark.connect.CurrentDatabaseH\x00R\x0f\x63urrentDatabase\x12U\n\x14set_current_database\x18\x02 \x01(\x0b\x32!.spark.connect.SetCurrentDatabaseH\x00R\x12setCurrentDatabase\x12\x45\n\x0elist_databases\x18\x03 \x01(\x0b\x32\x1c.spark.connect.ListDatabasesH\x00R\rlistDatabases\x12<\n [...]
+ b'\n\x1bspark/connect/catalog.proto\x12\rspark.connect\x1a\x1aspark/connect/common.proto\x1a\x19spark/connect/types.proto"\xc6\x0e\n\x07\x43\x61talog\x12K\n\x10\x63urrent_database\x18\x01 \x01(\x0b\x32\x1e.spark.connect.CurrentDatabaseH\x00R\x0f\x63urrentDatabase\x12U\n\x14set_current_database\x18\x02 \x01(\x0b\x32!.spark.connect.SetCurrentDatabaseH\x00R\x12setCurrentDatabase\x12\x45\n\x0elist_databases\x18\x03 \x01(\x0b\x32\x1c.spark.connect.ListDatabasesH\x00R\rlistDatabases\x12<\n [...]
)
@@ -403,53 +403,53 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_LISTDATABASES._serialized_start = 2032
_LISTDATABASES._serialized_end = 2090
_LISTTABLES._serialized_start = 2092
- _LISTTABLES._serialized_end = 2146
- _LISTFUNCTIONS._serialized_start = 2148
- _LISTFUNCTIONS._serialized_end = 2205
- _LISTCOLUMNS._serialized_start = 2207
- _LISTCOLUMNS._serialized_end = 2293
- _GETDATABASE._serialized_start = 2295
- _GETDATABASE._serialized_end = 2333
- _GETTABLE._serialized_start = 2335
- _GETTABLE._serialized_end = 2418
- _GETFUNCTION._serialized_start = 2420
- _GETFUNCTION._serialized_end = 2512
- _DATABASEEXISTS._serialized_start = 2514
- _DATABASEEXISTS._serialized_end = 2555
- _TABLEEXISTS._serialized_start = 2557
- _TABLEEXISTS._serialized_end = 2643
- _FUNCTIONEXISTS._serialized_start = 2645
- _FUNCTIONEXISTS._serialized_end = 2740
- _CREATEEXTERNALTABLE._serialized_start = 2743
- _CREATEEXTERNALTABLE._serialized_end = 3069
- _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_start = 2980
- _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_end = 3038
- _CREATETABLE._serialized_start = 3072
- _CREATETABLE._serialized_end = 3437
- _CREATETABLE_OPTIONSENTRY._serialized_start = 2980
- _CREATETABLE_OPTIONSENTRY._serialized_end = 3038
- _DROPTEMPVIEW._serialized_start = 3439
- _DROPTEMPVIEW._serialized_end = 3482
- _DROPGLOBALTEMPVIEW._serialized_start = 3484
- _DROPGLOBALTEMPVIEW._serialized_end = 3533
- _RECOVERPARTITIONS._serialized_start = 3535
- _RECOVERPARTITIONS._serialized_end = 3585
- _ISCACHED._serialized_start = 3587
- _ISCACHED._serialized_end = 3628
- _CACHETABLE._serialized_start = 3631
- _CACHETABLE._serialized_end = 3763
- _UNCACHETABLE._serialized_start = 3765
- _UNCACHETABLE._serialized_end = 3810
- _CLEARCACHE._serialized_start = 3812
- _CLEARCACHE._serialized_end = 3824
- _REFRESHTABLE._serialized_start = 3826
- _REFRESHTABLE._serialized_end = 3871
- _REFRESHBYPATH._serialized_start = 3873
- _REFRESHBYPATH._serialized_end = 3908
- _CURRENTCATALOG._serialized_start = 3910
- _CURRENTCATALOG._serialized_end = 3926
- _SETCURRENTCATALOG._serialized_start = 3928
- _SETCURRENTCATALOG._serialized_end = 3982
- _LISTCATALOGS._serialized_start = 3984
- _LISTCATALOGS._serialized_end = 4041
+ _LISTTABLES._serialized_end = 2189
+ _LISTFUNCTIONS._serialized_start = 2191
+ _LISTFUNCTIONS._serialized_end = 2248
+ _LISTCOLUMNS._serialized_start = 2250
+ _LISTCOLUMNS._serialized_end = 2336
+ _GETDATABASE._serialized_start = 2338
+ _GETDATABASE._serialized_end = 2376
+ _GETTABLE._serialized_start = 2378
+ _GETTABLE._serialized_end = 2461
+ _GETFUNCTION._serialized_start = 2463
+ _GETFUNCTION._serialized_end = 2555
+ _DATABASEEXISTS._serialized_start = 2557
+ _DATABASEEXISTS._serialized_end = 2598
+ _TABLEEXISTS._serialized_start = 2600
+ _TABLEEXISTS._serialized_end = 2686
+ _FUNCTIONEXISTS._serialized_start = 2688
+ _FUNCTIONEXISTS._serialized_end = 2783
+ _CREATEEXTERNALTABLE._serialized_start = 2786
+ _CREATEEXTERNALTABLE._serialized_end = 3112
+ _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_start = 3023
+ _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_end = 3081
+ _CREATETABLE._serialized_start = 3115
+ _CREATETABLE._serialized_end = 3480
+ _CREATETABLE_OPTIONSENTRY._serialized_start = 3023
+ _CREATETABLE_OPTIONSENTRY._serialized_end = 3081
+ _DROPTEMPVIEW._serialized_start = 3482
+ _DROPTEMPVIEW._serialized_end = 3525
+ _DROPGLOBALTEMPVIEW._serialized_start = 3527
+ _DROPGLOBALTEMPVIEW._serialized_end = 3576
+ _RECOVERPARTITIONS._serialized_start = 3578
+ _RECOVERPARTITIONS._serialized_end = 3628
+ _ISCACHED._serialized_start = 3630
+ _ISCACHED._serialized_end = 3671
+ _CACHETABLE._serialized_start = 3674
+ _CACHETABLE._serialized_end = 3806
+ _UNCACHETABLE._serialized_start = 3808
+ _UNCACHETABLE._serialized_end = 3853
+ _CLEARCACHE._serialized_start = 3855
+ _CLEARCACHE._serialized_end = 3867
+ _REFRESHTABLE._serialized_start = 3869
+ _REFRESHTABLE._serialized_end = 3914
+ _REFRESHBYPATH._serialized_start = 3916
+ _REFRESHBYPATH._serialized_end = 3951
+ _CURRENTCATALOG._serialized_start = 3953
+ _CURRENTCATALOG._serialized_end = 3969
+ _SETCURRENTCATALOG._serialized_start = 3971
+ _SETCURRENTCATALOG._serialized_end = 4025
+ _LISTCATALOGS._serialized_start = 4027
+ _LISTCATALOGS._serialized_end = 4084
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/catalog_pb2.pyi b/python/pyspark/sql/connect/proto/catalog_pb2.pyi
index fd58ca543ae..77a924d6d51 100644
--- a/python/pyspark/sql/connect/proto/catalog_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/catalog_pb2.pyi
@@ -373,22 +373,51 @@ class ListTables(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DB_NAME_FIELD_NUMBER: builtins.int
+ PATTERN_FIELD_NUMBER: builtins.int
db_name: builtins.str
"""(Optional)"""
+ pattern: builtins.str
+ """(Optional) The pattern that the table name needs to match"""
def __init__(
self,
*,
db_name: builtins.str | None = ...,
+ pattern: builtins.str | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["_db_name", b"_db_name", "db_name", b"db_name"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_db_name",
+ b"_db_name",
+ "_pattern",
+ b"_pattern",
+ "db_name",
+ b"db_name",
+ "pattern",
+ b"pattern",
+ ],
) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["_db_name", b"_db_name", "db_name", b"db_name"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_db_name",
+ b"_db_name",
+ "_pattern",
+ b"_pattern",
+ "db_name",
+ b"db_name",
+ "pattern",
+ b"pattern",
+ ],
) -> None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_db_name", b"_db_name"]
) -> typing_extensions.Literal["db_name"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_pattern", b"_pattern"]
+ ) -> typing_extensions.Literal["pattern"] | None: ...
global___ListTables = ListTables
diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py
index 93390aa0881..716f0638866 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -86,13 +86,25 @@ class CatalogTestsMixin:
)
tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
+ tablesWithPattern = sorted(
+ spark.catalog.listTables(pattern="tab*"), key=lambda t: t.name
+ )
tablesDefault = sorted(
spark.catalog.listTables("default"), key=lambda t: t.name
)
+ tablesDefaultWithPattern = sorted(
+ spark.catalog.listTables("default", "tab*"), key=lambda t: t.name
+ )
tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
+ tablesSomeDbWithPattern = sorted(
+ spark.catalog.listTables("some_db", "tab*"), key=lambda t: t.name
+ )
self.assertEqual(tables, tablesDefault)
+ self.assertEqual(tablesWithPattern, tablesDefaultWithPattern)
self.assertEqual(len(tables), 3)
+ self.assertEqual(len(tablesWithPattern), 2)
self.assertEqual(len(tablesSomeDb), 2)
+ self.assertEqual(len(tablesSomeDbWithPattern), 1)
# make table in old fashion
def makeTable(
@@ -157,6 +169,30 @@ class CatalogTestsMixin:
),
)
)
+ self.assertTrue(
+ compareTables(
+ tablesWithPattern[0],
+ makeTable(
+ name="tab1",
+ database="default",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False,
+ ),
+ )
+ )
+ self.assertTrue(
+ compareTables(
+ tablesWithPattern[1],
+ makeTable(
+ name="tab3_via_catalog",
+ database="default",
+ description=description,
+ tableType="MANAGED",
+ isTemporary=False,
+ ),
+ )
+ )
self.assertTrue(
compareTables(
tablesSomeDb[0],
@@ -181,6 +217,18 @@ class CatalogTestsMixin:
),
)
)
+ self.assertTrue(
+ compareTables(
+ tablesSomeDbWithPattern[0],
+ makeTable(
+ name="tab2",
+ database="some_db",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False,
+ ),
+ )
+ )
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index c2cdd2382c4..b8cb97e1650 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -79,6 +79,16 @@ abstract class Catalog {
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String): Dataset[Table]
+ /**
+ * Returns a list of tables/views in the specified database (namespace)
+ * which name match the specify pattern (the name can be qualified with catalog).
+ * This includes all temporary views.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listTables(dbName: String, pattern: String): Dataset[Table]
+
/**
* Returns a list of functions registered in the current database (namespace).
* This includes all temporary functions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index f8da89eea0a..3c61102699e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -122,16 +122,28 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
@throws[AnalysisException]("database does not exist")
override def listTables(dbName: String): Dataset[Table] = {
- // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or
- // a qualified namespace with catalog name. We assume it's a single database name
- // and check if we can find it in the sessionCatalog. If so we list tables under
- // that database. Otherwise we will resolve the catalog/namespace and list tables there.
- val namespace = if (sessionCatalog.databaseExists(dbName)) {
- Seq(CatalogManager.SESSION_CATALOG_NAME, dbName)
- } else {
- parseIdent(dbName)
- }
- val plan = ShowTables(UnresolvedNamespace(namespace), None)
+ listTablesInternal(dbName, None)
+ }
+
+ /**
+ * Returns a list of tables/views in the specified database (namespace)
+ * which name match the specify pattern (the name can be qualified with catalog).
+ * This includes all temporary views.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ override def listTables(dbName: String, pattern: String): Dataset[Table] = {
+ listTablesInternal(dbName, Some(pattern))
+ }
+
+ private def listTablesInternal(dbName: String, pattern: Option[String]): Dataset[Table] = {
+ val namespace = resolveNamespace(dbName)
+ val plan = ShowTables(UnresolvedNamespace(namespace), pattern)
+ makeTablesDataset(plan)
+ }
+
+ private def makeTablesDataset(plan: ShowTables): Dataset[Table] = {
val qe = sparkSession.sessionState.executePlan(plan)
val catalog = qe.analyzed.collectFirst {
case ShowTables(r: ResolvedNamespace, _, _) => r.catalog
@@ -228,15 +240,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
@throws[AnalysisException]("database does not exist")
override def listFunctions(dbName: String): Dataset[Function] = {
- // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or
- // a qualified namespace with catalog name. We assume it's a single database name
- // and check if we can find it in the sessionCatalog. If so we list functions under
- // that database. Otherwise we will resolve the catalog/namespace and list functions there.
- val namespace = if (sessionCatalog.databaseExists(dbName)) {
- Seq(CatalogManager.SESSION_CATALOG_NAME, dbName)
- } else {
- parseIdent(dbName)
- }
+ val namespace = resolveNamespace(dbName)
val functions = collection.mutable.ArrayBuilder.make[Function]
// TODO: The SHOW FUNCTIONS should tell us the function type (built-in, temp, persistent) and
@@ -386,15 +390,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* `Database` can be found.
*/
override def getDatabase(dbName: String): Database = {
- // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or
- // a qualified namespace with catalog name. We assume it's a single database name
- // and check if we can find it in the sessionCatalog. Otherwise we will parse `dbName` and
- // resolve catalog/namespace with it.
- val namespace = if (sessionCatalog.databaseExists(dbName)) {
- Seq(CatalogManager.SESSION_CATALOG_NAME, dbName)
- } else {
- sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName)
- }
+ val namespace = resolveNamespace(dbName)
val plan = UnresolvedNamespace(namespace)
sparkSession.sessionState.executePlan(plan).analyzed match {
case ResolvedNamespace(catalog, namespace) =>
@@ -403,6 +399,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}
+ private def resolveNamespace(dbName: String): Seq[String] = {
+ // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or
+ // a qualified namespace with catalog name. We assume it's a single database name
+ // and check if we can find it in the sessionCatalog. If so we list functions under
+ // that database. Otherwise we will resolve the catalog/namespace and list functions there.
+ if (sessionCatalog.databaseExists(dbName)) {
+ Seq(CatalogManager.SESSION_CATALOG_NAME, dbName)
+ } else {
+ parseIdent(dbName)
+ }
+ }
+
+
/**
* Gets the table or view with the specified name. This table can be a temporary view or a
* table/view. This throws an `AnalysisException` when no `Table` can be found.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 6fa7ad56b68..5ef8e35da9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -179,11 +179,17 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
createTempTable("my_temp_table")
assert(spark.catalog.listTables().collect().map(_.name).toSet ==
Set("my_table1", "my_table2", "my_temp_table"))
+ assert(spark.catalog.listTables(spark.catalog.currentDatabase, "my_table*").collect()
+ .map(_.name).toSet == Set("my_table1", "my_table2"))
dropTable("my_table1")
assert(spark.catalog.listTables().collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
+ assert(spark.catalog.listTables(spark.catalog.currentDatabase, "my_table*").collect()
+ .map(_.name).toSet == Set("my_table2"))
dropTable("my_temp_table")
assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2"))
+ assert(spark.catalog.listTables(spark.catalog.currentDatabase, "my_table*").collect()
+ .map(_.name).toSet == Set("my_table2"))
}
test("SPARK-39828: Catalog.listTables() should respect currentCatalog") {
@@ -224,14 +230,17 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
Set("my_table1", "my_temp_table"))
assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
+ assert(spark.catalog.listTables("my_db2", "my_table*").collect().map(_.name).toSet ==
+ Set("my_table2"))
dropTable("my_table1", Some("my_db1"))
assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet ==
Set("my_temp_table"))
+ assert(spark.catalog.listTables("my_db1", "my_table*").collect().isEmpty)
assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
dropTable("my_temp_table")
- assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty)
- assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty)
+ assert(spark.catalog.listTables("default").collect().isEmpty)
+ assert(spark.catalog.listTables("my_db1").collect().isEmpty)
assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2"))
val e = intercept[AnalysisException] {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org