You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2021/11/09 08:03:00 UTC
[flink] branch release-1.14 updated (aa6ef87 -> bafb0b4)
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a change to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git.
from aa6ef87 [FLINK-24761][table] Fix PartitionPruner code gen compile fail
new 3610151 [hotfix][table-common][tests] Add tests for LogicalTypeMerging decimal rules
new 0c47ed7 [hotfix][table-common] Add java docs to LogicalTypeMerging
new bafb0b4 [FLINK-24691][table-planner] Fix decimal precision for SUM
The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
Summary of changes:
.../functions/BuiltInFunctionDefinitions.java | 19 ++++
...rategy.java => AggDecimalPlusTypeStrategy.java} | 50 +++++-----
.../strategies/SpecificTypeStrategies.java | 3 +
.../types/logical/utils/LogicalTypeMerging.java | 14 +++
.../logical/utils/LogicalTypeMergingTest.java | 109 +++++++++++++++++++++
.../planner/expressions/ExpressionBuilder.java | 9 ++
.../functions/aggfunctions/SumAggFunction.java | 41 +++++---
.../table/planner/codegen/ExprCodeGenerator.scala | 19 ++--
.../runtime/stream/sql/AggregateITCase.scala | 28 ++++++
.../runtime/stream/table/AggregateITCase.scala | 23 +++++
10 files changed, 265 insertions(+), 50 deletions(-)
copy flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/{DecimalPlusTypeStrategy.java => AggDecimalPlusTypeStrategy.java} (54%)
create mode 100644 flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
[flink] 01/03: [hotfix][table-common][tests] Add tests for
LogicalTypeMerging decimal rules
Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 36101511a96944ac3c05b1a265f5d726441f477a
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 09:54:07 2021 +0100
[hotfix][table-common][tests] Add tests for LogicalTypeMerging decimal rules
Add tests for all methods of `LogicalTypeMerging` which calculate the precision
and scale of the resulting decimal for arithmetic operations like `+ - * / % round`
as well as for `avg` and `sum` aggregate functions.
---
.../logical/utils/LogicalTypeMergingTest.java | 101 +++++++++++++++++++++
1 file changed, 101 insertions(+)
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
new file mode 100644
index 0000000..22c4bf8
--- /dev/null
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.logical.utils;
+
+import org.apache.flink.table.types.logical.DecimalType;
+
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */
+public class LogicalTypeMergingTest {
+
+ @Test
+ public void testFindDivisionDecimalType() {
+ assertThat(
+ LogicalTypeMerging.findDivisionDecimalType(32, 8, 38, 8), equalTo(decimal(38, 6)));
+ assertThat(
+ LogicalTypeMerging.findDivisionDecimalType(30, 20, 30, 20),
+ equalTo(decimal(38, 8)));
+ }
+
+ @Test
+ public void testFindMultiplicationDecimalType() {
+ assertThat(
+ LogicalTypeMerging.findMultiplicationDecimalType(30, 10, 30, 10),
+ equalTo(decimal(38, 6)));
+ assertThat(
+ LogicalTypeMerging.findMultiplicationDecimalType(30, 20, 30, 20),
+ equalTo(decimal(38, 17)));
+ assertThat(
+ LogicalTypeMerging.findMultiplicationDecimalType(38, 2, 38, 3),
+ equalTo(decimal(38, 5)));
+ }
+
+ @Test
+ public void testFindModuloDecimalType() {
+ assertThat(
+ LogicalTypeMerging.findModuloDecimalType(30, 10, 30, 10), equalTo(decimal(30, 10)));
+ assertThat(
+ LogicalTypeMerging.findModuloDecimalType(30, 20, 25, 20), equalTo(decimal(25, 20)));
+ assertThat(
+ LogicalTypeMerging.findModuloDecimalType(10, 10, 10, 10), equalTo(decimal(10, 10)));
+ }
+
+ @Test
+ public void testFindAdditionDecimalType() {
+ assertThat(
+ LogicalTypeMerging.findAdditionDecimalType(38, 8, 32, 8), equalTo(decimal(38, 7)));
+ assertThat(
+ LogicalTypeMerging.findAdditionDecimalType(32, 8, 38, 8), equalTo(decimal(38, 7)));
+ assertThat(
+ LogicalTypeMerging.findAdditionDecimalType(30, 20, 28, 20),
+ equalTo(decimal(31, 20)));
+ assertThat(
+ LogicalTypeMerging.findAdditionDecimalType(10, 10, 10, 10),
+ equalTo(decimal(11, 10)));
+ assertThat(
+ LogicalTypeMerging.findAdditionDecimalType(38, 5, 38, 4), equalTo(decimal(38, 5)));
+ }
+
+ @Test
+ public void testFindRoundingDecimalType() {
+ assertThat(LogicalTypeMerging.findRoundDecimalType(32, 8, 5), equalTo(decimal(30, 5)));
+ assertThat(LogicalTypeMerging.findRoundDecimalType(32, 8, 10), equalTo(decimal(32, 8)));
+ assertThat(LogicalTypeMerging.findRoundDecimalType(30, 20, 18), equalTo(decimal(29, 18)));
+ assertThat(LogicalTypeMerging.findRoundDecimalType(10, 10, 2), equalTo(decimal(3, 2)));
+ }
+
+ @Test
+ public void testFindAvgAggType() {
+ assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 20)), equalTo(decimal(38, 20)));
+ assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 2)), equalTo(decimal(38, 6)));
+ assertThat(LogicalTypeMerging.findAvgAggType(decimal(38, 8)), equalTo(decimal(38, 8)));
+ assertThat(LogicalTypeMerging.findAvgAggType(decimal(30, 20)), equalTo(decimal(38, 20)));
+ assertThat(LogicalTypeMerging.findAvgAggType(decimal(10, 10)), equalTo(decimal(38, 10)));
+ }
+
+ private static final DecimalType decimal(int precision, int scale) {
+ return new DecimalType(false, precision, scale);
+ }
+}
[flink] 03/03: [FLINK-24691][table-planner] Fix decimal precision
for SUM
Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git
commit bafb0b4c2377d6d502ed9dba8853631ebf16cfb7
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 13:03:26 2021 +0100
[FLINK-24691][table-planner] Fix decimal precision for SUM
Since SUM is using internally `plus()` operator to implement the sum
aggregation, the decimal return type calculated by `LogicalTypeMerging#findSumAggType()`
gets overriden by the calculation for the `plus()` operator done by
`LogicalTypeMerging#findAdditionDecimalType()`. To prevent this add a special
`aggDecimalPlus()` operator to be used exclusively for aggregate function to avoid
overriding their calculated precision.
This closes #17634.
---
.../functions/BuiltInFunctionDefinitions.java | 19 +++++++
.../strategies/AggDecimalPlusTypeStrategy.java | 66 ++++++++++++++++++++++
.../strategies/SpecificTypeStrategies.java | 3 +
.../types/logical/utils/LogicalTypeMerging.java | 14 +++--
.../logical/utils/LogicalTypeMergingTest.java | 10 +++-
.../planner/expressions/ExpressionBuilder.java | 9 +++
.../functions/aggfunctions/SumAggFunction.java | 41 +++++++++-----
.../table/planner/codegen/ExprCodeGenerator.scala | 19 ++++---
.../runtime/stream/sql/AggregateITCase.scala | 28 +++++++++
.../runtime/stream/table/AggregateITCase.scala | 23 ++++++++
10 files changed, 204 insertions(+), 28 deletions(-)
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index 022b7b1..d1695f6 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -29,6 +29,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies;
import org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies;
+import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeFamily;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
@@ -743,6 +744,24 @@ public final class BuiltInFunctionDefinitions {
explicit(DataTypes.STRING()))))
.build();
+ /**
+ * Special "+" operator used internally by {@code SumAggFunction} to implement SUM aggregation
+ * on a Decimal type. Uses the {@link LogicalTypeMerging#findSumAggType(LogicalType)} to avoid
+ * the normal {@link #PLUS} override the special calculation for precision and scale needed by
+ * SUM.
+ */
+ public static final BuiltInFunctionDefinition AGG_DECIMAL_PLUS =
+ BuiltInFunctionDefinition.newBuilder()
+ .name("AGG_DECIMAL_PLUS")
+ .kind(SCALAR)
+ .inputTypeStrategy(
+ sequence(
+ logical(LogicalTypeRoot.DECIMAL),
+ logical(LogicalTypeRoot.DECIMAL)))
+ .outputTypeStrategy(SpecificTypeStrategies.AGG_DECIMAL_PLUS)
+ .runtimeProvided()
+ .build();
+
/** Combines numeric subtraction and "datetime - interval" arithmetic. */
public static final BuiltInFunctionDefinition MINUS =
BuiltInFunctionDefinition.newBuilder()
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
new file mode 100644
index 0000000..23be242
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference.strategies;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
+import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
+import org.apache.flink.table.types.utils.TypeConversions;
+import org.apache.flink.util.Preconditions;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Type strategy that returns the result decimal addition used, internally by {@code SumAggFunction}
+ * to implement SUM aggregation on a Decimal type. Uses the {@link
+ * LogicalTypeMerging#findSumAggType(LogicalType)} and prevents the {@link DecimalPlusTypeStrategy}
+ * from overriding the special calculation for precision and scale needed by SUM.
+ */
+@Internal
+class AggDecimalPlusTypeStrategy implements TypeStrategy {
+
+ private static final String ERROR_MSG =
+ "Both args of "
+ + AggDecimalPlusTypeStrategy.class.getSimpleName()
+ + " should be of type["
+ + DecimalType.class.getSimpleName()
+ + "]";
+
+ @Override
+ public Optional<DataType> inferType(CallContext callContext) {
+ final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
+ final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType();
+ final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType();
+
+ Preconditions.checkArgument(
+ LogicalTypeChecks.hasRoot(addend1, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+ Preconditions.checkArgument(
+ LogicalTypeChecks.hasRoot(addend2, LogicalTypeRoot.DECIMAL), ERROR_MSG);
+
+ return Optional.of(
+ TypeConversions.fromLogicalToDataType(LogicalTypeMerging.findSumAggType(addend2)));
+ }
+}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
index 69b1e5f..c0e4dee 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java
@@ -61,6 +61,9 @@ public final class SpecificTypeStrategies {
/** See {@link DecimalPlusTypeStrategy}. */
public static final TypeStrategy DECIMAL_PLUS = new DecimalPlusTypeStrategy();
+ /** See {@link AggDecimalPlusTypeStrategy}. */
+ public static final TypeStrategy AGG_DECIMAL_PLUS = new AggDecimalPlusTypeStrategy();
+
/** See {@link DecimalScale0TypeStrategy}. */
public static final TypeStrategy DECIMAL_SCALE_0 = new DecimalScale0TypeStrategy();
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
index 8cf5d96..e7aab32 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
@@ -304,15 +304,19 @@ public final class LogicalTypeMerging {
*
* <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
*
- * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach
- * of Spark/Hive is followed for adjusting the precision.
+ * <p>The rules (although inspired by SQL Server) are not followed 100%, instead the approach of
+ * Spark/Hive is followed for adjusting the precision.
*
* <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html
*
- * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31
+ * <p>For (38, 8) + (32, 8) -> (39, 8) (The rules for addition, initially calculate a decimal
+ * type, assuming its precision is infinite) results in a decimal with integral part of 31
+ * digits.
*
- * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead
- * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
+ * <p>This method is called subsequently to adjust the resulting decimal since the maximum
+ * allowed precision is 38 (so far a precision of 39 is calculated in the first step). So, the
+ * rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead we
+ * follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
*/
private static DecimalType adjustPrecisionScale(int precision, int scale) {
if (precision <= DecimalType.MAX_PRECISION) {
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
index 22c4bf8..8217130 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java
@@ -19,6 +19,7 @@
package org.apache.flink.table.types.logical.utils;
import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
import org.junit.Test;
@@ -27,7 +28,14 @@ import java.util.List;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
-/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */
+/**
+ * Tests for {@link LogicalTypeMerging} for finding the result decimal type for the various
+ * operations, e.g.: {@link LogicalTypeMerging#findSumAggType(LogicalType)}, {@link
+ * LogicalTypeMerging#findAdditionDecimalType(int, int, int, int)}, etc.
+ *
+ * <p>For {@link LogicalTypeMerging#findCommonType(List)} tests please check {@link
+ * org.apache.flink.table.types.LogicalCommonTypeTest}
+ */
public class LogicalTypeMergingTest {
@Test
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
index 9effb05..1e58d93 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
@@ -18,6 +18,7 @@
package org.apache.flink.table.planner.expressions;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.table.expressions.ApiExpressionUtils;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.TypeLiteralExpression;
@@ -28,6 +29,7 @@ import org.apache.flink.table.types.DataType;
import java.util.List;
+import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT;
@@ -100,6 +102,13 @@ public class ExpressionBuilder {
return call(PLUS, input1, input2);
}
+ // Used only for implementing the SumAggFunction to avoid overriding decimal precision/scale
+ // calculation for sum with the rules applied for the normal plus
+ @Internal
+ public static UnresolvedCallExpression aggDecimalPlus(Expression input1, Expression input2) {
+ return call(AGG_DECIMAL_PLUS, input1, input2);
+ }
+
public static UnresolvedCallExpression minus(Expression input1, Expression input2) {
return call(MINUS, input1, input2);
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
index 4e93800..ba3ebbb 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java
@@ -28,16 +28,16 @@ import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.aggDecimalPlus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
-import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
/** built-in sum aggregate function. */
public abstract class SumAggFunction extends DeclarativeAggregateFunction {
- private UnresolvedReferenceExpression sum = unresolvedRef("sum");
+
+ protected UnresolvedReferenceExpression sum = unresolvedRef("sum");
@Override
public int operandCount() {
@@ -62,11 +62,13 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
@Override
public Expression[] accumulateExpressions() {
return new Expression[] {
- /* sum = */ adjustSumType(
+ /* sum = */ ifThenElse(
+ isNull(operand(0)),
+ sum,
ifThenElse(
isNull(operand(0)),
sum,
- ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))))
+ ifThenElse(isNull(sum), operand(0), doPlus(sum, operand(0)))))
};
}
@@ -79,17 +81,16 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
@Override
public Expression[] mergeExpressions() {
return new Expression[] {
- /* sum = */ adjustSumType(
- ifThenElse(
- isNull(mergeOperand(sum)),
- sum,
- ifThenElse(
- isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))))
+ /* sum = */ ifThenElse(
+ isNull(mergeOperand(sum)),
+ sum,
+ ifThenElse(isNull(sum), mergeOperand(sum), doPlus(sum, mergeOperand(sum))))
};
}
- private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
- return cast(sumExpr, typeLiteral(getResultType()));
+ protected UnresolvedCallExpression doPlus(
+ UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+ return plus(arg1, arg2);
}
@Override
@@ -149,6 +150,7 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
/** Built-in Decimal Sum aggregate function. */
public static class DecimalSumAggFunction extends SumAggFunction {
private DecimalType decimalType;
+ private DataType returnType;
public DecimalSumAggFunction(DecimalType decimalType) {
this.decimalType = decimalType;
@@ -156,8 +158,17 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction {
@Override
public DataType getResultType() {
- DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
- return DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+ if (returnType == null) {
+ DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType);
+ returnType = DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale());
+ }
+ return returnType;
+ }
+
+ @Override
+ protected UnresolvedCallExpression doPlus(
+ UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) {
+ return aggDecimalPlus(arg1, arg2);
}
}
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
index 4cea7d2..32cfe0f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
@@ -806,21 +806,26 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
case bsf: BridgingSqlFunction =>
bsf.getDefinition match {
- case functionDefinition : FunctionDefinition
- if functionDefinition eq BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
+ case BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
generateWatermark(ctx, contextTerm, resultType)
- case functionDefinition : FunctionDefinition
- if functionDefinition eq BuiltInFunctionDefinitions.GREATEST =>
+
+ case BuiltInFunctionDefinitions.GREATEST =>
operands.foreach { operand =>
requireComparable(operand)
}
generateGreatestLeast(resultType, operands)
- case functionDefinition : FunctionDefinition
- if functionDefinition eq BuiltInFunctionDefinitions.LEAST =>
+
+ case BuiltInFunctionDefinitions.LEAST =>
operands.foreach { operand =>
requireComparable(operand)
}
- generateGreatestLeast(resultType, operands, false)
+ generateGreatestLeast(resultType, operands, greatest = false)
+
+ case BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS =>
+ val left = operands.head
+ val right = operands(1)
+ generateBinaryArithmeticOperator(ctx, "+", resultType, left, right)
+
case _ =>
new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 8db0b97..64bec0b 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -397,6 +397,34 @@ class AggregateITCase(
}
@Test
+ def testPrecisionForSumAggregationOnDecimal(): Unit = {
+ var t = tEnv.sqlQuery(
+ "select cast(sum(1.03520274) as DECIMAL(32, 8)), " +
+ "cast(sum(12345.035202748654) AS DECIMAL(30, 20)), " +
+ "cast(sum(12.345678901234567) AS DECIMAL(25, 22))")
+ var sink = new TestingRetractSink
+ t.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ var expected = List("1.03520274,12345.03520274865400000000,12.3456789012345670000000")
+ assertEquals(expected, sink.getRetractResults)
+
+ val data = new mutable.MutableList[(Double, Int)]
+ data .+= ((1.11111111, 1))
+ data .+= ((1.11111111, 2))
+ env.setParallelism(1)
+
+ t = failingDataSource(data).toTable(tEnv, 'a, 'b)
+ tEnv.registerTable("T", t)
+
+ t = tEnv.sqlQuery("select sum(cast(a as decimal(32, 8))) from T")
+ sink = new TestingRetractSink
+ t.toRetractStream[Row].addSink(sink)
+ env.execute()
+ expected = List("2.22222222")
+ assertEquals(expected, sink.getRetractResults)
+ }
+
+ @Test
def testGroupByAgg(): Unit = {
val data = new mutable.MutableList[(Int, Long, String)]
data.+=((1, 1L, "A"))
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
index 89454c4..0c3a99a 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.internal.TableEnvironmentInternal
+import org.apache.flink.table.api.DataTypes.DECIMAL
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.runtime.utils.TestData._
@@ -395,4 +396,26 @@ class AggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase
val expected = mutable.MutableList("1,1", "2,3", "3,6", "4,10", "5,15", "6,21")
assertEquals(expected.sorted, sink.getRetractResults.sorted)
}
+
+ @Test
+ def testPrecisionForSumAggregationOnDecimal(): Unit = {
+ val data = new mutable.MutableList[(Double, Double, Double, Double)]
+ data.+=((1.03520274, 12345.035202748654, 12.345678901234567, 1.11111111))
+ data.+=((0, 0, 0, 1.11111111))
+ val t = failingDataSource(data).toTable(tEnv, 'a, 'b, 'c, 'd)
+
+ val results = t
+ .select('a.cast(DECIMAL(32, 8)).sum as 'a,
+ 'b.cast(DECIMAL(30, 20)).sum as 'b,
+ 'c.cast(DECIMAL(25, 20)).sum as 'c,
+ 'd.cast(DECIMAL(32, 8)).sum as 'd)
+ .toRetractStream[Row]
+
+ val sink = new TestingRetractSink
+ results.addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List("1.03520274,12345.03520274865300000000,12.34567890123456700000,2.22222222")
+ assertEquals(expected, sink.getRetractResults)
+ }
}
[flink] 02/03: [hotfix][table-common] Add java docs to
LogicalTypeMerging
Posted by tw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 0c47ed79b4ea7e3804309426c61f3ecdc3b96273
Author: Marios Trivyzas <ma...@gmail.com>
AuthorDate: Thu Nov 4 10:06:25 2021 +0100
[hotfix][table-common] Add java docs to LogicalTypeMerging
Add more explanation in `LogicalTypeMerging#adjustPrecisionScale()` method
regarding the decision to not follow 100% the Microsoft's SQL Server rules
but instead the Hive/Spark behaviour, when calculating the resulting precision
of a decimal operation.
---
.../flink/table/types/logical/utils/LogicalTypeMerging.java | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
index dfb1ec1..8cf5d96 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java
@@ -303,6 +303,16 @@ public final class LogicalTypeMerging {
* integral part of a result from being truncated.
*
* <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
+ *
+ * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach
+ * of Spark/Hive is followed for adjusting the precision.
+ *
+ * <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html
+ *
+ * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31
+ *
+ * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead
+ * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31
*/
private static DecimalType adjustPrecisionScale(int precision, int scale) {
if (precision <= DecimalType.MAX_PRECISION) {