You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/02 11:38:08 UTC
[spark] branch master updated: [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 79da1ab400f [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
79da1ab400f is described below
commit 79da1ab400f25dbceec45e107e5366d084138fa8
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Thu Mar 2 20:37:19 2023 +0900
[SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
### What changes were proposed in this pull request?
This PR removes `sql("command").collect()` workaround in PySpark tests codes.
### Why are the changes needed?
They were added previously to work around within Spark Connect. This is fixed now, so we don't need to call `collect` anymore.
### Does this PR introduce _any_ user-facing change?
No, test-only.
### How was this patch tested?
CI in this PR should test it out.
Closes #40251 from HyukjinKwon/SPARK-41725.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/catalog.py | 84 ++++++++++++++---------------
python/pyspark/sql/readwriter.py | 18 +++----
python/pyspark/sql/tests/test_catalog.py | 56 +++++++++----------
python/pyspark/sql/tests/test_readwriter.py | 8 +--
python/pyspark/sql/tests/test_types.py | 4 +-
python/pyspark/testing/sqlutils.py | 6 +--
6 files changed, 83 insertions(+), 93 deletions(-)
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index c83d02d4cb3..ccf88492acf 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -246,8 +246,6 @@ class Catalog:
locationUri=jdb.locationUri(),
)
- # TODO(SPARK-41725): we don't have to `collect` for every `sql` but
- # Spark Connect requires it. We should remove them out.
def databaseExists(self, dbName: str) -> bool:
"""Check if the database with the specified name exists.
@@ -275,7 +273,7 @@ class Catalog:
>>> spark.catalog.databaseExists("test_new_database")
False
- >>> _ = spark.sql("CREATE DATABASE test_new_database").collect()
+ >>> _ = spark.sql("CREATE DATABASE test_new_database")
>>> spark.catalog.databaseExists("test_new_database")
True
@@ -283,7 +281,7 @@ class Catalog:
>>> spark.catalog.databaseExists("spark_catalog.test_new_database")
True
- >>> _ = spark.sql("DROP DATABASE test_new_database").collect()
+ >>> _ = spark.sql("DROP DATABASE test_new_database")
"""
return self._jcatalog.databaseExists(dbName)
@@ -372,8 +370,8 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.getTable("tbl1")
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
@@ -383,7 +381,7 @@ class Catalog:
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
>>> spark.catalog.getTable("spark_catalog.default.tbl1")
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
Throw an analysis exception when the table does not exist.
@@ -535,7 +533,7 @@ class Catalog:
Examples
--------
>>> _ = spark.sql(
- ... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'").collect()
+ ... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'")
>>> spark.catalog.getFunction("my_func1")
Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ...
@@ -602,11 +600,11 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
>>> spark.catalog.listColumns("tblA")
[Column(name='name', description=None, dataType='string', nullable=True, ...
- >>> _ = spark.sql("DROP TABLE tblA").collect()
+ >>> _ = spark.sql("DROP TABLE tblA")
"""
if dbName is None:
iter = self._jcatalog.listColumns(tableName).toLocalIterator()
@@ -667,8 +665,8 @@ class Catalog:
>>> spark.catalog.tableExists("unexisting_table")
False
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.tableExists("tbl1")
True
@@ -680,13 +678,13 @@ class Catalog:
True
>>> spark.catalog.tableExists("tbl1", "default")
True
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
Check if views exist:
>>> spark.catalog.tableExists("view1")
False
- >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1").collect()
+ >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1")
>>> spark.catalog.tableExists("view1")
True
@@ -698,14 +696,14 @@ class Catalog:
True
>>> spark.catalog.tableExists("view1", "default")
True
- >>> _ = spark.sql("DROP VIEW view1").collect()
+ >>> _ = spark.sql("DROP VIEW view1")
Check if temporary views exist:
- >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1").collect()
+ >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1")
>>> spark.catalog.tableExists("view1")
True
- >>> df = spark.sql("DROP VIEW view1").collect()
+ >>> df = spark.sql("DROP VIEW view1")
>>> spark.catalog.tableExists("view1")
False
"""
@@ -806,7 +804,7 @@ class Catalog:
Creating a managed table.
>>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
Creating an external table
@@ -814,7 +812,7 @@ class Catalog:
>>> with tempfile.TemporaryDirectory() as d:
... _ = spark.catalog.createTable(
... "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
- >>> _ = spark.sql("DROP TABLE tbl2").collect()
+ >>> _ = spark.sql("DROP TABLE tbl2")
"""
if path is not None:
options["path"] = path
@@ -954,8 +952,8 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.cacheTable("tbl1")
>>> spark.catalog.isCached("tbl1")
True
@@ -972,7 +970,7 @@ class Catalog:
>>> spark.catalog.isCached("spark_catalog.default.tbl1")
True
>>> spark.catalog.uncacheTable("tbl1")
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
return self._jcatalog.isCached(tableName)
@@ -994,8 +992,8 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.cacheTable("tbl1")
Throw an analysis exception when the table does not exist.
@@ -1009,7 +1007,7 @@ class Catalog:
>>> spark.catalog.cacheTable("spark_catalog.default.tbl1")
>>> spark.catalog.uncacheTable("tbl1")
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.cacheTable(tableName)
@@ -1031,8 +1029,8 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.cacheTable("tbl1")
>>> spark.catalog.uncacheTable("tbl1")
>>> spark.catalog.isCached("tbl1")
@@ -1050,7 +1048,7 @@ class Catalog:
>>> spark.catalog.uncacheTable("spark_catalog.default.tbl1")
>>> spark.catalog.isCached("tbl1")
False
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.uncacheTable(tableName)
@@ -1064,12 +1062,12 @@ class Catalog:
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
- >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+ >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> spark.catalog.clearCache()
>>> spark.catalog.isCached("tbl1")
False
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.clearCache()
@@ -1095,10 +1093,10 @@ class Catalog:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
- ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+ ... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... _ = spark.sql(
- ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
- ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
+ ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
+ ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
... spark.catalog.cacheTable("tbl1")
... spark.table("tbl1").show()
+---+
@@ -1121,7 +1119,7 @@ class Catalog:
Using the fully qualified name for the table.
>>> spark.catalog.refreshTable("spark_catalog.default.tbl1")
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.refreshTable(tableName)
@@ -1149,12 +1147,12 @@ class Catalog:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
- ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+ ... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... spark.range(1).selectExpr(
... "id as key", "id as value").write.partitionBy("key").mode("overwrite").save(d)
... _ = spark.sql(
... "CREATE TABLE tbl1 (key LONG, value LONG)"
- ... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)).collect()
+ ... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d))
... spark.table("tbl1").show()
... spark.catalog.recoverPartitions("tbl1")
... spark.table("tbl1").show()
@@ -1167,7 +1165,7 @@ class Catalog:
+-----+---+
| 0| 0|
+-----+---+
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.recoverPartitions(tableName)
@@ -1191,10 +1189,10 @@ class Catalog:
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
- ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+ ... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... _ = spark.sql(
- ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
- ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
+ ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
+ ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
... spark.catalog.cacheTable("tbl1")
... spark.table("tbl1").show()
+---+
@@ -1214,7 +1212,7 @@ class Catalog:
>>> spark.table("tbl1").count()
0
- >>> _ = spark.sql("DROP TABLE tbl1").collect()
+ >>> _ = spark.sql("DROP TABLE tbl1")
"""
self._jcatalog.refreshByPath(path)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 93fd938dff4..17b59311648 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -466,7 +466,7 @@ class DataFrameReader(OptionUtils):
| 8|
| 9|
+---+
- >>> _ = spark.sql("DROP TABLE tblA").collect()
+ >>> _ = spark.sql("DROP TABLE tblA")
"""
return self._df(self._jreader.table(tableName))
@@ -1232,7 +1232,7 @@ class DataFrameWriter(OptionUtils):
>>> from pyspark.sql.functions import input_file_name
>>> # Write a DataFrame into a Parquet file in a bucketed manner.
- ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect()
+ ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table")
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
... schema=["age", "name"]
@@ -1246,7 +1246,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE bucketed_table").collect()
+ >>> _ = spark.sql("DROP TABLE bucketed_table")
"""
if not isinstance(numBuckets, int):
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))
@@ -1296,7 +1296,7 @@ class DataFrameWriter(OptionUtils):
>>> from pyspark.sql.functions import input_file_name
>>> # Write a DataFrame into a Parquet file in a sorted-bucketed manner.
- ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table").collect()
+ ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table")
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
... schema=["age", "name"]
@@ -1311,7 +1311,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect()
+ >>> _ = spark.sql("DROP TABLE sorted_bucketed_table")
"""
if isinstance(col, (list, tuple)):
if cols:
@@ -1417,7 +1417,7 @@ class DataFrameWriter(OptionUtils):
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
>>> df = spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
... schema=["age", "name"]
@@ -1438,7 +1438,7 @@ class DataFrameWriter(OptionUtils):
|140| Haejoon Lee|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE tblA").collect()
+ >>> _ = spark.sql("DROP TABLE tblA")
"""
if overwrite is not None:
self.mode("overwrite" if overwrite else "append")
@@ -1495,7 +1495,7 @@ class DataFrameWriter(OptionUtils):
--------
Creates a table from a DataFrame, and read it back.
- >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
... schema=["age", "name"]
@@ -1508,7 +1508,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE tblA").collect()
+ >>> _ = spark.sql("DROP TABLE tblA")
"""
self.mode(mode).options(**options)
if partitionBy is not None:
diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py
index 10f3ec12c9c..ae92ce57dc8 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -24,7 +24,7 @@ class CatalogTestsMixin:
spark = self.spark
with self.database("some_db"):
self.assertEqual(spark.catalog.currentDatabase(), "default")
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
spark.catalog.setCurrentDatabase("some_db")
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegex(
@@ -38,7 +38,7 @@ class CatalogTestsMixin:
with self.database("some_db"):
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEqual(databases, ["default"])
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEqual(sorted(databases), ["default", "some_db"])
@@ -47,7 +47,7 @@ class CatalogTestsMixin:
spark = self.spark
with self.database("some_db"):
self.assertFalse(spark.catalog.databaseExists("some_db"))
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
self.assertTrue(spark.catalog.databaseExists("some_db"))
self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db"))
self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2"))
@@ -55,7 +55,7 @@ class CatalogTestsMixin:
def test_get_database(self):
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
db = spark.catalog.getDatabase("spark_catalog.some_db")
self.assertEqual(db.name, "some_db")
self.assertEqual(db.catalog, "spark_catalog")
@@ -65,16 +65,14 @@ class CatalogTestsMixin:
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
with self.tempView("temp_tab"):
self.assertEqual(spark.catalog.listTables(), [])
self.assertEqual(spark.catalog.listTables("some_db"), [])
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
- spark.sql(
- "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet"
- ).collect()
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
+ spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
schema = StructType([StructField("a", IntegerType(), True)])
description = "this a table created via Catalog.createTable()"
@@ -187,7 +185,7 @@ class CatalogTestsMixin:
def test_list_functions(self):
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
self.assertTrue(len(functions) > 200)
@@ -215,10 +213,8 @@ class CatalogTestsMixin:
if support_udf:
spark.udf.register("temp_func", lambda x: str(x))
- spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
- spark.sql(
- "CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'"
- ).collect()
+ spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
+ spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
newFunctionsSomeDb = dict(
(f.name, f) for f in spark.catalog.listFunctions("some_db")
@@ -247,7 +243,7 @@ class CatalogTestsMixin:
self.assertFalse(spark.catalog.functionExists("default.func1"))
self.assertFalse(spark.catalog.functionExists("spark_catalog.default.func1"))
self.assertFalse(spark.catalog.functionExists("func1", "default"))
- spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
+ spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
self.assertTrue(spark.catalog.functionExists("func1"))
self.assertTrue(spark.catalog.functionExists("default.func1"))
self.assertTrue(spark.catalog.functionExists("spark_catalog.default.func1"))
@@ -256,7 +252,7 @@ class CatalogTestsMixin:
def test_get_function(self):
spark = self.spark
with self.function("func1"):
- spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
+ spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
func1 = spark.catalog.getFunction("spark_catalog.default.func1")
self.assertTrue(func1.name == "func1")
self.assertTrue(func1.namespace == ["default"])
@@ -269,12 +265,12 @@ class CatalogTestsMixin:
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2"):
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
spark.sql(
"CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet"
- ).collect()
+ )
columns = sorted(
spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name
)
@@ -343,11 +339,9 @@ class CatalogTestsMixin:
def test_table_cache(self):
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
with self.table("tab1"):
- spark.sql(
- "CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet"
- ).collect()
+ spark.sql("CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet")
self.assertFalse(spark.catalog.isCached("some_db.tab1"))
self.assertFalse(spark.catalog.isCached("spark_catalog.some_db.tab1"))
spark.catalog.cacheTable("spark_catalog.some_db.tab1")
@@ -361,18 +355,16 @@ class CatalogTestsMixin:
# SPARK-36176: testing that table_exists returns correct boolean
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2"):
self.assertFalse(spark.catalog.tableExists("tab1"))
self.assertFalse(spark.catalog.tableExists("tab2", "some_db"))
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
self.assertTrue(spark.catalog.tableExists("tab1"))
self.assertTrue(spark.catalog.tableExists("default.tab1"))
self.assertTrue(spark.catalog.tableExists("spark_catalog.default.tab1"))
self.assertTrue(spark.catalog.tableExists("tab1", "default"))
- spark.sql(
- "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet"
- ).collect()
+ spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
self.assertFalse(spark.catalog.tableExists("tab2"))
self.assertTrue(spark.catalog.tableExists("some_db.tab2"))
self.assertTrue(spark.catalog.tableExists("spark_catalog.some_db.tab2"))
@@ -381,9 +373,9 @@ class CatalogTestsMixin:
def test_get_table(self):
spark = self.spark
with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db").collect()
+ spark.sql("CREATE DATABASE some_db")
with self.table("tab1"):
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
self.assertEqual(spark.catalog.getTable("tab1").database, "default")
self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog")
self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1")
@@ -397,8 +389,8 @@ class CatalogTestsMixin:
with self.table("my_tab"):
spark.sql(
"CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir)
- ).collect()
- spark.sql("INSERT INTO my_tab SELECT 'abc'").collect()
+ )
+ spark.sql("INSERT INTO my_tab SELECT 'abc'")
spark.catalog.cacheTable("my_tab")
self.assertEqual(spark.table("my_tab").count(), 1)
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 21c66284ace..17c158a870a 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -56,11 +56,11 @@ class ReadwriterTestsMixin:
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
try:
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
finally:
- self.spark.sql("RESET spark.sql.sources.default").collect()
+ self.spark.sql("RESET spark.sql.sources.default")
csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
@@ -95,11 +95,11 @@ class ReadwriterTestsMixin:
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
try:
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
finally:
- self.spark.sql("RESET spark.sql.sources.default").collect()
+ self.spark.sql("RESET spark.sql.sources.default")
finally:
shutil.rmtree(tmpPath)
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 68424cad386..9db090fa810 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -379,13 +379,13 @@ class TypesTestsMixin:
def test_negative_decimal(self):
try:
- self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true").collect()
+ self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
df = self.spark.createDataFrame([(1,), (11,)], ["value"])
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
finally:
- self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false").collect()
+ self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 937ad491479..077d854b1dd 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -202,7 +202,7 @@ class SQLTestUtils:
yield
finally:
for db in databases:
- self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db).collect()
+ self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
self.spark.catalog.setCurrentDatabase("default")
@contextmanager
@@ -217,7 +217,7 @@ class SQLTestUtils:
yield
finally:
for t in tables:
- self.spark.sql("DROP TABLE IF EXISTS %s" % t).collect()
+ self.spark.sql("DROP TABLE IF EXISTS %s" % t)
@contextmanager
def tempView(self, *views):
@@ -245,7 +245,7 @@ class SQLTestUtils:
yield
finally:
for f in functions:
- self.spark.sql("DROP FUNCTION IF EXISTS %s" % f).collect()
+ self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
@staticmethod
def assert_close(a, b):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org