You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/07/16 07:43:44 UTC
[spark] branch master updated: [SPARK-28395][SQL] Division operator
support integral division
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6926849 [SPARK-28395][SQL] Division operator support integral division
6926849 is described below
commit 69268492471137dd7a3da54c218026c3b1fa7db3
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Tue Jul 16 15:43:15 2019 +0800
[SPARK-28395][SQL] Division operator support integral division
## What changes were proposed in this pull request?
PostgreSQL, Teradata, SQL Server, DB2 and Presto perform integral division with the `/` operator.
But Oracle, Vertica, Hive, MySQL and MariaDB perform fractional division with the `/` operator.
This pr add a flag(`spark.sql.function.preferIntegralDivision`) to control whether to use integral division with the `/` operator.
Examples:
**PostgreSQL**:
```sql
postgres=# select substr(version(), 0, 16), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8);
substr | ?column? | ?column? | ?column? | ?column?
-----------------+----------+------------------+-----------------+------------------
PostgreSQL 11.3 | 3 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903
(1 row)
```
**SQL Server**:
```sql
1> select cast(10 as int) / cast(3 as int), cast(10.1 as float) / cast(3 as int), cast(10 as int) / cast(3.1 as float), cast(10.1 as float)/cast(3.1 as float);
2> go
----------- ------------------------ ------------------------ ------------------------
3 3.3666666666666667 3.225806451612903 3.258064516129032
(1 rows affected)
```
**DB2**:
```sql
[db2inst12f3c821d36b7 ~]$ db2 "select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double) from table (sysproc.env_get_inst_info())"
1 2 3 4
----------- ------------------------ ------------------------ ------------------------
3 +3.36666666666667E+000 +3.22580645161290E+000 +3.25806451612903E+000
1 record(s) selected.
```
**Presto**:
```sql
presto> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
_col0 | _col1 | _col2 | _col3
-------+--------------------+-------------------+-------------------
3 | 3.3666666666666667 | 3.225806451612903 | 3.258064516129032
(1 row)
```
**Teradata**:
![image](https://user-images.githubusercontent.com/5399861/61200701-e97d5380-a714-11e9-9a1d-57fd99d38c8d.png)
**Oracle**:
```sql
SQL> select 10 / 3 from dual;
10/3
----------
3.33333333
```
**Vertica**
```sql
dbadmin=> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8);
version | ?column? | ?column? | ?column? | ?column?
------------------------------------+----------------------+------------------+-----------------+------------------
Vertica Analytic Database v9.1.1-0 | 3.333333333333333333 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903
(1 row)
```
**Hive**:
```sql
hive> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
OK
3.3333333333333335 3.3666666666666667 3.225806451612903 3.258064516129032
Time taken: 0.143 seconds, Fetched: 1 row(s)
```
**MariaDB**:
```sql
MariaDB [(none)]> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double);
+--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
| version() | cast(10 as int) / cast(3 as int) | cast(10.1 as double) / cast(3 as int) | cast(10 as int) / cast(3.1 as double) | cast(10.1 as double)/cast(3.1 as double) |
+--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
| 10.4.6-MariaDB-1:10.4.6+maria~bionic | 3.3333 | 3.3666666666666667 | 3.225806451612903 | 3.258064516129032 |
+--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+
1 row in set (0.000 sec)
```
**MySQL**:
```sql
mysql> select version(), 10 / 3, 10 / 3.1, 10.1 / 3, 10.1 / 3.1;
+-----------+--------+----------+----------+------------+
| version() | 10 / 3 | 10 / 3.1 | 10.1 / 3 | 10.1 / 3.1 |
+-----------+--------+----------+----------+------------+
| 8.0.16 | 3.3333 | 3.2258 | 3.36667 | 3.25806 |
+-----------+--------+----------+----------+------------+
1 row in set (0.00 sec)
```
## How was this patch tested?
unit tests
Closes #25158 from wangyum/SPARK-28395.
Authored-by: Yuming Wang <yu...@ebay.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 11 ++++++---
.../org/apache/spark/sql/internal/SQLConf.scala | 8 +++++++
.../sql/catalyst/analysis/TypeCoercionSuite.scala | 27 ++++++++++++++++++++--
3 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 1fdec89..3125f8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -59,7 +59,7 @@ object TypeCoercion {
CaseWhenCoercion ::
IfCoercion ::
StackCoercion ::
- Division ::
+ Division(conf) ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
@@ -666,7 +666,7 @@ object TypeCoercion {
* Hive only performs integral division with the DIV operator. The arguments to / are always
* converted to fractional types.
*/
- object Division extends TypeCoercionRule {
+ case class Division(conf: SQLConf) extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
@@ -677,7 +677,12 @@ object TypeCoercion {
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
- Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+ (left.dataType, right.dataType) match {
+ case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision =>
+ IntegralDivide(left, right)
+ case _ =>
+ Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+ }
}
private def isNumericOrNull(ex: Expression): Boolean = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index f76103e..57f5128 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1524,6 +1524,12 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
+ .doc("When true, will perform integral division with the / operator " +
+ "if both sides are integral types.")
+ .booleanConf
+ .createWithDefault(false)
+
val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
.internal()
@@ -2294,6 +2300,8 @@ class SQLConf extends Serializable with Logging {
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
+ def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION)
+
def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index a725e4b..949bb30 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
- val rules = Seq(FunctionArgumentConversion, Division)
+ val rules = Seq(FunctionArgumentConversion, Division(conf))
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("SPARK-17117 null type coercion in divide") {
- val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+ val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts)
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
}
+ test("SPARK-28395 Division operator support integral division") {
+ val rules = Seq(FunctionArgumentConversion, Division(conf))
+ Seq(true, false).foreach { preferIntegralDivision =>
+ withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") {
+ val result1 = if (preferIntegralDivision) {
+ IntegralDivide(1L, 1L)
+ } else {
+ Divide(Cast(1L, DoubleType), Cast(1L, DoubleType))
+ }
+ ruleTest(rules, Divide(1L, 1L), result1)
+ val result2 = if (preferIntegralDivision) {
+ IntegralDivide(1, Cast(1, ShortType))
+ } else {
+ Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType))
+ }
+ ruleTest(rules, Divide(1, Cast(1, ShortType)), result2)
+
+ ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType)))
+ ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L))
+ }
+ }
+ }
+
test("binary comparison with string promotion") {
val rule = TypeCoercion.PromoteStrings(conf)
ruleTest(rule,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org