You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/07/16 03:43:54 UTC

[flink] branch master updated: [FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to support making two planners available in one jar

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new a7bdbb9  [FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to support making two planners available in one jar
a7bdbb9 is described below

commit a7bdbb9aff734fc586e74c95984c7f956dbf576c
Author: godfreyhe <go...@163.com>
AuthorDate: Mon Jul 15 21:13:23 2019 +0800

    [FLINK-13268][table-planner-blink] Revert SqlSplittableAggFunction to support making two planners available in one jar
    
    This closes #9119
---
 .../calcite/sql/SqlSplittableAggFunction.java      | 374 ---------------------
 .../logical/FlinkAggregateJoinTransposeRule.java   |  67 +---
 .../table/plan/rules/FlinkStreamRuleSets.scala     |   2 +-
 .../batch/sql/agg/AggregateReduceGroupingTest.xml  |  17 +-
 .../logical/AggregateReduceGroupingRuleTest.xml    |  11 +-
 ...xml => FlinkAggregateJoinTransposeRuleTest.xml} |   4 +-
 .../FlinkAggregateOuterJoinTransposeRuleTest.xml   | 267 ---------------
 ...a => FlinkAggregateJoinTransposeRuleTest.scala} |   3 +-
 .../FlinkAggregateOuterJoinTransposeRuleTest.scala | 122 -------
 9 files changed, 31 insertions(+), 836 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
