You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/22 06:20:55 UTC

[groovy] 03/03: GROOVY-9787: add the most powerful aggregate function `agg`

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 5d189c19c7079b268d15340f2d4fd5f216f731ae
Author: Daniel Sun <su...@apache.org>
AuthorDate: Thu Oct 22 14:18:46 2020 +0800

    GROOVY-9787: add the most powerful aggregate function `agg`
---
 .../groovy/linq/provider/collection/GinqAstWalker.groovy    | 13 +++++++++----
 .../src/test/groovy/org/apache/groovy/linq/GinqTest.groovy  | 11 +++++++++++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index f158ef5..8b9436c 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -434,15 +434,20 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
             Expression transformedExpression = null
             if (expression instanceof VariableExpression) {
                 if (expression.isThisExpression()) return expression
+                if (expression.text && Character.isUpperCase(expression.text.charAt(0))) return expression // type should be transformed
 
                 if (isGroup) { //  groupby
                     // in #1, we will correct receiver of built-in aggregate functions
                     // the correct receiver is `__t.v2`, so we should not replace `__t` here
                     if (lambdaParamName != expression.text) {
                         if (visitingAggregateFunction) {
-                            transformedExpression = isJoin
-                                    ? correctVarsForJoin(expression, new VariableExpression(lambdaParamName))
-                                    : new VariableExpression(lambdaParamName)
+                            if ('_q' == expression.text) {
+                                transformedExpression = new VariableExpression(lambdaParamName)
+                            } else {
+                                transformedExpression = isJoin
+                                        ? correctVarsForJoin(expression, new VariableExpression(lambdaParamName))
+                                        : new VariableExpression(lambdaParamName)
+                            }
                         } else {
                             // replace `gk` in the groupby with `__t.v1.gk`, note: __t.v1 stores the group key
                             transformedExpression = propX(propX(new VariableExpression(lambdaParamName), 'v1'), expression.text)
@@ -461,7 +466,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
                             expression.objectExpression = propX(new VariableExpression(lambdaParamName), 'v2')
                             transformedExpression = expression
                             visitingAggregateFunction = false
-                        } else if (methodName in ['count', 'sum'] && 1 == ((TupleExpression) expression.arguments).getExpressions().size()) {
+                        } else if (methodName in ['count', 'sum', 'agg'] && 1 == ((TupleExpression) expression.arguments).getExpressions().size()) {
                             visitingAggregateFunction = true
                             Expression lambdaCode = ((TupleExpression) expression.arguments).getExpression(0)
                             lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, findRootObjectExpression(lambdaCode).text)
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 5cdbc48..643bee2 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -1612,6 +1612,17 @@ class GinqTest {
     }
 
     @Test
+    void "testGinq - from groupby select - 7"() {
+        assertScript '''
+            assert [[1, 2], [3, 6], [6, 18]] == GINQ {
+                from n in [1, 1, 3, 3, 6, 6, 6]
+                groupby n
+                select n, agg(_q.stream().map(e -> e).reduce(BigDecimal.ZERO, BigDecimal::add)) // the most powerful aggregate function, `_q` represents the grouped Queryable object
+            }.toList()
+        '''
+    }
+
+    @Test
     void "testGinq - from where groupby select - 1"() {
         assertScript '''
             assert [[1, 2], [6, 3]] == GINQ {