You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/12/22 16:13:57 UTC

[groovy] branch master updated: Support basic window function

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new be3f183  Support basic window function
be3f183 is described below

commit be3f183b01cceed9817b6c2afe1796a0f1962b7b
Author: Daniel Sun <su...@apache.org>
AuthorDate: Wed Dec 23 00:10:36 2020 +0800

    Support basic window function
---
 .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java |  12 +-
 .../ginq/provider/collection/GinqAstWalker.groovy  | 188 ++++++++++++++++++---
 .../provider/collection/runtime/Queryable.java     |  22 +--
 .../collection/runtime/QueryableCollection.java    | 122 ++++++-------
 .../test/org/apache/groovy/ginq/GinqTest.groovy    |  21 +++
 .../runtime/QueryableCollectionTest.groovy         |  30 ++--
 6 files changed, 277 insertions(+), 118 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index 72ecd03..a17beed 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -108,10 +108,18 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
         return ginqExpression.getNodeMetaData(__LATEST_GINQ_EXPRESSION_CLAUSE);
     }
 
+    private boolean visitingOverClause;
+
     @Override
     public void visitMethodCallExpression(MethodCallExpression call) {
-        super.visitMethodCallExpression(call);
         final String methodName = call.getMethodAsString();
+        if ("over".equals(methodName)) {
+            visitingOverClause = true;
+        }
+        super.visitMethodCallExpression(call);
+        if ("over".equals(methodName)) {
+            visitingOverClause = false;
+        }
 
         if (!KEYWORD_SET.contains(methodName)) {
             ignoredMethodCallExpressionList.add(call);
@@ -247,7 +255,7 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep
             return;
         }
 
-        if (KW_ORDERBY.equals(methodName)) {
+        if (KW_ORDERBY.equals(methodName) && !visitingOverClause) {
             OrderExpression orderExpression = new OrderExpression(call.getArguments());
             orderExpression.setSourcePosition(call.getMethod());
 
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 8a2c8e3..017ed64 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -40,11 +40,13 @@ import org.apache.groovy.ginq.dsl.expression.WhereExpression
 import org.apache.groovy.ginq.provider.collection.runtime.NamedRecord
 import org.apache.groovy.ginq.provider.collection.runtime.Queryable
 import org.apache.groovy.ginq.provider.collection.runtime.QueryableHelper
+import org.apache.groovy.ginq.provider.collection.runtime.WindowDefinition
 import org.apache.groovy.util.Maps
 import org.codehaus.groovy.GroovyBugError
 import org.codehaus.groovy.ast.ClassHelper
 import org.codehaus.groovy.ast.ClassNode
 import org.codehaus.groovy.ast.CodeVisitorSupport
+import org.codehaus.groovy.ast.Parameter
 import org.codehaus.groovy.ast.expr.ArgumentListExpression
 import org.codehaus.groovy.ast.expr.BinaryExpression
 import org.codehaus.groovy.ast.expr.CastExpression
@@ -166,10 +168,13 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         SelectExpression selectExpression = currentGinqExpression.selectExpression
         selectExpression.putNodeMetaData(__METHOD_CALL_RECEIVER, resultMethodCallReceiver)
         selectExpression.dataSourceExpression = resultDataSourceExpression
-
         MethodCallExpression selectMethodCallExpression = this.visitSelectExpression(selectExpression)
 
         List<Statement> statementList = []
+        boolean useWindowFunction = isUseWindowFunction(selectExpression)
+        if (useWindowFunction) {
+            statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'setVar', args(new ConstantExpression(USE_WINDOW_FUNCTION), new ConstantExpression(TRUE_STR))))
+        }
 
         boolean isRootGinqExpression = ginqExpression === ginqExpression.getNodeMetaData(GinqAstBuilder.ROOT_GINQ_EXPRESSION)
         boolean parallelEnabled = isRootGinqExpression && TRUE_STR == configuration.get(GinqGroovyMethods.CONF_PARALLEL)
@@ -195,6 +200,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         if (parallelEnabled) {
             statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'removeVar', args(new ConstantExpression(PARALLEL))))
         }