deleted file mode 100644
index a69a82d..0000000
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
+++ /dev/null
@@ -1,374 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.calcite.sql;
-
-import com.google.common.collect.ImmutableList;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.type.SqlTypeName;
-import org.apache.calcite.util.ImmutableIntList;
-import org.apache.calcite.util.mapping.Mappings;
-
-import java.math.BigDecimal;
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * This file is copied from Calcite and made the following changes:
- * 1. makeProperRexNodeForOuterJoin function added for CountSplitter and AbstractSumSplitter.
- * 2. If the join type is left or right outer join then make the proper rexNode, or follow the previous logic.
- *
- * This copy can be removed once [CALCITE-2378] is fixed.
- */
-
-/**
- * Aggregate function that can be split into partial aggregates.
- *
- * <p>For example, {@code COUNT(x)} can be split into {@code COUNT(x)} on
- * subsets followed by {@code SUM} to combine those counts.
- */
-public interface SqlSplittableAggFunction {
-	AggregateCall split(AggregateCall aggregateCall,
-			Mappings.TargetMapping mapping);
-
-	/** Called to generate an aggregate for the other side of the join
-	 * than the side aggregate call's arguments come from. Returns null if
-	 * no aggregate is required. */
-	AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e);
-
-	/** Generates an aggregate call to merge sub-totals.
-	 *
-	 * <p>Most implementations will add a single aggregate call to
-	 * {@code aggCalls}, and return a {@link RexInputRef} that points to it.
-	 *
-	 * @param rexBuilder Rex builder
-	 * @param extra Place to define extra input expressions
-	 * @param offset Offset due to grouping columns (and indicator columns if
-	 *     applicable)
-	 * @param inputRowType Input row type
-	 * @param aggregateCall Source aggregate call
-	 * @param leftSubTotal Ordinal of the sub-total coming from the left side of
-	 *     the join, or -1 if there is no such sub-total
-	 * @param rightSubTotal Ordinal of the sub-total coming from the right side
-	 *     of the join, or -1 if there is no such sub-total
-	 * @param joinRelType the join type
-	 *
-	 * @return Aggregate call
-	 */
-	AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra,
-			int offset, RelDataType inputRowType, AggregateCall aggregateCall,
-			int leftSubTotal, int rightSubTotal, JoinRelType joinRelType);
-
-	/** Generates an expression for the value of the aggregate function when
-	 * applied to a single row.
-	 *
-	 * <p>For example, if there is one row:
-	 * <ul>
-	 *   <li>{@code SUM(x)} is {@code x}
-	 *   <li>{@code MIN(x)} is {@code x}
-	 *   <li>{@code MAX(x)} is {@code x}
-	 *   <li>{@code COUNT(x)} is {@code CASE WHEN x IS NOT NULL THEN 1 ELSE 0 END 1}
-	 *   which can be simplified to {@code 1} if {@code x} is never null
-	 *   <li>{@code COUNT(*)} is 1
-	 * </ul>
-	 *
-	 * @param rexBuilder Rex builder
-	 * @param inputRowType Input row type
-	 * @param aggregateCall Aggregate call
-	 *
-	 * @return Expression for single row
-	 */
-	RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
-			AggregateCall aggregateCall);
-
-	/** Collection in which one can register an element. Registering may return
-	 * a reference to an existing element.
-	 *
-	 * @param <E> element type */
-	interface Registry<E> {
-		int register(E e);
-	}
-
-	/** Splitting strategy for {@code COUNT}.
-	 *
-	 * <p>COUNT splits into itself followed by SUM. (Actually
-	 * SUM0, because the total needs to be 0, not null, if there are 0 rows.)
-	 * This rule works for any number of arguments to COUNT, including COUNT(*).
-	 */
-	class CountSplitter implements SqlSplittableAggFunction {
-		public static final CountSplitter INSTANCE = new CountSplitter();
-
-		public AggregateCall split(AggregateCall aggregateCall,
-								   Mappings.TargetMapping mapping) {
-			return aggregateCall.transform(mapping);
-		}
-
-		public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
-			return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
-				ImmutableIntList.of(), -1,
-				typeFactory.createSqlType(SqlTypeName.BIGINT), null);
-		}
-
-		/**
-		 * This new function create a proper RexNode for {@coide COUNT} Agg with OuterJoin Condition.
-		 */
-		private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder,
-													  RelDataType inputRowType,
-													  AggregateCall aggregateCall,
-													  int index) {
-			RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index);
-			RexLiteral literal;
-			boolean isCountStar = aggregateCall.getArgList() == null || aggregateCall.getArgList().isEmpty();
-			if (isCountStar) {
-				literal = rexBuilder.makeExactLiteral(BigDecimal.ONE);
-			} else {
-				literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
-			}
-			RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef);
-			return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
-				predicate,
-				literal,
-				rexBuilder.makeCast(aggregateCall.type, inputRef)
-			);
-		}
-
-		public AggregateCall topSplit(RexBuilder rexBuilder,
-									  Registry<RexNode> extra, int offset, RelDataType inputRowType,
-									  AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
-									  JoinRelType joinRelType) {
-			final List<RexNode> merges = new ArrayList<>();
-			if (leftSubTotal >= 0) {
-				// add support for right outer join
-				if (joinRelType == JoinRelType.RIGHT) {
-					merges.add(
-						makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal)
-					);
-				} else {
-					// if it's a inner join, then do the previous logic
-					merges.add(
-						rexBuilder.makeInputRef(aggregateCall.type, leftSubTotal));
-				}
-			}
-			if (rightSubTotal >= 0) {
-				// add support for left outer join
-				if (joinRelType == JoinRelType.LEFT) {
-					merges.add(
-						makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, rightSubTotal)
-					);
-				} else {
-					// if it's a inner join, then do the previous logic
-					merges.add(
-						rexBuilder.makeInputRef(aggregateCall.type, rightSubTotal));
-				}
-			}
-			RexNode node;
-			switch (merges.size()) {
-				case 1:
-					node = merges.get(0);
-					break;
-				case 2:
-					node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
-					break;
-				default:
-					throw new AssertionError("unexpected count " + merges);
-			}
-			int ordinal = extra.register(node);
-			return AggregateCall.create(SqlStdOperatorTable.SUM0, false, false,
-				ImmutableList.of(ordinal), -1, aggregateCall.type,
-				aggregateCall.name);
-		}
-
-		/**
-		 * {@inheritDoc}
-		 *
-		 * <p>{@code COUNT(*)}, and {@code COUNT} applied to all NOT NULL arguments,
-		 * become {@code 1}; otherwise
-		 * {@code CASE WHEN arg0 IS NOT NULL THEN 1 ELSE 0 END}.
-		 */
-		public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
-								 AggregateCall aggregateCall) {
-			final List<RexNode> predicates = new ArrayList<>();
-			for (Integer arg : aggregateCall.getArgList()) {
-				final RelDataType type = inputRowType.getFieldList().get(arg).getType();
-				if (type.isNullable()) {
-					predicates.add(
-						rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
-							rexBuilder.makeInputRef(type, arg)));
-				}
-			}
-			final RexNode predicate =
-				RexUtil.composeConjunction(rexBuilder, predicates, true);
-			if (predicate == null) {
-				return rexBuilder.makeExactLiteral(BigDecimal.ONE);
-			} else {
-				return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate,
-					rexBuilder.makeExactLiteral(BigDecimal.ONE),
-					rexBuilder.makeExactLiteral(BigDecimal.ZERO));
-			}
-		}
-	}
-
-	/** Aggregate function that splits into two applications of itself.
-	 *
-	 * <p>Examples are MIN and MAX. */
-	class SelfSplitter implements SqlSplittableAggFunction {
-		public static final SelfSplitter INSTANCE = new SelfSplitter();
-
-		public RexNode singleton(RexBuilder rexBuilder,
-								 RelDataType inputRowType, AggregateCall aggregateCall) {
-			final int arg = aggregateCall.getArgList().get(0);
-			final RelDataTypeField field = inputRowType.getFieldList().get(arg);
-			return rexBuilder.makeInputRef(field.getType(), arg);
-		}
-
-		public AggregateCall split(AggregateCall aggregateCall,
-								   Mappings.TargetMapping mapping) {
-			return aggregateCall.transform(mapping);
-		}
-
-		public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
-			return null; // no aggregate function required on other side
-		}
-
-		public AggregateCall topSplit(RexBuilder rexBuilder,
-									  Registry<RexNode> extra, int offset, RelDataType inputRowType,
-									  AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
-									  JoinRelType joinRelType) {
-			assert (leftSubTotal >= 0) != (rightSubTotal >= 0);
-			final int arg = leftSubTotal >= 0 ? leftSubTotal : rightSubTotal;
-			return aggregateCall.copy(ImmutableIntList.of(arg), -1);
-		}
-	}
-
-	/** Common Splitting strategy for {@coide SUM} and {@coide SUM0}. */
-	abstract class AbstractSumSplitter implements SqlSplittableAggFunction {
-
-		public RexNode singleton(RexBuilder rexBuilder,
-								 RelDataType inputRowType, AggregateCall aggregateCall) {
-			final int arg = aggregateCall.getArgList().get(0);
-			final RelDataTypeField field = inputRowType.getFieldList().get(arg);
-			return rexBuilder.makeInputRef(field.getType(), arg);
-		}
-
-		public AggregateCall split(AggregateCall aggregateCall,
-								   Mappings.TargetMapping mapping) {
-			return aggregateCall.transform(mapping);
-		}
-
-		public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
-			return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
-				ImmutableIntList.of(), -1,
-				typeFactory.createSqlType(SqlTypeName.BIGINT), null);
-		}
-
-		/**
-		 * This new function create a proper RexNode for {@coide SUM} Agg with OuterJoin Condition.
-		 */
-		private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder,
-													  RelDataType inputRowType,
-													  AggregateCall aggregateCall,
-													  int index) {
-			RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index);
-			RexLiteral literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
-			RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef);
-			return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
-				predicate,
-				literal,
-				rexBuilder.makeCast(aggregateCall.type, inputRef)
-			);
-		}
-
-		public AggregateCall topSplit(RexBuilder rexBuilder,
-									  Registry<RexNode> extra, int offset, RelDataType inputRowType,
-									  AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal,
-									  JoinRelType joinRelType) {
-			final List<RexNode> merges = new ArrayList<>();
-			final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
-			if (leftSubTotal >= 0) {
-				// add support for left outer join
-				if (joinRelType == JoinRelType.RIGHT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) {
-					merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal));
-				} else {
-					// if it's a inner join, then do the previous logic
-					final RelDataType type = fieldList.get(leftSubTotal).getType();
-					merges.add(rexBuilder.makeInputRef(type, leftSubTotal));
-				}
-			}
-			if (rightSubTotal >= 0) {
-				// add support for right outer join
-				if (joinRelType == JoinRelType.LEFT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) {
-					merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, offset + rightSubTotal));
-				} else {
-					// if it's a inner join, then do the previous logic
-					final RelDataType type = fieldList.get(rightSubTotal).getType();
-					merges.add(rexBuilder.makeInputRef(type, rightSubTotal));
-				}
-			}
-			RexNode node;
-			switch (merges.size()) {
-				case 1:
-					node = merges.get(0);
-					break;
-				case 2:
-					node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
-					node = rexBuilder.makeAbstractCast(aggregateCall.type, node);
-					break;
-				default:
-					throw new AssertionError("unexpected count " + merges);
-			}
-			int ordinal = extra.register(node);
-			return AggregateCall.create(getMergeAggFunctionOfTopSplit(), false, false,
-				ImmutableList.of(ordinal), -1, aggregateCall.type,
-				aggregateCall.name);
-		}
-
-		protected abstract SqlAggFunction getMergeAggFunctionOfTopSplit();
-
-	}
-
-	/** Splitting strategy for {@coide SUM}. */
-	class SumSplitter extends AbstractSumSplitter {
-
-		public static final SumSplitter INSTANCE = new SumSplitter();
-
-		@Override public SqlAggFunction getMergeAggFunctionOfTopSplit() {
-			return SqlStdOperatorTable.SUM;
-		}
-
-	}
-
-	/** Splitting strategy for {@code SUM0}. */
-	class Sum0Splitter extends AbstractSumSplitter {
-
-		public static final Sum0Splitter INSTANCE = new Sum0Splitter();
-
-		@Override public SqlAggFunction getMergeAggFunctionOfTopSplit() {
-			return SqlStdOperatorTable.SUM0;
-		}
-	}
-}
-
-// End SqlSplittableAggFunction.java
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
index 10c0b94..c803f20 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java
@@ -31,7 +31,6 @@ import org.apache.calcite.rel.SingleRel;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.Join;
-import org.apache.calcite.rel.core.JoinInfo;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.logical.LogicalAggregate;
@@ -46,7 +45,6 @@ import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
@@ -73,7 +71,6 @@ import scala.collection.Seq;
  * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateJoinTransposeRule}.
  * Modification:
  * - Do not match temporal join since lookup table source doesn't support aggregate.
