You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/08 14:13:18 UTC

[groovy] branch GROOVY-8258 updated: GROOVY-8258: support groupby

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/GROOVY-8258 by this push:
     new ad5c26a  GROOVY-8258: support groupby
ad5c26a is described below

commit ad5c26a21beb0aa576cdc22d3e69e6316e1fbf4e
Author: Daniel Sun <su...@apache.org>
AuthorDate: Thu Oct 8 22:12:51 2020 +0800

    GROOVY-8258: support groupby
---
 .../org/apache/groovy/linq/dsl/GinqAstBuilder.java | 14 ++++
 .../org/apache/groovy/linq/dsl/GinqVisitor.java    |  2 +
 .../linq/dsl/expression/DataSourceExpression.java  |  9 ++
 .../linq/dsl/expression/GroupExpression.java       | 44 ++++++++++
 .../linq/provider/collection/GinqAstWalker.groovy  | 97 ++++++++++++++++------
 .../groovy/linq/provider/collection/Queryable.java |  1 +
 .../provider/collection/QueryableCollection.java   | 18 +++-
 .../groovy/org/apache/groovy/linq/GinqTest.groovy  | 11 +++
 8 files changed, 170 insertions(+), 26 deletions(-)

diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
index 5b9fa5d..ce3874c 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
@@ -23,6 +23,7 @@ import org.apache.groovy.linq.dsl.expression.DataSourceExpression;
 import org.apache.groovy.linq.dsl.expression.FilterExpression;
 import org.apache.groovy.linq.dsl.expression.FromExpression;
 import org.apache.groovy.linq.dsl.expression.GinqExpression;
+import org.apache.groovy.linq.dsl.expression.GroupExpression;
 import org.apache.groovy.linq.dsl.expression.JoinExpression;
 import org.apache.groovy.linq.dsl.expression.OnExpression;
 import org.apache.groovy.linq.dsl.expression.OrderExpression;
@@ -140,6 +141,19 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
             return;
         }
 