+        if (useWindowFunction) {
+            statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'removeVar', args(new ConstantExpression(USE_WINDOW_FUNCTION))))
+        }
         statementList << returnS(varX(resultName))
 
         def result = callX(lambdaX(block(statementList as Statement[])), "call")
@@ -203,6 +211,22 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         return result
     }
 
+    private boolean isUseWindowFunction(SelectExpression selectExpression) {
+        boolean useWindowFunction = false
+        selectExpression.projectionExpr.visit(new GinqAstBaseVisitor() {
+            @Override
+            void visitMethodCallExpression(MethodCallExpression call) {
+                if (call.methodAsString == 'over') {
+                    useWindowFunction = true
+                    return
+                }
+
+                super.visitMethodCallExpression(call)
+            }
+        })
+        return useWindowFunction
+    }
+
     private static boolean isAggregateFunction(Expression expression) {
         Expression expr = expression
         if (expression instanceof CastExpression) {
@@ -552,6 +576,15 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         Expression orderMethodCallReceiver = orderExpression.getNodeMetaData(__METHOD_CALL_RECEIVER)
         Expression ordersExpr = orderExpression.ordersExpr
 
+        List<Expression> orderCtorCallExpressions = constructOrderCtorCallExpressions(ordersExpr, dataSourceExpression)
+
+        def orderMethodCallExpression = callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
+        orderMethodCallExpression.setSourcePosition(orderExpression)
+
+        return orderMethodCallExpression
+    }
+
+    private List<Expression> constructOrderCtorCallExpressions(Expression ordersExpr, DataSourceExpression dataSourceExpression) {
         List<Expression> argumentExpressionList = ((ArgumentListExpression) ordersExpr).getExpressions()
         List<Expression> orderCtorCallExpressions = argumentExpressionList.stream().map(e -> {
             Expression target = e
@@ -576,11 +609,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
 
             return ctorX(ORDER_TYPE, args(lambdaExpression, new ConstantExpression(asc)))
         }).collect(Collectors.toList())
-
-        def orderMethodCallExpression = callX(orderMethodCallReceiver, "orderBy", args(orderCtorCallExpressions))
-        orderMethodCallExpression.setSourcePosition(orderExpression)
-
-        return orderMethodCallExpression
+        return orderCtorCallExpressions
     }
 
     @Override
@@ -628,17 +657,92 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                     )
                 }
 
+                if (expression instanceof MethodCallExpression) {
+                    if ('over' == expression.methodAsString) {
+                        if (expression.objectExpression instanceof MethodCallExpression) {
+                            VariableExpression wqVar = varX(getWindowQueryableName())
+
+                            String lambdaParamName = getLambdaParamName(dataSourceExpression, lambdaCode)
+                            VariableExpression currentRecordVar = varX(lambdaParamName)
+
+                            currentGinqExpression.putNodeMetaData(__VISITING_WINDOW_FUNCTION, true)
+                            def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression)
+                            Expression newObjectExpression = callX(wqVar, 'over', args(
+                                    currentRecordVar,
+                                    windowDefinitionFactoryMethodCallExpression
+                            ))
+
+                            def windowFunctionMethodCallExpression = (MethodCallExpression) expression.objectExpression
+                            def windowFunctionLambdaCode = ((ArgumentListExpression) windowFunctionMethodCallExpression.arguments).getExpression(0)
+                            def windowFunctionLambdaName = findRootObjectExpression(windowFunctionLambdaCode).text
+                            def newWindowFunctionLambdaName = '__wfp'
+
+                            windowFunctionLambdaCode = ((ListExpression) (new ListExpression(Collections.singletonList(windowFunctionLambdaCode)).transformExpression(new ExpressionTransformer() {
+                                @Override
+                                Expression transform(Expression expr) {
+                                    if (expr instanceof VariableExpression) {
+                                        if (windowFunctionLambdaName == expr.text) {
+                                            if (dataSourceExpression instanceof JoinExpression) {
+                                                return correctVars(dataSourceExpression, newWindowFunctionLambdaName=getLambdaParamName(dataSourceExpression, expr), expr)
+                                            } else {
+                                                return new VariableExpression(newWindowFunctionLambdaName)
+                                            }
+                                        }
+                                    }
+                                    return expr.transformExpression(this)
+                                }
+                            }))).getExpression(0)
+
+                            def result = callX(
+                                    newObjectExpression,
+                                    windowFunctionMethodCallExpression.methodAsString,
+                                    lambdaX(
+                                            params(param(ClassHelper.DYNAMIC_TYPE, newWindowFunctionLambdaName)),
+                                            block(stmt(windowFunctionLambdaCode))
+                                    )
+                            )
+                            currentGinqExpression.putNodeMetaData(__VISITING_WINDOW_FUNCTION, false)
+
+                            return result
+                        }
+                    }
+                }
+
                 return expression.transformExpression(this)
             }
         })).getExpression(0)
 
