You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/08 01:50:47 UTC

[spark] branch master updated: [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 46175d13620 [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
46175d13620 is described below

commit 46175d1362062035fb93f87f25d61a9b711359ab
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Jun 8 09:50:08 2022 +0800

    [SPARK-39321][SQL] Refactor TryCast to use RuntimeReplaceable
    
    ### What changes were proposed in this pull request?
    
    This PR refactors `TryCast` to use `RuntimeReplaceable`, so that we don't need `CastBase` anymore. The unit tests are also simplified because we don't need to check the execution of `RuntimeReplaceable`, but only the analysis behavior.
    
    ### Why are the changes needed?
    
    code cleanup
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #36703 from cloud-fan/cast.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 114 +++++------
 .../spark/sql/catalyst/expressions/TryCast.scala   | 122 ------------
 .../spark/sql/catalyst/expressions/TryEval.scala   |  86 +++++++-
 .../spark/sql/catalyst/optimizer/expressions.scala |   3 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  10 +-
 .../plans/logical/basicLogicalOperators.scala      |   2 +-
 .../apache/spark/sql/catalyst/util/package.scala   |   2 +-
 .../sql/catalyst/expressions/CastSuiteBase.scala   | 217 ++++++++++-----------
 ...{CastSuite.scala => CastWithAnsiOffSuite.scala} |  20 +-
 ...stSuiteBase.scala => CastWithAnsiOnSuite.scala} | 149 +++++---------
 .../sql/catalyst/expressions/TryCastSuite.scala    |  67 +++++--
 .../catalyst/optimizer/ConstantFoldingSuite.scala  |   2 +-
 .../apache/spark/sql/hive/client/HiveShim.scala    |   4 +-
 13 files changed, 347 insertions(+), 451 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 6ed25f5e45e..497261be2e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -425,29 +425,54 @@ object Cast {
     }
 }
 
