You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/11/09 08:03:00 UTC

[flink] branch release-1.14 updated (aa6ef87 -> bafb0b4)

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a change to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from aa6ef87  [FLINK-24761][table] Fix PartitionPruner code gen compile fail
     new 3610151  [hotfix][table-common][tests] Add tests for LogicalTypeMerging decimal rules
     new 0c47ed7  [hotfix][table-common] Add java docs to LogicalTypeMerging
     new bafb0b4  [FLINK-24691][table-planner] Fix decimal precision for SUM

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../functions/BuiltInFunctionDefinitions.java      |  19 ++++
 ...rategy.java => AggDecimalPlusTypeStrategy.java} |  50 +++++-----
 .../strategies/SpecificTypeStrategies.java         |   3 +
 .../types/logical/utils/LogicalTypeMerging.java    |  14 +++
 .../logical/utils/LogicalTypeMergingTest.java      | 109 +++++++++++++++++++++
 .../planner/expressions/ExpressionBuilder.java     |   9 ++
 .../functions/aggfunctions/SumAggFunction.java     |  41 +++++---
 .../table/planner/codegen/ExprCodeGenerator.scala  |  19 ++--
 .../runtime/stream/sql/AggregateITCase.scala       |  28 ++++++
 .../runtime/stream/table/AggregateITCase.scala     |  23 +++++
 10 files changed, 265 insertions(+), 50 deletions(-)
 copy flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/{DecimalPlusTypeStrategy.java => AggDecimalPlusTypeStrategy.java} (54%)
 create mode 100644 flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java

[flink] 01/03: [hotfix][table-common][tests] Add tests for LogicalTypeMerging decimal rules

Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 36101511a96944ac3c05b1a265f5d726441f477a
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 09:54:07 2021 +0100

    [hotfix][table-common][tests] Add tests for LogicalTypeMerging decimal rules
    
    Add tests for all methods of `LogicalTypeMerging` which calculate the precision
    and scale of the resulting decimal for arithmetic operations like `+ - * / % round`
    as well as for `avg` and `sum` aggregate functions.
---
 .../logical/utils/LogicalTypeMergingTest.java      | 101 +++++++++++++++++++++
 1 file changed, 101 insertions(+)

diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
new file mode 100644
index 0000000..22c4bf8
--- /dev/null
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.logical.utils;
+
+import org.apache.flink.table.types.logical.DecimalType;
+
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */
+public class LogicalTypeMergingTest {
+
+    @Test
+    public void testFindDivisionDecimalType() {
+        assertThat(
+                LogicalTypeMerging.findDivisionDecimalType(32, 8, 38, 8), equalTo(decimal(38, 6)));
+        assertThat(
+                LogicalTypeMerging.findDivisionDecimalType(30, 20, 30, 20),
+                equalTo(decimal(38, 8)));
+    }
+
+    @Test
+    public void testFindMultiplicationDecimalType() {
+        assertThat(
+                LogicalTypeMerging.findMultiplicationDecimalType(30, 10, 30, 10),
+                equalTo(decimal(38, 6)));
+        assertThat(
+                LogicalTypeMerging.findMultiplicationDecimalType(30, 20, 30, 20),
+                equalTo(decimal(38, 17)));
+        assertThat(
+                LogicalTypeMerging.findMultiplicationDecimalType(38, 2, 38, 3),
+                equalTo(decimal(38, 5)));
+    }
+
+    @Test
+    public void testFindModuloDecimalType() {
+        assertThat(
+                LogicalTypeMerging.findModuloDecimalType(30, 10, 30, 10), equalTo(decimal(30, 10)));
+        assertThat(
+                LogicalTypeMerging.findModuloDecimalType(30, 20, 25, 20), equalTo(decimal(25, 20)));
+        assertThat(
+                LogicalTypeMerging.findModuloDecimalType(10, 10, 10, 10), equalTo(decimal(10, 10)));
+    }
+
+    @Test
+    public void testFindAdditionDecimalType() {
+        assertThat(
+                LogicalTypeMerging.findAdditionDecimalType(38, 8, 32, 8), equalTo(decimal(38, 7)));
+        assertThat(
+                LogicalTypeMerging.findAdditionDecimalType(32, 8, 38, 8), equalTo(decimal(38, 7)));
+        assertThat(
+                LogicalTypeMerging.findAdditionDecimalType(30, 20, 28, 20),
+                equalTo(decimal(31, 20)));
+        assertThat(
+                LogicalTypeMerging.findAdditionDecimalType(10, 10, 10, 10),
+                equalTo(decimal(11, 10)));
+        assertThat(
+                LogicalTypeMerging.findAdditionDecimalType(38, 5, 38, 4), equalTo(decimal(38, 5)));
+    }
+
+    @Test
+    public void testFindRoundingDecimalType() {
+        assertThat(LogicalTypeMerging.findRoundDecimalType(32, 8, 5), equalTo(decimal(30, 5)));
+        assertThat(LogicalTypeMerging.findRoundDecimalType(32, 8, 10), equalTo(decimal(32, 8)));
+        assertThat(LogicalTypeMerging.findRoundDecimalType(30, 20, 18), equalTo(decimal(29, 18)));
+        assertThat(LogicalTypeMerging.findRoundDecimalType(10, 10, 2), equalTo(decimal(3, 2)));
+    }
+
+    @Test
+    public void testFindAvgAggType() {
+        assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 20)), equalTo(decimal(38, 20)));
+        assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 2)), equalTo(decimal(38, 6)));
+        assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 8)), equalTo(decimal(38, 8)));
+        assertThat(LogicalTypeMerging.findAvgAggType(decimal(30, 20)), equalTo(decimal(38, 20)));
+        assertThat(LogicalTypeMerging.findAvgAggType(decimal(10, 10)), equalTo(decimal(38, 10)));
+    }
+
+    private static final DecimalType decimal(int precision, int scale) {
+        return new DecimalType(false, precision, scale);
+    }
+}

