You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/01/05 19:20:00 UTC
spark git commit: [SPARK-12438][SQL] Add SQLUserDefinedType support
for encoder
Repository: spark
Updated Branches:
refs/heads/master 1cdc42d2b -> b3c48e39f
[SPARK-12438][SQL] Add SQLUserDefinedType support for encoder
JIRA: https://issues.apache.org/jira/browse/SPARK-12438
ScalaReflection lacks the support of SQLUserDefinedType. We should add it.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #10390 from viirya/encoder-udt.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3c48e39
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3c48e39
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3c48e39
Branch: refs/heads/master
Commit: b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d
Parents: 1cdc42d
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue Jan 5 10:19:56 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jan 5 10:19:56 2016 -0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 22 ++++++++++++++++++++
.../spark/sql/catalyst/expressions/Cast.scala | 14 +++++++++++++
.../encoders/ExpressionEncoderSuite.scala | 2 ++
3 files changed, 38 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9784c96..c6aa60b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}
+ val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
@@ -360,6 +361,16 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
+
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
@@ -409,6 +420,7 @@ object ScalaReflection extends ScalaReflection {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
+ val className = getClassNameFromType(tpe)
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
@@ -559,6 +571,16 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b18f49f..d82d3ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal => JavaBigDecimal}
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -81,6 +82,9 @@ object Cast {
toField.nullable)
}
+ case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
+ true
+
case _ => false
}
@@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+ case udt: UserDefinedType[_]
+ if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+ identity[Any]
+ case _: UserDefinedType[_] =>
+ throw new SparkException(s"Cannot cast $from to $to.")
}
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
@@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+ case udt: UserDefinedType[_]
+ if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+ (c, evPrim, evNull) => s"$evPrim = $c;"
+ case _: UserDefinedType[_] =>
+ throw new SparkException(s"Cannot cast $from to $to.")
}
// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
http://git-wip-us.apache.org/repos/asf/spark/blob/b3c48e39/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3740dea..6453f1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -244,6 +244,8 @@ class ExpressionEncoderSuite extends SparkFunSuite {
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}
+ productTest(("UDT", new ExamplePoint(0.1, 0.2)))
+
test("nullable of encoder schema") {
def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org