-abstract class CastBase extends UnaryExpression
-    with TimeZoneAwareExpression
-    with NullIntolerant
-    with SupportQueryContext {
+/**
+ * Cast the child expression to the target data type.
+ *
+ * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
+ * session local timezone by an analyzer [[ResolveTimeZone]].
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('10' as int);
+       10
+  """,
+  since = "1.0.0",
+  group = "conversion_funcs")
+case class Cast(
+    child: Expression,
+    dataType: DataType,
+    timeZoneId: Option[String] = None,
+    ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends UnaryExpression
+  with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext {
 
-  def child: Expression
+  def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
+    this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
 
-  def dataType: DataType
+  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+    copy(timeZoneId = Option(timeZoneId))
 
-  /**
-   * Returns true iff we can cast `from` type to `to` type.
-   */
-  def canCast(from: DataType, to: DataType): Boolean
+  override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
 
-  /**
-   * Returns the error message if casting from one type to another one is invalid.
-   */
-  def typeCheckFailureMessage: String
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
 
-  override def toString: String = s"cast($child as ${dataType.simpleString})"
+  private def typeCheckFailureMessage: String = if (ansiEnabled) {
+    if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
+      Cast.typeCheckFailureMessage(child.dataType, dataType,
+        Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
+    } else {
+      Cast.typeCheckFailureMessage(child.dataType, dataType,
+        Some(SQLConf.ANSI_ENABLED.key -> "false"))
+    }
+  } else {
+    s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
+  }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (canCast(child.dataType, dataType)) {
+    if (ansiEnabled && Cast.canAnsiCast(child.dataType, dataType)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else if (!ansiEnabled && Cast.canCast(child.dataType, dataType)) {
       TypeCheckResult.TypeCheckSuccess
     } else {
       TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage)
@@ -456,8 +481,6 @@ abstract class CastBase extends UnaryExpression
 
   override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)
 
-  protected def ansiEnabled: Boolean
-
   override def initQueryContext(): String = if (ansiEnabled) {
     origin.context
   } else {
@@ -470,7 +493,7 @@ abstract class CastBase extends UnaryExpression
     childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)
 
   override lazy val preCanonicalized: Expression = {
-    val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
+    val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
     if (timeZoneId.isDefined && !needsTimeZone) {
       basic.withTimeZone(null)
     } else {
@@ -2246,6 +2269,8 @@ abstract class CastBase extends UnaryExpression
       """
   }
 
+  override def toString: String = s"cast($child as ${dataType.simpleString})"
+
   override def sql: String = dataType match {
     // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
     // type of casting can only be introduced by the analyzer, and can be omitted when converting
@@ -2255,57 +2280,6 @@ abstract class CastBase extends UnaryExpression
   }
 }
 
-/**
- * Cast the child expression to the target data type.
- *
- * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
- * session local timezone by an analyzer [[ResolveTimeZone]].
- */
-@ExpressionDescription(
-  usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_('10' as int);
-       10
-  """,
-  since = "1.0.0",
-  group = "conversion_funcs")
-case class Cast(
-    child: Expression,
-    dataType: DataType,
-    timeZoneId: Option[String] = None,
-    override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
-  extends CastBase {
-
-  def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
-    this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
-
-  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
-    copy(timeZoneId = Option(timeZoneId))
-
-  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
-
-  override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
-    Cast.canAnsiCast(from, to)
-  } else {
-    Cast.canCast(from, to)
-  }
-
-  override def typeCheckFailureMessage: String = if (ansiEnabled) {
-    if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
-      Cast.typeCheckFailureMessage(child.dataType, dataType,
-        Some(SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString))
-    } else {
-      Cast.typeCheckFailureMessage(child.dataType, dataType,
-        Some(SQLConf.ANSI_ENABLED.key -> "false"))
-    }
-  } else {
-    s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
-  }
-
-  override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
-}
-
 /**
  * Cast the child expression to the target data type, but will throw error if the cast might
  * truncate, e.g. long -> int, timestamp -> data.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
deleted file mode 100644
index 9ac6329f281..00000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
-
-/**
- * A special version of [[AnsiCast]]. It performs the same operation (i.e. converts a value of
- * one data type into another data type), but returns a NULL value instead of raising an error
- * when the conversion can not be performed.
- *
- * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
- * session local timezone by an analyzer [[ResolveTimeZone]].
- */
-@ExpressionDescription(
-  usage = """
-    _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
-      This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
-      true, except it returns NULL instead of raising an error. Note that the behavior of this
-      expression doesn't depend on configuration `spark.sql.ansi.enabled`.
-  """,
-  examples = """
-    Examples:
-      > SELECT _FUNC_('10' as int);
-       10
-      > SELECT _FUNC_(1234567890123L as int);
-       null
-  """,
-  since = "3.2.0",
-  group = "conversion_funcs")
-case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
-  extends CastBase {
-  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
-    copy(timeZoneId = Option(timeZoneId))
-
-  // Here we force `ansiEnabled` as true so that we can reuse the evaluation code branches which
-  // throw exceptions on conversion failures.
-  override protected val ansiEnabled: Boolean = true
-
-  override def nullable: Boolean = true
-
-  // If the target data type is a complex type which can't have Null values, we should guarantee
-  // that the casting between the element types won't produce Null results.
-  override def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
-    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
-      canCast(fromType, toType) &&
-        resolvableNullability(fn || forceNullable(fromType, toType), tn)
-
-    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
-      canCast(fromKey, toKey) &&
-        (!forceNullable(fromKey, toKey)) &&
-        canCast(fromValue, toValue) &&
-        resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
-
-    case (StructType(fromFields), StructType(toFields)) =>
-      fromFields.length == toFields.length &&
-        fromFields.zip(toFields).forall {
-          case (fromField, toField) =>
-            canCast(fromField.dataType, toField.dataType) &&
-              resolvableNullability(
-                fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
-                toField.nullable)
-        }
-
-    case _ =>
-      Cast.canAnsiCast(from, to)
-  }
-
-  override def cast(from: DataType, to: DataType): Any => Any = (input: Any) =>
-    try {
-      super.cast(from, to)(input)
-    } catch {
-      case _: Exception =>
-        null
-    }
-
-  override def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
-    result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
-    val javaType = JavaCode.javaType(resultType)
-    code"""
-      boolean $resultIsNull = $inputIsNull;
-      $javaType $result = ${CodeGenerator.defaultValue(resultType)};
-      if (!$inputIsNull) {
-        try {
-          ${cast(input, result, resultIsNull)}
-        } catch (Exception e) {
-          $resultIsNull = true;
-        }
-      }
-    """
-  }
-
-  override def typeCheckFailureMessage: String =
-    Cast.typeCheckFailureMessage(child.dataType, dataType, None)
-
-  override protected def withNewChildInternal(newChild: Expression): TryCast =
-    copy(child = newChild)
-
-  override def toString: String = {
-    s"try_cast($child as ${dataType.simpleString})"
-  }
-
-  override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
index c179c83befb..dc5bcae4c08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
@@ -18,9 +18,12 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvableNullability}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
 
 case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -53,6 +56,87 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran
     copy(child = newChild)
 }
 
