You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/09 10:16:08 UTC

[groovy] branch GROOVY-8258 updated: GROOVY-8258: support multiple joins

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/GROOVY-8258 by this push:
     new 46803f3  GROOVY-8258: support multiple joins
46803f3 is described below

commit 46803f30e70b320dd1a17264d0d1e6fef44caf5d
Author: Daniel Sun <su...@apache.org>
AuthorDate: Fri Oct 9 17:15:05 2020 +0800

    GROOVY-8258: support multiple joins
---
 .../linq/provider/collection/GinqAstWalker.groovy  | 106 ++++++++++++++-------
 .../groovy/org/apache/groovy/linq/GinqTest.groovy  |  49 +++++++++-
 2 files changed, 117 insertions(+), 38 deletions(-)

diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index 354e65e..9a63d2e 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -158,8 +158,6 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
     @Override
     MethodCallExpression visitJoinExpression(JoinExpression joinExpression) {
         Expression receiver = joinExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
-        DataSourceExpression dataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
-        Expression receiverAliasExpr = dataSourceExpression.aliasExpr
         List<FilterExpression> filterExpressionList = joinExpression.getFilterExpressionList()
         int filterExpressionListSize = filterExpressionList.size()
 
@@ -183,7 +181,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
         }
         WhereExpression whereExpression = filterExpressionList.size() < (whereExpressionPos + 1) ? null : (WhereExpression) filterExpressionList.get(whereExpressionPos)
 
-        MethodCallExpression joinMethodCallExpression = constructJoinMethodCallExpression(receiver, receiverAliasExpr, joinExpression, onExpression, whereExpression)
+        MethodCallExpression joinMethodCallExpression = constructJoinMethodCallExpression(receiver, joinExpression, onExpression, whereExpression)
 
         return joinMethodCallExpression
     }
@@ -209,18 +207,31 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
     }
 
     private MethodCallExpression constructJoinMethodCallExpression(
-            Expression receiver, Expression receiverAliasExpr, JoinExpression joinExpression,
+            Expression receiver, JoinExpression joinExpression,
             OnExpression onExpression, WhereExpression whereExpression) {
+
+        DataSourceExpression otherDataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+        Expression otherAliasExpr = otherDataSourceExpression.aliasExpr
+
+        String otherParamName = otherAliasExpr.text
+        Expression filterExpr = EmptyExpression.INSTANCE
+        if (onExpression) {
+            filterExpr = onExpression.getFilterExpr()
+            Tuple2<String, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(otherDataSourceExpression, filterExpr)
+            otherParamName = paramNameAndLambdaCode.v1
+            filterExpr = paramNameAndLambdaCode.v2
+        }
+
         MethodCallExpression resultMethodCallExpression
         MethodCallExpression joinMethodCallExpression = callX(receiver, joinExpression.joinName.replace('join', 'Join'),
                 args(
                         constructFromMethodCallExpression(joinExpression.dataSourceExpr),
                         null == onExpression ? EmptyExpression.INSTANCE : lambdaX(
                                 params(
-                                        param(ClassHelper.DYNAMIC_TYPE, receiverAliasExpr.text),
+                                        param(ClassHelper.DYNAMIC_TYPE, otherParamName),
                                         param(ClassHelper.DYNAMIC_TYPE, joinExpression.aliasExpr.text)
                                 ),
-                                stmt(onExpression.getFilterExpr())
+                                stmt(filterExpr)
                         )
                 )
         )
@@ -330,13 +341,10 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
         return namedListCtorCallExpression
     }
 
