You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/04 02:46:35 UTC
[spark] branch master updated: [SPARK-27001][SQL] Refactor
"serializerFor" method between ScalaReflection and JavaTypeInference
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 34f6066 [SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference
34f6066 is described below
commit 34f606678a90e860711a5f9f9618cf00788c9eb0
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Mon Mar 4 10:45:48 2019 +0800
[SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference
## What changes were proposed in this pull request?
This patch proposes refactoring `serializerFor` method between `ScalaReflection` and `JavaTypeInference`, being consistent with what we refactored for `deserializerFor` in #23854.
This patch also extracts the logic on recording walk type path since the logic is duplicated across `serializerFor` and `deserializerFor` with `ScalaReflection` and `JavaTypeInference`.
## How was this patch tested?
Existing tests.
Closes #23908 from HeartSaVioR/SPARK-27001.
Authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/catalyst/DeserializerBuildHelper.scala | 32 ++-
.../spark/sql/catalyst/JavaTypeInference.scala | 143 +++++---------
.../spark/sql/catalyst/ScalaReflection.scala | 220 +++++++--------------
.../spark/sql/catalyst/SerializerBuildHelper.scala | 198 +++++++++++++++++++
.../apache/spark/sql/catalyst/WalkedTypePath.scala | 57 ++++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
.../sql/catalyst/encoders/ExpressionEncoder.scala | 2 +-
.../spark/sql/catalyst/encoders/RowEncoder.scala | 2 +-
.../spark/sql/catalyst/expressions/Cast.scala | 4 +-
.../catalyst/expressions/CodeGenerationSuite.scala | 2 +-
.../expressions/NullExpressionsSuite.scala | 2 +-
11 files changed, 394 insertions(+), 270 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index d75d3ca..e55c25c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -29,7 +29,7 @@ object DeserializerBuildHelper {
path: Expression,
part: String,
dataType: DataType,
- walkedTypePath: Seq[String]): Expression = {
+ walkedTypePath: WalkedTypePath): Expression = {
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
@@ -39,40 +39,30 @@ object DeserializerBuildHelper {
path: Expression,
ordinal: Int,
dataType: DataType,
- walkedTypePath: Seq[String]): Expression = {
+ walkedTypePath: WalkedTypePath): Expression = {
val newPath = GetStructField(path, ordinal)
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
- def deserializerForWithNullSafety(
- expr: Expression,
- dataType: DataType,
- nullable: Boolean,
- walkedTypePath: Seq[String],
- funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
- val newExpr = funcForCreatingNewExpr(expr, walkedTypePath)
- expressionWithNullSafety(newExpr, nullable, walkedTypePath)
- }
-
def deserializerForWithNullSafetyAndUpcast(
expr: Expression,
dataType: DataType,
nullable: Boolean,
- walkedTypePath: Seq[String],
- funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
+ walkedTypePath: WalkedTypePath,
+ funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
- deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath,
- funcForCreatingNewExpr)
+ expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath),
+ nullable, walkedTypePath)
}
- private def expressionWithNullSafety(
+ def expressionWithNullSafety(
expr: Expression,
nullable: Boolean,
- walkedTypePath: Seq[String]): Expression = {
+ walkedTypePath: WalkedTypePath): Expression = {
if (nullable) {
expr
} else {
- AssertNotNull(expr, walkedTypePath)
+ AssertNotNull(expr, walkedTypePath.getPaths)
}
}
@@ -167,10 +157,10 @@ object DeserializerBuildHelper {
private def upCastToExpectedType(
expr: Expression,
expected: DataType,
- walkedTypePath: Seq[String]): Expression = expected match {
+ walkedTypePath: WalkedTypePath): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
- case _ => UpCast(expr, expected, walkedTypePath)
+ case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 87b2ae8..933a6db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -27,12 +27,12 @@ import scala.language.existentials
import com.google.common.reflect.TypeToken
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
/**
* Type-inference utilities for POJOs and Java collections.
@@ -195,7 +195,7 @@ object JavaTypeInference {
*/
def deserializerFor(beanClass: Class[_]): Expression = {
val typeToken = TypeToken.of(beanClass)
- val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil
+ val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName)
val (dataType, nullable) = inferDataType(typeToken)
// Assumes we are deserializing the first column of a row.
@@ -208,7 +208,7 @@ object JavaTypeInference {
private def deserializerFor(
typeToken: TypeToken[_],
path: Expression,
- walkedTypePath: Seq[String]): Expression = {
+ walkedTypePath: WalkedTypePath): Expression = {
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
@@ -244,8 +244,7 @@ object JavaTypeInference {
case c if c.isArray =>
val elementType = c.getComponentType
- val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +:
- walkedTypePath
+ val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName)
val (dataType, elementNullable) = inferDataType(elementType)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
@@ -274,8 +273,7 @@ object JavaTypeInference {
case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
- val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +:
- walkedTypePath
+ val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName)
val (dataType, elementNullable) = inferDataType(et)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
@@ -291,8 +289,8 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
- val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" +
- s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath
+ val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName,
+ valueType.getType.getTypeName)
val keyData =
Invoke(
@@ -328,15 +326,12 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(fieldType)
- val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" +
- s""", name: "$fieldName")""") +: walkedTypePath
- val setter = deserializerForWithNullSafety(
- path,
- dataType,
+ val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
+ val setter = expressionWithNullSafety(
+ deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
+ newTypePath),
nullable = nullable,
- newTypePath,
- (expr, typePath) => deserializerFor(fieldType,
- addToPath(expr, fieldName, dataType, typePath), typePath))
+ newTypePath)
p.getWriteMethod.getName -> setter
}.toMap
@@ -367,12 +362,10 @@ object JavaTypeInference {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
if (ScalaReflection.isNativeType(dataType)) {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(dataType, nullable))
+ createSerializerForGenericArray(input, dataType, nullable = nullable)
} else {
- MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
+ createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
+ serializerFor(_, elementType))
}
}
@@ -380,60 +373,26 @@ object JavaTypeInference {
inputObject
} else {
typeToken.getRawType match {
- case c if c == classOf[String] =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- inputObject :: Nil,
- returnNullable = false)
-
- case c if c == classOf[java.sql.Timestamp] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "fromJavaTimestamp",
- inputObject :: Nil,
- returnNullable = false)
-
- case c if c == classOf[java.time.LocalDate] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "localDateToDays",
- inputObject :: Nil,
- returnNullable = false)
-
- case c if c == classOf[java.sql.Date] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "fromJavaDate",
- inputObject :: Nil,
- returnNullable = false)
+ case c if c == classOf[String] => createSerializerForString(inputObject)
+
+ case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)
+
+ case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)
+
+ case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)
+
+ case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
case c if c == classOf[java.math.BigDecimal] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil,
- returnNullable = false)
-
- case c if c == classOf[java.lang.Boolean] =>
- Invoke(inputObject, "booleanValue", BooleanType)
- case c if c == classOf[java.lang.Byte] =>
- Invoke(inputObject, "byteValue", ByteType)
- case c if c == classOf[java.lang.Short] =>
- Invoke(inputObject, "shortValue", ShortType)
- case c if c == classOf[java.lang.Integer] =>
- Invoke(inputObject, "intValue", IntegerType)
- case c if c == classOf[java.lang.Long] =>
- Invoke(inputObject, "longValue", LongType)
- case c if c == classOf[java.lang.Float] =>
- Invoke(inputObject, "floatValue", FloatType)
- case c if c == classOf[java.lang.Double] =>
- Invoke(inputObject, "doubleValue", DoubleType)
+ createSerializerForJavaBigDecimal(inputObject)
+
+ case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
+ case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
+ case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
+ case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
+ case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
+ case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
+ case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)
case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)
@@ -444,38 +403,34 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
- ExternalMapToCatalyst(
+ createSerializerForMap(
inputObject,
- ObjectType(keyType.getRawType),
- serializerFor(_, keyType),
- keyNullable = true,
- ObjectType(valueType.getRawType),
- serializerFor(_, valueType),
- valueNullable = true
+ MapElementInformation(
+ ObjectType(keyType.getRawType),
+ nullable = true,
+ serializerFor(_, keyType)),
+ MapElementInformation(
+ ObjectType(valueType.getRawType),
+ nullable = true,
+ serializerFor(_, valueType))
)
case other if other.isEnum =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
- returnNullable = false)
+ createSerializerForString(
+ Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))
case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
- val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
+ val fields = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
- })
-
- val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
- expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+ (fieldName, serializerFor(fieldValue, fieldType))
+ }
+ createSerializerForObject(inputObject, fields)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bbddd33..5b3109a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -21,10 +21,11 @@ import org.apache.commons.lang3.reflect.ConstructorUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.{Expression, _}
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -136,7 +137,7 @@ object ScalaReflection extends ScalaReflection {
*/
def deserializerForType(tpe: `Type`): Expression = {
val clsName = getClassNameFromType(tpe)
- val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
+ val walkedTypePath = new WalkedTypePath().recordRoot(clsName)
val Schema(dataType, nullable) = schemaFor(tpe)
// Assumes we are deserializing the first column of a row.
@@ -156,14 +157,14 @@ object ScalaReflection extends ScalaReflection {
private def deserializerFor(
tpe: `Type`,
path: Expression,
- walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
+ walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects {
tpe.dealias match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
- val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
+ val newTypePath = walkedTypePath.recordOption(className)
WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
case t if t <:< localTypeOf[java.lang.Integer] =>
@@ -225,7 +226,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
- val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
+ val newTypePath = walkedTypePath.recordArray(className)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
@@ -260,7 +261,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
- val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
+ val newTypePath = walkedTypePath.recordArray(className)
val mapFunction: Expression => Expression = element => {
deserializerForWithNullSafetyAndUpcast(
@@ -286,8 +287,7 @@ object ScalaReflection extends ScalaReflection {
val classNameForKey = getClassNameFromType(keyType)
val classNameForValue = getClassNameFromType(valueType)
- val newTypePath = (s"""- map key class: "${classNameForKey}"""" +
- s""", value class: "${classNameForValue}"""") +: walkedTypePath
+ val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
UnresolvedCatalystToExternalMap(
path,
@@ -322,28 +322,24 @@ object ScalaReflection extends ScalaReflection {
val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
- val newTypePath = (s"""- field (class: "$clsName", """ +
- s"""name: "$fieldName")""") +: walkedTypePath
+ val newTypePath = walkedTypePath.recordField(clsName, fieldName)
// For tuples, we based grab the inner fields by ordinal instead of name.
- deserializerForWithNullSafety(
- path,
- dataType,
+ val newPath = if (cls.getName startsWith "scala.Tuple") {
+ deserializerFor(
+ fieldType,
+ addToPathOrdinal(path, i, dataType, newTypePath),
+ newTypePath)
+ } else {
+ deserializerFor(
+ fieldType,
+ addToPath(path, fieldName, dataType, newTypePath),
+ newTypePath)
+ }
+ expressionWithNullSafety(
+ newPath,
nullable = nullable,
- newTypePath,
- (expr, typePath) => {
- if (cls.getName startsWith "scala.Tuple") {
- deserializerFor(
- fieldType,
- addToPathOrdinal(expr, i, dataType, typePath),
- newTypePath)
- } else {
- deserializerFor(
- fieldType,
- addToPath(expr, fieldName, dataType, typePath),
- newTypePath)
- }
- })
+ newTypePath)
}
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
@@ -371,7 +367,7 @@ object ScalaReflection extends ScalaReflection {
*/
def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
val clsName = getClassNameFromType(tpe)
- val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
+ val walkedTypePath = new WalkedTypePath().recordRoot(clsName)
// The input object to `ExpressionEncoder` is located at first column of an row.
val isPrimitive = tpe.typeSymbol.asClass.isPrimitive
@@ -387,38 +383,28 @@ object ScalaReflection extends ScalaReflection {
private def serializerFor(
inputObject: Expression,
tpe: `Type`,
- walkedTypePath: Seq[String],
+ walkedTypePath: WalkedTypePath,
seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
dataTypeFor(elementType) match {
case dt: ObjectType =>
val clsName = getClassNameFromType(elementType)
- val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt)
+ val newPath = walkedTypePath.recordArray(clsName)
+ createSerializerForMapObjects(input, dt,
+ serializerFor(_, elementType, newPath, seenTypeSet))
case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType) =>
val cls = input.dataType.asInstanceOf[ObjectType].cls
if (cls.isArray && cls.getComponentType.isPrimitive) {
- StaticInvoke(
- classOf[UnsafeArrayData],
- ArrayType(dt, false),
- "fromPrimitiveArray",
- input :: Nil,
- returnNullable = false)
+ createSerializerForPrimitiveArray(input, dt)
} else {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(dt, schemaFor(elementType).nullable))
+ createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable)
}
case dt =>
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(dt, schemaFor(elementType).nullable))
+ createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable)
}
}
@@ -428,7 +414,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
- val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val newPath = walkedTypePath.recordOption(className)
val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
serializerFor(unwrapped, optType, newPath, seenTypeSet)
@@ -447,17 +433,20 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyClsName = getClassNameFromType(keyType)
val valueClsName = getClassNameFromType(valueType)
- val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath
- val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath
+ val keyPath = walkedTypePath.recordKeyForMap(keyClsName)
+ val valuePath = walkedTypePath.recordValueForMap(valueClsName)
- ExternalMapToCatalyst(
+ createSerializerForMap(
inputObject,
- dataTypeFor(keyType),
- serializerFor(_, keyType, keyPath, seenTypeSet),
- keyNullable = !keyType.typeSymbol.asClass.isPrimitive,
- dataTypeFor(valueType),
- serializerFor(_, valueType, valuePath, seenTypeSet),
- valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
+ MapElementInformation(
+ dataTypeFor(keyType),
+ nullable = !keyType.typeSymbol.asClass.isPrimitive,
+ serializerFor(_, keyType, keyPath, seenTypeSet)),
+ MapElementInformation(
+ dataTypeFor(valueType),
+ nullable = !valueType.typeSymbol.asClass.isPrimitive,
+ serializerFor(_, valueType, valuePath, seenTypeSet))
+ )
case t if t <:< localTypeOf[scala.collection.Set[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
@@ -472,110 +461,47 @@ object ScalaReflection extends ScalaReflection {
toCatalystArray(newInput, elementType)
- case t if t <:< localTypeOf[String] =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- inputObject :: Nil,
- returnNullable = false)
+ case t if t <:< localTypeOf[String] => createSerializerForString(inputObject)
- case t if t <:< localTypeOf[java.time.Instant] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "instantToMicros",
- inputObject :: Nil,
- returnNullable = false)
+ case t if t <:< localTypeOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "fromJavaTimestamp",
- inputObject :: Nil,
- returnNullable = false)
+ createSerializerForSqlTimestamp(inputObject)
case t if t <:< localTypeOf[java.time.LocalDate] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "localDateToDays",
- inputObject :: Nil,
- returnNullable = false)
+ createSerializerForJavaLocalDate(inputObject)
- case t if t <:< localTypeOf[java.sql.Date] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "fromJavaDate",
- inputObject :: Nil,
- returnNullable = false)
+ case t if t <:< localTypeOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
- case t if t <:< localTypeOf[BigDecimal] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil,
- returnNullable = false)
+ case t if t <:< localTypeOf[BigDecimal] => createSerializerForScalaBigDecimal(inputObject)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil,
- returnNullable = false)
+ createSerializerForJavaBigDecimal(inputObject)
case t if t <:< localTypeOf[java.math.BigInteger] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.BigIntDecimal,
- "apply",
- inputObject :: Nil,
- returnNullable = false)
+ createSerializerForJavaBigInteger(inputObject)
- case t if t <:< localTypeOf[scala.math.BigInt] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.BigIntDecimal,
- "apply",
- inputObject :: Nil,
- returnNullable = false)
+ case t if t <:< localTypeOf[scala.math.BigInt] => createSerializerForScalaBigInt(inputObject)
- case t if t <:< localTypeOf[java.lang.Integer] =>
- Invoke(inputObject, "intValue", IntegerType)
- case t if t <:< localTypeOf[java.lang.Long] =>
- Invoke(inputObject, "longValue", LongType)
- case t if t <:< localTypeOf[java.lang.Double] =>
- Invoke(inputObject, "doubleValue", DoubleType)
- case t if t <:< localTypeOf[java.lang.Float] =>
- Invoke(inputObject, "floatValue", FloatType)
- case t if t <:< localTypeOf[java.lang.Short] =>
- Invoke(inputObject, "shortValue", ShortType)
- case t if t <:< localTypeOf[java.lang.Byte] =>
- Invoke(inputObject, "byteValue", ByteType)
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- Invoke(inputObject, "booleanValue", BooleanType)
+ case t if t <:< localTypeOf[java.lang.Integer] => createSerializerForInteger(inputObject)
+ case t if t <:< localTypeOf[java.lang.Long] => createSerializerForLong(inputObject)
+ case t if t <:< localTypeOf[java.lang.Double] => createSerializerForDouble(inputObject)
+ case t if t <:< localTypeOf[java.lang.Float] => createSerializerForFloat(inputObject)
+ case t if t <:< localTypeOf[java.lang.Short] => createSerializerForShort(inputObject)
+ case t if t <:< localTypeOf[java.lang.Byte] => createSerializerForByte(inputObject)
+ case t if t <:< localTypeOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t)
.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "serialize", udt, inputObject :: Nil)
+ val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+ createSerializerForUserDefinedType(inputObject, udt, udtClass)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
newInstance().asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- Invoke(obj, "serialize", udt, inputObject :: Nil)
+ val udtClass = udt.getClass
+ createSerializerForUserDefinedType(inputObject, udt, udtClass)
case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
@@ -584,10 +510,10 @@ object ScalaReflection extends ScalaReflection {
}
val params = getConstructorParameters(t)
- val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val fields = params.map { case (fieldName, fieldType) =>
if (javaKeywords.contains(fieldName)) {
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
- "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
+ "cannot be used as field name\n" + walkedTypePath)
}
// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
@@ -597,16 +523,14 @@ object ScalaReflection extends ScalaReflection {
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
- val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) ::
- serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil
- })
- val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
- expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+ val newPath = walkedTypePath.recordField(clsName, fieldName)
+ (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
+ }
+ createSerializerForObject(inputObject, fields)
- case other =>
+ case _ =>
throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
+ s"No Encoder found for $tpe\n" + walkedTypePath)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
new file mode 100644
index 0000000..e035c4b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.expressions.objects._
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+object SerializerBuildHelper {
+
+ def createSerializerForBoolean(inputObject: Expression): Expression = {
+ Invoke(inputObject, "booleanValue", BooleanType)
+ }
+
+ def createSerializerForByte(inputObject: Expression): Expression = {
+ Invoke(inputObject, "byteValue", ByteType)
+ }
+
+ def createSerializerForShort(inputObject: Expression): Expression = {
+ Invoke(inputObject, "shortValue", ShortType)
+ }
+
+ def createSerializerForInteger(inputObject: Expression): Expression = {
+ Invoke(inputObject, "intValue", IntegerType)
+ }
+
+ def createSerializerForLong(inputObject: Expression): Expression = {
+ Invoke(inputObject, "longValue", LongType)
+ }
+
+ def createSerializerForFloat(inputObject: Expression): Expression = {
+ Invoke(inputObject, "floatValue", FloatType)
+ }
+
+ def createSerializerForDouble(inputObject: Expression): Expression = {
+ Invoke(inputObject, "doubleValue", DoubleType)
+ }
+
+ def createSerializerForString(inputObject: Expression): Expression = {
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForJavaInstant(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ TimestampType,
+ "instantToMicros",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForSqlTimestamp(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ TimestampType,
+ "fromJavaTimestamp",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForJavaLocalDate(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ DateType,
+ "localDateToDays",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForSqlDate(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ DateType,
+ "fromJavaDate",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = {
+ createSerializerForJavaBigDecimal(inputObject)
+ }
+
+ def createSerializerForJavaBigInteger(inputObject: Expression): Expression = {
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.BigIntDecimal,
+ "apply",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
+ createSerializerForJavaBigInteger(inputObject)
+ }
+
+ def createSerializerForPrimitiveArray(
+ inputObject: Expression,
+ dataType: DataType): Expression = {
+ StaticInvoke(
+ classOf[UnsafeArrayData],
+ ArrayType(dataType, false),
+ "fromPrimitiveArray",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForGenericArray(
+ inputObject: Expression,
+ dataType: DataType,
+ nullable: Boolean): Expression = {
+ NewInstance(
+ classOf[GenericArrayData],
+ inputObject :: Nil,
+ dataType = ArrayType(dataType, nullable))
+ }
+
+ def createSerializerForMapObjects(
+ inputObject: Expression,
+ dataType: ObjectType,
+ funcForNewExpr: Expression => Expression): Expression = {
+ MapObjects(funcForNewExpr, inputObject, dataType)
+ }
+
+ case class MapElementInformation(
+ dataType: DataType,
+ nullable: Boolean,
+ funcForNewExpr: Expression => Expression)
+
+ def createSerializerForMap(
+ inputObject: Expression,
+ keyInformation: MapElementInformation,
+ valueInformation: MapElementInformation): Expression = {
+ ExternalMapToCatalyst(
+ inputObject,
+ keyInformation.dataType,
+ keyInformation.funcForNewExpr,
+ keyNullable = keyInformation.nullable,
+ valueInformation.dataType,
+ valueInformation.funcForNewExpr,
+ valueNullable = valueInformation.nullable
+ )
+ }
+
+ private def argumentsForFieldSerializer(
+ fieldName: String,
+ serializerForFieldValue: Expression): Seq[Expression] = {
+ expressions.Literal(fieldName) :: serializerForFieldValue :: Nil
+ }
+
+ def createSerializerForObject(
+ inputObject: Expression,
+ fields: Seq[(String, Expression)]): Expression = {
+ val nonNullOutput = CreateNamedStruct(fields.flatMap { case(fieldName, fieldExpr) =>
+ argumentsForFieldSerializer(fieldName, fieldExpr)
+ })
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+ }
+
+ def createSerializerForUserDefinedType(
+ inputObject: Expression,
+ udt: UserDefinedType[_],
+ udtClass: Class[_]): Expression = {
+ val obj = NewInstance(udtClass, Nil, dataType = ObjectType(udtClass))
+ Invoke(obj, "serialize", udt, inputObject :: Nil)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala
new file mode 100644
index 0000000..cdb55b8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+/**
+ * This class records the paths the serializer and deserializer walk through to reach current path.
+ * Note that this class adds new path in prior to recorded paths so it maintains
+ * the paths as reverse order.
+ */
+case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) extends Serializable {
+ def recordRoot(className: String): WalkedTypePath =
+ newInstance(s"""- root class: "$className"""")
+
+ def recordOption(className: String): WalkedTypePath =
+ newInstance(s"""- option value class: "$className"""")
+
+ def recordArray(elementClassName: String): WalkedTypePath =
+ newInstance(s"""- array element class: "$elementClassName"""")
+
+ def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = {
+ newInstance(s"""- map key class: "$keyClassName"""" +
+ s""", value class: "$valueClassName"""")
+ }
+
+ def recordKeyForMap(keyClassName: String): WalkedTypePath =
+ newInstance(s"""- map key class: "$keyClassName"""")
+
+ def recordValueForMap(valueClassName: String): WalkedTypePath =
+ newInstance(s"""- map value class: "$valueClassName"""")
+
+ def recordField(className: String, fieldName: String): WalkedTypePath =
+ newInstance(s"""- field (class: "$className", name: "$fieldName")""")
+
+ override def toString: String = {
+ walkedPaths.mkString("\n")
+ }
+
+ def getPaths: Seq[String] = walkedPaths
+
+ private def newInstance(newRecord: String): WalkedTypePath =
+ WalkedTypePath(newRecord +: walkedPaths)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 42904c5..ab9cedc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2348,7 +2348,7 @@ class Analyzer(
} else {
// always add an UpCast. it will be removed in the optimizer if it is unnecessary.
Some(Alias(
- UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name
+ UpCast(queryExpr, tableAttr.dataType), tableAttr.name
)(
explicitMetadata = Option(tableAttr.metadata)
))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index da5c1fd..abffda7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -21,7 +21,7 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
+import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 68a603b..97709bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -155,7 +155,7 @@ object RowEncoder {
element => {
val value = serializerFor(ValidateExternalType(element, et), et)
if (!containsNull) {
- AssertNotNull(value, Seq.empty)
+ AssertNotNull(value)
} else {
value
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d591c58..84087ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, WalkedTypePath}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -1378,7 +1378,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
*/
-case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String])
+case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil)
extends UnaryExpression with Unevaluable {
override lazy val resolved = false
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index baa1b3b..7d49866 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -338,7 +338,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("should not apply common subexpression elimination on conditional expressions") {
val row = InternalRow(null)
val bound = BoundReference(0, IntegerType, true)
- val assertNotNull = AssertNotNull(bound, Nil)
+ val assertNotNull = AssertNotNull(bound)
val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
val projection = GenerateUnsafeProjection.generate(
Seq(expr), subexpressionEliminationEnabled = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index b7ce367..49fd59c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -53,7 +53,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("AssertNotNUll") {
val ex = intercept[RuntimeException] {
- evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String]))
+ evaluateWithoutCodegen(AssertNotNull(Literal(null)))
}.getMessage
assert(ex.contains("Null value appeared in non-nullable field"))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org