[flink] 03/03: [FLINK-24691][table-planner] Fix decimal precision for SUM

Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bafb0b4c2377d6d502ed9dba8853631ebf16cfb7
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 13:03:26 2021 +0100

    [FLINK-24691][table-planner] Fix decimal precision for SUM
    
    Since SUM is using internally `plus()` operator to implement the sum
    aggregation, the decimal return type calculated by `LogicalTypeMerging#findSumAggType()`
    gets overriden by the calculation for the `plus()` operator done by
    `LogicalTypeMerging#findAdditionDecimalType()`. To prevent this add a special
    `aggDecimalPlus()` operator to be used exclusively for aggregate function to avoid
    overriding their calculated precision.
    
    This closes #17634.
---
 .../functions/BuiltInFunctionDefinitions.java      | 19 +++++++
 .../strategies/AggDecimalPlusTypeStrategy.java     | 66 ++++++++++++++++++++++
 .../strategies/SpecificTypeStrategies.java         |  3 +
 .../types/logical/utils/LogicalTypeMerging.java    | 14 +++--
 .../logical/utils/LogicalTypeMergingTest.java      | 10 +++-
 .../planner/expressions/ExpressionBuilder.java     |  9 +++
 .../functions/aggfunctions/SumAggFunction.java     | 41 +++++++++-----
 .../table/planner/codegen/ExprCodeGenerator.scala  | 19 ++++---
 .../runtime/stream/sql/AggregateITCase.scala       | 28 +++++++++
 .../runtime/stream/table/AggregateITCase.scala     | 23 ++++++++
 10 files changed, 204 insertions(+), 28 deletions(-)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index 022b7b1..d1695f6 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -29,6 +29,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategies;
 import org.apache.flink.table.types.inference.TypeStrategies;
 import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies;
 import org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies;
+import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
 import org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
@@ -743,6 +744,24 @@ public final class BuiltInFunctionDefinitions {
                                             explicit(DataTypes.STRING()))))
                     .build();
 
