You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/07/16 07:43:44 UTC

[spark] branch master updated: [SPARK-28395][SQL] Division operator support integral division

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6926849  [SPARK-28395][SQL] Division operator support integral division
6926849 is described below

commit 69268492471137dd7a3da54c218026c3b1fa7db3
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Tue Jul 16 15:43:15 2019 +0800

    [SPARK-28395][SQL] Division operator support integral division
    
    ## What changes were proposed in this pull request?
    
    PostgreSQL, Teradata, SQL Server, DB2 and Presto perform integral division with the `/` operator.
    But Oracle, Vertica, Hive, MySQL and MariaDB perform fractional division with the `/` operator.
    
    This pr add a flag(`spark.sql.function.preferIntegralDivision`) to control whether to use integral division with the `/` operator.
    
    Examples:
    
    **PostgreSQL**:
    ```sql
    postgres=# select substr(version(), 0, 16), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8);
         substr      | ?column? |     ?column?     |    ?column?     |     ?column?
    -----------------+----------+------------------+-----------------+------------------
     PostgreSQL 11.3 |        3 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903
    (1 row)
    ```
    **SQL Server**:
    ```sql
    1> select cast(10 as int) / cast(3 as int), cast(10.1 as float) / cast(3 as int), cast(10 as int) / cast(3.1 as float), cast(10.1 as float)/cast(3.1 as float);
    2> go
    
    ----------- ------------------------ ------------------------ ------------------------
              3       3.3666666666666667        3.225806451612903        3.258064516129032
    
    (1 rows affected)
    ```
    **DB2**:
    ```sql
    [db2inst12f3c821d36b7 ~]$ db2 "select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double) from table (sysproc.env_get_inst_info())"
    
    1           2                        3                        4
    ----------- ------------------------ ------------------------ ------------------------
              3   +3.36666666666667E+000   +3.22580645161290E+000   +3.25806451612903E+000
    
      1 record(s) selected.
    ```
    **Presto**:
    ```sql
    presto> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
     _col0 |       _col1        |       _col2       |       _col3
    -------+--------------------+-------------------+-------------------
         3 | 3.3666666666666667 | 3.225806451612903 | 3.258064516129032
    (1 row)
    ```
    **Teradata**:
    ![image](https://user-images.githubusercontent.com/5399861/61200701-e97d5380-a714-11e9-9a1d-57fd99d38c8d.png)
    
    **Oracle**:
    ```sql
    SQL> select 10 / 3 from dual;
    
          10/3
    ----------
    3.33333333
    ```
    **Vertica**
    ```sql
    dbadmin=> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8);
                  version               |       ?column?       |     ?column?     |    ?column?     |     ?column?
    ------------------------------------+----------------------+------------------+-----------------+------------------
     Vertica Analytic Database v9.1.1-0 | 3.333333333333333333 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903
    (1 row)
    ```
    **Hive**:
    ```sql
    hive> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
    OK
    3.3333333333333335	3.3666666666666667	3.225806451612903	3.258064516129032
    Time taken: 0.143 seconds, Fetched: 1 row(s)
    ```
    **MariaDB**:
    ```sql
    MariaDB [(none)]> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
    +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
    | version()                            | cast(10 as int) / cast(3 as int) | cast(10.1 as double) / cast(3 as int) | cast(10 as int) / cast(3.1 as double) | cast(10.1 as double)/cast(3.1 as double) |
    +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
    | 10.4.6-MariaDB-1:10.4.6+maria~bionic |                           3.3333 |                    3.3666666666666667 |                     3.225806451612903 |                        3.258064516129032 |
    +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
    1 row in set (0.000 sec)
    ```
    **MySQL**:
    ```sql
    mysql>  select version(), 10 / 3, 10 / 3.1, 10.1 / 3, 10.1 / 3.1;
    +-----------+--------+----------+----------+------------+
    | version() | 10 / 3 | 10 / 3.1 | 10.1 / 3 | 10.1 / 3.1 |
    +-----------+--------+----------+----------+------------+
    | 8.0.16    | 3.3333 |   3.2258 |  3.36667 |    3.25806 |
    +-----------+--------+----------+----------+------------+
    1 row in set (0.00 sec)
    ```
    ## How was this patch tested?
    
    unit tests
    
    Closes #25158 from wangyum/SPARK-28395.
    
    Authored-by: Yuming Wang <yu...@ebay.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 11 ++++++---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  8 +++++++
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  | 27 ++++++++++++++++++++--
 3 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 1fdec89..3125f8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -59,7 +59,7 @@ object TypeCoercion {
       CaseWhenCoercion ::
       IfCoercion ::
       StackCoercion ::
-      Division ::
+      Division(conf) ::
       ImplicitTypeCasts ::
       DateTimeOperations ::
       WindowFrameCoercion ::
@@ -666,7 +666,7 @@ object TypeCoercion {
    * Hive only performs integral division with the DIV operator. The arguments to / are always
    * converted to fractional types.
    */
-  object Division extends TypeCoercionRule {
+  case class Division(conf: SQLConf)  extends TypeCoercionRule {
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
       // Skip nodes who has not been resolved yet,
@@ -677,7 +677,12 @@ object TypeCoercion {
       case d: Divide if d.dataType == DoubleType => d
       case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
       case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
-        Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+        (left.dataType, right.dataType) match {
+          case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision =>
+            IntegralDivide(left, right)
+          case _ =>
+            Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+        }
     }
 
     private def isNumericOrNull(ex: Expression): Boolean = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index f76103e..57f5128 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1524,6 +1524,12 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
+    .doc("When true, will perform integral division with the / operator " +
+      "if both sides are integral types.")
+    .booleanConf
+    .createWithDefault(false)
+
   val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
     buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
     .internal()
@@ -2294,6 +2300,8 @@ class SQLConf extends Serializable with Logging {
 
   def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
 
+  def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION)
+
   def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
     getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index a725e4b..949bb30 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest {
 
   test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
     "in aggregation function like sum") {
-    val rules = Seq(FunctionArgumentConversion, Division)
+    val rules = Seq(FunctionArgumentConversion, Division(conf))
     // Casts Integer to Double
     ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
     // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest {
   }
 
   test("SPARK-17117 null type coercion in divide") {
-    val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+    val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts)
     val nullLit = Literal.create(null, NullType)
     ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
     ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
   }
 
+  test("SPARK-28395 Division operator support integral division") {
+    val rules = Seq(FunctionArgumentConversion, Division(conf))
+    Seq(true, false).foreach { preferIntegralDivision =>
+      withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") {
+        val result1 = if (preferIntegralDivision) {
+          IntegralDivide(1L, 1L)
+        } else {
+          Divide(Cast(1L, DoubleType), Cast(1L, DoubleType))
+        }
+        ruleTest(rules, Divide(1L, 1L), result1)
+        val result2 = if (preferIntegralDivision) {
+          IntegralDivide(1, Cast(1, ShortType))
+        } else {
+          Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType))
+        }
+        ruleTest(rules, Divide(1, Cast(1, ShortType)), result2)
+
+        ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType)))
+        ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L))
+      }
+    }
+  }
+
   test("binary comparison with string promotion") {
     val rule = TypeCoercion.PromoteStrings(conf)
     ruleTest(rule,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org