You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/13 02:55:14 UTC
[spark] branch master updated: [MINOR][SQL] Refactor RowEncoder to
use existing (De)serializerBuildHelper methods
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1b06cda [MINOR][SQL] Refactor RowEncoder to use existing (De)serializerBuildHelper methods
1b06cda is described below
commit 1b06cda532b74ed555f759bcb4f73759966b71bb
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Wed Mar 13 10:54:47 2019 +0800
[MINOR][SQL] Refactor RowEncoder to use existing (De)serializerBuildHelper methods
## What changes were proposed in this pull request?
This patch proposes to reuse existing methods in (De)serializerBuildHelper in RowEncoder to achieve deduplication as well as consistent creation of serialization/deserialization of same type.
## How was this patch tested?
Existing UT.
Closes #24014 from HeartSaVioR/SPARK-27092.
Authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/encoders/RowEncoder.scala | 149 ++++++++-------------
1 file changed, 55 insertions(+), 94 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 97709bd..3a06f8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -22,14 +22,15 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkException
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath}
+import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
/**
* A factory for constructing encoders that convert external row to/from the Spark SQL
@@ -93,37 +94,19 @@ object RowEncoder {
dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
- case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "instantToMicros",
- inputObject :: Nil,
- returnNullable = false)
-
case TimestampType =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "fromJavaTimestamp",
- inputObject :: Nil,
- returnNullable = false)
-
- case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "localDateToDays",
- inputObject :: Nil,
- returnNullable = false)
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createSerializerForJavaInstant(inputObject)
+ } else {
+ createSerializerForSqlTimestamp(inputObject)
+ }
case DateType =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "fromJavaDate",
- inputObject :: Nil,
- returnNullable = false)
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createSerializerForJavaLocalDate(inputObject)
+ } else {
+ createSerializerForSqlDate(inputObject)
+ }
case d: DecimalType =>
CheckOverflow(StaticInvoke(
@@ -133,13 +116,7 @@ object RowEncoder {
inputObject :: Nil,
returnNullable = false), d)
- case StringType =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- inputObject :: Nil,
- returnNullable = false)
+ case StringType => createSerializerForString(inputObject)
case t @ ArrayType(et, containsNull) =>
et match {
@@ -151,17 +128,14 @@ object RowEncoder {
inputObject :: Nil,
returnNullable = false)
- case _ => MapObjects(
- element => {
- val value = serializerFor(ValidateExternalType(element, et), et)
- if (!containsNull) {
- AssertNotNull(value)
- } else {
- value
- }
- },
- inputObject,
- ObjectType(classOf[Object]))
+ case _ =>
+ createSerializerForMapObjects(
+ inputObject,
+ ObjectType(classOf[Object]),
+ element => {
+ val value = serializerFor(ValidateExternalType(element, et), et)
+ expressionWithNullSafety(value, containsNull, WalkedTypePath())
+ })
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -188,9 +162,7 @@ object RowEncoder {
propagateNull = false)
if (inputObject.nullable) {
- If(IsNull(inputObject),
- Literal.create(null, nonNullOutput.dataType),
- nonNullOutput)
+ expressionForNullableExpr(inputObject, nonNullOutput)
} else {
nonNullOutput
}
@@ -217,9 +189,7 @@ object RowEncoder {
})
if (inputObject.nullable) {
- If(IsNull(inputObject),
- Literal.create(null, nonNullOutput.dataType),
- nonNullOutput)
+ expressionForNullableExpr(inputObject, nonNullOutput)
} else {
nonNullOutput
}
@@ -244,12 +214,18 @@ object RowEncoder {
def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
- case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
- ObjectType(classOf[java.time.Instant])
- case TimestampType => ObjectType(classOf[java.sql.Timestamp])
- case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
- ObjectType(classOf[java.time.LocalDate])
- case DateType => ObjectType(classOf[java.sql.Date])
+ case TimestampType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ ObjectType(classOf[java.time.Instant])
+ } else {
+ ObjectType(classOf[java.sql.Timestamp])
+ }
+ case DateType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ ObjectType(classOf[java.time.LocalDate])
+ } else {
+ ObjectType(classOf[java.sql.Date])
+ }
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -291,44 +267,23 @@ object RowEncoder {
dataType = ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
- case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
- StaticInvoke(
- DateTimeUtils.getClass,
- ObjectType(classOf[java.time.Instant]),
- "microsToInstant",
- input :: Nil,
- returnNullable = false)
-
case TimestampType =>
- StaticInvoke(
- DateTimeUtils.getClass,
- ObjectType(classOf[java.sql.Timestamp]),
- "toJavaTimestamp",
- input :: Nil,
- returnNullable = false)
-
- case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
- StaticInvoke(
- DateTimeUtils.getClass,
- ObjectType(classOf[java.time.LocalDate]),
- "daysToLocalDate",
- input :: Nil,
- returnNullable = false)
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createDeserializerForInstant(input)
+ } else {
+ createDeserializerForSqlTimestamp(input)
+ }
case DateType =>
- StaticInvoke(
- DateTimeUtils.getClass,
- ObjectType(classOf[java.sql.Date]),
- "toJavaDate",
- input :: Nil,
- returnNullable = false)
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createDeserializerForLocalDate(input)
+ } else {
+ createDeserializerForSqlDate(input)
+ }
- case _: DecimalType =>
- Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
- returnNullable = false)
+ case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)
- case StringType =>
- Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
+ case StringType => createDeserializerForString(input, returnNullable = false)
case ArrayType(et, nullable) =>
val arrayData =
@@ -368,4 +323,10 @@ object RowEncoder {
Literal.create(null, externalDataTypeFor(input.dataType)),
CreateExternalRow(convertedFields, schema))
}
+
+ private def expressionForNullableExpr(
+ expr: Expression,
+ newExprWhenNotNull: Expression): Expression = {
+ If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType), newExprWhenNotNull)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org