You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/04/12 05:31:43 UTC

[spark] branch master updated: [SPARK-27199][SQL][FOLLOWUP] Fix bug in codegen templates in UnixTime and FromUnixTime

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bbbe54a  [SPARK-27199][SQL][FOLLOWUP] Fix bug in codegen templates in UnixTime and FromUnixTime
bbbe54a is described below

commit bbbe54aa79c0f6b66e3f3ac34515cc096beb5730
Author: Kris Mok <kr...@databricks.com>
AuthorDate: Fri Apr 12 13:31:18 2019 +0800

    [SPARK-27199][SQL][FOLLOWUP] Fix bug in codegen templates in UnixTime and FromUnixTime
    
    ## What changes were proposed in this pull request?
    
    SPARK-27199 introduced the use of `ZoneId` instead of `TimeZone` in a few date/time expressions.
    There were 3 occurrences of `ctx.addReferenceObj("zoneId", zoneId)` in that PR, which had a bug because while the `java.time.ZoneId` base type is public, the actual concrete implementation classes are not public, so using the 2-arg version of `CodegenContext.addReferenceObj` would incorrectly generate code that reference non-public types (`java.time.ZoneRegion`, to be specific). The 3-arg version should be used, with the class name of the referenced object explicitly specified to the  [...]
    
    One of such occurrences was caught in testing in the main PR of SPARK-27199 (https://github.com/apache/spark/pull/24141), for `DateFormatClass`. But the other 2 occurrences slipped through because there were no test cases that covered them.
    
    Example of this bug in the current Apache Spark master, in a Spark Shell:
    ```
    scala> Seq(("2016-04-08", "yyyy-MM-dd")).toDF("s", "f").repartition(1).selectExpr("to_unix_timestamp(s, f)").show
    ...
    java.lang.IllegalAccessError: tried to access class java.time.ZoneRegion from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1
    ```
    
    This PR fixes the codegen issues and adds the corresponding unit tests.
    
    ## How was this patch tested?
    
    Enhanced tests in `DateExpressionsSuite` for `to_unix_timestamp` and `from_unixtime`.
    
    Closes #24352 from rednaxelafx/fix-spark-27199.
    
    Authored-by: Kris Mok <kr...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/expressions/datetimeExpressions.scala | 10 ++---
 .../expressions/DateExpressionsSuite.scala         | 48 ++++++++++++++++------
 2 files changed, 40 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index aad9f20..9a6e6c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -541,7 +541,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
-    val zid = ctx.addReferenceObj("zoneId", zoneId, "java.time.ZoneId")
+    val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
     val locale = ctx.addReferenceObj("locale", Locale.US)
     defineCodeGen(ctx, ev, (timestamp, format) => {
       s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $zid, $locale)
@@ -710,13 +710,13 @@ abstract class UnixTime
             }""")
         }
       case StringType =>
-        val tz = ctx.addReferenceObj("zoneId", zoneId)
+        val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
         val locale = ctx.addReferenceObj("locale", Locale.US)
         val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
         nullSafeCodeGen(ctx, ev, (string, format) => {
           s"""
             try {
-              ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $tz, $locale)
+              ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $zid, $locale)
                 .parse($string.toString()) / $MICROS_PER_SECOND;
             } catch (java.lang.IllegalArgumentException e) {
               ${ev.isNull} = true;
@@ -849,13 +849,13 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
           }""")
       }
     } else {
-      val tz = ctx.addReferenceObj("zoneId", zoneId)
+      val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
       val locale = ctx.addReferenceObj("locale", Locale.US)
       val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
       nullSafeCodeGen(ctx, ev, (seconds, f) => {
         s"""
         try {
-          ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $tz, $locale).
+          ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $zid, $locale).
             format($seconds * 1000000L));
         } catch (java.lang.IllegalArgumentException e) {
           ${ev.isNull} = true;
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 64bf899..88607d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -26,13 +26,14 @@ import java.util.concurrent.TimeUnit._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -652,7 +653,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("from_unixtime") {
-    val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
+    val fmt1 = "yyyy-MM-dd HH:mm:ss"
+    val sdf1 = new SimpleDateFormat(fmt1, Locale.US)
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
     val sdf2 = new SimpleDateFormat(fmt2, Locale.US)
     for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) {
@@ -661,10 +663,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       sdf2.setTimeZone(tz)
 
       checkEvaluation(
-        FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+        FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId),
         sdf1.format(new Timestamp(0)))
       checkEvaluation(FromUnixTime(
-        Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+        Literal(1000L), Literal(fmt1), timeZoneId),
         sdf1.format(new Timestamp(1000000)))
       checkEvaluation(
         FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId),
@@ -673,13 +675,22 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
         FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId),
         null)
       checkEvaluation(
-        FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+        FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId),
         null)
       checkEvaluation(
         FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId),
         null)
       checkEvaluation(
         FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null)
+
+        // The codegen path for non-literal input should also work
+        checkEvaluation(
+          expression = FromUnixTime(
+            BoundReference(ordinal = 0, dataType = LongType, nullable = true),
+            BoundReference(ordinal = 1, dataType = StringType, nullable = true),
+            timeZoneId),
+          expected = UTF8String.fromString(sdf1.format(new Timestamp(0))),
+          inputRow = InternalRow(0L, UTF8String.fromString(fmt1)))
     }
   }
 
@@ -739,7 +750,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("to_unix_timestamp") {
-    val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
+    val fmt1 = "yyyy-MM-dd HH:mm:ss"
+    val sdf1 = new SimpleDateFormat(fmt1, Locale.US)
     val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
     val sdf2 = new SimpleDateFormat(fmt2, Locale.US)
     val fmt3 = "yy-MM-dd"
@@ -754,15 +766,15 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
         val date1 = Date.valueOf("2015-07-24")
         checkEvaluation(ToUnixTimestamp(
-          Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L)
+          Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L)
         checkEvaluation(ToUnixTimestamp(
-          Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+          Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId),
           1000L)
         checkEvaluation(ToUnixTimestamp(
-          Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")),
+          Literal(new Timestamp(1000000)), Literal(fmt1)),
           1000L)
         checkEvaluation(
-          ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+          ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId),
           MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz)))
         checkEvaluation(
           ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId),
@@ -772,21 +784,31 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(
             DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz)))
         val t1 = ToUnixTimestamp(
-          CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
+          CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
         val t2 = ToUnixTimestamp(
-          CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
+          CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
         assert(t2 - t1 <= 1)
         checkEvaluation(ToUnixTimestamp(
           Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null)
         checkEvaluation(
           ToUnixTimestamp(
-            Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
+            Literal.create(null, DateType), Literal(fmt1), timeZoneId),
           null)
         checkEvaluation(ToUnixTimestamp(
           Literal(date1), Literal.create(null, StringType), timeZoneId),
           MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz)))
         checkEvaluation(
           ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null)
+
+        // The codegen path for non-literal input should also work
+        checkEvaluation(
+          expression = ToUnixTimestamp(
+            BoundReference(ordinal = 0, dataType = StringType, nullable = true),
+            BoundReference(ordinal = 1, dataType = StringType, nullable = true),
+            timeZoneId),
+          expected = 0L,
+          inputRow = InternalRow(
+            UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1)))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org