You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2023/01/18 14:14:40 UTC
[flink] branch master updated: [FLINK-29719][hive] Supports native count function for hive dialect
This is an automated email from the ASF dual-hosted git repository.
godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 606f297198a [FLINK-29719][hive] Supports native count function for hive dialect
606f297198a is described below
commit 606f297198acd74a5c1a39700bd84ad9e26e7b82
Author: fengli <ld...@163.com>
AuthorDate: Wed Jan 4 15:31:03 2023 +0800
[FLINK-29719][hive] Supports native count function for hive dialect
This closes #21596
---
.../table/functions/hive/HiveCountAggFunction.java | 116 +++++++++++++++++++++
.../apache/flink/table/module/hive/HiveModule.java | 6 +-
.../connectors/hive/HiveDialectAggITCase.java | 72 +++++++++++--
.../connectors/hive/HiveDialectQueryPlanTest.java | 24 ++++-
.../explain/testCountAggFunctionFallbackPlan.out | 35 +++++++
.../resources/explain/testCountAggFunctionPlan.out | 27 +++++
6 files changed, 268 insertions(+), 12 deletions(-)
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java
new file mode 100644
index 00000000000..e15a0cbaf3d
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.hive;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.expressions.Expression;
+import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
+import org.apache.flink.table.planner.expressions.ExpressionBuilder;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
+
+/** built-in hive count aggregate function. */
+public class HiveCountAggFunction extends HiveDeclarativeAggregateFunction {
+
+ private final UnresolvedReferenceExpression count = unresolvedRef("count");
+ private Integer arguments;
+ private boolean countLiteral;
+
+ @Override
+ public int operandCount() {
+ return arguments;
+ }
+
+ @Override
+ public UnresolvedReferenceExpression[] aggBufferAttributes() {
+ return new UnresolvedReferenceExpression[] {count};
+ }
+
+ @Override
+ public DataType[] getAggBufferTypes() {
+ return new DataType[] {DataTypes.BIGINT()};
+ }
+
+ @Override
+ public DataType getResultType() {
+ return DataTypes.BIGINT();
+ }
+
+ @Override
+ public Expression[] initialValuesExpressions() {
+ return new Expression[] {/* count = */ literal(0L, getResultType().notNull())};
+ }
+
+ @Override
+ public Expression[] accumulateExpressions() {
+ // count(*) and count(literal) mean that count all elements
+ if (arguments == 0 || countLiteral) {
+ return new Expression[] {/* count = */ plus(count, literal(1L))};
+ }
+
+ // other case need to determine the value of the element
+ List<Expression> operandExpressions = new ArrayList<>();
+ for (int i = 0; i < arguments; i++) {
+ operandExpressions.add(operand(i));
+ }
+ Expression operandExpression =
+ operandExpressions.stream()
+ .map(ExpressionBuilder::isNull)
+ .reduce(ExpressionBuilder::or)
+ .get();
+ return new Expression[] {
+ /* count = */ ifThenElse(operandExpression, count, plus(count, literal(1L)))
+ };
+ }
+
+ @Override
+ public Expression[] retractExpressions() {
+ throw new TableException("Count aggregate function does not support retraction.");
+ }
+
+ @Override
+ public Expression[] mergeExpressions() {
+ return new Expression[] {/* count = */ plus(count, mergeOperand(count))};
+ }
+
+ @Override
+ public Expression getValueExpression() {
+ return count;
+ }
+
+ @Override
+ public void setArguments(CallContext callContext) {
+ if (arguments == null) {
+ arguments = callContext.getArgumentDataTypes().size();
+ if (arguments == 1) {
+ // If the argument is literal indicates use count(literal)
+ countLiteral = callContext.isArgumentLiteral(0);
+ }
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
index 6ca6ca84dd9..bb598891fef 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
@@ -26,6 +26,7 @@ import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory;
import org.apache.flink.table.factories.FunctionDefinitionFactory;
import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.functions.hive.HiveCountAggFunction;
import org.apache.flink.table.functions.hive.HiveMinAggFunction;
import org.apache.flink.table.functions.hive.HiveSumAggFunction;
import org.apache.flink.table.module.Module;
@@ -86,7 +87,7 @@ public class HiveModule implements Module {
"tumble_start")));
static final Set<String> BUILTIN_NATIVE_AGG_FUNC =
- Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "min")));
+ Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "count", "min")));
private final HiveFunctionDefinitionFactory factory;
private final String hiveVersion;
@@ -206,6 +207,9 @@ public class HiveModule implements Module {
case "sum":
// We override Hive's sum function by native implementation to supports hash-agg
return Optional.of(new HiveSumAggFunction());
+ case "count":
+ // We override Hive's sum function by native implementation to supports hash-agg
+ return Optional.of(new HiveCountAggFunction());
case "min":
// We override Hive's min function by native implementation to supports hash-agg
return Optional.of(new HiveMinAggFunction());
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
index bd124c7d554..3af77ad72b6 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
@@ -30,7 +30,6 @@ import org.apache.flink.util.CollectionUtil;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
@@ -58,6 +57,8 @@ public class HiveDialectAggITCase {
hiveCatalog.getHiveConf().setVar(HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT, "none");
hiveCatalog.open();
tableEnv = getTableEnvWithHiveCatalog();
+ // enable native hive agg function
+ tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true);
// create tables
tableEnv.executeSql("create table foo (x int, y int)");
@@ -71,12 +72,6 @@ public class HiveDialectAggITCase {
.commit();
}
- @Before
- public void before() {
- // enable native hive agg function
- tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true);
- }
-
@Test
public void testSimpleSumAggFunction() throws Exception {
tableEnv.executeSql(
@@ -167,6 +162,69 @@ public class HiveDialectAggITCase {
tableEnv.executeSql("drop table test_sum_group");
}
+ @Test
+ public void testSimpleCount() throws Exception {
+ tableEnv.executeSql("create table test_count(a int, x string, y string, z int, d bigint)");
+ tableEnv.executeSql(
+ "insert into test_count values (1, NULL, '2', 1, 2), "
+ + "(1, NULL, 'b', 2, NULL), "
+ + "(2, NULL, '4', 1, 2), "
+ + "(2, NULL, NULL, 4, 3)")
+ .await();
+
+ // test count(*)
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select count(*) from test_count").collect());
+ assertThat(result.toString()).isEqualTo("[+I[4]]");
+
+ // test count(1)
+ List<Row> result2 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select count(1) from test_count").collect());
+ assertThat(result2.toString()).isEqualTo("[+I[4]]");
+
+ // test count(col1)
+ List<Row> result3 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select count(y) from test_count").collect());
+ assertThat(result3.toString()).isEqualTo("[+I[3]]");
+
+ // test count(distinct col1)
+ List<Row> result4 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select count(distinct z) from test_count").collect());
+ assertThat(result4.toString()).isEqualTo("[+I[3]]");
+
+ // test count(distinct col1, col2)
+ List<Row> result5 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select count(distinct z, d) from test_count")
+ .collect());
+ assertThat(result5.toString()).isEqualTo("[+I[2]]");
+
+ tableEnv.executeSql("drop table test_count");
+ }
+
+ @Test
+ public void testCountAggWithGroupKey() throws Exception {
+ tableEnv.executeSql(
+ "create table test_count_group(a int, x string, y string, z int, d bigint)");
+ tableEnv.executeSql(
+ "insert into test_count_group values (1, NULL, '2', 1, 2), "
+ + "(1, NULL, '2', 2, NULL), "
+ + "(2, NULL, '4', 1, 2), "
+ + "(2, NULL, 3, 4, 3)")
+ .await();
+
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql(
+ "select count(*), count(x), count(distinct y), count(distinct z, d) from test_count_group group by a")
+ .collect());
+ assertThat(result.toString()).isEqualTo("[+I[2, 0, 1, 1], +I[2, 0, 2, 2]]");
+ }
+
@Test
public void testMinAggFunction() throws Exception {
tableEnv.executeSql(
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
index 69b8ca9179f..48cc8b913d1 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
@@ -70,25 +70,41 @@ public class HiveDialectQueryPlanTest {
@Test
public void testSumAggFunctionPlan() {
// test explain
- String actualPlan = explainSql("select x, sum(y) from foo group by x");
+ String sql = "select x, sum(y) from foo group by x";
+ String actualPlan = explainSql(sql);
assertThat(actualPlan).isEqualTo(readFromResource("/explain/testSumAggFunctionPlan.out"));
// test fallback to hive sum udaf
tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
- String actualSortAggPlan = explainSql("select x, sum(y) from foo group by x");
+ String actualSortAggPlan = explainSql(sql);
assertThat(actualSortAggPlan)
.isEqualTo(readFromResource("/explain/testSumAggFunctionFallbackPlan.out"));
}
+ @Test
+ public void testCountAggFunctionPlan() {
+ // test explain
+ String sql = "select x, count(*), count(y), count(distinct y) from foo group by x";
+ String actualPlan = explainSql(sql);
+ assertThat(actualPlan).isEqualTo(readFromResource("/explain/testCountAggFunctionPlan.out"));
+
+ // test fallback to hive count udaf
+ tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
+ String actualSortAggPlan = explainSql(sql);
+ assertThat(actualSortAggPlan)
+ .isEqualTo(readFromResource("/explain/testCountAggFunctionFallbackPlan.out"));
+ }
+
@Test
public void testMinAggFunctionPlan() {
// test explain
- String actualPlan = explainSql("select x, min(y) from foo group by x");
+ String sql = "select x, min(y) from foo group by x";
+ String actualPlan = explainSql(sql);
assertThat(actualPlan).isEqualTo(readFromResource("/explain/testMinAggFunctionPlan.out"));
// test fallback to hive min udaf
tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
- String actualSortAggPlan = explainSql("select x, min(y) from foo group by x");
+ String actualSortAggPlan = explainSql(sql);
assertThat(actualSortAggPlan)
.isEqualTo(readFromResource("/explain/testMinAggFunctionFallbackPlan.out"));
}
diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out
new file mode 100644
index 00000000000..e356f402212
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out
@@ -0,0 +1,35 @@
+== Abstract Syntax Tree ==
+LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3])
++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)])
+ +- LogicalProject($f0=[$0], $f1=[$1])
+ +- LogicalTableScan(table=[[test-catalog, default, foo]])
+
+== Optimized Physical Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3])
++- Sort(orderBy=[x ASC])
+ +- Exchange(distribution=[hash[x]])
+ +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3])
+ +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1])
+ +- Sort(orderBy=[x ASC])
+ +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2])
+ +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+ +- Exchange(distribution=[hash[x, y, $e]])
+ +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2])
+ +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+ +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+ +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
+
+== Optimized Execution Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3])
++- Sort(orderBy=[x ASC])
+ +- Exchange(distribution=[hash[x]])
+ +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3])
+ +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1])
+ +- Sort(orderBy=[x ASC])
+ +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2])
+ +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+ +- Exchange(distribution=[hash[x, y, $e]])
+ +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2])
+ +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+ +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+ +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out
new file mode 100644
index 00000000000..fc6e8b6d8cb
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out
@@ -0,0 +1,27 @@
+== Abstract Syntax Tree ==
+LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3])
++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)])
+ +- LogicalProject($f0=[$0], $f1=[$1])
+ +- LogicalTableScan(table=[[test-catalog, default, foo]])
+
+== Optimized Physical Plan ==
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3])
++- Exchange(distribution=[hash[x]])
+ +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2])
+ +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1])
+ +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2])
+ +- Exchange(distribution=[hash[x, y, $e]])
+ +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1])
+ +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+ +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
+
+== Optimized Execution Plan ==
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3])
++- Exchange(distribution=[hash[x]])
+ +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2])
+ +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1])
+ +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2])
+ +- Exchange(distribution=[hash[x, y, $e]])
+ +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1])
+ +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+ +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])