You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/22 06:44:40 UTC

[spark] branch master updated: [SPARK-41213][CONNECT][PYTHON] Implement `DataFrame.__repr__` and `DataFrame.dtypes`

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 55addb38f2a [SPARK-41213][CONNECT][PYTHON] Implement `DataFrame.__repr__` and `DataFrame.dtypes`
55addb38f2a is described below

commit 55addb38f2a9cc2521a2e7908bb8ad3d49c9a8bb
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Nov 22 15:44:28 2022 +0900

    [SPARK-41213][CONNECT][PYTHON] Implement `DataFrame.__repr__` and `DataFrame.dtypes`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.__repr__` and `DataFrame.dtypes`
    
    ### Why are the changes needed?
    For api coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #38735 from zhengruifeng/connect_df_repr.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py               | 19 ++++++++++++++++++-
 .../pyspark/sql/tests/connect/test_connect_basic.py   | 13 ++++++++++---
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 15aa028b11b..275a6d2668d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -115,6 +115,9 @@ class DataFrame(object):
         self._cache: Dict[str, Any] = {}
         self._session: "RemoteSparkSession" = session
 
+    def __repr__(self) -> str:
+        return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+
     @classmethod
     def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame":
         """Main initialization method used to construct a new data frame with a child plan."""
@@ -137,13 +140,26 @@ class DataFrame(object):
     def colRegex(self, regex: str) -> "DataFrame":
         ...
 
+    @property
+    def dtypes(self) -> List[Tuple[str, str]]:
+        """Returns all column names and their data types as a list.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        list
+            List of columns as tuple pairs.
+        """
+        return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
+
     @property
     def columns(self) -> List[str]:
         """Returns the list of columns of the current data frame."""
         if self._plan is None:
             return []
 
-        return self.schema().names
+        return self.schema.names
 
     def sparkSession(self) -> "RemoteSparkSession":
         """Returns Spark session that created this :class:`DataFrame`.
@@ -736,6 +752,7 @@ class DataFrame(object):
         query = self._plan.to_proto(self._session)
         return self._session._to_pandas(query)
 
+    @property
     def schema(self) -> StructType:
         """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9e7a5f2f4a5..917cce0ebb0 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -137,7 +137,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         self.assertGreater(len(result), 0)
 
     def test_schema(self):
-        schema = self.connect.read.table(self.tbl_name).schema()
+        schema = self.connect.read.table(self.tbl_name).schema
         self.assertEqual(
             StructType(
                 [StructField("id", LongType(), True), StructField("name", StringType(), True)]
@@ -333,6 +333,14 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         expected = "+---+---+\n|  X|  Y|\n+---+---+\n|  1|  2|\n+---+---+\n"
         self.assertEqual(show_str, expected)
 
+    def test_repr(self):
+        # SPARK-41213: Test the __repr__ method
+        query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
+        self.assertEqual(
+            self.connect.sql(query).__repr__(),
+            self.spark.sql(query).__repr__(),
+        )
+
     def test_explain_string(self):
         # SPARK-41122: test explain API.
         plan_str = self.connect.sql("SELECT 1").explain(extended=True)
@@ -380,8 +388,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         col0 = (
             self.connect.range(1, 10)
             .select(col("id").alias("name", metadata={"max": 99}))
-            .schema()
-            .names[0]
+            .schema.names[0]
         )
         self.assertEqual("name", col0)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org