You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2021/05/28 18:50:50 UTC
[calcite] 03/03: [CALCITE-4497] In RelBuilder,
support windowed aggregate functions (OVER)
This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 241b6db49c3ca88c9cdf25ddd2c0c8457734a130
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Thu Feb 25 16:55:58 2021 -0800
[CALCITE-4497] In RelBuilder, support windowed aggregate functions (OVER)
Add method RelBuilder.AggCall.over(), and class RelBuilder.OverCall.
---
.../java/org/apache/calcite/sql/SqlWindow.java | 5 +
.../apache/calcite/sql2rel/SqlToRelConverter.java | 210 +++++++------
.../java/org/apache/calcite/tools/RelBuilder.java | 342 ++++++++++++++++++++-
.../calcite/rel/rel2sql/RelToSqlConverterTest.java | 21 +-
.../org/apache/calcite/test/RelBuilderTest.java | 76 +++--
.../apache/calcite/test/SqlToRelConverterTest.xml | 5 +-
.../org/apache/calcite/piglet/PigRelOpVisitor.java | 37 +--
site/_docs/algebra.md | 27 ++
8 files changed, 554 insertions(+), 169 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlWindow.java b/core/src/main/java/org/apache/calcite/sql/SqlWindow.java
index b41e89a..8311d62 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlWindow.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlWindow.java
@@ -258,6 +258,11 @@ public class SqlWindow extends SqlCall {
} else {
return false;
}
+ return isAlwaysNonEmpty(lower, upper);
+ }
+
+ public static boolean isAlwaysNonEmpty(RexWindowBound lower,
+ RexWindowBound upper) {
final int lowerKey = lower.getOrderKey();
final int upperKey = upper.getOrderKey();
return lowerKey > -1 && lowerKey <= upperKey;
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index a1d14de..0c0eacc 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -73,7 +73,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexFieldAccess;
@@ -2052,13 +2051,14 @@ public class SqlToRelConverter {
}
}
- final ImmutableList.Builder<RexFieldCollation> orderKeys =
+ final ImmutableList.Builder<RexNode> orderKeys =
ImmutableList.builder();
for (SqlNode order : orderList) {
orderKeys.add(
bb.convertSortExpression(order,
RelFieldCollation.Direction.ASCENDING,
- RelFieldCollation.NullDirection.UNSPECIFIED));
+ RelFieldCollation.NullDirection.UNSPECIFIED,
+ bb::sortToRex));
}
try {
@@ -2079,14 +2079,10 @@ public class SqlToRelConverter {
&& q.getValue() == SqlSelectKeyword.DISTINCT;
final RexShuttle visitor =
- new HistogramShuttle(
- partitionKeys.build(), orderKeys.build(),
+ new HistogramShuttle(partitionKeys.build(), orderKeys.build(), rows,
RexWindowBounds.create(sqlLowerBound, lowerBound),
RexWindowBounds.create(sqlUpperBound, upperBound),
- rows,
- window.isAllowPartial(),
- isDistinct,
- ignoreNulls);
+ window.isAllowPartial(), isDistinct, ignoreNulls);
return rexAgg.accept(visitor);
} finally {
bb.window = null;
@@ -4481,6 +4477,18 @@ public class SqlToRelConverter {
}
}
+ /** Function that can convert a sort specification (expression, direction
+ * and null direction) to a target format.
+ *
+ * @param <R> Target format, such as {@link RexFieldCollation} or
+ * {@link RexNode}
+ */
+ @FunctionalInterface
+ interface SortExpressionConverter<R> {
+ R convert(SqlNode node, RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection);
+ }
+
/**
* Workspace for translating an individual SELECT statement (or sub-SELECT).
*/
@@ -5097,50 +5105,78 @@ public class SqlToRelConverter {
public RexFieldCollation convertSortExpression(SqlNode expr,
RelFieldCollation.Direction direction,
RelFieldCollation.NullDirection nullDirection) {
+ return convertSortExpression(expr, direction, nullDirection,
+ this::sortToRexFieldCollation);
+ }
+
+ /** Handles an item in an ORDER BY clause, passing using a converter
+ * function to produce the final result. */
+ <R> R convertSortExpression(SqlNode expr,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection,
+ SortExpressionConverter<R> converter) {
switch (expr.getKind()) {
case DESCENDING:
return convertSortExpression(((SqlCall) expr).operand(0),
- RelFieldCollation.Direction.DESCENDING, nullDirection);
+ RelFieldCollation.Direction.DESCENDING, nullDirection, converter);
case NULLS_LAST:
return convertSortExpression(((SqlCall) expr).operand(0),
- direction, RelFieldCollation.NullDirection.LAST);
+ direction, RelFieldCollation.NullDirection.LAST, converter);
case NULLS_FIRST:
return convertSortExpression(((SqlCall) expr).operand(0),
- direction, RelFieldCollation.NullDirection.FIRST);
+ direction, RelFieldCollation.NullDirection.FIRST, converter);
default:
- final Set<SqlKind> flags = EnumSet.noneOf(SqlKind.class);
- switch (direction) {
- case DESCENDING:
- flags.add(SqlKind.DESCENDING);
- break;
- default:
- break;
- }
- switch (nullDirection) {
- case UNSPECIFIED:
- final RelFieldCollation.NullDirection nullDefaultDirection =
+ return converter.convert(expr, direction, nullDirection);
+ }
+ }
+
+ private RexFieldCollation sortToRexFieldCollation(SqlNode expr,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection) {
+ final Set<SqlKind> flags = EnumSet.noneOf(SqlKind.class);
+ if (direction == RelFieldCollation.Direction.DESCENDING) {
+ flags.add(SqlKind.DESCENDING);
+ }
+ switch (nullDirection) {
+ case UNSPECIFIED:
+ final RelFieldCollation.NullDirection nullDefaultDirection =
+ validator().config().defaultNullCollation().last(desc(direction))
+ ? RelFieldCollation.NullDirection.LAST
+ : RelFieldCollation.NullDirection.FIRST;
+ if (nullDefaultDirection != direction.defaultNullDirection()) {
+ SqlKind nullDirectionSqlKind =
validator().config().defaultNullCollation().last(desc(direction))
- ? RelFieldCollation.NullDirection.LAST
- : RelFieldCollation.NullDirection.FIRST;
- if (nullDefaultDirection != direction.defaultNullDirection()) {
- SqlKind nullDirectionSqlKind =
- validator().config().defaultNullCollation().last(desc(direction))
- ? SqlKind.NULLS_LAST
- : SqlKind.NULLS_FIRST;
- flags.add(nullDirectionSqlKind);
- }
- break;
- case FIRST:
- flags.add(SqlKind.NULLS_FIRST);
- break;
- case LAST:
- flags.add(SqlKind.NULLS_LAST);
- break;
- default:
- break;
+ ? SqlKind.NULLS_LAST
+ : SqlKind.NULLS_FIRST;
+ flags.add(nullDirectionSqlKind);
}
- return new RexFieldCollation(convertExpression(expr), flags);
+ break;
+ case FIRST:
+ flags.add(SqlKind.NULLS_FIRST);
+ break;
+ case LAST:
+ flags.add(SqlKind.NULLS_LAST);
+ break;
+ default:
+ break;
}
+ return new RexFieldCollation(convertExpression(expr), flags);
+ }
+
+ private RexNode sortToRex(SqlNode expr,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection) {
+ RexNode node = convertExpression(expr);
+ if (direction == RelFieldCollation.Direction.DESCENDING) {
+ node = relBuilder.desc(node);
+ }
+ if (nullDirection == RelFieldCollation.NullDirection.FIRST) {
+ node = relBuilder.nullsFirst(node);
+ }
+ if (nullDirection == RelFieldCollation.NullDirection.LAST) {
+ node = relBuilder.nullsLast(node);
+ }
+ return node;
}
/**
@@ -5727,12 +5763,8 @@ public class SqlToRelConverter {
.map(order ->
bb.convertSortExpression(order,
RelFieldCollation.Direction.ASCENDING,
- RelFieldCollation.NullDirection.UNSPECIFIED))
- .map(fieldCollation ->
- new RelFieldCollation(
- lookupOrCreateGroupExpr(fieldCollation.left),
- fieldCollation.getDirection(),
- fieldCollation.getNullDirection()))
+ RelFieldCollation.NullDirection.UNSPECIFIED,
+ this::sortToFieldCollation))
.collect(Collectors.toList()));
}
final AggregateCall aggCall =
@@ -5757,6 +5789,17 @@ public class SqlToRelConverter {
aggMapping.put(outerCall, rex);
}
+ private RelFieldCollation sortToFieldCollation(SqlNode expr,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection) {
+ final RexNode node = bb.convertExpression(expr);
+ final int fieldIndex = lookupOrCreateGroupExpr(node);
+ if (nullDirection == RelFieldCollation.NullDirection.UNSPECIFIED) {
+ nullDirection = direction.defaultNullDirection();
+ }
+ return new RelFieldCollation(fieldIndex, direction, nullDirection);
+ }
+
private int lookupOrCreateGroupExpr(RexNode expr) {
int index = 0;
for (RexNode convertedInputExpr : Pair.left(convertedInputExprs)) {
@@ -5886,8 +5929,8 @@ public class SqlToRelConverter {
*/
static final boolean ENABLE_HISTOGRAM_AGG = false;
- private final List<RexNode> partitionKeys;
- private final ImmutableList<RexFieldCollation> orderKeys;
+ private final ImmutableList<RexNode> partitionKeys;
+ private final ImmutableList<RexNode> orderKeys;
private final RexWindowBound lowerBound;
private final RexWindowBound upperBound;
private final boolean rows;
@@ -5895,14 +5938,10 @@ public class SqlToRelConverter {
private final boolean distinct;
private final boolean ignoreNulls;
- HistogramShuttle(
- List<RexNode> partitionKeys,
- ImmutableList<RexFieldCollation> orderKeys,
+ HistogramShuttle(ImmutableList<RexNode> partitionKeys,
+ ImmutableList<RexNode> orderKeys, boolean rows,
RexWindowBound lowerBound, RexWindowBound upperBound,
- boolean rows,
- boolean allowPartial,
- boolean distinct,
- boolean ignoreNulls) {
+ boolean allowPartial, boolean distinct, boolean ignoreNulls) {
this.partitionKeys = partitionKeys;
this.orderKeys = orderKeys;
this.lowerBound = lowerBound;
@@ -5947,28 +5986,18 @@ public class SqlToRelConverter {
: rexBuilder.makeCast(histogramType, exprs.get(0)));
}
- RexCallBinding bind =
- new RexCallBinding(
- rexBuilder.getTypeFactory(),
- SqlStdOperatorTable.HISTOGRAM_AGG,
- exprs,
- ImmutableList.of());
-
RexNode over =
- rexBuilder.makeOver(
- SqlStdOperatorTable.HISTOGRAM_AGG
- .inferReturnType(bind),
- SqlStdOperatorTable.HISTOGRAM_AGG,
- exprs,
- partitionKeys,
- orderKeys,
- lowerBound,
- upperBound,
- rows,
- allowPartial,
- false,
- distinct,
- ignoreNulls);
+ relBuilder.aggregateCall(SqlStdOperatorTable.HISTOGRAM_AGG, exprs)
+ .distinct(distinct)
+ .ignoreNulls(ignoreNulls)
+ .over()
+ .partitionBy(partitionKeys)
+ .orderBy(orderKeys)
+ .let(c ->
+ rows ? c.rowsBetween(lowerBound, upperBound)
+ : c.rangeBetween(lowerBound, upperBound))
+ .allowPartial(allowPartial)
+ .toRex();
RexNode histogramCall =
rexBuilder.makeCall(
@@ -5998,19 +6027,18 @@ public class SqlToRelConverter {
SqlAggFunction aggOpToUse =
needSum0 ? SqlStdOperatorTable.SUM0
: aggOp;
- return rexBuilder.makeOver(
- type,
- aggOpToUse,
- exprs,
- partitionKeys,
- orderKeys,
- lowerBound,
- upperBound,
- rows,
- allowPartial,
- needSum0,
- distinct,
- ignoreNulls);
+ return relBuilder.aggregateCall(aggOpToUse, exprs)
+ .distinct(distinct)
+ .ignoreNulls(ignoreNulls)
+ .over()
+ .partitionBy(partitionKeys)
+ .orderBy(orderKeys)
+ .let(c ->
+ rows ? c.rowsBetween(lowerBound, upperBound)
+ : c.rangeBetween(lowerBound, upperBound))
+ .allowPartial(allowPartial)
+ .nullWhenCountZero(needSum0)
+ .toRex();
}
}
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 269a72f..013e097 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -68,20 +68,25 @@ import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexExecutor;
+import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSimplify;
import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexWindowBound;
+import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.schema.TransientTable;
import org.apache.calcite.schema.impl.ListTransientTable;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlWindow;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlLikeOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -124,6 +129,7 @@ import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Deque;
+import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -235,7 +241,7 @@ public class RelBuilder {
return new RelBuilder(context, cluster, relOptSchema);
}
- /** Performs an action on this RelBuilder if a condition is true.
+ /** Performs an action on this RelBuilder.
*
* <p>For example, consider the following code:
*
@@ -822,6 +828,43 @@ public class RelBuilder {
return call(SqlStdOperatorTable.NULLS_FIRST, node);
}
+ // Methods that create window bounds
+
+ /** Creates an {@code UNBOUNDED PRECEDING} window bound,
+ * for use in methods such as {@link OverCall#rowsFrom(RexWindowBound)}
+ * and {@link OverCall#rangeBetween(RexWindowBound, RexWindowBound)}. */
+ public RexWindowBound unboundedPreceding() {
+ return RexWindowBounds.UNBOUNDED_PRECEDING;
+ }
+
+ /** Creates a {@code bound PRECEDING} window bound,
+ * for use in methods such as {@link OverCall#rowsFrom(RexWindowBound)}
+ * and {@link OverCall#rangeBetween(RexWindowBound, RexWindowBound)}. */
+ public RexWindowBound preceding(RexNode bound) {
+ return RexWindowBounds.preceding(bound);
+ }
+
+ /** Creates a {@code CURRENT ROW} window bound,
+ * for use in methods such as {@link OverCall#rowsFrom(RexWindowBound)}
+ * and {@link OverCall#rangeBetween(RexWindowBound, RexWindowBound)}. */
+ public RexWindowBound currentRow() {
+ return RexWindowBounds.CURRENT_ROW;
+ }
+
+ /** Creates a {@code bound FOLLOWING} window bound,
+ * for use in methods such as {@link OverCall#rowsFrom(RexWindowBound)}
+ * and {@link OverCall#rangeBetween(RexWindowBound, RexWindowBound)}. */
+ public RexWindowBound following(RexNode bound) {
+ return RexWindowBounds.following(bound);
+ }
+
+ /** Creates an {@code UNBOUNDED FOLLOWING} window bound,
+ * for use in methods such as {@link OverCall#rowsFrom(RexWindowBound)}
+ * and {@link OverCall#rangeBetween(RexWindowBound, RexWindowBound)}. */
+ public RexWindowBound unboundedFollowing() {
+ return RexWindowBounds.UNBOUNDED_FOLLOWING;
+ }
+
// Methods that create group keys and aggregate calls
/** Creates an empty group key. */
@@ -1813,7 +1856,11 @@ public class RelBuilder {
public RelBuilder aggregate(GroupKey groupKey, List<AggregateCall> aggregateCalls) {
return aggregate(groupKey,
aggregateCalls.stream()
- .map(AggCallImpl2::new)
+ .map(aggregateCall ->
+ new AggCallImpl2(aggregateCall,
+ aggregateCall.getArgList().stream()
+ .map(this::field)
+ .collect(Util.toImmutableList())))
.collect(Collectors.toList()));
}
@@ -2864,7 +2911,8 @@ public class RelBuilder {
private static RelFieldCollation collation(RexNode node,
RelFieldCollation.Direction direction,
- RelFieldCollation.@Nullable NullDirection nullDirection, List<RexNode> extraNodes) {
+ RelFieldCollation.@Nullable NullDirection nullDirection,
+ List<RexNode> extraNodes) {
switch (node.getKind()) {
case INPUT_REF:
return new RelFieldCollation(((RexInputRef) node).getIndex(), direction,
@@ -2887,6 +2935,34 @@ public class RelBuilder {
}
}
+ private static RexFieldCollation rexCollation(RexNode node,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.@Nullable NullDirection nullDirection) {
+ switch (node.getKind()) {
+ case DESCENDING:
+ return rexCollation(((RexCall) node).operands.get(0),
+ RelFieldCollation.Direction.DESCENDING, nullDirection);
+ case NULLS_LAST:
+ return rexCollation(((RexCall) node).operands.get(0),
+ direction, RelFieldCollation.NullDirection.LAST);
+ case NULLS_FIRST:
+ return rexCollation(((RexCall) node).operands.get(0),
+ direction, RelFieldCollation.NullDirection.FIRST);
+ default:
+ final Set<SqlKind> flags = EnumSet.noneOf(SqlKind.class);
+ if (direction == RelFieldCollation.Direction.DESCENDING) {
+ flags.add(SqlKind.DESCENDING);
+ }
+ if (nullDirection == RelFieldCollation.NullDirection.FIRST) {
+ flags.add(SqlKind.NULLS_FIRST);
+ }
+ if (nullDirection == RelFieldCollation.NullDirection.LAST) {
+ flags.add(SqlKind.NULLS_LAST);
+ }
+ return new RexFieldCollation(node, flags);
+ }
+ }
+
/**
* Creates a projection that converts the current relational expression's
* output to a desired row type.
@@ -3291,6 +3367,9 @@ public class RelBuilder {
default AggCall distinct() {
return distinct(true);
}
+
+ /** Converts this aggregate call to a windowed aggregate call. */
+ OverCall over();
}
/** Internal methods shared by all implementations of {@link AggCall}. */
@@ -3488,6 +3567,11 @@ public class RelBuilder {
registrar.registerExpressions(orderKeys);
}
+ @Override public OverCall over() {
+ return new OverCallImpl(aggFunction, distinct, operands, ignoreNulls,
+ alias);
+ }
+
@Override public AggCall sort(Iterable<RexNode> orderKeys) {
final ImmutableList<RexNode> orderKeyList =
ImmutableList.copyOf(orderKeys);
@@ -3548,11 +3632,19 @@ public class RelBuilder {
/** Implementation of {@link AggCall} that wraps an
* {@link AggregateCall}. */
- private static class AggCallImpl2 implements AggCallPlus {
+ private class AggCallImpl2 implements AggCallPlus {
private final AggregateCall aggregateCall;
+ private final ImmutableList<RexNode> operands;
- AggCallImpl2(AggregateCall aggregateCall) {
+ AggCallImpl2(AggregateCall aggregateCall, ImmutableList<RexNode> operands) {
this.aggregateCall = requireNonNull(aggregateCall, "aggregateCall");
+ this.operands = requireNonNull(operands, "operands");
+ }
+
+ @Override public OverCall over() {
+ return new OverCallImpl(aggregateCall.getAggregation(),
+ aggregateCall.isDistinct(), operands, aggregateCall.ignoreNulls(),
+ aggregateCall.name);
}
@Override public String toString() {
@@ -3613,6 +3705,246 @@ public class RelBuilder {
}
}
+ /** Call to a windowed aggregate function.
+ *
+ * <p>To create an {@code OverCall}, start with an {@link AggCall} (created
+ * by a method such as {@link #aggregateCall}, {@link #sum} or {@link #count})
+ * and call its {@link AggCall#over()} method. For example,
+ *
+ * <pre>{@code
+ * b.scan("EMP")
+ * .project(b.field("DEPTNO"),
+ * b.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
+ * .over()
+ * .partitionBy()
+ * .orderBy(b.field("EMPNO"))
+ * .rowsUnbounded()
+ * .allowPartial(true)
+ * .nullWhenCountZero(false)
+ * .as("x"))
+ * }</pre>
+ *
+ * <p>Unlike an aggregate call, a windowed aggregate call is an expression
+ * that you can use in a {@link Project} or {@link Filter}. So, to finish,
+ * call {@link OverCall#toRex()} to convert the {@code OverCall} to a
+ * {@link RexNode}; the {@link OverCall#as} method (used in the above example)
+ * does the same but also assigns an column alias.
+ */
+ public interface OverCall {
+ /** Performs an action on this OverCall. */
+ default <R> R let(Function<OverCall, R> consumer) {
+ return consumer.apply(this);
+ }
+
+ /** Sets the PARTITION BY clause to an array of expressions. */
+ OverCall partitionBy(RexNode... expressions);
+
+ /** Sets the PARTITION BY clause to a list of expressions. */
+ OverCall partitionBy(Iterable<? extends RexNode> expressions);
+
+ /** Sets the ORDER BY BY clause to an array of expressions.
+ *
+ * <p>Use {@link #desc(RexNode)}, {@link #nullsFirst(RexNode)},
+ * {@link #nullsLast(RexNode)} to control the sort order. */
+ OverCall orderBy(RexNode... expressions);
+
+ /** Sets the ORDER BY BY clause to a list of expressions.
+ *
+ * <p>Use {@link #desc(RexNode)}, {@link #nullsFirst(RexNode)},
+ * {@link #nullsLast(RexNode)} to control the sort order. */
+ OverCall orderBy(Iterable<? extends RexNode> expressions);
+
+ /** Sets an unbounded ROWS window,
+ * equivalent to SQL {@code ROWS BETWEEN UNBOUNDED PRECEDING AND
+ * UNBOUNDED FOLLOWING}. */
+ default OverCall rowsUnbounded() {
+ return rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING,
+ RexWindowBounds.UNBOUNDED_FOLLOWING);
+ }
+
+ /** Sets a ROWS window with a lower bound,
+ * equivalent to SQL {@code ROWS BETWEEN lower AND CURRENT ROW}. */
+ default OverCall rowsFrom(RexWindowBound lower) {
+ return rowsBetween(lower, RexWindowBounds.UNBOUNDED_FOLLOWING);
+ }
+
+ /** Sets a ROWS window with an upper bound,
+ * equivalent to SQL {@code ROWS BETWEEN CURRENT ROW AND upper}. */
+ default OverCall rowsTo(RexWindowBound upper) {
+ return rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, upper);
+ }
+
+ /** Sets a RANGE window with lower and upper bounds,
+ * equivalent to SQL {@code ROWS BETWEEN lower ROW AND upper}. */
+ OverCall rowsBetween(RexWindowBound lower, RexWindowBound upper);
+
+ /** Sets an unbounded RANGE window,
+ * equivalent to SQL {@code RANGE BETWEEN UNBOUNDED PRECEDING AND
+ * UNBOUNDED FOLLOWING}. */
+ default OverCall rangeUnbounded() {
+ return rangeBetween(RexWindowBounds.UNBOUNDED_PRECEDING,
+ RexWindowBounds.UNBOUNDED_FOLLOWING);
+ }
+
+ /** Sets a RANGE window with a lower bound,
+ * equivalent to SQL {@code RANGE BETWEEN lower AND CURRENT ROW}. */
+ default OverCall rangeFrom(RexWindowBound lower) {
+ return rangeBetween(lower, RexWindowBounds.CURRENT_ROW);
+ }
+
+ /** Sets a RANGE window with an upper bound,
+ * equivalent to SQL {@code RANGE BETWEEN CURRENT ROW AND upper}. */
+ default OverCall rangeTo(RexWindowBound upper) {
+ return rangeBetween(RexWindowBounds.UNBOUNDED_PRECEDING, upper);
+ }
+
+ /** Sets a RANGE window with lower and upper bounds,
+ * equivalent to SQL {@code RANGE BETWEEN lower ROW AND upper}. */
+ OverCall rangeBetween(RexWindowBound lower, RexWindowBound upper);
+
+ /** Sets whether to allow partial width windows; default true. */
+ OverCall allowPartial(boolean allowPartial);
+
+ /** Sets whether the aggregate function should evaluate to null if no rows
+ * are in the window; default false. */
+ OverCall nullWhenCountZero(boolean nullWhenCountZero);
+
+ /** Sets the alias of this expression, and converts it to a {@link RexNode};
+ * default is the alias that was set via {@link AggCall#as(String)}. */
+ RexNode as(String alias);
+
+ /** Converts this expression to a {@link RexNode}. */
+ RexNode toRex();
+ }
+
+ /** Implementation of {@link OverCall}. */
+ private class OverCallImpl implements OverCall {
+ private final ImmutableList<RexNode> operands;
+ private final boolean ignoreNulls;
+ private final @Nullable String alias;
+ private final boolean nullWhenCountZero;
+ private final boolean allowPartial;
+ private final boolean rows;
+ private final RexWindowBound lowerBound;
+ private final RexWindowBound upperBound;
+ private final ImmutableList<RexNode> partitionKeys;
+ private final ImmutableList<RexFieldCollation> sortKeys;
+ private final SqlAggFunction op;
+ private final boolean distinct;
+
+ private OverCallImpl(SqlAggFunction op, boolean distinct,
+ ImmutableList<RexNode> operands, boolean ignoreNulls,
+ @Nullable String alias, ImmutableList<RexNode> partitionKeys,
+ ImmutableList<RexFieldCollation> sortKeys, boolean rows,
+ RexWindowBound lowerBound, RexWindowBound upperBound,
+ boolean nullWhenCountZero, boolean allowPartial) {
+ this.op = op;
+ this.distinct = distinct;
+ this.operands = operands;
+ this.ignoreNulls = ignoreNulls;
+ this.alias = alias;
+ this.partitionKeys = partitionKeys;
+ this.sortKeys = sortKeys;
+ this.nullWhenCountZero = nullWhenCountZero;
+ this.allowPartial = allowPartial;
+ this.rows = rows;
+ this.lowerBound = lowerBound;
+ this.upperBound = upperBound;
+ }
+
+ /** Creates an OverCallImpl with default settings. */
+ OverCallImpl(SqlAggFunction op, boolean distinct,
+ ImmutableList<RexNode> operands, boolean ignoreNulls,
+ @Nullable String alias) {
+ this(op, distinct, operands, ignoreNulls, alias, ImmutableList.of(),
+ ImmutableList.of(), true, RexWindowBounds.UNBOUNDED_PRECEDING,
+ RexWindowBounds.UNBOUNDED_FOLLOWING, false, true);
+ }
+
+ @Override public OverCall partitionBy(
+ Iterable<? extends RexNode> expressions) {
+ return partitionBy_(ImmutableList.copyOf(expressions));
+ }
+
+ @Override public OverCall partitionBy(RexNode... expressions) {
+ return partitionBy_(ImmutableList.copyOf(expressions));
+ }
+
+ private OverCall partitionBy_(ImmutableList<RexNode> partitionKeys) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, rows, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ private OverCall orderBy_(ImmutableList<RexFieldCollation> sortKeys) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, rows, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ @Override public OverCall orderBy(Iterable<? extends RexNode> sortKeys) {
+ ImmutableList.Builder<RexFieldCollation> fieldCollations =
+ ImmutableList.builder();
+ sortKeys.forEach(sortKey ->
+ fieldCollations.add(
+ rexCollation(sortKey, RelFieldCollation.Direction.ASCENDING,
+ RelFieldCollation.NullDirection.UNSPECIFIED)));
+ return orderBy_(fieldCollations.build());
+ }
+
+ @Override public OverCall orderBy(RexNode... sortKeys) {
+ return orderBy(Arrays.asList(sortKeys));
+ }
+
+ @Override public OverCall rowsBetween(RexWindowBound lowerBound,
+ RexWindowBound upperBound) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, true, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ @Override public OverCall rangeBetween(RexWindowBound lowerBound,
+ RexWindowBound upperBound) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, false, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ @Override public OverCall allowPartial(boolean allowPartial) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, rows, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ @Override public OverCall nullWhenCountZero(boolean nullWhenCountZero) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, rows, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial);
+ }
+
+ @Override public RexNode as(String alias) {
+ return new OverCallImpl(op, distinct, operands, ignoreNulls, alias,
+ partitionKeys, sortKeys, rows, lowerBound, upperBound,
+ nullWhenCountZero, allowPartial).toRex();
+ }
+
+ @Override public RexNode toRex() {
+ final RexCallBinding bind =
+ new RexCallBinding(getTypeFactory(), op, operands,
+ ImmutableList.of()) {
+ @Override public int getGroupCount() {
+ return SqlWindow.isAlwaysNonEmpty(lowerBound, upperBound) ? 1 : 0;
+ }
+ };
+ final RelDataType type = op.inferReturnType(bind);
+ final RexNode over = getRexBuilder()
+ .makeOver(type, op, operands, partitionKeys, sortKeys,
+ lowerBound, upperBound, rows, allowPartial, nullWhenCountZero,
+ distinct, ignoreNulls);
+ return alias == null ? over : alias(over, alias);
+ }
+ }
+
/** Collects the extra expressions needed for {@link #aggregate}.
*
* <p>The extra expressions come from the group key and as arguments to
diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index cbc1269..b276a15 100644
--- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -36,8 +36,6 @@ import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
-import org.apache.calcite.rex.RexFieldCollation;
-import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.schema.SchemaPlus;
@@ -843,18 +841,13 @@ class RelToSqlConverterTest {
.scan("EMP")
.project(b.field("SAL"))
.project(
- b.alias(
- b.getRexBuilder().makeOver(
- b.getTypeFactory().createSqlType(SqlTypeName.INTEGER),
- SqlStdOperatorTable.RANK, ImmutableList.of(),
- ImmutableList.of(),
- ImmutableList.of(
- new RexFieldCollation(b.field("SAL"),
- ImmutableSet.of())),
- RexWindowBounds.UNBOUNDED_PRECEDING,
- RexWindowBounds.UNBOUNDED_FOLLOWING,
- true, true, false, false, false),
- "rank"))
+ b.aggregateCall(SqlStdOperatorTable.RANK)
+ .over()
+ .orderBy(b.field("SAL"))
+ .rowsUnbounded()
+ .allowPartial(true)
+ .nullWhenCountZero(false)
+ .as("rank"))
.as("t")
.aggregate(b.groupKey(),
b.count(b.field("t", "rank")).distinct().as("c"))
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 2586317..8a3c731 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -45,7 +45,6 @@ import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.schema.SchemaPlus;
@@ -971,22 +970,44 @@ public class RelBuilderTest {
return b.call(SqlStdOperatorTable.CASE, list);
}
- /** Creates a {@link Project} that contains a windowed aggregate function. As
- * {@link RelBuilder} not explicitly support for {@link RexOver} the syntax is
- * a bit cumbersome. */
+ /** Creates a {@link Project} that contains a windowed aggregate function.
+ * Repeats the using {@link RelBuilder.AggCall#over} and
+ * {@link RexBuilder#makeOver}. */
@Test void testProjectOver() {
- final Function<RelBuilder, RelNode> f = b -> b.scan("EMP")
+ final Function<RelBuilder, RelNode> f = b -> {
+ final RelDataType intType =
+ b.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
+ return b.scan("EMP")
+ .project(b.field("DEPTNO"),
+ b.alias(
+ b.getRexBuilder().makeOver(intType,
+ SqlStdOperatorTable.ROW_NUMBER, ImmutableList.of(),
+ ImmutableList.of(),
+ ImmutableList.of(
+ new RexFieldCollation(b.field("EMPNO"),
+ ImmutableSet.of())),
+ RexWindowBounds.UNBOUNDED_PRECEDING,
+ RexWindowBounds.UNBOUNDED_FOLLOWING,
+ true, true, false, false, false),
+ "x"))
+ .build();
+ };
+ final Function<RelBuilder, RelNode> f2 = b -> b.scan("EMP")
.project(b.field("DEPTNO"),
- over(b,
- ImmutableList.of(
- new RexFieldCollation(b.field("EMPNO"),
- ImmutableSet.of())),
- "x"))
+ b.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
+ .over()
+ .partitionBy()
+ .orderBy(b.field("EMPNO"))
+ .rowsUnbounded()
+ .allowPartial(true)
+ .nullWhenCountZero(false)
+ .as("x"))
.build();
final String expected = ""
+ "LogicalProject(DEPTNO=[$7], x=[ROW_NUMBER() OVER (ORDER BY $0)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(f.apply(createBuilder()), hasTree(expected));
+ assertThat(f2.apply(createBuilder()), hasTree(expected));
}
/** Tests that RelBuilder does not merge a Project that contains a windowed
@@ -994,17 +1015,19 @@ public class RelBuilderTest {
@Test void testProjectOverOver() {
final Function<RelBuilder, RelNode> f = b -> b.scan("EMP")
.project(b.field("DEPTNO"),
- over(b,
- ImmutableList.of(
- new RexFieldCollation(b.field("EMPNO"),
- ImmutableSet.of())),
- "x"))
+ b.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
+ .over()
+ .partitionBy()
+ .orderBy(b.field("EMPNO"))
+ .rowsUnbounded()
+ .as("x"))
.project(b.field("DEPTNO"),
- over(b,
- ImmutableList.of(
- new RexFieldCollation(b.field("DEPTNO"),
- ImmutableSet.of())),
- "y"))
+ b.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
+ .over()
+ .partitionBy()
+ .orderBy(b.field("DEPTNO"))
+ .rowsUnbounded()
+ .as("y"))
.build();
final String expected = ""
+ "LogicalProject(DEPTNO=[$0], y=[ROW_NUMBER() OVER (ORDER BY $0)])\n"
@@ -1013,19 +1036,6 @@ public class RelBuilderTest {
assertThat(f.apply(createBuilder()), hasTree(expected));
}
- private RexNode over(RelBuilder b,
- ImmutableList<RexFieldCollation> fieldCollations, String alias) {
- final RelDataType intType =
- b.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
- return b.alias(
- b.getRexBuilder()
- .makeOver(intType, SqlStdOperatorTable.ROW_NUMBER,
- ImmutableList.of(), ImmutableList.of(), fieldCollations,
- RexWindowBounds.UNBOUNDED_PRECEDING,
- RexWindowBounds.UNBOUNDED_FOLLOWING, true, true, false,
- false, false), alias);
- }
-
@Test void testRename() {
final RelBuilder builder = RelBuilder.create(config().build());
diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index 64b9104..0dcd4e6 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -2305,8 +2305,7 @@ LogicalProject(EMPNO=[$0])
</TestCase>
<TestCase name="testInWithConstantList">
<Resource name="sql">
- <![CDATA[1 in (1,2,3)
-]]>
+ <![CDATA[1 in (1,2,3)]]>
</Resource>
<Resource name="plan">
<![CDATA[
@@ -7003,7 +7002,7 @@ order by row_number() over(partition by empno order by deptno)]]>
<![CDATA[
LogicalProject(DEPTNO=[$0], EXPR$1=[$1])
LogicalSort(sort0=[$2], dir0=[ASC-nulls-first])
- LogicalProject(DEPTNO=[$7], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $7 NULLS FIRST)], EXPR$2=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $7 NULLS FIRST)])
+ LogicalProject(DEPTNO=[$7], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $7)], EXPR$2=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $7)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
diff --git a/piglet/src/main/java/org/apache/calcite/piglet/PigRelOpVisitor.java b/piglet/src/main/java/org/apache/calcite/piglet/PigRelOpVisitor.java
index 7d0dc71..350119d 100644
--- a/piglet/src/main/java/org/apache/calcite/piglet/PigRelOpVisitor.java
+++ b/piglet/src/main/java/org/apache/calcite/piglet/PigRelOpVisitor.java
@@ -24,16 +24,14 @@ import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Pair;
import org.apache.pig.builtin.CubeDimensions;
import org.apache.pig.builtin.RollupDimensions;
@@ -670,30 +668,23 @@ class PigRelOpVisitor extends PigRelOpWalker.PlanPreVisitor {
loRank.isDenseRank() ? SqlStdOperatorTable.DENSE_RANK : SqlStdOperatorTable.RANK;
// Build the order keys
- List<RexFieldCollation> orderNodes = new ArrayList<>();
- for (int i = 0; i < loRank.getRankColPlans().size(); i++) {
+ List<RexNode> orderNodes = new ArrayList<>();
+ for (Pair<LogicalExpressionPlan, Boolean> p
+ : Pair.zip(loRank.getRankColPlans(), loRank.getAscendingCol())) {
RexNode orderNode =
- PigRelExVisitor.translatePigEx(builder, loRank.getRankColPlans().get(i));
- Set<SqlKind> flags = new HashSet<>();
- if (!loRank.getAscendingCol().get(i)) {
- flags.add(SqlKind.DESCENDING);
+ PigRelExVisitor.translatePigEx(builder, p.left);
+ final boolean ascending = p.right;
+ if (!ascending) {
+ orderNode = builder.desc(orderNode);
}
- orderNodes.add(new RexFieldCollation(orderNode, flags));
+ orderNodes.add(orderNode);
}
- return builder.getRexBuilder().makeOver(
- PigTypes.TYPE_FACTORY.createSqlType(SqlTypeName.BIGINT), // Return type
- rank, // Aggregate function
- Collections.emptyList(), // Operands for the aggregate function, empty here
- Collections.emptyList(), // No partition keys
- ImmutableList.copyOf(orderNodes), // order keys
- RexWindowBounds.UNBOUNDED_PRECEDING,
- RexWindowBounds.CURRENT_ROW,
- false, // Range-based
- true, // allow partial
- false, // not return null when count is zero
- false, // no distinct
- false);
+ return builder.aggregateCall(rank)
+ .over()
+ .rangeFrom(RexWindowBounds.UNBOUNDED_PRECEDING)
+ .orderBy(orderNodes)
+ .toRex();
}
@Override public void visit(LOStream loStream) throws FrontendException {
diff --git a/site/_docs/algebra.md b/site/_docs/algebra.md
index 9a55c8e..0c3ee62 100644
--- a/site/_docs/algebra.md
+++ b/site/_docs/algebra.md
@@ -491,3 +491,30 @@ To further modify the `AggCall`, call its methods:
| `filter(expr)` | Filters rows before aggregating (see SQL `FILTER (WHERE ...)`)
| `sort(expr...)`<br/>`sort(exprList)` | Sorts rows before aggregating (see SQL `WITHIN GROUP`)
| `unique(expr...)`<br/>`unique(exprList)` | Makes rows unique before aggregating (see SQL `WITHIN DISTINCT`)
+| `over()` | Converts this `AggCall` into a windowed aggregate (see `OverCall` below)
+
+#### Windowed aggregate call methods
+
+To create an
+[RelBuilder.OverCall]({{ site.apiRoot }}/org/apache/calcite/tools/RelBuilder.OverCall.html),
+which represents a call to a windowed aggregate function, create an aggregate
+call and then call its `over()` method, for instance `count().over()`.
+
+To further modify the `OverCall`, call its methods:
+
+| Method | Description
+|:-------------------- |:-----------
+| `rangeUnbounded()` | Creates an unbounded range-based window, `RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING`
+| `rangeFrom(lower)` | Creates a range-based window bounded below, `RANGE BETWEEN lower AND CURRENT ROW`
+| `rangeTo(upper)` | Creates a range-based window bounded above, `RANGE BETWEEN CURRENT ROW AND upper`
+| `rangeBetween(lower, upper)` | Creates a range-based window, `RANGE BETWEEN lower AND upper`
+| `rowsUnbounded()` | Creates an unbounded row-based window, `ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING`
+| `rowsFrom(lower)` | Creates a row-based window bounded below, `ROWS BETWEEN lower AND CURRENT ROW`
+| `rowsTo(upper)` | Creates a row-based window bounded above, `ROWS BETWEEN CURRENT ROW AND upper`
+| `rowsBetween(lower, upper)` | Creates a rows-based window, `ROWS BETWEEN lower AND upper`
+| `partitionBy(expr...)`<br/>`partitionBy(exprList)` | Partitions the window on the given expressions (see SQL `PARTITION BY`)
+| `orderBy(expr...)`<br/>`sort(exprList)` | Sorts the rows in the window (see SQL `ORDER BY`)
+| `allowPartial(b)` | Sets whether to allow partial width windows; default true
+| `nullWhenCountZero(b)` | Sets whether whether the aggregate function should evaluate to null if no rows are in the window; default false
+| `as(alias)` | Assigns a column alias (see SQL `AS`) and converts this `OverCall` to a `RexNode`
+| `toRex()` | Converts this `OverCall` to a `RexNode`