You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/08 14:13:18 UTC
[groovy] branch GROOVY-8258 updated: GROOVY-8258: support groupby
This is an automated email from the ASF dual-hosted git repository.
sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git
The following commit(s) were added to refs/heads/GROOVY-8258 by this push:
new ad5c26a GROOVY-8258: support groupby
ad5c26a is described below
commit ad5c26a21beb0aa576cdc22d3e69e6316e1fbf4e
Author: Daniel Sun <su...@apache.org>
AuthorDate: Thu Oct 8 22:12:51 2020 +0800
GROOVY-8258: support groupby
---
.../org/apache/groovy/linq/dsl/GinqAstBuilder.java | 14 ++++
.../org/apache/groovy/linq/dsl/GinqVisitor.java | 2 +
.../linq/dsl/expression/DataSourceExpression.java | 9 ++
.../linq/dsl/expression/GroupExpression.java | 44 ++++++++++
.../linq/provider/collection/GinqAstWalker.groovy | 97 ++++++++++++++++------
.../groovy/linq/provider/collection/Queryable.java | 1 +
.../provider/collection/QueryableCollection.java | 18 +++-
.../groovy/org/apache/groovy/linq/GinqTest.groovy | 11 +++
8 files changed, 170 insertions(+), 26 deletions(-)
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
index 5b9fa5d..ce3874c 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
@@ -23,6 +23,7 @@ import org.apache.groovy.linq.dsl.expression.DataSourceExpression;
import org.apache.groovy.linq.dsl.expression.FilterExpression;
import org.apache.groovy.linq.dsl.expression.FromExpression;
import org.apache.groovy.linq.dsl.expression.GinqExpression;
+import org.apache.groovy.linq.dsl.expression.GroupExpression;
import org.apache.groovy.linq.dsl.expression.JoinExpression;
import org.apache.groovy.linq.dsl.expression.OnExpression;
import org.apache.groovy.linq.dsl.expression.OrderExpression;
@@ -140,6 +141,19 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
return;
}
+ if ("groupby".equals(methodName)) {
+ GroupExpression groupExpression = new GroupExpression(call.getArguments());
+ groupExpression.setSourcePosition(call);
+
+ if (ginqExpression instanceof DataSourceExpression) {
+ ((DataSourceExpression) ginqExpression).setGroupExpression(groupExpression);
+ } else {
+ throw new GroovyBugError("The preceding expression is not a DataSourceExpression: " + ginqExpression);
+ }
+
+ return;
+ }
+
if ("orderby".equals(methodName)) {
OrderExpression orderExpression = new OrderExpression(call.getArguments());
orderExpression.setSourcePosition(call);
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
index 3647b31..8515b71 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
@@ -21,6 +21,7 @@ package org.apache.groovy.linq.dsl;
import org.apache.groovy.linq.dsl.expression.AbstractGinqExpression;
import org.apache.groovy.linq.dsl.expression.FromExpression;
import org.apache.groovy.linq.dsl.expression.GinqExpression;
+import org.apache.groovy.linq.dsl.expression.GroupExpression;
import org.apache.groovy.linq.dsl.expression.JoinExpression;
import org.apache.groovy.linq.dsl.expression.OnExpression;
import org.apache.groovy.linq.dsl.expression.OrderExpression;
@@ -39,6 +40,7 @@ public interface GinqVisitor<R> {
R visitJoinExpression(JoinExpression joinExpression);
R visitOnExpression(OnExpression onExpression);
R visitWhereExpression(WhereExpression whereExpression);
+ R visitGroupExpression(GroupExpression groupExpression);
R visitOrderExpression(OrderExpression orderExpression);
R visitSelectExpression(SelectExpression selectExpression);
R visit(AbstractGinqExpression expression);
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
index 71fcd1a..850372b 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
@@ -32,6 +32,7 @@ public abstract class DataSourceExpression extends AbstractGinqExpression {
protected Expression aliasExpr;
protected Expression dataSourceExpr;
protected final List<FilterExpression> filterExpressionList = new ArrayList<>(); // on, where
+ protected GroupExpression groupExpression;
protected OrderExpression orderExpression;
public DataSourceExpression(Expression aliasExpr, Expression dataSourceExpr) {
@@ -54,6 +55,14 @@ public abstract class DataSourceExpression extends AbstractGinqExpression {
this.filterExpressionList.add(filterExpression);
}
+ public GroupExpression getGroupExpression() {
+ return groupExpression;
+ }
+
+ public void setGroupExpression(GroupExpression groupExpression) {
+ this.groupExpression = groupExpression;
+ }
+
public OrderExpression getOrderExpression() {
return orderExpression;
}
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java
new file mode 100644
index 0000000..d5d56c8
--- /dev/null
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.groovy.linq.dsl.expression;
+
+import org.apache.groovy.linq.dsl.GinqVisitor;
+import org.codehaus.groovy.ast.expr.Expression;
+
+/**
+ * Represents group by expression
+ *
+ * @since 4.0.0
+ */
+public class GroupExpression extends AbstractGinqExpression {
+ private final Expression classifierExpr;
+
+ public GroupExpression(Expression classifierExpr) {
+ this.classifierExpr = classifierExpr;
+ }
+
+ @Override
+ public <R> R accept(GinqVisitor<R> visitor) {
+ return visitor.visitGroupExpression(this);
+ }
+
+ public Expression getClassifierExpr() {
+ return classifierExpr;
+ }
+}
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index d1e99cc..17556a7 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -28,6 +28,7 @@ import org.apache.groovy.linq.dsl.expression.DataSourceExpression
import org.apache.groovy.linq.dsl.expression.FilterExpression
import org.apache.groovy.linq.dsl.expression.FromExpression
import org.apache.groovy.linq.dsl.expression.GinqExpression
+import org.apache.groovy.linq.dsl.expression.GroupExpression
import org.apache.groovy.linq.dsl.expression.JoinExpression
import org.apache.groovy.linq.dsl.expression.OnExpression
import org.apache.groovy.linq.dsl.expression.OrderExpression
@@ -51,7 +52,6 @@ import org.codehaus.groovy.ast.expr.MethodCallExpression
import org.codehaus.groovy.ast.expr.PropertyExpression
import org.codehaus.groovy.ast.expr.TupleExpression
import org.codehaus.groovy.ast.expr.VariableExpression
-import org.codehaus.groovy.ast.tools.GeneralUtils
import org.codehaus.groovy.control.SourceUnit
import org.codehaus.groovy.syntax.Types
@@ -59,6 +59,7 @@ import java.util.stream.Collectors
import static org.codehaus.groovy.ast.tools.GeneralUtils.args
import static org.codehaus.groovy.ast.tools.GeneralUtils.callX
+import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX
import static org.codehaus.groovy.ast.tools.GeneralUtils.lambdaX
import static org.codehaus.groovy.ast.tools.GeneralUtils.param
import static org.codehaus.groovy.ast.tools.GeneralUtils.params
@@ -116,14 +117,13 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
resultMethodCallExpression = fromMethodCallExpression
List<FilterExpression> filterExpressionList = fromExpression.getFilterExpressionList()
- OrderExpression orderExpression = fromExpression.getOrderExpression()
WhereExpression whereExpression = filterExpressionList.isEmpty() ? null : (WhereExpression) filterExpressionList.get(0)
- return decorateDataSourceMethodCallExpression(resultMethodCallExpression, fromExpression, whereExpression, orderExpression)
+ return decorateDataSourceMethodCallExpression(resultMethodCallExpression, fromExpression, whereExpression)
}
private MethodCallExpression decorateDataSourceMethodCallExpression(MethodCallExpression dataSourceMethodCallExpression,
- DataSourceExpression dataSourceExpression, WhereExpression whereExpression, OrderExpression orderExpression) {
+ DataSourceExpression dataSourceExpression, WhereExpression whereExpression) {
if (whereExpression) {
whereExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
whereExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
@@ -132,7 +132,16 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
dataSourceMethodCallExpression = whereMethodCallExpression
}
+ GroupExpression groupExpression = dataSourceExpression.groupExpression
+ if (groupExpression) {
+ groupExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
+ groupExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
+ MethodCallExpression groupMethodCallExpression = visitGroupExpression(groupExpression)
+ dataSourceMethodCallExpression = groupMethodCallExpression
+ }
+
+ OrderExpression orderExpression = dataSourceExpression.orderExpression
if (orderExpression) {
orderExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
orderExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
@@ -223,9 +232,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
}
- OrderExpression orderExpression = joinExpression.orderExpression
-
- return decorateDataSourceMethodCallExpression(resultMethodCallExpression, joinExpression, whereExpression, orderExpression)
+ return decorateDataSourceMethodCallExpression(resultMethodCallExpression, joinExpression, whereExpression)
}
@Override
@@ -238,6 +245,18 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
@Override
+ MethodCallExpression visitGroupExpression(GroupExpression groupExpression) {
+ DataSourceExpression dataSourceExpression = groupExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+ Expression groupMethodCallReceiver = groupExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
+ Expression classifierExpr = groupExpression.classifierExpr
+
+ List<Expression> argumentExpressionList = ((ArgumentListExpression) classifierExpr).getExpressions()
+ ConstructorCallExpression namedListCtorCallExpression = constructNamedListCtorCallExpression(argumentExpressionList)
+
+ return callXWithLambda(groupMethodCallReceiver, "groupBy2", dataSourceExpression, namedListCtorCallExpression)
+ }
+
+ @Override
MethodCallExpression visitOrderExpression(OrderExpression orderExpression) {
DataSourceExpression dataSourceExpression = orderExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
Expression orderMethodCallReceiver = orderExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
@@ -254,7 +273,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, target)
- return GeneralUtils.ctorX(ClassHelper.make(Queryable.Order.class), args(lambdaExpression, new ConstantExpression(asc)))
+ return ctorX(ClassHelper.make(Queryable.Order.class), args(lambdaExpression, new ConstantExpression(asc)))
}).collect(Collectors.toList())
return callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
@@ -302,29 +321,50 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
nameExpressionList << nameExpression
}
- ConstructorCallExpression namedListCtorCallExpression = GeneralUtils.ctorX(ClassHelper.make(NamedList.class), args(new ListExpression(elementExpressionList), new ListExpression(nameExpressionList)))
- namedListCtorCallExpression
+ ConstructorCallExpression namedListCtorCallExpression = ctorX(ClassHelper.make(NamedList.class), args(new ListExpression(elementExpressionList), new ListExpression(nameExpressionList)))
+ return namedListCtorCallExpression
}
- private static Expression correctVariablesOfGinqExpression(JoinExpression joinExpression, Expression expr) {
- DataSourceExpression dataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
- final Expression firstAliasExpr = dataSourceExpression.aliasExpr
- final Expression secondAliasExpr = joinExpression.aliasExpr
+ private static Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr) {
+ boolean isJoin = dataSourceExpression instanceof JoinExpression
+
+ DataSourceExpression otherDataSourceExpression = dataSourceExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+ final Expression firstAliasExpr = null == otherDataSourceExpression ? EmptyExpression.INSTANCE : otherDataSourceExpression.aliasExpr
+ final Expression secondAliasExpr = dataSourceExpression.aliasExpr
def correctVars = { Expression expression ->
if (expression instanceof VariableExpression) {
Expression transformedExpression = null
- if (firstAliasExpr.text == expression.text) {
- // replace `n1` with `__t.v1`
- transformedExpression = constructFirstAliasVariableAccess()
- } else if (secondAliasExpr.text == expression.text) {
- // replace `n2` with `__t.v2`
- transformedExpression = constructSecondAliasVariableAccess()
+ if (isJoin) {
+ if (firstAliasExpr.text == expression.text) {
+ // replace `n1` with `__t.v1`
+ transformedExpression = constructFirstAliasVariableAccess()
+ } else if (secondAliasExpr.text == expression.text) {
+ // replace `n2` with `__t.v2`
+ transformedExpression = constructSecondAliasVariableAccess()
+ }
+ } else { // groupby
+ // in #1, we will correct receiver of built-in aggregate functions
+ // the correct receiver is `__t.v2`, so we should not replace `__t` here
+ if (__T != expression.text) {
+ // replace `gk` in the groupby with `__t.v1.gk`, note: __t.v1 stores the group key
+ transformedExpression = propX(constructFirstAliasVariableAccess(), expression.text)
+ }
}
-
if (null != transformedExpression) {
return transformedExpression
}
+ } else if (expression instanceof MethodCallExpression) {
+ // #1
+ if (!isJoin) { // groupby
+ if (expression.implicitThis) {
+ String methodName = expression.methodAsString
+ if ('count' == methodName && ((TupleExpression) expression.arguments).getExpressions().isEmpty()) {
+ expression.objectExpression = constructSecondAliasVariableAccess()
+ return expression
+ }
+ }
+ }
}
return expression
@@ -354,16 +394,25 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
private static MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode) {
- LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode)
+ LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, receiver)
callXWithLambda(receiver, methodName, lambdaExpression)
}
private static LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+ constructLambdaExpression(dataSourceExpression, lambdaCode, null)
+ }
+
+ private static LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Expression receiver) {
+ boolean isGroup = false
+ if (receiver instanceof MethodCallExpression && receiver.methodAsString.startsWith('groupBy')) {
+ isGroup = true
+ }
+
String lambdaParamName
- if (dataSourceExpression instanceof JoinExpression) {
+ if (dataSourceExpression instanceof JoinExpression || isGroup) {
lambdaParamName = __T
- lambdaCode = correctVariablesOfGinqExpression((JoinExpression) dataSourceExpression, lambdaCode)
+ lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
} else {
lambdaParamName = dataSourceExpression.aliasExpr.text
}
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
index a99fa9f..39fa431 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
@@ -98,6 +98,7 @@ public interface Queryable<T> {
// Built-in aggregate functions {
int count();
BigDecimal sum(Function<? super T, BigDecimal> mapper);
+ <R> R agg(Function<? super Queryable<? extends T>, ? extends R> mapper);
// } Built-in aggregate functions
class Order<T, U extends Comparable<? super U>> {
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
index b5df13a..c2d4c85 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
@@ -117,6 +117,15 @@ class QueryableCollection<T> implements Queryable<T>, Iterable<T> {
return from(stream);
}
+
+ public <K> Queryable<Tuple2<K, Queryable<T>>> groupBy2(Function<? super T, ? extends K> classifier, BiPredicate<? super K, ? super Queryable<? extends T>> having) {
+ return this.groupBy(classifier, having);
+ }
+
+ public <K> Queryable<Tuple2<K, Queryable<T>>> groupBy2(Function<? super T, ? extends K> classifier) {
+ return this.groupBy(classifier);
+ }
+
@Override
public <U extends Comparable<? super U>> Queryable<T> orderBy(Order<? super T, ? extends U>... orders) {
Comparator<T> comparator = null;
@@ -190,12 +199,17 @@ class QueryableCollection<T> implements Queryable<T>, Iterable<T> {
@Override
public int count() {
- return toList().size();
+ return agg(q -> q.toList().size());
}
@Override
public BigDecimal sum(Function<? super T, BigDecimal> mapper) {
- return this.stream().map(mapper).reduce(BigDecimal.ZERO, BigDecimal::add);
+ return agg(q -> this.stream().map(mapper).reduce(BigDecimal.ZERO, BigDecimal::add));
+ }
+
+ @Override
+ public <R> R agg(Function<? super Queryable<? extends T>, ? extends R> mapper) {
+ return mapper.apply(this);
}
private static <T, U> Queryable<Tuple2<T, U>> outerJoin(Queryable<? extends T> queryable1, Queryable<? extends U> queryable2, BiPredicate<? super T, ? super U> joiner) {
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 5ee2349..67f0be3 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -1238,6 +1238,17 @@ class GinqTest {
'''
}
+ @Test
+ void "testGinq - from groupBy select - 1"() {
+ assertScript '''
+ assert [[1, 2], [3, 2], [6, 3]] == GINQ {
+ from n in [1, 1, 3, 3, 6, 6, 6]
+ groupby n
+ select n, count() // reference the column `n` in the groupby clause, and `count()` is a built-in aggregate function
+ }.toList()
+ '''
+ }
+
@CompileDynamic
@Test
void "testGinq - query json - 1"() {