You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/05 06:25:31 UTC
spark git commit: [SPARK-18623][SQL] Add `returnNullable` to
`StaticInvoke` and modify it to handle properly.
Repository: spark
Updated Branches:
refs/heads/master f2c3b1dd6 -> a38643256
[SPARK-18623][SQL] Add `returnNullable` to `StaticInvoke` and modify it to handle properly.
## What changes were proposed in this pull request?
Add `returnNullable` to `StaticInvoke` the same as #15780 is trying to add to `Invoke` and modify to handle properly.
## How was this patch tested?
Existing tests.
Author: Takuya UESHIN <ue...@happy-camper.st>
Author: Takuya UESHIN <ue...@databricks.com>
Closes #16056 from ueshin/issues/SPARK-18623.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3864325
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3864325
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3864325
Branch: refs/heads/master
Commit: a38643256691947ff7f7c474b85c052a7d5d8553
Parents: f2c3b1d
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Wed Jul 5 14:25:26 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jul 5 14:25:26 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/JavaTypeInference.scala | 21 ++++++----
.../spark/sql/catalyst/ScalaReflection.scala | 44 ++++++++++++--------
.../sql/catalyst/encoders/RowEncoder.scala | 27 ++++++++----
.../catalyst/expressions/objects/objects.scala | 42 +++++++++++++++----
4 files changed, 91 insertions(+), 43 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 90ec699..21363d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -216,7 +216,7 @@ object JavaTypeInference {
ObjectType(c),
"valueOf",
getPath :: Nil,
- propagateNull = true)
+ returnNullable = false)
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
@@ -224,7 +224,7 @@ object JavaTypeInference {
ObjectType(c),
"toJavaDate",
getPath :: Nil,
- propagateNull = true)
+ returnNullable = false)
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
@@ -232,7 +232,7 @@ object JavaTypeInference {
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
- propagateNull = true)
+ returnNullable = false)
case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -300,7 +300,8 @@ object JavaTypeInference {
ArrayBasedMapData.getClass,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
- keyData :: valueData :: Nil)
+ keyData :: valueData :: Nil,
+ returnNullable = false)
case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
@@ -367,28 +368,32 @@ object JavaTypeInference {
classOf[UTF8String],
StringType,
"fromString",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bea0de4..814f2c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
- StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+ StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
- getPath :: Nil)
+ getPath :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
- getPath :: Nil)
+ getPath :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
@@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection {
classOf[UnsafeArrayData],
ArrayType(dt, false),
"fromPrimitiveArray",
- input :: Nil)
+ input :: Nil,
+ returnNullable = false)
} else {
NewInstance(
classOf[GenericArrayData],
@@ -504,49 +507,56 @@ object ScalaReflection extends ScalaReflection {
classOf[UTF8String],
StringType,
"fromString",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.math.BigInteger] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[scala.math.BigInt] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0f8282d..cc32fac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -96,28 +96,32 @@ object RowEncoder {
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case d: DecimalType =>
StaticInvoke(
Decimal.getClass,
d,
"fromDecimal",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case StringType =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case t @ ArrayType(et, cn) =>
et match {
@@ -126,7 +130,8 @@ object RowEncoder {
classOf[ArrayData],
t,
"toArrayData",
- inputObject :: Nil)
+ inputObject :: Nil,
+ returnNullable = false)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
@@ -254,14 +259,16 @@ object RowEncoder {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
- input :: Nil)
+ input :: Nil,
+ returnNullable = false)
case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
- input :: Nil)
+ input :: Nil,
+ returnNullable = false)
case _: DecimalType =>
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
@@ -280,7 +287,8 @@ object RowEncoder {
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
- arrayData :: Nil)
+ arrayData :: Nil,
+ returnNullable = false)
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
@@ -293,7 +301,8 @@ object RowEncoder {
ArrayBasedMapData.getClass,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
- keyData :: valueData :: Nil)
+ keyData :: valueData :: Nil,
+ returnNullable = false)
case schema @ StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index ce07f4a..24c06d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression {
* @param arguments An optional list of expressions to pass as arguments to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
+ * @param returnNullable When false, indicating the invoked method will always return
+ * non-null value.
*/
case class StaticInvoke(
staticObject: Class[_],
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
- propagateNull: Boolean = true) extends InvokeLike {
+ propagateNull: Boolean = true,
+ returnNullable: Boolean = true) extends InvokeLike {
val objectName = staticObject.getName.stripSuffix("$")
- override def nullable: Boolean = true
+ override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments
override def eval(input: InternalRow): Any =
@@ -141,19 +144,40 @@ case class StaticInvoke(
val callFunc = s"$objectName.$functionName($argString)"
- // If the function can return null, we do an extra check to make sure our null bit is still set
- // correctly.
- val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
+ val prepareIsNull = if (nullable) {
+ s"boolean ${ev.isNull} = $resultIsNull;"
} else {
+ ev.isNull = "false"
""
}
+ val evaluate = if (returnNullable) {
+ if (ctx.defaultValue(dataType) == "null") {
+ s"""
+ ${ev.value} = $callFunc;
+ ${ev.isNull} = ${ev.value} == null;
+ """
+ } else {
+ val boxedResult = ctx.freshName("boxedResult")
+ s"""
+ ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
+ ${ev.isNull} = $boxedResult == null;
+ if (!${ev.isNull}) {
+ ${ev.value} = $boxedResult;
+ }
+ """
+ }
+ } else {
+ s"${ev.value} = $callFunc;"
+ }
+
val code = s"""
$argCode
- boolean ${ev.isNull} = $resultIsNull;
- final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
- $postNullCheck
+ $prepareIsNull
+ $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!$resultIsNull) {
+ $evaluate
+ }
"""
ev.copy(code = code)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org