You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/01/05 19:20:00 UTC

spark git commit: [SPARK-12438][SQL] Add SQLUserDefinedType support for encoder

Repository: spark
Updated Branches:
  refs/heads/master 1cdc42d2b -> b3c48e39f


[SPARK-12438][SQL] Add SQLUserDefinedType support for encoder

JIRA: https://issues.apache.org/jira/browse/SPARK-12438

ScalaReflection lacks the support of SQLUserDefinedType. We should add it.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #10390 from viirya/encoder-udt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3c48e39
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3c48e39
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3c48e39

Branch: refs/heads/master
Commit: b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d
Parents: 1cdc42d
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue Jan 5 10:19:56 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jan 5 10:19:56 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 22 ++++++++++++++++++++
 .../spark/sql/catalyst/expressions/Cast.scala   | 14 +++++++++++++
 .../encoders/ExpressionEncoderSuite.scala       |  2 ++
 3 files changed, 38 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9784c96..c6aa60b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection {
       case _ => UpCast(expr, expected, walkedTypePath)
     }
 
+    val className = getClassNameFromType(tpe)
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
@@ -360,6 +361,16 @@ object ScalaReflection extends ScalaReflection {
         } else {
           newInstance
         }
+
+      case t if Utils.classIsLoadable(className) &&
+        Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+        val udt = Utils.classForName(className)
+          .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+        val obj = NewInstance(
+          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+          Nil,
+          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
     }
   }
 
@@ -409,6 +420,7 @@ object ScalaReflection extends ScalaReflection {
     if (!inputObject.dataType.isInstanceOf[ObjectType]) {
       inputObject
     } else {
+      val className = getClassNameFromType(tpe)
       tpe match {
         case t if t <:< localTypeOf[Option[_]] =>
           val TypeRef(_, _, Seq(optType)) = t
@@ -559,6 +571,16 @@ object ScalaReflection extends ScalaReflection {
         case t if t <:< localTypeOf[java.lang.Boolean] =>
           Invoke(inputObject, "booleanValue", BooleanType)
 
+        case t if Utils.classIsLoadable(className) &&
+          Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+          val udt = Utils.classForName(className)
+            .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+          val obj = NewInstance(
+            udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+            Nil,
+            dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+          Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
         case other =>
           throw new UnsupportedOperationException(
             s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))

http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b18f49f..d82d3ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.math.{BigDecimal => JavaBigDecimal}
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -81,6 +82,9 @@ object Cast {
                 toField.nullable)
         }
 
+    case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
+      true
+
     case _ => false
   }
 
@@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
     case map: MapType => castMap(from.asInstanceOf[MapType], map)
     case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+    case udt: UserDefinedType[_]
+      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+      identity[Any]
+    case _: UserDefinedType[_] =>
+      throw new SparkException(s"Cannot cast $from to $to.")
   }
 
   private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
@@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
     case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
     case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+    case udt: UserDefinedType[_]
+      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+      (c, evPrim, evNull) => s"$evPrim = $c;"
+    case _: UserDefinedType[_] =>
+      throw new SparkException(s"Cannot cast $from to $to.")
   }
 
   // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's

http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3740dea..6453f1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -244,6 +244,8 @@ class ExpressionEncoderSuite extends SparkFunSuite {
     ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
   }
 
+  productTest(("UDT", new ExamplePoint(0.1, 0.2)))
+
   test("nullable of encoder schema") {
     def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
       assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org