+        if ("groupby".equals(methodName)) {
+            GroupExpression groupExpression = new GroupExpression(call.getArguments());
+            groupExpression.setSourcePosition(call);
+
+            if (ginqExpression instanceof DataSourceExpression) {
+                ((DataSourceExpression) ginqExpression).setGroupExpression(groupExpression);
+            } else {
+                throw new GroovyBugError("The preceding expression is not a DataSourceExpression: " + ginqExpression);
+            }
+
+            return;
+        }
+
         if ("orderby".equals(methodName)) {
             OrderExpression orderExpression = new OrderExpression(call.getArguments());
             orderExpression.setSourcePosition(call);
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
index 3647b31..8515b71 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqVisitor.java
@@ -21,6 +21,7 @@ package org.apache.groovy.linq.dsl;
 import org.apache.groovy.linq.dsl.expression.AbstractGinqExpression;
 import org.apache.groovy.linq.dsl.expression.FromExpression;
 import org.apache.groovy.linq.dsl.expression.GinqExpression;
+import org.apache.groovy.linq.dsl.expression.GroupExpression;
 import org.apache.groovy.linq.dsl.expression.JoinExpression;
 import org.apache.groovy.linq.dsl.expression.OnExpression;
 import org.apache.groovy.linq.dsl.expression.OrderExpression;
@@ -39,6 +40,7 @@ public interface GinqVisitor<R> {
     R visitJoinExpression(JoinExpression joinExpression);
     R visitOnExpression(OnExpression onExpression);
     R visitWhereExpression(WhereExpression whereExpression);
+    R visitGroupExpression(GroupExpression groupExpression);
     R visitOrderExpression(OrderExpression orderExpression);
     R visitSelectExpression(SelectExpression selectExpression);
     R visit(AbstractGinqExpression expression);
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
index 71fcd1a..850372b 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/DataSourceExpression.java
@@ -32,6 +32,7 @@ public abstract class DataSourceExpression extends AbstractGinqExpression {
     protected Expression aliasExpr;
     protected Expression dataSourceExpr;
     protected final List<FilterExpression> filterExpressionList = new ArrayList<>(); // on, where
+    protected GroupExpression groupExpression;
     protected OrderExpression orderExpression;
 
     public DataSourceExpression(Expression aliasExpr, Expression dataSourceExpr) {
@@ -54,6 +55,14 @@ public abstract class DataSourceExpression extends AbstractGinqExpression {
         this.filterExpressionList.add(filterExpression);
     }
 
+    public GroupExpression getGroupExpression() {
+        return groupExpression;
+    }
+
+    public void setGroupExpression(GroupExpression groupExpression) {
+        this.groupExpression = groupExpression;
+    }
+
     public OrderExpression getOrderExpression() {
         return orderExpression;
     }
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java
new file mode 100644
index 0000000..d5d56c8
--- /dev/null
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/expression/GroupExpression.java
@@ -0,0 +1,44 @@
+/*
+ *  Licensed to the Apache Software Foundation (ASF) under one
+ *  or more contributor license agreements.  See the NOTICE file
+ *  distributed with this work for additional information
+ *  regarding copyright ownership.  The ASF licenses this file
+ *  to you under the Apache License, Version 2.0 (the
+ *  "License"); you may not use this file except in compliance
+ *  with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing,
+ *  software distributed under the License is distributed on an
+ *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ *  KIND, either express or implied.  See the License for the
+ *  specific language governing permissions and limitations
+ *  under the License.
+ */
+package org.apache.groovy.linq.dsl.expression;
+
+import org.apache.groovy.linq.dsl.GinqVisitor;
+import org.codehaus.groovy.ast.expr.Expression;
+
+/**
+ * Represents group by expression
+ *
+ * @since 4.0.0
+ */
+public class GroupExpression extends AbstractGinqExpression {
+    private final Expression classifierExpr;
+
+    public GroupExpression(Expression classifierExpr) {
+        this.classifierExpr = classifierExpr;
+    }
+
+    @Override
+    public <R> R accept(GinqVisitor<R> visitor) {
+        return visitor.visitGroupExpression(this);
+    }
+
+    public Expression getClassifierExpr() {
+        return classifierExpr;
+    }
+}
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index d1e99cc..17556a7 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -28,6 +28,7 @@ import org.apache.groovy.linq.dsl.expression.DataSourceExpression
 import org.apache.groovy.linq.dsl.expression.FilterExpression
 import org.apache.groovy.linq.dsl.expression.FromExpression
 import org.apache.groovy.linq.dsl.expression.GinqExpression
+import org.apache.groovy.linq.dsl.expression.GroupExpression
 import org.apache.groovy.linq.dsl.expression.JoinExpression
 import org.apache.groovy.linq.dsl.expression.OnExpression
 import org.apache.groovy.linq.dsl.expression.OrderExpression
@@ -51,7 +52,6 @@ import org.codehaus.groovy.ast.expr.MethodCallExpression
 import org.codehaus.groovy.ast.expr.PropertyExpression
 import org.codehaus.groovy.ast.expr.TupleExpression
 import org.codehaus.groovy.ast.expr.VariableExpression
-import org.codehaus.groovy.ast.tools.GeneralUtils
 import org.codehaus.groovy.control.SourceUnit
 import org.codehaus.groovy.syntax.Types
 
@@ -59,6 +59,7 @@ import java.util.stream.Collectors
 
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args
 import static org.codehaus.groovy.ast.tools.GeneralUtils.callX
+import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.lambdaX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.param
 import static org.codehaus.groovy.ast.tools.GeneralUtils.params
@@ -116,14 +117,13 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
         resultMethodCallExpression = fromMethodCallExpression
 
         List<FilterExpression> filterExpressionList = fromExpression.getFilterExpressionList()
-        OrderExpression orderExpression = fromExpression.getOrderExpression()
         WhereExpression whereExpression = filterExpressionList.isEmpty() ? null : (WhereExpression) filterExpressionList.get(0)
 
-        return decorateDataSourceMethodCallExpression(resultMethodCallExpression, fromExpression, whereExpression, orderExpression)
+        return decorateDataSourceMethodCallExpression(resultMethodCallExpression, fromExpression, whereExpression)
     }
 
     private MethodCallExpression decorateDataSourceMethodCallExpression(MethodCallExpression dataSourceMethodCallExpression,
-                                                                        DataSourceExpression dataSourceExpression, WhereExpression whereExpression, OrderExpression orderExpression) {
+                                                                        DataSourceExpression dataSourceExpression, WhereExpression whereExpression) {
         if (whereExpression) {
             whereExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
             whereExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
@@ -132,7 +132,16 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
             dataSourceMethodCallExpression = whereMethodCallExpression
         }
 
+        GroupExpression groupExpression = dataSourceExpression.groupExpression
+        if (groupExpression) {
+            groupExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
+            groupExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
 
+            MethodCallExpression groupMethodCallExpression = visitGroupExpression(groupExpression)
+            dataSourceMethodCallExpression = groupMethodCallExpression
+        }
+
+        OrderExpression orderExpression = dataSourceExpression.orderExpression
         if (orderExpression) {
             orderExpression.putNodeMetaData(__DATA_SOURCE_EXPRESSION, dataSourceExpression)
             orderExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, dataSourceMethodCallExpression)
@@ -223,9 +232,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
             }
         }
 
-        OrderExpression orderExpression = joinExpression.orderExpression
-
-        return decorateDataSourceMethodCallExpression(resultMethodCallExpression, joinExpression, whereExpression, orderExpression)
+        return decorateDataSourceMethodCallExpression(resultMethodCallExpression, joinExpression, whereExpression)
     }
 
     @Override
@@ -238,6 +245,18 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
     }
 
     @Override