+    /**
+     * Special "+" operator used internally by {@code SumAggFunction} to implement SUM aggregation
+     * on a Decimal type. Uses the {@link LogicalTypeMerging#findSumAggType(LogicalType)} to avoid
+     * the normal {@link #PLUS} override the special calculation for precision and scale needed by
+     * SUM.
+     */
+    public static final BuiltInFunctionDefinition AGG_DECIMAL_PLUS =
+            BuiltInFunctionDefinition.newBuilder()
+                    .name("AGG_DECIMAL_PLUS")
+                    .kind(SCALAR)
+                    .inputTypeStrategy(
+                            sequence(
+                                    logical(LogicalTypeRoot.DECIMAL),
+                                    logical(LogicalTypeRoot.DECIMAL)))
+                    .outputTypeStrategy(SpecificTypeStrategies.AGG_DECIMAL_PLUS)
+                    .runtimeProvided()
+                    .build();
+
     /** Combines numeric subtraction and "datetime - interval" arithmetic. */
     public static final BuiltInFunctionDefinition MINUS =
             BuiltInFunctionDefinition.newBuilder()
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
new file mode 100644
index 0000000..23be242
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
+import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
+import org.apache.flink.table.types.utils.TypeConversions;
+import org.apache.flink.util.Preconditions;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Type strategy that returns the result decimal addition used, internally by {@code SumAggFunction}
+ * to implement SUM aggregation on a Decimal type. Uses the {@link
+ * LogicalTypeMerging#findSumAggType(LogicalType)} and prevents the {@link DecimalPlusTypeStrategy}
+ * from overriding the special calculation for precision and scale needed by SUM.
+ */
+@Internal
+class AggDecimalPlusTypeStrategy implements TypeStrategy {
+
+    private static final String ERROR_MSG =
+            "Both args of "
+                    + AggDecimalPlusTypeStrategy.class.getSimpleName()
+                    + " should be of type["
+                    + DecimalType.class.getSimpleName()
+                    + "]";
+
+    @Override
+    public Optional<DataType> inferType(CallContext callContext) {
+        final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
+        final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType();
+        final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType();
+
+        Preconditions.checkArgument(
+                LogicalTypeChecks.hasRoot(addend1, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+        Preconditions.checkArgument(
+                LogicalTypeChecks.hasRoot(addend2, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+
+        return Optional.of(
+                TypeConversions.fromLogicalToDataType(LogicalTypeMerging.findSumAggType(addend2)));
+    }
+}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
index 69b1e5f..c0e4dee 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
@@ -61,6 +61,9 @@ public final class SpecificTypeStrategies {
     /** See {@link DecimalPlusTypeStrategy}. */
     public static final TypeStrategy DECIMAL_PLUS = new DecimalPlusTypeStrategy();
 
+    /** See {@link AggDecimalPlusTypeStrategy}. */
+    public static final TypeStrategy AGG_DECIMAL_PLUS = new AggDecimalPlusTypeStrategy();
+
     /** See {@link DecimalScale0TypeStrategy}. */
     public static final TypeStrategy DECIMAL_SCALE_0 = new DecimalScale0TypeStrategy();
 
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
index 8cf5d96..e7aab32 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
@@ -304,15 +304,19 @@ public final class LogicalTypeMerging {
      *
      * <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
      *
-     * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach
-     * of Spark/Hive is followed for adjusting the precision.
+     * <p>The rules (although inspired by SQL Server) are not followed 100%, instead the approach of
+     * Spark/Hive is followed for adjusting the precision.
      *
      * <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html
      *
-     * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31
+     * <p>For (38, 8) + (32, 8) -> (39, 8) (The rules for addition, initially calculate a decimal
+     * type, assuming its precision is infinite) results in a decimal with integral part of 31
+     * digits.
      *
-     * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead
-     * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
+     * <p>This method is called subsequently to adjust the resulting decimal since the maximum
+     * allowed precision is 38 (so far a precision of 39 is calculated in the first step). So, the
+     * rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead we
+     * follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
      */
     private static DecimalType adjustPrecisionScale(int precision, int scale) {
         if (precision <= DecimalType.MAX_PRECISION) {
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
index 22c4bf8..8217130 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.table.types.logical.utils;
 
 import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
 
 import org.junit.Test;
 
@@ -27,7 +28,14 @@ import java.util.List;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */
+/**
+ * Tests for {@link LogicalTypeMerging} for finding the result decimal type for the various
+ * operations, e.g.: {@link LogicalTypeMerging#findSumAggType(LogicalType)}, {@link
+ * LogicalTypeMerging#findAdditionDecimalType(int, int, int, int)}, etc.
+ *
+ * <p>For {@link LogicalTypeMerging#findCommonType(List)} tests please check {@link
+ * org.apache.flink.table.types.LogicalCommonTypeTest}
+ */
 public class LogicalTypeMergingTest {
 
     @Test
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
index 9effb05..1e58d93 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.planner.expressions;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.table.expressions.ApiExpressionUtils;
 import org.apache.flink.table.expressions.Expression;
 import org.apache.flink.table.expressions.TypeLiteralExpression;
@@ -28,6 +29,7 @@ import org.apache.flink.table.types.DataType;
 
 import java.util.List;
 
+import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT;
@@ -100,6 +102,13 @@ public class ExpressionBuilder {
         return call(PLUS, input1, input2);
     }
 
+    // Used only for implementing the SumAggFunction to avoid overriding decimal precision/scale
+    // calculation for sum with the rules applied for the normal plus
+    @Internal
+    public static UnresolvedCallExpression aggDecimalPlus(Expression input1, Expression input2) {
+        return call(AGG_DECIMAL_PLUS, input1, input2);
+    }
+
     public static UnresolvedCallExpression minus(Expression input1, Expression input2) {
         return call(MINUS, input1, input2);
     }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
index 4e93800..ba3ebbb 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
@@ -28,16 +28,16 @@ import org.apache.flink.table.types.logical.DecimalType;
 import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
 
 import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.aggDecimalPlus;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
 
 /** built-in sum aggregate function. */
 public abstract class SumAggFunction extends DeclarativeAggregateFunction {
-    private UnresolvedReferenceExpression sum = unresolvedRef("sum");
+
+    protected UnresolvedReferenceExpression sum = unresolvedRef("sum");
 
     @Override
     public int operandCount() {
@@ -62,11 +62,13 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     @Override
     public Expression[] accumulateExpressions() {
         return new Expression[] {
-            /* sum = */ adjustSumType(
+            /* sum = */ ifThenElse(
+                    isNull(operand(0)),
+                    sum,
                     ifThenElse(
                             isNull(operand(0)),
                             sum,
-                            ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))))
+                            ifThenElse(isNull(sum), operand(0), doPlus(sum, operand(0)))))
         };
     }
 
@@ -79,17 +81,16 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     @Override
     public Expression[] mergeExpressions() {
         return new Expression[] {
-            /* sum = */ adjustSumType(
-                    ifThenElse(
-                            isNull(mergeOperand(sum)),
-                            sum,
-                            ifThenElse(
-                                    isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))))
+            /* sum = */ ifThenElse(
+                    isNull(mergeOperand(sum)),
+                    sum,
+                    ifThenElse(isNull(sum), mergeOperand(sum), doPlus(sum, mergeOperand(sum))))
         };
     }
 
-    private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
-        return cast(sumExpr, typeLiteral(getResultType()));
+    protected UnresolvedCallExpression doPlus(
+            UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+        return plus(arg1, arg2);
     }
 
     @Override
@@ -149,6 +150,7 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     /** Built-in Decimal Sum aggregate function. */
     public static class DecimalSumAggFunction extends SumAggFunction {
         private DecimalType decimalType;
+        private DataType returnType;
 
         public DecimalSumAggFunction(DecimalType decimalType) {
             this.decimalType = decimalType;
@@ -156,8 +158,17 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
 
         @Override
         public DataType getResultType() {
-            DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
-            return DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+            if (returnType == null) {
+                DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
+                returnType = DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+            }
+            return returnType;
+        }
+
+        @Override
+        protected UnresolvedCallExpression doPlus(
+                UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+            return aggDecimalPlus(arg1, arg2);
         }
     }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
index 4cea7d2..32cfe0f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
@@ -806,21 +806,26 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
 
       case bsf: BridgingSqlFunction =>
         bsf.getDefinition match {
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
+          case BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
             generateWatermark(ctx, contextTerm, resultType)
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.GREATEST =>
+
+          case BuiltInFunctionDefinitions.GREATEST =>
             operands.foreach { operand =>
               requireComparable(operand)
             }
             generateGreatestLeast(resultType, operands)
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.LEAST =>
+
+          case BuiltInFunctionDefinitions.LEAST =>
             operands.foreach { operand =>
               requireComparable(operand)
             }
-            generateGreatestLeast(resultType, operands, false)
+            generateGreatestLeast(resultType, operands, greatest = false)
+
+          case BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS =>
+            val left = operands.head
+            val right = operands(1)
+            generateBinaryArithmeticOperator(ctx, "+", resultType, left, right)
+
           case _ =>
             new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
         }
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 8db0b97..64bec0b 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -397,6 +397,34 @@ class AggregateITCase(
   }
 
   @Test
+  def testPrecisionForSumAggregationOnDecimal(): Unit = {
+    var t = tEnv.sqlQuery(
+        "select cast(sum(1.03520274) as DECIMAL(32, 8)), " +
+        "cast(sum(12345.035202748654) AS DECIMAL(30, 20)), " +
+        "cast(sum(12.345678901234567) AS DECIMAL(25, 22))")
+    var sink = new TestingRetractSink
+    t.toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+    var expected = List("1.03520274,12345.03520274865400000000,12.3456789012345670000000")
+    assertEquals(expected, sink.getRetractResults)
+
+    val data = new mutable.MutableList[(Double, Int)]
+    data .+= ((1.11111111, 1))
+    data .+= ((1.11111111, 2))
+    env.setParallelism(1)
+
+    t = failingDataSource(data).toTable(tEnv, 'a, 'b)
+    tEnv.registerTable("T", t)
+
+    t = tEnv.sqlQuery("select sum(cast(a as decimal(32, 8))) from T")
+    sink = new TestingRetractSink
+    t.toRetractStream[Row].addSink(sink)
+    env.execute()
+    expected = List("2.22222222")
+    assertEquals(expected, sink.getRetractResults)
+  }
+
+  @Test
   def testGroupByAgg(): Unit = {
     val data = new mutable.MutableList[(Int, Long, String)]
     data.+=((1, 1L, "A"))
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
index 89454c4..0c3a99a 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.scala._
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.api.internal.TableEnvironmentInternal
+import org.apache.flink.table.api.DataTypes.DECIMAL
 import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
 import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
 import org.apache.flink.table.planner.runtime.utils.TestData._
@@ -395,4 +396,26 @@ class AggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase
     val expected = mutable.MutableList("1,1", "2,3", "3,6", "4,10", "5,15", "6,21")
     assertEquals(expected.sorted, sink.getRetractResults.sorted)
   }
+
+  @Test
+  def testPrecisionForSumAggregationOnDecimal(): Unit = {
+    val data = new mutable.MutableList[(Double, Double, Double, Double)]
+    data.+=((1.03520274, 12345.035202748654, 12.345678901234567, 1.11111111))
+    data.+=((0, 0, 0, 1.11111111))
+    val t = failingDataSource(data).toTable(tEnv, 'a, 'b, 'c, 'd)
+
+    val results = t
+      .select('a.cast(DECIMAL(32, 8)).sum as 'a,
+        'b.cast(DECIMAL(30, 20)).sum as 'b,
+        'c.cast(DECIMAL(25, 20)).sum as 'c,
+        'd.cast(DECIMAL(32, 8)).sum as 'd)
+      .toRetractStream[Row]
+
+    val sink = new TestingRetractSink
+    results.addSink(sink).setParallelism(1)
+    env.execute()
+
+    val expected = List("1.03520274,12345.03520274865300000000,12.34567890123456700000,2.22222222")
+    assertEquals(expected, sink.getRetractResults)
+  }
 }

[flink] 02/03: [hotfix][table-common] Add java docs to LogicalTypeMerging

Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 0c47ed79b4ea7e3804309426c61f3ecdc3b96273
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 10:06:25 2021 +0100

    [hotfix][table-common] Add java docs to LogicalTypeMerging
    
    Add more explanation in `LogicalTypeMerging#adjustPrecisionScale()` method
    regarding the decision to not follow 100% the Microsoft's SQL Server rules
    but instead the Hive/Spark behaviour, when calculating the resulting precision
    of a decimal operation.
---
 .../flink/table/types/logical/utils/LogicalTypeMerging.java    | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
index dfb1ec1..8cf5d96 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
@@ -303,6 +303,16 @@ public final class LogicalTypeMerging {
      * integral part of a result from being truncated.
      *
      * <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
+     *
+     * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach
+     * of Spark/Hive is followed for adjusting the precision.
+     *
+     * <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html
+     *
+     * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31
+     *
+     * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead
+     * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
      */
     private static DecimalType adjustPrecisionScale(int precision, int scale) {
         if (precision <= DecimalType.MAX_PRECISION) {