You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/02 11:38:08 UTC

[spark] branch master updated: [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 79da1ab400f [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
79da1ab400f is described below

commit 79da1ab400f25dbceec45e107e5366d084138fa8
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Thu Mar 2 20:37:19 2023 +0900

    [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
    
    ### What changes were proposed in this pull request?
    
    This PR removes `sql("command").collect()` workaround in PySpark tests codes.
    
    ### Why are the changes needed?
    
    They were added previously to work around within Spark Connect. This is fixed now, so we don't need to call `collect` anymore.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    CI in this PR should test it out.
    
    Closes #40251 from HyukjinKwon/SPARK-41725.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/catalog.py               | 84 ++++++++++++++---------------
 python/pyspark/sql/readwriter.py            | 18 +++----
 python/pyspark/sql/tests/test_catalog.py    | 56 +++++++++----------
 python/pyspark/sql/tests/test_readwriter.py |  8 +--
 python/pyspark/sql/tests/test_types.py      |  4 +-
 python/pyspark/testing/sqlutils.py          |  6 +--
 6 files changed, 83 insertions(+), 93 deletions(-)

diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index c83d02d4cb3..ccf88492acf 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -246,8 +246,6 @@ class Catalog:
             locationUri=jdb.locationUri(),
         )
 
-    # TODO(SPARK-41725): we don't have to `collect` for every `sql` but
-    #  Spark Connect requires it. We should remove them out.
     def databaseExists(self, dbName: str) -> bool:
         """Check if the database with the specified name exists.
 
@@ -275,7 +273,7 @@ class Catalog:
 
         >>> spark.catalog.databaseExists("test_new_database")
         False
-        >>> _ = spark.sql("CREATE DATABASE test_new_database").collect()
+        >>> _ = spark.sql("CREATE DATABASE test_new_database")
         >>> spark.catalog.databaseExists("test_new_database")
         True
 
@@ -283,7 +281,7 @@ class Catalog:
 
         >>> spark.catalog.databaseExists("spark_catalog.test_new_database")
         True
-        >>> _ = spark.sql("DROP DATABASE test_new_database").collect()
+        >>> _ = spark.sql("DROP DATABASE test_new_database")
         """
         return self._jcatalog.databaseExists(dbName)
 
@@ -372,8 +370,8 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.getTable("tbl1")
         Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
 
@@ -383,7 +381,7 @@ class Catalog:
         Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
         >>> spark.catalog.getTable("spark_catalog.default.tbl1")
         Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
 
         Throw an analysis exception when the table does not exist.
 
@@ -535,7 +533,7 @@ class Catalog:
         Examples
         --------
         >>> _ = spark.sql(
-        ...     "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'").collect()
+        ...     "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'")
         >>> spark.catalog.getFunction("my_func1")
         Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ...
 
@@ -602,11 +600,11 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
         >>> spark.catalog.listColumns("tblA")
         [Column(name='name', description=None, dataType='string', nullable=True, ...
-        >>> _ = spark.sql("DROP TABLE tblA").collect()
+        >>> _ = spark.sql("DROP TABLE tblA")
         """
         if dbName is None:
             iter = self._jcatalog.listColumns(tableName).toLocalIterator()
@@ -667,8 +665,8 @@ class Catalog:
 
         >>> spark.catalog.tableExists("unexisting_table")
         False
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.tableExists("tbl1")
         True
 
@@ -680,13 +678,13 @@ class Catalog:
         True
         >>> spark.catalog.tableExists("tbl1", "default")
         True
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
 
         Check if views exist:
 
         >>> spark.catalog.tableExists("view1")
         False
-        >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1").collect()
+        >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1")
         >>> spark.catalog.tableExists("view1")
         True
 
@@ -698,14 +696,14 @@ class Catalog:
         True
         >>> spark.catalog.tableExists("view1", "default")
         True
-        >>> _ = spark.sql("DROP VIEW view1").collect()
+        >>> _ = spark.sql("DROP VIEW view1")
 
         Check if temporary views exist:
 
-        >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1").collect()
+        >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1")
         >>> spark.catalog.tableExists("view1")
         True
-        >>> df = spark.sql("DROP VIEW view1").collect()
+        >>> df = spark.sql("DROP VIEW view1")
         >>> spark.catalog.tableExists("view1")
         False
         """
@@ -806,7 +804,7 @@ class Catalog:
         Creating a managed table.
 
         >>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
 
         Creating an external table
 
@@ -814,7 +812,7 @@ class Catalog:
         >>> with tempfile.TemporaryDirectory() as d:
         ...     _ = spark.catalog.createTable(
         ...         "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
-        >>> _ = spark.sql("DROP TABLE tbl2").collect()
+        >>> _ = spark.sql("DROP TABLE tbl2")
         """
         if path is not None:
             options["path"] = path
@@ -954,8 +952,8 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.cacheTable("tbl1")
         >>> spark.catalog.isCached("tbl1")
         True
@@ -972,7 +970,7 @@ class Catalog:
         >>> spark.catalog.isCached("spark_catalog.default.tbl1")
         True
         >>> spark.catalog.uncacheTable("tbl1")
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         return self._jcatalog.isCached(tableName)
 
@@ -994,8 +992,8 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.cacheTable("tbl1")
 
         Throw an analysis exception when the table does not exist.
@@ -1009,7 +1007,7 @@ class Catalog:
 
         >>> spark.catalog.cacheTable("spark_catalog.default.tbl1")
         >>> spark.catalog.uncacheTable("tbl1")
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.cacheTable(tableName)
 
@@ -1031,8 +1029,8 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.cacheTable("tbl1")
         >>> spark.catalog.uncacheTable("tbl1")
         >>> spark.catalog.isCached("tbl1")
@@ -1050,7 +1048,7 @@ class Catalog:
         >>> spark.catalog.uncacheTable("spark_catalog.default.tbl1")
         >>> spark.catalog.isCached("tbl1")
         False
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.uncacheTable(tableName)
 
@@ -1064,12 +1062,12 @@ class Catalog:
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
-        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
+        >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
         >>> spark.catalog.clearCache()
         >>> spark.catalog.isCached("tbl1")
         False
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.clearCache()
 
@@ -1095,10 +1093,10 @@ class Catalog:
 
         >>> import tempfile
         >>> with tempfile.TemporaryDirectory() as d:
-        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1")
         ...     _ = spark.sql(
-        ...         "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
-        ...     _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
+        ...         "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
+        ...     _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
         ...     spark.catalog.cacheTable("tbl1")
         ...     spark.table("tbl1").show()
         +---+
@@ -1121,7 +1119,7 @@ class Catalog:
         Using the fully qualified name for the table.
 
         >>> spark.catalog.refreshTable("spark_catalog.default.tbl1")
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.refreshTable(tableName)
 
@@ -1149,12 +1147,12 @@ class Catalog:
 
         >>> import tempfile
         >>> with tempfile.TemporaryDirectory() as d:
-        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1")
         ...     spark.range(1).selectExpr(
         ...         "id as key", "id as value").write.partitionBy("key").mode("overwrite").save(d)
         ...     _ = spark.sql(
         ...          "CREATE TABLE tbl1 (key LONG, value LONG)"
-        ...          "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)).collect()
+        ...          "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d))
         ...     spark.table("tbl1").show()
         ...     spark.catalog.recoverPartitions("tbl1")
         ...     spark.table("tbl1").show()
@@ -1167,7 +1165,7 @@ class Catalog:
         +-----+---+
         |    0|  0|
         +-----+---+
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.recoverPartitions(tableName)
 
@@ -1191,10 +1189,10 @@ class Catalog:
 
         >>> import tempfile
         >>> with tempfile.TemporaryDirectory() as d:
-        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
+        ...     _ = spark.sql("DROP TABLE IF EXISTS tbl1")
         ...     _ = spark.sql(
-        ...         "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
-        ...     _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
+        ...         "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
+        ...     _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
         ...     spark.catalog.cacheTable("tbl1")
         ...     spark.table("tbl1").show()
         +---+
@@ -1214,7 +1212,7 @@ class Catalog:
         >>> spark.table("tbl1").count()
         0
 
-        >>> _ = spark.sql("DROP TABLE tbl1").collect()
+        >>> _ = spark.sql("DROP TABLE tbl1")
         """
         self._jcatalog.refreshByPath(path)
 
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 93fd938dff4..17b59311648 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -466,7 +466,7 @@ class DataFrameReader(OptionUtils):
         |  8|
         |  9|
         +---+
-        >>> _ = spark.sql("DROP TABLE tblA").collect()
+        >>> _ = spark.sql("DROP TABLE tblA")
         """
         return self._df(self._jreader.table(tableName))
 
@@ -1232,7 +1232,7 @@ class DataFrameWriter(OptionUtils):
 
         >>> from pyspark.sql.functions import input_file_name
         >>> # Write a DataFrame into a Parquet file in a bucketed manner.
-        ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect()
+        ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table")
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
         ...     schema=["age", "name"]
@@ -1246,7 +1246,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE bucketed_table").collect()
+        >>> _ = spark.sql("DROP TABLE bucketed_table")
         """
         if not isinstance(numBuckets, int):
             raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))
@@ -1296,7 +1296,7 @@ class DataFrameWriter(OptionUtils):
 
         >>> from pyspark.sql.functions import input_file_name
         >>> # Write a DataFrame into a Parquet file in a sorted-bucketed manner.
-        ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table").collect()
+        ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table")
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
         ...     schema=["age", "name"]
@@ -1311,7 +1311,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect()
+        >>> _ = spark.sql("DROP TABLE sorted_bucketed_table")
         """
         if isinstance(col, (list, tuple)):
             if cols:
@@ -1417,7 +1417,7 @@ class DataFrameWriter(OptionUtils):
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
         >>> df = spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
         ...     schema=["age", "name"]
@@ -1438,7 +1438,7 @@ class DataFrameWriter(OptionUtils):
         |140| Haejoon Lee|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE tblA").collect()
+        >>> _ = spark.sql("DROP TABLE tblA")
         """
         if overwrite is not None:
             self.mode("overwrite" if overwrite else "append")
@@ -1495,7 +1495,7 @@ class DataFrameWriter(OptionUtils):
         --------
         Creates a table from a DataFrame, and read it back.
 
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
         ...     schema=["age", "name"]
@@ -1508,7 +1508,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE tblA").collect()
+        >>> _ = spark.sql("DROP TABLE tblA")
         """
         self.mode(mode).options(**options)
         if partitionBy is not None:
diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py
index 10f3ec12c9c..ae92ce57dc8 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -24,7 +24,7 @@ class CatalogTestsMixin:
         spark = self.spark
         with self.database("some_db"):
             self.assertEqual(spark.catalog.currentDatabase(), "default")
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             spark.catalog.setCurrentDatabase("some_db")
             self.assertEqual(spark.catalog.currentDatabase(), "some_db")
             self.assertRaisesRegex(
@@ -38,7 +38,7 @@ class CatalogTestsMixin:
         with self.database("some_db"):
             databases = [db.name for db in spark.catalog.listDatabases()]
             self.assertEqual(databases, ["default"])
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             databases = [db.name for db in spark.catalog.listDatabases()]
             self.assertEqual(sorted(databases), ["default", "some_db"])
 
@@ -47,7 +47,7 @@ class CatalogTestsMixin:
         spark = self.spark
         with self.database("some_db"):
             self.assertFalse(spark.catalog.databaseExists("some_db"))
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             self.assertTrue(spark.catalog.databaseExists("some_db"))
             self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db"))
             self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2"))
@@ -55,7 +55,7 @@ class CatalogTestsMixin:
     def test_get_database(self):
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             db = spark.catalog.getDatabase("spark_catalog.some_db")
             self.assertEqual(db.name, "some_db")
             self.assertEqual(db.catalog, "spark_catalog")
@@ -65,16 +65,14 @@ class CatalogTestsMixin:
 
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
                 with self.tempView("temp_tab"):
                     self.assertEqual(spark.catalog.listTables(), [])
                     self.assertEqual(spark.catalog.listTables("some_db"), [])
                     spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
-                    spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
-                    spark.sql(
-                        "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet"
-                    ).collect()
+                    spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
+                    spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
 
                     schema = StructType([StructField("a", IntegerType(), True)])
                     description = "this a table created via Catalog.createTable()"
@@ -187,7 +185,7 @@ class CatalogTestsMixin:
     def test_list_functions(self):
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             functions = dict((f.name, f) for f in spark.catalog.listFunctions())
             functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
             self.assertTrue(len(functions) > 200)
@@ -215,10 +213,8 @@ class CatalogTestsMixin:
 
                 if support_udf:
                     spark.udf.register("temp_func", lambda x: str(x))
-                spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
-                spark.sql(
-                    "CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'"
-                ).collect()
+                spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
+                spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
                 newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
                 newFunctionsSomeDb = dict(
                     (f.name, f) for f in spark.catalog.listFunctions("some_db")
@@ -247,7 +243,7 @@ class CatalogTestsMixin:
             self.assertFalse(spark.catalog.functionExists("default.func1"))
             self.assertFalse(spark.catalog.functionExists("spark_catalog.default.func1"))
             self.assertFalse(spark.catalog.functionExists("func1", "default"))
-            spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
+            spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
             self.assertTrue(spark.catalog.functionExists("func1"))
             self.assertTrue(spark.catalog.functionExists("default.func1"))
             self.assertTrue(spark.catalog.functionExists("spark_catalog.default.func1"))
@@ -256,7 +252,7 @@ class CatalogTestsMixin:
     def test_get_function(self):
         spark = self.spark
         with self.function("func1"):
-            spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect()
+            spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
             func1 = spark.catalog.getFunction("spark_catalog.default.func1")
             self.assertTrue(func1.name == "func1")
             self.assertTrue(func1.namespace == ["default"])
@@ -269,12 +265,12 @@ class CatalogTestsMixin:
 
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             with self.table("tab1", "some_db.tab2"):
-                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
                 spark.sql(
                     "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet"
-                ).collect()
+                )
                 columns = sorted(
                     spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name
                 )
@@ -343,11 +339,9 @@ class CatalogTestsMixin:
     def test_table_cache(self):
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             with self.table("tab1"):
-                spark.sql(
-                    "CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet"
-                ).collect()
+                spark.sql("CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet")
                 self.assertFalse(spark.catalog.isCached("some_db.tab1"))
                 self.assertFalse(spark.catalog.isCached("spark_catalog.some_db.tab1"))
                 spark.catalog.cacheTable("spark_catalog.some_db.tab1")
@@ -361,18 +355,16 @@ class CatalogTestsMixin:
         # SPARK-36176: testing that table_exists returns correct boolean
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             with self.table("tab1", "some_db.tab2"):
                 self.assertFalse(spark.catalog.tableExists("tab1"))
                 self.assertFalse(spark.catalog.tableExists("tab2", "some_db"))
-                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
                 self.assertTrue(spark.catalog.tableExists("tab1"))
                 self.assertTrue(spark.catalog.tableExists("default.tab1"))
                 self.assertTrue(spark.catalog.tableExists("spark_catalog.default.tab1"))
                 self.assertTrue(spark.catalog.tableExists("tab1", "default"))
-                spark.sql(
-                    "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet"
-                ).collect()
+                spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
                 self.assertFalse(spark.catalog.tableExists("tab2"))
                 self.assertTrue(spark.catalog.tableExists("some_db.tab2"))
                 self.assertTrue(spark.catalog.tableExists("spark_catalog.some_db.tab2"))
@@ -381,9 +373,9 @@ class CatalogTestsMixin:
     def test_get_table(self):
         spark = self.spark
         with self.database("some_db"):
-            spark.sql("CREATE DATABASE some_db").collect()
+            spark.sql("CREATE DATABASE some_db")
             with self.table("tab1"):
-                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect()
+                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
                 self.assertEqual(spark.catalog.getTable("tab1").database, "default")
                 self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog")
                 self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1")
@@ -397,8 +389,8 @@ class CatalogTestsMixin:
             with self.table("my_tab"):
                 spark.sql(
                     "CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir)
-                ).collect()
-                spark.sql("INSERT INTO my_tab SELECT 'abc'").collect()
+                )
+                spark.sql("INSERT INTO my_tab SELECT 'abc'")
                 spark.catalog.cacheTable("my_tab")
                 self.assertEqual(spark.table("my_tab").count(), 1)
 
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 21c66284ace..17c158a870a 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -56,11 +56,11 @@ class ReadwriterTestsMixin:
             self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
             try:
-                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
                 actual = self.spark.read.load(path=tmpPath)
                 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
             finally:
-                self.spark.sql("RESET spark.sql.sources.default").collect()
+                self.spark.sql("RESET spark.sql.sources.default")
 
             csvpath = os.path.join(tempfile.mkdtemp(), "data")
             df.write.option("quote", None).format("csv").save(csvpath)
@@ -95,11 +95,11 @@ class ReadwriterTestsMixin:
             self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
             try:
-                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
                 actual = self.spark.read.load(path=tmpPath)
                 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
             finally:
-                self.spark.sql("RESET spark.sql.sources.default").collect()
+                self.spark.sql("RESET spark.sql.sources.default")
         finally:
             shutil.rmtree(tmpPath)
 
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 68424cad386..9db090fa810 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -379,13 +379,13 @@ class TypesTestsMixin:
 
     def test_negative_decimal(self):
         try:
-            self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true").collect()
+            self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
             df = self.spark.createDataFrame([(1,), (11,)], ["value"])
             ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
             actual = list(map(lambda r: int(r.value), ret))
             self.assertEqual(actual, [0, 10])
         finally:
-            self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false").collect()
+            self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
 
     def test_create_dataframe_from_objects(self):
         data = [MyObject(1, "1"), MyObject(2, "2")]
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 937ad491479..077d854b1dd 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -202,7 +202,7 @@ class SQLTestUtils:
             yield
         finally:
             for db in databases:
-                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db).collect()
+                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
             self.spark.catalog.setCurrentDatabase("default")
 
     @contextmanager
@@ -217,7 +217,7 @@ class SQLTestUtils:
             yield
         finally:
             for t in tables:
-                self.spark.sql("DROP TABLE IF EXISTS %s" % t).collect()
+                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
 
     @contextmanager
     def tempView(self, *views):
@@ -245,7 +245,7 @@ class SQLTestUtils:
             yield
         finally:
             for f in functions:
-                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f).collect()
+                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
 
     @staticmethod
     def assert_close(a, b):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org