+    MethodCallExpression visitGroupExpression(GroupExpression groupExpression) {
+        DataSourceExpression dataSourceExpression = groupExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+        Expression groupMethodCallReceiver = groupExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
+        Expression classifierExpr = groupExpression.classifierExpr
+
+        List<Expression> argumentExpressionList = ((ArgumentListExpression) classifierExpr).getExpressions()
+        ConstructorCallExpression namedListCtorCallExpression = constructNamedListCtorCallExpression(argumentExpressionList)
+
+        return callXWithLambda(groupMethodCallReceiver, "groupBy2", dataSourceExpression, namedListCtorCallExpression)
+    }
+
+    @Override
     MethodCallExpression visitOrderExpression(OrderExpression orderExpression) {
         DataSourceExpression dataSourceExpression = orderExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
         Expression orderMethodCallReceiver = orderExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
@@ -254,7 +273,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
 
             LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, target)
 
-            return GeneralUtils.ctorX(ClassHelper.make(Queryable.Order.class), args(lambdaExpression, new ConstantExpression(asc)))
+            return ctorX(ClassHelper.make(Queryable.Order.class), args(lambdaExpression, new ConstantExpression(asc)))
         }).collect(Collectors.toList())
 
         return callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
@@ -302,29 +321,50 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
             nameExpressionList << nameExpression
         }
 
-        ConstructorCallExpression namedListCtorCallExpression = GeneralUtils.ctorX(ClassHelper.make(NamedList.class), args(new ListExpression(elementExpressionList), new ListExpression(nameExpressionList)))
-        namedListCtorCallExpression
+        ConstructorCallExpression namedListCtorCallExpression = ctorX(ClassHelper.make(NamedList.class), args(new ListExpression(elementExpressionList), new ListExpression(nameExpressionList)))
+        return namedListCtorCallExpression
     }
 
-    private static Expression correctVariablesOfGinqExpression(JoinExpression joinExpression, Expression expr) {
-        DataSourceExpression dataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
-        final Expression firstAliasExpr = dataSourceExpression.aliasExpr
-        final Expression secondAliasExpr = joinExpression.aliasExpr
+    private static Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr) {
+        boolean isJoin = dataSourceExpression instanceof JoinExpression
+
+        DataSourceExpression otherDataSourceExpression = dataSourceExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+        final Expression firstAliasExpr = null == otherDataSourceExpression ? EmptyExpression.INSTANCE : otherDataSourceExpression.aliasExpr
+        final Expression secondAliasExpr = dataSourceExpression.aliasExpr
 
         def correctVars = { Expression expression ->
             if (expression instanceof VariableExpression) {
                 Expression transformedExpression = null
-                if (firstAliasExpr.text == expression.text) {
-                    // replace `n1` with `__t.v1`
-                    transformedExpression = constructFirstAliasVariableAccess()
-                } else if (secondAliasExpr.text == expression.text) {
-                    // replace `n2` with `__t.v2`
-                    transformedExpression = constructSecondAliasVariableAccess()
+                if (isJoin) {
+                    if (firstAliasExpr.text == expression.text) {
+                        // replace `n1` with `__t.v1`
+                        transformedExpression = constructFirstAliasVariableAccess()
+                    } else if (secondAliasExpr.text == expression.text) {
+                        // replace `n2` with `__t.v2`
+                        transformedExpression = constructSecondAliasVariableAccess()
+                    }
+                } else { //  groupby
+                    // in #1, we will correct receiver of built-in aggregate functions
+                    // the correct receiver is `__t.v2`, so we should not replace `__t` here
+                    if (__T != expression.text) {
+                        // replace `gk` in the groupby with `__t.v1.gk`, note: __t.v1 stores the group key
+                        transformedExpression = propX(constructFirstAliasVariableAccess(), expression.text)
+                    }
                 }
-
                 if (null != transformedExpression) {
                     return transformedExpression
                 }
+            } else if (expression instanceof MethodCallExpression) {
+                // #1
+                if (!isJoin) { // groupby
+                    if (expression.implicitThis) {
+                        String methodName = expression.methodAsString
+                        if ('count' == methodName && ((TupleExpression) expression.arguments).getExpressions().isEmpty()) {
+                            expression.objectExpression = constructSecondAliasVariableAccess()
+                            return expression
+                        }
+                    }
+                }
             }
 
             return expression
@@ -354,16 +394,25 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
     }
 
     private static MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode) {
-        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode)
+        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, receiver)
 
         callXWithLambda(receiver, methodName, lambdaExpression)
     }
 
     private static LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+        constructLambdaExpression(dataSourceExpression, lambdaCode, null)