-        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode)
+        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, param(ClassHelper.DYNAMIC_TYPE, getWindowQueryableName()))
 
         currentGinqExpression.putNodeMetaData(__VISITING_SELECT, false)
 
         return selectMethodCallExpression
     }
 
+    private MethodCallExpression constructWindowDefinitionFactoryMethodCallExpression(MethodCallExpression methodCallExpression, DataSourceExpression dataSourceExpression) {
+        Expression orderExpr = null
+        ArgumentListExpression argumentListExpression = (ArgumentListExpression) methodCallExpression.arguments
+        if (!argumentListExpression.getExpressions().isEmpty()) {
+            MethodCallExpression windowMce = (MethodCallExpression) argumentListExpression.getExpression(0)
+            if ('orderby' == windowMce.methodAsString) {
+                orderExpr = windowMce.arguments
+            }
+        }
+        callX(new ClassExpression(WINDOW_DEFINITION_TYPE), 'of', constructOrderCtorCallExpressions(orderExpr, dataSourceExpression).get(0))
+    }
+
+    private int windowQueryableNameSeq = 0
+    private String getWindowQueryableName() {
+        String name = (String) currentGinqExpression.getNodeMetaData(__WINDOW_QUERYABLE_NAME)
+
+        if (!name) {
+            name = "${__WINDOW_QUERYABLE_NAME}${windowQueryableNameSeq++}"
+            currentGinqExpression.putNodeMetaData(__WINDOW_QUERYABLE_NAME, name)
+        }
+
+        return name
+    }
+
     private static boolean isExpression(final Expression expr, final Class... expressionTypes) {
         Arrays.stream(expressionTypes).anyMatch(clazz -> {
             Expression tmpExpr = expr
@@ -918,6 +1022,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                         }
                     }
                 }
+            } else {
+                if (visitingWindowFunction) {
+                    boolean isJoin = dataSourceExpression instanceof JoinExpression
+                    if (isJoin) {
+                        Map<String, Expression>  aliasAccessPathMap = findAliasAccessPath(dataSourceExpression, new VariableExpression(lambdaParamName))
+                        transformedExpression = aliasAccessPathMap.get(expression.text)
+                    }
+                }
             }
         } else if (expression instanceof MethodCallExpression) {
             // #1
@@ -1005,22 +1117,29 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         return expression.accept(this)
     }
 
-    private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode) {
-        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode)
+    private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) {
+        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, extraParams)
 
         callXWithLambda(receiver, methodName, lambdaExpression)
     }
 
-    private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+    private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) {
         Tuple3<String, List<DeclarationExpression>, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode)
 
         List<DeclarationExpression> declarationExpressionList = paramNameAndLambdaCode.v2
         List<Statement> statementList = []
