You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by ru...@apache.org on 2020/06/29 08:36:13 UTC
[calcite] branch master updated: [CALCITE-4008] Implement Code
generation for EnumerableSortedAggregate (Rui Wang).
This is an automated email from the ASF dual-hosted git repository.
rubenql pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push:
new bf9ff00 [CALCITE-4008] Implement Code generation for EnumerableSortedAggregate (Rui Wang).
bf9ff00 is described below
commit bf9ff001db743bcba35943daf7fec5fe8b8b207e
Author: amaliujia <am...@163.com>
AuthorDate: Thu Jun 18 11:09:17 2020 -0700
[CALCITE-4008] Implement Code generation for EnumerableSortedAggregate (Rui Wang).
---
.../adapter/enumerable/EnumerableAggregate.java | 266 +----------------
.../enumerable/EnumerableAggregateBase.java | 330 +++++++++++++++++++++
.../enumerable/EnumerableSortedAggregate.java | 143 ++++++++-
.../org/apache/calcite/util/BuiltInMethod.java | 2 +
.../enumerable/EnumerableSortedAggregateTest.java | 142 +++++++++
.../apache/calcite/linq4j/DefaultEnumerable.java | 11 +
.../apache/calcite/linq4j/EnumerableDefaults.java | 127 ++++++++
.../apache/calcite/linq4j/ExtendedEnumerable.java | 17 ++
8 files changed, 775 insertions(+), 263 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
index a0dee72..03f3192 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
@@ -16,11 +16,8 @@
*/
package org.apache.calcite.adapter.enumerable;
-import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
-import org.apache.calcite.config.CalciteSystemProperty;
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Function0;
import org.apache.calcite.linq4j.function.Function1;
@@ -29,35 +26,23 @@ import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
-import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
-import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.Pair;
-import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Type;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.LinkedList;
import java.util.List;
/** Implementation of {@link org.apache.calcite.rel.core.Aggregate} in
* {@link org.apache.calcite.adapter.enumerable.EnumerableConvention enumerable calling convention}. */
-public class EnumerableAggregate extends Aggregate implements EnumerableRel {
+public class EnumerableAggregate extends EnumerableAggregateBase implements EnumerableRel {
public EnumerableAggregate(
RelOptCluster cluster,
RelTraitSet traitSet,
@@ -202,37 +187,8 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
final List<Expression> initExpressions = new ArrayList<>();
final BlockBuilder initBlock = new BlockBuilder();
- final List<Type> aggStateTypes = new ArrayList<>();
- for (final AggImpState agg : aggs) {
- agg.context = new AggContextImpl(agg, typeFactory);
- final List<Type> state = agg.implementor.getStateType(agg.context);
-
- if (state.isEmpty()) {
- agg.state = ImmutableList.of();
- continue;
- }
-
- aggStateTypes.addAll(state);
-
- final List<Expression> decls = new ArrayList<>(state.size());
- for (int i = 0; i < state.size(); i++) {
- String aggName = "a" + agg.aggIdx;
- if (CalciteSystemProperty.DEBUG.value()) {
- aggName = Util.toJavaId(agg.call.getAggregation().getName(), 0)
- .substring("ID$0$".length()) + aggName;
- }
- Type type = state.get(i);
- ParameterExpression pe =
- Expressions.parameter(type,
- initBlock.newName(aggName + "s" + i));
- initBlock.add(Expressions.declare(0, pe, null));
- decls.add(pe);
- }
- agg.state = decls;
- initExpressions.addAll(decls);
- agg.implementor.implementReset(agg.context,
- new AggResultContextImpl(initBlock, agg.call, decls, null, null));
- }
+ final List<Type> aggStateTypes = createAggStateTypes(
+ initExpressions, initBlock, aggs, typeFactory);
final PhysType accPhysType =
PhysTypeImpl.of(typeFactory,
@@ -258,55 +214,9 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
Expressions.parameter(inputPhysType.getJavaRowType(), "in");
final ParameterExpression acc_ =
Expressions.parameter(accPhysType.getJavaRowType(), "acc");
- for (int i = 0, stateOffset = 0; i < aggs.size(); i++) {
- final BlockBuilder builder2 = new BlockBuilder();
- final AggImpState agg = aggs.get(i);
- final int stateSize = agg.state.size();
- final List<Expression> accumulator = new ArrayList<>(stateSize);
- for (int j = 0; j < stateSize; j++) {
- accumulator.add(accPhysType.fieldReference(acc_, j + stateOffset));
- }
- agg.state = accumulator;
-
- stateOffset += stateSize;
-
- AggAddContext addContext =
- new AggAddContextImpl(builder2, accumulator) {
- public List<RexNode> rexArguments() {
- List<RelDataTypeField> inputTypes =
- inputPhysType.getRowType().getFieldList();
- List<RexNode> args = new ArrayList<>();
- for (int index : agg.call.getArgList()) {
- args.add(RexInputRef.of(index, inputTypes));
- }
- return args;
- }
-
- public RexNode rexFilterArgument() {
- return agg.call.filterArg < 0
- ? null
- : RexInputRef.of(agg.call.filterArg,
- inputPhysType.getRowType());
- }
-
- public RexToLixTranslator rowTranslator() {
- return RexToLixTranslator.forAggregation(typeFactory,
- currentBlock(),
- new RexToLixTranslator.InputGetterImpl(
- Collections.singletonList(
- Pair.of(inParameter, inputPhysType))),
- implementor.getConformance())
- .setNullable(currentNullables());
- }
- };
-
- agg.implementor.implementAdd(agg.context, addContext);
- builder2.add(acc_);
- agg.accumulatorAdder = builder.append("accumulatorAdder",
- Expressions.lambda(Function2.class, builder2.toBlock(), acc_,
- inParameter));
- }
+ createAccumulatorAdders(
+ inParameter, aggs, accPhysType, acc_, inputPhysType, builder, implementor, typeFactory);
final ParameterExpression lambdaFactory =
Expressions.parameter(AggregateLambdaFactory.class,
@@ -443,170 +353,4 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
}
return implementor.result(physType, builder.toBlock());
}
-
- private static boolean hasOrderedCall(List<AggImpState> aggs) {
- for (AggImpState agg : aggs) {
- if (!agg.call.collation.equals(RelCollations.EMPTY)) {
- return true;
- }
- }
- return false;
- }
-
- private void declareParentAccumulator(List<Expression> initExpressions,
- BlockBuilder initBlock, PhysType accPhysType) {
- if (accPhysType.getJavaRowType()
- instanceof JavaTypeFactoryImpl.SyntheticRecordType) {
- // We have to initialize the SyntheticRecordType instance this way, to
- // avoid using a class constructor with too many parameters.
- final JavaTypeFactoryImpl.SyntheticRecordType synType =
- (JavaTypeFactoryImpl.SyntheticRecordType)
- accPhysType.getJavaRowType();
- final ParameterExpression record0_ =
- Expressions.parameter(accPhysType.getJavaRowType(), "record0");
- initBlock.add(Expressions.declare(0, record0_, null));
- initBlock.add(
- Expressions.statement(
- Expressions.assign(record0_,
- Expressions.new_(accPhysType.getJavaRowType()))));
- List<Types.RecordField> fieldList = synType.getRecordFields();
- for (int i = 0; i < initExpressions.size(); i++) {
- Expression right = initExpressions.get(i);
- initBlock.add(
- Expressions.statement(
- Expressions.assign(
- Expressions.field(record0_, fieldList.get(i)), right)));
- }
- initBlock.add(record0_);
- } else {
- initBlock.add(accPhysType.record(initExpressions));
- }
- }
-
- /**
- * Implements the {@link AggregateLambdaFactory}.
- *
- * <p>Behavior depends upon ordering:
- * <ul>
- *
- * <li>{@code hasOrderedCall == true} means there is at least one aggregate
- * call including sort spec. We use {@link LazyAggregateLambdaFactory}
- * implementation to implement sorted aggregates for that.
- *
- * <li>{@code hasOrderedCall == false} indicates to use
- * {@link BasicAggregateLambdaFactory} to implement a non-sort
- * aggregate.
- *
- * </ul>
- */
- private void implementLambdaFactory(BlockBuilder builder,
- PhysType inputPhysType,
- List<AggImpState> aggs,
- Expression accumulatorInitializer,
- boolean hasOrderedCall,
- ParameterExpression lambdaFactory) {
- if (hasOrderedCall) {
- ParameterExpression pe = Expressions.parameter(List.class,
- builder.newName("lazyAccumulators"));
- builder.add(
- Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
-
- for (AggImpState agg : aggs) {
- if (agg.call.collation.equals(RelCollations.EMPTY)) {
- // if the call does not require ordering, fallback to
- // use a non-sorted lazy accumulator.
- builder.add(
- Expressions.statement(
- Expressions.call(pe,
- BuiltInMethod.COLLECTION_ADD.method,
- Expressions.new_(BuiltInMethod.BASIC_LAZY_ACCUMULATOR.constructor,
- agg.accumulatorAdder))));
- continue;
- }
- final Pair<Expression, Expression> pair =
- inputPhysType.generateCollationKey(
- agg.call.collation.getFieldCollations());
- builder.add(
- Expressions.statement(
- Expressions.call(pe,
- BuiltInMethod.COLLECTION_ADD.method,
- Expressions.new_(BuiltInMethod.SOURCE_SORTER.constructor,
- agg.accumulatorAdder, pair.left, pair.right))));
- }
- builder.add(
- Expressions.declare(0, lambdaFactory,
- Expressions.new_(
- BuiltInMethod.LAZY_AGGREGATE_LAMBDA_FACTORY.constructor,
- accumulatorInitializer, pe)));
- } else {
- // when hasOrderedCall == false
- ParameterExpression pe = Expressions.parameter(List.class,
- builder.newName("accumulatorAdders"));
- builder.add(
- Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
-
- for (AggImpState agg : aggs) {
- builder.add(
- Expressions.statement(
- Expressions.call(pe, BuiltInMethod.COLLECTION_ADD.method,
- agg.accumulatorAdder)));
- }
- builder.add(
- Expressions.declare(0, lambdaFactory,
- Expressions.new_(
- BuiltInMethod.BASIC_AGGREGATE_LAMBDA_FACTORY.constructor,
- accumulatorInitializer, pe)));
- }
- }
-
- /** An implementation of {@link AggContext}. */
- private class AggContextImpl implements AggContext {
- private final AggImpState agg;
- private final JavaTypeFactory typeFactory;
-
- AggContextImpl(AggImpState agg, JavaTypeFactory typeFactory) {
- this.agg = agg;
- this.typeFactory = typeFactory;
- }
-
- public SqlAggFunction aggregation() {
- return agg.call.getAggregation();
- }
-
- public RelDataType returnRelType() {
- return agg.call.type;
- }
-
- public Type returnType() {
- return EnumUtils.javaClass(typeFactory, returnRelType());
- }
-
- public List<? extends RelDataType> parameterRelTypes() {
- return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
- agg.call.getArgList());
- }
-
- public List<? extends Type> parameterTypes() {
- return EnumUtils.fieldTypes(
- typeFactory,
- parameterRelTypes());
- }
-
- public List<ImmutableBitSet> groupSets() {
- return groupSets;
- }
-
- public List<Integer> keyOrdinals() {
- return groupSet.asList();
- }
-
- public List<? extends RelDataType> keyRelTypes() {
- return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
- groupSet.asList());
- }
-
- public List<? extends Type> keyTypes() {
- return EnumUtils.fieldTypes(typeFactory, keyRelTypes());
- }
- }
}
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java
new file mode 100644
index 0000000..d8233a4
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.enumerable;
+
+import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
+import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
+import org.apache.calcite.adapter.java.JavaTypeFactory;
+import org.apache.calcite.config.CalciteSystemProperty;
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.linq4j.tree.BlockBuilder;
+import org.apache.calcite.linq4j.tree.Expression;
+import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.linq4j.tree.ParameterExpression;
+import org.apache.calcite.linq4j.tree.Types;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.hint.RelHint;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.util.BuiltInMethod;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+
+import com.google.common.collect.ImmutableList;
+
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
+/** Base class for EnumerableAggregate and EnumerableSortedAggregate. */
+public abstract class EnumerableAggregateBase extends Aggregate {
+ protected EnumerableAggregateBase(
+ RelOptCluster cluster,
+ RelTraitSet traitSet,
+ List<RelHint> hints,
+ RelNode input,
+ ImmutableBitSet groupSet,
+ List<ImmutableBitSet> groupSets,
+ List<AggregateCall> aggCalls) {
+ super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
+ }
+
+ protected static boolean hasOrderedCall(List<AggImpState> aggs) {
+ for (AggImpState agg : aggs) {
+ if (!agg.call.collation.equals(RelCollations.EMPTY)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ protected void declareParentAccumulator(List<Expression> initExpressions,
+ BlockBuilder initBlock, PhysType accPhysType) {
+ if (accPhysType.getJavaRowType()
+ instanceof JavaTypeFactoryImpl.SyntheticRecordType) {
+ // We have to initialize the SyntheticRecordType instance this way, to
+ // avoid using a class constructor with too many parameters.
+ final JavaTypeFactoryImpl.SyntheticRecordType synType =
+ (JavaTypeFactoryImpl.SyntheticRecordType)
+ accPhysType.getJavaRowType();
+ final ParameterExpression record0_ =
+ Expressions.parameter(accPhysType.getJavaRowType(), "record0");
+ initBlock.add(Expressions.declare(0, record0_, null));
+ initBlock.add(
+ Expressions.statement(
+ Expressions.assign(record0_,
+ Expressions.new_(accPhysType.getJavaRowType()))));
+ List<Types.RecordField> fieldList = synType.getRecordFields();
+ for (int i = 0; i < initExpressions.size(); i++) {
+ Expression right = initExpressions.get(i);
+ initBlock.add(
+ Expressions.statement(
+ Expressions.assign(
+ Expressions.field(record0_, fieldList.get(i)), right)));
+ }
+ initBlock.add(record0_);
+ } else {
+ initBlock.add(accPhysType.record(initExpressions));
+ }
+ }
+
+ /**
+ * Implements the {@link AggregateLambdaFactory}.
+ *
+ * <p>Behavior depends upon ordering:
+ * <ul>
+ *
+ * <li>{@code hasOrderedCall == true} means there is at least one aggregate
+ * call including sort spec. We use {@link LazyAggregateLambdaFactory}
+ * implementation to implement sorted aggregates for that.
+ *
+ * <li>{@code hasOrderedCall == false} indicates to use
+ * {@link BasicAggregateLambdaFactory} to implement a non-sort
+ * aggregate.
+ *
+ * </ul>
+ */
+ protected void implementLambdaFactory(BlockBuilder builder,
+ PhysType inputPhysType, List<AggImpState> aggs,
+ Expression accumulatorInitializer, boolean hasOrderedCall,
+ ParameterExpression lambdaFactory) {
+ if (hasOrderedCall) {
+ ParameterExpression pe = Expressions.parameter(List.class,
+ builder.newName("lazyAccumulators"));
+ builder.add(
+ Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
+
+ for (AggImpState agg : aggs) {
+ if (agg.call.collation.equals(RelCollations.EMPTY)) {
+ // if the call does not require ordering, fallback to
+ // use a non-sorted lazy accumulator.
+ builder.add(
+ Expressions.statement(
+ Expressions.call(pe,
+ BuiltInMethod.COLLECTION_ADD.method,
+ Expressions.new_(BuiltInMethod.BASIC_LAZY_ACCUMULATOR.constructor,
+ agg.accumulatorAdder))));
+ continue;
+ }
+ final Pair<Expression, Expression> pair =
+ inputPhysType.generateCollationKey(
+ agg.call.collation.getFieldCollations());
+ builder.add(
+ Expressions.statement(
+ Expressions.call(pe,
+ BuiltInMethod.COLLECTION_ADD.method,
+ Expressions.new_(BuiltInMethod.SOURCE_SORTER.constructor,
+ agg.accumulatorAdder, pair.left, pair.right))));
+ }
+ builder.add(
+ Expressions.declare(0, lambdaFactory,
+ Expressions.new_(
+ BuiltInMethod.LAZY_AGGREGATE_LAMBDA_FACTORY.constructor,
+ accumulatorInitializer, pe)));
+ } else {
+ // when hasOrderedCall == false
+ ParameterExpression pe = Expressions.parameter(List.class,
+ builder.newName("accumulatorAdders"));
+ builder.add(
+ Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
+
+ for (AggImpState agg : aggs) {
+ builder.add(
+ Expressions.statement(
+ Expressions.call(pe, BuiltInMethod.COLLECTION_ADD.method,
+ agg.accumulatorAdder)));
+ }
+ builder.add(
+ Expressions.declare(0, lambdaFactory,
+ Expressions.new_(
+ BuiltInMethod.BASIC_AGGREGATE_LAMBDA_FACTORY.constructor,
+ accumulatorInitializer, pe)));
+ }
+ }
+
+ /** An implementation of {@link AggContext}. */
+ protected class AggContextImpl implements AggContext {
+ private final AggImpState agg;
+ private final JavaTypeFactory typeFactory;
+
+ AggContextImpl(AggImpState agg, JavaTypeFactory typeFactory) {
+ this.agg = agg;
+ this.typeFactory = typeFactory;
+ }
+
+ public SqlAggFunction aggregation() {
+ return agg.call.getAggregation();
+ }
+
+ public RelDataType returnRelType() {
+ return agg.call.type;
+ }
+
+ public Type returnType() {
+ return EnumUtils.javaClass(typeFactory, returnRelType());
+ }
+
+ public List<? extends RelDataType> parameterRelTypes() {
+ return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
+ agg.call.getArgList());
+ }
+
+ public List<? extends Type> parameterTypes() {
+ return EnumUtils.fieldTypes(
+ typeFactory,
+ parameterRelTypes());
+ }
+
+ public List<ImmutableBitSet> groupSets() {
+ return groupSets;
+ }
+
+ public List<Integer> keyOrdinals() {
+ return groupSet.asList();
+ }
+
+ public List<? extends RelDataType> keyRelTypes() {
+ return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
+ groupSet.asList());
+ }
+
+ public List<? extends Type> keyTypes() {
+ return EnumUtils.fieldTypes(typeFactory, keyRelTypes());
+ }
+ }
+
+ protected void createAccumulatorAdders(
+ final ParameterExpression inParameter,
+ final List<AggImpState> aggs,
+ final PhysType accPhysType,
+ final ParameterExpression accExpr,
+ final PhysType inputPhysType,
+ final BlockBuilder builder,
+ EnumerableRelImplementor implementor,
+ JavaTypeFactory typeFactory) {
+ for (int i = 0, stateOffset = 0; i < aggs.size(); i++) {
+ final BlockBuilder builder2 = new BlockBuilder();
+ final AggImpState agg = aggs.get(i);
+
+ final int stateSize = agg.state.size();
+ final List<Expression> accumulator = new ArrayList<>(stateSize);
+ for (int j = 0; j < stateSize; j++) {
+ accumulator.add(accPhysType.fieldReference(accExpr, j + stateOffset));
+ }
+ agg.state = accumulator;
+
+ stateOffset += stateSize;
+
+ AggAddContext addContext =
+ new AggAddContextImpl(builder2, accumulator) {
+ public List<RexNode> rexArguments() {
+ List<RelDataTypeField> inputTypes =
+ inputPhysType.getRowType().getFieldList();
+ List<RexNode> args = new ArrayList<>();
+ for (int index : agg.call.getArgList()) {
+ args.add(RexInputRef.of(index, inputTypes));
+ }
+ return args;
+ }
+
+ public RexNode rexFilterArgument() {
+ return agg.call.filterArg < 0
+ ? null
+ : RexInputRef.of(agg.call.filterArg,
+ inputPhysType.getRowType());
+ }
+
+ public RexToLixTranslator rowTranslator() {
+ return RexToLixTranslator.forAggregation(typeFactory,
+ currentBlock(),
+ new RexToLixTranslator.InputGetterImpl(
+ Collections.singletonList(
+ Pair.of(inParameter, inputPhysType))),
+ implementor.getConformance())
+ .setNullable(currentNullables());
+ }
+ };
+
+ agg.implementor.implementAdd(agg.context, addContext);
+ builder2.add(accExpr);
+ agg.accumulatorAdder = builder.append("accumulatorAdder",
+ Expressions.lambda(Function2.class, builder2.toBlock(), accExpr,
+ inParameter));
+ }
+ }
+
+ protected List<Type> createAggStateTypes(
+ final List<Expression> initExpressions,
+ final BlockBuilder initBlock,
+ final List<AggImpState> aggs,
+ JavaTypeFactory typeFactory) {
+ final List<Type> aggStateTypes = new ArrayList<>();
+ for (final AggImpState agg : aggs) {
+ agg.context = new AggContextImpl(agg, typeFactory);
+ final List<Type> state = agg.implementor.getStateType(agg.context);
+
+ if (state.isEmpty()) {
+ agg.state = ImmutableList.of();
+ continue;
+ }
+
+ aggStateTypes.addAll(state);
+
+ final List<Expression> decls = new ArrayList<>(state.size());
+ for (int i = 0; i < state.size(); i++) {
+ String aggName = "a" + agg.aggIdx;
+ if (CalciteSystemProperty.DEBUG.value()) {
+ aggName = Util.toJavaId(agg.call.getAggregation().getName(), 0)
+ .substring("ID$0$".length()) + aggName;
+ }
+ Type type = state.get(i);
+ ParameterExpression pe =
+ Expressions.parameter(type,
+ initBlock.newName(aggName + "s" + i));
+ initBlock.add(Expressions.declare(0, pe, null));
+ decls.add(pe);
+ }
+ agg.state = decls;
+ initExpressions.addAll(decls);
+ agg.implementor.implementReset(agg.context,
+ new AggResultContextImpl(initBlock, agg.call, decls, null, null));
+ }
+ return aggStateTypes;
+ }
+}
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
index 00ca78a..1bb5fd3 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
@@ -16,6 +16,15 @@
*/
package org.apache.calcite.adapter.enumerable;
+import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
+import org.apache.calcite.adapter.java.JavaTypeFactory;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.linq4j.function.Function0;
+import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.linq4j.tree.BlockBuilder;
+import org.apache.calcite.linq4j.tree.Expression;
+import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
@@ -25,6 +34,7 @@ import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
@@ -32,12 +42,13 @@ import org.apache.calcite.util.mapping.Mappings;
import com.google.common.collect.ImmutableList;
+import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
/** Sort based physical implementation of {@link Aggregate} in
* {@link EnumerableConvention enumerable calling convention}. */
-public class EnumerableSortedAggregate extends Aggregate implements EnumerableRel {
+public class EnumerableSortedAggregate extends EnumerableAggregateBase implements EnumerableRel {
public EnumerableSortedAggregate(
RelOptCluster cluster,
RelTraitSet traitSet,
@@ -90,6 +101,134 @@ public class EnumerableSortedAggregate extends Aggregate implements EnumerableRe
}
public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
- throw Util.needToImplement("EnumerableSortedAggregate");
+ if (!Aggregate.isSimple(this)) {
+ throw Util.needToImplement("EnumerableSortedAggregate");
+ }
+
+ final JavaTypeFactory typeFactory = implementor.getTypeFactory();
+ final BlockBuilder builder = new BlockBuilder();
+ final EnumerableRel child = (EnumerableRel) getInput();
+ final Result result = implementor.visitChild(this, 0, child, pref);
+ Expression childExp =
+ builder.append(
+ "child",
+ result.block);
+
+ final PhysType physType =
+ PhysTypeImpl.of(
+ typeFactory, getRowType(), pref.preferCustom());
+
+ final PhysType inputPhysType = result.physType;
+
+ ParameterExpression parameter =
+ Expressions.parameter(inputPhysType.getJavaRowType(), "a0");
+
+ final PhysType keyPhysType =
+ inputPhysType.project(groupSet.asList(), getGroupType() != Group.SIMPLE,
+ JavaRowFormat.LIST);
+ final int groupCount = getGroupCount();
+
+ final List<AggImpState> aggs = new ArrayList<>(aggCalls.size());
+ for (Ord<AggregateCall> call : Ord.zip(aggCalls)) {
+ aggs.add(new AggImpState(call.i, call.e, false));
+ }
+
+ // Function0<Object[]> accumulatorInitializer =
+ // new Function0<Object[]>() {
+ // public Object[] apply() {
+ // return new Object[] {0, 0};
+ // }
+ // };
+ final List<Expression> initExpressions = new ArrayList<>();
+ final BlockBuilder initBlock = new BlockBuilder();
+
+ final List<Type> aggStateTypes = createAggStateTypes(
+ initExpressions, initBlock, aggs, typeFactory);
+
+ final PhysType accPhysType =
+ PhysTypeImpl.of(typeFactory,
+ typeFactory.createSyntheticType(aggStateTypes));
+
+ declareParentAccumulator(initExpressions, initBlock, accPhysType);
+
+ final Expression accumulatorInitializer =
+ builder.append("accumulatorInitializer",
+ Expressions.lambda(
+ Function0.class,
+ initBlock.toBlock()));
+
+ // Function2<Object[], Employee, Object[]> accumulatorAdder =
+ // new Function2<Object[], Employee, Object[]>() {
+ // public Object[] apply(Object[] acc, Employee in) {
+ // acc[0] = ((Integer) acc[0]) + 1;
+ // acc[1] = ((Integer) acc[1]) + in.salary;
+ // return acc;
+ // }
+ // };
+ final ParameterExpression inParameter =
+ Expressions.parameter(inputPhysType.getJavaRowType(), "in");
+ final ParameterExpression acc_ =
+ Expressions.parameter(accPhysType.getJavaRowType(), "acc");
+
+ createAccumulatorAdders(
+ inParameter, aggs, accPhysType, acc_, inputPhysType, builder, implementor, typeFactory);
+
+ final ParameterExpression lambdaFactory =
+ Expressions.parameter(AggregateLambdaFactory.class,
+ builder.newName("lambdaFactory"));
+
+ implementLambdaFactory(builder, inputPhysType, aggs, accumulatorInitializer,
+ false, lambdaFactory);
+
+ final BlockBuilder resultBlock = new BlockBuilder();
+ final List<Expression> results = Expressions.list();
+ final ParameterExpression key_;
+ final Type keyType = keyPhysType.getJavaRowType();
+ key_ = Expressions.parameter(keyType, "key");
+ for (int j = 0; j < groupCount; j++) {
+ final Expression ref = keyPhysType.fieldReference(key_, j);
+ results.add(ref);
+ }
+
+ for (final AggImpState agg : aggs) {
+ results.add(
+ agg.implementor.implementResult(agg.context,
+ new AggResultContextImpl(resultBlock, agg.call, agg.state, key_,
+ keyPhysType)));
+ }
+ resultBlock.add(physType.record(results));
+
+ final Expression keySelector_ =
+ builder.append("keySelector",
+ inputPhysType.generateSelector(parameter,
+ groupSet.asList(),
+ keyPhysType.getFormat()));
+ // Generate the appropriate key Comparator. In the case of NULL values
+ // in group keys, the comparator must be able to support NULL values by giving a
+ // consistent sort ordering.
+ final Expression comparator = keyPhysType.generateComparator(getTraitSet().getCollation());
+
+ final Expression resultSelector_ =
+ builder.append("resultSelector",
+ Expressions.lambda(Function2.class,
+ resultBlock.toBlock(),
+ key_,
+ acc_));
+
+ builder.add(
+ Expressions.return_(null,
+ Expressions.call(childExp,
+ BuiltInMethod.SORTED_GROUP_BY.method,
+ Expressions.list(keySelector_,
+ Expressions.call(lambdaFactory,
+ BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_INITIALIZER.method),
+ Expressions.call(lambdaFactory,
+ BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_ADDER.method),
+ Expressions.call(lambdaFactory,
+ BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_RESULT_SELECTOR.method,
+ resultSelector_), comparator)
+ )));
+
+ return implementor.result(physType, builder.toBlock());
}
}
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index b7e6176..4f92c5f 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -212,6 +212,8 @@ public enum BuiltInMethod {
WHERE2(ExtendedEnumerable.class, "where", Predicate2.class),
DISTINCT(ExtendedEnumerable.class, "distinct"),
DISTINCT2(ExtendedEnumerable.class, "distinct", EqualityComparer.class),
+ SORTED_GROUP_BY(ExtendedEnumerable.class, "sortedGroupBy", Function1.class,
+ Function0.class, Function2.class, Function2.class, Comparator.class),
GROUP_BY(ExtendedEnumerable.class, "groupBy", Function1.class),
GROUP_BY2(ExtendedEnumerable.class, "groupBy", Function1.class,
Function0.class, Function2.class, Function2.class),
diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java
new file mode 100644
index 0000000..225aa28
--- /dev/null
+++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.test.enumerable;
+
+import org.apache.calcite.adapter.enumerable.EnumerableRules;
+import org.apache.calcite.adapter.java.ReflectiveSchema;
+import org.apache.calcite.config.CalciteConnectionProperty;
+import org.apache.calcite.config.Lex;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.runtime.Hook;
+import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.test.JdbcTest;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.function.Consumer;
+
+public class EnumerableSortedAggregateTest {
+ @Test void sortedAgg() {
+ tester(false, new JdbcTest.HrSchema())
+ .query(
+ "select deptno, "
+ + "max(salary) as max_salary, count(name) as num_employee "
+ + "from emps group by deptno")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+ planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+ })
+ .explainContains(
+ "EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+ + " EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+ + " EnumerableTableScan(table=[[s, emps]])")
+ .returnsOrdered(
+ "deptno=10; max_salary=11500.0; num_employee=3",
+ "deptno=20; max_salary=8000.0; num_employee=1");
+ }
+
+ @Test void sortedAggTwoGroupKeys() {
+ tester(false, new JdbcTest.HrSchema())
+ .query(
+ "select deptno, commission, "
+ + "max(salary) as max_salary, count(name) as num_employee "
+ + "from emps group by deptno, commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+ planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+ })
+ .explainContains(
+ "EnumerableSortedAggregate(group=[{1, 4}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+ + " EnumerableSort(sort0=[$1], sort1=[$4], dir0=[ASC], dir1=[ASC])\n"
+ + " EnumerableTableScan(table=[[s, emps]])")
+ .returnsOrdered(
+ "deptno=10; commission=250; max_salary=11500.0; num_employee=1",
+ "deptno=10; commission=1000; max_salary=10000.0; num_employee=1",
+ "deptno=10; commission=null; max_salary=7000.0; num_employee=1",
+ "deptno=20; commission=500; max_salary=8000.0; num_employee=1");
+ }
+
+ // Outer sort is expected to be pushed through aggregation.
+ @Test void sortedAggGroupbyXOrderbyX() {
+ tester(false, new JdbcTest.HrSchema())
+ .query(
+ "select deptno, "
+ + "max(salary) as max_salary, count(name) as num_employee "
+ + "from emps group by deptno order by deptno")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+ planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+ })
+ .explainContains(
+ "EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+ + " EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+ + " EnumerableTableScan(table=[[s, emps]])")
+ .returnsOrdered(
+ "deptno=10; max_salary=11500.0; num_employee=3",
+ "deptno=20; max_salary=8000.0; num_employee=1");
+ }
+
+ // Outer sort is not expected to be pushed through aggregation.
+ @Test void sortedAggGroupbyXOrderbyY() {
+ tester(false, new JdbcTest.HrSchema())
+ .query(
+ "select deptno, "
+ + "max(salary) as max_salary, count(name) as num_employee "
+ + "from emps group by deptno order by num_employee desc")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+ planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+ })
+ .explainContains(
+ "EnumerableSort(sort0=[$2], dir0=[DESC])\n"
+ + " EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+ + " EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+ + " EnumerableTableScan(table=[[s, emps]])")
+ .returnsOrdered(
+ "deptno=10; max_salary=11500.0; num_employee=3",
+ "deptno=20; max_salary=8000.0; num_employee=1");
+ }
+
+ @Test void sortedAggNullValueInSortedGroupByKeys() {
+ tester(false, new JdbcTest.HrSchema())
+ .query(
+ "select commission, "
+ + "count(deptno) as num_dept "
+ + "from emps group by commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+ planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+ })
+ .explainContains(
+ "EnumerableSortedAggregate(group=[{4}], num_dept=[COUNT()])\n"
+ + " EnumerableSort(sort0=[$4], dir0=[ASC])\n"
+ + " EnumerableTableScan(table=[[s, emps]])")
+ .returnsOrdered(
+ "commission=250; num_dept=1",
+ "commission=500; num_dept=1",
+ "commission=1000; num_dept=1",
+ "commission=null; num_dept=1");
+ }
+
+ private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
+ Object schema) {
+ return CalciteAssert.that()
+ .with(CalciteConnectionProperty.LEX, Lex.JAVA)
+ .with(CalciteConnectionProperty.FORCE_DECORRELATE, forceDecorrelate)
+ .withSchema("s", new ReflectiveSchema(schema));
+ }
+}
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
index c55e4fc..3cc4c9e 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
@@ -319,6 +319,17 @@ public abstract class DefaultEnumerable<T> implements OrderedEnumerable<T> {
accumulatorInitializer, accumulatorAdder, resultSelector, comparer);
}
+ public <TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+ Function1<T, TKey> keySelector,
+ Function0<TAccumulate> accumulatorInitializer,
+ Function2<TAccumulate, T, TAccumulate> accumulatorAdder,
+ Function2<TKey, TAccumulate, TResult> resultSelector,
+ Comparator<TKey> comparator) {
+ return EnumerableDefaults.sortedGroupBy(
+ getThis(), keySelector, accumulatorInitializer,
+ accumulatorAdder, resultSelector, comparator);
+ }
+
public <TInner, TKey, TResult> Enumerable<TResult> groupJoin(
Enumerable<TInner> inner, Function1<T, TKey> outerKeySelector,
Function1<TInner, TKey> innerKeySelector,
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index bd9b2ab..1d1be29 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -817,6 +817,133 @@ public abstract class EnumerableDefaults {
resultSelector);
}
+ /**
+ * Group keys are sorted already. Key values are compared by using a
+ * specified comparator. Groups the elements of a sequence according to a
+ * specified key selector function and initializing one accumulator at a time.
+ * Go over elements sequentially, adding to accumulator each time an element
+ * with the same key is seen. When key changes, creates a result value from the
+ * accumulator and then re-initializes the accumulator. In the case of NULL values
+ * in group keys, the comparator must be able to support NULL values by giving a
+ * consistent sort ordering.
+ */
+ public static <TSource, TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+ Enumerable<TSource> enumerable,
+ Function1<TSource, TKey> keySelector,
+ Function0<TAccumulate> accumulatorInitializer,
+ Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+ final Function2<TKey, TAccumulate, TResult> resultSelector,
+ final Comparator<TKey> comparator) {
+ return new AbstractEnumerable<TResult>() {
+ public Enumerator<TResult> enumerator() {
+ return new SortedAggregateEnumerator(
+ enumerable, keySelector, accumulatorInitializer,
+ accumulatorAdder, resultSelector, comparator);
+ }
+ };
+ }
+
+ private static class SortedAggregateEnumerator<TSource, TKey, TAccumulate, TResult>
+ implements Enumerator<TResult> {
+ private final Enumerable<TSource> enumerable;
+ private final Function1<TSource, TKey> keySelector;
+ private final Function0<TAccumulate> accumulatorInitializer;
+ private final Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder;
+ private final Function2<TKey, TAccumulate, TResult> resultSelector;
+ private final Comparator<TKey> comparator;
+ private boolean isInitialized;
+ private boolean isLastMoveNextFalse;
+ private TAccumulate curAccumulator;
+ private Enumerator<TSource> enumerator;
+ private TResult curResult;
+
+ SortedAggregateEnumerator(
+ Enumerable<TSource> enumerable,
+ Function1<TSource, TKey> keySelector,
+ Function0<TAccumulate> accumulatorInitializer,
+ Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+ final Function2<TKey, TAccumulate, TResult> resultSelector,
+ final Comparator<TKey> comparator) {
+ this.enumerable = enumerable;
+ this.keySelector = keySelector;
+ this.accumulatorInitializer = accumulatorInitializer;
+ this.accumulatorAdder = accumulatorAdder;
+ this.resultSelector = resultSelector;
+ this.comparator = comparator;
+ isInitialized = false;
+ curAccumulator = null;
+ enumerator = enumerable.enumerator();
+ curResult = null;
+ isLastMoveNextFalse = false;
+ }
+
+ @Override public TResult current() {
+ if (isLastMoveNextFalse) {
+ throw new NoSuchElementException();
+ }
+ return curResult;
+ }
+
+ @Override public boolean moveNext() {
+ if (!isInitialized) {
+ isInitialized = true;
+ // input is empty
+ if (!enumerator.moveNext()) {
+ isLastMoveNextFalse = true;
+ return false;
+ }
+ } else if (curAccumulator == null) {
+ // input has been exhausted.
+ isLastMoveNextFalse = true;
+ return false;
+ }
+
+ if (curAccumulator == null) {
+ curAccumulator = accumulatorInitializer.apply();
+ }
+
+ // reset result because now it can move to next aggregated result.
+ curResult = null;
+ TSource o = enumerator.current();
+ TKey prevKey = keySelector.apply(o);
+ curAccumulator = accumulatorAdder.apply(curAccumulator, o);
+ while (enumerator.moveNext()) {
+ o = enumerator.current();
+ TKey curKey = keySelector.apply(o);
+ if (comparator.compare(prevKey, curKey) != 0) {
+ // current key is different from previous key, get accumulated results and re-create
+ // accumulator for current key.
+ curResult = resultSelector.apply(prevKey, curAccumulator);
+ curAccumulator = accumulatorInitializer.apply();
+ break;
+ }
+ curAccumulator = accumulatorAdder.apply(curAccumulator, o);
+ prevKey = curKey;
+ }
+
+ if (curResult == null) {
+ // current key is the last key.
+ curResult = resultSelector.apply(prevKey, curAccumulator);
+ // no need to keep accumulator for the last key.
+ curAccumulator = null;
+ }
+
+ return true;
+ }
+
+ @Override public void reset() {
+ enumerator.reset();
+ isInitialized = false;
+ curResult = null;
+ curAccumulator = null;
+ isLastMoveNextFalse = false;
+ }
+
+ @Override public void close() {
+ enumerator.close();
+ }
+ }
+
private static <TSource, TKey, TAccumulate, TResult> Enumerable<TResult> groupBy_(
final Map<TKey, TAccumulate> map, Enumerable<TSource> enumerable,
Function1<TSource, TKey> keySelector,
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
index dd0ae26..5ed7924 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
@@ -460,6 +460,23 @@ public interface ExtendedEnumerable<TSource> {
EqualityComparer<TKey> comparer);
/**
+ * Group keys are sorted already. Key values are compared by using a
+ * specified comparator. Groups the elements of a sequence according to a
+ * specified key selector function and initializing one accumulator at a time.
+ * Go over elements sequentially, adding to accumulator each time an element
+ * with the same key is seen. When key changes, creates a result value from the
+ * accumulator and then re-initializes the accumulator. In the case of NULL values
+ * in group keys, the comparator must be able to support NULL values by giving a
+ * consistent sort ordering.
+ */
+ <TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+ Function1<TSource, TKey> keySelector,
+ Function0<TAccumulate> accumulatorInitializer,
+ Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+ Function2<TKey, TAccumulate, TResult> resultSelector,
+ Comparator<TKey> comparator);
+
+ /**
* Correlates the elements of two sequences based on
* equality of keys and groups the results. The default equality
* comparer is used to compare keys.