+/**
+ * A special version of [[Cast]] with ansi mode on. It performs the same operation (i.e. converts a
+ * value of one data type into another data type), but returns a NULL value instead of raising an
+ * error when the conversion can not be performed.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.
+      This expression is identical to CAST with configuration `spark.sql.ansi.enabled` as
+      true, except it returns NULL instead of raising an error. Note that the behavior of this
+      expression doesn't depend on configuration `spark.sql.ansi.enabled`.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('10' as int);
+       10
+      > SELECT _FUNC_(1234567890123L as int);
+       null
+  """,
+  since = "3.2.0",
+  group = "conversion_funcs")
+case class TryCast(child: Expression, toType: DataType, timeZoneId: Option[String] = None)
+  extends UnaryExpression with RuntimeReplaceable with TimeZoneAwareExpression {
+
+  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+    copy(timeZoneId = Option(timeZoneId))
+
+  override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
+
+  // When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
+  // Otherwise behave like Expression.resolved.
+  override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess &&
+    (!Cast.needsTimeZone(child.dataType, toType) || timeZoneId.isDefined)
+
+  override lazy val replacement = {
+    TryEval(Cast(child, toType, timeZoneId = timeZoneId, ansiEnabled = true))
+  }
+
+  // If the target data type is a complex type which can't have Null values, we should guarantee
+  // that the casting between the element types won't produce Null results.
+  private def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
+    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+      canCast(fromType, toType) &&
+        resolvableNullability(fn || forceNullable(fromType, toType), tn)
+
+    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+      canCast(fromKey, toKey) &&
+        (!forceNullable(fromKey, toKey)) &&
+        canCast(fromValue, toValue) &&
+        resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+    case (StructType(fromFields), StructType(toFields)) =>
+      fromFields.length == toFields.length &&
+        fromFields.zip(toFields).forall {
+          case (fromField, toField) =>
+            canCast(fromField.dataType, toField.dataType) &&
+              resolvableNullability(
+                fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
+                toField.nullable)
+        }
+
+    case _ =>
+      Cast.canAnsiCast(from, to)
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (canCast(child.dataType, dataType)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(Cast.typeCheckFailureMessage(child.dataType, toType, None))
+    }
+  }
+
+  override def toString: String = s"try_cast($child as ${dataType.simpleString})"
+
+  override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})"
+
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    this.copy(child = newChild)
+}
+
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = "_FUNC_(expr1, expr2) - Returns the sum of `expr1`and `expr2` and the result is null on overflow. " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 62c328a29a8..3fc23c31ac7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -631,7 +631,8 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
     case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
     case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
       true
-    case _: CastBase => true
+    case _: Cast => true
+    case _: TryEval => true
     case _: GetDateField | _: LastDay => true
     case _: ExtractIntervalPart[_] => true
     case _: ArraySetLike => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7ae04010ad2..46847411bf0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1788,15 +1788,17 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
   override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
     val rawDataType = typedVisit[DataType](ctx.dataType())
     val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
-    val cast = ctx.name.getType match {
+    ctx.name.getType match {
       case SqlBaseParser.CAST =>
-        Cast(expression(ctx.expression), dataType)
+        val cast = Cast(expression(ctx.expression), dataType)
+        cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
+        cast
 
       case SqlBaseParser.TRY_CAST =>
+        // `TryCast` can only be user-specified and we don't need to set the USER_SPECIFIED_CAST
+        // tag, which is only used by `Cast`
         TryCast(expression(ctx.expression), dataType)
     }
-    cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
-    cast
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 677bdf27336..11d68294023 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -590,7 +590,7 @@ case class View(
   // See more details in `SessionCatalog.fromCatalogTable`.
   private def canRemoveProject(p: Project): Boolean = {
     p.output.length == p.child.output.length && p.projectList.zip(p.child.output).forall {
-      case (Alias(cast: CastBase, name), childAttr) =>
+      case (Alias(cast: Cast, name), childAttr) =>
         cast.child match {
           case a: AttributeReference =>
             a.dataType == cast.dataType && a.name == name && childAttr.semanticEquals(a)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index e06072cbed2..f73fc7c6816 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -119,7 +119,7 @@ package object util extends Logging {
       PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
     case r: InheritAnalysisRules =>
       PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)), r.dataType)
-    case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) =>
+    case c: Cast if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) =>
       PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType)
     case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children)
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index ba8ab708046..ca492e11226 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -43,11 +43,19 @@ import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
- * Common test suite for [[Cast]], [[AnsiCast]] and [[TryCast]] expressions.
+ * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work
+ * for both ansi on and off.
  */
 abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
 