-        statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList()))
+        if (!visitingWindowFunction) {
+            statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList()))
+        }
         statementList.add(stmt(paramNameAndLambdaCode.v3))
 
+        def paramList = [param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)]
+        if (extraParams) {
+            paramList.addAll(Arrays.asList(extraParams))
+        }
+
         lambdaX(
-                params(param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)),
+                params(paramList as Parameter[]),
                 block(statementList as Statement[])
         )
     }
@@ -1030,18 +1149,29 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         "__t_${lambdaParamSeq++}"
     }
 
-    private Tuple3<String, List<DeclarationExpression>, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+    private String getLambdaParamName(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
         boolean groupByVisited = isGroupByVisited()
-
-        List<DeclarationExpression> declarationExpressionList = Collections.emptyList()
         String lambdaParamName
-        if (dataSourceExpression instanceof JoinExpression || groupByVisited) {
+        if (dataSourceExpression instanceof JoinExpression || groupByVisited || visitingWindowFunction) {
             lambdaParamName = lambdaCode.getNodeMetaData(__LAMBDA_PARAM_NAME)
-            if (!lambdaParamName || visitingAggregateFunctionStack) {
+            if (!lambdaParamName || visitingAggregateFunctionStack || visitingWindowFunction) {
                 lambdaParamName = generateLambdaParamName()
             }
 
             lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+        } else {
+            lambdaParamName = dataSourceExpression.aliasExpr.text
+            lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+        }
+
+        return lambdaParamName
+    }
+
+    private Tuple3<String, List<DeclarationExpression>, Expression> correctVariablesOfLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode) {
+        boolean groupByVisited = isGroupByVisited()
+        List<DeclarationExpression> declarationExpressionList = Collections.emptyList()
+        String lambdaParamName = getLambdaParamName(dataSourceExpression, lambdaCode)
+        if (dataSourceExpression instanceof JoinExpression || groupByVisited) {
             Tuple2<List<DeclarationExpression>, Expression> declarationAndLambdaCode = correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
             if (!visitingAggregateFunctionStack) {
                 declarationExpressionList = declarationAndLambdaCode.v1
@@ -1060,10 +1190,22 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             }
             lambdaCode = declarationAndLambdaCode.v2
         } else {
-            lambdaParamName = dataSourceExpression.aliasExpr.text
-            lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
+            if (visitingWindowFunction) {
+                lambdaCode = ((ListExpression) (new ListExpression(Collections.singletonList(lambdaCode)).transformExpression(new ExpressionTransformer() {
+                    @Override
+                    Expression transform(Expression expr) {
+                        if (expr instanceof VariableExpression) {
+                            if (dataSourceExpression.aliasExpr.text == expr.text) {
+                                return new VariableExpression(lambdaParamName)
+                            }
+                        }
+                        return expr.transformExpression(this)
+                    }
+                }))).getExpression(0)
+            }
         }
 
+
         if (lambdaCode instanceof ConstructorCallExpression) {
             if (NAMEDRECORD_CLASS_NAME == lambdaCode.type.redirect().name) {
                 // store the source record
@@ -1082,6 +1224,10 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         return currentGinqExpression.getNodeMetaData(__VISITING_SELECT) ?: false
     }
 
+    private boolean isVisitingWindowFunction() {
+        return currentGinqExpression.getNodeMetaData(__VISITING_WINDOW_FUNCTION) ?: false
+    }
+
     private boolean isRowNumberUsed() {
         return currentGinqExpression.getNodeMetaData(__RN_USED)  ?: false
     }
@@ -1121,6 +1267,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final ClassNode ORDER_TYPE = makeWithoutCaching(Queryable.Order.class)
     private static final ClassNode NAMED_RECORD_TYPE = makeWithoutCaching(NamedRecord.class)
     private static final ClassNode QUERYABLE_HELPER_TYPE = makeWithoutCaching(QueryableHelper.class)
+    private static final ClassNode WINDOW_DEFINITION_TYPE = makeWithoutCaching(WindowDefinition.class)
 
     private static final List<String> ORDER_OPTION_LIST = Arrays.asList('asc', 'desc')
     private static final String FUNCTION_COUNT = 'count'
@@ -1134,15 +1281,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
 
     private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name
 
+    private static final String USE_WINDOW_FUNCTION = 'useWindowFunction'
     private static final String PARALLEL = 'parallel'
     private static final String TRUE_STR = 'true'
 
     private static final String __METHOD_CALL_RECEIVER = "__METHOD_CALL_RECEIVER"
     private static final String __GROUPBY_VISITED = "__GROUPBY_VISITED"
     private static final String __VISITING_SELECT = "__VISITING_SELECT"
+    private static final String __VISITING_WINDOW_FUNCTION = "__VISITING_WINDOW_FUNCTION"
     private static final String __LAMBDA_PARAM_NAME = "__LAMBDA_PARAM_NAME"
     private static final String  __RN_USED = '__RN_USED'
     private static final String __META_DATA_MAP_NAME_PREFIX = '__metaDataMap_'
+    private static final String __WINDOW_QUERYABLE_NAME = '__wq_'
     private static final String __ROW_NUMBER_NAME_PREFIX = '__rowNumber_'
     private static final String __SOURCE_RECORD = "__sourceRecord"
     private static final String __GROUP = "__group"
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
index c3096db..10c0618 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
@@ -25,6 +25,7 @@ import java.math.BigDecimal;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
+import java.util.function.BiFunction;
 import java.util.function.BiPredicate;
 import java.util.function.Function;
 import java.util.function.Predicate;
@@ -266,7 +267,7 @@ public interface Queryable<T> {
      * @return the result of projecting
      * @since 4.0.0
      */
-    <U> Queryable<U> select(Function<? super T, ? extends U> mapper);
+    <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper);
 
     /**
      * Check if the result is empty, similar to SQL's {@code exists}
@@ -426,16 +427,15 @@ public interface Queryable<T> {
         return toList().stream();
     }
 
-    default <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition) {
-        Queryable<T> partition =
-                this.groupBy(windowDefinition.partitionBy()) // TODO cache the group result
-                        .where(e -> QueryableHelper.isIdentical(e.getV1(), windowDefinition.partitionBy().apply(currentRecord)))
-                        .select(e -> e.getV2())
-                        .toList()
-                        .get(0);
-
-        return new WindowImpl<>(currentRecord, partition, windowDefinition);
-    }
+    /**
+     * Open window for current record
+     *
+     * @param currentRecord current record
+     * @param windowDefinition window definition
+     * @param <U> the type of window value
+     * @return the window
+     */
+    <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition);
 
     /**
      * Represents an order rule
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index b9a4a47..6d91299 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -18,7 +18,6 @@
  */
 package org.apache.groovy.ginq.provider.collection.runtime;
 
-import groovy.lang.GroovyRuntimeException;
 import groovy.lang.Tuple2;
 import groovy.transform.Internal;
 import org.apache.groovy.internal.util.Supplier;
@@ -39,10 +38,10 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.Callable;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.function.BiFunction;
 import java.util.function.BiPredicate;
 import java.util.function.Function;
 import java.util.function.Predicate;
@@ -71,13 +70,16 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     }
 
     public Iterator<T> iterator() {
-        return readLock(() -> {
+        readLock.lock();
+        try {
             if (null != sourceIterable) {
                 return sourceIterable.iterator();
             }
 
             return sourceStream.iterator();
-        });
+        } finally {
+            readLock.unlock();
+        }
     }
 
     @Override
@@ -149,12 +151,12 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
 
     @Override
     public <U> Queryable<Tuple2<T, U>> rightJoin(Queryable<? extends U> queryable, BiPredicate<? super T, ? super U> joiner) {
-        return outerJoin(queryable, this, (a, b) -> joiner.test(b, a)).select(e -> tuple(e.getV2(), e.getV1()));
+        return outerJoin(queryable, this, (a, b) -> joiner.test(b, a)).select((e, q) -> tuple(e.getV2(), e.getV1()));
     }
 
     @Override
     public <U> Queryable<Tuple2<T, U>> rightHashJoin(Queryable<? extends U> queryable, Function<? super T, ?> fieldsExtractor1, Function<? super U, ?> fieldsExtractor2) {
-        return outerHashJoin(queryable, this, fieldsExtractor2, fieldsExtractor1).select(e -> tuple(e.getV2(), e.getV1()));
+        return outerHashJoin(queryable, this, fieldsExtractor2, fieldsExtractor1).select((e, q) -> tuple(e.getV2(), e.getV1()));
     }
 
     @Override
@@ -247,8 +249,12 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     }
 
     @Override
