You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/27 02:47:14 UTC

[spark] branch master updated: [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 686e37e640d [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression
686e37e640d is described below

commit 686e37e640d078f9727e5457e47ce58033ce8684
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Mon Jun 26 22:47:01 2023 -0400

    [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression
    
    ### What changes were proposed in this pull request?
    
    Implement DataTypeExpression to offer `Unapply` for expression. By doing so we can drop `Unapply` from DataType.
    
    ### Why are the changes needed?
    
    Simplify DataType interface.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #41559 from amaliujia/move_datatypes_1.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../sql/catalyst/analysis/AnsiTypeCoercion.scala   | 29 +++++-----
 .../sql/catalyst/analysis/DecimalPrecision.scala   | 20 +++----
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 46 +++++++--------
 .../apache/spark/sql/types/AbstractDataType.scala  | 40 +------------
 .../org/apache/spark/sql/types/DataType.scala      | 10 ----
 .../spark/sql/types/DataTypeExpression.scala       | 67 ++++++++++++++++++++++
 .../apache/spark/sql/hive/client/HiveShim.scala    |  4 +-
 7 files changed, 119 insertions(+), 97 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index 56dbb2a8590..d3f20f87493 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -244,28 +244,29 @@ object AnsiTypeCoercion extends TypeCoercionBase {
         val promoteType = findWiderTypeForString(left.dataType, right.dataType).get
         b.withNewChildren(Seq(castExpr(left, promoteType), castExpr(right, promoteType)))
 
-      case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
-      case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType)))
-      case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
+      case Abs(e @ StringTypeExpression(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
+      case m @ UnaryMinus(e @ StringTypeExpression(), _) =>
+        m.withNewChildren(Seq(Cast(e, DoubleType)))
+      case UnaryPositive(e @ StringTypeExpression()) => UnaryPositive(Cast(e, DoubleType))
 
-      case d @ DateAdd(left @ StringType(), _) =>
+      case d @ DateAdd(left @ StringTypeExpression(), _) =>
         d.copy(startDate = Cast(d.startDate, DateType))
-      case d @ DateAdd(_, right @ StringType()) =>
+      case d @ DateAdd(_, right @ StringTypeExpression()) =>
         d.copy(days = Cast(right, IntegerType))
-      case d @ DateSub(left @ StringType(), _) =>
+      case d @ DateSub(left @ StringTypeExpression(), _) =>
         d.copy(startDate = Cast(d.startDate, DateType))
-      case d @ DateSub(_, right @ StringType()) =>
+      case d @ DateSub(_, right @ StringTypeExpression()) =>
         d.copy(days = Cast(right, IntegerType))
 
-      case s @ SubtractDates(left @ StringType(), _, _) =>
+      case s @ SubtractDates(left @ StringTypeExpression(), _, _) =>
         s.copy(left = Cast(s.left, DateType))
-      case s @ SubtractDates(_, right @ StringType(), _) =>
+      case s @ SubtractDates(_, right @ StringTypeExpression(), _) =>
         s.copy(right = Cast(s.right, DateType))
-      case t @ TimeAdd(left @ StringType(), _, _) =>
+      case t @ TimeAdd(left @ StringTypeExpression(), _, _) =>
         t.copy(start = Cast(t.start, TimestampType))
-      case t @ SubtractTimestamps(left @ StringType(), _, _, _) =>
+      case t @ SubtractTimestamps(left @ StringTypeExpression(), _, _, _) =>
         t.copy(left = Cast(t.left, t.right.dataType))
-      case t @ SubtractTimestamps(_, right @ StringType(), _, _) =>
+      case t @ SubtractTimestamps(_, right @ StringTypeExpression(), _, _) =>
         t.copy(right = Cast(right, t.left.dataType))
     }
   }
@@ -296,9 +297,9 @@ object AnsiTypeCoercion extends TypeCoercionBase {
       case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
       case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
 
-      case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) =>
+      case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) =>
         s.copy(left = Cast(s.left, s.right.dataType))
