You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/19 03:37:16 UTC

spark git commit: [SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String.

Repository: spark
Updated Branches:
  refs/heads/master 6121e91b7 -> 568055da9


[SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String.

## What changes were proposed in this pull request?

This is a follow-up of #20246.

If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`.

This pr fixes it by using its `sqlType` casting.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <ue...@databricks.com>

Closes #20306 from ueshin/issues/SPARK-23054/fup1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/568055da
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/568055da
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/568055da

Branch: refs/heads/master
Commit: 568055da93049c207bb830f244ff9b60c638837c
Parents: 6121e91
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Fri Jan 19 11:37:08 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jan 19 11:37:08 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                              | 11 +++++++++++
 .../org/apache/spark/sql/catalyst/expressions/Cast.scala |  2 ++
 .../org/apache/spark/sql/test/ExamplePointUDT.scala      |  2 ++
 3 files changed, 15 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/568055da/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2548359..4fee2ec 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1189,6 +1189,17 @@ class SQLTests(ReusedSQLTestCase):
             ]
         )
 
+    def test_cast_to_string_with_udt(self):
+        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
+        from pyspark.sql.functions import col
+        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
+        schema = StructType([StructField("point", ExamplePointUDT(), False),
+                             StructField("pypoint", PythonOnlyUDT(), False)])
+        df = self.spark.createDataFrame([row], schema)
+
+        result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
+        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
+
     def test_column_operators(self):
         ci = self.df.key
         cs = self.df.value

http://git-wip-us.apache.org/repos/asf/spark/blob/568055da/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a95ebe3..79b0516 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         builder.append("]")
         builder.build()
       })
+    case pudt: PythonUserDefinedType => castToString(pudt.sqlType)
     case udt: UserDefinedType[_] =>
       buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
     case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
@@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
              |$evPrim = $buffer.build();
            """.stripMargin
         }
+      case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx)
       case udt: UserDefinedType[_] =>
         val udtRef = ctx.addReferenceObj("udt", udt)
         (c, evPrim, evNull) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/568055da/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index a73e427..8bab7e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab
     case that: ExamplePoint => this.x == that.x && this.y == that.y
     case _ => false
   }
+
+  override def toString(): String = s"($x, $y)"
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org