You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/01 16:20:07 UTC

[GitHub] [spark] eddyxu commented on a change in pull request #32026: [SPARK-34771] Support UDT for Pandas/Spark conversion with Arrow support Enabled

eddyxu commented on a change in pull request #32026:
URL: https://github.com/apache/spark/pull/32026#discussion_r605778192



##########
File path: python/pyspark/sql/types.py
##########
@@ -764,6 +764,21 @@ def __eq__(self, other):
         return type(self) == type(other)
 
 
+def _is_datatype_with_udt(dt):

Review comment:
       `_has_udt()` for short?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
##########
@@ -89,9 +89,57 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
 
     columnarBatchIter.flatMap { batch =>
       val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
-      assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
-        s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+      assert(plainSchemaSeq(outputTypes) == actualDataTypes,
+        "Incompatible schema from pandas_udf: " +
+          s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
       batch.rowIterator.asScala
     }
   }
+
+  private def plainSchemaSeq(schema: Seq[DataType]): Seq[DataType] = {
+    schema.map(v => ArrowEvalPythonExec.plainSchema(v)).toList
+  }
+
+}
+
+private[sql] object ArrowEvalPythonExec {
+  /**
+   * Erase User-Defined Types and returns the plain Spark StructType instead.
+   *
+   * UserDefinedType:
+   * - will be erased as dt.sqlType

Review comment:
       should we rephrase `dt.sqlType` to `UserDefinedType.sqlType`, as `dt` has not been defined with in this context. 

##########
File path: python/pyspark/sql/tests/test_arrow.py
##########
@@ -196,6 +197,33 @@ def test_pandas_round_trip(self):
         pdf_arrow = df.toPandas()
         assert_frame_equal(pdf_arrow, pdf)
 
+    def test_udt_roundtrip(self):
+        pdf = pd.DataFrame({'point': pd.Series([ExamplePoint(1.0, 1.0), ExamplePoint(2.0, 2.0)])})
+        schema = StructType([StructField('point', ExamplePointUDT(), False)])
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
+            df = self.spark.createDataFrame(pdf, schema)
+            pdf_arrow = df.toPandas()
+            assert_frame_equal(pdf_arrow, pdf)
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": False}):
+            df = self.spark.createDataFrame(pdf, schema)
+            pdf_arrow = df.toPandas()
+            assert_frame_equal(pdf_arrow, pdf)
+
+    def test_array_udt_roundtrip(self):
+        pdf = pd.DataFrame({'points': pd.Series([
+            [ExamplePoint(1.0, 1.0), ExamplePoint(1.0, 2.0), ExamplePoint(1.0, 3.0)],

Review comment:
       In the original PR, there is a new UDT of which the sqlType is a structtype, instead of `ArrayType` here for `ExamplePoint`.   Should we also test that?
   
   Also , is it possible a UDT's sqlTypes are primitive types? might want to add them in tests as well.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
##########
@@ -89,9 +89,57 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
 
     columnarBatchIter.flatMap { batch =>
       val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
-      assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
-        s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+      assert(plainSchemaSeq(outputTypes) == actualDataTypes,
+        "Incompatible schema from pandas_udf: " +
+          s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
       batch.rowIterator.asScala
     }
   }
+
+  private def plainSchemaSeq(schema: Seq[DataType]): Seq[DataType] = {
+    schema.map(v => ArrowEvalPythonExec.plainSchema(v)).toList
+  }
+
+}
+
+private[sql] object ArrowEvalPythonExec {
+  /**
+   * Erase User-Defined Types and returns the plain Spark StructType instead.
+   *
+   * UserDefinedType:
+   * - will be erased as dt.sqlType
+   * - recursively rewrite internal ArrayType with `containsNull=true`
+   * ArrayType: containsNull will always be true when returned by PyArrow

Review comment:
       Should we add a link to the code where pyarrow to Spark conversion happens?

##########
File path: python/pyspark/sql/types.py
##########
@@ -764,6 +764,21 @@ def __eq__(self, other):
         return type(self) == type(other)
 
 
+def _is_datatype_with_udt(dt):

Review comment:
       Is `MapType` explicitly unsupported? if so, should we add document to this method?

##########
File path: python/pyspark/sql/pandas/conversion.py
##########
@@ -452,24 +457,27 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
                 struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
             schema = struct
 
-        # Determine arrow types to coerce data when creating batches
+        # Determine data types to coerce data when creating batches
         if isinstance(schema, StructType):
-            arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
+            data_types = [f.dataType for f in schema.fields]
         elif isinstance(schema, DataType):
             raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
         else:
             # Any timestamps must be coerced to be compatible with Spark
-            arrow_types = [to_arrow_type(TimestampType())
-                           if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
-                           for t in pdf.dtypes]
+            data_types = [to_arrow_type(TimestampType())
+                          if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
+                          for t in pdf.dtypes]
 
         # Slice the DataFrame to be batched
         step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # round int up
         pdf_slices = (pdf.iloc[start:start + step] for start in range(0, len(pdf), step))
 
         # Create list of Arrow (columns, type) for serializer dump_stream
-        arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
-                      for pdf_slice in pdf_slices]
+        # Type can be Spark SQL Data Type or Arrow Data Type
+        arrow_data_with_t = [

Review comment:
       this seems to be only a rename? should we leave them as they are to reduce the size of PR?

##########
File path: python/pyspark/sql/pandas/conversion.py
##########
@@ -20,9 +20,10 @@
 
 from pyspark.rdd import _load_from_socket
 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
-from pyspark.sql.types import IntegralType
 from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
-    DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
+    DoubleType, BooleanType, MapType, TimestampType, StructType, DataType, \
+    IntegralType, _is_datatype_with_udt
+from pyspark.sql.pandas.types import _deserialize_pandas_with_udt

Review comment:
       these two `_` prefixed functions are used outside their files. should we remove the `_` prefix, because they are not private anymore.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org