You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/08 00:31:24 UTC
[spark] branch master updated: [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2a361b9ddfa [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client
2a361b9ddfa is described below
commit 2a361b9ddfa766c719399b35c38f4dafe68353ee
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Tue Nov 8 08:30:49 2022 +0800
[SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client
### What changes were proposed in this pull request?
1. Add `take(n)` API.
2. Change `head(n)` API to return `Union[Optional[Row], List[Row]]`.
3. Update `first()` to return `Optional[Row]`.
### Why are the changes needed?
Improve API coverage.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38488 from amaliujia/SPARK-41002.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/dataframe.py | 61 ++++++++++++++++++++--
.../sql/tests/connect/test_connect_basic.py | 36 +++++++++++--
2 files changed, 90 insertions(+), 7 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index b9ba4b99ba0..9eecdbb7145 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -24,6 +24,7 @@ from typing import (
Tuple,
Union,
TYPE_CHECKING,
+ overload,
)
import pandas
@@ -211,14 +212,66 @@ class DataFrame(object):
plan.Filter(child=self._plan, filter=condition), session=self._session
)
- def first(self) -> Optional["pandas.DataFrame"]:
- return self.head(1)
+ def first(self) -> Optional[Row]:
+ """Returns the first row as a :class:`Row`.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`Row`
+ First row if :class:`DataFrame` is not empty, otherwise ``None``.
+ """
+ return self.head()
def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
return GroupingFrame(self, *cols)
- def head(self, n: int) -> Optional["pandas.DataFrame"]:
- return self.limit(n).toPandas()
+ @overload
+ def head(self) -> Optional[Row]:
+ ...
+
+ @overload
+ def head(self, n: int) -> List[Row]:
+ ...
+
+ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
+ """Returns the first ``n`` rows.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ n : int, optional
+ default 1. Number of rows to return.
+
+ Returns
+ -------
+ If n is greater than 1, return a list of :class:`Row`.
+ If n is 1, return a single Row.
+ """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def take(self, num: int) -> List[Row]:
+ """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ num : int
+ Number of records to return. Will return this number of records
+ or whataver number is available.
+
+ Returns
+ -------
+ list
+ List of rows
+ """
+ return self.limit(num).collect()
# TODO: extend `on` to also be type List[ColumnRef].
def join(
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 18a752ee19d..a0f046907f7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -46,6 +46,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
if have_pandas:
connect: RemoteSparkSession
tbl_name: str
+ tbl_name_empty: str
df_text: "DataFrame"
@classmethod
@@ -61,6 +62,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()
cls.tbl_name = "test_connect_basic_table_1"
+ cls.tbl_name_empty = "test_connect_basic_table_empty"
# Cleanup test data
cls.spark_connect_clean_up_test_data()
@@ -80,10 +82,21 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
# Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
+ empty_table_schema = StructType(
+ [
+ StructField("firstname", StringType(), True),
+ StructField("middlename", StringType(), True),
+ StructField("lastname", StringType(), True),
+ ]
+ )
+ emptyRDD = cls.spark.sparkContext.emptyRDD()
+ empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
+ empty_df.write.saveAsTable(cls.tbl_name_empty)
@classmethod
def spark_connect_clean_up_test_data(cls: Any) -> None:
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
+ cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
class SparkConnectTests(SparkConnectSQLTestCase):
@@ -145,10 +158,27 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.assertEqual(1, len(pdf.index))
def test_head(self):
+ # SPARK-41002: test `head` API in Python Client
+ df = self.connect.read.table(self.tbl_name)
+ self.assertIsNotNone(len(df.head()))
+ self.assertIsNotNone(len(df.head(1)))
+ self.assertIsNotNone(len(df.head(5)))
+ df2 = self.connect.read.table(self.tbl_name_empty)
+ self.assertIsNone(df2.head())
+
+ def test_first(self):
+ # SPARK-41002: test `first` API in Python Client
+ df = self.connect.read.table(self.tbl_name)
+ self.assertIsNotNone(len(df.first()))
+ df2 = self.connect.read.table(self.tbl_name_empty)
+ self.assertIsNone(df2.first())
+
+ def test_take(self) -> None:
+ # SPARK-41002: test `take` API in Python Client
df = self.connect.read.table(self.tbl_name)
- pd = df.head(10)
- self.assertIsNotNone(pd)
- self.assertEqual(10, len(pd.index))
+ self.assertEqual(5, len(df.take(5)))
+ df2 = self.connect.read.table(self.tbl_name_empty)
+ self.assertEqual(0, len(df2.take(5)))
def test_range(self):
self.assertTrue(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org