You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/07 10:00:24 UTC

[groovy] 02/02: GROOVY-8258: support nested linq

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 463ae780480aabfaa79bef1c2d197f0503cfac6d
Author: Daniel Sun <su...@apache.org>
AuthorDate: Wed Oct 7 17:59:34 2020 +0800

    GROOVY-8258: support nested linq
---
 .../org/apache/groovy/linq/dsl/GinqAstBuilder.java | 20 ++++--
 .../linq/provider/collection/GinqAstWalker.groovy  | 12 +++-
 .../groovy/linq/provider/collection/Queryable.java |  4 --
 .../provider/collection/QueryableCollection.java   | 10 ++-
 .../groovy/org/apache/groovy/linq/GinqTest.groovy  | 77 ++++++++++++++++++++++
 .../collection/QueryableCollectionTest.groovy      |  1 -
 6 files changed, 109 insertions(+), 15 deletions(-)

diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
index ee37ce5..2e27d2b 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
@@ -43,7 +43,8 @@ import org.codehaus.groovy.syntax.Types;
  * @since 4.0.0
  */
 public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorReportable {
-    private SimpleGinqExpression simpleGinqExpression = new SimpleGinqExpression(); // store the result
+    private SimpleGinqExpression currentSimpleGinqExpression;
+    private SimpleGinqExpression latestSimpleGinqExpression;
     private GinqExpression ginqExpression; // store the return value
     private final SourceUnit sourceUnit;
 
@@ -52,7 +53,7 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
     }
 
     public SimpleGinqExpression getSimpleGinqExpression() {
-        return simpleGinqExpression;
+        return latestSimpleGinqExpression;
     }
 
     @Override
@@ -60,6 +61,10 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
         super.visitMethodCallExpression(call);
         final String methodName = call.getMethodAsString();
 