- * - Support Left/Right Outer Join
  * - Fix type mismatch error
  * - Support aggregate with AUXILIARY_GROUP
  */
@@ -86,25 +83,19 @@ import scala.collection.Seq;
 public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 	public static final FlinkAggregateJoinTransposeRule INSTANCE =
 			new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
-					RelFactories.LOGICAL_BUILDER, false, false);
+					RelFactories.LOGICAL_BUILDER, false);
 
 	/** Extended instance of the rule that can push down aggregate functions. */
 	public static final FlinkAggregateJoinTransposeRule EXTENDED =
 			new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
-					RelFactories.LOGICAL_BUILDER, true, false);
-
-	public static final FlinkAggregateJoinTransposeRule LEFT_RIGHT_OUTER_JOIN_EXTENDED =
-			new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class,
-					RelFactories.LOGICAL_BUILDER, true, true);
+					RelFactories.LOGICAL_BUILDER, true);
 
 	private final boolean allowFunctions;
 
-	private final boolean allowLeftOrRightOuterJoin;
-
 	/** Creates an FlinkAggregateJoinTransposeRule. */
 	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass,
 			Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory,
-			boolean allowFunctions, boolean allowLeftOrRightOuterJoin) {
+			boolean allowFunctions) {
 		super(
 				operandJ(aggregateClass, null,
 						aggregate -> aggregate.getGroupType() == Aggregate.Group.SIMPLE,
@@ -112,7 +103,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 				relBuilderFactory, null);
 
 		this.allowFunctions = allowFunctions;
-		this.allowLeftOrRightOuterJoin = allowLeftOrRightOuterJoin;
 	}
 
 	@Deprecated // to be removed before 2.0
@@ -121,7 +111,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 			Class<? extends Join> joinClass,
 			RelFactories.JoinFactory joinFactory) {
 		this(aggregateClass, joinClass,
-				RelBuilder.proto(aggregateFactory, joinFactory), false, false);
+				RelBuilder.proto(aggregateFactory, joinFactory), false);
 	}
 
 	@Deprecated // to be removed before 2.0
