You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/04/25 01:14:17 UTC

[spark] branch branch-3.0 updated: [SPARK-31552][SQL] Fix ClassCastException in ScalaReflection arrayClassFor

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new a26d769  [SPARK-31552][SQL] Fix ClassCastException in ScalaReflection arrayClassFor
a26d769 is described below

commit a26d769c1342f70b80f04846ba933358a3f2e311
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Fri Apr 24 18:04:26 2020 -0700

    [SPARK-31552][SQL] Fix ClassCastException in ScalaReflection arrayClassFor
    
    ### What changes were proposed in this pull request?
    
    the 2 method `arrayClassFor` and `dataTypeFor` in `ScalaReflection` call each other circularly, the cases in `dataTypeFor` are not fully handled in `arrayClassFor`
    
    For example:
    ```scala
    scala> implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
    newArrayEncoder: [T <: Array[_]](implicit evidence$1: reflect.runtime.universe.TypeTag[T])org.apache.spark.sql.Encoder[T]
    
    scala> val decOne = Decimal(1, 38, 18)
    decOne: org.apache.spark.sql.types.Decimal = 1E-18
    
    scala> val decTwo = Decimal(2, 38, 18)
    decTwo: org.apache.spark.sql.types.Decimal = 2E-18
    
    scala> val decSpark = Array(decOne, decTwo)
    decSpark: Array[org.apache.spark.sql.types.Decimal] = Array(1E-18, 2E-18)
    
    scala> Seq(decSpark).toDF()
    java.lang.ClassCastException: org.apache.spark.sql.types.DecimalType cannot be cast to org.apache.spark.sql.types.ObjectType
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$arrayClassFor$1(ScalaReflection.scala:131)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.arrayClassFor(ScalaReflection.scala:120)
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$dataTypeFor$1(ScalaReflection.scala:105)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.dataTypeFor(ScalaReflection.scala:88)
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerForType$1(ScalaReflection.scala:399)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.serializerForType(ScalaReflection.scala:393)
      at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:57)
      at newArrayEncoder(<console>:57)
      ... 53 elided
    
    scala>
    ```
    
    In this PR, we add the missing cases to `arrayClassFor`
    
    ### Why are the changes needed?
    
    bugfix as described above
    
    ### Does this PR introduce any user-facing change?
    
    no
    
    ### How was this patch tested?
    
    add a test for array encoders
    
    Closes #28324 from yaooqinn/SPARK-31552.
    
    Authored-by: Kent Yao <ya...@hotmail.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit caf3ab84113c29c51049c1c906c91462e233e9d9)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/catalyst/ScalaReflection.scala       |  3 +
 .../org/apache/spark/sql/DataFrameSuite.scala      | 82 +++++++++++++++++++++-
 2 files changed, 82 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9c8da32..05de21b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -126,6 +126,9 @@ object ScalaReflection extends ScalaReflection {
       case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
       case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
       case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]]
+      case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]]
+      case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]]
       case other =>
         // There is probably a better way to do this, but I couldn't find it...
         val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f797290..f20e684 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -23,6 +23,7 @@ import java.sql.{Date, Timestamp}
 import java.util.UUID
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.reflect.runtime.universe.TypeTag
 import scala.util.Random
 
 import org.scalatest.Matchers._
@@ -30,10 +31,11 @@ import org.scalatest.Matchers._
 import org.apache.spark.SparkException
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions.Uuid
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -41,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.Utils
@@ -2358,6 +2360,80 @@ class DataFrameSuite extends QueryTest
     val df = Seq((1, new CalendarInterval(1, 2, 3))).toDF("a", "b")
     checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3)))
   }
+
+  test("SPARK-31552: array encoder with different types") {
+    // primitives
+    val booleans = Array(true, false)
+    checkAnswer(Seq(booleans).toDF(), Row(booleans))
+
+    val bytes = Array(1.toByte, 2.toByte)
+    checkAnswer(Seq(bytes).toDF(), Row(bytes))
+    val shorts = Array(1.toShort, 2.toShort)
+    checkAnswer(Seq(shorts).toDF(), Row(shorts))
+    val ints = Array(1, 2)
+    checkAnswer(Seq(ints).toDF(), Row(ints))
+    val longs = Array(1L, 2L)
+    checkAnswer(Seq(longs).toDF(), Row(longs))
+
+    val floats = Array(1.0F, 2.0F)
+    checkAnswer(Seq(floats).toDF(), Row(floats))
+    val doubles = Array(1.0D, 2.0D)
+    checkAnswer(Seq(doubles).toDF(), Row(doubles))
+
+    val strings = Array("2020-04-24", "2020-04-25")
+    checkAnswer(Seq(strings).toDF(), Row(strings))
+
+    // tuples
+    val decOne = Decimal(1, 38, 18)
+    val decTwo = Decimal(2, 38, 18)
+    val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
+    val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
+    checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF())
+
+    // case classes
+    val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
+    checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
+
+    // We can move this implicit def to [[SQLImplicits]] when we eventually make fully
+    // support for array encoder like Seq and Set
+    // For now cases below, decimal/datetime/interval/binary/nested types, etc,
+    // are not supported by array
+    implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+    // decimals
+    val decSpark = Array(decOne, decTwo)
+    val decScala = decSpark.map(_.toBigDecimal)
+    val decJava = decSpark.map(_.toJavaBigDecimal)
+    checkAnswer(Seq(decSpark).toDF(), Row(decJava))
+    checkAnswer(Seq(decScala).toDF(), Row(decJava))
+    checkAnswer(Seq(decJava).toDF(), Row(decJava))
+
+    // datetimes and intervals
+    val dates = strings.map(Date.valueOf)
+    checkAnswer(Seq(dates).toDF(), Row(dates))
+    val localDates = dates.map(d => DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d)))
+    checkAnswer(Seq(localDates).toDF(), Row(dates))
+
+    val timestamps =
+      Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33"))
+    checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
+    val instants =
+      timestamps.map(t => DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t)))
+    checkAnswer(Seq(instants).toDF(), Row(timestamps))
+
+    val intervals = Array(new CalendarInterval(1, 2, 3), new CalendarInterval(4, 5, 6))
+    checkAnswer(Seq(intervals).toDF(), Row(intervals))
+
+    // binary
+    val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte))
+    checkAnswer(Seq(bins).toDF(), Row(bins))
+
+    // nested
+    val nestedIntArray = Array(Array(1), Array(2))
+    checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray)))
+    val nestedDecArray = Array(decSpark)
+    checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org