You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/08 01:50:47 UTC
[spark] branch master updated: [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 46175d13620 [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
46175d13620 is described below
commit 46175d1362062035fb93f87f25d61a9b711359ab
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Jun 8 09:50:08 2022 +0800
[SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
### What changes were proposed in this pull request?
This PR refactors `TryCast` to use `RuntimeReplaceable`, so that we don't need `CastBase` anymore. The unit tests are also simplified because we don't need to check the execution of `RuntimeReplaceable`, but only the analysis behavior.
### Why are the changes needed?
code cleanup
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
Closes #36703 from cloud-fan/cast.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/expressions/Cast.scala | 114 +++++------
.../spark/sql/catalyst/expressions/TryCast.scala | 122 ------------
.../spark/sql/catalyst/expressions/TryEval.scala | 86 +++++++-
.../spark/sql/catalyst/optimizer/expressions.scala | 3 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 10 +-
.../plans/logical/basicLogicalOperators.scala | 2 +-
.../apache/spark/sql/catalyst/util/package.scala | 2 +-
.../sql/catalyst/expressions/CastSuiteBase.scala | 217 ++++++++++-----------
...{CastSuite.scala => CastWithAnsiOffSuite.scala} | 20 +-
...stSuiteBase.scala => CastWithAnsiOnSuite.scala} | 149 +++++---------
.../sql/catalyst/expressions/TryCastSuite.scala | 67 +++++--
.../catalyst/optimizer/ConstantFoldingSuite.scala | 2 +-
.../apache/spark/sql/hive/client/HiveShim.scala | 4 +-
13 files changed, 347 insertions(+), 451 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 6ed25f5e45e..497261be2e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -425,29 +425,54 @@ object Cast {
}
}
-abstract class CastBase extends UnaryExpression
- with TimeZoneAwareExpression
- with NullIntolerant
- with SupportQueryContext {
+/**
+ * Cast the child expression to the target data type.
+ *
+ * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
+ * session local timezone by an analyzer [[ResolveTimeZone]].
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('10' as int);
+ 10
+ """,
+ since = "1.0.0",
+ group = "conversion_funcs")
+case class Cast(
+ child: Expression,
+ dataType: DataType,
+ timeZoneId: Option[String] = None,
+ ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
+ with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
- def child: Expression
+ def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
+ this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
- def dataType: DataType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
- /**
- * Returns true iff we can cast `from` type to `to` type.
- */
- def canCast(from: DataType, to: DataType): Boolean
+ override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
- /**
- * Returns the error message if casting from one type to another one is invalid.
- */
- def typeCheckFailureMessage: String
+ final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
- override def toString: String = s"cast($child as ${dataType.simpleString})"
+ private def typeCheckFailureMessage: String = if (ansiEnabled) {
+ if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
+ Cast.typeCheckFailureMessage(child.dataType, dataType,
+ Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
+ } else {
+ Cast.typeCheckFailureMessage(child.dataType, dataType,
+ Some(SQLConf.ANSI_ENABLED.key -> "false"))
+ }
+ } else {
+ s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
+ }
override def checkInputDataTypes(): TypeCheckResult = {
- if (canCast(child.dataType, dataType)) {
+ if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
@@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)
- protected def ansiEnabled: Boolean
-
override def initQueryContext(): String = if (ansiEnabled) {
origin.context
} else {
@@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)
override lazy val preCanonicalized: Expression = {
- val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
+ val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
if (timeZoneId.isDefined && !needsTimeZone) {
basic.withTimeZone(null)
} else {
@@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression
"""
}
+ override def toString: String = s"cast($child as ${dataType.simpleString})"
+
override def sql: String = dataType match {
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
// type of casting can only be introduced by the analyzer, and can be omitted when converting
@@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression
}
}
-/**
- * Cast the child expression to the target data type.
- *
- * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
- * session local timezone by an analyzer [[ResolveTimeZone]].
- */
-@ExpressionDescription(
- usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
- examples = """
- Examples:
- > SELECT _FUNC_('10' as int);
- 10
- """,
- since = "1.0.0",
- group = "conversion_funcs")
-case class Cast(
- child: Expression,
- dataType: DataType,
- timeZoneId: Option[String] = None,
- override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
- extends CastBase {
-
- def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
- this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
-
- override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
- copy(timeZoneId = Option(timeZoneId))
-
- final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
-
- override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
- Cast.canAnsiCast(from, to)
- } else {
- Cast.canCast(from, to)
- }
-
- override def typeCheckFailureMessage: String = if (ansiEnabled) {
- if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
- Cast.typeCheckFailureMessage(child.dataType, dataType,
- Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
- } else {
- Cast.typeCheckFailureMessage(child.dataType, dataType,
- Some(SQLConf.ANSI_ENABLED.key -> "false"))
- }
- } else {
- s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
- }
-
- override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
-}
-
/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
deleted file mode 100644
index 9ac6329f281..00000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
-
-/**
- * A special version of [[AnsiCast]]. It performs the same operation (i.e. converts a value of
- * one data type into another data type), but returns a NULL value instead of raising an error
- * when the conversion can not be performed.
- *
- * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
- * session local timezone by an analyzer [[ResolveTimeZone]].
- */
-@ExpressionDescription(
- usage = """
- _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
- This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
- true, except it returns NULL instead of raising an error. Note that the behavior of this
- expression doesn't depend on configuration `spark.sql.ansi.enabled`.
- """,
- examples = """
- Examples:
- > SELECT _FUNC_('10' as int);
- 10
- > SELECT _FUNC_(1234567890123L as int);
- null
- """,
- since = "3.2.0",
- group = "conversion_funcs")
-case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
- extends CastBase {
- override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
- copy(timeZoneId = Option(timeZoneId))
-
- // Here we force `ansiEnabled` as true so that we can reuse the evaluation code branches which
- // throw exceptions on conversion failures.
- override protected val ansiEnabled: Boolean = true
-
- override def nullable: Boolean = true
-
- // If the target data type is a complex type which can't have Null values, we should guarantee
- // that the casting between the element types won't produce Null results.
- override def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
- case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
- canCast(fromType, toType) &&
- resolvableNullability(fn || forceNullable(fromType, toType), tn)
-
- case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
- canCast(fromKey, toKey) &&
- (!forceNullable(fromKey, toKey)) &&
- canCast(fromValue, toValue) &&
- resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
-
- case (StructType(fromFields), StructType(toFields)) =>
- fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall {
- case (fromField, toField) =>
- canCast(fromField.dataType, toField.dataType) &&
- resolvableNullability(
- fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
- toField.nullable)
- }
-
- case _ =>
- Cast.canAnsiCast(from, to)
- }
-
- override def cast(from: DataType, to: DataType): Any => Any = (input: Any) =>
- try {
- super.cast(from, to)(input)
- } catch {
- case _: Exception =>
- null
- }
-
- override def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
- result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
- val javaType = JavaCode.javaType(resultType)
- code"""
- boolean $resultIsNull = $inputIsNull;
- $javaType $result = ${CodeGenerator.defaultValue(resultType)};
- if (!$inputIsNull) {
- try {
- ${cast(input, result, resultIsNull)}
- } catch (Exception e) {
- $resultIsNull = true;
- }
- }
- """
- }
-
- override def typeCheckFailureMessage: String =
- Cast.typeCheckFailureMessage(child.dataType, dataType, None)
-
- override protected def withNewChildInternal(newChild: Expression): TryCast =
- copy(child = newChild)
-
- override def toString: String = {
- s"try_cast($child as ${dataType.simpleString})"
- }
-
- override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
index c179c83befb..dc5bcae4c08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
@@ -18,9 +18,12 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
copy(child = newChild)
}
+/**
+ * A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
+ * value of one data type into another data type), but returns a NULL value instead of raising an
+ * error when the conversion can not be performed.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
+ This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
+ true, except it returns NULL instead of raising an error. Note that the behavior of this
+ expression doesn't depend on configuration `spark.sql.ansi.enabled`.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('10' as int);
+ 10
+ > SELECT _FUNC_(1234567890123L as int);
+ null
+ """,
+ since = "3.2.0",
+ group = "conversion_funcs")
+case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
+ extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
+ override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
+
+ // When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
+ // Otherwise behave like Expression.resolved.
+ override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
+ (!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)
+
+ override lazy val replacement = {
+ TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
+ }
+
+ // If the target data type is a complex type which can't have Null values, we should guarantee
+ // that the casting between the element types won't produce Null results.
+ private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
+ case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+ canCast(fromType, toType) &&
+ resolvableNullability(fn || forceNullable(fromType, toType), tn)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ canCast(fromKey, toKey) &&
+ (!forceNullable(fromKey, toKey)) &&
+ canCast(fromValue, toValue) &&
+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).forall {
+ case (fromField, toField) =>
+ canCast(fromField.dataType, toField.dataType) &&
+ resolvableNullability(
+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+ toField.nullable)
+ }
+
+ case _ =>
+ Cast.canAnsiCast(from, to)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (canCast(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
+ }
+ }
+
+ override def toString: String = s"try_cast($child as ${dataType.simpleString})"
+
+ override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ this.copy(child = newChild)
+}
+
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 62c328a29a8..3fc23c31ac7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
true
- case _: CastBase => true
+ case _: Cast => true
+ case _: TryEval => true
case _: GetDateField | _: LastDay => true
case _: ExtractIntervalPart[_] => true
case _: ArraySetLike => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7ae04010ad2..46847411bf0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1788,15 +1788,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
val rawDataType = typedVisit[DataType](ctx.dataType())
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
- val cast = ctx.name.getType match {
+ ctx.name.getType match {
case SqlBaseParser.CAST =>
- Cast(expression(ctx.expression), dataType)
+ val cast = Cast(expression(ctx.expression), dataType)
+ cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
+ cast
case SqlBaseParser.TRY_CAST =>
+ // `TryCast` can only be user-specified and we don't need to set the USER_SPECIFIED_CAST
+ // tag, which is only used by `Cast`
TryCast(expression(ctx.expression), dataType)
}
- cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
- cast
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 677bdf27336..11d68294023 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -590,7 +590,7 @@ case class View(
// See more details in `SessionCatalog.fromCatalogTable`.
private def canRemoveProject(p: Project): Boolean = {
p.output.length == p.child.output.length && p.projectList.zip(p.child.output).forall {
- case (Alias(cast: CastBase, name), childAttr) =>
+ case (Alias(cast: Cast, name), childAttr) =>
cast.child match {
case a: AttributeReference =>
a.dataType == cast.dataType && a.name == name && childAttr.semanticEquals(a)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index e06072cbed2..f73fc7c6816 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -119,7 +119,7 @@ package object util extends Logging {
PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
case r: InheritAnalysisRules =>
PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)), r.dataType)
- case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) =>
+ case c: Cast if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) =>
PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType)
case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index ba8ab708046..ca492e11226 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -43,11 +43,19 @@ import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.UTF8String
/**
- * Common test suite for [[Cast]], [[AnsiCast]] and [[TryCast]] expressions.
+ * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work
+ * for both ansi on and off.
*/
abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
- protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase
+ protected def ansiEnabled: Boolean
+
+ protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
+ v match {
+ case lit: Expression => Cast(lit, targetType, timeZoneId, ansiEnabled)
+ case _ => Cast(Literal(v), targetType, timeZoneId, ansiEnabled)
+ }
+ }
// expected cannot be null
protected def checkCast(v: Any, expected: Any): Unit = {
@@ -58,7 +66,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null)
}
- protected def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = {
+ protected def verifyCastFailure(c: Cast, optionalExpectedMsg: Option[String] = None): Unit = {
val typeCheckResult = c.checkInputDataTypes()
assert(typeCheckResult.isFailure)
assert(typeCheckResult.isInstanceOf[TypeCheckFailure])
@@ -66,20 +74,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
if (optionalExpectedMsg.isDefined) {
assert(message.contains(optionalExpectedMsg.get))
- } else if (setConfigurationHint.nonEmpty) {
- assert(message.contains("with ANSI mode on"))
- assert(message.contains(setConfigurationHint))
} else {
assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined)
+ if (ansiEnabled) {
+ assert(message.contains("with ANSI mode on"))
+ assert(message.contains(s"set ${SQLConf.ANSI_ENABLED.key} as false"))
+ }
}
}
- // Whether the test suite is for TryCast. If yes, there is no exceptions and the result is
- // always nullable.
- protected def isTryCast: Boolean = false
-
- protected def setConfigurationHint: String = ""
-
test("null cast") {
import DataTypeTestUtils._
@@ -281,8 +284,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from string") {
- assert(cast("abcdef", StringType).nullable === isTryCast)
- assert(cast("abcdef", BinaryType).nullable === isTryCast)
+ assert(!cast("abcdef", StringType).nullable)
+ assert(!cast("abcdef", BinaryType).nullable)
assert(cast("abcdef", BooleanType).nullable)
assert(cast("abcdef", TimestampType).nullable)
assert(cast("abcdef", LongType).nullable)
@@ -981,14 +984,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
DayTimeIntervalType()), StringType), ansiInterval)
}
- if (!isTryCast) {
- Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND",
- "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval =>
- val e = intercept[ArithmeticException] {
- cast(Literal.create(interval), DayTimeIntervalType()).eval()
- }.getMessage
- assert(e.contains("long overflow"))
- }
+ Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND",
+ "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval =>
+ val e = intercept[ArithmeticException] {
+ cast(Literal.create(interval), DayTimeIntervalType()).eval()
+ }.getMessage
+ assert(e.contains("long overflow"))
}
Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1,
@@ -1027,15 +1028,13 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
YearMonthIntervalType()), StringType), ansiInterval)
}
- if (!isTryCast) {
- Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
- .foreach { interval =>
- val e = intercept[IllegalArgumentException] {
- cast(Literal.create(interval), YearMonthIntervalType()).eval()
- }.getMessage
- assert(e.contains("Error parsing interval year-month string: integer overflow"))
- }
- }
+ Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
+ .foreach { interval =>
+ val e = intercept[IllegalArgumentException] {
+ cast(Literal.create(interval), YearMonthIntervalType()).eval()
+ }.getMessage
+ assert(e.contains("Error parsing interval year-month string: integer overflow"))
+ }
Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue)
.foreach { period =>
@@ -1098,9 +1097,27 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
- if (!isTryCast) {
- Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
- val dataType = YearMonthIntervalType()
+ Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
+ val dataType = YearMonthIntervalType()
+ val e = intercept[IllegalArgumentException] {
+ cast(Literal.create(interval), dataType).eval()
+ }.getMessage
+ assert(e.contains(s"Interval string does not match year-month format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval"))
+ }
+ Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
+ ("1", YearMonthIntervalType(YEAR, MONTH)),
+ ("1-1", YearMonthIntervalType(YEAR)),
+ ("1-1", YearMonthIntervalType(MONTH)),
+ ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)),
+ ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)),
+ ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)),
+ ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)),
+ ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)),
+ ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH)))
+ .foreach { case (interval, dataType) =>
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
@@ -1109,26 +1126,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval"))
}
- Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
- ("1", YearMonthIntervalType(YEAR, MONTH)),
- ("1-1", YearMonthIntervalType(YEAR)),
- ("1-1", YearMonthIntervalType(MONTH)),
- ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)),
- ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)),
- ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)),
- ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)),
- ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)),
- ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH)))
- .foreach { case (interval, dataType) =>
- val e = intercept[IllegalArgumentException] {
- cast(Literal.create(interval), dataType).eval()
- }.getMessage
- assert(e.contains(s"Interval string does not match year-month format of " +
- s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${dataType.typeName}: $interval"))
- }
- }
}
test("SPARK-35735: Take into account day-time interval fields in cast") {
@@ -1218,63 +1215,61 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Literal.create(interval), dataType), dt)
}
- if (!isTryCast) {
- Seq(
- ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)),
- ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)),
- ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
- ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
- ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)),
- ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
-
- ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)),
- ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)),
- ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
- ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
- ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)),
- ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
- ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)),
- ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)),
- ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)),
- ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)),
- ("1.23", DayTimeIntervalType(DAY)),
- ("1.23", DayTimeIntervalType(HOUR)),
- ("1.23", DayTimeIntervalType(MINUTE)),
- ("1.23", DayTimeIntervalType(MINUTE)))
- .foreach { case (interval, dataType) =>
- val e = intercept[IllegalArgumentException] {
- cast(Literal.create(interval), dataType).eval()
- }.getMessage
- assert(e.contains(s"Interval string does not match day-time format of " +
- s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${dataType.typeName}: $interval, " +
- s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
- "to restore the behavior before Spark 3.0."))
- }
+ Seq(
+ ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)),
+ ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)),
+ ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+ ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
+ ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)),
+ ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
+
+ ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)),
+ ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)),
+ ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+ ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
+ ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)),
+ ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
+ ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)),
+ ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)),
+ ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)),
+ ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)),
+ ("1.23", DayTimeIntervalType(DAY)),
+ ("1.23", DayTimeIntervalType(HOUR)),
+ ("1.23", DayTimeIntervalType(MINUTE)),
+ ("1.23", DayTimeIntervalType(MINUTE)))
+ .foreach { case (interval, dataType) =>
+ val e = intercept[IllegalArgumentException] {
+ cast(Literal.create(interval), dataType).eval()
+ }.getMessage
+ assert(e.contains(s"Interval string does not match day-time format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval, " +
+ s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+ "to restore the behavior before Spark 3.0."))
+ }
- // Check first field outof bound
- Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)),
- ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)),
- ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
- ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()),
- ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)),
- ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)),
- ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)),
- ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)),
- ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
- ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
- .foreach { case (interval, dataType) =>
- val e = intercept[IllegalArgumentException] {
- cast(Literal.create(interval), dataType).eval()
- }.getMessage
- assert(e.contains(s"Interval string does not match day-time format of " +
- s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
- .map(format => s"`$format`").mkString(", ")} " +
- s"when cast to ${dataType.typeName}: $interval, " +
- s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
- "to restore the behavior before Spark 3.0."))
- }
- }
+ // Check first field outof bound
+ Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)),
+ ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)),
+ ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+ ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()),
+ ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)),
+ ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)),
+ ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)),
+ ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)),
+ ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
+ ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
+ .foreach { case (interval, dataType) =>
+ val e = intercept[IllegalArgumentException] {
+ cast(Literal.create(interval), dataType).eval()
+ }.getMessage
+ assert(e.contains(s"Interval string does not match day-time format of " +
+ s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+ .map(format => s"`$format`").mkString(", ")} " +
+ s"when cast to ${dataType.typeName}: $interval, " +
+ s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+ "to restore the behavior before Spark 3.0."))
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
similarity index 98%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
index 630c45adba1..4e4bc096dea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
@@ -35,26 +35,10 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
- * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them
- * in `CastSuiteBase` instead of this file to ensure the test coverage.
*/
-class CastSuite extends CastSuiteBase {
- override def beforeAll(): Unit = {
- super.beforeAll()
- SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false)
- }
-
- override def afterAll(): Unit = {
- super.afterAll()
- SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
- }
+class CastWithAnsiOffSuite extends CastSuiteBase {
- override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
- v match {
- case lit: Expression => Cast(lit, targetType, timeZoneId)
- case _ => Cast(Literal(v), targetType, timeZoneId)
- }
- }
+ override def ansiEnabled: Boolean = false
test("null cast #2") {
import DataTypeTestUtils._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
similarity index 85%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
index 84f0d5c59aa..f2cfc529984 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
@@ -27,19 +27,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
- * Test suite base for
- * 1. [[Cast]] with ANSI mode enabled
- * 2. [[AnsiCast]]
- * 3. [[TryCast]]
- * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them
- * in `CastSuiteBase` instead of this file to ensure the test coverage.
+ * Test suite for data type casting expression [[Cast]] with ANSI mode enabled.
*/
-abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
+class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
+
+ override def ansiEnabled: Boolean = true
private def testIntMaxAndMin(dt: DataType): Unit = {
assert(Seq(IntegerType, ShortType, ByteType).contains(dt))
@@ -339,25 +335,21 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
- assert(ret.resolved == !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[SparkRuntimeException](
- ret, """cannot be cast to "BOOLEAN"""")
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[SparkRuntimeException](
+ ret, """cannot be cast to "BOOLEAN"""")
}
}
test("cast from array III") {
- if (!isTryCast) {
- val from: DataType = ArrayType(DoubleType, containsNull = false)
- val array = Literal.create(Seq(1.0, 2.0), from)
- val to: DataType = ArrayType(IntegerType, containsNull = false)
- val answer = Literal.create(Seq(1, 2), to).value
- checkEvaluation(cast(array, to), answer)
+ val from: DataType = ArrayType(DoubleType, containsNull = false)
+ val array = Literal.create(Seq(1.0, 2.0), from)
+ val to: DataType = ArrayType(IntegerType, containsNull = false)
+ val answer = Literal.create(Seq(1, 2), to).value
+ checkEvaluation(cast(array, to), answer)
- val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from)
- checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow")
- }
+ val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from)
+ checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow")
}
test("cast from map II") {
@@ -386,48 +378,40 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
{
val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
- assert(ret.resolved == !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[NumberFormatException](
- ret,
- castErrMsg("a", IntegerType))
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[NumberFormatException](
+ ret,
+ castErrMsg("a", IntegerType))
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
- assert(ret.resolved == !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[SparkRuntimeException](
- ret,
- castErrMsg("123", BooleanType))
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[SparkRuntimeException](
+ ret,
+ castErrMsg("123", BooleanType))
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
- assert(ret.resolved == !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[NumberFormatException](
- ret,
- castErrMsg("a", IntegerType))
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[NumberFormatException](
+ ret,
+ castErrMsg("a", IntegerType))
}
}
test("cast from map III") {
- if (!isTryCast) {
- val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false)
- val map = Literal.create(Map(1.0 -> 2.0), from)
- val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false)
- val answer = Literal.create(Map(1 -> 2), to).value
- checkEvaluation(cast(map, to), answer)
-
- Seq(
- Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from),
- Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap =>
- checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow")
- }
+ val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false)
+ val map = Literal.create(Map(1.0 -> 2.0), from)
+ val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false)
+ val answer = Literal.create(Map(1 -> 2), to).value
+ checkEvaluation(cast(map, to), answer)
+
+ Seq(
+ Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from),
+ Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap =>
+ checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow")
}
}
@@ -487,26 +471,22 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
- assert(ret.resolved == !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[SparkRuntimeException](
- ret,
- castErrMsg("123", BooleanType))
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[SparkRuntimeException](
+ ret,
+ castErrMsg("123", BooleanType))
}
}
test("cast from struct III") {
- if (!isTryCast) {
- val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false)))
- val struct = Literal.create(InternalRow(1.0), from)
- val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false)))
- val answer = Literal.create(InternalRow(1), to).value
- checkEvaluation(cast(struct, to), answer)
+ val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false)))
+ val struct = Literal.create(InternalRow(1.0), from)
+ val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false)))
+ val answer = Literal.create(InternalRow(1), to).value
+ checkEvaluation(cast(struct, to), answer)
- val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from)
- checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow")
- }
+ val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from)
+ checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow")
}
test("complex casting") {
@@ -533,12 +513,10 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
StructType(Seq(
StructField("l", LongType, nullable = true)))))))
- assert(ret.resolved === !isTryCast)
- if (!isTryCast) {
- checkExceptionInExpression[NumberFormatException](
- ret,
- castErrMsg("true", IntegerType))
- }
+ assert(ret.resolved)
+ checkExceptionInExpression[NumberFormatException](
+ ret,
+ castErrMsg("true", IntegerType))
}
test("ANSI mode: cast string to timestamp with parse error") {
@@ -599,28 +577,3 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
}
}
}
-
-/**
- * Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
- */
-class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
- override def beforeAll(): Unit = {
- super.beforeAll()
- SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
- }
-
- override def afterAll(): Unit = {
- super.afterAll()
- SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
- }
-
- override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
- v match {
- case lit: Expression => Cast(lit, targetType, timeZoneId)
- case _ => Cast(Literal(v), targetType, timeZoneId)
- }
- }
-
- override def setConfigurationHint: String =
- s"set ${SQLConf.ANSI_ENABLED.key} as false"
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
index bb9ab888947..bb66a9fd24a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
@@ -17,40 +17,65 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.reflect.ClassTag
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+// A test suite to check analysis behaviors of `TryCast`.
+class TryCastSuite extends SparkFunSuite {
-class TryCastSuite extends AnsiCastSuiteBase {
- override protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String]) = {
+ private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): TryCast = {
v match {
case lit: Expression => TryCast(lit, targetType, timeZoneId)
case _ => TryCast(Literal(v), targetType, timeZoneId)
}
}
- override def isTryCast: Boolean = true
-
- override protected def setConfigurationHint: String = ""
-
- override def checkExceptionInExpression[T <: Throwable : ClassTag](
- expression: => Expression,
- inputRow: InternalRow,
- expectedErrMsg: String): Unit = {
- checkEvaluation(expression, null, inputRow)
+ test("print string") {
+ assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
+ assert(TryCast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)")
}
- override def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
- checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+ test("nullability") {
+ assert(cast("abcdef", StringType).nullable)
+ assert(cast("abcdef", BinaryType).nullable)
}
- override def checkCastToNumericError(l: Literal, to: DataType,
- expectedDataTypeInErrorMsg: DataType, tryCastResult: Any): Unit = {
- checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+ test("only require timezone for datetime types") {
+ assert(cast("abc", IntegerType).resolved)
+ assert(!cast("abc", TimestampType).resolved)
+ assert(cast("abc", TimestampType, UTC_OPT).resolved)
}
- test("try_cast: to_string") {
- assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
+ test("element type nullability") {
+ val array = Literal.create(Seq("123", "true"),
+ ArrayType(StringType, containsNull = false))
+ // array element can be null after try_cast which violates the target type.
+ val c1 = cast(array, ArrayType(BooleanType, containsNull = false))
+ assert(!c1.resolved)
+
+ val map = Literal.create(Map("a" -> "123", "b" -> "true"),
+ MapType(StringType, StringType, valueContainsNull = false))
+ // key can be null after try_cast which violates the map key requirement.
+ val c2 = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
+ assert(!c2.resolved)
+ // map value can be null after try_cast which violates the target type.
+ val c3 = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
+ assert(!c3.resolved)
+
+ val struct = Literal.create(
+ InternalRow(
+ UTF8String.fromString("123"),
+ UTF8String.fromString("true")),
+ new StructType()
+ .add("a", StringType, nullable = true)
+ .add("b", StringType, nullable = true))
+ // struct field `b` can be null after try_cast which violates the target type.
+ val c4 = cast(struct, new StructType()
+ .add("a", BooleanType, nullable = true)
+ .add("b", BooleanType, nullable = false))
+ assert(!c4.resolved)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index dc92cf24ab1..e1d1f064e34 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -146,7 +146,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Cast(Literal("2"), IntegerType) + Literal(3) + $"a" as "c1",
- Coalesce(Seq(TryCast(Literal("abc"), IntegerType), Literal(3))) as "c2")
+ Coalesce(Seq(TryCast(Literal("abc"), IntegerType).replacement, Literal(3))) as "c2")
val optimized = Optimize.execute(originalQuery.analyze)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 67bb72c1878..95e5582cb8c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -1146,8 +1146,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
// client-side filtering cannot be used with TimeZoneAwareExpression.
def hasTimeZoneAwareExpression(e: Expression): Boolean = {
e.exists {
- case cast: CastBase => cast.needsTimeZone
- case tz: TimeZoneAwareExpression => !tz.isInstanceOf[CastBase]
+ case cast: Cast => cast.needsTimeZone
+ case tz: TimeZoneAwareExpression => !tz.isInstanceOf[Cast]
case _ => false
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org