You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/29 07:51:13 UTC

spark git commit: [SPARK-9281] [SQL] use decimal or double when parsing SQL

Repository: spark
Updated Branches:
  refs/heads/master 6309b9346 -> 15667a0af


[SPARK-9281] [SQL] use decimal or double when parsing SQL

Right now, we use double to parse all the float number in SQL. When it's used in expression together with DecimalType, it will turn the decimal into double as well. Also it will loss some precision when using double.

This PR change to parse float number to decimal or double, based on it's  using scientific notation or not, see https://msdn.microsoft.com/en-us/library/ms179899.aspx

This is a break change, should we doc it somewhere?

Author: Davies Liu <da...@databricks.com>

Closes #7642 from davies/parse_decimal and squashes the following commits:

1f576d9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal
5e142b6 [Davies Liu] fix scala style
eca99de [Davies Liu] fix tests
2afe702 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal
f4a320b [Davies Liu] Update SqlParser.scala
1c48e34 [Davies Liu] use decimal or double when parsing SQL


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/15667a0a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/15667a0a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/15667a0a

Branch: refs/heads/master
Commit: 15667a0afa5fb17f4cc6fbf32b2ddb573630f20a
Parents: 6309b93
Author: Davies Liu <da...@databricks.com>
Authored: Tue Jul 28 22:51:08 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jul 28 22:51:08 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   | 14 +++++-
 .../catalyst/analysis/HiveTypeCoercion.scala    | 50 +++++++++++++-------
 .../sql/catalyst/analysis/AnalysisSuite.scala   |  4 +-
 .../apache/spark/sql/MathExpressionsSuite.scala |  3 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 14 +++---
 .../org/apache/spark/sql/json/JsonSuite.scala   | 14 +++---
 6 files changed, 62 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b423f0f..e5f115f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
   protected lazy val numericLiteral: Parser[Literal] =
     ( integral  ^^ { case i => Literal(toNarrowestIntegerType(i)) }
     | sign.? ~ unsignedFloat ^^ {
-      // TODO(davies): some precisions may loss, we should create decimal literal
-      case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
+      case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f))
     }
     )
 
@@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
     }
   }
 
+  private def toDecimalOrDouble(value: String): Any = {
+    val decimal = BigDecimal(value)
+    // follow the behavior in MS SQL Server
+    // https://msdn.microsoft.com/en-us/library/ms179899.aspx
+    if (value.contains('E') || value.contains('e')) {
+      decimal.doubleValue()
+    } else {
+      decimal.underlying()
+    }
+  }
+
   protected lazy val baseExpression: Parser[Expression] =
     ( "*" ^^^ UnresolvedStar(None)
     | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }

http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e052750..ecc4898 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -109,13 +109,35 @@ object HiveTypeCoercion {
    * Find the tightest common type of a set of types by continuously applying
    * `findTightestCommonTypeOfTwo` on these types.
    */
-  private def findTightestCommonType(types: Seq[DataType]) = {
+  private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
     types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
       case None => None
       case Some(d) => findTightestCommonTypeOfTwo(d, c)
     })
   }
 