-  protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase
+  protected def ansiEnabled: Boolean
+
+  protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
+    v match {
+      case lit: Expression => Cast(lit, targetType, timeZoneId, ansiEnabled)
+      case _ => Cast(Literal(v), targetType, timeZoneId, ansiEnabled)
+    }
+  }
 
   // expected cannot be null
   protected def checkCast(v: Any, expected: Any): Unit = {
@@ -58,7 +66,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null)
   }
 
-  protected def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = {
+  protected def verifyCastFailure(c: Cast, optionalExpectedMsg: Option[String] = None): Unit = {
     val typeCheckResult = c.checkInputDataTypes()
     assert(typeCheckResult.isFailure)
     assert(typeCheckResult.isInstanceOf[TypeCheckFailure])
@@ -66,20 +74,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
 
     if (optionalExpectedMsg.isDefined) {
       assert(message.contains(optionalExpectedMsg.get))
-    } else if (setConfigurationHint.nonEmpty) {
-      assert(message.contains("with ANSI mode on"))
-      assert(message.contains(setConfigurationHint))
     } else {
       assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined)
+      if (ansiEnabled) {
+        assert(message.contains("with ANSI mode on"))
+        assert(message.contains(s"set ${SQLConf.ANSI_ENABLED.key} as false"))
+      }
     }
   }
 
-  // Whether the test suite is for TryCast. If yes, there is no exceptions and the result is
-  // always nullable.
-  protected def isTryCast: Boolean = false
-
-  protected def setConfigurationHint: String = ""
-
   test("null cast") {
     import DataTypeTestUtils._
 
@@ -281,8 +284,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast from string") {
-    assert(cast("abcdef", StringType).nullable === isTryCast)
-    assert(cast("abcdef", BinaryType).nullable === isTryCast)
+    assert(!cast("abcdef", StringType).nullable)
+    assert(!cast("abcdef", BinaryType).nullable)
     assert(cast("abcdef", BooleanType).nullable)
     assert(cast("abcdef", TimestampType).nullable)
     assert(cast("abcdef", LongType).nullable)
@@ -981,14 +984,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
         DayTimeIntervalType()), StringType), ansiInterval)
     }
 
-    if (!isTryCast) {
-      Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND",
-        "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval =>
-        val e = intercept[ArithmeticException] {
-          cast(Literal.create(interval), DayTimeIntervalType()).eval()
-        }.getMessage
-        assert(e.contains("long overflow"))
-      }
+    Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND",
+      "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval =>
+      val e = intercept[ArithmeticException] {
+        cast(Literal.create(interval), DayTimeIntervalType()).eval()
+      }.getMessage
+      assert(e.contains("long overflow"))
     }
 
     Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1,
@@ -1027,15 +1028,13 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
         YearMonthIntervalType()), StringType), ansiInterval)
     }
 
-    if (!isTryCast) {
-      Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
-        .foreach { interval =>
-          val e = intercept[IllegalArgumentException] {
-            cast(Literal.create(interval), YearMonthIntervalType()).eval()
-          }.getMessage
-          assert(e.contains("Error parsing interval year-month string: integer overflow"))
-        }
-    }
+    Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
+      .foreach { interval =>
+        val e = intercept[IllegalArgumentException] {
+          cast(Literal.create(interval), YearMonthIntervalType()).eval()
+        }.getMessage
+        assert(e.contains("Error parsing interval year-month string: integer overflow"))
+      }
 
     Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue)
       .foreach { period =>
@@ -1098,9 +1097,27 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
         }
       }
 
