You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/12/22 16:13:57 UTC
[groovy] branch master updated: Support basic window function
This is an automated email from the ASF dual-hosted git repository.
sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git
The following commit(s) were added to refs/heads/master by this push:
new be3f183 Support basic window function
be3f183 is described below
commit be3f183b01cceed9817b6c2afe1796a0f1962b7b
Author: Daniel Sun <su...@apache.org>
AuthorDate: Wed Dec 23 00:10:36 2020 +0800
Support basic window function
---
.../org/apache/groovy/ginq/dsl/GinqAstBuilder.java | 12 +-
.../ginq/provider/collection/GinqAstWalker.groovy | 188 ++++++++++++++++++---
.../provider/collection/runtime/Queryable.java | 22 +--
.../collection/runtime/QueryableCollection.java | 122 ++++++-------
.../test/org/apache/groovy/ginq/GinqTest.groovy | 21 +++
.../runtime/QueryableCollectionTest.groovy | 30 ++--
6 files changed, 277 insertions(+), 118 deletions(-)
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index 72ecd03..a17beed 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -108,10 +108,18 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
return ginqExpression.getNodeMetaData(__LATEST_GINQ_EXPRESSION_CLAUSE);
}
+ private boolean visitingOverClause;
+
@Override
public void visitMethodCallExpression(MethodCallExpression call) {
- super.visitMethodCallExpression(call);
final String methodName = call.getMethodAsString();
+ if ("over".equals(methodName)) {
+ visitingOverClause = true;
+ }
+ super.visitMethodCallExpression(call);
+ if ("over".equals(methodName)) {
+ visitingOverClause = false;
+ }
if (!KEYWORD_SET.contains(methodName)) {
ignoredMethodCallExpressionList.add(call);
@@ -247,7 +255,7 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
return;
}
- if (KW_ORDERBY.equals(methodName)) {
+ if (KW_ORDERBY.equals(methodName) && !visitingOverClause) {
OrderExpression orderExpression = new OrderExpression(call.getArguments());
orderExpression.setSourcePosition(call.getMethod());
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 8a2c8e3..017ed64 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -40,11 +40,13 @@ import org.apache.groovy.ginq.dsl.expression.WhereExpression
import org.apache.groovy.ginq.provider.collection.runtime.NamedRecord
import org.apache.groovy.ginq.provider.collection.runtime.Queryable
import org.apache.groovy.ginq.provider.collection.runtime.QueryableHelper
+import org.apache.groovy.ginq.provider.collection.runtime.WindowDefinition
import org.apache.groovy.util.Maps
import org.codehaus.groovy.GroovyBugError
import org.codehaus.groovy.ast.ClassHelper
import org.codehaus.groovy.ast.ClassNode
import org.codehaus.groovy.ast.CodeVisitorSupport
+import org.codehaus.groovy.ast.Parameter
import org.codehaus.groovy.ast.expr.ArgumentListExpression
import org.codehaus.groovy.ast.expr.BinaryExpression
import org.codehaus.groovy.ast.expr.CastExpression
@@ -166,10 +168,13 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
SelectExpression selectExpression = currentGinqExpression.selectExpression
selectExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, resultMethodCallReceiver)
selectExpression.dataSourceExpression = resultDataSourceExpression
-
MethodCallExpression selectMethodCallExpression = this.visitSelectExpression(selectExpression)
List<Statement> statementList = []
+ boolean useWindowFunction = isUseWindowFunction(selectExpression)
+ if (useWindowFunction) {
+ statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'setVar', args(new ConstantExpression(USE_WINDOW_FUNCTION), new ConstantExpression(TRUE_STR))))
+ }
boolean isRootGinqExpression = ginqExpression === ginqExpression.getNodeMetaData(GinqAstBuilder.ROOT_GINQ_EXPRESSION)
boolean parallelEnabled = isRootGinqExpression && TRUE_STR == configuration.get(GinqGroovyMethods.CONF_PARALLEL)
@@ -195,6 +200,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
if (parallelEnabled) {
statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'removeVar', args(new ConstantExpression(PARALLEL))))
}
+ if (useWindowFunction) {
+ statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'removeVar', args(new ConstantExpression(USE_WINDOW_FUNCTION))))
+ }
statementList << returnS(varX(resultName))
def result = callX(lambdaX(block(statementList as Statement[])), "call")
@@ -203,6 +211,22 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
return result
}
+ private boolean isUseWindowFunction(SelectExpression selectExpression) {
+ boolean useWindowFunction = false
+ selectExpression.projectionExpr.visit(new GinqAstBaseVisitor() {
+ @Override
+ void visitMethodCallExpression(MethodCallExpression call) {
+ if (call.methodAsString == 'over') {
+ useWindowFunction = true
+ return
+ }
+
+ super.visitMethodCallExpression(call)
+ }
+ })
+ return useWindowFunction
+ }
+
private static boolean isAggregateFunction(Expression expression) {
Expression expr = expression
if (expression instanceof CastExpression) {
@@ -552,6 +576,15 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
Expression orderMethodCallReceiver = orderExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
Expression ordersExpr = orderExpression.ordersExpr
+ List<Expression> orderCtorCallExpressions = constructOrderCtorCallExpressions(ordersExpr, dataSourceExpression)
+
+ def orderMethodCallExpression = callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
+ orderMethodCallExpression.setSourcePosition(orderExpression)
+
+ return orderMethodCallExpression
+ }
+
+ private List<Expression> constructOrderCtorCallExpressions(Expression ordersExpr, DataSourceExpression dataSourceExpression) {
List<Expression> argumentExpressionList = ((ArgumentListExpression) ordersExpr).getExpressions()
List<Expression> orderCtorCallExpressions = argumentExpressionList.stream().map(e -> {
Expression target = e
@@ -576,11 +609,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
return ctorX(ORDER_TYPE, args(lambdaExpression, new ConstantExpression(asc)))
}).collect(Collectors.toList())
-
- def orderMethodCallExpression = callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
- orderMethodCallExpression.setSourcePosition(orderExpression)
-
- return orderMethodCallExpression
+ return orderCtorCallExpressions
}
@Override
@@ -628,17 +657,92 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
)
}
+ if (expression instanceof MethodCallExpression) {
+ if ('over' == expression.methodAsString) {
+ if (expression.objectExpression instanceof MethodCallExpression) {
+ VariableExpression wqVar = varX(getWindowQueryableName())
+
+ String lambdaParamName = getLambdaParamName(dataSourceExpression, lambdaCode)
+ VariableExpression currentRecordVar = varX(lambdaParamName)
+
+ currentGinqExpression.putNodeMetaData(__VISITING_WINDOW_FUNCTION, true)
+ def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression)
+ Expression newObjectExpression = callX(wqVar, 'over', args(
+ currentRecordVar,
+ windowDefinitionFactoryMethodCallExpression
+ ))
+
+ def windowFunctionMethodCallExpression = (MethodCallExpression) expression.objectExpression
+ def windowFunctionLambdaCode = ((ArgumentListExpression) windowFunctionMethodCallExpression.arguments).getExpression(0)
+ def windowFunctionLambdaName = findRootObjectExpression(windowFunctionLambdaCode).text
+ def newWindowFunctionLambdaName = '__wfp'
+
+ windowFunctionLambdaCode = ((ListExpression) (new ListExpression(Collections.singletonList(windowFunctionLambdaCode)).transformExpression(new ExpressionTransformer() {
+ @Override
+ Expression transform(Expression expr) {
+ if (expr instanceof VariableExpression) {
+ if (windowFunctionLambdaName == expr.text) {
+ if (dataSourceExpression instanceof JoinExpression) {
+ return correctVars(dataSourceExpression, newWindowFunctionLambdaName=getLambdaParamName(dataSourceExpression, expr), expr)
+ } else {
+ return new VariableExpression(newWindowFunctionLambdaName)
+ }
+ }
+ }
+ return expr.transformExpression(this)
+ }
+ }))).getExpression(0)
+
+ def result = callX(
+ newObjectExpression,
+ windowFunctionMethodCallExpression.methodAsString,
+ lambdaX(
+ params(param(ClassHelper.DYNAMIC_TYPE, newWindowFunctionLambdaName)),
+ block(stmt(windowFunctionLambdaCode))
+ )
+ )
+ currentGinqExpression.putNodeMetaData(__VISITING_WINDOW_FUNCTION, false)
+
+ return result
+ }
+ }
+ }
+
return expression.transformExpression(this)
}
})).getExpression(0)
- def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode)
+ def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, param(ClassHelper.DYNAMIC_TYPE, getWindowQueryableName()))
currentGinqExpression.putNodeMetaData(__VISITING_SELECT, false)
return selectMethodCallExpression
}
+ private MethodCallExpression constructWindowDefinitionFactoryMethodCallExpression(MethodCallExpression methodCallExpression, DataSourceExpression dataSourceExpression) {
+ Expression orderExpr = null
+ ArgumentListExpression argumentListExpression = (ArgumentListExpression) methodCallExpression.arguments
+ if (!argumentListExpression.getExpressions().isEmpty()) {
+ MethodCallExpression windowMce = (MethodCallExpression) argumentListExpression.getExpression(0)
+ if ('orderby' == windowMce.methodAsString) {
+ orderExpr = windowMce.arguments
+ }
+ }
+ callX(new ClassExpression(WINDOW_DEFINITION_TYPE), 'of', constructOrderCtorCallExpressions(orderExpr, dataSourceExpression).get(0))
+ }
+
+ private int windowQueryableNameSeq = 0
+ private String getWindowQueryableName() {
+ String name = (String) currentGinqExpression.getNodeMetaData(__WINDOW_QUERYABLE_NAME)
+
+ if (!name) {
+ name = "${__WINDOW_QUERYABLE_NAME}${windowQueryableNameSeq++}"
+ currentGinqExpression.putNodeMetaData(__WINDOW_QUERYABLE_NAME, name)
+ }
+
+ return name
+ }
+
private static boolean isExpression(final Expression expr, final Class... expressionTypes) {
Arrays.stream(expressionTypes).anyMatch(clazz -> {
Expression tmpExpr = expr
@@ -918,6 +1022,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
}
}
}
+ } else {
+ if (visitingWindowFunction) {
+ boolean isJoin = dataSourceExpression instanceof JoinExpression
+ if (isJoin) {
+ Map<String, Expression> aliasAccessPathMap = findAliasAccessPath(dataSourceExpression, new VariableExpression(lambdaParamName))
+ transformedExpression = aliasAccessPathMap.get(expression.text)
+ }
+ }
}
} else if (expression instanceof MethodCallExpression) {
// #1
@@ -1005,22 +1117,29 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
return expression.accept(this)
}
- private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode) {
- LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode)
+ private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) {
+ LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, extraParams)
callXWithLambda(receiver, methodName, lambdaExpression)
}
- private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+ private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) {
Tuple3<String, List<DeclarationExpression>, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode)
List<DeclarationExpression> declarationExpressionList = paramNameAndLambdaCode.v2
List<Statement> statementList = []
- statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList()))
+ if (!visitingWindowFunction) {
+ statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList()))
+ }
statementList.add(stmt(paramNameAndLambdaCode.v3))
+ def paramList = [param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)]
+ if (extraParams) {
+ paramList.addAll(Arrays.asList(extraParams))
+ }
+
lambdaX(
- params(param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)),
+ params(paramList as Parameter[]),
block(statementList as Statement[])
)
}
@@ -1030,18 +1149,29 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
"__t_${lambdaParamSeq++}"
}
- private Tuple3<String, List<DeclarationExpression>, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+ private String getLambdaParamName(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
boolean groupByVisited = isGroupByVisited()
-
- List<DeclarationExpression> declarationExpressionList = Collections.emptyList()
String lambdaParamName
- if (dataSourceExpression instanceof JoinExpression || groupByVisited) {
+ if (dataSourceExpression instanceof JoinExpression || groupByVisited || visitingWindowFunction) {
lambdaParamName = lambdaCode.getNodeMetaData(__LAMBDA_PARAM_NAME)
- if (!lambdaParamName || visitingAggregateFunctionStack) {
+ if (!lambdaParamName || visitingAggregateFunctionStack || visitingWindowFunction) {
lambdaParamName = generateLambdaParamName()
}
lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+ } else {
+ lambdaParamName = dataSourceExpression.aliasExpr.text
+ lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+ }
+
+ return lambdaParamName
+ }
+
+ private Tuple3<String, List<DeclarationExpression>, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+ boolean groupByVisited = isGroupByVisited()
+ List<DeclarationExpression> declarationExpressionList = Collections.emptyList()
+ String lambdaParamName = getLambdaParamName(dataSourceExpression, lambdaCode)
+ if (dataSourceExpression instanceof JoinExpression || groupByVisited) {
Tuple2<List<DeclarationExpression>, Expression> declarationAndLambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
if (!visitingAggregateFunctionStack) {
declarationExpressionList = declarationAndLambdaCode.v1
@@ -1060,10 +1190,22 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
}
lambdaCode = declarationAndLambdaCode.v2
} else {
- lambdaParamName = dataSourceExpression.aliasExpr.text
- lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+ if (visitingWindowFunction) {
+ lambdaCode = ((ListExpression) (new ListExpression(Collections.singletonList(lambdaCode)).transformExpression(new ExpressionTransformer() {
+ @Override
+ Expression transform(Expression expr) {
+ if (expr instanceof VariableExpression) {
+ if (dataSourceExpression.aliasExpr.text == expr.text) {
+ return new VariableExpression(lambdaParamName)
+ }
+ }
+ return expr.transformExpression(this)
+ }
+ }))).getExpression(0)
+ }
}
+
if (lambdaCode instanceof ConstructorCallExpression) {
if (NAMEDRECORD_CLASS_NAME == lambdaCode.type.redirect().name) {
// store the source record
@@ -1082,6 +1224,10 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
return currentGinqExpression.getNodeMetaData(__VISITING_SELECT) ?: false
}
+ private boolean isVisitingWindowFunction() {
+ return currentGinqExpression.getNodeMetaData(__VISITING_WINDOW_FUNCTION) ?: false
+ }
+
private boolean isRowNumberUsed() {
return currentGinqExpression.getNodeMetaData(__RN_USED) ?: false
}
@@ -1121,6 +1267,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
private static final ClassNode ORDER_TYPE = makeWithoutCaching(Queryable.Order.class)
private static final ClassNode NAMED_RECORD_TYPE = makeWithoutCaching(NamedRecord.class)
private static final ClassNode QUERYABLE_HELPER_TYPE = makeWithoutCaching(QueryableHelper.class)
+ private static final ClassNode WINDOW_DEFINITION_TYPE = makeWithoutCaching(WindowDefinition.class)
private static final List<String> ORDER_OPTION_LIST = Arrays.asList('asc', 'desc')
private static final String FUNCTION_COUNT = 'count'
@@ -1134,15 +1281,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name
+ private static final String USE_WINDOW_FUNCTION = 'useWindowFunction'
private static final String PARALLEL = 'parallel'
private static final String TRUE_STR = 'true'
private static final String __METHOD_CALL_RECEIVER = "__METHOD_CALL_RECEIVER"
private static final String __GROUPBY_VISITED = "__GROUPBY_VISITED"
private static final String __VISITING_SELECT = "__VISITING_SELECT"
+ private static final String __VISITING_WINDOW_FUNCTION = "__VISITING_WINDOW_FUNCTION"
private static final String __LAMBDA_PARAM_NAME = "__LAMBDA_PARAM_NAME"
private static final String __RN_USED = '__RN_USED'
private static final String __META_DATA_MAP_NAME_PREFIX = '__metaDataMap_'
+ private static final String __WINDOW_QUERYABLE_NAME = '__wq_'
private static final String __ROW_NUMBER_NAME_PREFIX = '__rowNumber_'
private static final String __SOURCE_RECORD = "__sourceRecord"
private static final String __GROUP = "__group"
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
index c3096db..10c0618 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
@@ -25,6 +25,7 @@ import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
+import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
@@ -266,7 +267,7 @@ public interface Queryable<T> {
* @return the result of projecting
* @since 4.0.0
*/
- <U> Queryable<U> select(Function<? super T, ? extends U> mapper);
+ <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper);
/**
* Check if the result is empty, similar to SQL's {@code exists}
@@ -426,16 +427,15 @@ public interface Queryable<T> {
return toList().stream();
}
- default <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition) {
- Queryable<T> partition =
- this.groupBy(windowDefinition.partitionBy()) // TODO cache the group result
- .where(e -> QueryableHelper.isIdentical(e.getV1(), windowDefinition.partitionBy().apply(currentRecord)))
- .select(e -> e.getV2())
- .toList()
- .get(0);
-
- return new WindowImpl<>(currentRecord, partition, windowDefinition);
- }
+ /**
+ * Open window for current record
+ *
+ * @param currentRecord current record
+ * @param windowDefinition window definition
+ * @param <U> the type of window value
+ * @return the window
+ */
+ <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition);
/**
* Represents an order rule
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index b9a4a47..6d91299 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -18,7 +18,6 @@
*/
package org.apache.groovy.ginq.provider.collection.runtime;
-import groovy.lang.GroovyRuntimeException;
import groovy.lang.Tuple2;
import groovy.transform.Internal;
import org.apache.groovy.internal.util.Supplier;
@@ -39,10 +38,10 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
-import java.util.concurrent.Callable;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
@@ -71,13 +70,16 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
}
public Iterator<T> iterator() {
- return readLock(() -> {
+ readLock.lock();
+ try {
if (null != sourceIterable) {
return sourceIterable.iterator();
}
return sourceStream.iterator();
- });
+ } finally {
+ readLock.unlock();
+ }
}
@Override
@@ -149,12 +151,12 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
@Override
public <U> Queryable<Tuple2<T, U>> rightJoin(Queryable<? extends U> queryable, BiPredicate<? super T, ? super U> joiner) {
- return outerJoin(queryable, this, (a, b) -> joiner.test(b, a)).select(e -> tuple(e.getV2(), e.getV1()));
+ return outerJoin(queryable, this, (a, b) -> joiner.test(b, a)).select((e, q) -> tuple(e.getV2(), e.getV1()));
}
@Override
public <U> Queryable<Tuple2<T, U>> rightHashJoin(Queryable<? extends U> queryable, Function<? super T, ?> fieldsExtractor1, Function<? super U, ?> fieldsExtractor2) {
- return outerHashJoin(queryable, this, fieldsExtractor2, fieldsExtractor1).select(e -> tuple(e.getV2(), e.getV1()));
+ return outerHashJoin(queryable, this, fieldsExtractor2, fieldsExtractor1).select((e, q) -> tuple(e.getV2(), e.getV1()));
}
@Override
@@ -247,8 +249,12 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
}
@Override
- public <U> Queryable<U> select(Function<? super T, ? extends U> mapper) {
- Stream<U> stream = this.stream().map(mapper);
+ public <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper) {
+ if (TRUE_STR.equals(QueryableHelper.getVar(USE_WINDOW_FUNCTION))) {
+ this.makeReusable();
+ }
+
+ Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this));
return from(stream);
}
@@ -438,7 +444,8 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
@Override
public List<T> toList() {
- return writeLock(() -> {
+ writeLock.lock();
+ try {
if (sourceIterable instanceof List) {
return (List<T>) sourceIterable;
}
@@ -447,7 +454,9 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
sourceIterable = result;
return result;
- });
+ } finally {
+ writeLock.unlock();
+ }
}
@Override
@@ -457,17 +466,33 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
@Override
public Stream<T> stream() {
- return writeLock(() -> {
+ writeLock.lock();
+ try {
if (isReusable()) {
sourceStream = toStream(sourceIterable); // we have to create new stream every time because Java stream can not be reused
}
- if (!sourceStream.isParallel() && isParallelEnabled()) {
+ if (!sourceStream.isParallel() && TRUE_STR.equals(QueryableHelper.getVar(PARALLEL))) {
sourceStream = sourceStream.parallel();
}
return sourceStream;
- });
+ } finally {
+ writeLock.unlock();
+ }
+ }
+
+ @Override
+ public <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition) {
+ this.makeReusable();
+ Queryable<T> partition =
+ this.groupBy(windowDefinition.partitionBy()) // TODO cache the group result
+ .where(e -> QueryableHelper.isIdentical(e.getV1(), windowDefinition.partitionBy().apply(currentRecord)))
+ .select((e, q) -> e.getV2())
+ .toList()
+ .get(0);
+
+ return new WindowImpl<>(currentRecord, partition, windowDefinition);
}
private static <T> Stream<T> toStream(Iterable<T> sourceIterable) {
@@ -475,19 +500,25 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
}
private boolean isReusable() {
- return readLock(() -> {
+ readLock.lock();
+ try {
return null != sourceIterable;
- });
+ } finally {
+ readLock.unlock();
+ }
}
private void makeReusable() {
if (null != this.sourceIterable) return;
- writeLock(() -> {
+ writeLock.lock();
+ try {
if (null != this.sourceIterable) return;
this.sourceIterable = this.sourceStream.collect(Collectors.toList());
- });
+ } finally {
+ writeLock.unlock();
+ }
}
public Object asType(Class<?> clazz) {
@@ -514,58 +545,6 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
return DefaultGroovyMethods.asType(this, clazz);
}
- private void readLock(Runnable runnable) {
- final boolean parallel = isParallelEnabled();
-
- if (parallel) rl.lock();
- try {
- runnable.run();
- } finally {
- if (parallel) rl.unlock();
- }
- }
-
- private <R> R readLock(Callable<R> callable) {
- final boolean parallel = isParallelEnabled();
-
- if (parallel) rl.lock();
- try {
- return callable.call();
- } catch (Exception e) {
- throw new GroovyRuntimeException(e);
- } finally {
- if (parallel) rl.unlock();
- }
- }
-
- private void writeLock(Runnable runnable) {
- final boolean parallel = isParallelEnabled();
-
- if (parallel) wl.lock();
- try {
- runnable.run();
- } finally {
- if (parallel) wl.unlock();
- }
- }
-
- private <R> R writeLock(Callable<R> callable) {
- final boolean parallel = isParallelEnabled();
-
- if (parallel) wl.lock();
- try {
- return callable.call();
- } catch (Exception e) {
- throw new GroovyRuntimeException(e);
- } finally {
- if (parallel) wl.unlock();
- }
- }
-
- private static boolean isParallelEnabled() {
- return TRUE_STR.equals(QueryableHelper.getVar(PARALLEL));
- }
-
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -587,9 +566,10 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
private Stream<T> sourceStream;
private volatile Iterable<T> sourceIterable;
private final ReadWriteLock rwl = new ReentrantReadWriteLock();
- private final Lock rl = rwl.readLock();
- private final Lock wl = rwl.writeLock();
+ private final Lock readLock = rwl.readLock();
+ private final Lock writeLock = rwl.writeLock();
private static final BigDecimal BD_TWO = BigDecimal.valueOf(2);
+ private static final String USE_WINDOW_FUNCTION = "useWindowFunction";
private static final String PARALLEL = "parallel";
private static final String TRUE_STR = "true";
private static final long serialVersionUID = -5067092453136522893L;
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index 737876b..6d2b48d 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -4648,6 +4648,27 @@ class GinqTest {
'''
}
+ @Test
+ void "testGinq - window - 0"() {
+ assertGinqScript '''
+ assert [[2, 1], [1, null], [3, 2]] == GQ {
+ from n in [2, 1, 3]
+ select n, (lag(n) over(orderby n))
+ }.toList()
+ '''
+ }
+
+ @Test
+ void "testGinq - window - 1"() {
+ assertGinqScript '''
+ assert [[2, 1], [1, null], [3, 2]] == GQ {
+ from n in [2, 1, 3]
+ join m in [2, 1, 3] on m == n
+ select n, (lag(n) over(orderby n))
+ }.toList()
+ '''
+ }
+
private static void assertGinqScript(String script) {
String deoptimizedScript = script.replaceAll(/\bGQ\s*[{]/, 'GQ(optimize:false) {')
List<String> scriptList = [deoptimizedScript, script]
diff --git a/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy b/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
index cf91e74..716c99a 100644
--- a/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
+++ b/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
@@ -701,14 +701,14 @@ class QueryableCollectionTest {
@Test
void testGroupBySelect0() {
def nums = [1, 2, 2, 3, 3, 4, 4, 5]
- def result = from(nums).groupBy(e -> e).select(e -> Tuple.tuple(e.v1, e.v2.toList())).toList()
+ def result = from(nums).groupBy(e -> e).select((e, q) -> Tuple.tuple(e.v1, e.v2.toList())).toList()
assert [[1, [1]], [2, [2, 2]], [3, [3, 3]], [4, [4, 4]], [5, [5]]] == result
}
@Test
void testGroupBySelect1() {
def nums = [1, 2, 2, 3, 3, 4, 4, 5]
- def result = from(nums).groupBy(e -> e).select(e -> Tuple.tuple(e.v1, e.v2.count())).toList()
+ def result = from(nums).groupBy(e -> e).select((e, q) -> Tuple.tuple(e.v1, e.v2.count())).toList()
assert [[1, 1], [2, 2], [3, 2], [4, 2], [5, 1]] == result
}
@@ -718,7 +718,7 @@ class QueryableCollectionTest {
def nums = [1, 2, 2, 3, 3, 4, 4, 5]
def result =
from(nums).groupBy(e -> e)
- .select(e ->
+ .select((e, q) ->
Tuple.tuple(
e.v1,
e.v2.count(),
@@ -734,7 +734,7 @@ class QueryableCollectionTest {
def nums = [1, 2, 2, 3, 3, 4, 4, 5]
def result =
from(nums).groupBy(e -> e, g -> g.v1 > 2)
- .select(e ->
+ .select((e, q) ->
Tuple.tuple(
e.v1,
e.v2.count(),
@@ -751,7 +751,7 @@ class QueryableCollectionTest {
new Person2('David', 121, 'Male')]
def result = from(persons).groupBy(p -> p.gender)
- .select(e -> Tuple.tuple(e.v1, e.v2.count())).toList()
+ .select((e, q) -> Tuple.tuple(e.v1, e.v2.count())).toList()
assert [['Male', 2], ['Female', 1]] == result
}
@@ -764,7 +764,7 @@ class QueryableCollectionTest {
new Person2('David', 121, 'Male')]
def result = from(persons).groupBy(p -> new NamedTuple<>([p.gender], ['gender']))
- .select(e -> Tuple.tuple(e.v1.gender, e.v2.min(p -> p.weight), e.v2.max(p -> p.weight))).toList()
+ .select((e, q) -> Tuple.tuple(e.v1.gender, e.v2.min(p -> p.weight), e.v2.max(p -> p.weight))).toList()
assert [['Male', 121, 135], ['Female', 100, 100]] == result
}
@@ -775,7 +775,7 @@ class QueryableCollectionTest {
def nums = [1, 2, 2, 3, 3, 4, 4, 5]
def result =
from(nums).groupBy(e -> e)
- .select(e ->
+ .select((e, q) ->
Tuple.tuple(
e.v1,
e.v2.count(),
@@ -791,7 +791,7 @@ class QueryableCollectionTest {
def nums = [null, 2, 3]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.sum(n -> n)
)
.toList()
@@ -804,7 +804,7 @@ class QueryableCollectionTest {
def nums = [null, 2, 3]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.count(n -> n)
)
.toList()
@@ -817,7 +817,7 @@ class QueryableCollectionTest {
def nums = [null, 2, 3]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.avg(n -> n)
)
.toList()
@@ -830,7 +830,7 @@ class QueryableCollectionTest {
def nums = [1, 3, 2]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.median(n -> n)
)
.toList()
@@ -843,7 +843,7 @@ class QueryableCollectionTest {
def nums = [1, 3, 2, 4]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.median(n -> n)
)
.toList()
@@ -856,7 +856,7 @@ class QueryableCollectionTest {
def nums = [1]
def result =
from(nums).groupBy(e -> 1)
- .select(e ->
+ .select((e, q) ->
e.v2.median(n -> n)
)
.toList()
@@ -922,7 +922,7 @@ class QueryableCollectionTest {
@Test
void testSelect() {
def nums = [1, 2, 3, 4, 5]
- def result = from(nums).select(e -> e + 1).toList()
+ def result = from(nums).select((e, q) -> e + 1).toList()
assert [2, 3, 4, 5, 6] == result
}
@@ -990,7 +990,7 @@ class QueryableCollectionTest {
.innerJoin(from(nums2), (a, b) -> a == b)
.where(t -> t.v1 > 1)
.limit(1, 2)
- .select(t -> t.v1 + 1)
+ .select((t, q) -> t.v1 + 1)
.toList()
assert [4, 5] == result
}