You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/18 06:40:24 UTC
spark git commit: [SPARK-19896][SQL] Throw an exception if case
classes have circular references in toDS
Repository: spark
Updated Branches:
refs/heads/master c083b6b7d -> ccba622e3
[SPARK-19896][SQL] Throw an exception if case classes have circular references in toDS
## What changes were proposed in this pull request?
If case classes have circular references below, it throws StackOverflowError;
```
scala> :pasge
case class classA(i: Int, cls: classB)
case class classB(cls: classA)
scala> Seq(classA(0, null)).toDS()
java.lang.StackOverflowError
at scala.reflect.internal.Symbols$Symbol.info(Symbols.scala:1494)
at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.scala$reflect$runtime$SynchronizedSymbols$SynchronizedSymbol$$super$info(JavaMirrors.scala:66)
at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127)
at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127)
at scala.reflect.runtime.Gil$class.gilSynchronized(Gil.scala:19)
at scala.reflect.runtime.JavaUniverse.gilSynchronized(JavaUniverse.scala:16)
at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.gilSynchronizedIfNotThreadsafe(SynchronizedSymbols.scala:123)
at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.gilSynchronizedIfNotThreadsafe(JavaMirrors.scala:66)
at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.info(SynchronizedSymbols.scala:127)
at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.info(JavaMirrors.scala:66)
at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:48)
at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45)
at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45)
at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45)
at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45)
```
This pr added code to throw UnsupportedOperationException in that case as follows;
```
scala> :paste
case class A(cls: B)
case class B(cls: A)
scala> Seq(A(null)).toDS()
java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class B
at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:627)
at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:644)
at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:632)
at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
at scala.collection.immutable.List.foreach(List.scala:381)
at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
```
## How was this patch tested?
Added tests in `DatasetSuite`.
Author: Takeshi Yamamuro <ya...@apache.org>
Closes #17318 from maropu/SPARK-19896.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ccba622e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ccba622e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ccba622e
Branch: refs/heads/master
Commit: ccba622e35741d8344ec8d74b6750529b2c7219b
Parents: c083b6b
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Sat Mar 18 14:40:16 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sat Mar 18 14:40:16 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++++++------
.../org/apache/spark/sql/DatasetSuite.scala | 24 ++++++++++++++++++++
2 files changed, 37 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ccba622e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 7f7dd51..c4af284 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -470,14 +470,15 @@ object ScalaReflection extends ScalaReflection {
private def serializerFor(
inputObject: Expression,
tpe: `Type`,
- walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
+ walkedTypePath: Seq[String],
+ seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
dataTypeFor(elementType) match {
case dt: ObjectType =>
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(serializerFor(_, elementType, newPath), input, dt)
+ MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt)
case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType) =>
@@ -511,7 +512,7 @@ object ScalaReflection extends ScalaReflection {
val className = getClassNameFromType(optType)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
- serializerFor(unwrapped, optType, newPath)
+ serializerFor(unwrapped, optType, newPath, seenTypeSet)
// Since List[_] also belongs to localTypeOf[Product], we put this case before
// "case t if definedByConstructorParams(t)" to make sure it will match to the
@@ -534,9 +535,9 @@ object ScalaReflection extends ScalaReflection {
ExternalMapToCatalyst(
inputObject,
dataTypeFor(keyType),
- serializerFor(_, keyType, keyPath),
+ serializerFor(_, keyType, keyPath, seenTypeSet),
dataTypeFor(valueType),
- serializerFor(_, valueType, valuePath),
+ serializerFor(_, valueType, valuePath, seenTypeSet),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
case t if t <:< localTypeOf[String] =>
@@ -622,6 +623,11 @@ object ScalaReflection extends ScalaReflection {
Invoke(obj, "serialize", udt, inputObject :: Nil)
case t if definedByConstructorParams(t) =>
+ if (seenTypeSet.contains(t)) {
+ throw new UnsupportedOperationException(
+ s"cannot have circular references in class, but got the circular reference of class $t")
+ }
+
val params = getConstructorParameters(t)
val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
if (javaKeywords.contains(fieldName)) {
@@ -634,7 +640,8 @@ object ScalaReflection extends ScalaReflection {
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+ expressions.Literal(fieldName) ::
+ serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
@@ -643,7 +650,6 @@ object ScalaReflection extends ScalaReflection {
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
-
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/ccba622e/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b37bf13..6417e7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1136,6 +1136,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head ==
new java.sql.Timestamp(100000))
}
+
+ test("SPARK-19896: cannot have circular references in in case class") {
+ val errMsg1 = intercept[UnsupportedOperationException] {
+ Seq(CircularReferenceClassA(null)).toDS
+ }
+ assert(errMsg1.getMessage.startsWith("cannot have circular references in class, but got the " +
+ "circular reference of class"))
+ val errMsg2 = intercept[UnsupportedOperationException] {
+ Seq(CircularReferenceClassC(null)).toDS
+ }
+ assert(errMsg2.getMessage.startsWith("cannot have circular references in class, but got the " +
+ "circular reference of class"))
+ val errMsg3 = intercept[UnsupportedOperationException] {
+ Seq(CircularReferenceClassD(null)).toDS
+ }
+ assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " +
+ "circular reference of class"))
+ }
}
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
@@ -1214,3 +1232,9 @@ object DatasetTransform {
case class Route(src: String, dest: String, cost: Int)
case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])
+
+case class CircularReferenceClassA(cls: CircularReferenceClassB)
+case class CircularReferenceClassB(cls: CircularReferenceClassA)
+case class CircularReferenceClassC(ar: Array[CircularReferenceClassC])
+case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
+case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org