@@ -131,7 +121,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 			RelFactories.JoinFactory joinFactory,
 			boolean allowFunctions) {
 		this(aggregateClass, joinClass,
-				RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions, false);
+				RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
 	}
 
 	@Deprecated // to be removed before 2.0
@@ -141,7 +131,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 			RelFactories.JoinFactory joinFactory,
 			RelFactories.ProjectFactory projectFactory) {
 		this(aggregateClass, joinClass,
-				RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false, false);
+				RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
 	}
 
 	@Deprecated // to be removed before 2.0
@@ -153,7 +143,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 			boolean allowFunctions) {
 		this(aggregateClass, joinClass,
 				RelBuilder.proto(aggregateFactory, joinFactory, projectFactory),
-				allowFunctions, false);
+				allowFunctions);
 	}
 
 	private boolean containsSnapshot(RelNode relNode) {
@@ -189,13 +179,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 		final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
 		final RelBuilder relBuilder = call.builder();
 
-		boolean isLeftOrRightOuterJoin =
-				join.getJoinType() == JoinRelType.LEFT || join.getJoinType() == JoinRelType.RIGHT;
-
-		if (join.getJoinType() != JoinRelType.INNER && !(allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin)) {
-			return;
-		}
-
 		// converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
 		// if the converted aggregate can be push down,
 		// AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this rule
@@ -210,18 +193,15 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 					== null) {
 				return;
 			}
-			if (allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin) {
-				// todo do not support max/min agg until we've built the proper model
-				if (aggregateCall.getAggregation().kind == SqlKind.MAX ||
-						aggregateCall.getAggregation().kind == SqlKind.MIN) {
-					return;
-				}
-			}
 			if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
 				return;
 			}
 		}
 
