You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2019/09/29 07:47:33 UTC

[spark] 01/01: Add replaceConstraintsWithCast

This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch SPARK-29231
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 52c43ed599fcb1004d6132415c8f3d034933e983
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Sun Sep 29 15:46:01 2019 +0800

    Add replaceConstraintsWithCast
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala  |  2 +-
 .../plans/logical/QueryPlanConstraints.scala        | 21 +++++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 8ab25f2..f1e72b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -266,7 +266,7 @@ object TypeCoercion {
    * system limitation, this rule will truncate the decimal type. If a decimal and other fractional
    * types are compared, returns a double type.
    */
-  private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
+  def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
     (dt1, dt2) match {
       case (t1: DecimalType, t2: DecimalType) =>
         Some(DecimalPrecision.widerDecimalType(t1, t2))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
index 1355003..fc6211f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{findTightestCommonType, findWiderTypeForDecimal}
 import org.apache.spark.sql.catalyst.expressions._
 
 
@@ -67,6 +68,14 @@ trait ConstraintHelper {
         val candidateConstraints = constraints - eq
         inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
         inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
+
+      case eq @ EqualTo(Cast(l: Attribute, dataType, _), r: Attribute)
+        if findTightestCommonType(l.dataType, r.dataType)
+          .orElse(findWiderTypeForDecimal(l.dataType, r.dataType)).exists(_.sameType(dataType)) =>
+        val candidateConstraints = constraints - eq
+        inferredConstraints ++= replaceConstraintsWithCast(candidateConstraints, l, r)
+        inferredConstraints ++= replaceConstraintsWithCast(candidateConstraints, r, l)
+
       case _ => // No inference
     }
     inferredConstraints -- constraints
@@ -79,6 +88,18 @@ trait ConstraintHelper {
     case e: Expression if e.semanticEquals(source) => destination
   })
 
+  private def replaceConstraintsWithCast(
+      constraints: Set[Expression],
+      source: Expression,
+      destination: Attribute): Set[Expression] = {
+    replaceConstraints(constraints, source, destination).map {
+      case eq @ EqualTo(l, r) if l.dataType != r.dataType =>
+        eq.copy(right = Cast(r, l.dataType))
+      case other =>
+        other
+    }
+  }
+
   /**
    * Infers a set of `isNotNull` constraints from null intolerant expressions as well as
    * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org