You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/04/25 01:14:17 UTC
[spark] branch branch-3.0 updated: [SPARK-31552][SQL] Fix
ClassCastException in ScalaReflection arrayClassFor
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new a26d769 [SPARK-31552][SQL] Fix ClassCastException in ScalaReflection arrayClassFor
a26d769 is described below
commit a26d769c1342f70b80f04846ba933358a3f2e311
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Fri Apr 24 18:04:26 2020 -0700
[SPARK-31552][SQL] Fix ClassCastException in ScalaReflection arrayClassFor
### What changes were proposed in this pull request?
the 2 method `arrayClassFor` and `dataTypeFor` in `ScalaReflection` call each other circularly, the cases in `dataTypeFor` are not fully handled in `arrayClassFor`
For example:
```scala
scala> implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
newArrayEncoder: [T <: Array[_]](implicit evidence$1: reflect.runtime.universe.TypeTag[T])org.apache.spark.sql.Encoder[T]
scala> val decOne = Decimal(1, 38, 18)
decOne: org.apache.spark.sql.types.Decimal = 1E-18
scala> val decTwo = Decimal(2, 38, 18)
decTwo: org.apache.spark.sql.types.Decimal = 2E-18
scala> val decSpark = Array(decOne, decTwo)
decSpark: Array[org.apache.spark.sql.types.Decimal] = Array(1E-18, 2E-18)
scala> Seq(decSpark).toDF()
java.lang.ClassCastException: org.apache.spark.sql.types.DecimalType cannot be cast to org.apache.spark.sql.types.ObjectType
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$arrayClassFor$1(ScalaReflection.scala:131)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.arrayClassFor(ScalaReflection.scala:120)
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$dataTypeFor$1(ScalaReflection.scala:105)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.dataTypeFor(ScalaReflection.scala:88)
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerForType$1(ScalaReflection.scala:399)
at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:879)
at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:878)
at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:49)
at org.apache.spark.sql.catalyst.ScalaReflection$.serializerForType(ScalaReflection.scala:393)
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:57)
at newArrayEncoder(<console>:57)
... 53 elided
scala>
```
In this PR, we add the missing cases to `arrayClassFor`
### Why are the changes needed?
bugfix as described above
### Does this PR introduce any user-facing change?
no
### How was this patch tested?
add a test for array encoders
Closes #28324 from yaooqinn/SPARK-31552.
Authored-by: Kent Yao <ya...@hotmail.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
(cherry picked from commit caf3ab84113c29c51049c1c906c91462e233e9d9)
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../spark/sql/catalyst/ScalaReflection.scala | 3 +
.../org/apache/spark/sql/DataFrameSuite.scala | 82 +++++++++++++++++++++-
2 files changed, 82 insertions(+), 3 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9c8da32..05de21b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -126,6 +126,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]]
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]]
+ case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]]
case other =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f797290..f20e684 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -23,6 +23,7 @@ import java.sql.{Date, Timestamp}
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
+import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
import org.scalatest.Matchers._
@@ -30,10 +31,11 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -41,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
@@ -2358,6 +2360,80 @@ class DataFrameSuite extends QueryTest
val df = Seq((1, new CalendarInterval(1, 2, 3))).toDF("a", "b")
checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3)))
}
+
+ test("SPARK-31552: array encoder with different types") {
+ // primitives
+ val booleans = Array(true, false)
+ checkAnswer(Seq(booleans).toDF(), Row(booleans))
+
+ val bytes = Array(1.toByte, 2.toByte)
+ checkAnswer(Seq(bytes).toDF(), Row(bytes))
+ val shorts = Array(1.toShort, 2.toShort)
+ checkAnswer(Seq(shorts).toDF(), Row(shorts))
+ val ints = Array(1, 2)
+ checkAnswer(Seq(ints).toDF(), Row(ints))
+ val longs = Array(1L, 2L)
+ checkAnswer(Seq(longs).toDF(), Row(longs))
+
+ val floats = Array(1.0F, 2.0F)
+ checkAnswer(Seq(floats).toDF(), Row(floats))
+ val doubles = Array(1.0D, 2.0D)
+ checkAnswer(Seq(doubles).toDF(), Row(doubles))
+
+ val strings = Array("2020-04-24", "2020-04-25")
+ checkAnswer(Seq(strings).toDF(), Row(strings))
+
+ // tuples
+ val decOne = Decimal(1, 38, 18)
+ val decTwo = Decimal(2, 38, 18)
+ val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
+ val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
+ checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF())
+
+ // case classes
+ val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
+ checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
+
+ // We can move this implicit def to [[SQLImplicits]] when we eventually make fully
+ // support for array encoder like Seq and Set
+ // For now cases below, decimal/datetime/interval/binary/nested types, etc,
+ // are not supported by array
+ implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+ // decimals
+ val decSpark = Array(decOne, decTwo)
+ val decScala = decSpark.map(_.toBigDecimal)
+ val decJava = decSpark.map(_.toJavaBigDecimal)
+ checkAnswer(Seq(decSpark).toDF(), Row(decJava))
+ checkAnswer(Seq(decScala).toDF(), Row(decJava))
+ checkAnswer(Seq(decJava).toDF(), Row(decJava))
+
+ // datetimes and intervals
+ val dates = strings.map(Date.valueOf)
+ checkAnswer(Seq(dates).toDF(), Row(dates))
+ val localDates = dates.map(d => DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d)))
+ checkAnswer(Seq(localDates).toDF(), Row(dates))
+
+ val timestamps =
+ Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33"))
+ checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
+ val instants =
+ timestamps.map(t => DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t)))
+ checkAnswer(Seq(instants).toDF(), Row(timestamps))
+
+ val intervals = Array(new CalendarInterval(1, 2, 3), new CalendarInterval(4, 5, 6))
+ checkAnswer(Seq(intervals).toDF(), Row(intervals))
+
+ // binary
+ val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte))
+ checkAnswer(Seq(bins).toDF(), Row(bins))
+
+ // nested
+ val nestedIntArray = Array(Array(1), Array(2))
+ checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray)))
+ val nestedDecArray = Array(decSpark)
+ checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
+ }
}
case class GroupByKey(a: Int, b: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org