You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/07/07 06:14:56 UTC

[spark] branch master updated: [SPARK-35929][PYTHON] Support to infer nested dict as a struct when creating a DataFrame

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2537fe8  [SPARK-35929][PYTHON] Support to infer nested dict as a struct when creating a DataFrame
2537fe8 is described below

commit 2537fe8cbaf49070137d4b5bc39af078b306c4c8
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Jul 7 15:14:18 2021 +0900

    [SPARK-35929][PYTHON] Support to infer nested dict as a struct when creating a DataFrame
    
    ### What changes were proposed in this pull request?
    
    Currently, inferring nested structs is always using `MapType`.
    
    This behavior causes an issue because it infers the schema with a value type of the first field of the struct as below:
    
    ```python
    data = [{"inside_struct": {"payment": 100.5, "name": "Lee"}}]
    df = spark.createDataFrame(data)
    df.show(truncate=False)
    +--------------------------------+
    |inside_struct                   |
    +--------------------------------+
    |{name -> null, payment -> 100.5}|
    +--------------------------------+
    ```
    
    The "name" became `null`, but it should've been `"Lee"`.
    
    In this case, we need to be able to infer the schema with a `StructType` instead of a `MapType`.
    
    Therefore, this PR proposes adding an new configuration `spark.sql.pyspark.inferNestedDictAsStruct.enabled` to handle which type is used for inferring nested structs.
    - When `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is `false` (by default), inferring nested structs by `MapType`
    - When `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is `true`, inferring nested structs by `StructType`
    
    ### Why are the changes needed?
    
    Because always inferring the nested structs by `MapType` doesn't work properly for some cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    New configuration `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is added.
    
    ### How was this patch tested?
    
    Added an unit test
    
    Closes #33214 from itholic/SPARK-35929.
    
    Lead-authored-by: itholic <ha...@databricks.com>
    Co-authored-by: Hyukjin Kwon <gu...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/session.py                      | 13 +++++++----
 python/pyspark/sql/tests/test_types.py             | 15 +++++++++++++
 python/pyspark/sql/types.py                        | 26 ++++++++++++++--------
 .../org/apache/spark/sql/internal/SQLConf.scala    |  9 ++++++++
 4 files changed, 50 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 740ceb3..f3a63de 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -436,7 +436,9 @@ class SparkSession(SparkConversionMixin):
         """
         if not data:
             raise ValueError("can not infer schema from empty dataset")
-        schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
+        infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
+        schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct)
+                        for row in data))
         if _has_nulltype(schema):
             raise ValueError("Some of types cannot be determined after inferring")
         return schema
@@ -462,11 +464,13 @@ class SparkSession(SparkConversionMixin):
             raise ValueError("The first row in RDD is empty, "
                              "can not infer schema")
 
+        infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
         if samplingRatio is None:
-            schema = _infer_schema(first, names=names)
+            schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct)
             if _has_nulltype(schema):
                 for row in rdd.take(100)[1:]:
-                    schema = _merge_type(schema, _infer_schema(row, names=names))
+                    schema = _merge_type(schema, _infer_schema(
+                        row, names=names, infer_dict_as_struct=infer_dict_as_struct))
                     if not _has_nulltype(schema):
                         break
                 else:
@@ -475,7 +479,8 @@ class SparkSession(SparkConversionMixin):
         else:
             if samplingRatio < 0.99:
                 rdd = rdd.sample(False, float(samplingRatio))
-            schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
+            schema = rdd.map(lambda row: _infer_schema(
+                row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type)
         return schema
 
     def _createFromRDD(self, rdd, schema, samplingRatio):
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index eb4caf0..0bb1f00 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -204,6 +204,21 @@ class TypesTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(rdd)
         self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
 
+    def test_infer_nested_dict_as_struct(self):
+        # SPARK-35929: Test inferring nested dict as a struct type.
+        NestedRow = Row("f1", "f2")
+
+        with self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}):
+            data = [NestedRow([{"payment": 200.5, "name": "A"}], [1, 2]),
+                    NestedRow([{"payment": 100.5, "name": "B"}], [2, 3])]
+
+            nestedRdd = self.sc.parallelize(data)
+            df = self.spark.createDataFrame(nestedRdd)
+            self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first())
+
+            df = self.spark.createDataFrame(data)
+            self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first())
+
     def test_create_dataframe_from_dict_respects_schema(self):
         df = self.spark.createDataFrame([{'a': 1}], ["b"])
         self.assertEqual(df.columns, ['b'])
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 78c7732..e3d8f49 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1003,7 +1003,7 @@ if sys.version_info[0] < 4:
     _array_type_mappings['u'] = StringType
 
 
-def _infer_type(obj):
+def _infer_type(obj, infer_dict_as_struct=False):
     """Infer the DataType from obj
     """
     if obj is None:
@@ -1020,14 +1020,22 @@ def _infer_type(obj):
         return dataType()
 
     if isinstance(obj, dict):
-        for key, value in obj.items():
-            if key is not None and value is not None:
-                return MapType(_infer_type(key), _infer_type(value), True)
-        return MapType(NullType(), NullType(), True)
+        if infer_dict_as_struct:
+            struct = StructType()
+            for key, value in obj.items():
+                if key is not None and value is not None:
+                    struct.add(key, _infer_type(value, infer_dict_as_struct), True)
+            return struct
+        else:
+            for key, value in obj.items():
+                if key is not None and value is not None:
+                    return MapType(_infer_type(key, infer_dict_as_struct),
+                                   _infer_type(value, infer_dict_as_struct), True)
+            return MapType(NullType(), NullType(), True)
     elif isinstance(obj, list):
         for v in obj:
             if v is not None:
-                return ArrayType(_infer_type(obj[0]), True)
+                return ArrayType(_infer_type(obj[0], infer_dict_as_struct), True)
         return ArrayType(NullType(), True)
     elif isinstance(obj, array):
         if obj.typecode in _array_type_mappings:
@@ -1036,12 +1044,12 @@ def _infer_type(obj):
             raise TypeError("not supported type: array(%s)" % obj.typecode)
     else:
         try:
-            return _infer_schema(obj)
+            return _infer_schema(obj, infer_dict_as_struct=infer_dict_as_struct)
         except TypeError:
             raise TypeError("not supported type: %s" % type(obj))
 
 
-def _infer_schema(row, names=None):
+def _infer_schema(row, names=None, infer_dict_as_struct=False):
     """Infer the schema from dict/namedtuple/object"""
     if isinstance(row, dict):
         items = sorted(row.items())
@@ -1067,7 +1075,7 @@ def _infer_schema(row, names=None):
     fields = []
     for k, v in items:
         try:
-            fields.append(StructField(k, _infer_type(v), True))
+            fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), True))
         except TypeError as e:
             raise TypeError("Unable to infer the type of the field {}.".format(k)) from e
     return StructType(fields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index cc53d92..e9c5f6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3335,6 +3335,13 @@ object SQLConf {
     .intConf
     .createWithDefault(0)
 
+  val INFER_NESTED_DICT_AS_STRUCT = buildConf("spark.sql.pyspark.inferNestedDictAsStruct.enabled")
+    .doc("PySpark's SparkSession.createDataFrame infers the nested dict as a map by default. " +
+      "When it set to true, it infers the nested dict as a struct.")
+    .version("3.3.0")
+    .booleanConf
+    .createWithDefault(false)
+
   /**
    * Holds information about keys that have been deprecated.
    *
@@ -4048,6 +4055,8 @@ class SQLConf extends Serializable with Logging {
 
   def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)
 
+  def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org