-    public <U> Queryable<U> select(Function<? super T, ? extends U> mapper) {
-        Stream<U> stream = this.stream().map(mapper);
+    public <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper) {
+        if (TRUE_STR.equals(QueryableHelper.getVar(USE_WINDOW_FUNCTION))) {
+            this.makeReusable();
+        }
+
+        Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this));
 
         return from(stream);
     }
@@ -438,7 +444,8 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
 
     @Override
     public List<T> toList() {
-        return writeLock(() -> {
+        writeLock.lock();
+        try {
             if (sourceIterable instanceof List) {
                 return (List<T>) sourceIterable;
             }
@@ -447,7 +454,9 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
             sourceIterable = result;
 
             return result;
-        });
+        } finally {
+            writeLock.unlock();
+        }
     }
 
     @Override
@@ -457,17 +466,33 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
 
     @Override
     public Stream<T> stream() {
-        return writeLock(() -> {
+        writeLock.lock();
+        try {
             if (isReusable()) {
                 sourceStream = toStream(sourceIterable);  // we have to create new stream every time because Java stream can not be reused
             }
 
-            if (!sourceStream.isParallel() && isParallelEnabled()) {
+            if (!sourceStream.isParallel() && TRUE_STR.equals(QueryableHelper.getVar(PARALLEL))) {
                 sourceStream = sourceStream.parallel();
             }
 
             return sourceStream;
-        });
+        } finally {
+            writeLock.unlock();
+        }
+    }
+
+    @Override
+    public <U extends Comparable<? super U>> Window<T> over(T currentRecord, WindowDefinition<T, U> windowDefinition) {
+        this.makeReusable();
+        Queryable<T> partition =
+                this.groupBy(windowDefinition.partitionBy()) // TODO cache the group result
+                        .where(e -> QueryableHelper.isIdentical(e.getV1(), windowDefinition.partitionBy().apply(currentRecord)))
+                        .select((e, q) -> e.getV2())
+                        .toList()
+                        .get(0);
+
+        return new WindowImpl<>(currentRecord, partition, windowDefinition);
     }
 
     private static <T> Stream<T> toStream(Iterable<T> sourceIterable) {
@@ -475,19 +500,25 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     }
 
     private boolean isReusable() {
-        return readLock(() -> {
+        readLock.lock();
+        try {
             return null != sourceIterable;
-        });
+        } finally {
+            readLock.unlock();
+        }
     }
 
     private void makeReusable() {
         if (null != this.sourceIterable) return;
 
-        writeLock(() -> {
+        writeLock.lock();
+        try {
             if (null != this.sourceIterable) return;
 
             this.sourceIterable = this.sourceStream.collect(Collectors.toList());
-        });
+        } finally {
+            writeLock.unlock();
+        }
     }
 
     public Object asType(Class<?> clazz) {
@@ -514,58 +545,6 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
         return DefaultGroovyMethods.asType(this, clazz);
     }
 
-    private void readLock(Runnable runnable) {
-        final boolean parallel = isParallelEnabled();
-
-        if (parallel) rl.lock();
-        try {
-            runnable.run();
-        } finally {
-            if (parallel) rl.unlock();
-        }
-    }
-
-    private <R> R readLock(Callable<R> callable) {
-        final boolean parallel = isParallelEnabled();
-
-        if (parallel) rl.lock();
-        try {
-            return callable.call();
-        } catch (Exception e) {
-            throw new GroovyRuntimeException(e);
-        } finally {
-            if (parallel) rl.unlock();
-        }
-    }
-
-    private void writeLock(Runnable runnable) {
-        final boolean parallel = isParallelEnabled();
-
-        if (parallel) wl.lock();
-        try {
-            runnable.run();
-        } finally {
-            if (parallel) wl.unlock();
-        }
-    }
-
-    private <R> R writeLock(Callable<R> callable) {
-        final boolean parallel = isParallelEnabled();
-
-        if (parallel) wl.lock();
-        try {
-            return callable.call();
-        } catch (Exception e) {
-            throw new GroovyRuntimeException(e);
-        } finally {
-            if (parallel) wl.unlock();
-        }
-    }
-
-    private static boolean isParallelEnabled() {
-        return TRUE_STR.equals(QueryableHelper.getVar(PARALLEL));
-    }
-
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;
@@ -587,9 +566,10 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     private Stream<T> sourceStream;
     private volatile Iterable<T> sourceIterable;
     private final ReadWriteLock rwl = new ReentrantReadWriteLock();
-    private final Lock rl = rwl.readLock();
-    private final Lock wl = rwl.writeLock();
+    private final Lock readLock = rwl.readLock();
+    private final Lock writeLock = rwl.writeLock();
     private static final BigDecimal BD_TWO = BigDecimal.valueOf(2);
+    private static final String USE_WINDOW_FUNCTION = "useWindowFunction";
     private static final String PARALLEL = "parallel";
     private static final String TRUE_STR = "true";
     private static final long serialVersionUID = -5067092453136522893L;
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index 737876b..6d2b48d 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -4648,6 +4648,27 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - window - 0"() {
+        assertGinqScript '''
+            assert [[2, 1], [1, null], [3, 2]] == GQ {
+                from n in [2, 1, 3]
+                select n, (lag(n) over(orderby n))
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - window - 1"() {
+        assertGinqScript '''
+            assert [[2, 1], [1, null], [3, 2]] == GQ {
+                from n in [2, 1, 3]
+                join m in [2, 1, 3] on m == n
+                select n, (lag(n) over(orderby n))
+            }.toList()
+        '''
+    }
+
     private static void assertGinqScript(String script) {
         String deoptimizedScript = script.replaceAll(/\bGQ\s*[{]/, 'GQ(optimize:false) {')
         List<String> scriptList = [deoptimizedScript, script]
diff --git a/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy b/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
index cf91e74..716c99a 100644
--- a/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
+++ b/subprojects/groovy-ginq/src/test/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollectionTest.groovy
@@ -701,14 +701,14 @@ class QueryableCollectionTest {
     @Test
     void testGroupBySelect0() {
         def nums = [1, 2, 2, 3, 3, 4, 4, 5]
-        def result = from(nums).groupBy(e -> e).select(e -> Tuple.tuple(e.v1, e.v2.toList())).toList()
+        def result = from(nums).groupBy(e -> e).select((e, q) -> Tuple.tuple(e.v1, e.v2.toList())).toList()
         assert [[1, [1]], [2, [2, 2]], [3, [3, 3]], [4, [4, 4]], [5, [5]]] == result
     }
 
     @Test
     void testGroupBySelect1() {
         def nums = [1, 2, 2, 3, 3, 4, 4, 5]
-        def result = from(nums).groupBy(e -> e).select(e -> Tuple.tuple(e.v1, e.v2.count())).toList()
+        def result = from(nums).groupBy(e -> e).select((e, q) -> Tuple.tuple(e.v1, e.v2.count())).toList()
         assert [[1, 1], [2, 2], [3, 2], [4, 2], [5, 1]] == result
     }
 
@@ -718,7 +718,7 @@ class QueryableCollectionTest {
         def nums = [1, 2, 2, 3, 3, 4, 4, 5]
         def result =
                 from(nums).groupBy(e -> e)
-                        .select(e ->
+                        .select((e, q) ->
                                 Tuple.tuple(
                                         e.v1,
                                         e.v2.count(),
@@ -734,7 +734,7 @@ class QueryableCollectionTest {
         def nums = [1, 2, 2, 3, 3, 4, 4, 5]
         def result =
                 from(nums).groupBy(e -> e, g -> g.v1 > 2)
-                        .select(e ->
+                        .select((e, q) ->
                                 Tuple.tuple(
                                         e.v1,
                                         e.v2.count(),
@@ -751,7 +751,7 @@ class QueryableCollectionTest {
                        new Person2('David', 121, 'Male')]
 
         def result = from(persons).groupBy(p -> p.gender)
-                .select(e -> Tuple.tuple(e.v1, e.v2.count())).toList()
+                .select((e, q) -> Tuple.tuple(e.v1, e.v2.count())).toList()
 
         assert [['Male', 2], ['Female', 1]] == result
     }
@@ -764,7 +764,7 @@ class QueryableCollectionTest {
                        new Person2('David', 121, 'Male')]
 
         def result = from(persons).groupBy(p -> new NamedTuple<>([p.gender], ['gender']))
-                .select(e -> Tuple.tuple(e.v1.gender, e.v2.min(p -> p.weight), e.v2.max(p -> p.weight))).toList()
+                .select((e, q) -> Tuple.tuple(e.v1.gender, e.v2.min(p -> p.weight), e.v2.max(p -> p.weight))).toList()
 
         assert [['Male', 121, 135], ['Female', 100, 100]] == result
     }
@@ -775,7 +775,7 @@ class QueryableCollectionTest {
         def nums = [1, 2, 2, 3, 3, 4, 4, 5]
         def result =
                 from(nums).groupBy(e -> e)
-                        .select(e ->
+                        .select((e, q) ->
                                 Tuple.tuple(
                                         e.v1,
                                         e.v2.count(),
@@ -791,7 +791,7 @@ class QueryableCollectionTest {
         def nums = [null, 2, 3]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.sum(n -> n)
                                 )
                         .toList()
@@ -804,7 +804,7 @@ class QueryableCollectionTest {
         def nums = [null, 2, 3]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.count(n -> n)
                         )
                         .toList()
@@ -817,7 +817,7 @@ class QueryableCollectionTest {
         def nums = [null, 2, 3]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.avg(n -> n)
                         )
                         .toList()
@@ -830,7 +830,7 @@ class QueryableCollectionTest {
         def nums = [1, 3, 2]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.median(n -> n)
                         )
                         .toList()
@@ -843,7 +843,7 @@ class QueryableCollectionTest {
         def nums = [1, 3, 2, 4]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.median(n -> n)
                         )
                         .toList()
@@ -856,7 +856,7 @@ class QueryableCollectionTest {
         def nums = [1]
         def result =
                 from(nums).groupBy(e -> 1)
-                        .select(e ->
+                        .select((e, q) ->
                                 e.v2.median(n -> n)
                         )
                         .toList()
@@ -922,7 +922,7 @@ class QueryableCollectionTest {
     @Test
     void testSelect() {
         def nums = [1, 2, 3, 4, 5]
-        def result = from(nums).select(e -> e + 1).toList()
+        def result = from(nums).select((e, q) -> e + 1).toList()
         assert [2, 3, 4, 5, 6] == result
     }
 
@@ -990,7 +990,7 @@ class QueryableCollectionTest {
                         .innerJoin(from(nums2), (a, b) -> a == b)
                         .where(t -> t.v1 > 1)
                         .limit(1, 2)
-                        .select(t -> t.v1 + 1)
+                        .select((t, q) -> t.v1 + 1)
                         .toList()
         assert [4, 5] == result
     }