-    if (!isTryCast) {
-      Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
-        val dataType = YearMonthIntervalType()
+    Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
+      val dataType = YearMonthIntervalType()
+      val e = intercept[IllegalArgumentException] {
+        cast(Literal.create(interval), dataType).eval()
+      }.getMessage
+      assert(e.contains(s"Interval string does not match year-month format of " +
+        s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+          .map(format => s"`$format`").mkString(", ")} " +
+        s"when cast to ${dataType.typeName}: $interval"))
+    }
+    Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
+      ("1", YearMonthIntervalType(YEAR, MONTH)),
+      ("1-1", YearMonthIntervalType(YEAR)),
+      ("1-1", YearMonthIntervalType(MONTH)),
+      ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)),
+      ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)),
+      ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)),
+      ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)),
+      ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)),
+      ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH)))
+      .foreach { case (interval, dataType) =>
         val e = intercept[IllegalArgumentException] {
           cast(Literal.create(interval), dataType).eval()
         }.getMessage
@@ -1109,26 +1126,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
             .map(format => s"`$format`").mkString(", ")} " +
           s"when cast to ${dataType.typeName}: $interval"))
       }
-      Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
-        ("1", YearMonthIntervalType(YEAR, MONTH)),
-        ("1-1", YearMonthIntervalType(YEAR)),
-        ("1-1", YearMonthIntervalType(MONTH)),
-        ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(YEAR)),
-        ("INTERVAL '1-1' YEAR TO MONTH", YearMonthIntervalType(MONTH)),
-        ("INTERVAL '1' YEAR", YearMonthIntervalType(YEAR, MONTH)),
-        ("INTERVAL '1' YEAR", YearMonthIntervalType(MONTH)),
-        ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)),
-        ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH)))
-        .foreach { case (interval, dataType) =>
-          val e = intercept[IllegalArgumentException] {
-            cast(Literal.create(interval), dataType).eval()
-          }.getMessage
-          assert(e.contains(s"Interval string does not match year-month format of " +
-            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
-              .map(format => s"`$format`").mkString(", ")} " +
-            s"when cast to ${dataType.typeName}: $interval"))
-        }
-    }
   }
 
   test("SPARK-35735: Take into account day-time interval fields in cast") {
@@ -1218,63 +1215,61 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
         checkEvaluation(cast(Literal.create(interval), dataType), dt)
       }
 
-    if (!isTryCast) {
-      Seq(
-        ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)),
-        ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)),
-        ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
-        ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
-        ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)),
-        ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
-
-        ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)),
-        ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)),
-        ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
-        ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
-        ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)),
-        ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
-        ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)),
-        ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)),
-        ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)),
-        ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)),
-        ("1.23", DayTimeIntervalType(DAY)),
-        ("1.23", DayTimeIntervalType(HOUR)),
-        ("1.23", DayTimeIntervalType(MINUTE)),
-        ("1.23", DayTimeIntervalType(MINUTE)))
-        .foreach { case (interval, dataType) =>
-          val e = intercept[IllegalArgumentException] {
-            cast(Literal.create(interval), dataType).eval()
-          }.getMessage
-          assert(e.contains(s"Interval string does not match day-time format of " +
-            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
-              .map(format => s"`$format`").mkString(", ")} " +
-            s"when cast to ${dataType.typeName}: $interval, " +
-            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
-            "to restore the behavior before Spark 3.0."))
-        }
+    Seq(
+      ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)),
+      ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)),
+      ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+      ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
+      ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)),
+      ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
+
+      ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)),
+      ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)),
+      ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+      ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
+      ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)),
+      ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
+      ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)),
+      ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)),
+      ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)),
+      ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)),
+      ("1.23", DayTimeIntervalType(DAY)),
+      ("1.23", DayTimeIntervalType(HOUR)),
+      ("1.23", DayTimeIntervalType(MINUTE)),
+      ("1.23", DayTimeIntervalType(MINUTE)))
+      .foreach { case (interval, dataType) =>
+        val e = intercept[IllegalArgumentException] {
+          cast(Literal.create(interval), dataType).eval()
+        }.getMessage
+        assert(e.contains(s"Interval string does not match day-time format of " +
+          s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+            .map(format => s"`$format`").mkString(", ")} " +
+          s"when cast to ${dataType.typeName}: $interval, " +
+          s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+          "to restore the behavior before Spark 3.0."))
+      }
 
