You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/10/09 03:06:30 UTC

git commit: [SPARK-3707] [SQL] Fix bug of type coercion in DIV

Repository: spark
Updated Branches:
  refs/heads/master 00b779172 -> 4ec931951


[SPARK-3707] [SQL] Fix bug of type coercion in DIV

Calling `BinaryArithmetic.dataType` will throws exception until it's resolved, but in type coercion rule `Division`, seems doesn't follow this.

Author: Cheng Hao <ha...@intel.com>

Closes #2559 from chenghao-intel/type_coercion and squashes the following commits:

199a85d [Cheng Hao] Simplify the divide rule
dc55218 [Cheng Hao] fix bug of type coercion in div


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ec93195
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ec93195
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ec93195

Branch: refs/heads/master
Commit: 4ec931951fea4efbfe5db39cf581704df7d2775b
Parents: 00b7791
Author: Cheng Hao <ha...@intel.com>
Authored: Wed Oct 8 17:52:27 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Oct 8 17:52:27 2014 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  7 +++-
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 40 ++++++++++++++++++--
 2 files changed, 42 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4ec93195/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 79e5283..6488185 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -348,8 +348,11 @@ trait HiveTypeCoercion {
       case e if !e.childrenResolved => e
 
       // Decimal and Double remain the same
-      case d: Divide if d.dataType == DoubleType => d
-      case d: Divide if d.dataType == DecimalType => d
+      case d: Divide if d.resolved && d.dataType == DoubleType => d
+      case d: Divide if d.resolved && d.dataType == DecimalType => d
+
+      case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
+      case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
 
       case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/4ec93195/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 5809a10..7b45738 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.scalatest.{BeforeAndAfter, FunSuite}
 
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.types.IntegerType
+import org.apache.spark.sql.catalyst.types._
 
 class AnalysisSuite extends FunSuite with BeforeAndAfter {
   val caseSensitiveCatalog = new SimpleCatalog(true)
@@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
 
   val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+  val testRelation2 = LocalRelation(
+    AttributeReference("a", StringType)(),
+    AttributeReference("b", StringType)(),
+    AttributeReference("c", DoubleType)(),
+    AttributeReference("d", DecimalType)(),
+    AttributeReference("e", ShortType)())
 
   before {
     caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
@@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     val e = intercept[RuntimeException] {
       caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
     }
-    assert(e.getMessage === "Table Not Found: tAbLe")
+    assert(e.getMessage == "Table Not Found: tAbLe")
 
     assert(
       caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
@@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     }
     assert(e.getMessage().toLowerCase.contains("unresolved plan"))
   }
+
+  test("divide should be casted into fractional types") {
+    val testRelation2 = LocalRelation(
+      AttributeReference("a", StringType)(),
+      AttributeReference("b", StringType)(),
+      AttributeReference("c", DoubleType)(),
+      AttributeReference("d", DecimalType)(),
+      AttributeReference("e", ShortType)())
+
+    val expr0 = 'a / 2
+    val expr1 = 'a / 'b
+    val expr2 = 'a / 'c
+    val expr3 = 'a / 'd
+    val expr4 = 'e / 'e
+    val plan = caseInsensitiveAnalyze(Project(
+      Alias(expr0, s"Analyzer($expr0)")() ::
+      Alias(expr1, s"Analyzer($expr1)")() ::
+      Alias(expr2, s"Analyzer($expr2)")() ::
+      Alias(expr3, s"Analyzer($expr3)")() ::
+      Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
+    val pl = plan.asInstanceOf[Project].projectList
+    assert(pl(0).dataType == DoubleType)
+    assert(pl(1).dataType == DoubleType)
+    assert(pl(2).dataType == DoubleType)
+    assert(pl(3).dataType == DecimalType)
+    assert(pl(4).dataType == DoubleType)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org