You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/08/02 17:08:22 UTC

spark git commit: [SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs

Repository: spark
Updated Branches:
  refs/heads/master 1dab63d8d -> 146001a9f


[SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs

## What changes were proposed in this pull request?

There are two related bugs of Python-only UDTs. Because the test case of second one needs the first fix too. I put them into one PR. If it is not appropriate, please let me know.

### First bug: When MapObjects works on Python-only UDTs

`RowEncoder` will use `PythonUserDefinedType.sqlType` for its deserializer expression. If the sql type is `ArrayType`, we will have `MapObjects` working on it. But `MapObjects` doesn't consider `PythonUserDefinedType` as its input data type. It causes error like:

    import pyspark.sql.group
    from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
    from pyspark.sql.types import *

    schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
    df = spark.createDataFrame([(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema)
    df.show()

    File "/home/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o36.showString.
    : java.lang.RuntimeException: Error while decoding: scala.MatchError: org.apache.spark.sql.types.PythonUserDefinedTypef4ceede8 (of class org.apache.spark.sql.types.PythonUserDefinedType)
    ...

### Second bug: When Python-only UDTs is the element type of ArrayType

    import pyspark.sql.group
    from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
    from pyspark.sql.types import *

    schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
    df = spark.createDataFrame([(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema)
    df.show()

## How was this patch tested?
PySpark's sql tests.

Author: Liang-Chi Hsieh <si...@tw.ibm.com>

Closes #13778 from viirya/fix-pyudt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/146001a9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/146001a9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/146001a9

Branch: refs/heads/master
Commit: 146001a9ffefc7aaedd3d888d68c7a9b80bca545
Parents: 1dab63d
Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Authored: Tue Aug 2 10:08:18 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Aug 2 10:08:18 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 35 ++++++++++++++++++++
 .../sql/catalyst/encoders/RowEncoder.scala      |  9 ++++-
 .../catalyst/expressions/objects/objects.scala  | 17 ++++++++--
 3 files changed, 58 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/146001a9/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a8ca386..87dbb50 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -575,6 +575,41 @@ class SQLTests(ReusedPySparkTestCase):
         _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
         self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
 
+    def test_simple_udt_in_df(self):
+        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+        df = self.spark.createDataFrame(
+            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+            schema=schema)
+        df.show()
+
+    def test_nested_udt_in_df(self):
+        schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
+        df = self.spark.createDataFrame(
+            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
+            schema=schema)
+        df.collect()
+
+        schema = StructType().add("key", LongType()).add("val",
+                                                         MapType(LongType(), PythonOnlyUDT()))
+        df = self.spark.createDataFrame(
+            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
+            schema=schema)
+        df.collect()
+
+    def test_complex_nested_udt_in_df(self):
+        from pyspark.sql.functions import udf
+
+        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+        df = self.spark.createDataFrame(
+            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+            schema=schema)
+        df.collect()
+
+        gd = df.groupby("key").agg({"val": "collect_list"})
+        gd.collect()
+        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+        gd.select(udf(*gd)).collect()
+
     def test_udt_with_none(self):
         df = self.spark.range(0, 10, 1, 1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/146001a9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 67fca15..2a6fcd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -206,6 +206,7 @@ object RowEncoder {
     case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
     case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
     case _: StructType => ObjectType(classOf[Row])
+    case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
     case udt: UserDefinedType[_] => ObjectType(udt.userClass)
   }
 
@@ -220,9 +221,15 @@ object RowEncoder {
     CreateExternalRow(fields, schema)
   }
 
-  private def deserializerFor(input: Expression): Expression = input.dataType match {
+  private def deserializerFor(input: Expression): Expression = {
+    deserializerFor(input, input.dataType)
+  }
+
+  private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match {
     case dt if ScalaReflection.isNativeType(dt) => input
 
+    case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)
+
     case udt: UserDefinedType[_] =>
       val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
       val udtClass: Class[_] = if (annotation != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/146001a9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 0658941..952a5f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -359,6 +359,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
 object MapObjects {
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
+  /**
+   * Construct an instance of MapObjects case class.
+   *
+   * @param function The function applied on the collection elements.
+   * @param inputData An expression that when evaluated returns a collection object.
+   * @param elementType The data type of elements in the collection.
+   */
   def apply(
       function: Expression => Expression,
       inputData: Expression,
@@ -446,8 +453,14 @@ case class MapObjects private(
       case _ => ""
     }
 
+    // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
+    // When we want to apply MapObjects on it, we have to use it.
+    val inputDataType = inputData.dataType match {
+      case p: PythonUserDefinedType => p.sqlType
+      case _ => inputData.dataType
+    }
 
-    val (getLength, getLoopVar) = inputData.dataType match {
+    val (getLength, getLoopVar) = inputDataType match {
       case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
         s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
       case ObjectType(cls) if cls.isArray =>
@@ -461,7 +474,7 @@ case class MapObjects private(
           s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
     }
 
-    val loopNullCheck = inputData.dataType match {
+    val loopNullCheck = inputDataType match {
       case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
       // The element of primitive array will never be null.
       case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org