You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/07/29 03:59:57 UTC

spark git commit: [SPARK-20090][PYTHON] Add StructType.fieldNames in PySpark

Repository: spark
Updated Branches:
  refs/heads/master 0ef9fe64e -> b56f79cc3


[SPARK-20090][PYTHON] Add StructType.fieldNames in PySpark

## What changes were proposed in this pull request?

This PR proposes `StructType.fieldNames` that returns a copy of a field name list rather than a (undocumented) `StructType.names`.

There are two points here:

  - API consistency with Scala/Java

  - Provide a safe way to get the field names. Manipulating these might cause unexpected behaviour as below:

    ```python
    from pyspark.sql.types import *

    struct = StructType([StructField("f1", StringType(), True)])
    names = struct.names
    del names[0]
    spark.createDataFrame([{"f1": 1}], struct).show()
    ```

    ```
    ...
    java.lang.IllegalStateException: Input row doesn't have expected number of values required by the schema. 1 fields are required while 0 values are provided.
    	at org.apache.spark.sql.execution.python.EvaluatePython$.fromJava(EvaluatePython.scala:138)
    	at org.apache.spark.sql.SparkSession$$anonfun$6.apply(SparkSession.scala:741)
    	at org.apache.spark.sql.SparkSession$$anonfun$6.apply(SparkSession.scala:741)
    ...
    ```

## How was this patch tested?

Added tests in `python/pyspark/sql/tests.py`.

Author: hyukjinkwon <gu...@gmail.com>

Closes #18618 from HyukjinKwon/SPARK-20090.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b56f79cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b56f79cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b56f79cc

Branch: refs/heads/master
Commit: b56f79cc359d093d757af83171175cfd933162d1
Parents: 0ef9fe6
Author: hyukjinkwon <gu...@gmail.com>
Authored: Fri Jul 28 20:59:32 2017 -0700
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Fri Jul 28 20:59:32 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 16 ++++++++--------
 python/pyspark/sql/types.py | 15 ++++++++++++++-
 2 files changed, 22 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b56f79cc/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 54756ed..cfd9c55 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1241,26 +1241,29 @@ class SQLTests(ReusedPySparkTestCase):
         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
         struct2 = StructType([StructField("f1", StringType(), True),
                               StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1.fieldNames(), struct2.names)
         self.assertEqual(struct1, struct2)
 
         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
         struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1.fieldNames(), struct2.names)
         self.assertNotEqual(struct1, struct2)
 
         struct1 = (StructType().add(StructField("f1", StringType(), True))
                    .add(StructField("f2", StringType(), True, None)))
         struct2 = StructType([StructField("f1", StringType(), True),
                               StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1.fieldNames(), struct2.names)
         self.assertEqual(struct1, struct2)
 
         struct1 = (StructType().add(StructField("f1", StringType(), True))
                    .add(StructField("f2", StringType(), True, None)))
         struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1.fieldNames(), struct2.names)
         self.assertNotEqual(struct1, struct2)
 
         # Catch exception raised during improper construction
-        with self.assertRaises(ValueError):
-            struct1 = StructType().add("name")
+        self.assertRaises(ValueError, lambda: StructType().add("name"))
 
         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
         for field in struct1:
@@ -1273,12 +1276,9 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertIs(struct1["f1"], struct1.fields[0])
         self.assertIs(struct1[0], struct1.fields[0])
         self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
-        with self.assertRaises(KeyError):
-            not_a_field = struct1["f9"]
-        with self.assertRaises(IndexError):
-            not_a_field = struct1[9]
-        with self.assertRaises(TypeError):
-            not_a_field = struct1[9.9]
+        self.assertRaises(KeyError, lambda: struct1["f9"])
+        self.assertRaises(IndexError, lambda: struct1[9])
+        self.assertRaises(TypeError, lambda: struct1[9.9])
 
     def test_parse_datatype_string(self):
         from pyspark.sql.types import _all_atomic_types, _parse_datatype_string

http://git-wip-us.apache.org/repos/asf/spark/blob/b56f79cc/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c376805..ecb8eb9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -446,9 +446,12 @@ class StructType(DataType):
 
     This is the data type representing a :class:`Row`.
 
-    Iterating a :class:`StructType` will iterate its :class:`StructField`s.
+    Iterating a :class:`StructType` will iterate its :class:`StructField`\\s.
     A contained :class:`StructField` can be accessed by name or position.
 
+    .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead
+        to get a list of field names.
+
     >>> struct1 = StructType([StructField("f1", StringType(), True)])
     >>> struct1["f1"]
     StructField(f1,StringType,true)
@@ -563,6 +566,16 @@ class StructType(DataType):
     def fromJson(cls, json):
         return StructType([StructField.fromJson(f) for f in json["fields"]])
 
+    def fieldNames(self):
+        """
+        Returns all field names in a list.
+
+        >>> struct = StructType([StructField("f1", StringType(), True)])
+        >>> struct.fieldNames()
+        ['f1']
+        """
+        return list(self.names)
+
     def needConversion(self):
         # We need convert Row()/namedtuple into tuple()
         return True


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org