+		if (join.getJoinType() != JoinRelType.INNER) {
+			return;
+		}
+
 		if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
 			return;
 		}
@@ -229,19 +209,8 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 		// Do the columns used by the join appear in the output of the aggregate?
 		final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
 		final RelMetadataQuery mq = call.getMetadataQuery();
-		ImmutableBitSet keyColumns;
-		if (!isLeftOrRightOuterJoin) {
-			keyColumns = keyColumns(aggregateColumns,
-					mq.getPulledUpPredicates(join).pulledUpPredicates);
-		} else {
-			// this is an incomplete implementation
-			if (isAggregateKeyApplicable(aggregateColumns, join)) {
-				keyColumns = keyColumns(aggregateColumns,
-						com.google.common.collect.ImmutableList.copyOf(RelOptUtil.conjunctions(join.getCondition())));
-			} else {
-				keyColumns = aggregateColumns;
-			}
-		}
+		final ImmutableBitSet keyColumns = keyColumns(aggregateColumns,
+				mq.getPulledUpPredicates(join).pulledUpPredicates);
 		final ImmutableBitSet joinColumns =
 				RelOptUtil.InputFinder.bits(join.getCondition());
 		final boolean allColumnsInAggregate =
@@ -423,7 +392,7 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 					splitter.topSplit(rexBuilder, registry(projects),
 							groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e,
 							leftSubTotal == null ? -1 : leftSubTotal,
-							rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth, join.getJoinType()));
+							rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
 		}
 
 		relBuilder.project(projects);
@@ -546,14 +515,6 @@ public class FlinkAggregateJoinTransposeRule extends RelOptRule {
 		}
 	}
 
-	private static boolean isAggregateKeyApplicable(ImmutableBitSet aggregateKeys, Join join) {
-		JoinInfo joinInfo = join.analyzeCondition();
-		return (join.getJoinType() == JoinRelType.LEFT && joinInfo.leftSet().contains(aggregateKeys)) ||
-				(join.getJoinType() == JoinRelType.RIGHT &&
-						joinInfo.rightSet().shift(join.getInput(0).getRowType().getFieldCount())
-								.contains(aggregateKeys));
-	}
-
 	private static void populateEquivalence(Map<Integer, BitSet> equivalence,
 			int i0, int i1) {
 		BitSet bitSet = equivalence.get(i0);
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index 491fc1b..d095f76 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -249,7 +249,7 @@ object FlinkStreamRuleSets {
     // remove aggregation if it does not aggregate and input is already distinct
     FlinkAggregateRemoveRule.INSTANCE,
     // push aggregate through join
-    FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED,
+    FlinkAggregateJoinTransposeRule.EXTENDED,
     // using variants of aggregate union rule
     AggregateUnionAggregateRule.AGG_ON_FIRST_INPUT,
     AggregateUnionAggregateRule.AGG_ON_SECOND_INPUT,
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
index f964110..475744b 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
@@ -116,14 +116,13 @@ LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$4=[COUNT($4)], EXPR$5=[AVG($5)])
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-Calc(select=[a2, b2, a3, b3, *($f2, $f20) AS EXPR$4, /(CAST(CASE(=($f4, 0), null:BIGINT, $f3)), $f4) AS EXPR$5])
-+- HashJoin(joinType=[InnerJoin], where=[=(b2, a3)], select=[a2, b2, $f2, a3, b3, $f20, $f3, $f4], isBroadcast=[true], build=[right])
-   :- Calc(select=[a2, b2, CAST(CASE(IS NOT NULL(c2), 1, 0)) AS $f2])
-   :  +- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
-   +- Exchange(distribution=[broadcast])
-      +- HashAggregate(isMerge=[true], groupBy=[a3, b3], select=[a3, b3, Final_COUNT(count1$0) AS $f2, Final_$SUM0(sum$1) AS $f3, Final_COUNT(count$2) AS $f4])
-         +- Exchange(distribution=[hash[a3, b3]])
-            +- LocalHashAggregate(groupBy=[a3, b3], select=[a3, b3, Partial_COUNT(*) AS count1$0, Partial_$SUM0(d3) AS sum$1, Partial_COUNT(d3) AS count$2])
+Calc(select=[a2, b2, a3, b3, EXPR$4, EXPR$5])
++- HashAggregate(isMerge=[true], groupBy=[a3, b3], auxGrouping=[a2, b2], select=[a3, b3, a2, b2, Final_COUNT(count$0) AS EXPR$4, Final_AVG(sum$1, count$2) AS EXPR$5])
+   +- Exchange(distribution=[hash[a3, b3]])
+      +- LocalHashAggregate(groupBy=[a3, b3], auxGrouping=[a2, b2], select=[a3, b3, a2, b2, Partial_COUNT(c2) AS count$0, Partial_AVG(d3) AS (sum$1, count$2)])
+         +- HashJoin(joinType=[InnerJoin], where=[=(b2, a3)], select=[a2, b2, c2, a3, b3, d3], isBroadcast=[true], build=[right])
+            :- TableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
+            +- Exchange(distribution=[broadcast])
                +- Calc(select=[a3, b3, d3])
                   +- TableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
 ]]>
