You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/11/09 08:03:03 UTC

[flink] 03/03: [FLINK-24691][table-planner] Fix decimal precision for SUM

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bafb0b4c2377d6d502ed9dba8853631ebf16cfb7
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 13:03:26 2021 +0100

    [FLINK-24691][table-planner] Fix decimal precision for SUM
    
    Since SUM is using internally `plus()` operator to implement the sum
    aggregation, the decimal return type calculated by `LogicalTypeMerging#findSumAggType()`
    gets overriden by the calculation for the `plus()` operator done by
    `LogicalTypeMerging#findAdditionDecimalType()`. To prevent this add a special
    `aggDecimalPlus()` operator to be used exclusively for aggregate function to avoid
    overriding their calculated precision.
    
    This closes #17634.
---
 .../functions/BuiltInFunctionDefinitions.java      | 19 +++++++
 .../strategies/AggDecimalPlusTypeStrategy.java     | 66 ++++++++++++++++++++++
 .../strategies/SpecificTypeStrategies.java         |  3 +
 .../types/logical/utils/LogicalTypeMerging.java    | 14 +++--
 .../logical/utils/LogicalTypeMergingTest.java      | 10 +++-
 .../planner/expressions/ExpressionBuilder.java     |  9 +++
 .../functions/aggfunctions/SumAggFunction.java     | 41 +++++++++-----
 .../table/planner/codegen/ExprCodeGenerator.scala  | 19 ++++---
 .../runtime/stream/sql/AggregateITCase.scala       | 28 +++++++++
 .../runtime/stream/table/AggregateITCase.scala     | 23 ++++++++
 10 files changed, 204 insertions(+), 28 deletions(-)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index 022b7b1..d1695f6 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -29,6 +29,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategies;
 import org.apache.flink.table.types.inference.TypeStrategies;
 import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies;
 import org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies;
+import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
 import org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
@@ -743,6 +744,24 @@ public final class BuiltInFunctionDefinitions {
                                             explicit(DataTypes.STRING()))))
                     .build();
 
