You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/15 02:55:45 UTC

spark git commit: [SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String

Repository: spark
Updated Branches:
  refs/heads/master 42a1a15d7 -> b98ffa4d6


[SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String

## What changes were proposed in this pull request?
This pr fixed the issue when casting `UserDefinedType`s into strings;
```
>>> from pyspark.ml.classification import MultilayerPerceptronClassifier
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), (1.0, Vectors.dense([0.0, 1.0]))], ["label", "features"])
>>> df.selectExpr("CAST(features AS STRING)").show(truncate = False)
+-------------------------------------------+
|features                                   |
+-------------------------------------------+
|[6,1,0,0,2800000020,2,0,0,0]               |
|[6,1,0,0,2800000020,2,0,0,3ff0000000000000]|
+-------------------------------------------+
```
The root cause is that `Cast` handles input data as `UserDefinedType.sqlType`(this is underlying storage type), so we should pass data into `UserDefinedType.deserialize` then `toString`.
This pr modified the result into;
```
+---------+
|features |
+---------+
|[0.0,0.0]|
|[0.0,1.0]|
+---------+
```

## How was this patch tested?
Added tests in `UserDefinedTypeSuite `.

Author: Takeshi Yamamuro <ya...@apache.org>

Closes #20246 from maropu/SPARK-23054.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b98ffa4d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b98ffa4d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b98ffa4d

Branch: refs/heads/master
Commit: b98ffa4d6dabaf787177d3f14b200fc4b118c7ce
Parents: 42a1a15
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Mon Jan 15 10:55:21 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jan 15 10:55:21 2018 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/expressions/Cast.scala |  7 +++++++
 .../org/apache/spark/sql/UserDefinedTypeSuite.scala  | 15 +++++++++++++--
 2 files changed, 20 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b98ffa4d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f21aa1e..a95ebe3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         builder.append("]")
         builder.build()
       })
+    case udt: UserDefinedType[_] =>
+      buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
     case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
   }
 
@@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
              |$evPrim = $buffer.build();
            """.stripMargin
         }
+      case udt: UserDefinedType[_] =>
+        val udtRef = ctx.addReferenceObj("udt", udt)
+        (c, evPrim, evNull) => {
+          s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());"
+        }
       case _ =>
         (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b98ffa4d/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index a08433b..cc8b600 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal}
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
 import org.apache.spark.sql.functions._
@@ -44,6 +44,8 @@ object UDT {
       case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data)
       case _ => false
     }
+
+    override def toString: String = data.mkString("(", ", ", ")")
   }
 
   private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
@@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType]
   override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
 }
 
-class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
+class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest
+    with ExpressionEvalHelper {
   import testImplicits._
 
   private lazy val pointsRDD = Seq(
@@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
       pointsRDD.except(pointsRDD2),
       Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
   }
+
+  test("SPARK-23054 Cast UserDefinedType to string") {
+    val udt = new UDT.MyDenseVectorUDT()
+    val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0))
+    val data = udt.serialize(vector)
+    val ret = Cast(Literal(data, udt), StringType, None)
+    checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org