-    private static Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr, boolean isGroup) {
+    private Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr) {
+        boolean isGroup = isGroupByVisited()
         boolean isJoin = dataSourceExpression instanceof JoinExpression
 
-        DataSourceExpression otherDataSourceExpression = dataSourceExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
-        final Expression firstAliasExpr = null == otherDataSourceExpression ? EmptyExpression.INSTANCE : otherDataSourceExpression.aliasExpr
-        final Expression secondAliasExpr = dataSourceExpression.aliasExpr
-
         def correctVars = { Expression expression ->
             if (expression instanceof VariableExpression) {
                 Expression transformedExpression = null
@@ -346,15 +354,41 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
                     // the correct receiver is `__t.v2`, so we should not replace `__t` here
                     if (__T != expression.text) {
                         // replace `gk` in the groupby with `__t.v1.gk`, note: __t.v1 stores the group key
-                        transformedExpression = propX(constructFirstAliasVariableAccess(), expression.text)
+                        transformedExpression = propX(propX(new VariableExpression(__T), 'v1'), expression.text)
                     }
                 } else if (isJoin) {
-                    if (firstAliasExpr.text == expression.text) {
-                        // replace `n1` with `__t.v1`
-                        transformedExpression = constructFirstAliasVariableAccess()
-                    } else if (secondAliasExpr.text == expression.text) {
-                        // replace `n2` with `__t.v2`
-                        transformedExpression = constructSecondAliasVariableAccess()
+                    /*
+                     * `n1`(from node) join `n2` join `n3`  will construct a join tree:
+                     *
+                     *  __t (join node)
+                     *    |__ v2 (n3)
+                     *    |__ v1 (join node)
+                     *         |__ v2 (n2)
+                     *         |__ v1 (n1) (from node)
+                     *
+                     * Note: `__t` is a tuple with 2 elements
+                     * so  `n3`'s access path is `__t.v2`
+                     * and `n2`'s access path is `__t.v1.v2`
+                     * and `n1`'s access path is `__t.v1.v1`
+                     *
+                     * The following code shows how to construct the access path for variables
+                     */
+                    def prop = new VariableExpression(__T)
+                    for (DataSourceExpression dse = dataSourceExpression;
+                         null == transformedExpression && dse instanceof JoinExpression;
+                         dse = dse.getNodeMetaData(__DATA_SOURCE_EXPRESSION)) {
+
+                        DataSourceExpression otherDataSourceExpression = dse.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+                        Expression firstAliasExpr = otherDataSourceExpression?.aliasExpr ?: EmptyExpression.INSTANCE
+                        Expression secondAliasExpr = dse.aliasExpr
+
+                        if (firstAliasExpr.text == expression.text && otherDataSourceExpression !instanceof JoinExpression) {
+                            transformedExpression = propX(prop, 'v1')
+                        } else if (secondAliasExpr.text == expression.text) {
+                            transformedExpression = propX(prop, 'v2')
+                        } else { // not found
+                            prop = propX(prop, 'v1')
+                        }
                     }
                 }
                 if (null != transformedExpression) {
@@ -366,7 +400,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
                     if (expression.implicitThis) {
                         String methodName = expression.methodAsString
                         if ('count' == methodName && ((TupleExpression) expression.arguments).getExpressions().isEmpty()) {
-                            expression.objectExpression = constructSecondAliasVariableAccess()
+                            expression.objectExpression = propX(new VariableExpression(__T), 'v2')
                             return expression
                         }
                     }
@@ -406,20 +440,30 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
     }
 
     private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
-        boolean isGroup = currentGinqExpression.getNodeMetaData(__GROUP_BY) ?: false
+        Tuple2<String, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode)
+
+        lambdaX(
+                params(param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)),
+                stmt(paramNameAndLambdaCode.v2)
+        )
+    }
+
+    private Tuple2<String, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+        boolean isGroup = isGroupByVisited()
 
         String lambdaParamName
         if (dataSourceExpression instanceof JoinExpression || isGroup) {
             lambdaParamName = __T
-            lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode, isGroup)
+            lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
         } else {
             lambdaParamName = dataSourceExpression.aliasExpr.text
         }
 
-        lambdaX(
-                params(param(ClassHelper.DYNAMIC_TYPE, lambdaParamName)),
-                stmt(lambdaCode)
-        )
+        return Tuple.tuple(lambdaParamName, lambdaCode)
+    }
+
+    private boolean isGroupByVisited() {
+        return currentGinqExpression.getNodeMetaData(__GROUP_BY) ?: false
     }
 
     private static MethodCallExpression callXWithLambda(Expression receiver, String methodName, LambdaExpression lambdaExpression) {
@@ -430,18 +474,6 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
         )
     }
 
-    private static Expression constructFirstAliasVariableAccess() {
-        constructAliasVariableAccess('v1')
-    }
-
-    private static Expression constructSecondAliasVariableAccess() {
-        constructAliasVariableAccess('v2')
-    }
-
-    private static Expression constructAliasVariableAccess(String name) {
-        propX(new VariableExpression(__T), name)
-    }
-
     private static makeQueryableCollectionClassExpression() {
         new ClassExpression(ClassHelper.make(Queryable.class))
     }
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 22d274e..f6da498 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -25,7 +25,6 @@ import org.junit.Test
 
 import static groovy.test.GroovyAssert.assertScript
 
-
 @CompileStatic
 class GinqTest {
     @Test
@@ -312,6 +311,54 @@ class GinqTest {
     }
 
     @Test
+    void "testGinq - from innerjoin select - 11"() {
+        assertScript '''
+            def nums1 = [1, 2, 3]
+            def nums2 = [2, 3, 4]
+            def nums3 = [3, 4, 5]
+            assert [[3, 3, 3]] == GINQ {
+                from n1 in nums1
+                innerjoin n2 in nums2 on n2 == n1
+                innerjoin n3 in nums3 on n3 == n2
+                select n1, n2, n3
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - from innerjoin select - 12"() {
+        assertScript '''
+            def nums1 = [1, 2, 3]
+            def nums2 = [2, 3, 4]
+            def nums3 = [3, 4, 5]
+            assert [[3, 3, 3]] == GINQ {
+                from n1 in nums1
+                innerjoin n2 in nums2 on n2 == n1
+                innerjoin n3 in nums3 on n3 == n1
+                select n1, n2, n3
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - from innerjoin select - 13"() {
+        assertScript '''
+            def nums1 = [1, 2, 3]
+            def nums2 = [2, 3, 4]
+            def nums3 = [3, 4, 5]
+            assert [[3, 3, 3]] == GINQ {
+                from v in (
+                    from n1 in nums1
+                    innerjoin n2 in nums2 on n1 == n2
+                    select n1, n2
+                )
+                innerjoin n3 in nums3 on v.n2 == n3
+                select v.n1, v.n2, n3
+            }.toList()
+        '''
+    }
+
+    @Test
     void "testGinq - from innerjoin where select - 1"() {
         assertScript '''
             def nums1 = [1, 2, 3]