You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/08/08 08:12:45 UTC
spark git commit: [SPARK-21567][SQL] Dataset should work with type
alias
Repository: spark
Updated Branches:
refs/heads/master 312bebfb6 -> ee1304199
[SPARK-21567][SQL] Dataset should work with type alias
## What changes were proposed in this pull request?
If we create a type alias for a type workable with Dataset, the type alias doesn't work with Dataset.
A reproducible case looks like:
object C {
type TwoInt = (Int, Int)
def tupleTypeAlias: TwoInt = (1, 1)
}
Seq(1).toDS().map(_ => ("", C.tupleTypeAlias))
It throws an exception like:
type T1 is not a class
scala.ScalaReflectionException: type T1 is not a class
at scala.reflect.api.Symbols$SymbolApi$class.asClass(Symbols.scala:275)
...
This patch accesses the dealias of type in many places in `ScalaReflection` to fix it.
## How was this patch tested?
Added test case.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #18813 from viirya/SPARK-21567.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ee130419
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ee130419
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ee130419
Branch: refs/heads/master
Commit: ee1304199bcd9c1d5fc94f5b06fdd5f6fe7336a1
Parents: 312bebf
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue Aug 8 16:12:41 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Aug 8 16:12:41 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 27 ++++++++++----------
.../org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++++++++++
2 files changed, 38 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ee130419/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 004b4ef..17e595f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -62,7 +62,7 @@ object ScalaReflection extends ScalaReflection {
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
private def dataTypeFor(tpe: `Type`): DataType = {
- tpe match {
+ tpe.dealias match {
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
@@ -94,7 +94,7 @@ object ScalaReflection extends ScalaReflection {
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): ObjectType = {
- val cls = tpe match {
+ val cls = tpe.dealias match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
@@ -193,7 +193,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}
- tpe match {
+ tpe.dealias match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
case t if t <:< localTypeOf[Option[_]] =>
@@ -469,7 +469,7 @@ object ScalaReflection extends ScalaReflection {
}
}
- tpe match {
+ tpe.dealias match {
case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
case t if t <:< localTypeOf[Option[_]] =>
@@ -643,7 +643,7 @@ object ScalaReflection extends ScalaReflection {
* we also treat [[DefinedByConstructorParams]] as product type.
*/
def optionOfProductType(tpe: `Type`): Boolean = {
- tpe match {
+ tpe.dealias match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
definedByConstructorParams(optType)
@@ -690,7 +690,7 @@ object ScalaReflection extends ScalaReflection {
/*
* Retrieves the runtime class corresponding to the provided type.
*/
- def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass)
+ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
case class Schema(dataType: DataType, nullable: Boolean)
@@ -705,7 +705,7 @@ object ScalaReflection extends ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = {
- tpe match {
+ tpe.dealias match {
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
@@ -775,7 +775,7 @@ object ScalaReflection extends ScalaReflection {
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
def definedByConstructorParams(tpe: Type): Boolean = {
- tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
+ tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
}
private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
@@ -829,7 +829,7 @@ trait ScalaReflection {
* synthetic classes, emulating behaviour in Java bytecode.
*/
def getClassNameFromType(tpe: `Type`): String = {
- tpe.erasure.typeSymbol.asClass.fullName
+ tpe.dealias.erasure.typeSymbol.asClass.fullName
}
/**
@@ -848,9 +848,10 @@ trait ScalaReflection {
* support inner class.
*/
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
- val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = tpe
- val params = constructParams(tpe)
+ val dealiasedTpe = tpe.dealias
+ val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
+ val params = constructParams(dealiasedTpe)
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
@@ -864,7 +865,7 @@ trait ScalaReflection {
}
protected def constructParams(tpe: Type): Seq[Symbol] = {
- val constructorSymbol = tpe.member(termNames.CONSTRUCTOR)
+ val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramLists
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/ee130419/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 40235e3..6245b2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,16 @@ import org.apache.spark.sql.types._
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)
+object TestForTypeAlias {
+ type TwoInt = (Int, Int)
+ type ThreeInt = (TwoInt, Int)
+ type SeqOfTwoInt = Seq[TwoInt]
+
+ def tupleTypeAlias: TwoInt = (1, 1)
+ def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2)
+ def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2))
+}
+
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -1317,6 +1327,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.orderBy($"id"), expected)
checkAnswer(df.orderBy('id), expected)
}
+
+ test("SPARK-21567: Dataset should work with type alias") {
+ checkDataset(
+ Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)),
+ ("", (1, 1)))
+
+ checkDataset(
+ Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)),
+ ("", ((1, 1), 2)))
+
+ checkDataset(
+ Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)),
+ ("", Seq((1, 1), (2, 2))))
+ }
}
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org