+    /**
+     * Special "+" operator used internally by {@code SumAggFunction} to implement SUM aggregation
+     * on a Decimal type. Uses the {@link LogicalTypeMerging#findSumAggType(LogicalType)} to avoid
+     * the normal {@link #PLUS} override the special calculation for precision and scale needed by
+     * SUM.
+     */
+    public static final BuiltInFunctionDefinition AGG_DECIMAL_PLUS =
+            BuiltInFunctionDefinition.newBuilder()
+                    .name("AGG_DECIMAL_PLUS")
+                    .kind(SCALAR)
+                    .inputTypeStrategy(
+                            sequence(
+                                    logical(LogicalTypeRoot.DECIMAL),
+                                    logical(LogicalTypeRoot.DECIMAL)))
+                    .outputTypeStrategy(SpecificTypeStrategies.AGG_DECIMAL_PLUS)
+                    .runtimeProvided()
+                    .build();
+
     /** Combines numeric subtraction and "datetime - interval" arithmetic. */
     public static final BuiltInFunctionDefinition MINUS =
             BuiltInFunctionDefinition.newBuilder()
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
new file mode 100644
index 0000000..23be242
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
+import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
+import org.apache.flink.table.types.utils.TypeConversions;
+import org.apache.flink.util.Preconditions;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Type strategy that returns the result decimal addition used, internally by {@code SumAggFunction}
+ * to implement SUM aggregation on a Decimal type. Uses the {@link
+ * LogicalTypeMerging#findSumAggType(LogicalType)} and prevents the {@link DecimalPlusTypeStrategy}
+ * from overriding the special calculation for precision and scale needed by SUM.
+ */
+@Internal
+class AggDecimalPlusTypeStrategy implements TypeStrategy {
+
+    private static final String ERROR_MSG =
+            "Both args of "
+                    + AggDecimalPlusTypeStrategy.class.getSimpleName()
+                    + " should be of type["
+                    + DecimalType.class.getSimpleName()
+                    + "]";
+
+    @Override
+    public Optional<DataType> inferType(CallContext callContext) {
+        final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
+        final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType();
+        final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType();
+
+        Preconditions.checkArgument(
+                LogicalTypeChecks.hasRoot(addend1, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+        Preconditions.checkArgument(
+                LogicalTypeChecks.hasRoot(addend2, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+
+        return Optional.of(
+                TypeConversions.fromLogicalToDataType(LogicalTypeMerging.findSumAggType(addend2)));
+    }
+}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
index 69b1e5f..c0e4dee 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
@@ -61,6 +61,9 @@ public final class SpecificTypeStrategies {
     /** See {@link DecimalPlusTypeStrategy}. */
     public static final TypeStrategy DECIMAL_PLUS = new DecimalPlusTypeStrategy();
 
+    /** See {@link AggDecimalPlusTypeStrategy}. */
+    public static final TypeStrategy AGG_DECIMAL_PLUS = new AggDecimalPlusTypeStrategy();
+
     /** See {@link DecimalScale0TypeStrategy}. */
     public static final TypeStrategy DECIMAL_SCALE_0 = new DecimalScale0TypeStrategy();
 
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
index 8cf5d96..e7aab32 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
@@ -304,15 +304,19 @@ public final class LogicalTypeMerging {
      *
      * <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
      *
-     * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach
-     * of Spark/Hive is followed for adjusting the precision.
+     * <p>The rules (although inspired by SQL Server) are not followed 100%, instead the approach of
+     * Spark/Hive is followed for adjusting the precision.
      *
      * <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html
      *
-     * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31
+     * <p>For (38, 8) + (32, 8) -> (39, 8) (The rules for addition, initially calculate a decimal
+     * type, assuming its precision is infinite) results in a decimal with integral part of 31
+     * digits.
      *
-     * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead
-     * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
+     * <p>This method is called subsequently to adjust the resulting decimal since the maximum
+     * allowed precision is 38 (so far a precision of 39 is calculated in the first step). So, the
+     * rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead we
+     * follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
      */
     private static DecimalType adjustPrecisionScale(int precision, int scale) {
         if (precision <= DecimalType.MAX_PRECISION) {
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
index 22c4bf8..8217130 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.table.types.logical.utils;
 
 import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
 
 import org.junit.Test;
 
@@ -27,7 +28,14 @@ import java.util.List;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */
+/**
+ * Tests for {@link LogicalTypeMerging} for finding the result decimal type for the various
+ * operations, e.g.: {@link LogicalTypeMerging#findSumAggType(LogicalType)}, {@link
+ * LogicalTypeMerging#findAdditionDecimalType(int, int, int, int)}, etc.
+ *
+ * <p>For {@link LogicalTypeMerging#findCommonType(List)} tests please check {@link
+ * org.apache.flink.table.types.LogicalCommonTypeTest}
+ */
 public class LogicalTypeMergingTest {
 
     @Test
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
index 9effb05..1e58d93 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.planner.expressions;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.table.expressions.ApiExpressionUtils;
 import org.apache.flink.table.expressions.Expression;
 import org.apache.flink.table.expressions.TypeLiteralExpression;
@@ -28,6 +29,7 @@ import org.apache.flink.table.types.DataType;
 
 import java.util.List;
 
+import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT;
@@ -100,6 +102,13 @@ public class ExpressionBuilder {
         return call(PLUS, input1, input2);
     }
 
+    // Used only for implementing the SumAggFunction to avoid overriding decimal precision/scale
+    // calculation for sum with the rules applied for the normal plus
+    @Internal
+    public static UnresolvedCallExpression aggDecimalPlus(Expression input1, Expression input2) {
+        return call(AGG_DECIMAL_PLUS, input1, input2);
+    }
+
     public static UnresolvedCallExpression minus(Expression input1, Expression input2) {
         return call(MINUS, input1, input2);
     }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
index 4e93800..ba3ebbb 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
@@ -28,16 +28,16 @@ import org.apache.flink.table.types.logical.DecimalType;
 import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
 
 import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.aggDecimalPlus;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
 import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
 
 /** built-in sum aggregate function. */
 public abstract class SumAggFunction extends DeclarativeAggregateFunction {
-    private UnresolvedReferenceExpression sum = unresolvedRef("sum");
+
+    protected UnresolvedReferenceExpression sum = unresolvedRef("sum");
 
     @Override
     public int operandCount() {
@@ -62,11 +62,13 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     @Override
     public Expression[] accumulateExpressions() {
         return new Expression[] {
-            /* sum = */ adjustSumType(
+            /* sum = */ ifThenElse(
+                    isNull(operand(0)),
+                    sum,
                     ifThenElse(
                             isNull(operand(0)),
                             sum,
-                            ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))))
+                            ifThenElse(isNull(sum), operand(0), doPlus(sum, operand(0)))))
         };
     }
 
@@ -79,17 +81,16 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     @Override
     public Expression[] mergeExpressions() {
         return new Expression[] {
-            /* sum = */ adjustSumType(
-                    ifThenElse(
-                            isNull(mergeOperand(sum)),
-                            sum,
-                            ifThenElse(
-                                    isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))))
+            /* sum = */ ifThenElse(
+                    isNull(mergeOperand(sum)),
+                    sum,
+                    ifThenElse(isNull(sum), mergeOperand(sum), doPlus(sum, mergeOperand(sum))))
         };
     }
 
-    private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
-        return cast(sumExpr, typeLiteral(getResultType()));
+    protected UnresolvedCallExpression doPlus(
+            UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+        return plus(arg1, arg2);
     }
 
     @Override
@@ -149,6 +150,7 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
     /** Built-in Decimal Sum aggregate function. */
     public static class DecimalSumAggFunction extends SumAggFunction {
         private DecimalType decimalType;
+        private DataType returnType;
 
         public DecimalSumAggFunction(DecimalType decimalType) {
             this.decimalType = decimalType;
@@ -156,8 +158,17 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
 
         @Override
         public DataType getResultType() {
-            DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
-            return DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+            if (returnType == null) {
+                DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
+                returnType = DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+            }
+            return returnType;
+        }
+
+        @Override
+        protected UnresolvedCallExpression doPlus(
+                UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+            return aggDecimalPlus(arg1, arg2);
         }
     }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
index 4cea7d2..32cfe0f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
@@ -806,21 +806,26 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
 
       case bsf: BridgingSqlFunction =>
         bsf.getDefinition match {
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
+          case BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
             generateWatermark(ctx, contextTerm, resultType)
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.GREATEST =>
+
+          case BuiltInFunctionDefinitions.GREATEST =>
             operands.foreach { operand =>
               requireComparable(operand)
             }
             generateGreatestLeast(resultType, operands)
-          case functionDefinition : FunctionDefinition
-            if functionDefinition eq BuiltInFunctionDefinitions.LEAST =>
+
+          case BuiltInFunctionDefinitions.LEAST =>
             operands.foreach { operand =>
               requireComparable(operand)
             }
-            generateGreatestLeast(resultType, operands, false)
+            generateGreatestLeast(resultType, operands, greatest = false)
+
+          case BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS =>
+            val left = operands.head
+            val right = operands(1)
+            generateBinaryArithmeticOperator(ctx, "+", resultType, left, right)
+
           case _ =>
             new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
         }
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 8db0b97..64bec0b 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -397,6 +397,34 @@ class AggregateITCase(
   }
 
   @Test
+  def testPrecisionForSumAggregationOnDecimal(): Unit = {
+    var t = tEnv.sqlQuery(
+        "select cast(sum(1.03520274) as DECIMAL(32, 8)), " +
+        "cast(sum(12345.035202748654) AS DECIMAL(30, 20)), " +
+        "cast(sum(12.345678901234567) AS DECIMAL(25, 22))")
+    var sink = new TestingRetractSink
+    t.toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+    var expected = List("1.03520274,12345.03520274865400000000,12.3456789012345670000000")
+    assertEquals(expected, sink.getRetractResults)
+
+    val data = new mutable.MutableList[(Double, Int)]
+    data .+= ((1.11111111, 1))
+    data .+= ((1.11111111, 2))
+    env.setParallelism(1)
+
+    t = failingDataSource(data).toTable(tEnv, 'a, 'b)
+    tEnv.registerTable("T", t)
+
+    t = tEnv.sqlQuery("select sum(cast(a as decimal(32, 8))) from T")
+    sink = new TestingRetractSink
+    t.toRetractStream[Row].addSink(sink)
+    env.execute()
+    expected = List("2.22222222")
+    assertEquals(expected, sink.getRetractResults)
+  }
+
+  @Test
   def testGroupByAgg(): Unit = {
     val data = new mutable.MutableList[(Int, Long, String)]
     data.+=((1, 1L, "A"))
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
index 89454c4..0c3a99a 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.scala._
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.api.internal.TableEnvironmentInternal
+import org.apache.flink.table.api.DataTypes.DECIMAL
 import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
 import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
 import org.apache.flink.table.planner.runtime.utils.TestData._
@@ -395,4 +396,26 @@ class AggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase
     val expected = mutable.MutableList("1,1", "2,3", "3,6", "4,10", "5,15", "6,21")
     assertEquals(expected.sorted, sink.getRetractResults.sorted)
   }
+
+  @Test
+  def testPrecisionForSumAggregationOnDecimal(): Unit = {
+    val data = new mutable.MutableList[(Double, Double, Double, Double)]
+    data.+=((1.03520274, 12345.035202748654, 12.345678901234567, 1.11111111))
+    data.+=((0, 0, 0, 1.11111111))
+    val t = failingDataSource(data).toTable(tEnv, 'a, 'b, 'c, 'd)
+
+    val results = t
+      .select('a.cast(DECIMAL(32, 8)).sum as 'a,
+        'b.cast(DECIMAL(30, 20)).sum as 'b,
+        'c.cast(DECIMAL(25, 20)).sum as 'c,
+        'd.cast(DECIMAL(32, 8)).sum as 'd)
+      .toRetractStream[Row]
+
+    val sink = new TestingRetractSink
+    results.addSink(sink).setParallelism(1)
+    env.execute()
+
+    val expected = List("1.03520274,12345.03520274865300000000,12.34567890123456700000,2.22222222")
+    assertEquals(expected, sink.getRetractResults)
+  }
 }