You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/18 21:57:58 UTC

spark git commit: [SPARK-9055][SQL] WidenTypes should also support Intersect and Except

Repository: spark
Updated Branches:
  refs/heads/master cdc36eef4 -> 3d2134fc0


[SPARK-9055][SQL] WidenTypes should also support Intersect and Except

JIRA: https://issues.apache.org/jira/browse/SPARK-9055

cc rxin

Author: Yijie Shen <he...@gmail.com>

Closes #7491 from yijieshen/widen and squashes the following commits:

079fa52 [Yijie Shen] widenType support for intersect and expect


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d2134fc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d2134fc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d2134fc

Branch: refs/heads/master
Commit: 3d2134fc0d90379b89da08de7614aef1ac674b1b
Parents: cdc36ee
Author: Yijie Shen <he...@gmail.com>
Authored: Sat Jul 18 12:57:53 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Jul 18 12:57:53 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 93 +++++++++++---------
 .../catalyst/plans/logical/basicOperators.scala |  8 ++
 .../analysis/HiveTypeCoercionSuite.scala        | 34 ++++++-
 3 files changed, 94 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d2134fc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 50db7d2..ff20835 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import javax.annotation.Nullable
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types._
 
@@ -168,52 +168,65 @@ object HiveTypeCoercion {
    * - LongType to DoubleType
    */
   object WidenTypes extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      // TODO: unions with fixed-precision decimals
-      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
-        val castedInput = left.output.zip(right.output).map {
-          // When a string is found on one side, make the other side a string too.
-          case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
-            (lhs, Alias(Cast(rhs, StringType), rhs.name)())
-          case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
-            (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
 
-          case (lhs, rhs) if lhs.dataType != rhs.dataType =>
-            logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}")
-            findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
-              val newLeft =
-                if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
-              val newRight =
-                if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
-
-              (newLeft, newRight)
-            }.getOrElse {
-              // If there is no applicable conversion, leave expression unchanged.
-              (lhs, rhs)
-            }
+    private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
+        (LogicalPlan, LogicalPlan) = {
+
+      // TODO: with fixed-precision decimals
+      val castedInput = left.output.zip(right.output).map {
+        // When a string is found on one side, make the other side a string too.
+        case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
+          (lhs, Alias(Cast(rhs, StringType), rhs.name)())
+        case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
+          (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
+
+        case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+          logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
+          findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
+            val newLeft =
+              if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
+            val newRight =
+              if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
+
+            (newLeft, newRight)
+          }.getOrElse {
+            // If there is no applicable conversion, leave expression unchanged.
+            (lhs, rhs)
+          }
 
-          case other => other
-        }
+        case other => other
+      }
 
-        val (castedLeft, castedRight) = castedInput.unzip
+      val (castedLeft, castedRight) = castedInput.unzip
 
-        val newLeft =
-          if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
-            logDebug(s"Widening numeric types in union $castedLeft ${left.output}")
-            Project(castedLeft, left)
-          } else {
-            left
-          }
+      val newLeft =
+        if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+          logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
+          Project(castedLeft, left)
+        } else {
+          left
+        }
 
-        val newRight =
-          if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
-            logDebug(s"Widening numeric types in union $castedRight ${right.output}")
-            Project(castedRight, right)
-          } else {
-            right
-          }
+      val newRight =
+        if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+          logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
+          Project(castedRight, right)
+        } else {
+          right
+        }
+      (newLeft, newRight)
+    }
 
+    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+        val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
         Union(newLeft, newRight)
+      case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
+        val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
+        Except(newLeft, newRight)
+      case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
+        val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
+        Intersect(newLeft, newRight)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3d2134fc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 17a9124..986c315 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   override def output: Seq[Attribute] = left.output
+
+  override lazy val resolved: Boolean =
+    childrenResolved &&
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
 }
 
 case class InsertIntoTable(
@@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode {
 
 case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   override def output: Seq[Attribute] = left.output
+
+  override lazy val resolved: Boolean =
+    childrenResolved &&
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3d2134fc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index d0fd033..c9b3c69 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.plans.PlanTest
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types._
 
@@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest {
     )
   }
 
+  test("WidenTypes for union except and intersect") {
+    def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+      logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+        assert(attr.dataType === dt)
+      }
+    }
+
+    val left = LocalRelation(
+      AttributeReference("i", IntegerType)(),
+      AttributeReference("u", DecimalType.Unlimited)(),
+      AttributeReference("b", ByteType)(),
+      AttributeReference("d", DoubleType)())
+    val right = LocalRelation(
+      AttributeReference("s", StringType)(),
+      AttributeReference("d", DecimalType(2, 1))(),
+      AttributeReference("f", FloatType)(),
+      AttributeReference("l", LongType)())
+
+    val wt = HiveTypeCoercion.WidenTypes
+    val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
+
+    val r1 = wt(Union(left, right)).asInstanceOf[Union]
+    val r2 = wt(Except(left, right)).asInstanceOf[Except]
+    val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
+    checkOutput(r1.left, expectedTypes)
+    checkOutput(r1.right, expectedTypes)
+    checkOutput(r2.left, expectedTypes)
+    checkOutput(r2.right, expectedTypes)
+    checkOutput(r3.left, expectedTypes)
+    checkOutput(r3.right, expectedTypes)
+  }
+
   /**
    * There are rules that need to not fire before child expressions get resolved.
    * We use this test to make sure those rules do not fire early.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org