-      case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) =>
+      case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) =>
         s.copy(right = Cast(s.right, s.left.dataType))
       case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _)
         if s.left.dataType != s.right.dataType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 46fbf071f43..90fd13dfb54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -105,7 +105,7 @@ object DecimalPrecision extends TypeCoercionRule {
    */
   private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = {
 
-    case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) =>
+    case GreaterThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         TrueLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -114,7 +114,7 @@ object DecimalPrecision extends TypeCoercionRule {
         GreaterThan(i, Literal(value.floor.toLong))
       }
 
-    case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+    case GreaterThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         TrueLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -123,7 +123,7 @@ object DecimalPrecision extends TypeCoercionRule {
         GreaterThanOrEqual(i, Literal(value.ceil.toLong))
       }
 
-    case LessThan(i @ IntegralType(), DecimalLiteral(value)) =>
+    case LessThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         FalseLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -132,7 +132,7 @@ object DecimalPrecision extends TypeCoercionRule {
         LessThan(i, Literal(value.ceil.toLong))
       }
 
-    case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+    case LessThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         FalseLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -141,7 +141,7 @@ object DecimalPrecision extends TypeCoercionRule {
         LessThanOrEqual(i, Literal(value.floor.toLong))
       }
 
-    case GreaterThan(DecimalLiteral(value), i @ IntegralType()) =>
+    case GreaterThan(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         FalseLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -150,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
         GreaterThan(Literal(value.ceil.toLong), i)
       }
 
-    case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+    case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         FalseLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -159,7 +159,7 @@ object DecimalPrecision extends TypeCoercionRule {
         GreaterThanOrEqual(Literal(value.floor.toLong), i)
       }
 
-    case LessThan(DecimalLiteral(value), i @ IntegralType()) =>
+    case LessThan(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         TrueLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -168,7 +168,7 @@ object DecimalPrecision extends TypeCoercionRule {
         LessThan(Literal(value.floor.toLong), i)
       }
 
-    case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+    case LessThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
       if (DecimalLiteral.smallerThanSmallestLong(value)) {
         TrueLiteral
       } else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -208,9 +208,9 @@ object DecimalPrecision extends TypeCoercionRule {
           b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
         // Promote integers inside a binary expression with fixed-precision decimals to decimals,
         // and fixed-precision decimals in an expression with floats / doubles to doubles
-        case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
+        case (l @ IntegralTypeExpression(), r @ DecimalType.Expression(_, _)) =>
           b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
-        case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
+        case (l @ DecimalType.Expression(_, _), r @ IntegralTypeExpression()) =>
           b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
         case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
           b.makeCopy(Array(l, Cast(r, DoubleType)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index bd2255134fc..ae4db0575ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -461,8 +461,8 @@ abstract class TypeCoercionBase {
         m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
 
       // Hive lets you do aggregation of timestamps... for some reason
-      case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType))
-      case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType))
+      case Sum(e @ TimestampTypeExpression(), _) => Sum(Cast(e, DoubleType))
+      case Average(e @ TimestampTypeExpression(), _) => Average(Cast(e, DoubleType))
 
       // Coalesce should return the first non-null value, which could be any column
       // from the list. So we need to make sure the return type is deterministic and
@@ -1105,18 +1105,18 @@ object TypeCoercion extends TypeCoercionBase {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a @ BinaryArithmetic(left @ StringType(), right)
+      case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
         if right.dataType != CalendarIntervalType =>
         a.makeCopy(Array(Cast(left, DoubleType), right))
-      case a @ BinaryArithmetic(left, right @ StringType())
+      case a @ BinaryArithmetic(left, right @ StringTypeExpression())
         if left.dataType != CalendarIntervalType =>
         a.makeCopy(Array(left, Cast(right, DoubleType)))
 
       // For equality between string and timestamp we cast the string to a timestamp
       // so that things like rounding of subsecond precision does not affect the comparison.
-      case p @ Equality(left @ StringType(), right @ TimestampType()) =>
+      case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) =>
         p.makeCopy(Array(Cast(left, TimestampType), right))
-      case p @ Equality(left @ TimestampType(), right @ StringType()) =>
+      case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) =>
         p.makeCopy(Array(left, Cast(right, TimestampType)))
 
       case p @ BinaryComparison(left, right)
@@ -1142,30 +1142,30 @@ object TypeCoercion extends TypeCoercionBase {
 
       // We may simplify the expression if one side is literal numeric values
       // TODO: Maybe these rules should go into the optimizer.
-      case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+      case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
         if trueValues.contains(value) => bool
-      case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+      case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
         if falseValues.contains(value) => Not(bool)
-      case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+      case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
         if trueValues.contains(value) => bool
-      case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+      case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
         if falseValues.contains(value) => Not(bool)
-      case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+      case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
         if trueValues.contains(value) => And(IsNotNull(bool), bool)
-      case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+      case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
         if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
-      case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+      case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
         if trueValues.contains(value) => And(IsNotNull(bool), bool)
-      case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+      case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
         if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
 
-      case EqualTo(left @ BooleanType(), right @ NumericType()) =>
+      case EqualTo(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) =>
         EqualTo(Cast(left, right.dataType), right)
-      case EqualTo(left @ NumericType(), right @ BooleanType()) =>
+      case EqualTo(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) =>
         EqualTo(left, Cast(right, left.dataType))
-      case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
+      case EqualNullSafe(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) =>
         EqualNullSafe(Cast(left, right.dataType), right)
-      case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
+      case EqualNullSafe(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) =>
         EqualNullSafe(left, Cast(right, left.dataType))
     }
   }