+  private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
+    case (t1: DecimalType, t2: DecimalType) =>
+      Some(DecimalPrecision.widerDecimalType(t1, t2))
+    case (t: IntegralType, d: DecimalType) =>
+      Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+    case (d: DecimalType, t: IntegralType) =>
+      Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+    case (t: FractionalType, d: DecimalType) =>
+      Some(DoubleType)
+    case (d: DecimalType, t: FractionalType) =>
+      Some(DoubleType)
+    case _ =>
+      findTightestCommonTypeToString(t1, t2)
+  }
+
+  private def findWiderCommonType(types: Seq[DataType]) = {
+    types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+      case Some(d) => findWiderTypeForTwo(d, c)
+      case None => None
+    })
+  }
+
   /**
    * Applies any changes to [[AttributeReference]] data types that are made by other rules to
    * instances higher in the query tree.
@@ -182,20 +204,7 @@ object HiveTypeCoercion {
 
       val castedTypes = left.output.zip(right.output).map {
         case (lhs, rhs) if lhs.dataType != rhs.dataType =>
-          (lhs.dataType, rhs.dataType) match {
-            case (t1: DecimalType, t2: DecimalType) =>
-              Some(DecimalPrecision.widerDecimalType(t1, t2))
-            case (t: IntegralType, d: DecimalType) =>
-              Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
-            case (d: DecimalType, t: IntegralType) =>
-              Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
-            case (t: FractionalType, d: DecimalType) =>
-              Some(DoubleType)
-            case (d: DecimalType, t: FractionalType) =>
-              Some(DoubleType)
-            case _ =>
-              findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
-          }
+          findWiderTypeForTwo(lhs.dataType, rhs.dataType)
         case other => None
       }
 
@@ -236,8 +245,13 @@ object HiveTypeCoercion {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a @ BinaryArithmetic(left @ StringType(), r) =>
-        a.makeCopy(Array(Cast(left, DoubleType), r))
+      case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
+        a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
+      case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
+        a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
+
+      case a @ BinaryArithmetic(left @ StringType(), right) =>
+        a.makeCopy(Array(Cast(left, DoubleType), right))
       case a @ BinaryArithmetic(left, right @ StringType()) =>
         a.makeCopy(Array(left, Cast(right, DoubleType)))
 
@@ -543,7 +557,7 @@ object HiveTypeCoercion {
       // compatible with every child column.
       case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
         val types = es.map(_.dataType)
-        findTightestCommonTypeAndPromoteToString(types) match {
+        findWiderCommonType(types) match {
           case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
           case None => c
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 4589fac..221b4e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -145,11 +145,11 @@ class AnalysisSuite extends AnalysisTest {
         'e / 'e as 'div5))
     val pl = plan.asInstanceOf[Project].projectList
 
-    // StringType will be promoted into Double
     assert(pl(0).dataType == DoubleType)
     assert(pl(1).dataType == DoubleType)
     assert(pl(2).dataType == DoubleType)
-    assert(pl(3).dataType == DoubleType)
+    // StringType will be promoted into Decimal(38, 18)
+    assert(pl(3).dataType == DecimalType(38, 29))
     assert(pl(4).dataType == DoubleType)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 2125670..8cf2ef5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest {
     checkAnswer(
       ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
         s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
-      Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+      Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+        BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 42724ed..d13dde1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -368,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       Row(1))
     checkAnswer(
       sql("SELECT COALESCE(null, 1, 1.5)"),
-      Row(1.toDouble))
+      Row(BigDecimal(1)))
     checkAnswer(
       sql("SELECT COALESCE(null, null, null)"),
       Row(null))
@@ -1234,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
 
   test("Floating point number format") {
     checkAnswer(
-      sql("SELECT 0.3"), Row(0.3)
+      sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
     )
 
     checkAnswer(
-      sql("SELECT -0.8"), Row(-0.8)
+      sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
     )
 
     checkAnswer(
-      sql("SELECT .5"), Row(0.5)
+      sql("SELECT .5"), Row(BigDecimal(0.5))
     )
 
     checkAnswer(
-      sql("SELECT -.18"), Row(-0.18)
+      sql("SELECT -.18"), Row(BigDecimal(-0.18))
     )
   }
 
@@ -1279,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     )
 
     checkAnswer(
-      sql("SELECT -5.2"), Row(-5.2)
+      sql("SELECT -5.2"), Row(BigDecimal(-5.2))
     )
 
     checkAnswer(
-      sql("SELECT +6.8"), Row(6.8)
+      sql("SELECT +6.8"), Row(BigDecimal(6.8))
     )
 
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/15667a0a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 3ac312d..f19f22f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData {
       Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
     )
 
-    // Widening to DoubleType
+    // Widening to DecimalType
     checkAnswer(
       sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
-      Row(21474836472.2) ::
-        Row(92233720368547758071.3) :: Nil
+      Row(BigDecimal("21474836472.2")) ::
+        Row(BigDecimal("92233720368547758071.3")) :: Nil
     )
 
-    // Widening to DoubleType
+    // Widening to Double
     checkAnswer(
       sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
       Row(101.2) :: Row(21474836471.2) :: Nil
@@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData {
     // Number and String conflict: resolve the type as number in this query.
     checkAnswer(
       sql("select num_str + 1.2 from jsonTable where num_str > 14"),
-      Row(92233720368547758071.2)
+      Row(BigDecimal("92233720368547758071.2"))
     )
 
     // Number and String conflict: resolve the type as number in this query.
     checkAnswer(
       sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
-      Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
+      Row(new java.math.BigDecimal("92233720368547758071.2"))
     )
 
     // String and Boolean conflict: resolve the type as string.
@@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData {
     // Number and String conflict: resolve the type as number in this query.
     checkAnswer(
       sql("select num_str + 1.2 from jsonTable where num_str > 13"),
-      Row(14.3) :: Row(92233720368547758071.2) :: Nil
+      Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org