-      // Check first field outof bound
-      Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)),
-        ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)),
-        ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
-        ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()),
-        ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)),
-        ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)),
-        ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)),
-        ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)),
-        ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
-        ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
-        .foreach { case (interval, dataType) =>
-          val e = intercept[IllegalArgumentException] {
-            cast(Literal.create(interval), dataType).eval()
-          }.getMessage
-          assert(e.contains(s"Interval string does not match day-time format of " +
-            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
-              .map(format => s"`$format`").mkString(", ")} " +
-            s"when cast to ${dataType.typeName}: $interval, " +
-            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
-            "to restore the behavior before Spark 3.0."))
-        }
-    }
+    // Check first field outof bound
+    Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)),
+      ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)),
+      ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
+      ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()),
+      ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)),
+      ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)),
+      ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)),
+      ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)),
+      ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
+      ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
+      .foreach { case (interval, dataType) =>
+        val e = intercept[IllegalArgumentException] {
+          cast(Literal.create(interval), dataType).eval()
+        }.getMessage
+        assert(e.contains(s"Interval string does not match day-time format of " +
+          s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
+            .map(format => s"`$format`").mkString(", ")} " +
+          s"when cast to ${dataType.typeName}: $interval, " +
+          s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
+          "to restore the behavior before Spark 3.0."))
+      }
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
similarity index 98%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
index 630c45adba1..4e4bc096dea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
@@ -35,26 +35,10 @@ import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
- * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them
- *       in `CastSuiteBase` instead of this file to ensure the test coverage.
  */
-class CastSuite extends CastSuiteBase {
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false)
-  }
-
-  override def afterAll(): Unit = {
-    super.afterAll()
-    SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
-  }
+class CastWithAnsiOffSuite extends CastSuiteBase {
 
-  override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
-    v match {
-      case lit: Expression => Cast(lit, targetType, timeZoneId)
-      case _ => Cast(Literal(v), targetType, timeZoneId)
-    }
-  }
+  override def ansiEnabled: Boolean = false
 
   test("null cast #2") {
     import DataTypeTestUtils._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
similarity index 85%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
index 84f0d5c59aa..f2cfc529984 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
@@ -27,19 +27,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
 import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
- * Test suite base for
- *   1. [[Cast]] with ANSI mode enabled
- *   2. [[AnsiCast]]
- *   3. [[TryCast]]
- * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them
- *       in `CastSuiteBase` instead of this file to ensure the test coverage.
+ * Test suite for data type casting expression [[Cast]] with ANSI mode enabled.
  */
-abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
+class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
+
+  override def ansiEnabled: Boolean = true
 
   private def testIntMaxAndMin(dt: DataType): Unit = {
     assert(Seq(IntegerType, ShortType, ByteType).contains(dt))
@@ -339,25 +335,21 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
 
     {
       val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
-      assert(ret.resolved == !isTryCast)
-      if (!isTryCast) {
-        checkExceptionInExpression[SparkRuntimeException](
-          ret, """cannot be cast to "BOOLEAN"""")
-      }
+      assert(ret.resolved)
+      checkExceptionInExpression[SparkRuntimeException](
+        ret, """cannot be cast to "BOOLEAN"""")
     }
   }
 
   test("cast from array III") {
-    if (!isTryCast) {
-      val from: DataType = ArrayType(DoubleType, containsNull = false)
-      val array = Literal.create(Seq(1.0, 2.0), from)
-      val to: DataType = ArrayType(IntegerType, containsNull = false)
-      val answer = Literal.create(Seq(1, 2), to).value
-      checkEvaluation(cast(array, to), answer)
+    val from: DataType = ArrayType(DoubleType, containsNull = false)
+    val array = Literal.create(Seq(1.0, 2.0), from)
+    val to: DataType = ArrayType(IntegerType, containsNull = false)
+    val answer = Literal.create(Seq(1, 2), to).value
+    checkEvaluation(cast(array, to), answer)
 
-      val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from)
-      checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow")
-    }
+    val overflowArray = Literal.create(Seq(Int.MaxValue + 1.0D), from)
+    checkExceptionInExpression[ArithmeticException](cast(overflowArray, to), "overflow")
   }
 
   test("cast from map II") {
@@ -386,48 +378,40 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
 
     {
       val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
-      assert(ret.resolved == !isTryCast)
-      if (!isTryCast) {
-        checkExceptionInExpression[NumberFormatException](
-          ret,
-          castErrMsg("a", IntegerType))
-      }
+      assert(ret.resolved)
+      checkExceptionInExpression[NumberFormatException](
+        ret,
+        castErrMsg("a", IntegerType))
     }
 
     {
       val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
-      assert(ret.resolved == !isTryCast)
-      if (!isTryCast) {
-        checkExceptionInExpression[SparkRuntimeException](
-          ret,
-          castErrMsg("123", BooleanType))
-      }
+      assert(ret.resolved)
+      checkExceptionInExpression[SparkRuntimeException](
+        ret,
+        castErrMsg("123", BooleanType))
     }
 
     {
       val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
-      assert(ret.resolved == !isTryCast)
-      if (!isTryCast) {
-        checkExceptionInExpression[NumberFormatException](
-          ret,
-          castErrMsg("a", IntegerType))
-      }
+      assert(ret.resolved)
+      checkExceptionInExpression[NumberFormatException](
+        ret,
+        castErrMsg("a", IntegerType))
     }
   }
 
   test("cast from map III") {
-    if (!isTryCast) {
-      val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false)
-      val map = Literal.create(Map(1.0 -> 2.0), from)
-      val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false)
-      val answer = Literal.create(Map(1 -> 2), to).value
-      checkEvaluation(cast(map, to), answer)
-
-      Seq(
-        Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from),
-        Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap =>
-        checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow")
-      }
+    val from: DataType = MapType(DoubleType, DoubleType, valueContainsNull = false)
+    val map = Literal.create(Map(1.0 -> 2.0), from)
+    val to: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false)
+    val answer = Literal.create(Map(1 -> 2), to).value
+    checkEvaluation(cast(map, to), answer)
+
+    Seq(
+      Literal.create(Map((Int.MaxValue + 1.0) -> 2.0), from),
+      Literal.create(Map(1.0 -> (Int.MinValue - 1.0)), from)).foreach { overflowMap =>
+      checkExceptionInExpression[ArithmeticException](cast(overflowMap, to), "overflow")
     }
   }
 