@@ -1175,13 +1175,13 @@ object TypeCoercion extends TypeCoercionBase {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
       case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
-      case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
+      case d @ DateAdd(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType))
       case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
-      case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
+      case d @ DateSub(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType))
 
-      case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) =>
+      case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) =>
         s.copy(left = Cast(s.left, s.right.dataType))
-      case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) =>
+      case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) =>
         s.copy(right = Cast(s.right, s.left.dataType))
       case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _)
         if s.left.dataType != s.right.dataType =>
@@ -1189,7 +1189,7 @@ object TypeCoercion extends TypeCoercionBase {
         val newRight = castIfNotSameType(s.right, TimestampNTZType)
         s.copy(left = newLeft, right = newRight)
 
-      case t @ TimeAdd(StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
+      case t @ TimeAdd(StringTypeExpression(), _, _) => t.copy(start = Cast(t.start, TimestampType))
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index f498282d4f3..01fa27822b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -121,16 +121,7 @@ protected[sql] object AnyDataType extends AbstractDataType with Serializable {
  */
 protected[sql] abstract class AtomicType extends DataType
 
-object AtomicType {
-  /**
-   * Enables matching against AtomicType for expressions:
-   * {{{
-   *   case Cast(child @ AtomicType(), StringType) =>
-   *     ...
-   * }}}
-   */
-  def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[AtomicType]
-}
+object AtomicType
 
 
 /**
@@ -143,15 +134,6 @@ abstract class NumericType extends AtomicType
 
 
 private[spark] object NumericType extends AbstractDataType {
-  /**
-   * Enables matching against NumericType for expressions:
-   * {{{
-   *   case Cast(child @ NumericType(), StringType) =>
-   *     ...
-   * }}}
-   */
-  def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-
   override private[spark] def defaultConcreteType: DataType = DoubleType
 
   override private[spark] def simpleString: String = "numeric"
@@ -162,15 +144,6 @@ private[spark] object NumericType extends AbstractDataType {
 
 
 private[sql] object IntegralType extends AbstractDataType {
-  /**
-   * Enables matching against IntegralType for expressions:
-   * {{{
-   *   case Cast(child @ IntegralType(), StringType) =>
-   *     ...
-   * }}}
-   */
-  def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
-
   override private[sql] def defaultConcreteType: DataType = IntegerType
 
   override private[sql] def simpleString: String = "integral"
@@ -182,16 +155,7 @@ private[sql] object IntegralType extends AbstractDataType {
 private[sql] abstract class IntegralType extends NumericType
 
 
-private[sql] object FractionalType {
-  /**
-   * Enables matching against FractionalType for expressions:
-   * {{{
-   *   case Cast(child @ FractionalType(), StringType) =>
-   *     ...
-   * }}}
-   */
-  def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
-}
+private[sql] object FractionalType
 
 
 private[sql] abstract class FractionalType extends NumericType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index f78a8de5e6a..893a41f3e39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -30,7 +30,6 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.SparkThrowable
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.analysis.Resolver
-import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
@@ -50,15 +49,6 @@ import org.apache.spark.util.Utils
 @JsonSerialize(using = classOf[DataTypeJsonSerializer])
 @JsonDeserialize(using = classOf[DataTypeJsonDeserializer])
 abstract class DataType extends AbstractDataType {
-  /**
-   * Enables matching against DataType for expressions:
-   * {{{
-   *   case Cast(child @ BinaryType(), StringType) =>
-   *     ...
-   * }}}
-   */
-  private[sql] def unapply(e: Expression): Boolean = e.dataType == this
-
   /**
    * The default size of a value of this data type, used internally for size estimation.
    */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala
new file mode 100644
index 00000000000..f88e266b943
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+abstract class DataTypeExpression(val dataType: DataType) {
+  /**
+   * Enables matching against DataType for expressions:
+   * {{{
+   *   case Cast(child @ BinaryType(), StringType) =>
+   *     ...
+   * }}}
+   */
+  private[sql] def unapply(e: Expression): Boolean = e.dataType == dataType
+}
+
+case object BooleanTypeExpression extends DataTypeExpression(BooleanType)
+case object StringTypeExpression extends DataTypeExpression(StringType)
+case object TimestampTypeExpression extends DataTypeExpression(TimestampType)
+case object DateTypeExpression extends DataTypeExpression(DateType)
+case object ByteTypeExpression extends DataTypeExpression(ByteType)
+case object ShortTypeExpression extends DataTypeExpression(ShortType)
+case object IntegerTypeExpression extends DataTypeExpression(IntegerType)
+case object LongTypeExpression extends DataTypeExpression(LongType)
+case object DoubleTypeExpression extends DataTypeExpression(DoubleType)
+case object FloatTypeExpression extends DataTypeExpression(FloatType)
+
+object NumericTypeExpression {
+  /**
+   * Enables matching against NumericType for expressions:
+   * {{{
+   *   case Cast(child @ NumericType(), StringType) =>
+   *     ...
+   * }}}
+   */
+  def unapply(e: Expression): Boolean = {
+    e.dataType.isInstanceOf[NumericType]
+  }
+}
+
+object IntegralTypeExpression {
+  /**
+   * Enables matching against IntegralType for expressions:
+   * {{{
+   *   case Cast(child @ IntegralType(), StringType) =>
+   *     ...
+   * }}}
+   */
+  def unapply(e: Expression): Boolean = {
+    e.dataType.isInstanceOf[IntegralType]
+  }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 9defd87aa7d..08615b90d80 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -50,7 +50,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateFormatter, Type
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.execution.datasources.PartitioningUtils
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, StringType}
+import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, IntegralTypeExpression, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
@@ -987,7 +987,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
       def unapply(expr: Expression): Option[Attribute] = {
         expr match {
           case attr: Attribute => Some(attr)
-          case Cast(child @ IntegralType(), dt: IntegralType, _, _)
+          case Cast(child @ IntegralTypeExpression(), dt: IntegralType, _, _)
               if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
           case _ => None
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org