You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/27 02:47:14 UTC
[spark] branch master updated: [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 686e37e640d [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression
686e37e640d is described below
commit 686e37e640d078f9727e5457e47ce58033ce8684
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Mon Jun 26 22:47:01 2023 -0400
[SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression
### What changes were proposed in this pull request?
Implement DataTypeExpression to offer `Unapply` for expression. By doing so we can drop `Unapply` from DataType.
### Why are the changes needed?
Simplify DataType interface.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing tests
Closes #41559 from amaliujia/move_datatypes_1.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../sql/catalyst/analysis/AnsiTypeCoercion.scala | 29 +++++-----
.../sql/catalyst/analysis/DecimalPrecision.scala | 20 +++----
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 46 +++++++--------
.../apache/spark/sql/types/AbstractDataType.scala | 40 +------------
.../org/apache/spark/sql/types/DataType.scala | 10 ----
.../spark/sql/types/DataTypeExpression.scala | 67 ++++++++++++++++++++++
.../apache/spark/sql/hive/client/HiveShim.scala | 4 +-
7 files changed, 119 insertions(+), 97 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index 56dbb2a8590..d3f20f87493 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -244,28 +244,29 @@ object AnsiTypeCoercion extends TypeCoercionBase {
val promoteType = findWiderTypeForString(left.dataType, right.dataType).get
b.withNewChildren(Seq(castExpr(left, promoteType), castExpr(right, promoteType)))
- case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
- case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType)))
- case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
+ case Abs(e @ StringTypeExpression(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
+ case m @ UnaryMinus(e @ StringTypeExpression(), _) =>
+ m.withNewChildren(Seq(Cast(e, DoubleType)))
+ case UnaryPositive(e @ StringTypeExpression()) => UnaryPositive(Cast(e, DoubleType))
- case d @ DateAdd(left @ StringType(), _) =>
+ case d @ DateAdd(left @ StringTypeExpression(), _) =>
d.copy(startDate = Cast(d.startDate, DateType))
- case d @ DateAdd(_, right @ StringType()) =>
+ case d @ DateAdd(_, right @ StringTypeExpression()) =>
d.copy(days = Cast(right, IntegerType))
- case d @ DateSub(left @ StringType(), _) =>
+ case d @ DateSub(left @ StringTypeExpression(), _) =>
d.copy(startDate = Cast(d.startDate, DateType))
- case d @ DateSub(_, right @ StringType()) =>
+ case d @ DateSub(_, right @ StringTypeExpression()) =>
d.copy(days = Cast(right, IntegerType))
- case s @ SubtractDates(left @ StringType(), _, _) =>
+ case s @ SubtractDates(left @ StringTypeExpression(), _, _) =>
s.copy(left = Cast(s.left, DateType))
- case s @ SubtractDates(_, right @ StringType(), _) =>
+ case s @ SubtractDates(_, right @ StringTypeExpression(), _) =>
s.copy(right = Cast(s.right, DateType))
- case t @ TimeAdd(left @ StringType(), _, _) =>
+ case t @ TimeAdd(left @ StringTypeExpression(), _, _) =>
t.copy(start = Cast(t.start, TimestampType))
- case t @ SubtractTimestamps(left @ StringType(), _, _, _) =>
+ case t @ SubtractTimestamps(left @ StringTypeExpression(), _, _, _) =>
t.copy(left = Cast(t.left, t.right.dataType))
- case t @ SubtractTimestamps(_, right @ StringType(), _, _) =>
+ case t @ SubtractTimestamps(_, right @ StringTypeExpression(), _, _) =>
t.copy(right = Cast(right, t.left.dataType))
}
}
@@ -296,9 +297,9 @@ object AnsiTypeCoercion extends TypeCoercionBase {
case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
- case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) =>
+ case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) =>
s.copy(left = Cast(s.left, s.right.dataType))
- case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) =>
+ case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) =>
s.copy(right = Cast(s.right, s.left.dataType))
case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _)
if s.left.dataType != s.right.dataType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 46fbf071f43..90fd13dfb54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -105,7 +105,7 @@ object DecimalPrecision extends TypeCoercionRule {
*/
private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = {
- case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ case GreaterThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
TrueLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -114,7 +114,7 @@ object DecimalPrecision extends TypeCoercionRule {
GreaterThan(i, Literal(value.floor.toLong))
}
- case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ case GreaterThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
TrueLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -123,7 +123,7 @@ object DecimalPrecision extends TypeCoercionRule {
GreaterThanOrEqual(i, Literal(value.ceil.toLong))
}
- case LessThan(i @ IntegralType(), DecimalLiteral(value)) =>
+ case LessThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
FalseLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -132,7 +132,7 @@ object DecimalPrecision extends TypeCoercionRule {
LessThan(i, Literal(value.ceil.toLong))
}
- case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+ case LessThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
FalseLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -141,7 +141,7 @@ object DecimalPrecision extends TypeCoercionRule {
LessThanOrEqual(i, Literal(value.floor.toLong))
}
- case GreaterThan(DecimalLiteral(value), i @ IntegralType()) =>
+ case GreaterThan(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
FalseLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -150,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
GreaterThan(Literal(value.ceil.toLong), i)
}
- case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
FalseLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -159,7 +159,7 @@ object DecimalPrecision extends TypeCoercionRule {
GreaterThanOrEqual(Literal(value.floor.toLong), i)
}
- case LessThan(DecimalLiteral(value), i @ IntegralType()) =>
+ case LessThan(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
TrueLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -168,7 +168,7 @@ object DecimalPrecision extends TypeCoercionRule {
LessThan(Literal(value.floor.toLong), i)
}
- case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+ case LessThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) =>
if (DecimalLiteral.smallerThanSmallestLong(value)) {
TrueLiteral
} else if (DecimalLiteral.largerThanLargestLong(value)) {
@@ -208,9 +208,9 @@ object DecimalPrecision extends TypeCoercionRule {
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
- case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
+ case (l @ IntegralTypeExpression(), r @ DecimalType.Expression(_, _)) =>
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
- case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
+ case (l @ DecimalType.Expression(_, _), r @ IntegralTypeExpression()) =>
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
b.makeCopy(Array(l, Cast(r, DoubleType)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index bd2255134fc..ae4db0575ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -461,8 +461,8 @@ abstract class TypeCoercionBase {
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Hive lets you do aggregation of timestamps... for some reason
- case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType))
- case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType))
+ case Sum(e @ TimestampTypeExpression(), _) => Sum(Cast(e, DoubleType))
+ case Average(e @ TimestampTypeExpression(), _) => Average(Cast(e, DoubleType))
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
@@ -1105,18 +1105,18 @@ object TypeCoercion extends TypeCoercionBase {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), right)
+ case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
if right.dataType != CalendarIntervalType =>
a.makeCopy(Array(Cast(left, DoubleType), right))
- case a @ BinaryArithmetic(left, right @ StringType())
+ case a @ BinaryArithmetic(left, right @ StringTypeExpression())
if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
// For equality between string and timestamp we cast the string to a timestamp
// so that things like rounding of subsecond precision does not affect the comparison.
- case p @ Equality(left @ StringType(), right @ TimestampType()) =>
+ case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) =>
p.makeCopy(Array(Cast(left, TimestampType), right))
- case p @ Equality(left @ TimestampType(), right @ StringType()) =>
+ case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
case p @ BinaryComparison(left, right)
@@ -1142,30 +1142,30 @@ object TypeCoercion extends TypeCoercionBase {
// We may simplify the expression if one side is literal numeric values
// TODO: Maybe these rules should go into the optimizer.
- case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+ case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
if trueValues.contains(value) => bool
- case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
+ case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(bool)
- case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+ case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
if trueValues.contains(value) => bool
- case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
+ case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
if falseValues.contains(value) => Not(bool)
- case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+ case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(bool), bool)
- case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
+ case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
- case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+ case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
if trueValues.contains(value) => And(IsNotNull(bool), bool)
- case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
+ case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression())
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
- case EqualTo(left @ BooleanType(), right @ NumericType()) =>
+ case EqualTo(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) =>
EqualTo(Cast(left, right.dataType), right)
- case EqualTo(left @ NumericType(), right @ BooleanType()) =>
+ case EqualTo(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) =>
EqualTo(left, Cast(right, left.dataType))
- case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
+ case EqualNullSafe(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) =>
EqualNullSafe(Cast(left, right.dataType), right)
- case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
+ case EqualNullSafe(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) =>
EqualNullSafe(left, Cast(right, left.dataType))
}
}
@@ -1175,13 +1175,13 @@ object TypeCoercion extends TypeCoercionBase {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
- case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
+ case d @ DateAdd(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType))
case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
- case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
+ case d @ DateSub(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType))
- case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) =>
+ case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) =>
s.copy(left = Cast(s.left, s.right.dataType))
- case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) =>
+ case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) =>
s.copy(right = Cast(s.right, s.left.dataType))
case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _)
if s.left.dataType != s.right.dataType =>
@@ -1189,7 +1189,7 @@ object TypeCoercion extends TypeCoercionBase {
val newRight = castIfNotSameType(s.right, TimestampNTZType)
s.copy(left = newLeft, right = newRight)
- case t @ TimeAdd(StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
+ case t @ TimeAdd(StringTypeExpression(), _, _) => t.copy(start = Cast(t.start, TimestampType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index f498282d4f3..01fa27822b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -121,16 +121,7 @@ protected[sql] object AnyDataType extends AbstractDataType with Serializable {
*/
protected[sql] abstract class AtomicType extends DataType
-object AtomicType {
- /**
- * Enables matching against AtomicType for expressions:
- * {{{
- * case Cast(child @ AtomicType(), StringType) =>
- * ...
- * }}}
- */
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[AtomicType]
-}
+object AtomicType
/**
@@ -143,15 +134,6 @@ abstract class NumericType extends AtomicType
private[spark] object NumericType extends AbstractDataType {
- /**
- * Enables matching against NumericType for expressions:
- * {{{
- * case Cast(child @ NumericType(), StringType) =>
- * ...
- * }}}
- */
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-
override private[spark] def defaultConcreteType: DataType = DoubleType
override private[spark] def simpleString: String = "numeric"
@@ -162,15 +144,6 @@ private[spark] object NumericType extends AbstractDataType {
private[sql] object IntegralType extends AbstractDataType {
- /**
- * Enables matching against IntegralType for expressions:
- * {{{
- * case Cast(child @ IntegralType(), StringType) =>
- * ...
- * }}}
- */
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
-
override private[sql] def defaultConcreteType: DataType = IntegerType
override private[sql] def simpleString: String = "integral"
@@ -182,16 +155,7 @@ private[sql] object IntegralType extends AbstractDataType {
private[sql] abstract class IntegralType extends NumericType
-private[sql] object FractionalType {
- /**
- * Enables matching against FractionalType for expressions:
- * {{{
- * case Cast(child @ FractionalType(), StringType) =>
- * ...
- * }}}
- */
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
-}
+private[sql] object FractionalType
private[sql] abstract class FractionalType extends NumericType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index f78a8de5e6a..893a41f3e39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -30,7 +30,6 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkThrowable
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
-import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
@@ -50,15 +49,6 @@ import org.apache.spark.util.Utils
@JsonSerialize(using = classOf[DataTypeJsonSerializer])
@JsonDeserialize(using = classOf[DataTypeJsonDeserializer])
abstract class DataType extends AbstractDataType {
- /**
- * Enables matching against DataType for expressions:
- * {{{
- * case Cast(child @ BinaryType(), StringType) =>
- * ...
- * }}}
- */
- private[sql] def unapply(e: Expression): Boolean = e.dataType == this
-
/**
* The default size of a value of this data type, used internally for size estimation.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala
new file mode 100644
index 00000000000..f88e266b943
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+abstract class DataTypeExpression(val dataType: DataType) {
+ /**
+ * Enables matching against DataType for expressions:
+ * {{{
+ * case Cast(child @ BinaryType(), StringType) =>
+ * ...
+ * }}}
+ */
+ private[sql] def unapply(e: Expression): Boolean = e.dataType == dataType
+}
+
+case object BooleanTypeExpression extends DataTypeExpression(BooleanType)
+case object StringTypeExpression extends DataTypeExpression(StringType)
+case object TimestampTypeExpression extends DataTypeExpression(TimestampType)
+case object DateTypeExpression extends DataTypeExpression(DateType)
+case object ByteTypeExpression extends DataTypeExpression(ByteType)
+case object ShortTypeExpression extends DataTypeExpression(ShortType)
+case object IntegerTypeExpression extends DataTypeExpression(IntegerType)
+case object LongTypeExpression extends DataTypeExpression(LongType)
+case object DoubleTypeExpression extends DataTypeExpression(DoubleType)
+case object FloatTypeExpression extends DataTypeExpression(FloatType)
+
+object NumericTypeExpression {
+ /**
+ * Enables matching against NumericType for expressions:
+ * {{{
+ * case Cast(child @ NumericType(), StringType) =>
+ * ...
+ * }}}
+ */
+ def unapply(e: Expression): Boolean = {
+ e.dataType.isInstanceOf[NumericType]
+ }
+}
+
+object IntegralTypeExpression {
+ /**
+ * Enables matching against IntegralType for expressions:
+ * {{{
+ * case Cast(child @ IntegralType(), StringType) =>
+ * ...
+ * }}}
+ */
+ def unapply(e: Expression): Boolean = {
+ e.dataType.isInstanceOf[IntegralType]
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 9defd87aa7d..08615b90d80 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -50,7 +50,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateFormatter, Type
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, StringType}
+import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, IntegralTypeExpression, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -987,7 +987,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
def unapply(expr: Expression): Option[Attribute] = {
expr match {
case attr: Attribute => Some(attr)
- case Cast(child @ IntegralType(), dt: IntegralType, _, _)
+ case Cast(child @ IntegralTypeExpression(), dt: IntegralType, _, _)
if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
case _ => None
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org