+    }
+
+    private static LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Expression receiver) {
+        boolean isGroup = false
+        if (receiver instanceof MethodCallExpression && receiver.methodAsString.startsWith('groupBy')) {
+            isGroup = true
+        }
+
         String lambdaParamName
-        if (dataSourceExpression instanceof JoinExpression) {
+        if (dataSourceExpression instanceof JoinExpression || isGroup) {
             lambdaParamName = __T
-            lambdaCode = correctVariablesOfGinqExpression((JoinExpression) dataSourceExpression, lambdaCode)
+            lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
         } else {
             lambdaParamName = dataSourceExpression.aliasExpr.text
         }
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
index a99fa9f..39fa431 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
@@ -98,6 +98,7 @@ public interface Queryable<T> {
     //  Built-in aggregate functions {
     int count();
     BigDecimal sum(Function<? super T, BigDecimal> mapper);
+    <R> R agg(Function<? super Queryable<? extends T>, ? extends R> mapper);
     // } Built-in aggregate functions
 
     class Order<T, U extends Comparable<? super U>> {
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
index b5df13a..c2d4c85 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
@@ -117,6 +117,15 @@ class QueryableCollection<T> implements Queryable<T>, Iterable<T> {
         return from(stream);
     }
 
+
+    public <K> Queryable<Tuple2<K, Queryable<T>>> groupBy2(Function<? super T, ? extends K> classifier, BiPredicate<? super K, ? super Queryable<? extends T>> having) {
+        return this.groupBy(classifier, having);
+    }
+
+    public <K> Queryable<Tuple2<K, Queryable<T>>> groupBy2(Function<? super T, ? extends K> classifier) {
+        return this.groupBy(classifier);
+    }
+
     @Override
     public <U extends Comparable<? super U>> Queryable<T> orderBy(Order<? super T, ? extends U>... orders) {
         Comparator<T> comparator = null;
@@ -190,12 +199,17 @@ class QueryableCollection<T> implements Queryable<T>, Iterable<T> {
 
     @Override
     public int count() {
-        return toList().size();
+        return agg(q -> q.toList().size());
     }
 
     @Override
     public BigDecimal sum(Function<? super T, BigDecimal> mapper) {
-        return this.stream().map(mapper).reduce(BigDecimal.ZERO, BigDecimal::add);
+        return agg(q -> this.stream().map(mapper).reduce(BigDecimal.ZERO, BigDecimal::add));
+    }
+
+    @Override
+    public <R> R agg(Function<? super Queryable<? extends T>, ? extends R> mapper) {
+        return mapper.apply(this);
     }
 
     private static <T, U> Queryable<Tuple2<T, U>> outerJoin(Queryable<? extends T> queryable1, Queryable<? extends U> queryable2, BiPredicate<? super T, ? super U> joiner) {
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 5ee2349..67f0be3 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -1238,6 +1238,17 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - from groupBy select - 1"() {
+        assertScript '''
+            assert [[1, 2], [3, 2], [6, 3]] == GINQ {
+                from n in [1, 1, 3, 3, 6, 6, 6]
+                groupby n
+                select n, count() // reference the column `n` in the groupby clause, and `count()` is a built-in aggregate function
+            }.toList()
+        '''
+    }
+
     @CompileDynamic
     @Test
     void "testGinq - query json - 1"() {