You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/07/16 03:45:53 UTC
[flink] branch release-1.9 updated:
[FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to
support making two planners available in one jar
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push:
new c7b69b0 [FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to support making two planners available in one jar
c7b69b0 is described below
commit c7b69b0753e357f424d177736c94f909d65153fb
Author: godfreyhe <go...@163.com>
AuthorDate: Mon Jul 15 21:13:23 2019 +0800
[FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to support making two planners available in one jar
This closes #9119
---
.../calcite/sql/SqlSplittableAggFunction.java | 374 ---------------------
.../logical/FlinkAggregateJoinTransposeRule.java | 67 +---
.../table/plan/rules/FlinkStreamRuleSets.scala | 2 +-
.../batch/sql/agg/AggregateReduceGroupingTest.xml | 17 +-
.../logical/AggregateReduceGroupingRuleTest.xml | 11 +-
...xml => FlinkAggregateJoinTransposeRuleTest.xml} | 4 +-
.../FlinkAggregateOuterJoinTransposeRuleTest.xml | 267 ---------------
...a => FlinkAggregateJoinTransposeRuleTest.scala} | 3 +-
.../FlinkAggregateOuterJoinTransposeRuleTest.scala | 122 -------
9 files changed, 31 insertions(+), 836 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
deleted file mode 100644
index a69a82d..0000000
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
+++ /dev/null
@@ -1,374 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.calcite.sql;
-
-import com.google.common.collect.ImmutableList;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.type.SqlTypeName;
-import org.apache.calcite.util.ImmutableIntList;
-import org.apache.calcite.util.mapping.Mappings;
-
-import java.math.BigDecimal;
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * This file is copied from Calcite and made the following changes:
- * 1. makeProperRexNodeForOuterJoin function added for CountSplitter and AbstractSumSplitter.
- * 2. If the join type is left or right outer join then make the proper rexNode, or follow the previous logic.
- *
- * This copy can be removed once [CALCITE-2378] is fixed.
- */
-
-/**
- * Aggregate function that can be split into partial aggregates.
- *
- * <p>For example, {@code COUNT(x)} can be split into {@code COUNT(x)} on
- * subsets followed by {@code SUM} to combine those counts.
- */
-public interface SqlSplittableAggFunction {
- AggregateCall split(AggregateCall aggregateCall,
- Mappings.TargetMapping mapping);
-
- /** Called to generate an aggregate for the other side of the join
- * than the side aggregate call's arguments come from. Returns null if
- * no aggregate is required. */
- AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e);
-
- /** Generates an aggregate call to merge sub-totals.
- *
- * <p>Most implementations will add a single aggregate call to
- * {@code aggCalls}, and return a {@link RexInputRef} that points to it.
- *
- * @param rexBuilder Rex builder
- * @param extra Place to define extra input expressions
- * @param offset Offset due to grouping columns (and indicator columns if
- * applicable)
- * @param inputRowType Input row type
- * @param aggregateCall Source aggregate call
- * @param leftSubTotal Ordinal of the sub-total coming from the left side of
- * the join, or -1 if there is no such sub-total
- * @param rightSubTotal Ordinal of the sub-total coming from the right side
- * of the join, or -1 if there is no such sub-total
- * @param joinRelType the join type
- *
- * @return Aggregate call
- */
- AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra,
- int offset, RelDataType inputRowType, AggregateCall aggregateCall,
- int leftSubTotal, int rightSubTotal, JoinRelType joinRelType);
-
- /** Generates an expression for the value of the aggregate function when
- * applied to a single row.
- *
- * <p>For example, if there is one row:
- * <ul>
- * <li>{@code SUM(x)} is {@code x}
- * <li>{@code MIN(x)} is {@code x}
- * <li>{@code MAX(x)} is {@code x}
- * <li>{@code COUNT(x)} is {@code CASE WHEN x IS NOT NULL THEN 1 ELSE 0 END 1}
- * which can be simplified to {@code 1} if {@code x} is never null
- * <li>{@code COUNT(*)} is 1
- * </ul>
- *
- * @param rexBuilder Rex builder
- * @param inputRowType Input row type
- * @param aggregateCall Aggregate call
- *
- * @return Expression for single row
- */
- RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
- AggregateCall aggregateCall);
-
- /** Collection in which one can register an element. Registering may return
- * a reference to an existing element.
- *
- * @param <E> element type */
- interface Registry<E> {
- int register(E e);
- }
-
- /** Splitting strategy for {@code COUNT}.
- *
- * <p>COUNT splits into itself followed by SUM. (Actually
- * SUM0, because the total needs to be 0, not null, if there are 0 rows.)
- * This rule works for any number of arguments to COUNT, including COUNT(*).
- */
- class CountSplitter implements SqlSplittableAggFunction {
- public static final CountSplitter INSTANCE = new CountSplitter();
-
- public AggregateCall split(AggregateCall aggregateCall,
- Mappings.TargetMapping mapping) {
- return aggregateCall.transform(mapping);
- }
-
- public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
- return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
- ImmutableIntList.of(), -1,
- typeFactory.createSqlType(SqlTypeName.BIGINT), null);
- }
-
- /**
- * This new function create a proper RexNode for {@coide COUNT} Agg with OuterJoin Condition.
- */
- private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder,
- RelDataType inputRowType,
- AggregateCall aggregateCall,
- int index) {
- RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index);
- RexLiteral literal;
- boolean isCountStar = aggregateCall.getArgList() == null || aggregateCall.getArgList().isEmpty();
- if (isCountStar) {
- literal = rexBuilder.makeExactLiteral(BigDecimal.ONE);
- } else {
- literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
- }
- RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef);
- return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
- predicate,
- literal,
- rexBuilder.makeCast(aggregateCall.type, inputRef)
- );
- }
-
- public AggregateCall topSplit(RexBuilder rexBuilder,
- Registry<RexNode> extra, int offset, RelDataType inputRowType,
- AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
- JoinRelType joinRelType) {
- final List<RexNode> merges = new ArrayList<>();
- if (leftSubTotal >= 0) {
- // add support for right outer join
- if (joinRelType == JoinRelType.RIGHT) {
- merges.add(
- makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal)
- );
- } else {
- // if it's a inner join, then do the previous logic
- merges.add(
- rexBuilder.makeInputRef(aggregateCall.type, leftSubTotal));
- }
- }
- if (rightSubTotal >= 0) {
- // add support for left outer join
- if (joinRelType == JoinRelType.LEFT) {
- merges.add(
- makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, rightSubTotal)
- );
- } else {
- // if it's a inner join, then do the previous logic
- merges.add(
- rexBuilder.makeInputRef(aggregateCall.type, rightSubTotal));
- }
- }
- RexNode node;
- switch (merges.size()) {
- case 1:
- node = merges.get(0);
- break;
- case 2:
- node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
- break;
- default:
- throw new AssertionError("unexpected count " + merges);
- }
- int ordinal = extra.register(node);
- return AggregateCall.create(SqlStdOperatorTable.SUM0, false, false,
- ImmutableList.of(ordinal), -1, aggregateCall.type,
- aggregateCall.name);
- }
-
- /**
- * {@inheritDoc}
- *
- * <p>{@code COUNT(*)}, and {@code COUNT} applied to all NOT NULL arguments,
- * become {@code 1}; otherwise
- * {@code CASE WHEN arg0 IS NOT NULL THEN 1 ELSE 0 END}.
- */
- public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
- AggregateCall aggregateCall) {
- final List<RexNode> predicates = new ArrayList<>();
- for (Integer arg : aggregateCall.getArgList()) {
- final RelDataType type = inputRowType.getFieldList().get(arg).getType();
- if (type.isNullable()) {
- predicates.add(
- rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
- rexBuilder.makeInputRef(type, arg)));
- }
- }
- final RexNode predicate =
- RexUtil.composeConjunction(rexBuilder, predicates, true);
- if (predicate == null) {
- return rexBuilder.makeExactLiteral(BigDecimal.ONE);
- } else {
- return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate,
- rexBuilder.makeExactLiteral(BigDecimal.ONE),
- rexBuilder.makeExactLiteral(BigDecimal.ZERO));
- }
- }
- }
-
- /** Aggregate function that splits into two applications of itself.
- *
- * <p>Examples are MIN and MAX. */
- class SelfSplitter implements SqlSplittableAggFunction {
- public static final SelfSplitter INSTANCE = new SelfSplitter();
-
- public RexNode singleton(RexBuilder rexBuilder,
- RelDataType inputRowType, AggregateCall aggregateCall) {
- final int arg = aggregateCall.getArgList().get(0);
- final RelDataTypeField field = inputRowType.getFieldList().get(arg);
- return rexBuilder.makeInputRef(field.getType(), arg);
- }
-
- public AggregateCall split(AggregateCall aggregateCall,
- Mappings.TargetMapping mapping) {
- return aggregateCall.transform(mapping);
- }
-
- public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
- return null; // no aggregate function required on other side
- }
-
- public AggregateCall topSplit(RexBuilder rexBuilder,
- Registry<RexNode> extra, int offset, RelDataType inputRowType,
- AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
- JoinRelType joinRelType) {
- assert (leftSubTotal >= 0) != (rightSubTotal >= 0);
- final int arg = leftSubTotal >= 0 ? leftSubTotal : rightSubTotal;
- return aggregateCall.copy(ImmutableIntList.of(arg), -1);
- }
- }
-
- /** Common Splitting strategy for {@coide SUM} and {@coide SUM0}. */
- abstract class AbstractSumSplitter implements SqlSplittableAggFunction {
-
- public RexNode singleton(RexBuilder rexBuilder,
- RelDataType inputRowType, AggregateCall aggregateCall) {
- final int arg = aggregateCall.getArgList().get(0);
- final RelDataTypeField field = inputRowType.getFieldList().get(arg);
- return rexBuilder.makeInputRef(field.getType(), arg);
- }
-
- public AggregateCall split(AggregateCall aggregateCall,
- Mappings.TargetMapping mapping) {
- return aggregateCall.transform(mapping);
- }
-
- public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
- return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
- ImmutableIntList.of(), -1,
- typeFactory.createSqlType(SqlTypeName.BIGINT), null);
- }
-
- /**
- * This new function create a proper RexNode for {@coide SUM} Agg with OuterJoin Condition.
- */
- private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder,
- RelDataType inputRowType,
- AggregateCall aggregateCall,
- int index) {
- RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index);
- RexLiteral literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
- RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef);
- return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
- predicate,
- literal,
- rexBuilder.makeCast(aggregateCall.type, inputRef)
- );
- }
-
- public AggregateCall topSplit(RexBuilder rexBuilder,
- Registry<RexNode> extra, int offset, RelDataType inputRowType,
- AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
- JoinRelType joinRelType) {
- final List<RexNode> merges = new ArrayList<>();
- final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
- if (leftSubTotal >= 0) {
- // add support for left outer join
- if (joinRelType == JoinRelType.RIGHT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) {
- merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal));
- } else {
- // if it's a inner join, then do the previous logic
- final RelDataType type = fieldList.get(leftSubTotal).getType();
- merges.add(rexBuilder.makeInputRef(type, leftSubTotal));
- }
- }
- if (rightSubTotal >= 0) {
- // add support for right outer join
- if (joinRelType == JoinRelType.LEFT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) {
- merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, offset + rightSubTotal));
- } else {
- // if it's a inner join, then do the previous logic
- final RelDataType type = fieldList.get(rightSubTotal).getType();
- merges.add(rexBuilder.makeInputRef(type, rightSubTotal));
- }
- }
- RexNode node;
- switch (merges.size()) {
- case 1:
- node = merges.get(0);
- break;
- case 2:
- node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
- node = rexBuilder.makeAbstractCast(aggregateCall.type, node);
- break;
- default:
- throw new AssertionError("unexpected count " + merges);
- }
- int ordinal = extra.register(node);
- return AggregateCall.create(getMergeAggFunctionOfTopSplit(), false, false,
- ImmutableList.of(ordinal), -1, aggregateCall.type,
- aggregateCall.name);
- }
-
- protected abstract SqlAggFunction getMergeAggFunctionOfTopSplit();
-
- }
-
- /** Splitting strategy for {@coide SUM}. */
- class SumSplitter extends AbstractSumSplitter {
-
- public static final SumSplitter INSTANCE = new SumSplitter();
-
- @Override public SqlAggFunction getMergeAggFunctionOfTopSplit() {
- return SqlStdOperatorTable.SUM;
- }
-
- }
-
- /** Splitting strategy for {@code SUM0}. */
- class Sum0Splitter extends AbstractSumSplitter {
-
- public static final Sum0Splitter INSTANCE = new Sum0Splitter();
-
- @Override public SqlAggFunction getMergeAggFunctionOfTopSplit() {
- return SqlStdOperatorTable.SUM0;
- }
- }
-}
-
-// End SqlSplittableAggFunction.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
index 10c0b94..c803f20 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
@@ -31,7 +31,6 @@ import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
-import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
@@ -46,7 +45,6 @@ import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
@@ -73,7 +71,6 @@ import scala.collection.Seq;
* This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateJoinTransposeRule}.
* Modification:
* - Do not match temporal join since lookup table source doesn't support aggregate.
- * - Support Left/Right Outer Join
* - Fix type mismatch error
* - Support aggregate with AUXILIARY_GROUP
*/
@@ -86,25 +83,19 @@ import scala.collection.Seq;
public class FlinkAggregateJoinTransposeRule extends RelOptRule {
public static final FlinkAggregateJoinTransposeRule INSTANCE =
new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
- RelFactories.LOGICAL_BUILDER, false, false);
+ RelFactories.LOGICAL_BUILDER, false);
/** Extended instance of the rule that can push down aggregate functions. */
public static final FlinkAggregateJoinTransposeRule EXTENDED =
new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
- RelFactories.LOGICAL_BUILDER, true, false);
-
- public static final FlinkAggregateJoinTransposeRule LEFT_RIGHT_OUTER_JOIN_EXTENDED =
- new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
- RelFactories.LOGICAL_BUILDER, true, true);
+ RelFactories.LOGICAL_BUILDER, true);
private final boolean allowFunctions;
- private final boolean allowLeftOrRightOuterJoin;
-
/** Creates an FlinkAggregateJoinTransposeRule. */
public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass,
Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory,
- boolean allowFunctions, boolean allowLeftOrRightOuterJoin) {
+ boolean allowFunctions) {
super(
operandJ(aggregateClass, null,
aggregate -> aggregate.getGroupType() == Aggregate.Group.SIMPLE,
@@ -112,7 +103,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
relBuilderFactory, null);
this.allowFunctions = allowFunctions;
- this.allowLeftOrRightOuterJoin = allowLeftOrRightOuterJoin;
}
@Deprecated // to be removed before 2.0
@@ -121,7 +111,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
Class<? extends Join> joinClass,
RelFactories.JoinFactory joinFactory) {
this(aggregateClass, joinClass,
- RelBuilder.proto(aggregateFactory, joinFactory), false, false);
+ RelBuilder.proto(aggregateFactory, joinFactory), false);
}
@Deprecated // to be removed before 2.0
@@ -131,7 +121,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
RelFactories.JoinFactory joinFactory,
boolean allowFunctions) {
this(aggregateClass, joinClass,
- RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions, false);
+ RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
}
@Deprecated // to be removed before 2.0
@@ -141,7 +131,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
RelFactories.JoinFactory joinFactory,
RelFactories.ProjectFactory projectFactory) {
this(aggregateClass, joinClass,
- RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false, false);
+ RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
}
@Deprecated // to be removed before 2.0
@@ -153,7 +143,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
boolean allowFunctions) {
this(aggregateClass, joinClass,
RelBuilder.proto(aggregateFactory, joinFactory, projectFactory),
- allowFunctions, false);
+ allowFunctions);
}
private boolean containsSnapshot(RelNode relNode) {
@@ -189,13 +179,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
- boolean isLeftOrRightOuterJoin =
- join.getJoinType() == JoinRelType.LEFT || join.getJoinType() == JoinRelType.RIGHT;
-
- if (join.getJoinType() != JoinRelType.INNER && !(allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin)) {
- return;
- }
-
// converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
// if the converted aggregate can be push down,
// AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this rule
@@ -210,18 +193,15 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
== null) {
return;
}
- if (allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin) {
- // todo do not support max/min agg until we've built the proper model
- if (aggregateCall.getAggregation().kind == SqlKind.MAX ||
- aggregateCall.getAggregation().kind == SqlKind.MIN) {
- return;
- }
- }
if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
return;
}
}
+ if (join.getJoinType() != JoinRelType.INNER) {
+ return;
+ }
+
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
@@ -229,19 +209,8 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = call.getMetadataQuery();
- ImmutableBitSet keyColumns;
- if (!isLeftOrRightOuterJoin) {
- keyColumns = keyColumns(aggregateColumns,
- mq.getPulledUpPredicates(join).pulledUpPredicates);
- } else {
- // this is an incomplete implementation
- if (isAggregateKeyApplicable(aggregateColumns, join)) {
- keyColumns = keyColumns(aggregateColumns,
- com.google.common.collect.ImmutableList.copyOf(RelOptUtil.conjunctions(join.getCondition())));
- } else {
- keyColumns = aggregateColumns;
- }
- }
+ final ImmutableBitSet keyColumns = keyColumns(aggregateColumns,
+ mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns =
RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate =
@@ -423,7 +392,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
splitter.topSplit(rexBuilder, registry(projects),
groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e,
leftSubTotal == null ? -1 : leftSubTotal,
- rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth, join.getJoinType()));
+ rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
@@ -546,14 +515,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
}
}
- private static boolean isAggregateKeyApplicable(ImmutableBitSet aggregateKeys, Join join) {
- JoinInfo joinInfo = join.analyzeCondition();
- return (join.getJoinType() == JoinRelType.LEFT && joinInfo.leftSet().contains(aggregateKeys)) ||
- (join.getJoinType() == JoinRelType.RIGHT &&
- joinInfo.rightSet().shift(join.getInput(0).getRowType().getFieldCount())
- .contains(aggregateKeys));
- }
-
private static void populateEquivalence(Map<Integer, BitSet> equivalence,
int i0, int i1) {
BitSet bitSet = equivalence.get(i0);
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index 491fc1b..d095f76 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -249,7 +249,7 @@ object FlinkStreamRuleSets {
// remove aggregation if it does not aggregate and input is already distinct
FlinkAggregateRemoveRule.INSTANCE,
// push aggregate through join
- FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED,
+ FlinkAggregateJoinTransposeRule.EXTENDED,
// using variants of aggregate union rule
AggregateUnionAggregateRule.AGG_ON_FIRST_INPUT,
AggregateUnionAggregateRule.AGG_ON_SECOND_INPUT,
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
index f964110..475744b 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
@@ -116,14 +116,13 @@ LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$4=[COUNT($4)], EXPR$5=[AVG($5)])
</Resource>
<Resource name="planAfter">
<![CDATA[
-Calc(select=[a2, b2, a3, b3, *($f2, $f20) AS EXPR$4, /(CAST(CASE(=($f4, 0), null:BIGINT, $f3)), $f4) AS EXPR$5])
-+- HashJoin(joinType=[InnerJoin], where=[=(b2, a3)], select=[a2, b2, $f2, a3, b3, $f20, $f3, $f4], isBroadcast=[true], build=[right])
- :- Calc(select=[a2, b2, CAST(CASE(IS NOT NULL(c2), 1, 0)) AS $f2])
- : +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
- +- Exchange(distribution=[broadcast])
- +- HashAggregate(isMerge=[true], groupBy=[a3, b3], select=[a3, b3, Final_COUNT(count1$0) AS $f2, Final_$SUM0(sum$1) AS $f3, Final_COUNT(count$2) AS $f4])
- +- Exchange(distribution=[hash[a3, b3]])
- +- LocalHashAggregate(groupBy=[a3, b3], select=[a3, b3, Partial_COUNT(*) AS count1$0, Partial_$SUM0(d3) AS sum$1, Partial_COUNT(d3) AS count$2])
+Calc(select=[a2, b2, a3, b3, EXPR$4, EXPR$5])
++- HashAggregate(isMerge=[true], groupBy=[a3, b3], auxGrouping=[a2, b2], select=[a3, b3, a2, b2, Final_COUNT(count$0) AS EXPR$4, Final_AVG(sum$1, count$2) AS EXPR$5])
+ +- Exchange(distribution=[hash[a3, b3]])
+ +- LocalHashAggregate(groupBy=[a3, b3], auxGrouping=[a2, b2], select=[a3, b3, a2, b2, Partial_COUNT(c2) AS count$0, Partial_AVG(d3) AS (sum$1, count$2)])
+ +- HashJoin(joinType=[InnerJoin], where=[=(b2, a3)], select=[a2, b2, c2, a3, b3, d3], isBroadcast=[true], build=[right])
+ :- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
+ +- Exchange(distribution=[broadcast])
+- Calc(select=[a3, b3, d3])
+- TableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
]]>
@@ -149,7 +148,7 @@ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5}], EXPR$6=[COUNT($6)])
<![CDATA[
Calc(select=[a1, b1, a2, b2, a3, b3, *($f4, $f2) AS EXPR$6])
+- HashJoin(joinType=[InnerJoin], where=[=(a1, a3)], select=[a1, b1, a2, b2, $f4, a3, b3, $f2], isBroadcast=[true], build=[right])
- :- Calc(select=[a1, b1, a2, b2, CAST(CASE(IS NOT NULL(c1), 1, 0)) AS $f4])
+ :- Calc(select=[a1, b1, a2, b2, CASE(IS NOT NULL(c1), 1:BIGINT, 0:BIGINT) AS $f4])
: +- HashJoin(joinType=[InnerJoin], where=[=(a1, b2)], select=[a1, b1, c1, a2, b2], build=[right])
: :- Exchange(distribution=[hash[a1]])
: : +- Calc(select=[a1, b1, c1])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
index d659cd9..1212910 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
@@ -107,11 +107,10 @@ LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$4=[COUNT($4)], EXPR$5=[AVG($5)])
</Resource>
<Resource name="planAfter">
<![CDATA[
-FlinkLogicalCalc(select=[a2, b2, a3, b3, *($f2, $f20) AS EXPR$4, /(CAST(CASE(=($f4, 0), null:BIGINT, $f3)), $f4) AS EXPR$5])
-+- FlinkLogicalJoin(condition=[=($1, $3)], joinType=[inner])
- :- FlinkLogicalCalc(select=[a2, b2, CAST(CASE(IS NOT NULL(c2), 1, 0)) AS $f2])
- : +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
- +- FlinkLogicalAggregate(group=[{0, 1}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[COUNT($2)])
+FlinkLogicalCalc(select=[a2, b2, a3, b3, EXPR$4, EXPR$5])
++- FlinkLogicalAggregate(group=[{3, 4}], a2=[AUXILIARY_GROUP($0)], b2=[AUXILIARY_GROUP($1)], EXPR$4=[COUNT($2)], EXPR$5=[AVG($5)])
+ +- FlinkLogicalJoin(condition=[=($1, $3)], joinType=[inner])
+ :- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
+- FlinkLogicalCalc(select=[a3, b3, d3])
+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
]]>
@@ -137,7 +136,7 @@ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5}], EXPR$6=[COUNT($6)])
<![CDATA[
FlinkLogicalCalc(select=[a1, b1, a2, b2, a3, b3, *($f4, $f2) AS EXPR$6])
+- FlinkLogicalJoin(condition=[=($0, $5)], joinType=[inner])
- :- FlinkLogicalCalc(select=[a1, b1, a2, b2, CAST(CASE(IS NOT NULL(c1), 1, 0)) AS $f4])
+ :- FlinkLogicalCalc(select=[a1, b1, a2, b2, CASE(IS NOT NULL(c1), 1:BIGINT, 0:BIGINT) AS $f4])
: +- FlinkLogicalJoin(condition=[=($0, $4)], joinType=[inner])
: :- FlinkLogicalCalc(select=[a1, b1, c1])
: : +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a1, b1, c1, d1)]]], fields=[a1, b1, c1, d1])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
similarity index 98%
rename from flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml
rename to flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
index 475d60f..0d58e8b 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
@@ -192,7 +192,7 @@ LogicalProject(EXPR$0=[$2], EXPR$1=[$3], a=[$0], b=[$1], EXPR$4=[$4])
LogicalProject(EXPR$0=[$2], EXPR$1=[$3], a=[$0], b=[$1], EXPR$4=[$4])
+- LogicalProject(a=[$3], b=[$4], a2=[$1], b2=[$0], $f6=[*($2, $5)])
+- LogicalJoin(condition=[=($0, $3)], joinType=[inner])
- :- LogicalProject(b2=[$1], a2=[$0], $f2=[CAST(CASE(IS NOT NULL($2), 1, 0)):BIGINT NOT NULL])
+ :- LogicalProject(b2=[$1], a2=[$0], $f2=[CASE(IS NOT NULL($2), 1:BIGINT, 0:BIGINT)])
: +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]])
+- LogicalAggregate(group=[{0, 1}], agg#0=[COUNT()])
+- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c)]]])
@@ -219,7 +219,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, $3)])
+- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
- :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
+ :- LogicalProject(a=[$0], $f1=[CASE(IS NOT NULL($0), 1:BIGINT, 0:BIGINT)])
: +- LogicalAggregate(group=[{0}])
: +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c)]]])
+- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml
deleted file mode 100644
index 4df5360..0000000
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml
+++ /dev/null
@@ -1,267 +0,0 @@
-<?xml version="1.0" ?>
-<!--
-Licensed to the Apache Software Foundation (ASF) under one or more
-contributor license agreements. See the NOTICE file distributed with
-this work for additional information regarding copyright ownership.
-The ASF licenses this file to you under the Apache License, Version 2.0
-(the "License"); you may not use this file except in compliance with
-the License. You may obtain a copy of the License at
-
-http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
--->
-<Root>
- <TestCase name="testPushCountAggThroughJoinOverUniqueColumn">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
-+- LogicalProject(a=[$0])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
-+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, $3)])
- +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
- :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
- : +- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAggThroughLeftJoinAndGroupByLeft">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(B.b) FROM (SELECT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalProject(a=[$0], b=[$2])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a=[$0], $f4=[*($1, CASE(IS NULL($3), 0, CAST($3):BIGINT NOT NULL))])
- +- LogicalJoin(condition=[=($0, $2)], joinType=[left])
- :- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByLeft">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalProject(a=[$0], b=[$2])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a=[$0], $f3=[CASE(IS NULL($2), 0, CAST($2):BIGINT NOT NULL)])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumn">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
-+- LogicalProject(a=[$0])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
-+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, CASE(IS NULL($3), 0, CAST($3):BIGINT NOT NULL))])
- +- LogicalJoin(condition=[=($0, $2)], joinType=[left])
- :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
- : +- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushSumAggThroughLeftJoinOverUniqueColumn">
- <Resource name="sql">
- <![CDATA[SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
-+- LogicalProject(a=[$0])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
-+- LogicalProject(a=[$0], a0=[$1], $f1=[$2], $f3=[CAST(*($0, $2)):INTEGER])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByRight">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalProject(a0=[$1], b=[$2])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{1}], EXPR$0=[$SUM0($3)])
- +- LogicalProject(a=[$0], a0=[$1], EXPR$0=[$2], $f3=[CASE(IS NULL($2), 0, CAST($2):BIGINT NOT NULL)])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAggThroughRightJoin">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(B.b) FROM T AS B RIGHT OUTER JOIN (SELECT a FROM T) AS A ON A.a=B.a GROUP BY A.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- +- LogicalProject(a0=[$4], b=[$1])
- +- LogicalJoin(condition=[=($4, $0)], joinType=[right])
- :- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalProject(a=[$0])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a0=[$2], $f4=[*(CASE(IS NULL($1), 0, CAST($1):BIGINT NOT NULL), $3)])
- +- LogicalJoin(condition=[=($2, $0)], joinType=[right])
- :- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushSumAggThroughJoinOverUniqueColumn">
- <Resource name="sql">
- <![CDATA[SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
-+- LogicalProject(a=[$0])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
-+- LogicalProject(a=[$0], a0=[$1], $f1=[$2], $f3=[CAST(*($0, $2)):INTEGER])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
- <TestCase name="testPushCountAllAggThroughLeftJoinOverUniqueColumn">
- <Resource name="sql">
- <![CDATA[SELECT COUNT(*) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
- </Resource>
- <Resource name="planBefore">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
-+- LogicalProject($f0=[0])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalProject(a=[$0])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- <Resource name="planAfter">
- <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($3)])
-+- LogicalProject(a=[$0], a0=[$1], EXPR$0=[$2], $f3=[CASE(IS NULL($2), 1, CAST($2):BIGINT NOT NULL)])
- +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
- :- LogicalAggregate(group=[{0}])
- : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
- +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
- +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
- </Resource>
- </TestCase>
-</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
similarity index 97%
rename from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala
rename to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
index 316ff13..41218e8 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
@@ -34,9 +34,8 @@ import org.junit.{Before, Test}
/**
* Test for [[FlinkAggregateJoinTransposeRule]].
- * this class only test inner join.
*/
-class FlinkAggregateInnerJoinTransposeRuleTest extends TableTestBase {
+class FlinkAggregateJoinTransposeRuleTest extends TableTestBase {
private val util = batchTestUtil()
@Before
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala
deleted file mode 100644
index 186514d..0000000
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.rules.logical
-
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.plan.optimize.program.{FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE, StreamOptimizeContext}
-import org.apache.flink.table.util.TableTestBase
-
-import org.apache.calcite.plan.hep.HepMatchOrder
-import org.apache.calcite.rel.rules._
-import org.apache.calcite.tools.RuleSets
-import org.junit.{Before, Test}
-
-/**
- * Test for [[FlinkAggregateJoinTransposeRule]].
- * this class only test left/right outer join.
- */
-class FlinkAggregateOuterJoinTransposeRuleTest extends TableTestBase {
-
- private val util = streamTestUtil()
-
- @Before
- def setup(): Unit = {
- val program = new FlinkChainedProgram[StreamOptimizeContext]()
- program.addLast(
- "rules",
- FlinkHepRuleSetProgramBuilder.newBuilder
- .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
- .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
- .add(RuleSets.ofList(
- FlinkFilterJoinRule.FILTER_ON_JOIN,
- FlinkFilterJoinRule.JOIN,
- FilterAggregateTransposeRule.INSTANCE,
- FilterProjectTransposeRule.INSTANCE,
- FilterMergeRule.INSTANCE,
- AggregateProjectMergeRule.INSTANCE,
- FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED
- ))
- .build()
- )
- util.replaceStreamProgram(program)
-
- util.addTableSource[(Int, Long, String, Int)]("T", 'a, 'b, 'c, 'd)
- }
-
- @Test
- def testPushCountAggThroughJoinOverUniqueColumn(): Unit = {
- util.verifyPlan("SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a")
- }
-
- @Test
- def testPushSumAggThroughJoinOverUniqueColumn(): Unit = {
- util.verifyPlan("SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a")
- }
-
- @Test
- def testPushCountAggThroughLeftJoinOverUniqueColumn(): Unit = {
- val sqlQuery = "SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushSumAggThroughLeftJoinOverUniqueColumn(): Unit = {
- val sqlQuery = "SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushCountAllAggThroughLeftJoinOverUniqueColumn(): Unit = {
- val sqlQuery = "SELECT COUNT(*) FROM (SELECT DISTINCT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByLeft(): Unit = {
- val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByRight(): Unit = {
- val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY B.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushCountAggThroughLeftJoinAndGroupByLeft(): Unit = {
- val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT a FROM T) AS A " +
- "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a"
- util.verifyPlan(sqlQuery)
- }
-
- @Test
- def testPushCountAggThroughRightJoin(): Unit = {
- val sqlQuery = "SELECT COUNT(B.b) FROM T AS B RIGHT OUTER JOIN (SELECT a FROM T) AS A " +
- "ON A.a=B.a GROUP BY A.a"
- util.verifyPlan(sqlQuery)
- }
-
-}