You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/10/09 10:16:08 UTC
[groovy] branch GROOVY-8258 updated: GROOVY-8258: support multiple
joins
This is an automated email from the ASF dual-hosted git repository.
sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git
The following commit(s) were added to refs/heads/GROOVY-8258 by this push:
new 46803f3 GROOVY-8258: support multiple joins
46803f3 is described below
commit 46803f30e70b320dd1a17264d0d1e6fef44caf5d
Author: Daniel Sun <su...@apache.org>
AuthorDate: Fri Oct 9 17:15:05 2020 +0800
GROOVY-8258: support multiple joins
---
.../linq/provider/collection/GinqAstWalker.groovy | 106 ++++++++++++++-------
.../groovy/org/apache/groovy/linq/GinqTest.groovy | 49 +++++++++-
2 files changed, 117 insertions(+), 38 deletions(-)
diff --git a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index 354e65e..9a63d2e 100644
--- a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -158,8 +158,6 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
@Override
MethodCallExpression visitJoinExpression(JoinExpression joinExpression) {
Expression receiver = joinExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
- DataSourceExpression dataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
- Expression receiverAliasExpr = dataSourceExpression.aliasExpr
List<FilterExpression> filterExpressionList = joinExpression.getFilterExpressionList()
int filterExpressionListSize = filterExpressionList.size()
@@ -183,7 +181,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
WhereExpression whereExpression = filterExpressionList.size() < (whereExpressionPos + 1) ? null : (WhereExpression) filterExpressionList.get(whereExpressionPos)
- MethodCallExpression joinMethodCallExpression = constructJoinMethodCallExpression(receiver, receiverAliasExpr, joinExpression, onExpression, whereExpression)
+ MethodCallExpression joinMethodCallExpression = constructJoinMethodCallExpression(receiver, joinExpression, onExpression, whereExpression)
return joinMethodCallExpression
}
@@ -209,18 +207,31 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
private MethodCallExpression constructJoinMethodCallExpression(
- Expression receiver, Expression receiverAliasExpr, JoinExpression joinExpression,
+ Expression receiver, JoinExpression joinExpression,
OnExpression onExpression, WhereExpression whereExpression) {
+
+ DataSourceExpression otherDataSourceExpression = joinExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+ Expression otherAliasExpr = otherDataSourceExpression.aliasExpr
+
+ String otherParamName = otherAliasExpr.text
+ Expression filterExpr = EmptyExpression.INSTANCE
+ if (onExpression) {
+ filterExpr = onExpression.getFilterExpr()
+ Tuple2<String, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(otherDataSourceExpression, filterExpr)
+ otherParamName = paramNameAndLambdaCode.v1
+ filterExpr = paramNameAndLambdaCode.v2
+ }
+
MethodCallExpression resultMethodCallExpression
MethodCallExpression joinMethodCallExpression = callX(receiver, joinExpression.joinName.replace('join', 'Join'),
args(
constructFromMethodCallExpression(joinExpression.dataSourceExpr),
null == onExpression ? EmptyExpression.INSTANCE : lambdaX(
params(
- param(ClassHelper.DYNAMIC_TYPE, receiverAliasExpr.text),
+ param(ClassHelper.DYNAMIC_TYPE, otherParamName),
param(ClassHelper.DYNAMIC_TYPE, joinExpression.aliasExpr.text)
),
- stmt(onExpression.getFilterExpr())
+ stmt(filterExpr)
)
)
)
@@ -330,13 +341,10 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
return namedListCtorCallExpression
}
- private static Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr, boolean isGroup) {
+ private Expression correctVariablesOfGinqExpression(DataSourceExpression dataSourceExpression, Expression expr) {
+ boolean isGroup = isGroupByVisited()
boolean isJoin = dataSourceExpression instanceof JoinExpression
- DataSourceExpression otherDataSourceExpression = dataSourceExpression.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
- final Expression firstAliasExpr = null == otherDataSourceExpression ? EmptyExpression.INSTANCE : otherDataSourceExpression.aliasExpr
- final Expression secondAliasExpr = dataSourceExpression.aliasExpr
-
def correctVars = { Expression expression ->
if (expression instanceof VariableExpression) {
Expression transformedExpression = null
@@ -346,15 +354,41 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
// the correct receiver is `__t.v2`, so we should not replace `__t` here
if (__T != expression.text) {
// replace `gk` in the groupby with `__t.v1.gk`, note: __t.v1 stores the group key
- transformedExpression = propX(constructFirstAliasVariableAccess(), expression.text)
+ transformedExpression = propX(propX(new VariableExpression(__T), 'v1'), expression.text)
}
} else if (isJoin) {
- if (firstAliasExpr.text == expression.text) {
- // replace `n1` with `__t.v1`
- transformedExpression = constructFirstAliasVariableAccess()
- } else if (secondAliasExpr.text == expression.text) {
- // replace `n2` with `__t.v2`
- transformedExpression = constructSecondAliasVariableAccess()
+ /*
+ * `n1`(from node) join `n2` join `n3` will construct a join tree:
+ *
+ * __t (join node)
+ * |__ v2 (n3)
+ * |__ v1 (join node)
+ * |__ v2 (n2)
+ * |__ v1 (n1) (from node)
+ *
+ * Note: `__t` is a tuple with 2 elements
+ * so `n3`'s access path is `__t.v2`
+ * and `n2`'s access path is `__t.v1.v2`
+ * and `n1`'s access path is `__t.v1.v1`
+ *
+ * The following code shows how to construct the access path for variables
+ */
+ def prop = new VariableExpression(__T)
+ for (DataSourceExpression dse = dataSourceExpression;
+ null == transformedExpression && dse instanceof JoinExpression;
+ dse = dse.getNodeMetaData(__DATA_SOURCE_EXPRESSION)) {
+
+ DataSourceExpression otherDataSourceExpression = dse.getNodeMetaData(__DATA_SOURCE_EXPRESSION)
+ Expression firstAliasExpr = otherDataSourceExpression?.aliasExpr ?: EmptyExpression.INSTANCE
+ Expression secondAliasExpr = dse.aliasExpr
+
+ if (firstAliasExpr.text == expression.text && otherDataSourceExpression !instanceof JoinExpression) {
+ transformedExpression = propX(prop, 'v1')
+ } else if (secondAliasExpr.text == expression.text) {
+ transformedExpression = propX(prop, 'v2')
+ } else { // not found
+ prop = propX(prop, 'v1')
+ }
}
}
if (null != transformedExpression) {
@@ -366,7 +400,7 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
if (expression.implicitThis) {
String methodName = expression.methodAsString
if ('count' == methodName && ((TupleExpression) expression.arguments).getExpressions().isEmpty()) {
- expression.objectExpression = constructSecondAliasVariableAccess()
+ expression.objectExpression = propX(new VariableExpression(__T), 'v2')
return expression
}
}
@@ -406,20 +440,30 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
}
private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
- boolean isGroup = currentGinqExpression.getNodeMetaData(__GROUP_BY) ?: false
+ Tuple2<String, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode)
+
+ lambdaX(
+ params(param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)),
+ stmt(paramNameAndLambdaCode.v2)
+ )
+ }
+
+ private Tuple2<String, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+ boolean isGroup = isGroupByVisited()
String lambdaParamName
if (dataSourceExpression instanceof JoinExpression || isGroup) {
lambdaParamName = __T
- lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode, isGroup)
+ lambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
} else {
lambdaParamName = dataSourceExpression.aliasExpr.text
}
- lambdaX(
- params(param(ClassHelper.DYNAMIC_TYPE, lambdaParamName)),
- stmt(lambdaCode)
- )
+ return Tuple.tuple(lambdaParamName, lambdaCode)
+ }
+
+ private boolean isGroupByVisited() {
+ return currentGinqExpression.getNodeMetaData(__GROUP_BY) ?: false
}
private static MethodCallExpression callXWithLambda(Expression receiver, String methodName, LambdaExpression lambdaExpression) {
@@ -430,18 +474,6 @@ class GinqAstWalker implements GinqVisitor<Object>, SyntaxErrorReportable {
)
}
- private static Expression constructFirstAliasVariableAccess() {
- constructAliasVariableAccess('v1')
- }
-
- private static Expression constructSecondAliasVariableAccess() {
- constructAliasVariableAccess('v2')
- }
-
- private static Expression constructAliasVariableAccess(String name) {
- propX(new VariableExpression(__T), name)
- }
-
private static makeQueryableCollectionClassExpression() {
new ClassExpression(ClassHelper.make(Queryable.class))
}
diff --git a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index 22d274e..f6da498 100644
--- a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -25,7 +25,6 @@ import org.junit.Test
import static groovy.test.GroovyAssert.assertScript
-
@CompileStatic
class GinqTest {
@Test
@@ -312,6 +311,54 @@ class GinqTest {
}
@Test
+ void "testGinq - from innerjoin select - 11"() {
+ assertScript '''
+ def nums1 = [1, 2, 3]
+ def nums2 = [2, 3, 4]
+ def nums3 = [3, 4, 5]
+ assert [[3, 3, 3]] == GINQ {
+ from n1 in nums1
+ innerjoin n2 in nums2 on n2 == n1
+ innerjoin n3 in nums3 on n3 == n2
+ select n1, n2, n3
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - from innerjoin select - 12"() {
+ assertScript '''
+ def nums1 = [1, 2, 3]
+ def nums2 = [2, 3, 4]
+ def nums3 = [3, 4, 5]
+ assert [[3, 3, 3]] == GINQ {
+ from n1 in nums1
+ innerjoin n2 in nums2 on n2 == n1
+ innerjoin n3 in nums3 on n3 == n1
+ select n1, n2, n3
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - from innerjoin select - 13"() {
+ assertScript '''
+ def nums1 = [1, 2, 3]
+ def nums2 = [2, 3, 4]
+ def nums3 = [3, 4, 5]
+ assert [[3, 3, 3]] == GINQ {
+ from v in (
+ from n1 in nums1
+ innerjoin n2 in nums2 on n1 == n2
+ select n1, n2
+ )
+ innerjoin n3 in nums3 on v.n2 == n3
+ select v.n1, v.n2, n3
+ }.toList()
+ '''
+ }
+
+ @Test
void "testGinq - from innerjoin where select - 1"() {
assertScript '''
def nums1 = [1, 2, 3]