You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/13 02:55:14 UTC

[spark] branch master updated: [MINOR][SQL] Refactor RowEncoder to use existing (De)serializerBuildHelper methods

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b06cda  [MINOR][SQL] Refactor RowEncoder to use existing (De)serializerBuildHelper methods
1b06cda is described below

commit 1b06cda532b74ed555f759bcb4f73759966b71bb
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Wed Mar 13 10:54:47 2019 +0800

    [MINOR][SQL] Refactor RowEncoder to use existing (De)serializerBuildHelper methods
    
    ## What changes were proposed in this pull request?
    
    This patch proposes to reuse existing methods in (De)serializerBuildHelper in RowEncoder to achieve deduplication as well as consistent creation of serialization/deserialization of same type.
    
    ## How was this patch tested?
    
    Existing UT.
    
    Closes #24014 from HeartSaVioR/SPARK-27092.
    
    Authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/encoders/RowEncoder.scala   | 149 ++++++++-------------
 1 file changed, 55 insertions(+), 94 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 97709bd..3a06f8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -22,14 +22,15 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath}
+import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.SerializerBuildHelper._
 import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * A factory for constructing encoders that convert external row to/from the Spark SQL
@@ -93,37 +94,19 @@ object RowEncoder {
         dataType = ObjectType(udtClass), false)
       Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
 
-    case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        TimestampType,
-        "instantToMicros",
-        inputObject :: Nil,
-        returnNullable = false)
-
     case TimestampType =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        TimestampType,
-        "fromJavaTimestamp",
-        inputObject :: Nil,
-        returnNullable = false)
-
-    case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        DateType,
-        "localDateToDays",
-        inputObject :: Nil,
-        returnNullable = false)
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        createSerializerForJavaInstant(inputObject)
+      } else {
+        createSerializerForSqlTimestamp(inputObject)
+      }
 
     case DateType =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        DateType,
-        "fromJavaDate",
-        inputObject :: Nil,
-        returnNullable = false)
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        createSerializerForJavaLocalDate(inputObject)
+      } else {
+        createSerializerForSqlDate(inputObject)
+      }
 
     case d: DecimalType =>
       CheckOverflow(StaticInvoke(
@@ -133,13 +116,7 @@ object RowEncoder {
         inputObject :: Nil,
         returnNullable = false), d)
 
-    case StringType =>
-      StaticInvoke(
-        classOf[UTF8String],
-        StringType,
-        "fromString",
-        inputObject :: Nil,
-        returnNullable = false)
+    case StringType => createSerializerForString(inputObject)
 
     case t @ ArrayType(et, containsNull) =>
       et match {
@@ -151,17 +128,14 @@ object RowEncoder {
             inputObject :: Nil,
             returnNullable = false)
 
-        case _ => MapObjects(
-          element => {
-            val value = serializerFor(ValidateExternalType(element, et), et)
-            if (!containsNull) {
-              AssertNotNull(value)
-            } else {
-              value
-            }
-          },
-          inputObject,
-          ObjectType(classOf[Object]))
+        case _ =>
+          createSerializerForMapObjects(
+            inputObject,
+            ObjectType(classOf[Object]),
+            element => {
+              val value = serializerFor(ValidateExternalType(element, et), et)
+              expressionWithNullSafety(value, containsNull, WalkedTypePath())
+            })
       }
 
     case t @ MapType(kt, vt, valueNullable) =>
@@ -188,9 +162,7 @@ object RowEncoder {
         propagateNull = false)
 
       if (inputObject.nullable) {
-        If(IsNull(inputObject),
-          Literal.create(null, nonNullOutput.dataType),
-          nonNullOutput)
+        expressionForNullableExpr(inputObject, nonNullOutput)
       } else {
         nonNullOutput
       }
@@ -217,9 +189,7 @@ object RowEncoder {
       })
 
       if (inputObject.nullable) {
-        If(IsNull(inputObject),
-          Literal.create(null, nonNullOutput.dataType),
-          nonNullOutput)
+        expressionForNullableExpr(inputObject, nonNullOutput)
       } else {
         nonNullOutput
       }
@@ -244,12 +214,18 @@ object RowEncoder {
 
   def externalDataTypeFor(dt: DataType): DataType = dt match {
     case _ if ScalaReflection.isNativeType(dt) => dt
-    case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
-      ObjectType(classOf[java.time.Instant])
-    case TimestampType => ObjectType(classOf[java.sql.Timestamp])
-    case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
-      ObjectType(classOf[java.time.LocalDate])
-    case DateType => ObjectType(classOf[java.sql.Date])
+    case TimestampType =>
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        ObjectType(classOf[java.time.Instant])
+      } else {
+        ObjectType(classOf[java.sql.Timestamp])
+      }
+    case DateType =>
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        ObjectType(classOf[java.time.LocalDate])
+      } else {
+        ObjectType(classOf[java.sql.Date])
+      }
     case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
     case StringType => ObjectType(classOf[java.lang.String])
     case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -291,44 +267,23 @@ object RowEncoder {
         dataType = ObjectType(udtClass))
       Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
 
-    case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        ObjectType(classOf[java.time.Instant]),
-        "microsToInstant",
-        input :: Nil,
-        returnNullable = false)
-
     case TimestampType =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        ObjectType(classOf[java.sql.Timestamp]),
-        "toJavaTimestamp",
-        input :: Nil,
-        returnNullable = false)
-
-    case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        ObjectType(classOf[java.time.LocalDate]),
-        "daysToLocalDate",
-        input :: Nil,
-        returnNullable = false)
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        createDeserializerForInstant(input)
+      } else {
+        createDeserializerForSqlTimestamp(input)
+      }
 
     case DateType =>
-      StaticInvoke(
-        DateTimeUtils.getClass,
-        ObjectType(classOf[java.sql.Date]),
-        "toJavaDate",
-        input :: Nil,
-        returnNullable = false)
+      if (SQLConf.get.datetimeJava8ApiEnabled) {
+        createDeserializerForLocalDate(input)
+      } else {
+        createDeserializerForSqlDate(input)
+      }
 
-    case _: DecimalType =>
-      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
-        returnNullable = false)
+    case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)
 
-    case StringType =>
-      Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
+    case StringType => createDeserializerForString(input, returnNullable = false)
 
     case ArrayType(et, nullable) =>
       val arrayData =
@@ -368,4 +323,10 @@ object RowEncoder {
         Literal.create(null, externalDataTypeFor(input.dataType)),
         CreateExternalRow(convertedFields, schema))
   }
+
+  private def expressionForNullableExpr(
+      expr: Expression,
+      newExprWhenNotNull: Expression): Expression = {
+    If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType), newExprWhenNotNull)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org