@@ -487,26 +471,22 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
         StructField("a", BooleanType, nullable = true),
         StructField("b", BooleanType, nullable = true),
         StructField("c", BooleanType, nullable = false))))
-      assert(ret.resolved == !isTryCast)
-      if (!isTryCast) {
-        checkExceptionInExpression[SparkRuntimeException](
-          ret,
-          castErrMsg("123", BooleanType))
-      }
+      assert(ret.resolved)
+      checkExceptionInExpression[SparkRuntimeException](
+        ret,
+        castErrMsg("123", BooleanType))
     }
   }
 
   test("cast from struct III") {
-    if (!isTryCast) {
-      val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false)))
-      val struct = Literal.create(InternalRow(1.0), from)
-      val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false)))
-      val answer = Literal.create(InternalRow(1), to).value
-      checkEvaluation(cast(struct, to), answer)
+    val from: DataType = StructType(Seq(StructField("a", DoubleType, nullable = false)))
+    val struct = Literal.create(InternalRow(1.0), from)
+    val to: DataType = StructType(Seq(StructField("a", IntegerType, nullable = false)))
+    val answer = Literal.create(InternalRow(1), to).value
+    checkEvaluation(cast(struct, to), answer)
 
-      val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from)
-      checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow")
-    }
+    val overflowStruct = Literal.create(InternalRow(Int.MaxValue + 1.0), from)
+    checkExceptionInExpression[ArithmeticException](cast(overflowStruct, to), "overflow")
   }
 
   test("complex casting") {
@@ -533,12 +513,10 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
         StructType(Seq(
           StructField("l", LongType, nullable = true)))))))
 
-    assert(ret.resolved === !isTryCast)
-    if (!isTryCast) {
-      checkExceptionInExpression[NumberFormatException](
-        ret,
-        castErrMsg("true", IntegerType))
-    }
+    assert(ret.resolved)
+    checkExceptionInExpression[NumberFormatException](
+      ret,
+      castErrMsg("true", IntegerType))
   }
 
   test("ANSI mode: cast string to timestamp with parse error") {
@@ -599,28 +577,3 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase with QueryErrorsBase {
     }
   }
 }