@@ -149,7 +148,7 @@ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5}], EXPR$6=[COUNT($6)])
       <![CDATA[
 Calc(select=[a1, b1, a2, b2, a3, b3, *($f4, $f2) AS EXPR$6])
 +- HashJoin(joinType=[InnerJoin], where=[=(a1, a3)], select=[a1, b1, a2, b2, $f4, a3, b3, $f2], isBroadcast=[true], build=[right])
-   :- Calc(select=[a1, b1, a2, b2, CAST(CASE(IS NOT NULL(c1), 1, 0)) AS $f4])
+   :- Calc(select=[a1, b1, a2, b2, CASE(IS NOT NULL(c1), 1:BIGINT, 0:BIGINT) AS $f4])
    :  +- HashJoin(joinType=[InnerJoin], where=[=(a1, b2)], select=[a1, b1, c1, a2, b2], build=[right])
    :     :- Exchange(distribution=[hash[a1]])
    :     :  +- Calc(select=[a1, b1, c1])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
index d659cd9..1212910 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
@@ -107,11 +107,10 @@ LogicalAggregate(group=[{0, 1, 2, 3}], EXPR$4=[COUNT($4)], EXPR$5=[AVG($5)])
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-FlinkLogicalCalc(select=[a2, b2, a3, b3, *($f2, $f20) AS EXPR$4, /(CAST(CASE(=($f4, 0), null:BIGINT, $f3)), $f4) AS EXPR$5])
-+- FlinkLogicalJoin(condition=[=($1, $3)], joinType=[inner])
-   :- FlinkLogicalCalc(select=[a2, b2, CAST(CASE(IS NOT NULL(c2), 1, 0)) AS $f2])
-   :  +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
-   +- FlinkLogicalAggregate(group=[{0, 1}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[COUNT($2)])
+FlinkLogicalCalc(select=[a2, b2, a3, b3, EXPR$4, EXPR$5])
++- FlinkLogicalAggregate(group=[{3, 4}], a2=[AUXILIARY_GROUP($0)], b2=[AUXILIARY_GROUP($1)], EXPR$4=[COUNT($2)], EXPR$5=[AVG($5)])
+   +- FlinkLogicalJoin(condition=[=($1, $3)], joinType=[inner])
+      :- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
       +- FlinkLogicalCalc(select=[a3, b3, d3])
          +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
 ]]>
@@ -137,7 +136,7 @@ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5}], EXPR$6=[COUNT($6)])
       <![CDATA[
 FlinkLogicalCalc(select=[a1, b1, a2, b2, a3, b3, *($f4, $f2) AS EXPR$6])
 +- FlinkLogicalJoin(condition=[=($0, $5)], joinType=[inner])
-   :- FlinkLogicalCalc(select=[a1, b1, a2, b2, CAST(CASE(IS NOT NULL(c1), 1, 0)) AS $f4])
+   :- FlinkLogicalCalc(select=[a1, b1, a2, b2, CASE(IS NOT NULL(c1), 1:BIGINT, 0:BIGINT) AS $f4])
    :  +- FlinkLogicalJoin(condition=[=($0, $4)], joinType=[inner])
    :     :- FlinkLogicalCalc(select=[a1, b1, c1])
    :     :  +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a1, b1, c1, d1)]]], fields=[a1, b1, c1, d1])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
similarity index 98%
rename from flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml
rename to flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
index 475d60f..0d58e8b 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml
@@ -192,7 +192,7 @@ LogicalProject(EXPR$0=[$2], EXPR$1=[$3], a=[$0], b=[$1], EXPR$4=[$4])
 LogicalProject(EXPR$0=[$2], EXPR$1=[$3], a=[$0], b=[$1], EXPR$4=[$4])
 +- LogicalProject(a=[$3], b=[$4], a2=[$1], b2=[$0], $f6=[*($2, $5)])
    +- LogicalJoin(condition=[=($0, $3)], joinType=[inner])
