You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2019/09/29 07:47:33 UTC
[spark] 01/01: Add replaceConstraintsWithCast
This is an automated email from the ASF dual-hosted git repository.
yumwang pushed a commit to branch SPARK-29231
in repository https://gitbox.apache.org/repos/asf/spark.git
commit 52c43ed599fcb1004d6132415c8f3d034933e983
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Sun Sep 29 15:46:01 2019 +0800
Add replaceConstraintsWithCast
---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +-
.../plans/logical/QueryPlanConstraints.scala | 21 +++++++++++++++++++++
2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 8ab25f2..f1e72b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -266,7 +266,7 @@ object TypeCoercion {
* system limitation, this rule will truncate the decimal type. If a decimal and other fractional
* types are compared, returns a double type.
*/
- private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
+ def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
(dt1, dt2) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
index 1355003..fc6211f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{findTightestCommonType, findWiderTypeForDecimal}
import org.apache.spark.sql.catalyst.expressions._
@@ -67,6 +68,14 @@ trait ConstraintHelper {
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
+
+ case eq @ EqualTo(Cast(l: Attribute, dataType, _), r: Attribute)
+ if findTightestCommonType(l.dataType, r.dataType)
+ .orElse(findWiderTypeForDecimal(l.dataType, r.dataType)).exists(_.sameType(dataType)) =>
+ val candidateConstraints = constraints - eq
+ inferredConstraints ++= replaceConstraintsWithCast(candidateConstraints, l, r)
+ inferredConstraints ++= replaceConstraintsWithCast(candidateConstraints, r, l)
+
case _ => // No inference
}
inferredConstraints -- constraints
@@ -79,6 +88,18 @@ trait ConstraintHelper {
case e: Expression if e.semanticEquals(source) => destination
})
+ private def replaceConstraintsWithCast(
+ constraints: Set[Expression],
+ source: Expression,
+ destination: Attribute): Set[Expression] = {
+ replaceConstraints(constraints, source, destination).map {
+ case eq @ EqualTo(l, r) if l.dataType != r.dataType =>
+ eq.copy(right = Cast(r, l.dataType))
+ case other =>
+ other
+ }
+ }
+
/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org