You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/04/26 06:15:53 UTC
[spark] branch branch-2.4 updated: [SPARK-31552][SQL][2.4] Fix
ClassCastException in ScalaReflection arrayClassFor
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 61fc1f7 [SPARK-31552][SQL][2.4] Fix ClassCastException in ScalaReflection arrayClassFor
61fc1f7 is described below
commit 61fc1f719ba7667584734865123fa9133068f9fe
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Sat Apr 25 23:14:03 2020 -0700
[SPARK-31552][SQL][2.4] Fix ClassCastException in ScalaReflection arrayClassFor
This PR backports https://github.com/apache/spark/pull/28324 to branch-2.4
### What changes were proposed in this pull request?
the 2 method `arrayClassFor` and `dataTypeFor` in `ScalaReflection` call each other circularly, the cases in `dataTypeFor` are not fully handled in `arrayClassFor`
For example:
```scala
scala> implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
newArrayEncoder: [T <: Array[_]](implicit evidence$1: reflect.runtime.universe.TypeTag[T])org.apache.spark.sql.Encoder[T]
scala> val decOne = Decimal(1, 38, 18)
decOne: org.apache.spark.sql.types.Decimal = 1E-18
scala> val decTwo = Decimal(2, 38, 18)
decTwo: org.apache.spark.sql.types.Decimal = 2E-18
scala> val decSpark = Array(decOne, decTwo)
decSpark: Array[org.apache.spark.sql.types.Decimal] = Array(1E-18, 2E-18)
scala> Seq(decSpark).toDF()
java.lang.ClassCastException: org.apache.spark.sql.types.DecimalType cannot be cast to org.apache.spark.sql.types.ObjectType
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$arrayClassFor$1(ScalaReflection.scala:131)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.arrayClassFor(ScalaReflection.scala:120)
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$dataTypeFor$1(ScalaReflection.scala:105)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.dataTypeFor(ScalaReflection.scala:88)
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerForType$1(ScalaReflection.scala:399)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.serializerForType(ScalaReflection.scala:393)
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:57)
at newArrayEncoder(<console>:57)
... 53 elided
scala>
```
In this PR, we add the missing cases to `arrayClassFor`
### Why are the changes needed?
bugfix as described above
### Does this PR introduce any user-facing change?
no
### How was this patch tested?
add a test for array encoders
Closes #28341 from yaooqinn/SPARK-31552-24.
Authored-by: Kent Yao <ya...@hotmail.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../spark/sql/catalyst/ScalaReflection.scala | 7 ++-
.../org/apache/spark/sql/DataFrameSuite.scala | 73 ++++++++++++++++++++++
2 files changed, 78 insertions(+), 2 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 858c33a..dc4291e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -114,8 +114,8 @@ object ScalaReflection extends ScalaReflection {
* Given a type `T` this function constructs `ObjectType` that holds a class of type
* `Array[T]`.
*
- * Special handling is performed for primitive types to map them back to their raw
- * JVM form instead of the Scala Array that handles auto boxing.
+ * Special handling is performed for primitive types, Array[Byte], CalendarInterval and Decimal
+ * to map them back to their raw JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
val cls = tpe.dealias match {
@@ -126,6 +126,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]]
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]]
+ case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]]
case other =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index eee4478..e7d55ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -22,6 +22,8 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.util.UUID
+import scala.collection.mutable.WrappedArray
+import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
import org.scalatest.Matchers._
@@ -29,6 +31,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
@@ -2649,4 +2652,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
assert(e.getMessage.contains("Table or view not found:"))
}
+
+ test("SPARK-31552: array encoder with different types") {
+ // primitives
+ val booleans = Array(true, false)
+ checkAnswer(Seq(booleans).toDF(), Row(booleans))
+
+ val bytes = Array(1.toByte, 2.toByte)
+ checkAnswer(Seq(bytes).toDF(), Row(bytes))
+ val shorts = Array(1.toShort, 2.toShort)
+ checkAnswer(Seq(shorts).toDF(), Row(shorts))
+ val ints = Array(1, 2)
+ checkAnswer(Seq(ints).toDF(), Row(ints))
+ val longs = Array(1L, 2L)
+ checkAnswer(Seq(longs).toDF(), Row(longs))
+
+ val floats = Array(1.0F, 2.0F)
+ checkAnswer(Seq(floats).toDF(), Row(floats))
+ val doubles = Array(1.0D, 2.0D)
+ checkAnswer(Seq(doubles).toDF(), Row(doubles))
+
+ val strings = Array("2020-04-24", "2020-04-25")
+ checkAnswer(Seq(strings).toDF(), Row(strings))
+
+ // tuples
+ val decOne = Decimal(1, 38, 18)
+ val decTwo = Decimal(2, 38, 18)
+ val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
+ val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
+ checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF())
+
+ // case classes
+ val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
+ checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
+
+ // We can move this implicit def to `SQLImplicits` when we eventually make fully
+ // support for array encoder like Seq and Set
+ // For now cases below, decimal/datetime/interval/binary/nested types, etc,
+ // are not supported by array
+ implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+ // decimals
+ val decSpark = Array(decOne, decTwo)
+ val decScala = decSpark.map(_.toBigDecimal)
+ val decJava = decSpark.map(_.toJavaBigDecimal)
+ checkAnswer(Seq(decSpark).toDF(), Row(decJava))
+ checkAnswer(Seq(decScala).toDF(), Row(decJava))
+ checkAnswer(Seq(decJava).toDF(), Row(decJava))
+
+ // datetimes
+ val dates = strings.map(Date.valueOf)
+ checkAnswer(Seq(dates).toDF(), Row(dates))
+
+ val timestamps =
+ Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33"))
+ checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
+
+ // binary
+ val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte))
+
+ val binsRes = Seq(bins).toDF().head().get(0).asInstanceOf[WrappedArray[Array[Byte]]]
+ assert(binsRes.zip(bins).forall { case (a, b) => a.diff(b).isEmpty})
+
+ // nested
+ val nestedIntArray = Array(Array(1), Array(2))
+ checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray)))
+ val nestedDecArray = Array(decSpark)
+ checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
+ }
}
+
+case class GroupByKey(a: Int, b: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org