-
-/**
- * Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
- */
-class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
-  }
-
-  override def afterAll(): Unit = {
-    super.afterAll()
-    SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
-  }
-
-  override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
-    v match {
-      case lit: Expression => Cast(lit, targetType, timeZoneId)
-      case _ => Cast(Literal(v), targetType, timeZoneId)
-    }
-  }
-
-  override def setConfigurationHint: String =
-    s"set ${SQLConf.ANSI_ENABLED.key} as false"
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
index bb9ab888947..bb66a9fd24a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
@@ -17,40 +17,65 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import scala.reflect.ClassTag
-
+import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+// A test suite to check analysis behaviors of `TryCast`.
+class TryCastSuite extends SparkFunSuite {
 
-class TryCastSuite extends AnsiCastSuiteBase {
-  override protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String]) = {
+  private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): TryCast = {
     v match {
       case lit: Expression => TryCast(lit, targetType, timeZoneId)
       case _ => TryCast(Literal(v), targetType, timeZoneId)
     }
   }
 
-  override def isTryCast: Boolean = true
-
-  override protected def setConfigurationHint: String = ""
-
-  override def checkExceptionInExpression[T <: Throwable : ClassTag](
-      expression: => Expression,
-      inputRow: InternalRow,
-      expectedErrMsg: String): Unit = {
-    checkEvaluation(expression, null, inputRow)
+  test("print string") {
+    assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
+    assert(TryCast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)")
   }
 
-  override def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
-    checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+  test("nullability") {
+    assert(cast("abcdef", StringType).nullable)
+    assert(cast("abcdef", BinaryType).nullable)
   }
 
-  override def checkCastToNumericError(l: Literal, to: DataType,
-      expectedDataTypeInErrorMsg: DataType, tryCastResult: Any): Unit = {
-    checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+  test("only require timezone for datetime types") {
+    assert(cast("abc", IntegerType).resolved)
+    assert(!cast("abc", TimestampType).resolved)
+    assert(cast("abc", TimestampType, UTC_OPT).resolved)
   }
 
-  test("try_cast: to_string") {
-    assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
+  test("element type nullability") {
+    val array = Literal.create(Seq("123", "true"),
+      ArrayType(StringType, containsNull = false))
+    // array element can be null after try_cast which violates the target type.
+    val c1 = cast(array, ArrayType(BooleanType, containsNull = false))
+    assert(!c1.resolved)
+
+    val map = Literal.create(Map("a" -> "123", "b" -> "true"),
+      MapType(StringType, StringType, valueContainsNull = false))
+    // key can be null after try_cast which violates the map key requirement.
+    val c2 = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
+    assert(!c2.resolved)
+    // map value can be null after try_cast which violates the target type.
+    val c3 = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
+    assert(!c3.resolved)
+
+    val struct = Literal.create(
+      InternalRow(
+        UTF8String.fromString("123"),
+        UTF8String.fromString("true")),
+      new StructType()
+        .add("a", StringType, nullable = true)
+        .add("b", StringType, nullable = true))
+    // struct field `b` can be null after try_cast which violates the target type.
+    val c4 = cast(struct, new StructType()
+      .add("a", BooleanType, nullable = true)
+      .add("b", BooleanType, nullable = false))
+    assert(!c4.resolved)
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index dc92cf24ab1..e1d1f064e34 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -146,7 +146,7 @@ class ConstantFoldingSuite extends PlanTest {
       testRelation
         .select(
           Cast(Literal("2"), IntegerType) + Literal(3) + $"a" as "c1",
-          Coalesce(Seq(TryCast(Literal("abc"), IntegerType), Literal(3))) as "c2")
+          Coalesce(Seq(TryCast(Literal("abc"), IntegerType).replacement, Literal(3))) as "c2")
 
     val optimized = Optimize.execute(originalQuery.analyze)
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 67bb72c1878..95e5582cb8c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -1146,8 +1146,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
     // client-side filtering cannot be used with TimeZoneAwareExpression.
     def hasTimeZoneAwareExpression(e: Expression): Boolean = {
       e.exists {
-        case cast: CastBase => cast.needsTimeZone
-        case tz: TimeZoneAwareExpression => !tz.isInstanceOf[CastBase]
+        case cast: Cast => cast.needsTimeZone
+        case tz: TimeZoneAwareExpression => !tz.isInstanceOf[Cast]
         case _ => false
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org