You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/02 01:01:59 UTC
[spark] branch master updated: [SPARK-40930][CONNECT] Support Collect() in Python client
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 658681f3eb5 [SPARK-40930][CONNECT] Support Collect() in Python client
658681f3eb5 is described below
commit 658681f3eb5b8f3226ac8d3793e2c1a065351b6c
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Wed Nov 2 10:01:42 2022 +0900
[SPARK-40930][CONNECT] Support Collect() in Python client
### What changes were proposed in this pull request?
Before this PR, the `collect()` call will throw an exception to recommend to use `toPandas()`.
With this PR, we can generate a list of PySpark `Row` upon calling `collect()`.
### Why are the changes needed?
Improve API coverage.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38409 from amaliujia/python_support_collect.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/dataframe.py | 13 ++++++++++---
python/pyspark/sql/tests/connect/test_connect_basic.py | 10 +++++++++-
2 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index c7107a7e79f..b9ddb0db300 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -34,7 +34,10 @@ from pyspark.sql.connect.column import (
Expression,
LiteralExpression,
)
-from pyspark.sql.types import StructType
+from pyspark.sql.types import (
+ StructType,
+ Row,
+)
if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
@@ -317,8 +320,12 @@ class DataFrame(object):
return self._plan.print()
return ""
- def collect(self) -> None:
- raise NotImplementedError("Please use toPandas().")
+ def collect(self) -> List[Row]:
+ pdf = self.toPandas()
+ if pdf is not None:
+ return list(pdf.apply(lambda row: Row(**row), axis=1))
+ else:
+ return []
def toPandas(self) -> Optional["pandas.DataFrame"]:
if self._plan is None:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e9a06f9c545..0d3fc76134e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -73,7 +73,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
# Setup Remote Spark Session
cls.connect = RemoteSparkSession(user_id="test_user")
df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
- # Since we might create multiple Spark sessions, we need to creata global temporary view
+ # Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
@@ -89,6 +89,14 @@ class SparkConnectTests(SparkConnectSQLTestCase):
# Check that the limit is applied
self.assertEqual(len(data.index), 10)
+ def test_collect(self):
+ df = self.connect.read.table(self.tbl_name)
+ data = df.limit(10).collect()
+ self.assertEqual(len(data), 10)
+ # Check Row has schema column names.
+ self.assertTrue("name" in data[0])
+ self.assertTrue("id" in data[0])
+
def test_simple_udf(self):
def conv_udf(x) -> str:
return "Martin"
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org