+        if ("from".equals(methodName)) {
+            currentSimpleGinqExpression = new SimpleGinqExpression(); // store the result
+        }
+
         if ("from".equals(methodName)  || "innerJoin".equals(methodName)) {
             ArgumentListExpression arguments = (ArgumentListExpression) call.getArguments();
             if (arguments.getExpressions().size() != 1) {
@@ -82,15 +87,15 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
             }
             BinaryExpression binaryExpression = (BinaryExpression) expression;
             Expression aliasExpr = binaryExpression.getLeftExpression();
-            Expression dataSourceExpr = binaryExpression.getRightExpression();
+            Expression dataSourceExpr = null == latestSimpleGinqExpression ? binaryExpression.getRightExpression() : latestSimpleGinqExpression;
 
             FilterableExpression filterableExpression = null;
             if ("from".equals(methodName)) {
                 filterableExpression = new FromExpression(aliasExpr, dataSourceExpr);
-                simpleGinqExpression.setFromExpression((FromExpression) filterableExpression);
+                currentSimpleGinqExpression.setFromExpression((FromExpression) filterableExpression);
             } else if ("innerJoin".equals(methodName)) {
                 filterableExpression = new InnerJoinExpression(aliasExpr, dataSourceExpr);
-                simpleGinqExpression.addJoinExpression((JoinExpression) filterableExpression);
+                currentSimpleGinqExpression.addJoinExpression((JoinExpression) filterableExpression);
             }
             filterableExpression.setSourcePosition(call);
             ginqExpression = filterableExpression;
@@ -127,9 +132,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
             SelectExpression selectExpression = new SelectExpression(call.getArguments());
             selectExpression.setSourcePosition(call);
 
-            simpleGinqExpression.setSelectExpression(selectExpression);
+            currentSimpleGinqExpression.setSelectExpression(selectExpression);
             ginqExpression = selectExpression;
 
+            latestSimpleGinqExpression = currentSimpleGinqExpression;
+            currentSimpleGinqExpression = null;
+
             return;
         }
     }
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index 3dd1188..e0453c1 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -147,9 +147,17 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
 
     @CompileDynamic
     private MethodCallExpression constructFromMethodCallExpression(FromExpression fromExpression) {
-        macro {
-            $v{ makeQueryableCollectionClassExpression() }.from($v { fromExpression.dataSourceExpr })
+        MethodCallExpression fromMethodCallExpression = macro {
+            $v{ makeQueryableCollectionClassExpression() }.from($v {
+                if (fromExpression.dataSourceExpr instanceof SimpleGinqExpression) {
+                    return this.visitSimpleGinqExpression((SimpleGinqExpression) fromExpression.dataSourceExpr)
+                } else {
+                    return fromExpression.dataSourceExpr
+                }
+            })
         }
+
+        return fromMethodCallExpression
     }
 
     @CompileDynamic
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
index 6e3e0a0..3f44c60 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/Queryable.java
@@ -45,10 +45,6 @@ public interface Queryable<T> {
         return new QueryableCollection<>(sourceStream);
     }
 
-    static <T> Queryable<T> from(Queryable<T> queryable) {
-        return queryable;
-    }
-
     <U> Queryable<Tuple2<T, U>> innerJoin(Queryable<? extends U> queryable, BiPredicate<? super T, ? super U> joiner);
 
     <U> Queryable<Tuple2<T, U>> leftJoin(Queryable<? extends U> queryable, BiPredicate<? super T, ? super U> joiner);
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
index 44681ae..645bb08 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/QueryableCollection.java
@@ -48,8 +48,14 @@ class QueryableCollection<T> implements Queryable<T>, Iterable<T> {
     private Stream<T> sourceStream;
 
     QueryableCollection(Iterable<T> sourceIterable) {
-        this.sourceIterable = sourceIterable;
-        this.sourceStream = toStream(sourceIterable);
+        if (sourceIterable instanceof QueryableCollection) {
+            QueryableCollection<T> queryableCollection = (QueryableCollection<T>) sourceIterable;
+            this.sourceIterable = queryableCollection.sourceIterable;
+            this.sourceStream = queryableCollection.sourceStream;
+        } else {
+            this.sourceIterable = sourceIterable;
+            this.sourceStream = toStream(sourceIterable);
+        }
     }
 
     @SuppressWarnings("unchecked")
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 3120276..f34bc3f 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -387,4 +387,81 @@ class GinqTest {
             }.toList()
         '''
     }
+
+    @Test
+    void "testGinq - nested from - 0"() {
+        assertScript '''
+            assert [1, 2, 3] == GINQ {
+                from v in (
+                    from n in [1, 2, 3]
+                    select n
+                )
+                select v
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from - 1"() {
+        assertScript '''
+            def numbers = [1, 2, 3]
+            assert [1, 2, 3] == GINQ {
+                from v in (
+                    from n in numbers
+                    select n
+                )
+                select v
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from - 2"() {
+        assertScript '''
+            def numbers = [1, 2, 3]
+            assert [1, 2] == GINQ {
+                from v in (
+                    from n in numbers
+                    where n < 3
+                    select n
+                )
+                select v
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from - 3"() {
+        assertScript '''
+            def numbers = [1, 2, 3]
+            assert [2] == GINQ {
+                from v in (
+                    from n in numbers
+                    where n < 3
+                    select n
+                )
+                where v > 1
+                select v
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from - 4"() {
+        assertScript '''
+            def nums1 = [1, 2, 3, 4, 5]
+            def nums2 = [1, 2, 3, 4, 5]
+            assert [[3, 3], [5, 5]] == GINQ {
+                from v in (
+                    from n1 in nums1
+                    innerJoin n2 in nums2
+                    on n1 == n2
+                    where n1 > 1 && n2 <= 5
+                    select n1, n2
+                )
+                where v[0] >= 3 && v[1] in [3, 5] // v[0] references column 1, and v[1] references column 2
+                select v
+            }.toList()
+        '''
+    }
 }
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/provider/collection/QueryableCollectionTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/provider/collection/QueryableCollectionTest.groovy
index 26d4255..0c31357 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/provider/collection/QueryableCollectionTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/provider/collection/QueryableCollectionTest.groovy
@@ -36,7 +36,6 @@ class QueryableCollectionTest {
     void testFrom() {
         assert [1, 2, 3] == from(Stream.of(1, 2, 3)).toList()
         assert [1, 2, 3] == from(Arrays.asList(1, 2, 3)).toList()
-        assert [1, 2, 3] == from(from(Arrays.asList(1, 2, 3))).toList()
     }
 
     @Test