-      :- LogicalProject(b2=[$1], a2=[$0], $f2=[CAST(CASE(IS NOT NULL($2), 1, 0)):BIGINT NOT NULL])
+      :- LogicalProject(b2=[$1], a2=[$0], $f2=[CASE(IS NOT NULL($2), 1:BIGINT, 0:BIGINT)])
       :  +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]])
       +- LogicalAggregate(group=[{0, 1}], agg#0=[COUNT()])
          +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c)]]])
@@ -219,7 +219,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
 LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
 +- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, $3)])
    +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
-      :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
+      :- LogicalProject(a=[$0], $f1=[CASE(IS NOT NULL($0), 1:BIGINT, 0:BIGINT)])
       :  +- LogicalAggregate(group=[{0}])
       :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c)]]])
       +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml
deleted file mode 100644
index 4df5360..0000000
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml
+++ /dev/null
@@ -1,267 +0,0 @@
-<?xml version="1.0" ?>
-<!--
-Licensed to the Apache Software Foundation (ASF) under one or more
-contributor license agreements.  See the NOTICE file distributed with
-this work for additional information regarding copyright ownership.
-The ASF licenses this file to you under the Apache License, Version 2.0
-(the "License"); you may not use this file except in compliance with
-the License.  You may obtain a copy of the License at
-
-http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
--->
-<Root>
-  <TestCase name="testPushCountAggThroughJoinOverUniqueColumn">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
-+- LogicalProject(a=[$0])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalProject(a=[$0])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
-+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, $3)])
-   +- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
-      :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
-      :  +- LogicalAggregate(group=[{0}])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAggThroughLeftJoinAndGroupByLeft">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(B.b) FROM (SELECT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-   +- LogicalProject(a=[$0], b=[$2])
-      +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-         :- LogicalProject(a=[$0])
-         :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a=[$0], $f4=[*($1, CASE(IS NULL($3), 0, CAST($3):BIGINT NOT NULL))])
-   +- LogicalJoin(condition=[=($0, $2)], joinType=[left])
-      :- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByLeft">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-   +- LogicalProject(a=[$0], b=[$2])
-      +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-         :- LogicalAggregate(group=[{0}])
-         :  +- LogicalProject(a=[$0])
-         :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a=[$0], $f3=[CASE(IS NULL($2), 0, CAST($2):BIGINT NOT NULL)])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumn">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])
-+- LogicalProject(a=[$0])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalProject(a=[$0])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
-+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, CASE(IS NULL($3), 0, CAST($3):BIGINT NOT NULL))])
-   +- LogicalJoin(condition=[=($0, $2)], joinType=[left])
-      :- LogicalProject(a=[$0], $f1=[CAST(CASE(IS NOT NULL($0), 1, 0)):BIGINT NOT NULL])
-      :  +- LogicalAggregate(group=[{0}])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushSumAggThroughLeftJoinOverUniqueColumn">
-    <Resource name="sql">
-      <![CDATA[SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
-+- LogicalProject(a=[$0])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalProject(a=[$0])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
-+- LogicalProject(a=[$0], a0=[$1], $f1=[$2], $f3=[CAST(*($0, $2)):INTEGER])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByRight">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-   +- LogicalProject(a0=[$1], b=[$2])
-      +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-         :- LogicalAggregate(group=[{0}])
-         :  +- LogicalProject(a=[$0])
-         :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{1}], EXPR$0=[$SUM0($3)])
-   +- LogicalProject(a=[$0], a0=[$1], EXPR$0=[$2], $f3=[CASE(IS NULL($2), 0, CAST($2):BIGINT NOT NULL)])
-      +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-         :- LogicalAggregate(group=[{0}])
-         :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-         +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-            +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAggThroughRightJoin">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(B.b) FROM T AS B RIGHT OUTER JOIN (SELECT a FROM T) AS A ON A.a=B.a GROUP BY A.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-   +- LogicalProject(a0=[$4], b=[$1])
-      +- LogicalJoin(condition=[=($4, $0)], joinType=[right])
-         :- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-         +- LogicalProject(a=[$0])
-            +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalProject(EXPR$0=[$1])
-+- LogicalProject(a0=[$2], $f4=[*(CASE(IS NULL($1), 0, CAST($1):BIGINT NOT NULL), $3)])
-   +- LogicalJoin(condition=[=($2, $0)], joinType=[right])
-      :- LogicalAggregate(group=[{0}], EXPR$0=[COUNT($1)])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushSumAggThroughJoinOverUniqueColumn">
-    <Resource name="sql">
-      <![CDATA[SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
-+- LogicalProject(a=[$0])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalProject(a=[$0])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
-+- LogicalProject(a=[$0], a0=[$1], $f1=[$2], $f3=[CAST(*($0, $2)):INTEGER])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[inner])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testPushCountAllAggThroughLeftJoinOverUniqueColumn">
-    <Resource name="sql">
-      <![CDATA[SELECT COUNT(*) FROM (SELECT DISTINCT a FROM T) AS A LEFT OUTER JOIN T AS B ON A.a=B.a]]>
-    </Resource>
-    <Resource name="planBefore">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
-+- LogicalProject($f0=[0])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalProject(a=[$0])
-      :     +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-    <Resource name="planAfter">
-      <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($3)])
-+- LogicalProject(a=[$0], a0=[$1], EXPR$0=[$2], $f3=[CASE(IS NULL($2), 1, CAST($2):BIGINT NOT NULL)])
-   +- LogicalJoin(condition=[=($0, $1)], joinType=[left])
-      :- LogicalAggregate(group=[{0}])
-      :  +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-      +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
-         +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c, d)]]])
-]]>
-    </Resource>
-  </TestCase>
-</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
similarity index 97%
rename from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala
rename to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
index 316ff13..41218e8 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.scala
@@ -34,9 +34,8 @@ import org.junit.{Before, Test}
 
 /**
   * Test for [[FlinkAggregateJoinTransposeRule]].
-  * this class only test inner join.
   */
