You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by ru...@apache.org on 2020/06/29 08:36:13 UTC

[calcite] branch master updated: [CALCITE-4008] Implement Code generation for EnumerableSortedAggregate (Rui Wang).

This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new bf9ff00  [CALCITE-4008] Implement Code generation for EnumerableSortedAggregate (Rui Wang).
bf9ff00 is described below

commit bf9ff001db743bcba35943daf7fec5fe8b8b207e
Author: amaliujia <am...@163.com>
AuthorDate: Thu Jun 18 11:09:17 2020 -0700

    [CALCITE-4008] Implement Code generation for EnumerableSortedAggregate (Rui Wang).
---
 .../adapter/enumerable/EnumerableAggregate.java    | 266 +----------------
 .../enumerable/EnumerableAggregateBase.java        | 330 +++++++++++++++++++++
 .../enumerable/EnumerableSortedAggregate.java      | 143 ++++++++-
 .../org/apache/calcite/util/BuiltInMethod.java     |   2 +
 .../enumerable/EnumerableSortedAggregateTest.java  | 142 +++++++++
 .../apache/calcite/linq4j/DefaultEnumerable.java   |  11 +
 .../apache/calcite/linq4j/EnumerableDefaults.java  | 127 ++++++++
 .../apache/calcite/linq4j/ExtendedEnumerable.java  |  17 ++
 8 files changed, 775 insertions(+), 263 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
index a0dee72..03f3192 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregate.java
@@ -16,11 +16,8 @@
  */
 package org.apache.calcite.adapter.enumerable;
 
-import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
 import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
 import org.apache.calcite.adapter.java.JavaTypeFactory;
-import org.apache.calcite.config.CalciteSystemProperty;
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
 import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.linq4j.function.Function0;
 import org.apache.calcite.linq4j.function.Function1;
@@ -29,35 +26,23 @@ import org.apache.calcite.linq4j.tree.BlockBuilder;
 import org.apache.calcite.linq4j.tree.Expression;
 import org.apache.calcite.linq4j.tree.Expressions;
 import org.apache.calcite.linq4j.tree.ParameterExpression;
-import org.apache.calcite.linq4j.tree.Types;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.InvalidRelException;
-import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.Pair;
-import org.apache.calcite.util.Util;
 
 import com.google.common.collect.ImmutableList;
 
 import java.lang.reflect.Type;
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.LinkedList;
 import java.util.List;
 
 /** Implementation of {@link org.apache.calcite.rel.core.Aggregate} in
  * {@link org.apache.calcite.adapter.enumerable.EnumerableConvention enumerable calling convention}. */
