You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/08 00:31:24 UTC

[spark] branch master updated: [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a361b9ddfa [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client
2a361b9ddfa is described below

commit 2a361b9ddfa766c719399b35c38f4dafe68353ee
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Tue Nov 8 08:30:49 2022 +0800

    [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client
    
    ### What changes were proposed in this pull request?
    
    1. Add `take(n)` API.
    2. Change `head(n)` API to return `Union[Optional[Row], List[Row]]`.
    3. Update `first()` to return `Optional[Row]`.
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38488 from amaliujia/SPARK-41002.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            | 61 ++++++++++++++++++++--
 .../sql/tests/connect/test_connect_basic.py        | 36 +++++++++++--
 2 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index b9ba4b99ba0..9eecdbb7145 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -24,6 +24,7 @@ from typing import (
     Tuple,
     Union,
     TYPE_CHECKING,
+    overload,
 )
 
 import pandas
@@ -211,14 +212,66 @@ class DataFrame(object):
             plan.Filter(child=self._plan, filter=condition), session=self._session
         )
 
-    def first(self) -> Optional["pandas.DataFrame"]:
-        return self.head(1)
+    def first(self) -> Optional[Row]:
+        """Returns the first row as a :class:`Row`.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        :class:`Row`
+           First row if :class:`DataFrame` is not empty, otherwise ``None``.
+        """
+        return self.head()
 
     def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
         return GroupingFrame(self, *cols)
 
-    def head(self, n: int) -> Optional["pandas.DataFrame"]:
-        return self.limit(n).toPandas()
+    @overload
+    def head(self) -> Optional[Row]:
+        ...
+
+    @overload
+    def head(self, n: int) -> List[Row]:
+        ...
+
+    def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
+        """Returns the first ``n`` rows.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int, optional
+            default 1. Number of rows to return.
+
+        Returns
+        -------
+        If n is greater than 1, return a list of :class:`Row`.
+        If n is 1, return a single Row.
+        """
+        if n is None:
+            rs = self.head(1)
+            return rs[0] if rs else None
+        return self.take(n)
+
+    def take(self, num: int) -> List[Row]:
+        """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        num : int
+            Number of records to return. Will return this number of records
+            or whataver number is available.
+
+        Returns
+        -------
+        list
+            List of rows
+        """
+        return self.limit(num).collect()
 
     # TODO: extend `on` to also be type List[ColumnRef].
     def join(
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 18a752ee19d..a0f046907f7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -46,6 +46,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
     if have_pandas:
         connect: RemoteSparkSession
     tbl_name: str
+    tbl_name_empty: str
     df_text: "DataFrame"
 
     @classmethod
@@ -61,6 +62,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
         cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()
 
         cls.tbl_name = "test_connect_basic_table_1"
+        cls.tbl_name_empty = "test_connect_basic_table_empty"
 
         # Cleanup test data
         cls.spark_connect_clean_up_test_data()
@@ -80,10 +82,21 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
         # Since we might create multiple Spark sessions, we need to create global temporary view
         # that is specifically maintained in the "global_temp" schema.
         df.write.saveAsTable(cls.tbl_name)
+        empty_table_schema = StructType(
+            [
+                StructField("firstname", StringType(), True),
+                StructField("middlename", StringType(), True),
+                StructField("lastname", StringType(), True),
+            ]
+        )
+        emptyRDD = cls.spark.sparkContext.emptyRDD()
+        empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
+        empty_df.write.saveAsTable(cls.tbl_name_empty)
 
     @classmethod
     def spark_connect_clean_up_test_data(cls: Any) -> None:
         cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
+        cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
 
 
 class SparkConnectTests(SparkConnectSQLTestCase):
@@ -145,10 +158,27 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         self.assertEqual(1, len(pdf.index))
 
     def test_head(self):
+        # SPARK-41002: test `head` API in Python Client
+        df = self.connect.read.table(self.tbl_name)
+        self.assertIsNotNone(len(df.head()))
+        self.assertIsNotNone(len(df.head(1)))
+        self.assertIsNotNone(len(df.head(5)))
+        df2 = self.connect.read.table(self.tbl_name_empty)
+        self.assertIsNone(df2.head())
+
+    def test_first(self):
+        # SPARK-41002: test `first` API in Python Client
+        df = self.connect.read.table(self.tbl_name)
+        self.assertIsNotNone(len(df.first()))
+        df2 = self.connect.read.table(self.tbl_name_empty)
+        self.assertIsNone(df2.first())
+
+    def test_take(self) -> None:
+        # SPARK-41002: test `take` API in Python Client
         df = self.connect.read.table(self.tbl_name)
-        pd = df.head(10)
-        self.assertIsNotNone(pd)
-        self.assertEqual(10, len(pd.index))
+        self.assertEqual(5, len(df.take(5)))
+        df2 = self.connect.read.table(self.tbl_name_empty)
+        self.assertEqual(0, len(df2.take(5)))
 
     def test_range(self):
         self.assertTrue(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org