You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/01/17 13:44:59 UTC

[3/3] flink git commit: [FLINK-5144] [table] Fix error while applying rule AggregateJoinTransposeRule

[FLINK-5144] [table] Fix error while applying rule AggregateJoinTransposeRule

This closes #3062.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e187b5ee
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e187b5ee
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e187b5ee

Branch: refs/heads/master
Commit: e187b5ee9aa6d1bf9feec151ff460d1a28c4e5f0
Parents: dba7d7d
Author: Kurt Young <yk...@gmail.com>
Authored: Thu Jan 5 11:32:04 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Tue Jan 17 14:44:27 2017 +0100

----------------------------------------------------------------------
 .../rules/FlinkAggregateJoinTransposeRule.java  |  346 +++
 .../calcite/sql2rel/FlinkRelDecorrelator.java   | 2216 ++++++++++++++++++
 .../flink/table/calcite/FlinkPlannerImpl.scala  |    7 +-
 .../flink/table/plan/rules/FlinkRuleSets.scala  |    5 +-
 .../batch/sql/QueryDecorrelationTest.scala      |  218 ++
 5 files changed, 2787 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e187b5ee/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
new file mode 100644
index 0000000..ac36b3c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.calcite.rules;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.Bug;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mapping;
+import org.apache.calcite.util.mapping.Mappings;
+
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Copied from {@link org.apache.calcite.rel.rules.AggregateJoinTransposeRule}, should be
+ * removed once <a href="https://issues.apache.org/jira/browse/CALCITE-1544">[CALCITE-1544] fixes.
+ */
+public class FlinkAggregateJoinTransposeRule extends RelOptRule {
+	public static final FlinkAggregateJoinTransposeRule INSTANCE = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, false);
+
+	/**
+	 * Extended instance of the rule that can push down aggregate functions.
+	 */
+	public static final FlinkAggregateJoinTransposeRule EXTENDED = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, true);
+
+	private final boolean allowFunctions;
+
+	/**
+	 * Creates an FlinkAggregateJoinTransposeRule.
+	 */
+	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory, boolean allowFunctions) {
+		super(operand(aggregateClass, null, Aggregate.IS_SIMPLE, operand(joinClass, any())), relBuilderFactory, null);
+		this.allowFunctions = allowFunctions;
+	}
+
+	@Deprecated // to be removed before 2.0
+	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory) {
+		this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), false);
+	}
+
+	@Deprecated // to be removed before 2.0
+	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, boolean allowFunctions) {
+		this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
+	}
+
+	@Deprecated // to be removed before 2.0
+	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
+		this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
+	}
+
+	@Deprecated // to be removed before 2.0
+	public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
+		this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), allowFunctions);
+	}
+
+	public void onMatch(RelOptRuleCall call) {
+		final Aggregate aggregate = call.rel(0);
+		final Join join = call.rel(1);
+		final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
+		final RelBuilder relBuilder = call.builder();
+
+		// If any aggregate functions do not support splitting, bail out
+		// If any aggregate call has a filter, bail out
+		for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+			if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
+				return;
+			}
+			if (aggregateCall.filterArg >= 0) {
+				return;
+			}
+		}
+
+		// If it is not an inner join, we do not push the
+		// aggregate operator
+		if (join.getJoinType() != JoinRelType.INNER) {
+			return;
+		}
+
+		if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
+			return;
+		}
+
+		// Do the columns used by the join appear in the output of the aggregate?
+		final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
+		final RelMetadataQuery mq = RelMetadataQuery.instance();
+		final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
+		final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
+		final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
+		final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
+
+		// Split join condition
+		final List<Integer> leftKeys = Lists.newArrayList();
+		final List<Integer> rightKeys = Lists.newArrayList();
+		final List<Boolean> filterNulls = Lists.newArrayList();
+		RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
+		// If it contains non-equi join conditions, we bail out
+		if (!nonEquiConj.isAlwaysTrue()) {
+			return;
+		}
+
+		// Push each aggregate function down to each side that contains all of its
+		// arguments. Note that COUNT(*), because it has no arguments, can go to
+		// both sides.
+		final Map<Integer, Integer> map = new HashMap<>();
+		final List<Side> sides = new ArrayList<>();
+		int uniqueCount = 0;
+		int offset = 0;
+		int belowOffset = 0;
+		for (int s = 0; s < 2; s++) {
+			final Side side = new Side();
+			final RelNode joinInput = join.getInput(s);
+			int fieldCount = joinInput.getRowType().getFieldCount();
+			final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
+			final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
+			for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
+				map.put(c.e, belowOffset + c.i);
+			}
+			final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
+			final boolean unique;
+			if (!allowFunctions) {
+				assert aggregate.getAggCallList().isEmpty();
+				// If there are no functions, it doesn't matter as much whether we
+				// aggregate the inputs before the join, because there will not be
+				// any functions experiencing a cartesian product effect.
+				//
+				// But finding out whether the input is already unique requires a call
+				// to areColumnsUnique that currently (until [CALCITE-1048] "Make
+				// metadata more robust" is fixed) places a heavy load on
+				// the metadata system.
+				//
+				// So we choose to imagine the the input is already unique, which is
+				// untrue but harmless.
+				//
+				Util.discard(Bug.CALCITE_1048_FIXED);
+				unique = true;
+			} else {
+				final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
+				unique = unique0 != null && unique0;
+			}
+			if (unique) {
+				++uniqueCount;
+				side.aggregate = false;
+				side.newInput = joinInput;
+			} else {
+				side.aggregate = true;
+				List<AggregateCall> belowAggCalls = new ArrayList<>();
+				final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
+				final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
+				for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
+					final SqlAggFunction aggregation = aggCall.e.getAggregation();
+					final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
+					final AggregateCall call1;
+					if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
+						call1 = splitter.split(aggCall.e, mapping);
+					} else {
+						call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
+					}
+					if (call1 != null) {
+						side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
+					}
+				}
+				side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
+			}
+			offset += fieldCount;
+			belowOffset += side.newInput.getRowType().getFieldCount();
+			sides.add(side);
+		}
+
+		if (uniqueCount == 2) {
+			// Both inputs to the join are unique. There is nothing to be gained by
+			// this rule. In fact, this aggregate+join may be the result of a previous
+			// invocation of this rule; if we continue we might loop forever.
+			return;
+		}
+
+		// Update condition
+		final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {
+			public Integer apply(Integer a0) {
+				return map.get(a0);
+			}
+		}, join.getRowType().getFieldCount(), belowOffset);
+		final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
+
+		// Create new join
+		relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
+
+		// Aggregate above to sum up the sub-totals
+		final List<AggregateCall> newAggCalls = new ArrayList<>();
+		final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
+		final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
+		final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
+		for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
+			final SqlAggFunction aggregation = aggCall.e.getAggregation();
+			final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
+			final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
+			final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
+			newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
+		}
+
+		relBuilder.project(projects);
+
+		boolean aggConvertedToProjects = false;
+		if (allColumnsInAggregate) {
+			// let's see if we can convert aggregate into projects
+			List<RexNode> projects2 = new ArrayList<>();
+			for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
+				projects2.add(relBuilder.field(key));
+			}
+			for (AggregateCall newAggCall : newAggCalls) {
+				final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
+				if (splitter != null) {
+					projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
+				}
+			}
+			if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
+				// We successfully converted agg calls into projects.
+				relBuilder.project(projects2);
+				aggConvertedToProjects = true;
+			}
+		}
+
+		if (!aggConvertedToProjects) {
+			relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
+		}
+
+		call.transformTo(relBuilder.build());
+	}
+
+	/**
+	 * Computes the closure of a set of columns according to a given list of
+	 * constraints. Each 'x = y' constraint causes bit y to be set if bit x is
+	 * set, and vice versa.
+	 */
+	private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
+		SortedMap<Integer, BitSet> equivalence = new TreeMap<>();
+		for (RexNode pred : predicates) {
+			populateEquivalences(equivalence, pred);
+		}
+		ImmutableBitSet keyColumns = aggregateColumns;
+		for (Integer aggregateColumn : aggregateColumns) {
+			final BitSet bitSet = equivalence.get(aggregateColumn);
+			if (bitSet != null) {
+				keyColumns = keyColumns.union(bitSet);
+			}
+		}
+		return keyColumns;
+	}
+
+	private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
+		switch (predicate.getKind()) {
+			case EQUALS:
+				RexCall call = (RexCall) predicate;
+				final List<RexNode> operands = call.getOperands();
+				if (operands.get(0) instanceof RexInputRef) {
+					final RexInputRef ref0 = (RexInputRef) operands.get(0);
+					if (operands.get(1) instanceof RexInputRef) {
+						final RexInputRef ref1 = (RexInputRef) operands.get(1);
+						populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
+						populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
+					}
+				}
+		}
+	}
+
+	private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
+		BitSet bitSet = equivalence.get(i0);
+		if (bitSet == null) {
+			bitSet = new BitSet();
+			equivalence.put(i0, bitSet);
+		}
+		bitSet.set(i1);
+	}
+
+	/**
+	 * Creates a {@link SqlSplittableAggFunction.Registry}
+	 * that is a view of a list.
+	 */
+	private static <E> SqlSplittableAggFunction.Registry<E> registry(final List<E> list) {
+		return new SqlSplittableAggFunction.Registry<E>() {
+			public int register(E e) {
+				int i = list.indexOf(e);
+				if (i < 0) {
+					i = list.size();
+					list.add(e);
+				}
+				return i;
+			}
+		};
+	}
+
+	/**
+	 * Work space for an input to a join.
+	 */
+	private static class Side {
+		final Map<Integer, Integer> split = new HashMap<>();
+		RelNode newInput;
+		boolean aggregate;
+	}
+}
+
+// End FlinkAggregateJoinTransposeRule.java