-public class EnumerableAggregate extends Aggregate implements EnumerableRel {
+public class EnumerableAggregate extends EnumerableAggregateBase implements EnumerableRel {
   public EnumerableAggregate(
       RelOptCluster cluster,
       RelTraitSet traitSet,
@@ -202,37 +187,8 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
     final List<Expression> initExpressions = new ArrayList<>();
     final BlockBuilder initBlock = new BlockBuilder();
 
-    final List<Type> aggStateTypes = new ArrayList<>();
-    for (final AggImpState agg : aggs) {
-      agg.context = new AggContextImpl(agg, typeFactory);
-      final List<Type> state = agg.implementor.getStateType(agg.context);
-
-      if (state.isEmpty()) {
-        agg.state = ImmutableList.of();
-        continue;
-      }
-
-      aggStateTypes.addAll(state);
-
-      final List<Expression> decls = new ArrayList<>(state.size());
-      for (int i = 0; i < state.size(); i++) {
-        String aggName = "a" + agg.aggIdx;
-        if (CalciteSystemProperty.DEBUG.value()) {
-          aggName = Util.toJavaId(agg.call.getAggregation().getName(), 0)
-              .substring("ID$0$".length()) + aggName;
-        }
-        Type type = state.get(i);
-        ParameterExpression pe =
-            Expressions.parameter(type,
-                initBlock.newName(aggName + "s" + i));
-        initBlock.add(Expressions.declare(0, pe, null));
-        decls.add(pe);
-      }
-      agg.state = decls;
-      initExpressions.addAll(decls);
-      agg.implementor.implementReset(agg.context,
-          new AggResultContextImpl(initBlock, agg.call, decls, null, null));
-    }
+    final List<Type> aggStateTypes = createAggStateTypes(
+        initExpressions, initBlock, aggs, typeFactory);
 
     final PhysType accPhysType =
         PhysTypeImpl.of(typeFactory,
@@ -258,55 +214,9 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
         Expressions.parameter(inputPhysType.getJavaRowType(), "in");
     final ParameterExpression acc_ =
         Expressions.parameter(accPhysType.getJavaRowType(), "acc");
-    for (int i = 0, stateOffset = 0; i < aggs.size(); i++) {
-      final BlockBuilder builder2 = new BlockBuilder();
-      final AggImpState agg = aggs.get(i);
 
-      final int stateSize = agg.state.size();
-      final List<Expression> accumulator = new ArrayList<>(stateSize);
-      for (int j = 0; j < stateSize; j++) {
-        accumulator.add(accPhysType.fieldReference(acc_, j + stateOffset));
-      }
-      agg.state = accumulator;
-
-      stateOffset += stateSize;
-
-      AggAddContext addContext =
-          new AggAddContextImpl(builder2, accumulator) {
-            public List<RexNode> rexArguments() {
-              List<RelDataTypeField> inputTypes =
-                  inputPhysType.getRowType().getFieldList();
-              List<RexNode> args = new ArrayList<>();
-              for (int index : agg.call.getArgList()) {
-                args.add(RexInputRef.of(index, inputTypes));
-              }
-              return args;
-            }
-
-            public RexNode rexFilterArgument() {
-              return agg.call.filterArg < 0
-                  ? null
-                  : RexInputRef.of(agg.call.filterArg,
-                      inputPhysType.getRowType());
-            }
-
-            public RexToLixTranslator rowTranslator() {
-              return RexToLixTranslator.forAggregation(typeFactory,
-                  currentBlock(),
-                  new RexToLixTranslator.InputGetterImpl(
-                      Collections.singletonList(
-                          Pair.of(inParameter, inputPhysType))),
-                  implementor.getConformance())
-                  .setNullable(currentNullables());
-            }
-          };
-
-      agg.implementor.implementAdd(agg.context, addContext);
-      builder2.add(acc_);
-      agg.accumulatorAdder = builder.append("accumulatorAdder",
-          Expressions.lambda(Function2.class, builder2.toBlock(), acc_,
-              inParameter));
-    }
+    createAccumulatorAdders(
+        inParameter, aggs, accPhysType, acc_, inputPhysType, builder, implementor, typeFactory);
 
     final ParameterExpression lambdaFactory =
         Expressions.parameter(AggregateLambdaFactory.class,
@@ -443,170 +353,4 @@ public class EnumerableAggregate extends Aggregate implements EnumerableRel {
     }
     return implementor.result(physType, builder.toBlock());
   }
-
-  private static boolean hasOrderedCall(List<AggImpState> aggs) {
-    for (AggImpState agg : aggs) {
-      if (!agg.call.collation.equals(RelCollations.EMPTY)) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  private void declareParentAccumulator(List<Expression> initExpressions,
-      BlockBuilder initBlock, PhysType accPhysType) {
-    if (accPhysType.getJavaRowType()
-        instanceof JavaTypeFactoryImpl.SyntheticRecordType) {
-      // We have to initialize the SyntheticRecordType instance this way, to
-      // avoid using a class constructor with too many parameters.
-      final JavaTypeFactoryImpl.SyntheticRecordType synType =
-          (JavaTypeFactoryImpl.SyntheticRecordType)
-          accPhysType.getJavaRowType();
-      final ParameterExpression record0_ =
-          Expressions.parameter(accPhysType.getJavaRowType(), "record0");
-      initBlock.add(Expressions.declare(0, record0_, null));
-      initBlock.add(
-          Expressions.statement(
-              Expressions.assign(record0_,
-                  Expressions.new_(accPhysType.getJavaRowType()))));
-      List<Types.RecordField> fieldList = synType.getRecordFields();
-      for (int i = 0; i < initExpressions.size(); i++) {
-        Expression right = initExpressions.get(i);
-        initBlock.add(
-            Expressions.statement(
-                Expressions.assign(
-                    Expressions.field(record0_, fieldList.get(i)), right)));
-      }
-      initBlock.add(record0_);
-    } else {
-      initBlock.add(accPhysType.record(initExpressions));
-    }
-  }
-
-  /**
-   * Implements the {@link AggregateLambdaFactory}.
-   *
-   * <p>Behavior depends upon ordering:
-   * <ul>
-   *
-   * <li>{@code hasOrderedCall == true} means there is at least one aggregate
-   * call including sort spec. We use {@link LazyAggregateLambdaFactory}
-   * implementation to implement sorted aggregates for that.
-   *
-   * <li>{@code hasOrderedCall == false} indicates to use
-   * {@link BasicAggregateLambdaFactory} to implement a non-sort
-   * aggregate.
-   *
-   * </ul>
-   */
-  private void implementLambdaFactory(BlockBuilder builder,
-      PhysType inputPhysType,
-      List<AggImpState> aggs,
-      Expression accumulatorInitializer,
-      boolean hasOrderedCall,
-      ParameterExpression lambdaFactory) {
-    if (hasOrderedCall) {
-      ParameterExpression pe = Expressions.parameter(List.class,
-          builder.newName("lazyAccumulators"));
-      builder.add(
-          Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
-
-      for (AggImpState agg : aggs) {
-        if (agg.call.collation.equals(RelCollations.EMPTY)) {
-          // if the call does not require ordering, fallback to
-          // use a non-sorted lazy accumulator.
-          builder.add(
-              Expressions.statement(
-                  Expressions.call(pe,
-                      BuiltInMethod.COLLECTION_ADD.method,
-                      Expressions.new_(BuiltInMethod.BASIC_LAZY_ACCUMULATOR.constructor,
-                          agg.accumulatorAdder))));
-          continue;
-        }
-        final Pair<Expression, Expression> pair =
-            inputPhysType.generateCollationKey(
-                agg.call.collation.getFieldCollations());
-        builder.add(
-            Expressions.statement(
-                Expressions.call(pe,
-                    BuiltInMethod.COLLECTION_ADD.method,
-                    Expressions.new_(BuiltInMethod.SOURCE_SORTER.constructor,
-                        agg.accumulatorAdder, pair.left, pair.right))));
-      }
-      builder.add(
-          Expressions.declare(0, lambdaFactory,
-              Expressions.new_(
-                  BuiltInMethod.LAZY_AGGREGATE_LAMBDA_FACTORY.constructor,
-                  accumulatorInitializer, pe)));
-    } else {
-      // when hasOrderedCall == false
-      ParameterExpression pe = Expressions.parameter(List.class,
-          builder.newName("accumulatorAdders"));
-      builder.add(
-          Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
-
-      for (AggImpState agg : aggs) {
-        builder.add(
-            Expressions.statement(
-                Expressions.call(pe, BuiltInMethod.COLLECTION_ADD.method,
-                    agg.accumulatorAdder)));
-      }
-      builder.add(
-          Expressions.declare(0, lambdaFactory,
-              Expressions.new_(
-                  BuiltInMethod.BASIC_AGGREGATE_LAMBDA_FACTORY.constructor,
-                  accumulatorInitializer, pe)));
-    }
-  }
-
-  /** An implementation of {@link AggContext}. */
-  private class AggContextImpl implements AggContext {
-    private final AggImpState agg;
-    private final JavaTypeFactory typeFactory;
-
-    AggContextImpl(AggImpState agg, JavaTypeFactory typeFactory) {
-      this.agg = agg;
-      this.typeFactory = typeFactory;
-    }
-
-    public SqlAggFunction aggregation() {
-      return agg.call.getAggregation();
-    }
-
-    public RelDataType returnRelType() {
-      return agg.call.type;
-    }
-
-    public Type returnType() {
-      return EnumUtils.javaClass(typeFactory, returnRelType());
-    }
-
-    public List<? extends RelDataType> parameterRelTypes() {
-      return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
-          agg.call.getArgList());
-    }
-
-    public List<? extends Type> parameterTypes() {
-      return EnumUtils.fieldTypes(
-          typeFactory,
-          parameterRelTypes());
-    }
-
-    public List<ImmutableBitSet> groupSets() {
-      return groupSets;
-    }
-
-    public List<Integer> keyOrdinals() {
-      return groupSet.asList();
-    }
-
-    public List<? extends RelDataType> keyRelTypes() {
-      return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
-          groupSet.asList());
-    }
-
-    public List<? extends Type> keyTypes() {
-      return EnumUtils.fieldTypes(typeFactory, keyRelTypes());
-    }
-  }
 }
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java
new file mode 100644
index 0000000..d8233a4
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableAggregateBase.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.enumerable;
+
+import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
+import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
+import org.apache.calcite.adapter.java.JavaTypeFactory;
+import org.apache.calcite.config.CalciteSystemProperty;
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.linq4j.tree.BlockBuilder;
+import org.apache.calcite.linq4j.tree.Expression;
+import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.linq4j.tree.ParameterExpression;
+import org.apache.calcite.linq4j.tree.Types;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.hint.RelHint;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.util.BuiltInMethod;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+
+import com.google.common.collect.ImmutableList;
+
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
+/** Base class for EnumerableAggregate and EnumerableSortedAggregate. */
+public abstract class EnumerableAggregateBase extends Aggregate {
+  protected EnumerableAggregateBase(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      List<RelHint> hints,
+      RelNode input,
+      ImmutableBitSet groupSet,
+      List<ImmutableBitSet> groupSets,
+      List<AggregateCall> aggCalls) {
+    super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
+  }
+
+  protected static boolean hasOrderedCall(List<AggImpState> aggs) {
+    for (AggImpState agg : aggs) {
+      if (!agg.call.collation.equals(RelCollations.EMPTY)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  protected void declareParentAccumulator(List<Expression> initExpressions,
+      BlockBuilder initBlock, PhysType accPhysType) {
+    if (accPhysType.getJavaRowType()
+        instanceof JavaTypeFactoryImpl.SyntheticRecordType) {
+      // We have to initialize the SyntheticRecordType instance this way, to
+      // avoid using a class constructor with too many parameters.
+      final JavaTypeFactoryImpl.SyntheticRecordType synType =
+          (JavaTypeFactoryImpl.SyntheticRecordType)
+              accPhysType.getJavaRowType();
+      final ParameterExpression record0_ =
+          Expressions.parameter(accPhysType.getJavaRowType(), "record0");
+      initBlock.add(Expressions.declare(0, record0_, null));
+      initBlock.add(
+          Expressions.statement(
+              Expressions.assign(record0_,
+                  Expressions.new_(accPhysType.getJavaRowType()))));
+      List<Types.RecordField> fieldList = synType.getRecordFields();
+      for (int i = 0; i < initExpressions.size(); i++) {
+        Expression right = initExpressions.get(i);
+        initBlock.add(
+            Expressions.statement(
+                Expressions.assign(
+                    Expressions.field(record0_, fieldList.get(i)), right)));
+      }
+      initBlock.add(record0_);
+    } else {
+      initBlock.add(accPhysType.record(initExpressions));
+    }
+  }
+
+  /**
+   * Implements the {@link AggregateLambdaFactory}.
+   *
+   * <p>Behavior depends upon ordering:
+   * <ul>
+   *
+   * <li>{@code hasOrderedCall == true} means there is at least one aggregate
+   * call including sort spec. We use {@link LazyAggregateLambdaFactory}
+   * implementation to implement sorted aggregates for that.
+   *
+   * <li>{@code hasOrderedCall == false} indicates to use
+   * {@link BasicAggregateLambdaFactory} to implement a non-sort
+   * aggregate.
+   *
+   * </ul>
+   */
+  protected void implementLambdaFactory(BlockBuilder builder,
+      PhysType inputPhysType, List<AggImpState> aggs,
+      Expression accumulatorInitializer, boolean hasOrderedCall,
+      ParameterExpression lambdaFactory) {
+    if (hasOrderedCall) {
+      ParameterExpression pe = Expressions.parameter(List.class,
+          builder.newName("lazyAccumulators"));
+      builder.add(
+          Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
+
+      for (AggImpState agg : aggs) {
+        if (agg.call.collation.equals(RelCollations.EMPTY)) {
+          // if the call does not require ordering, fallback to
+          // use a non-sorted lazy accumulator.
+          builder.add(
+              Expressions.statement(
+                  Expressions.call(pe,
+                      BuiltInMethod.COLLECTION_ADD.method,
+                      Expressions.new_(BuiltInMethod.BASIC_LAZY_ACCUMULATOR.constructor,
+                          agg.accumulatorAdder))));
+          continue;
+        }
+        final Pair<Expression, Expression> pair =
+            inputPhysType.generateCollationKey(
+                agg.call.collation.getFieldCollations());
+        builder.add(
+            Expressions.statement(
+                Expressions.call(pe,
+                    BuiltInMethod.COLLECTION_ADD.method,
+                    Expressions.new_(BuiltInMethod.SOURCE_SORTER.constructor,
+                        agg.accumulatorAdder, pair.left, pair.right))));
+      }
+      builder.add(
+          Expressions.declare(0, lambdaFactory,
+              Expressions.new_(
+                  BuiltInMethod.LAZY_AGGREGATE_LAMBDA_FACTORY.constructor,
+                  accumulatorInitializer, pe)));
+    } else {
+      // when hasOrderedCall == false
+      ParameterExpression pe = Expressions.parameter(List.class,
+          builder.newName("accumulatorAdders"));
+      builder.add(
+          Expressions.declare(0, pe, Expressions.new_(LinkedList.class)));
+
+      for (AggImpState agg : aggs) {
+        builder.add(
+            Expressions.statement(
+                Expressions.call(pe, BuiltInMethod.COLLECTION_ADD.method,
+                    agg.accumulatorAdder)));
+      }
+      builder.add(
+          Expressions.declare(0, lambdaFactory,
+              Expressions.new_(
+                  BuiltInMethod.BASIC_AGGREGATE_LAMBDA_FACTORY.constructor,
+                  accumulatorInitializer, pe)));
+    }
+  }
+
+  /** An implementation of {@link AggContext}. */
+  protected class AggContextImpl implements AggContext {
+    private final AggImpState agg;
+    private final JavaTypeFactory typeFactory;
+
+    AggContextImpl(AggImpState agg, JavaTypeFactory typeFactory) {
+      this.agg = agg;
+      this.typeFactory = typeFactory;
+    }
+
+    public SqlAggFunction aggregation() {
+      return agg.call.getAggregation();
+    }
+
+    public RelDataType returnRelType() {
+      return agg.call.type;
+    }
+
+    public Type returnType() {
+      return EnumUtils.javaClass(typeFactory, returnRelType());
+    }
+
+    public List<? extends RelDataType> parameterRelTypes() {
+      return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
+          agg.call.getArgList());
+    }
+
+    public List<? extends Type> parameterTypes() {
+      return EnumUtils.fieldTypes(
+          typeFactory,
+          parameterRelTypes());
+    }
+
+    public List<ImmutableBitSet> groupSets() {
+      return groupSets;
+    }
+
+    public List<Integer> keyOrdinals() {
+      return groupSet.asList();
+    }
+
+    public List<? extends RelDataType> keyRelTypes() {
+      return EnumUtils.fieldRowTypes(getInput().getRowType(), null,
+          groupSet.asList());
+    }
+
+    public List<? extends Type> keyTypes() {
+      return EnumUtils.fieldTypes(typeFactory, keyRelTypes());
+    }
+  }
+
+  protected void createAccumulatorAdders(
+      final ParameterExpression inParameter,
+      final List<AggImpState> aggs,
+      final PhysType accPhysType,
+      final ParameterExpression accExpr,
+      final PhysType inputPhysType,
+      final BlockBuilder builder,
+      EnumerableRelImplementor implementor,
+      JavaTypeFactory typeFactory) {
+    for (int i = 0, stateOffset = 0; i < aggs.size(); i++) {
+      final BlockBuilder builder2 = new BlockBuilder();
+      final AggImpState agg = aggs.get(i);
+
+      final int stateSize = agg.state.size();
+      final List<Expression> accumulator = new ArrayList<>(stateSize);
+      for (int j = 0; j < stateSize; j++) {
+        accumulator.add(accPhysType.fieldReference(accExpr, j + stateOffset));
+      }
+      agg.state = accumulator;
+
+      stateOffset += stateSize;
+
+      AggAddContext addContext =
+          new AggAddContextImpl(builder2, accumulator) {
+            public List<RexNode> rexArguments() {
+              List<RelDataTypeField> inputTypes =
+                  inputPhysType.getRowType().getFieldList();
+              List<RexNode> args = new ArrayList<>();
+              for (int index : agg.call.getArgList()) {
+                args.add(RexInputRef.of(index, inputTypes));
+              }
+              return args;
+            }
+
+            public RexNode rexFilterArgument() {
+              return agg.call.filterArg < 0
+                  ? null
+                  : RexInputRef.of(agg.call.filterArg,
+                      inputPhysType.getRowType());
+            }
+
+            public RexToLixTranslator rowTranslator() {
+              return RexToLixTranslator.forAggregation(typeFactory,
+                  currentBlock(),
+                  new RexToLixTranslator.InputGetterImpl(
+                      Collections.singletonList(
+                          Pair.of(inParameter, inputPhysType))),
+                  implementor.getConformance())
+                  .setNullable(currentNullables());
+            }
+          };
+
+      agg.implementor.implementAdd(agg.context, addContext);
+      builder2.add(accExpr);
+      agg.accumulatorAdder = builder.append("accumulatorAdder",
+          Expressions.lambda(Function2.class, builder2.toBlock(), accExpr,
+              inParameter));
+    }
+  }
+
+  protected List<Type> createAggStateTypes(
+      final List<Expression> initExpressions,
+      final BlockBuilder initBlock,
+      final List<AggImpState> aggs,
+      JavaTypeFactory typeFactory) {
+    final List<Type> aggStateTypes = new ArrayList<>();
+    for (final AggImpState agg : aggs) {
+      agg.context = new AggContextImpl(agg, typeFactory);
+      final List<Type> state = agg.implementor.getStateType(agg.context);
+
+      if (state.isEmpty()) {
+        agg.state = ImmutableList.of();
+        continue;
+      }
+
+      aggStateTypes.addAll(state);
+
+      final List<Expression> decls = new ArrayList<>(state.size());
+      for (int i = 0; i < state.size(); i++) {
+        String aggName = "a" + agg.aggIdx;
+        if (CalciteSystemProperty.DEBUG.value()) {
+          aggName = Util.toJavaId(agg.call.getAggregation().getName(), 0)
+              .substring("ID$0$".length()) + aggName;
+        }
+        Type type = state.get(i);
+        ParameterExpression pe =
+            Expressions.parameter(type,
+                initBlock.newName(aggName + "s" + i));
+        initBlock.add(Expressions.declare(0, pe, null));
+        decls.add(pe);
+      }
+      agg.state = decls;
+      initExpressions.addAll(decls);
+      agg.implementor.implementReset(agg.context,
+          new AggResultContextImpl(initBlock, agg.call, decls, null, null));
+    }
+    return aggStateTypes;
+  }
+}
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
index 00ca78a..1bb5fd3 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableSortedAggregate.java
@@ -16,6 +16,15 @@
  */
 package org.apache.calcite.adapter.enumerable;
 
+import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
+import org.apache.calcite.adapter.java.JavaTypeFactory;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.linq4j.function.Function0;
+import org.apache.calcite.linq4j.function.Function2;
+import org.apache.calcite.linq4j.tree.BlockBuilder;
+import org.apache.calcite.linq4j.tree.Expression;
+import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.linq4j.tree.ParameterExpression;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelCollation;
@@ -25,6 +34,7 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.Util;
@@ -32,12 +42,13 @@ import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.collect.ImmutableList;
 
+import java.lang.reflect.Type;
 import java.util.ArrayList;
 import java.util.List;
 
 /** Sort based physical implementation of {@link Aggregate} in
  * {@link EnumerableConvention enumerable calling convention}. */
-public class EnumerableSortedAggregate extends Aggregate implements EnumerableRel {
+public class EnumerableSortedAggregate extends EnumerableAggregateBase implements EnumerableRel {
   public EnumerableSortedAggregate(
       RelOptCluster cluster,
       RelTraitSet traitSet,
@@ -90,6 +101,134 @@ public class EnumerableSortedAggregate extends Aggregate implements EnumerableRe
   }
 
   public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
-    throw Util.needToImplement("EnumerableSortedAggregate");
+    if (!Aggregate.isSimple(this)) {
+      throw Util.needToImplement("EnumerableSortedAggregate");
+    }
+
+    final JavaTypeFactory typeFactory = implementor.getTypeFactory();
+    final BlockBuilder builder = new BlockBuilder();
+    final EnumerableRel child = (EnumerableRel) getInput();
+    final Result result = implementor.visitChild(this, 0, child, pref);
+    Expression childExp =
+        builder.append(
+            "child",
+            result.block);
+
+    final PhysType physType =
+        PhysTypeImpl.of(
+            typeFactory, getRowType(), pref.preferCustom());
+
+    final PhysType inputPhysType = result.physType;
+
+    ParameterExpression parameter =
+        Expressions.parameter(inputPhysType.getJavaRowType(), "a0");
+
+    final PhysType keyPhysType =
+        inputPhysType.project(groupSet.asList(), getGroupType() != Group.SIMPLE,
+            JavaRowFormat.LIST);
+    final int groupCount = getGroupCount();
+
+    final List<AggImpState> aggs = new ArrayList<>(aggCalls.size());
+    for (Ord<AggregateCall> call : Ord.zip(aggCalls)) {
+      aggs.add(new AggImpState(call.i, call.e, false));
+    }
+
+    // Function0<Object[]> accumulatorInitializer =
+    //     new Function0<Object[]>() {
+    //         public Object[] apply() {
+    //             return new Object[] {0, 0};
+    //         }
+    //     };
+    final List<Expression> initExpressions = new ArrayList<>();
+    final BlockBuilder initBlock = new BlockBuilder();
+
+    final List<Type> aggStateTypes = createAggStateTypes(
+        initExpressions, initBlock, aggs, typeFactory);
+
+    final PhysType accPhysType =
+        PhysTypeImpl.of(typeFactory,
+            typeFactory.createSyntheticType(aggStateTypes));
+
+    declareParentAccumulator(initExpressions, initBlock, accPhysType);
+
+    final Expression accumulatorInitializer =
+        builder.append("accumulatorInitializer",
+            Expressions.lambda(
+                Function0.class,
+                initBlock.toBlock()));
+
+    // Function2<Object[], Employee, Object[]> accumulatorAdder =
+    //     new Function2<Object[], Employee, Object[]>() {
+    //         public Object[] apply(Object[] acc, Employee in) {
+    //              acc[0] = ((Integer) acc[0]) + 1;
+    //              acc[1] = ((Integer) acc[1]) + in.salary;
+    //             return acc;
+    //         }
+    //     };
+    final ParameterExpression inParameter =
+        Expressions.parameter(inputPhysType.getJavaRowType(), "in");
+    final ParameterExpression acc_ =
+        Expressions.parameter(accPhysType.getJavaRowType(), "acc");
+
+    createAccumulatorAdders(
+        inParameter, aggs, accPhysType, acc_, inputPhysType, builder, implementor, typeFactory);
+
+    final ParameterExpression lambdaFactory =
+        Expressions.parameter(AggregateLambdaFactory.class,
+            builder.newName("lambdaFactory"));
+
+    implementLambdaFactory(builder, inputPhysType, aggs, accumulatorInitializer,
+        false, lambdaFactory);
+
+    final BlockBuilder resultBlock = new BlockBuilder();
+    final List<Expression> results = Expressions.list();
+    final ParameterExpression key_;
+    final Type keyType = keyPhysType.getJavaRowType();
+    key_ = Expressions.parameter(keyType, "key");
+    for (int j = 0; j < groupCount; j++) {
+      final Expression ref = keyPhysType.fieldReference(key_, j);
+      results.add(ref);
+    }
+
+    for (final AggImpState agg : aggs) {
+      results.add(
+          agg.implementor.implementResult(agg.context,
+              new AggResultContextImpl(resultBlock, agg.call, agg.state, key_,
+                  keyPhysType)));
+    }
+    resultBlock.add(physType.record(results));
+
+    final Expression keySelector_ =
+        builder.append("keySelector",
+            inputPhysType.generateSelector(parameter,
+                groupSet.asList(),
+                keyPhysType.getFormat()));
+    // Generate the appropriate key Comparator. In the case of NULL values
+    // in group keys, the comparator must be able to support NULL values by giving a
+    // consistent sort ordering.
+    final Expression comparator = keyPhysType.generateComparator(getTraitSet().getCollation());
+
+    final Expression resultSelector_ =
+        builder.append("resultSelector",
+            Expressions.lambda(Function2.class,
+                resultBlock.toBlock(),
+                key_,
+                acc_));
+
+    builder.add(
+        Expressions.return_(null,
+            Expressions.call(childExp,
+                BuiltInMethod.SORTED_GROUP_BY.method,
+                Expressions.list(keySelector_,
+                    Expressions.call(lambdaFactory,
+                        BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_INITIALIZER.method),
+                    Expressions.call(lambdaFactory,
+                        BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_ADDER.method),
+                    Expressions.call(lambdaFactory,
+                        BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_RESULT_SELECTOR.method,
+                        resultSelector_), comparator)
+                    )));
+
+    return implementor.result(physType, builder.toBlock());
   }
 }
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index b7e6176..4f92c5f 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -212,6 +212,8 @@ public enum BuiltInMethod {
   WHERE2(ExtendedEnumerable.class, "where", Predicate2.class),
   DISTINCT(ExtendedEnumerable.class, "distinct"),
   DISTINCT2(ExtendedEnumerable.class, "distinct", EqualityComparer.class),
+  SORTED_GROUP_BY(ExtendedEnumerable.class, "sortedGroupBy", Function1.class,
+      Function0.class, Function2.class, Function2.class, Comparator.class),
   GROUP_BY(ExtendedEnumerable.class, "groupBy", Function1.class),
   GROUP_BY2(ExtendedEnumerable.class, "groupBy", Function1.class,
       Function0.class, Function2.class, Function2.class),
diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java
new file mode 100644
index 0000000..225aa28
--- /dev/null
+++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableSortedAggregateTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.test.enumerable;
+
+import org.apache.calcite.adapter.enumerable.EnumerableRules;
+import org.apache.calcite.adapter.java.ReflectiveSchema;
+import org.apache.calcite.config.CalciteConnectionProperty;
+import org.apache.calcite.config.Lex;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.runtime.Hook;
+import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.test.JdbcTest;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.function.Consumer;
+
+public class EnumerableSortedAggregateTest {
+  @Test void sortedAgg() {
+    tester(false, new JdbcTest.HrSchema())
+        .query(
+            "select deptno, "
+            + "max(salary) as max_salary, count(name) as num_employee "
+            + "from emps group by deptno")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+          planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+        })
+        .explainContains(
+            "EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+            + "  EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+            + "    EnumerableTableScan(table=[[s, emps]])")
+        .returnsOrdered(
+            "deptno=10; max_salary=11500.0; num_employee=3",
+            "deptno=20; max_salary=8000.0; num_employee=1");
+  }
+
+  @Test void sortedAggTwoGroupKeys() {
+    tester(false, new JdbcTest.HrSchema())
+        .query(
+            "select deptno, commission, "
+                + "max(salary) as max_salary, count(name) as num_employee "
+                + "from emps group by deptno, commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+          planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+        })
+        .explainContains(
+            "EnumerableSortedAggregate(group=[{1, 4}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+            + "  EnumerableSort(sort0=[$1], sort1=[$4], dir0=[ASC], dir1=[ASC])\n"
+            + "    EnumerableTableScan(table=[[s, emps]])")
+        .returnsOrdered(
+            "deptno=10; commission=250; max_salary=11500.0; num_employee=1",
+            "deptno=10; commission=1000; max_salary=10000.0; num_employee=1",
+            "deptno=10; commission=null; max_salary=7000.0; num_employee=1",
+            "deptno=20; commission=500; max_salary=8000.0; num_employee=1");
+  }
+
+  // Outer sort is expected to be pushed through aggregation.
+  @Test void sortedAggGroupbyXOrderbyX() {
+    tester(false, new JdbcTest.HrSchema())
+        .query(
+            "select deptno, "
+                + "max(salary) as max_salary, count(name) as num_employee "
+                + "from emps group by deptno order by deptno")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+          planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+        })
+        .explainContains(
+            "EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+            + "  EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+            + "    EnumerableTableScan(table=[[s, emps]])")
+        .returnsOrdered(
+            "deptno=10; max_salary=11500.0; num_employee=3",
+            "deptno=20; max_salary=8000.0; num_employee=1");
+  }
+
+  // Outer sort is not expected to be pushed through aggregation.
+  @Test void sortedAggGroupbyXOrderbyY() {
+    tester(false, new JdbcTest.HrSchema())
+        .query(
+            "select deptno, "
+                + "max(salary) as max_salary, count(name) as num_employee "
+                + "from emps group by deptno order by num_employee desc")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+          planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+        })
+        .explainContains(
+            "EnumerableSort(sort0=[$2], dir0=[DESC])\n"
+            + "  EnumerableSortedAggregate(group=[{1}], max_salary=[MAX($3)], num_employee=[COUNT($2)])\n"
+            + "    EnumerableSort(sort0=[$1], dir0=[ASC])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])")
+        .returnsOrdered(
+            "deptno=10; max_salary=11500.0; num_employee=3",
+            "deptno=20; max_salary=8000.0; num_employee=1");
+  }
+
+  @Test void sortedAggNullValueInSortedGroupByKeys() {
+    tester(false, new JdbcTest.HrSchema())
+        .query(
+            "select commission, "
+                + "count(deptno) as num_dept "
+                + "from emps group by commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_AGGREGATE_RULE);
+          planner.addRule(EnumerableRules.ENUMERABLE_SORTED_AGGREGATE_RULE);
+        })
+        .explainContains(
+            "EnumerableSortedAggregate(group=[{4}], num_dept=[COUNT()])\n"
+                + "  EnumerableSort(sort0=[$4], dir0=[ASC])\n"
+                + "    EnumerableTableScan(table=[[s, emps]])")
+        .returnsOrdered(
+            "commission=250; num_dept=1",
+            "commission=500; num_dept=1",
+            "commission=1000; num_dept=1",
+            "commission=null; num_dept=1");
+  }
+
+  private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
+                                          Object schema) {
+    return CalciteAssert.that()
+        .with(CalciteConnectionProperty.LEX, Lex.JAVA)
+        .with(CalciteConnectionProperty.FORCE_DECORRELATE, forceDecorrelate)
+        .withSchema("s", new ReflectiveSchema(schema));
+  }
+}
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
index c55e4fc..3cc4c9e 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
@@ -319,6 +319,17 @@ public abstract class DefaultEnumerable<T> implements OrderedEnumerable<T> {
         accumulatorInitializer, accumulatorAdder, resultSelector, comparer);
   }
 
+  public <TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+      Function1<T, TKey> keySelector,
+      Function0<TAccumulate> accumulatorInitializer,
+      Function2<TAccumulate, T, TAccumulate> accumulatorAdder,
+      Function2<TKey, TAccumulate, TResult> resultSelector,
+      Comparator<TKey> comparator) {
+    return EnumerableDefaults.sortedGroupBy(
+        getThis(), keySelector, accumulatorInitializer,
+        accumulatorAdder, resultSelector, comparator);
+  }
+
   public <TInner, TKey, TResult> Enumerable<TResult> groupJoin(
       Enumerable<TInner> inner, Function1<T, TKey> outerKeySelector,
       Function1<TInner, TKey> innerKeySelector,
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index bd9b2ab..1d1be29 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -817,6 +817,133 @@ public abstract class EnumerableDefaults {
         resultSelector);
   }
 