-class FlinkAggregateInnerJoinTransposeRuleTest extends TableTestBase {
+class FlinkAggregateJoinTransposeRuleTest extends TableTestBase {
   private val util = batchTestUtil()
 
   @Before
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala
deleted file mode 100644
index 186514d..0000000
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.rules.logical
-
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.plan.optimize.program.{FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE, StreamOptimizeContext}
-import org.apache.flink.table.util.TableTestBase
-
-import org.apache.calcite.plan.hep.HepMatchOrder
-import org.apache.calcite.rel.rules._
-import org.apache.calcite.tools.RuleSets
-import org.junit.{Before, Test}
-
-/**
-  * Test for [[FlinkAggregateJoinTransposeRule]].
-  * this class only test left/right outer join.
-  */
-class FlinkAggregateOuterJoinTransposeRuleTest extends TableTestBase {
-
-  private val util = streamTestUtil()
-
-  @Before
-  def setup(): Unit = {
-    val program = new FlinkChainedProgram[StreamOptimizeContext]()
-    program.addLast(
-      "rules",
-      FlinkHepRuleSetProgramBuilder.newBuilder
-        .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
-        .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
-        .add(RuleSets.ofList(
-          FlinkFilterJoinRule.FILTER_ON_JOIN,
-          FlinkFilterJoinRule.JOIN,
-          FilterAggregateTransposeRule.INSTANCE,
-          FilterProjectTransposeRule.INSTANCE,
-          FilterMergeRule.INSTANCE,
-          AggregateProjectMergeRule.INSTANCE,
-          FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED
-        ))
-        .build()
-    )
-    util.replaceStreamProgram(program)
-
-    util.addTableSource[(Int, Long, String, Int)]("T", 'a, 'b, 'c, 'd)
-  }
-
-  @Test
-  def testPushCountAggThroughJoinOverUniqueColumn(): Unit = {
-    util.verifyPlan("SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a")
-  }
-
-  @Test
-  def testPushSumAggThroughJoinOverUniqueColumn(): Unit = {
-    util.verifyPlan("SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a")
-  }
-
-  @Test
-  def testPushCountAggThroughLeftJoinOverUniqueColumn(): Unit = {
-    val sqlQuery = "SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushSumAggThroughLeftJoinOverUniqueColumn(): Unit = {
-    val sqlQuery = "SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushCountAllAggThroughLeftJoinOverUniqueColumn(): Unit = {
-    val sqlQuery = "SELECT COUNT(*) FROM (SELECT DISTINCT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByLeft(): Unit = {
-    val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByRight(): Unit = {
-    val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY B.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushCountAggThroughLeftJoinAndGroupByLeft(): Unit = {
-    val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT a FROM T) AS A " +
-      "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-  @Test
-  def testPushCountAggThroughRightJoin(): Unit = {
-    val sqlQuery = "SELECT COUNT(B.b) FROM T AS B RIGHT OUTER JOIN (SELECT a FROM T) AS A " +
-      "ON A.a=B.a GROUP BY A.a"
-    util.verifyPlan(sqlQuery)
-  }
-
-}