You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/02/27 23:35:22 UTC
[1/2] flink git commit: [FLINK-3475] [table] Add support for DISTINCT
aggregates in SQL queries.
Repository: flink
Updated Branches:
refs/heads/master 8bcb2ae3c -> 1a062b796
[FLINK-3475] [table] Add support for DISTINCT aggregates in SQL queries.
This closes #3111.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/36c9348f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/36c9348f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/36c9348f
Branch: refs/heads/master
Commit: 36c9348ff06cae1fe55925bcc6081154be2f10f5
Parents: 8bcb2ae
Author: Zhenghua Gao <do...@gmail.com>
Authored: Thu Jan 12 10:33:27 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Mon Feb 27 22:50:13 2017 +0100
----------------------------------------------------------------------
docs/dev/table_api.md | 3 +-
...nkAggregateExpandDistinctAggregatesRule.java | 1152 ++++++++++++++++++
.../flink/table/plan/rules/FlinkRuleSets.scala | 5 +-
.../rules/dataSet/DataSetAggregateRule.scala | 3 -
.../DataSetAggregateWithNullValuesRule.scala | 3 -
.../scala/batch/sql/AggregationsITCase.scala | 27 +-
.../scala/batch/sql/DistinctAggregateTest.scala | 476 ++++++++
.../batch/sql/QueryDecorrelationTest.scala | 2 +-
8 files changed, 1651 insertions(+), 20 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 22fd636..1b43b7c 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1324,13 +1324,12 @@ val result = tableEnv.sql(
#### Limitations
-The current version supports selection (filter), projection, inner equi-joins, grouping, non-distinct aggregates, and sorting on batch tables.
+The current version supports selection (filter), projection, inner equi-joins, grouping, aggregates, and sorting on batch tables.
Among others, the following SQL features are not supported, yet:
- Timestamps and intervals are limited to milliseconds precision
- Interval arithmetic is currenly limited
-- Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
- Non-equi joins and Cartesian products
- Efficient grouping sets
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
new file mode 100644
index 0000000..d7b1ffa
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
@@ -0,0 +1,1152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.calcite.rules;
+
+import org.apache.calcite.plan.Contexts;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+
+import org.apache.flink.util.Preconditions;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
+
+/**
+ Copy calcite's AggregateExpandDistinctAggregatesRule to Flink project,
+ and do a quick fix to avoid some bad case mentioned in CALCITE-1558.
+ Should drop it and use calcite's AggregateExpandDistinctAggregatesRule
+ when we upgrade to calcite 1.12(above)
+ */
+
+/**
+ * Planner rule that expands distinct aggregates
+ * (such as {@code COUNT(DISTINCT x)}) from a
+ * {@link org.apache.calcite.rel.logical.LogicalAggregate}.
+ *
+ * <p>How this is done depends upon the arguments to the function. If all
+ * functions have the same argument
+ * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
+ * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
+ * sufficient.
+ *
+ * <p>If there are multiple arguments
+ * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
+ * the rule creates separate {@code Aggregate}s and combines using a
+ * {@link org.apache.calcite.rel.core.Join}.
+ */
+public final class FlinkAggregateExpandDistinctAggregatesRule extends RelOptRule {
+ //~ Static fields/initializers ---------------------------------------------
+
+ /** The default instance of the rule; operates only on logical expressions. */
+ public static final FlinkAggregateExpandDistinctAggregatesRule INSTANCE =
+ new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true,
+ RelFactories.LOGICAL_BUILDER);
+
+ /** Instance of the rule that operates only on logical expressions and
+ * generates a join. */
+ public static final FlinkAggregateExpandDistinctAggregatesRule JOIN =
+ new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false,
+ RelFactories.LOGICAL_BUILDER);
+
+ private static final BigDecimal TWO = BigDecimal.valueOf(2L);
+
+ public final boolean useGroupingSets;
+
+ //~ Constructors -----------------------------------------------------------
+
+ public FlinkAggregateExpandDistinctAggregatesRule(
+ Class<? extends LogicalAggregate> clazz,
+ boolean useGroupingSets,
+ RelBuilderFactory relBuilderFactory) {
+ super(operand(clazz, any()), relBuilderFactory, null);
+ this.useGroupingSets = useGroupingSets;
+ }
+
+ @Deprecated // to be removed before 2.0
+ public FlinkAggregateExpandDistinctAggregatesRule(
+ Class<? extends LogicalAggregate> clazz,
+ boolean useGroupingSets,
+ RelFactories.JoinFactory joinFactory) {
+ this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory)));
+ }
+
+ @Deprecated // to be removed before 2.0
+ public FlinkAggregateExpandDistinctAggregatesRule(
+ Class<? extends LogicalAggregate> clazz,
+ RelFactories.JoinFactory joinFactory) {
+ this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
+ }
+
+ //~ Methods ----------------------------------------------------------------
+
+ public void onMatch(RelOptRuleCall call) {
+ final Aggregate aggregate = call.rel(0);
+ if (!aggregate.containsDistinctCall()) {
+ return;
+ }
+
+ // Find all of the agg expressions. We use a LinkedHashSet to ensure
+ // determinism.
+ int nonDistinctCount = 0;
+ int distinctCount = 0;
+ int filterCount = 0;
+ int unsupportedAggCount = 0;
+ final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
+ for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ if (aggCall.filterArg >= 0) {
+ ++filterCount;
+ }
+ if (!aggCall.isDistinct()) {
+ ++nonDistinctCount;
+ if (!(aggCall.getAggregation() instanceof SqlCountAggFunction
+ || aggCall.getAggregation() instanceof SqlSumAggFunction
+ || aggCall.getAggregation() instanceof SqlMinMaxAggFunction)) {
+ ++unsupportedAggCount;
+ }
+ continue;
+ }
+ ++distinctCount;
+ argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
+ }
+ Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
+
+ // If all of the agg expressions are distinct and have the same
+ // arguments then we can use a more efficient form.
+ if (nonDistinctCount == 0 && argLists.size() == 1) {
+ final Pair<List<Integer>, Integer> pair =
+ Iterables.getOnlyElement(argLists);
+ final RelBuilder relBuilder = call.builder();
+ convertMonopole(relBuilder, aggregate, pair.left, pair.right);
+ call.transformTo(relBuilder.build());
+ return;
+ }
+
+ if (useGroupingSets) {
+ rewriteUsingGroupingSets(call, aggregate, argLists);
+ return;
+ }
+
+ // If only one distinct aggregate and one or more non-distinct aggregates,
+ // we can generate multi-phase aggregates
+ if (distinctCount == 1 // one distinct aggregate
+ && filterCount == 0 // no filter
+ && unsupportedAggCount == 0 // sum/min/max/count in non-distinct aggregate
+ && nonDistinctCount > 0) { // one or more non-distinct aggregates
+ final RelBuilder relBuilder = call.builder();
+ convertSingletonDistinct(relBuilder, aggregate, argLists);
+ call.transformTo(relBuilder.build());
+ return;
+ }
+
+ // Create a list of the expressions which will yield the final result.
+ // Initially, the expressions point to the input field.
+ final List<RelDataTypeField> aggFields =
+ aggregate.getRowType().getFieldList();
+ final List<RexInputRef> refs = new ArrayList<>();
+ final List<String> fieldNames = aggregate.getRowType().getFieldNames();
+ final ImmutableBitSet groupSet = aggregate.getGroupSet();
+ final int groupAndIndicatorCount =
+ aggregate.getGroupCount() + aggregate.getIndicatorCount();
+ for (int i : Util.range(groupAndIndicatorCount)) {
+ refs.add(RexInputRef.of(i, aggFields));
+ }
+
+ // Aggregate the original relation, including any non-distinct aggregates.
+ final List<AggregateCall> newAggCallList = new ArrayList<>();
+ int i = -1;
+ for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ ++i;
+ if (aggCall.isDistinct()) {
+ refs.add(null);
+ continue;
+ }
+ refs.add(
+ new RexInputRef(
+ groupAndIndicatorCount + newAggCallList.size(),
+ aggFields.get(groupAndIndicatorCount + i).getType()));
+ newAggCallList.add(aggCall);
+ }
+
+ // In the case where there are no non-distinct aggregates (regardless of
+ // whether there are group bys), there's no need to generate the
+ // extra aggregate and join.
+ final RelBuilder relBuilder = call.builder();
+ relBuilder.push(aggregate.getInput());
+ int n = 0;
+ if (!newAggCallList.isEmpty()) {
+ final RelBuilder.GroupKey groupKey =
+ relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
+ relBuilder.aggregate(groupKey, newAggCallList);
+ ++n;
+ }
+
+ // For each set of operands, find and rewrite all calls which have that
+ // set of operands.
+ for (Pair<List<Integer>, Integer> argList : argLists) {
+ doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
+ }
+
+ relBuilder.project(refs, fieldNames);
+ call.transformTo(relBuilder.build());
+ }
+
+ /**
+ * Converts an aggregate with one distinct aggregate and one or more
+ * non-distinct aggregates to multi-phase aggregates (see reference example
+ * below).
+ *
+ * @param relBuilder Contains the input relational expression
+ * @param aggregate Original aggregate
+ * @param argLists Arguments and filters to the distinct aggregate function
+ *
+ */
+ private RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
+ Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+ // For example,
+ // SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
+ // FROM emp
+ // GROUP BY deptno
+ //
+ // becomes
+ //
+ // SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
+ // FROM (
+ // SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
+ // FROM EMP
+ // GROUP BY deptno, sal) // Aggregate B
+ // GROUP BY deptno // Aggregate A
+ relBuilder.push(aggregate.getInput());
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+ final Map<Integer, Integer> sourceOf = new HashMap<>();
+ SortedSet<Integer> newGroupSet = new TreeSet<>();
+ final List<RelDataTypeField> childFields =
+ relBuilder.peek().getRowType().getFieldList();
+ final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
+
+ SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
+
+ // Add the distinct aggregate column(s) to the group-by columns,
+ // if not already a part of the group-by
+ newGroupSet.addAll(aggregate.getGroupSet().asList());
+ for (Pair<List<Integer>, Integer> argList : argLists) {
+ newGroupSet.addAll(argList.getKey());
+ }
+
+ // Re-map the arguments to the aggregate A. These arguments will get
+ // remapped because of the intermediate aggregate B generated as part of the
+ // transformation.
+ for (int arg : newGroupSet) {
+ sourceOf.put(arg, projects.size());
+ projects.add(RexInputRef.of2(arg, childFields));
+ }
+ // Generate the intermediate aggregate B
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+ final List<AggregateCall> newAggCalls = new ArrayList<>();
+ final List<Integer> fakeArgs = new ArrayList<>();
+ final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
+ // First identify the real arguments, then use the rest for fake arguments
+ // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
+ for (final AggregateCall aggCall : aggCalls) {
+ if (!aggCall.isDistinct()) {
+ for (int arg : aggCall.getArgList()) {
+ if (!groupSet.contains(arg)) {
+ sourceOf.put(arg, projects.size());
+ }
+ }
+ }
+ }
+ int fakeArg0 = 0;
+ for (final AggregateCall aggCall : aggCalls) {
+ // We will deal with non-distinct aggregates below
+ if (!aggCall.isDistinct()) {
+ boolean isGroupKeyUsedInAgg = false;
+ for (int arg : aggCall.getArgList()) {
+ if (groupSet.contains(arg)) {
+ isGroupKeyUsedInAgg = true;
+ break;
+ }
+ }
+ if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
+ while (sourceOf.get(fakeArg0) != null) {
+ ++fakeArg0;
+ }
+ fakeArgs.add(fakeArg0);
+ ++fakeArg0;
+ }
+ }
+ }
+ for (final AggregateCall aggCall : aggCalls) {
+ if (!aggCall.isDistinct()) {
+ for (int arg : aggCall.getArgList()) {
+ if (!groupSet.contains(arg)) {
+ sourceOf.remove(arg);
+ }
+ }
+ }
+ }
+ // Compute the remapped arguments using fake arguments for non-distinct
+ // aggregates with no arguments e.g. count(*).
+ int fakeArgIdx = 0;
+ for (final AggregateCall aggCall : aggCalls) {
+ // Project the column corresponding to the distinct aggregate. Project
+ // as-is all the non-distinct aggregates
+ if (!aggCall.isDistinct()) {
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getAggregation(), false,
+ aggCall.getArgList(), -1,
+ ImmutableBitSet.of(newGroupSet).cardinality(),
+ relBuilder.peek(), null, aggCall.name);
+ newAggCalls.add(newCall);
+ if (newCall.getArgList().size() == 0) {
+ int fakeArg = fakeArgs.get(fakeArgIdx);
+ callArgMap.put(newCall, fakeArg);
+ sourceOf.put(fakeArg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+ newCall.getName()));
+ ++fakeArgIdx;
+ } else {
+ for (int arg : newCall.getArgList()) {
+ if (groupSet.contains(arg)) {
+ int fakeArg = fakeArgs.get(fakeArgIdx);
+ callArgMap.put(newCall, fakeArg);
+ sourceOf.put(fakeArg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+ newCall.getName()));
+ ++fakeArgIdx;
+ } else {
+ sourceOf.put(arg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
+ newCall.getName()));
+ }
+ }
+ }
+ }
+ }
+ // Generate the aggregate B (see the reference example above)
+ relBuilder.push(
+ aggregate.copy(
+ aggregate.getTraitSet(), relBuilder.build(),
+ false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
+ // Convert the existing aggregate to aggregate A (see the reference example above)
+ final List<AggregateCall> newTopAggCalls =
+ Lists.newArrayList(aggregate.getAggCallList());
+ // Use the remapped arguments for the (non)distinct aggregate calls
+ for (int i = 0; i < newTopAggCalls.size(); i++) {
+ // Re-map arguments.
+ final AggregateCall aggCall = newTopAggCalls.get(i);
+ final int argCount = aggCall.getArgList().size();
+ final List<Integer> newArgs = new ArrayList<>(argCount);
+ final AggregateCall newCall;
+
+
+ for (int j = 0; j < argCount; j++) {
+ final Integer arg = aggCall.getArgList().get(j);
+ if (callArgMap.containsKey(aggCall)) {
+ newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+ }
+ else {
+ newArgs.add(sourceOf.get(arg));
+ }
+ }
+ if (aggCall.isDistinct()) {
+ newCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+ -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.name);
+ } else {
+ // If aggregate B had a COUNT aggregate call the corresponding aggregate at
+ // aggregate A must be SUM. For other aggregates, it remains the same.
+ if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
+ if (aggCall.getArgList().size() == 0) {
+ newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+ }
+ if (hasGroupBy) {
+ SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
+ newCall =
+ AggregateCall.create(sumAgg, false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.getName());
+ } else {
+ SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
+ newCall =
+ AggregateCall.create(sumAgg, false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.getName());
+ }
+ } else {
+ newCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(),
+ relBuilder.peek(), aggCall.getType(), aggCall.name);
+ }
+ }
+ newTopAggCalls.set(i, newCall);
+ }
+ // Populate the group-by keys with the remapped arguments for aggregate A
+ newGroupSet.clear();
+ for (int arg : aggregate.getGroupSet()) {
+ newGroupSet.add(sourceOf.get(arg));
+ }
+ relBuilder.push(
+ aggregate.copy(aggregate.getTraitSet(),
+ relBuilder.build(), aggregate.indicator,
+ ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
+ return relBuilder;
+ }
+ /*
+ public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
+ Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+ // For example,
+ // SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
+ // FROM emp
+ // GROUP BY deptno
+ //
+ // becomes
+ //
+ // SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
+ // FROM (
+ // SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
+ // FROM EMP
+ // GROUP BY deptno, sal) // Aggregate B
+ // GROUP BY deptno // Aggregate A
+ relBuilder.push(aggregate.getInput());
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+ final Map<Integer, Integer> sourceOf = new HashMap<>();
+ SortedSet<Integer> newGroupSet = new TreeSet<>();
+ final List<RelDataTypeField> childFields =
+ relBuilder.peek().getRowType().getFieldList();
+ final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
+
+ // Add the distinct aggregate column(s) to the group-by columns,
+ // if not already a part of the group-by
+ newGroupSet.addAll(aggregate.getGroupSet().asList());
+ for (Pair<List<Integer>, Integer> argList : argLists) {
+ newGroupSet.addAll(argList.getKey());
+ }
+
+ // Re-map the arguments to the aggregate A. These arguments will get
+ // remapped because of the intermediate aggregate B generated as part of the
+ // transformation.
+ for (int arg : newGroupSet) {
+ sourceOf.put(arg, projects.size());
+ projects.add(RexInputRef.of2(arg, childFields));
+ }
+ // Generate the intermediate aggregate B
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+ final List<AggregateCall> newAggCalls = new ArrayList<>();
+ final List<Integer> fakeArgs = new ArrayList<>();
+ final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
+ // First identify the real arguments, then use the rest for fake arguments
+ // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
+ for (final AggregateCall aggCall : aggCalls) {
+ if (!aggCall.isDistinct()) {
+ for (int arg : aggCall.getArgList()) {
+ if (!sourceOf.containsKey(arg)) {
+ sourceOf.put(arg, projects.size());
+ }
+ }
+ }
+ }
+ int fakeArg0 = 0;
+ for (final AggregateCall aggCall : aggCalls) {
+ // We will deal with non-distinct aggregates below
+ if (!aggCall.isDistinct()) {
+ boolean isGroupKeyUsedInAgg = false;
+ for (int arg : aggCall.getArgList()) {
+ if (sourceOf.containsKey(arg)) {
+ isGroupKeyUsedInAgg = true;
+ break;
+ }
+ }
+ if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
+ while (sourceOf.get(fakeArg0) != null) {
+ ++fakeArg0;
+ }
+ fakeArgs.add(fakeArg0);
+ }
+ }
+ }
+ for (final AggregateCall aggCall : aggCalls) {
+ if (!aggCall.isDistinct()) {
+ for (int arg : aggCall.getArgList()) {
+ if (!sourceOf.containsKey(arg)) {
+ sourceOf.remove(arg);
+ }
+ }
+ }
+ }
+ // Compute the remapped arguments using fake arguments for non-distinct
+ // aggregates with no arguments e.g. count(*).
+ int fakeArgIdx = 0;
+ for (final AggregateCall aggCall : aggCalls) {
+ // Project the column corresponding to the distinct aggregate. Project
+ // as-is all the non-distinct aggregates
+ if (!aggCall.isDistinct()) {
+ final AggregateCall newCall =
+ AggregateCall.create(aggCall.getAggregation(), false,
+ aggCall.getArgList(), -1,
+ ImmutableBitSet.of(newGroupSet).cardinality(),
+ relBuilder.peek(), null, aggCall.name);
+ newAggCalls.add(newCall);
+ if (newCall.getArgList().size() == 0) {
+ int fakeArg = fakeArgs.get(fakeArgIdx);
+ callArgMap.put(newCall, fakeArg);
+ sourceOf.put(fakeArg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+ newCall.getName()));
+ ++fakeArgIdx;
+ } else {
+ for (int arg : newCall.getArgList()) {
+ if (sourceOf.containsKey(arg)) {
+ int fakeArg = fakeArgs.get(fakeArgIdx);
+ callArgMap.put(newCall, fakeArg);
+ sourceOf.put(fakeArg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+ newCall.getName()));
+ ++fakeArgIdx;
+ } else {
+ sourceOf.put(arg, projects.size());
+ projects.add(
+ Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
+ newCall.getName()));
+ }
+ }
+ }
+ }
+ }
+ // Generate the aggregate B (see the reference example above)
+ relBuilder.push(
+ aggregate.copy(
+ aggregate.getTraitSet(), relBuilder.build(),
+ false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
+ // Convert the existing aggregate to aggregate A (see the reference example above)
+ final List<AggregateCall> newTopAggCalls =
+ Lists.newArrayList(aggregate.getAggCallList());
+ // Use the remapped arguments for the (non)distinct aggregate calls
+ for (int i = 0; i < newTopAggCalls.size(); i++) {
+ // Re-map arguments.
+ final AggregateCall aggCall = newTopAggCalls.get(i);
+ final int argCount = aggCall.getArgList().size();
+ final List<Integer> newArgs = new ArrayList<>(argCount);
+ final AggregateCall newCall;
+
+
+ for (int j = 0; j < argCount; j++) {
+ final Integer arg = aggCall.getArgList().get(j);
+ if (callArgMap.containsKey(aggCall)) {
+ newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+ }
+ else {
+ newArgs.add(sourceOf.get(arg));
+ }
+ }
+ if (aggCall.isDistinct()) {
+ newCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+ -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.name);
+ } else {
+ // If aggregate B had a COUNT aggregate call the corresponding aggregate at
+ // aggregate A must be SUM. For other aggregates, it remains the same.
+ if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
+ if (aggCall.getArgList().size() == 0) {
+ newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+ }
+ if (hasGroupBy) {
+ SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
+ newCall =
+ AggregateCall.create(sumAgg, false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.getName());
+ } else {
+ SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
+ newCall =
+ AggregateCall.create(sumAgg, false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+ aggCall.getType(), aggCall.getName());
+ }
+ } else {
+ newCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+ aggregate.getGroupSet().cardinality(),
+ relBuilder.peek(), aggCall.getType(), aggCall.name);
+ }
+ }
+ newTopAggCalls.set(i, newCall);
+ }
+ // Populate the group-by keys with the remapped arguments for aggregate A
+ newGroupSet.clear();
+ for (int arg : aggregate.getGroupSet()) {
+ newGroupSet.add(sourceOf.get(arg));
+ }
+ relBuilder.push(
+ aggregate.copy(aggregate.getTraitSet(),
+ relBuilder.build(), aggregate.indicator,
+ ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
+ return relBuilder;
+ }
+ */
+
+ @SuppressWarnings("DanglingJavadoc")
+ private void rewriteUsingGroupingSets(RelOptRuleCall call,
+ Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+ final Set<ImmutableBitSet> groupSetTreeSet =
+ new TreeSet<>(ImmutableBitSet.ORDERING);
+ groupSetTreeSet.add(aggregate.getGroupSet());
+ for (Pair<List<Integer>, Integer> argList : argLists) {
+ groupSetTreeSet.add(
+ ImmutableBitSet.of(argList.left)
+ .setIf(argList.right, argList.right >= 0)
+ .union(aggregate.getGroupSet()));
+ }
+
+ final ImmutableList<ImmutableBitSet> groupSets =
+ ImmutableList.copyOf(groupSetTreeSet);
+ final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
+
+ final List<AggregateCall> distinctAggCalls = new ArrayList<>();
+ for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
+ if (!aggCall.left.isDistinct()) {
+ distinctAggCalls.add(aggCall.left.rename(aggCall.right));
+ }
+ }
+
+ final RelBuilder relBuilder = call.builder();
+ relBuilder.push(aggregate.getInput());
+ relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets),
+ distinctAggCalls);
+ final RelNode distinct = relBuilder.peek();
+ final int groupCount = fullGroupSet.cardinality();
+ final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
+
+ final RelOptCluster cluster = aggregate.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ final RelDataType booleanType =
+ typeFactory.createTypeWithNullability(
+ typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
+ final List<Pair<RexNode, String>> predicates = new ArrayList<>();
+ final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
+
+ /** Function to register a filter for a group set. */
+ class Registrar {
+ RexNode group = null;
+
+ private int register(ImmutableBitSet groupSet) {
+ if (group == null) {
+ group = makeGroup(groupCount - 1);
+ }
+ final RexNode node =
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group,
+ rexBuilder.makeExactLiteral(
+ toNumber(remap(fullGroupSet, groupSet))));
+ predicates.add(Pair.of(node, toString(groupSet)));
+ return groupCount + indicatorCount + distinctAggCalls.size()
+ + predicates.size() - 1;
+ }
+
+ private RexNode makeGroup(int i) {
+ final RexInputRef ref =
+ rexBuilder.makeInputRef(booleanType, groupCount + i);
+ final RexNode kase =
+ rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref,
+ rexBuilder.makeExactLiteral(BigDecimal.ZERO),
+ rexBuilder.makeExactLiteral(TWO.pow(i)));
+ if (i == 0) {
+ return kase;
+ } else {
+ return rexBuilder.makeCall(SqlStdOperatorTable.PLUS,
+ makeGroup(i - 1), kase);
+ }
+ }
+
+ private BigDecimal toNumber(ImmutableBitSet bitSet) {
+ BigDecimal n = BigDecimal.ZERO;
+ for (int key : bitSet) {
+ n = n.add(TWO.pow(key));
+ }
+ return n;
+ }
+
+ private String toString(ImmutableBitSet bitSet) {
+ final StringBuilder buf = new StringBuilder("$i");
+ for (int key : bitSet) {
+ buf.append(key).append('_');
+ }
+ return buf.substring(0, buf.length() - 1);
+ }
+ }
+ final Registrar registrar = new Registrar();
+ for (ImmutableBitSet groupSet : groupSets) {
+ filters.put(groupSet, registrar.register(groupSet));
+ }
+
+ if (!predicates.isEmpty()) {
+ List<Pair<RexNode, String>> nodes = new ArrayList<>();
+ for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
+ final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
+ nodes.add(Pair.of(node, f.getName()));
+ }
+ nodes.addAll(predicates);
+ relBuilder.project(Pair.left(nodes), Pair.right(nodes));
+ }
+
+ int x = groupCount + indicatorCount;
+ final List<AggregateCall> newCalls = new ArrayList<>();
+ for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ final int newFilterArg;
+ final List<Integer> newArgList;
+ final SqlAggFunction aggregation;
+ if (!aggCall.isDistinct()) {
+ aggregation = SqlStdOperatorTable.MIN;
+ newArgList = ImmutableIntList.of(x++);
+ newFilterArg = filters.get(aggregate.getGroupSet());
+ } else {
+ aggregation = aggCall.getAggregation();
+ newArgList = remap(fullGroupSet, aggCall.getArgList());
+ newFilterArg =
+ filters.get(
+ ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(aggregate.getGroupSet()));
+ }
+ final AggregateCall newCall =
+ AggregateCall.create(aggregation, false, newArgList, newFilterArg,
+ aggregate.getGroupCount(), distinct, null, aggCall.name);
+ newCalls.add(newCall);
+ }
+
+ relBuilder.aggregate(
+ relBuilder.groupKey(
+ remap(fullGroupSet, aggregate.getGroupSet()),
+ aggregate.indicator,
+ remap(fullGroupSet, aggregate.getGroupSets())),
+ newCalls);
+ relBuilder.convert(aggregate.getRowType(), true);
+ call.transformTo(relBuilder.build());
+ }
+
+ private static ImmutableBitSet remap(ImmutableBitSet groupSet,
+ ImmutableBitSet bitSet) {
+ final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
+ for (Integer bit : bitSet) {
+ builder.set(remap(groupSet, bit));
+ }
+ return builder.build();
+ }
+
+ private static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet,
+ Iterable<ImmutableBitSet> bitSets) {
+ final ImmutableList.Builder<ImmutableBitSet> builder =
+ ImmutableList.builder();
+ for (ImmutableBitSet bitSet : bitSets) {
+ builder.add(remap(groupSet, bitSet));
+ }
+ return builder.build();
+ }
+
+ private static List<Integer> remap(ImmutableBitSet groupSet,
+ List<Integer> argList) {
+ ImmutableIntList list = ImmutableIntList.of();
+ for (int arg : argList) {
+ list = list.append(remap(groupSet, arg));
+ }
+ return list;
+ }
+
+ private static int remap(ImmutableBitSet groupSet, int arg) {
+ return arg < 0 ? -1 : groupSet.indexOf(arg);
+ }
+
+ /**
+ * Converts an aggregate relational expression that contains just one
+ * distinct aggregate function (or perhaps several over the same arguments)
+ * and no non-distinct aggregate functions.
+ */
+ private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate,
+ List<Integer> argList, int filterArg) {
+ // For example,
+ // SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
+ // FROM emp
+ // GROUP BY deptno
+ //
+ // becomes
+ //
+ // SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
+ // FROM (
+ // SELECT DISTINCT deptno, sal AS distinct_sal
+ // FROM EMP GROUP BY deptno)
+ // GROUP BY deptno
+
+ // Project the columns of the GROUP BY plus the arguments
+ // to the agg function.
+ final Map<Integer, Integer> sourceOf = new HashMap<>();
+ createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
+
+ // Create an aggregate on top, with the new aggregate list.
+ final List<AggregateCall> newAggCalls =
+ Lists.newArrayList(aggregate.getAggCallList());
+ rewriteAggCalls(newAggCalls, argList, sourceOf);
+ final int cardinality = aggregate.getGroupSet().cardinality();
+ relBuilder.push(
+ aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
+ aggregate.indicator, ImmutableBitSet.range(cardinality), null,
+ newAggCalls));
+ return relBuilder;
+ }
+
+ /**
+ * Converts all distinct aggregate calls to a given set of arguments.
+ *
+ * <p>This method is called several times, one for each set of arguments.
+ * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
+ * relational expression, and modifies the set of top-level calls.
+ *
+ * @param aggregate Original aggregate
+ * @param n Ordinal of this in a join. {@code relBuilder} contains the
+ * input relational expression (either the original
+ * aggregate, the output from the previous call to this
+ * method. {@code n} is 0 if we're converting the
+ * first distinct aggregate in a query with no non-distinct
+ * aggregates)
+ * @param argList Arguments to the distinct aggregate function
+ * @param filterArg Argument that filters input to aggregate function, or -1
+ * @param refs Array of expressions which will be the projected by the
+ * result of this rule. Those relating to this arg list will
+ * be modified @return Relational expression
+ */
+ private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
+ List<Integer> argList, int filterArg, List<RexInputRef> refs) {
+ final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
+ final List<RelDataTypeField> leftFields;
+ if (n == 0) {
+ leftFields = null;
+ } else {
+ leftFields = relBuilder.peek().getRowType().getFieldList();
+ }
+
+ // LogicalAggregate(
+ // child,
+ // {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
+ //
+ // becomes
+ //
+ // LogicalAggregate(
+ // LogicalJoin(
+ // child,
+ // LogicalAggregate(child, < all columns > {}),
+ // INNER,
+ // <f2 = f5>))
+ //
+ // E.g.
+ // SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
+ // FROM Emps
+ // GROUP BY deptno
+ //
+ // becomes
+ //
+ // SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
+ // FROM (
+ // SELECT deptno, MAX(age) as max_age
+ // FROM Emps GROUP BY deptno) AS e
+ // JOIN (
+ // SELECT deptno, COUNT(gender) AS count_gender FROM (
+ // SELECT DISTINCT deptno, gender FROM Emps) AS dgender
+ // GROUP BY deptno) AS adgender
+ // ON e.deptno = adgender.deptno
+ // JOIN (
+ // SELECT deptno, SUM(sal) AS sum_sal FROM (
+ // SELECT DISTINCT deptno, sal FROM Emps) AS dsal
+ // GROUP BY deptno) AS adsal
+ // ON e.deptno = adsal.deptno
+ // GROUP BY e.deptno
+ //
+ // Note that if a query contains no non-distinct aggregates, then the
+ // very first join/group by is omitted. In the example above, if
+ // MAX(age) is removed, then the sub-select of "e" is not needed, and
+ // instead the two other group by's are joined to one another.
+
+ // Project the columns of the GROUP BY plus the arguments
+ // to the agg function.
+ final Map<Integer, Integer> sourceOf = new HashMap<>();
+ createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
+
+ // Now compute the aggregate functions on top of the distinct dataset.
+ // Each distinct agg becomes a non-distinct call to the corresponding
+ // field from the right; for example,
+ // "COUNT(DISTINCT e.sal)"
+ // becomes
+ // "COUNT(distinct_e.sal)".
+ final List<AggregateCall> aggCallList = new ArrayList<>();
+ final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+ final int groupAndIndicatorCount =
+ aggregate.getGroupCount() + aggregate.getIndicatorCount();
+ int i = groupAndIndicatorCount - 1;
+ for (AggregateCall aggCall : aggCalls) {
+ ++i;
+
+ // Ignore agg calls which are not distinct or have the wrong set
+ // arguments. If we're rewriting aggs whose args are {sal}, we will
+ // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
+ // COUNT(DISTINCT gender) or SUM(sal).
+ if (!aggCall.isDistinct()) {
+ continue;
+ }
+ if (!aggCall.getArgList().equals(argList)) {
+ continue;
+ }
+
+ // Re-map arguments.
+ final int argCount = aggCall.getArgList().size();
+ final List<Integer> newArgs = new ArrayList<>(argCount);
+ for (int j = 0; j < argCount; j++) {
+ final Integer arg = aggCall.getArgList().get(j);
+ newArgs.add(sourceOf.get(arg));
+ }
+ final int newFilterArg =
+ aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
+ final AggregateCall newAggCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+ newFilterArg, aggCall.getType(), aggCall.getName());
+ assert refs.get(i) == null;
+ if (n == 0) {
+ refs.set(i,
+ new RexInputRef(groupAndIndicatorCount + aggCallList.size(),
+ newAggCall.getType()));
+ } else {
+ refs.set(i,
+ new RexInputRef(leftFields.size() + groupAndIndicatorCount
+ + aggCallList.size(), newAggCall.getType()));
+ }
+ aggCallList.add(newAggCall);
+ }
+
+ final Map<Integer, Integer> map = new HashMap<>();
+ for (Integer key : aggregate.getGroupSet()) {
+ map.put(key, map.size());
+ }
+ final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
+ assert newGroupSet
+ .equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
+ ImmutableList<ImmutableBitSet> newGroupingSets = null;
+ if (aggregate.indicator) {
+ newGroupingSets =
+ ImmutableBitSet.ORDERING.immutableSortedCopy(
+ ImmutableBitSet.permute(aggregate.getGroupSets(), map));
+ }
+
+ relBuilder.push(
+ aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
+ aggregate.indicator, newGroupSet, newGroupingSets, aggCallList));
+
+ // If there's no left child yet, no need to create the join
+ if (n == 0) {
+ return;
+ }
+
+ // Create the join condition. It is of the form
+ // 'left.f0 = right.f0 and left.f1 = right.f1 and ...'
+ // where {f0, f1, ...} are the GROUP BY fields.
+ final List<RelDataTypeField> distinctFields =
+ relBuilder.peek().getRowType().getFieldList();
+ final List<RexNode> conditions = Lists.newArrayList();
+ for (i = 0; i < groupAndIndicatorCount; ++i) {
+ // null values form its own group
+ // use "is not distinct from" so that the join condition
+ // allows null values to match.
+ conditions.add(
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
+ RexInputRef.of(i, leftFields),
+ new RexInputRef(leftFields.size() + i,
+ distinctFields.get(i).getType())));
+ }
+
+ // Join in the new 'select distinct' relation.
+ relBuilder.join(JoinRelType.INNER, conditions);
+ }
+
+ private static void rewriteAggCalls(
+ List<AggregateCall> newAggCalls,
+ List<Integer> argList,
+ Map<Integer, Integer> sourceOf) {
+ // Rewrite the agg calls. Each distinct agg becomes a non-distinct call
+ // to the corresponding field from the right; for example,
+ // "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
+ for (int i = 0; i < newAggCalls.size(); i++) {
+ final AggregateCall aggCall = newAggCalls.get(i);
+
+ // Ignore agg calls which are not distinct or have the wrong set
+ // arguments. If we're rewriting aggregates whose args are {sal}, we will
+ // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
+ // COUNT(DISTINCT gender) or SUM(sal).
+ if (!aggCall.isDistinct()) {
+ continue;
+ }
+ if (!aggCall.getArgList().equals(argList)) {
+ continue;
+ }
+
+ // Re-map arguments.
+ final int argCount = aggCall.getArgList().size();
+ final List<Integer> newArgs = new ArrayList<>(argCount);
+ for (int j = 0; j < argCount; j++) {
+ final Integer arg = aggCall.getArgList().get(j);
+ newArgs.add(sourceOf.get(arg));
+ }
+ final AggregateCall newAggCall =
+ AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+ aggCall.getType(), aggCall.getName());
+ newAggCalls.set(i, newAggCall);
+ }
+ }
+
+ /**
+ * Given an {@link org.apache.calcite.rel.logical.LogicalAggregate}
+ * and the ordinals of the arguments to a
+ * particular call to an aggregate function, creates a 'select distinct'
+ * relational expression which projects the group columns and those
+ * arguments but nothing else.
+ *
+ * <p>For example, given
+ *
+ * <blockquote>
+ * <pre>select f0, count(distinct f1), count(distinct f2)
+ * from t group by f0</pre>
+ * </blockquote>
+ *
+ * and the argument list
+ *
+ * <blockquote>{2}</blockquote>
+ *
+ * returns
+ *
+ * <blockquote>
+ * <pre>select distinct f0, f2 from t</pre>
+ * </blockquote>
+ *
+ * '
+ *
+ * <p>The <code>sourceOf</code> map is populated with the source of each
+ * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.</p>
+ *
+ * @param relBuilder Relational expression builder
+ * @param aggregate Aggregate relational expression
+ * @param argList Ordinals of columns to make distinct
+ * @param filterArg Ordinal of column to filter on, or -1
+ * @param sourceOf Out parameter, is populated with a map of where each
+ * output field came from
+ * @return Aggregate relational expression which projects the required
+ * columns
+ */
+ private RelBuilder createSelectDistinct(RelBuilder relBuilder,
+ Aggregate aggregate, List<Integer> argList, int filterArg,
+ Map<Integer, Integer> sourceOf) {
+ relBuilder.push(aggregate.getInput());
+ final List<Pair<RexNode, String>> projects = new ArrayList<>();
+ final List<RelDataTypeField> childFields =
+ relBuilder.peek().getRowType().getFieldList();
+ for (int i : aggregate.getGroupSet()) {
+ sourceOf.put(i, projects.size());
+ projects.add(RexInputRef.of2(i, childFields));
+ }
+ for (Integer arg : argList) {
+ if (filterArg >= 0) {
+ // Implement
+ // agg(DISTINCT arg) FILTER $f
+ // by generating
+ // SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
+ // and then applying
+ // agg(arg)
+ // as usual.
+ //
+ // It works except for (rare) agg functions that need to see null
+ // values.
+ final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
+ final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
+ final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
+ RexNode condition =
+ rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef,
+ argRef.left,
+ rexBuilder.ensureType(argRef.left.getType(),
+ rexBuilder.constantNull(), true));
+ sourceOf.put(arg, projects.size());
+ projects.add(Pair.of(condition, "i$" + argRef.right));
+ continue;
+ }
+ if (sourceOf.get(arg) != null) {
+ continue;
+ }
+ sourceOf.put(arg, projects.size());
+ projects.add(RexInputRef.of2(arg, childFields));
+ }
+ relBuilder.project(Pair.left(projects), Pair.right(projects));
+
+ // Get the distinct values of the GROUP BY fields and the arguments
+ // to the agg functions.
+ relBuilder.push(
+ aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false,
+ ImmutableBitSet.range(projects.size()),
+ null, ImmutableList.<AggregateCall>of()));
+ return relBuilder;
+ }
+}
+
+// End AggregateExpandDistinctAggregatesRule.java
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index f9c8d8d..8f16d32 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.rules
import org.apache.calcite.rel.rules._
import org.apache.calcite.tools.{RuleSet, RuleSets}
-import org.apache.flink.table.calcite.rules.FlinkAggregateJoinTransposeRule
+import org.apache.flink.table.calcite.rules.{FlinkAggregateExpandDistinctAggregatesRule, FlinkAggregateJoinTransposeRule}
import org.apache.flink.table.plan.rules.dataSet._
import org.apache.flink.table.plan.rules.datastream._
import org.apache.flink.table.plan.rules.datastream.{DataStreamCalcRule, DataStreamScanRule, DataStreamUnionRule}
@@ -102,6 +102,9 @@ object FlinkRuleSets {
ProjectToCalcRule.INSTANCE,
CalcMergeRule.INSTANCE,
+ // distinct aggregate rule for FLINK-3475
+ FlinkAggregateExpandDistinctAggregatesRule.JOIN,
+
// translate to Flink DataSet nodes
DataSetWindowAggregateRule.INSTANCE,
DataSetAggregateRule.INSTANCE,
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
index d1f932e..9c0acdd 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -44,9 +44,6 @@ class DataSetAggregateRule
// check if we have distinct aggregates
val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
- if (distinctAggs) {
- throw TableException("DISTINCT aggregates are currently not supported.")
- }
!distinctAggs
}
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
index e8084fa..aa977b1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -50,9 +50,6 @@ class DataSetAggregateWithNullValuesRule
// check if we have distinct aggregates
val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
- if (distinctAggs) {
- throw TableException("DISTINCT aggregates are currently not supported.")
- }
!distinctAggs
}
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
index d7e429c..a60cfaa 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
@@ -213,7 +213,7 @@ class AggregationsITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
- @Test(expected = classOf[TableException])
+ @Test
def testDistinctAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
@@ -221,14 +221,17 @@ class AggregationsITCase(
val sqlQuery = "SELECT sum(_1) as a, count(distinct _3) as b FROM MyTable"
- val ds = CollectionDataSets.get3TupleDataSet(env)
- tEnv.registerDataSet("MyTable", ds)
+ val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv)
+ tEnv.registerTable("MyTable", ds)
- // must fail. distinct aggregates are not supported
- tEnv.sql(sqlQuery).toDataSet[Row]
+ val result = tEnv.sql(sqlQuery)
+
+ val expected = "231,21"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
}
- @Test(expected = classOf[TableException])
+ @Test
def testGroupedDistinctAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
@@ -236,11 +239,15 @@ class AggregationsITCase(
val sqlQuery = "SELECT _2, avg(distinct _1) as a, count(_3) as b FROM MyTable GROUP BY _2"
- val ds = CollectionDataSets.get3TupleDataSet(env)
- tEnv.registerDataSet("MyTable", ds)
+ val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv)
+ tEnv.registerTable("MyTable", ds)
- // must fail. distinct aggregates are not supported
- tEnv.sql(sqlQuery).toDataSet[Row]
+ val result = tEnv.sql(sqlQuery)
+
+ val expected =
+ "6,18,6\n5,13,5\n4,8,4\n3,5,3\n2,2,2\n1,1,1"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
new file mode 100644
index 0000000..38e4ea8
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.scala.batch.sql
+
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+class DistinctAggregateTest extends TableTestBase {
+
+ @Test
+ def testSingleDistinctAggregate(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT COUNT(DISTINCT a) FROM MyTable"
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a")
+ ),
+ term("groupBy", "a"),
+ term("select", "a")
+ ),
+ tuples(List(null)),
+ term("values", "a")
+ ),
+ term("union", "a")
+ ),
+ term("select", "COUNT(a) AS EXPR$0")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testMultiDistinctAggregateOnSameColumn(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT a), MAX(DISTINCT a) FROM MyTable"
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a")
+ ),
+ term("groupBy", "a"),
+ term("select", "a")
+ ),
+ tuples(List(null)),
+ term("values", "a")
+ ),
+ term("union", "a")
+ ),
+ term("select", "COUNT(a) AS EXPR$0", "SUM(a) AS EXPR$1", "MAX(a) AS EXPR$2")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSingleDistinctAggregateAndOneOrMultiNonDistinctAggregate(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ // case 0x00: DISTINCT on COUNT and Non-DISTINCT on others
+ val sqlQuery0 = "SELECT COUNT(DISTINCT a), SUM(b) FROM MyTable"
+
+ val expected0 = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a"),
+ term("select", "a", "SUM(b) AS EXPR$1")
+ ),
+ tuples(List(null, null)),
+ term("values", "a", "EXPR$1")
+ ),
+ term("union", "a", "EXPR$1")
+ ),
+ term("select", "COUNT(a) AS EXPR$0", "SUM(EXPR$1) AS EXPR$1")
+ )
+
+ util.verifySql(sqlQuery0, expected0)
+
+ // case 0x01: Non-DISTINCT on COUNT and DISTINCT on others
+ val sqlQuery1 = "SELECT COUNT(a), SUM(DISTINCT b) FROM MyTable"
+
+ val expected1 = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "b"),
+ term("select", "b", "COUNT(a) AS EXPR$0")
+ ),
+ tuples(List(null, null)),
+ term("values", "b", "EXPR$0")
+ ),
+ term("union", "b", "EXPR$0")
+ ),
+ term("select", "$SUM0(EXPR$0) AS EXPR$0", "SUM(b) AS EXPR$1")
+ )
+
+ util.verifySql(sqlQuery1, expected1)
+ }
+
+ @Test
+ def testMultiDistinctAggregateOnDifferentColumn(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b) FROM MyTable"
+
+ val expected = binaryNode(
+ "DataSetSingleRowJoin",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a")
+ ),
+ term("groupBy", "a"),
+ term("select", "a")
+ ),
+ tuples(List(null)),
+ term("values", "a")
+ ),
+ term("union", "a")
+ ),
+ term("select", "COUNT(a) AS EXPR$0")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "b")
+ ),
+ term("groupBy", "b"),
+ term("select", "b")
+ ),
+ tuples(List(null)),
+ term("values", "b")
+ ),
+ term("union", "b")
+ ),
+ term("select", "SUM(b) AS EXPR$1")
+ ),
+ term("where", "true"),
+ term("join", "EXPR$0", "EXPR$1"),
+ term("joinType", "NestedLoopJoin")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testMultiDistinctAndNonDistinctAggregateOnDifferentColumn(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b), COUNT(c) FROM MyTable"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ binaryNode(
+ "DataSetSingleRowJoin",
+ binaryNode(
+ "DataSetSingleRowJoin",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ batchTableNode(0),
+ tuples(List(null, null, null)),
+ term("values", "a, b, c")
+ ),
+ term("union", "a, b, c")
+ ),
+ term("select", "COUNT(c) AS EXPR$2")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a")
+ ),
+ term("groupBy", "a"),
+ term("select", "a")
+ ),
+ tuples(List(null)),
+ term("values", "a")
+ ),
+ term("union", "a")
+ ),
+ term("select", "COUNT(a) AS EXPR$0")
+ ),
+ term("where", "true"),
+ term("join", "EXPR$2, EXPR$0"),
+ term("joinType", "NestedLoopJoin")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "b")
+ ),
+ term("groupBy", "b"),
+ term("select", "b")
+ ),
+ tuples(List(null)),
+ term("values", "b")
+ ),
+ term("union", "b")
+ ),
+ term("select", "SUM(b) AS EXPR$1")
+ ),
+ term("where", "true"),
+ term("join", "EXPR$2", "EXPR$0, EXPR$1"),
+ term("joinType", "NestedLoopJoin")
+ ),
+ term("select", "EXPR$0, EXPR$1, EXPR$2")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSingleDistinctAggregateWithGrouping(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT a, COUNT(a), SUM(DISTINCT b) FROM MyTable GROUP BY a"
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a", "b"),
+ term("select", "a", "b", "COUNT(a) AS EXPR$1")
+ ),
+ term("groupBy", "a"),
+ term("select", "a", "SUM(EXPR$1) AS EXPR$1", "SUM(b) AS EXPR$2")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSingleDistinctAggregateWithGroupingAndCountStar(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b) FROM MyTable GROUP BY a"
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a", "b"),
+ term("select", "a", "b", "COUNT(*) AS EXPR$1")
+ ),
+ term("groupBy", "a"),
+ term("select", "a", "SUM(EXPR$1) AS EXPR$1", "SUM(b) AS EXPR$2")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testTwoDistinctAggregateWithGroupingAndCountStar(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT b) FROM MyTable GROUP BY a"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ binaryNode(
+ "DataSetJoin",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a"),
+ term("select", "a", "COUNT(*) AS EXPR$1")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a, b"),
+ term("select", "a, b")
+ ),
+ term("groupBy", "a"),
+ term("select", "a, SUM(b) AS EXPR$2, COUNT(b) AS EXPR$3")
+ ),
+ term("where", "IS NOT DISTINCT FROM(a, a0)"),
+ term("join", "a, EXPR$1, a0, EXPR$2, EXPR$3"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "a, EXPR$1, EXPR$2, EXPR$3")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testTwoDifferentDistinctAggregateWithGroupingAndCountStar(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT c) FROM MyTable GROUP BY a"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ binaryNode(
+ "DataSetJoin",
+ unaryNode(
+ "DataSetCalc",
+ binaryNode(
+ "DataSetJoin",
+ unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "a"),
+ term("select", "a, COUNT(*) AS EXPR$1")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("groupBy", "a, b"),
+ term("select", "a, b")
+ ),
+ term("groupBy", "a"),
+ term("select", "a, SUM(b) AS EXPR$2")
+ ),
+ term("where", "IS NOT DISTINCT FROM(a, a0)"),
+ term("join", "a, EXPR$1, a0, EXPR$2"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "a, EXPR$1, EXPR$2")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c")
+ ),
+ term("groupBy", "a, c"),
+ term("select", "a, c")
+ ),
+ term("groupBy", "a"),
+ term("select", "a, COUNT(c) AS EXPR$3")
+ ),
+ term("where", "IS NOT DISTINCT FROM(a, a0)"),
+ term("join", "a, EXPR$1, EXPR$2, a0, EXPR$3"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "a, EXPR$1, EXPR$2, EXPR$3")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
index abf71e2..516fcd2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
@@ -85,7 +85,7 @@ class QueryDecorrelationTest extends TableTestBase {
term("join", "empno", "salary", "empno0"),
term("joinType", "InnerJoin")
),
- term("select", "salary", "empno0")
+ term("select", "empno0", "salary")
),
term("groupBy", "empno0"),
term("select", "empno0", "AVG(salary) AS EXPR$0")
[2/2] flink git commit: [FLINK-5907] [java] Fix handling of trailing
empty fields in CsvInputFormat.
Posted by fh...@apache.org.
[FLINK-5907] [java] Fix handling of trailing empty fields in CsvInputFormat.
This closes #3417.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1a062b79
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1a062b79
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1a062b79
Branch: refs/heads/master
Commit: 1a062b796274c9f63caeb2bf12aad96e34efd0aa
Parents: 36c9348
Author: Kurt Young <yk...@gmail.com>
Authored: Sat Feb 25 16:37:37 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Mon Feb 27 22:50:19 2017 +0100
----------------------------------------------------------------------
.../api/common/io/GenericCsvInputFormat.java | 30 +++++++-
.../apache/flink/types/parser/FieldParser.java | 21 +++++
.../common/io/GenericCsvInputFormatTest.java | 4 +-
.../flink/types/parser/FieldParserTest.java | 46 +++++++++++
.../flink/api/java/io/RowCsvInputFormat.java | 13 +++-
.../flink/api/java/io/CsvInputFormatTest.java | 81 +++++++++++++++++++-
.../api/java/io/RowCsvInputFormatTest.java | 75 ++++++++++++++++--
7 files changed, 258 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java b/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java
index 20c643e..b934d41 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java
@@ -358,14 +358,14 @@ public abstract class GenericCsvInputFormat<OT> extends DelimitedInputFormat<OT>
for (int field = 0, output = 0; field < fieldIncluded.length; field++) {
// check valid start position
- if (startPos >= limit) {
+ if (startPos > limit || (startPos == limit && field != fieldIncluded.length - 1)) {
if (lenient) {
return false;
} else {
throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
}
}
-
+
if (fieldIncluded[field]) {
// parse field
@SuppressWarnings("unchecked")
@@ -373,7 +373,7 @@ public abstract class GenericCsvInputFormat<OT> extends DelimitedInputFormat<OT>
Object reuse = holders[output];
startPos = parser.resetErrorStateAndParse(bytes, startPos, limit, this.fieldDelim, reuse);
holders[output] = parser.getLastResult();
-
+
// check parse result
if (startPos < 0) {
// no good
@@ -387,6 +387,17 @@ public abstract class GenericCsvInputFormat<OT> extends DelimitedInputFormat<OT>
+ "in file: " + filePath);
}
}
+ else if (startPos == limit
+ && field != fieldIncluded.length - 1
+ && !FieldParser.endsWithDelimiter(bytes, startPos - 1, fieldDelim)) {
+ // We are at the end of the record, but not all fields have been read
+ // and the end is not a field delimiter indicating an empty last field.
+ if (lenient) {
+ return false;
+ } else {
+ throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
+ }
+ }
output++;
}
else {
@@ -398,6 +409,19 @@ public abstract class GenericCsvInputFormat<OT> extends DelimitedInputFormat<OT>
throw new ParseException("Line could not be parsed: '" + lineAsString+"'\n"
+ "Expect field types: "+fieldTypesToString()+" \n"
+ "in file: "+filePath);
+ } else {
+ return false;
+ }
+ }
+ else if (startPos == limit
+ && field != fieldIncluded.length - 1
+ && !FieldParser.endsWithDelimiter(bytes, startPos - 1, fieldDelim)) {
+ // We are at the end of the record, but not all fields have been read
+ // and the end is not a field delimiter indicating an empty last field.
+ if (lenient) {
+ return false;
+ } else {
+ throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-core/src/main/java/org/apache/flink/types/parser/FieldParser.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/types/parser/FieldParser.java b/flink-core/src/main/java/org/apache/flink/types/parser/FieldParser.java
index cf3c83d..c45f820 100644
--- a/flink-core/src/main/java/org/apache/flink/types/parser/FieldParser.java
+++ b/flink-core/src/main/java/org/apache/flink/types/parser/FieldParser.java
@@ -156,6 +156,27 @@ public abstract class FieldParser<T> {
return true;
}
+
+ /**
+ * Checks if the given bytes ends with the delimiter at the given end position.
+ *
+ * @param bytes The byte array that holds the value.
+ * @param endPos The index of the byte array where the check for the delimiter ends.
+ * @param delim The delimiter to check for.
+ *
+ * @return true if a delimiter ends at the given end position, false otherwise.
+ */
+ public static final boolean endsWithDelimiter(byte[] bytes, int endPos, byte[] delim) {
+ if (endPos < delim.length - 1) {
+ return false;
+ }
+ for (int pos = 0; pos < delim.length; ++pos) {
+ if (delim[pos] != bytes[endPos - delim.length + 1 + pos]) {
+ return false;
+ }
+ }
+ return true;
+ }
/**
* Sets the error state of the parser. Called by subclasses of the parser to set the type of error
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-core/src/test/java/org/apache/flink/api/common/io/GenericCsvInputFormatTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/io/GenericCsvInputFormatTest.java b/flink-core/src/test/java/org/apache/flink/api/common/io/GenericCsvInputFormatTest.java
index c11a573..4873fa8 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/io/GenericCsvInputFormatTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/io/GenericCsvInputFormatTest.java
@@ -522,14 +522,14 @@ public class GenericCsvInputFormatTest {
"kkz|777|foobar|hhg\n" + // wrong data type in field
"kkz|777foobarhhg \n" + // too short, a skipped field never ends
"xyx|ignored|42|\n"; // another good line
- final FileInputSplit split = createTempFile(fileContent);
+ final FileInputSplit split = createTempFile(fileContent);
final Configuration parameters = new Configuration();
format.setFieldDelimiter("|");
format.setFieldTypesGeneric(StringValue.class, null, IntValue.class);
format.setLenient(true);
-
+
format.configure(parameters);
format.open(split);
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-core/src/test/java/org/apache/flink/types/parser/FieldParserTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/types/parser/FieldParserTest.java b/flink-core/src/test/java/org/apache/flink/types/parser/FieldParserTest.java
new file mode 100644
index 0000000..bcb2bfb
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/types/parser/FieldParserTest.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.types.parser;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class FieldParserTest {
+
+ @Test
+ public void testDelimiterNext() throws Exception {
+ byte[] bytes = "aaabc".getBytes();
+ byte[] delim = "aa".getBytes();
+ assertTrue(FieldParser.delimiterNext(bytes, 0, delim));
+ assertTrue(FieldParser.delimiterNext(bytes, 1, delim));
+ assertFalse(FieldParser.delimiterNext(bytes, 2, delim));
+ }
+
+ @Test
+ public void testEndsWithDelimiter() throws Exception {
+ byte[] bytes = "aabc".getBytes();
+ byte[] delim = "ab".getBytes();
+ assertFalse(FieldParser.endsWithDelimiter(bytes, 0, delim));
+ assertFalse(FieldParser.endsWithDelimiter(bytes, 1, delim));
+ assertTrue(FieldParser.endsWithDelimiter(bytes, 2, delim));
+ assertFalse(FieldParser.endsWithDelimiter(bytes, 3, delim));
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-java/src/main/java/org/apache/flink/api/java/io/RowCsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/io/RowCsvInputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/io/RowCsvInputFormat.java
index af2e9e4..ce37c74 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/io/RowCsvInputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/io/RowCsvInputFormat.java
@@ -151,7 +151,7 @@ public class RowCsvInputFormat extends CsvInputFormat<Row> implements ResultType
while (field < fieldIncluded.length) {
// check valid start position
- if (startPos >= limit) {
+ if (startPos > limit || (startPos == limit && field != fieldIncluded.length - 1)) {
if (isLenient()) {
return false;
} else {
@@ -198,6 +198,17 @@ public class RowCsvInputFormat extends CsvInputFormat<Row> implements ResultType
throw new ParseException(String.format("Unexpected parser position for column %1$s of row '%2$s'",
field, new String(bytes, offset, numBytes)));
}
+ else if (startPos == limit
+ && field != fieldIncluded.length - 1
+ && !FieldParser.endsWithDelimiter(bytes, startPos - 1, fieldDelimiter)) {
+ // We are at the end of the record, but not all fields have been read
+ // and the end is not a field delimiter indicating an empty last field.
+ if (isLenient()) {
+ return false;
+ } else {
+ throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
+ }
+ }
field++;
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java
----------------------------------------------------------------------
diff --git a/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java b/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java
index cc0d5bc..a303ff7 100644
--- a/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java
+++ b/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java
@@ -430,7 +430,7 @@ public class CsvInputFormatTest {
assertEquals("", result.f0);
assertEquals("", result.f1);
assertEquals("", result.f2);
-
+
result = format.nextRecord(result);
assertNull(result);
assertTrue(format.reachedEnd());
@@ -441,6 +441,57 @@ public class CsvInputFormatTest {
}
@Test
+ public void testTailingEmptyFields() throws Exception {
+ final String fileContent = "aa,bb,cc\n" + // ok
+ "aa,bb,\n" + // the last field is empty
+ "aa,,\n" + // the last two fields are empty
+ ",,\n" + // all fields are empty
+ "aa,bb"; // row too short
+ final FileInputSplit split = createTempFile(fileContent);
+
+ final TupleTypeInfo<Tuple3<String, String, String>> typeInfo =
+ TupleTypeInfo.getBasicTupleTypeInfo(String.class, String.class, String.class);
+ final CsvInputFormat<Tuple3<String, String, String>> format =
+ new TupleCsvInputFormat<Tuple3<String, String, String>>(PATH, typeInfo);
+
+ format.setFieldDelimiter(",");
+
+ format.configure(new Configuration());
+ format.open(split);
+
+ Tuple3<String, String, String> result = new Tuple3<String, String, String>();
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("aa", result.f0);
+ assertEquals("bb", result.f1);
+ assertEquals("cc", result.f2);
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("aa", result.f0);
+ assertEquals("bb", result.f1);
+ assertEquals("", result.f2);
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("aa", result.f0);
+ assertEquals("", result.f1);
+ assertEquals("", result.f2);
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("", result.f0);
+ assertEquals("", result.f1);
+ assertEquals("", result.f2);
+
+ try {
+ format.nextRecord(result);
+ fail("Parse Exception was not thrown! (Row too short)");
+ } catch (ParseException e) {}
+ }
+
+ @Test
public void testIntegerFields() throws IOException {
try {
final String fileContent = "111|222|333|444|555\n666|777|888|999|000|\n";
@@ -957,6 +1008,34 @@ public class CsvInputFormatTest {
}
@Test
+ public void testPojoTypeWithTrailingEmptyFields() throws Exception {
+ final String fileContent = "123,,3.123,,\n456,BBB,3.23,,";
+ final FileInputSplit split = createTempFile(fileContent);
+
+ @SuppressWarnings("unchecked")
+ PojoTypeInfo<PrivatePojoItem> typeInfo = (PojoTypeInfo<PrivatePojoItem>) TypeExtractor.createTypeInfo(PrivatePojoItem.class);
+ CsvInputFormat<PrivatePojoItem> inputFormat = new PojoCsvInputFormat<PrivatePojoItem>(PATH, typeInfo);
+
+ inputFormat.configure(new Configuration());
+ inputFormat.open(split);
+
+ PrivatePojoItem item = new PrivatePojoItem();
+ inputFormat.nextRecord(item);
+
+ assertEquals(123, item.field1);
+ assertEquals("", item.field2);
+ assertEquals(Double.valueOf(3.123), item.field3);
+ assertEquals("", item.field4);
+
+ inputFormat.nextRecord(item);
+
+ assertEquals(456, item.field1);
+ assertEquals("BBB", item.field2);
+ assertEquals(Double.valueOf(3.23), item.field3);
+ assertEquals("", item.field4);
+ }
+
+ @Test
public void testPojoTypeWithMappingInformation() throws Exception {
File tempFile = File.createTempFile("CsvReaderPojoType", "tmp");
tempFile.deleteOnExit();
http://git-wip-us.apache.org/repos/asf/flink/blob/1a062b79/flink-java/src/test/java/org/apache/flink/api/java/io/RowCsvInputFormatTest.java
----------------------------------------------------------------------
diff --git a/flink-java/src/test/java/org/apache/flink/api/java/io/RowCsvInputFormatTest.java b/flink-java/src/test/java/org/apache/flink/api/java/io/RowCsvInputFormatTest.java
index b819641..f6bda30 100644
--- a/flink-java/src/test/java/org/apache/flink/api/java/io/RowCsvInputFormatTest.java
+++ b/flink-java/src/test/java/org/apache/flink/api/java/io/RowCsvInputFormatTest.java
@@ -230,7 +230,7 @@ public class RowCsvInputFormatTest {
@Test
public void readStringFields() throws Exception {
- String fileContent = "abc|def|ghijk\nabc||hhg\n|||";
+ String fileContent = "abc|def|ghijk\nabc||hhg\n|||\n||";
FileInputSplit split = createTempFile(fileContent);
@@ -264,13 +264,19 @@ public class RowCsvInputFormatTest {
assertEquals("", result.getField(2));
result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("", result.getField(0));
+ assertEquals("", result.getField(1));
+ assertEquals("", result.getField(2));
+
+ result = format.nextRecord(result);
assertNull(result);
assertTrue(format.reachedEnd());
}
@Test
public void readMixedQuotedStringFields() throws Exception {
- String fileContent = "@a|b|c@|def|@ghijk@\nabc||@|hhg@\n|||";
+ String fileContent = "@a|b|c@|def|@ghijk@\nabc||@|hhg@\n|||\n";
FileInputSplit split = createTempFile(fileContent);
@@ -351,6 +357,65 @@ public class RowCsvInputFormatTest {
}
@Test
+ public void testTailingEmptyFields() throws Exception {
+ String fileContent = "abc|-def|-ghijk\n" +
+ "abc|-def|-\n" +
+ "abc|-|-\n" +
+ "|-|-|-\n" +
+ "|-|-\n" +
+ "abc|-def\n";
+
+ FileInputSplit split = createTempFile(fileContent);
+
+ TypeInformation[] fieldTypes = new TypeInformation[]{
+ BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.STRING_TYPE_INFO};
+
+ RowCsvInputFormat format = new RowCsvInputFormat(PATH, fieldTypes, "\n", "|");
+ format.setFieldDelimiter("|-");
+ format.configure(new Configuration());
+ format.open(split);
+
+ Row result = new Row(3);
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("abc", result.getField(0));
+ assertEquals("def", result.getField(1));
+ assertEquals("ghijk", result.getField(2));
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("abc", result.getField(0));
+ assertEquals("def", result.getField(1));
+ assertEquals("", result.getField(2));
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("abc", result.getField(0));
+ assertEquals("", result.getField(1));
+ assertEquals("", result.getField(2));
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("", result.getField(0));
+ assertEquals("", result.getField(1));
+ assertEquals("", result.getField(2));
+
+ result = format.nextRecord(result);
+ assertNotNull(result);
+ assertEquals("", result.getField(0));
+ assertEquals("", result.getField(1));
+ assertEquals("", result.getField(2));
+
+ try {
+ format.nextRecord(result);
+ fail("Parse Exception was not thrown! (Row too short)");
+ } catch (ParseException e) {}
+ }
+
+ @Test
public void testIntegerFields() throws Exception {
String fileContent = "111|222|333|444|555\n666|777|888|999|000|\n";
@@ -396,12 +461,12 @@ public class RowCsvInputFormatTest {
public void testEmptyFields() throws Exception {
String fileContent =
",,,,,,,,\n" +
+ ",,,,,,,\n" +
",,,,,,,,\n" +
+ ",,,,,,,\n" +
",,,,,,,,\n" +
",,,,,,,,\n" +
- ",,,,,,,,\n" +
- ",,,,,,,,\n" +
- ",,,,,,,,\n" +
+ ",,,,,,,\n" +
",,,,,,,,\n";
FileInputSplit split = createTempFile(fileContent);