+  /**
+   * Group keys are sorted already. Key values are compared by using a
+   * specified comparator. Groups the elements of a sequence according to a
+   * specified key selector function and initializing one accumulator at a time.
+   * Go over elements sequentially, adding to accumulator each time an element
+   * with the same key is seen. When key changes, creates a result value from the
+   * accumulator and then re-initializes the accumulator. In the case of NULL values
+   * in group keys, the comparator must be able to support NULL values by giving a
+   * consistent sort ordering.
+   */
+  public static <TSource, TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+      Enumerable<TSource> enumerable,
+      Function1<TSource, TKey> keySelector,
+      Function0<TAccumulate> accumulatorInitializer,
+      Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+      final Function2<TKey, TAccumulate, TResult> resultSelector,
+      final Comparator<TKey> comparator) {
+    return new AbstractEnumerable<TResult>() {
+      public Enumerator<TResult> enumerator() {
+        return new SortedAggregateEnumerator(
+          enumerable, keySelector, accumulatorInitializer,
+          accumulatorAdder, resultSelector, comparator);
+      }
+    };
+  }
+
+  private static class SortedAggregateEnumerator<TSource, TKey, TAccumulate, TResult>
+      implements Enumerator<TResult> {
+    private final Enumerable<TSource> enumerable;
+    private final Function1<TSource, TKey> keySelector;
+    private final Function0<TAccumulate> accumulatorInitializer;
+    private final Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder;
+    private final Function2<TKey, TAccumulate, TResult> resultSelector;
+    private final Comparator<TKey> comparator;
+    private boolean isInitialized;
+    private boolean isLastMoveNextFalse;
+    private TAccumulate curAccumulator;
+    private Enumerator<TSource> enumerator;
+    private TResult curResult;
+
+    SortedAggregateEnumerator(
+        Enumerable<TSource> enumerable,
+        Function1<TSource, TKey> keySelector,
+        Function0<TAccumulate> accumulatorInitializer,
+        Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+        final Function2<TKey, TAccumulate, TResult> resultSelector,
+        final Comparator<TKey> comparator) {
+      this.enumerable = enumerable;
+      this.keySelector = keySelector;
+      this.accumulatorInitializer = accumulatorInitializer;
+      this.accumulatorAdder = accumulatorAdder;
+      this.resultSelector = resultSelector;
+      this.comparator = comparator;
+      isInitialized = false;
+      curAccumulator = null;
+      enumerator = enumerable.enumerator();
+      curResult = null;
+      isLastMoveNextFalse = false;
+    }
+
+    @Override public TResult current() {
+      if (isLastMoveNextFalse) {
+        throw new NoSuchElementException();
+      }
+      return curResult;
+    }
+
+    @Override public boolean moveNext() {
+      if (!isInitialized) {
+        isInitialized = true;
+        // input is empty
+        if (!enumerator.moveNext()) {
+          isLastMoveNextFalse = true;
+          return false;
+        }
+      } else if (curAccumulator == null) {
+        // input has been exhausted.
+        isLastMoveNextFalse = true;
+        return false;
+      }
+
+      if (curAccumulator == null) {
+        curAccumulator = accumulatorInitializer.apply();
+      }
+
+      // reset result because now it can move to next aggregated result.
+      curResult = null;
+      TSource o = enumerator.current();
+      TKey prevKey = keySelector.apply(o);
+      curAccumulator = accumulatorAdder.apply(curAccumulator, o);
+      while (enumerator.moveNext()) {
+        o = enumerator.current();
+        TKey curKey = keySelector.apply(o);
+        if (comparator.compare(prevKey, curKey) != 0) {
+          // current key is different from previous key, get accumulated results and re-create
+          // accumulator for current key.
+          curResult = resultSelector.apply(prevKey, curAccumulator);
+          curAccumulator = accumulatorInitializer.apply();
+          break;
+        }
+        curAccumulator = accumulatorAdder.apply(curAccumulator, o);
+        prevKey = curKey;
+      }
+
+      if (curResult == null) {
+        // current key is the last key.
+        curResult = resultSelector.apply(prevKey, curAccumulator);
+        // no need to keep accumulator for the last key.
+        curAccumulator = null;
+      }
+
+      return true;
+    }
+
+    @Override public void reset() {
+      enumerator.reset();
+      isInitialized = false;
+      curResult = null;
+      curAccumulator = null;
+      isLastMoveNextFalse = false;
+    }
+
+    @Override public void close() {
+      enumerator.close();
+    }
+  }
+
   private static <TSource, TKey, TAccumulate, TResult> Enumerable<TResult> groupBy_(
       final Map<TKey, TAccumulate> map, Enumerable<TSource> enumerable,
       Function1<TSource, TKey> keySelector,
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
index dd0ae26..5ed7924 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
@@ -460,6 +460,23 @@ public interface ExtendedEnumerable<TSource> {
       EqualityComparer<TKey> comparer);
 
   /**
+   * Group keys are sorted already. Key values are compared by using a
+   * specified comparator. Groups the elements of a sequence according to a
+   * specified key selector function and initializing one accumulator at a time.
+   * Go over elements sequentially, adding to accumulator each time an element
+   * with the same key is seen. When key changes, creates a result value from the
+   * accumulator and then re-initializes the accumulator. In the case of NULL values
+   * in group keys, the comparator must be able to support NULL values by giving a
+   * consistent sort ordering.
+   */
+  <TKey, TAccumulate, TResult> Enumerable<TResult> sortedGroupBy(
+      Function1<TSource, TKey> keySelector,
+      Function0<TAccumulate> accumulatorInitializer,
+      Function2<TAccumulate, TSource, TAccumulate> accumulatorAdder,
+      Function2<TKey, TAccumulate, TResult> resultSelector,
+      Comparator<TKey> comparator);
+
+  /**
    * Correlates the elements of two sequences based on
    * equality of keys and groups the results. The default equality
    * comparer is used to compare keys.