You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/04/26 06:15:53 UTC

[spark] branch branch-2.4 updated: [SPARK-31552][SQL][2.4] Fix ClassCastException in ScalaReflection arrayClassFor

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 61fc1f7  [SPARK-31552][SQL][2.4] Fix ClassCastException in ScalaReflection arrayClassFor
61fc1f7 is described below

commit 61fc1f719ba7667584734865123fa9133068f9fe
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Sat Apr 25 23:14:03 2020 -0700

    [SPARK-31552][SQL][2.4] Fix ClassCastException in ScalaReflection arrayClassFor
    
    This PR backports https://github.com/apache/spark/pull/28324 to branch-2.4
    
    ### What changes were proposed in this pull request?
    
    the 2 method `arrayClassFor` and `dataTypeFor` in `ScalaReflection` call each other circularly, the cases in `dataTypeFor` are not fully handled in `arrayClassFor`
    
    For example:
    ```scala
    scala> implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
    newArrayEncoder: [T <: Array[_]](implicit evidence$1: reflect.runtime.universe.TypeTag[T])org.apache.spark.sql.Encoder[T]
    
    scala> val decOne = Decimal(1, 38, 18)
    decOne: org.apache.spark.sql.types.Decimal = 1E-18
    
    scala> val decTwo = Decimal(2, 38, 18)
    decTwo: org.apache.spark.sql.types.Decimal = 2E-18
    
    scala> val decSpark = Array(decOne, decTwo)
    decSpark: Array[org.apache.spark.sql.types.Decimal] = Array(1E-18, 2E-18)
    
    scala> Seq(decSpark).toDF()
    java.lang.ClassCastException: org.apache.spark.sql.types.DecimalType cannot be cast to org.apache.spark.sql.types.ObjectType
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$arrayClassFor$1(ScalaReflection.scala:131)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.arrayClassFor(ScalaReflection.scala:120)
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$dataTypeFor$1(ScalaReflection.scala:105)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.dataTypeFor(ScalaReflection.scala:88)
      at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerForType$1(ScalaReflection.scala:399)
      at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
      at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
      at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
      at org.apache.spark.sql.catalyst.ScalaReflection$.serializerForType(ScalaReflection.scala:393)
      at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:57)
      at newArrayEncoder(<console>:57)
      ... 53 elided
    
    scala>
    ```
    
    In this PR, we add the missing cases to `arrayClassFor`
    
    ### Why are the changes needed?
    
    bugfix as described above
    
    ### Does this PR introduce any user-facing change?
    
    no
    
    ### How was this patch tested?
    
    add a test for array encoders
    
    Closes #28341 from yaooqinn/SPARK-31552-24.
    
    Authored-by: Kent Yao <ya...@hotmail.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/catalyst/ScalaReflection.scala       |  7 ++-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 73 ++++++++++++++++++++++
 2 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 858c33a..dc4291e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -114,8 +114,8 @@ object ScalaReflection extends ScalaReflection {
    * Given a type `T` this function constructs `ObjectType` that holds a class of type
    * `Array[T]`.
    *
-   * Special handling is performed for primitive types to map them back to their raw
-   * JVM form instead of the Scala Array that handles auto boxing.
+   * Special handling is performed for primitive types, Array[Byte], CalendarInterval and Decimal
+   * to map them back to their raw JVM form instead of the Scala Array that handles auto boxing.
    */
   private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
     val cls = tpe.dealias match {
@@ -126,6 +126,9 @@ object ScalaReflection extends ScalaReflection {
       case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
       case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
       case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]]
+      case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]]
+      case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]]
       case other =>
         // There is probably a better way to do this, but I couldn't find it...
         val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index eee4478..e7d55ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -22,6 +22,8 @@ import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
 import java.util.UUID
 
+import scala.collection.mutable.WrappedArray
+import scala.reflect.runtime.universe.TypeTag
 import scala.util.Random
 
 import org.scalatest.Matchers._
@@ -29,6 +31,7 @@ import org.scalatest.Matchers._
 import org.apache.spark.SparkException
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.Uuid
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
@@ -2649,4 +2652,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     }
     assert(e.getMessage.contains("Table or view not found:"))
   }
+
+  test("SPARK-31552: array encoder with different types") {
+    // primitives
+    val booleans = Array(true, false)
+    checkAnswer(Seq(booleans).toDF(), Row(booleans))
+
+    val bytes = Array(1.toByte, 2.toByte)
+    checkAnswer(Seq(bytes).toDF(), Row(bytes))
+    val shorts = Array(1.toShort, 2.toShort)
+    checkAnswer(Seq(shorts).toDF(), Row(shorts))
+    val ints = Array(1, 2)
+    checkAnswer(Seq(ints).toDF(), Row(ints))
+    val longs = Array(1L, 2L)
+    checkAnswer(Seq(longs).toDF(), Row(longs))
+
+    val floats = Array(1.0F, 2.0F)
+    checkAnswer(Seq(floats).toDF(), Row(floats))
+    val doubles = Array(1.0D, 2.0D)
+    checkAnswer(Seq(doubles).toDF(), Row(doubles))
+
+    val strings = Array("2020-04-24", "2020-04-25")
+    checkAnswer(Seq(strings).toDF(), Row(strings))
+
+    // tuples
+    val decOne = Decimal(1, 38, 18)
+    val decTwo = Decimal(2, 38, 18)
+    val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
+    val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
+    checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF())
+
+    // case classes
+    val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
+    checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
+
+    // We can move this implicit def to `SQLImplicits` when we eventually make fully
+    // support for array encoder like Seq and Set
+    // For now cases below, decimal/datetime/interval/binary/nested types, etc,
+    // are not supported by array
+    implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+    // decimals
+    val decSpark = Array(decOne, decTwo)
+    val decScala = decSpark.map(_.toBigDecimal)
+    val decJava = decSpark.map(_.toJavaBigDecimal)
+    checkAnswer(Seq(decSpark).toDF(), Row(decJava))
+    checkAnswer(Seq(decScala).toDF(), Row(decJava))
+    checkAnswer(Seq(decJava).toDF(), Row(decJava))
+
+    // datetimes
+    val dates = strings.map(Date.valueOf)
+    checkAnswer(Seq(dates).toDF(), Row(dates))
+
+    val timestamps =
+      Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33"))
+    checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
+
+    // binary
+    val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte))
+
+    val binsRes = Seq(bins).toDF().head().get(0).asInstanceOf[WrappedArray[Array[Byte]]]
+    assert(binsRes.zip(bins).forall { case (a, b) => a.diff(b).isEmpty})
+
+    // nested
+    val nestedIntArray = Array(Array(1), Array(2))
+    checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray)))
+    val nestedDecArray = Array(decSpark)
+    checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
+  }
 }
+
+case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org