You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2020/12/31 08:27:43 UTC

[groovy] branch master updated: Tweak window function of GINQ

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 186c6e7  Tweak window function of GINQ
186c6e7 is described below

commit 186c6e7402ea8bd11890fa891087217fcdb84f23
Author: Daniel Sun <su...@apache.org>
AuthorDate: Thu Dec 31 15:15:47 2020 +0800

    Tweak window function of GINQ
---
 .../collection/runtime/QueryableCollection.java    |   8 +-
 .../ginq/provider/collection/runtime/Window.java   |   1 +
 .../provider/collection/runtime/WindowImpl.java    | 102 +++++++++++----------
 .../test/org/apache/groovy/ginq/GinqTest.groovy    |  15 +++
 4 files changed, 71 insertions(+), 55 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index ade24b5..3e50bd8 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -71,12 +71,6 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
         this.sourceStream = sourceStream;
     }
 
-    protected List<Tuple2<T, Long>> listWithIndex;
-    QueryableCollection(Queryable<Tuple2<T, Long>> queryableWithIndex) {
-        this(queryableWithIndex.toList().stream().map(Tuple2::getV1).collect(Collectors.toList()));
-        this.listWithIndex = queryableWithIndex.toList();
-    }
-
     public Iterator<T> iterator() {
         readLock.lock();
         try {
@@ -531,7 +525,7 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
                         .findFirst()
                         .orElse(Queryable.emptyQueryable());
 
-        return new WindowImpl<>(currentRecord, partition, windowDefinition);
+        return WindowImpl.newInstance(currentRecord, partition, windowDefinition);
     }
 
     private static <T> Stream<T> toStream(Iterable<T> sourceIterable) {
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Window.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Window.java
index 3e683f2..5361de7 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Window.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Window.java
@@ -27,6 +27,7 @@ import java.util.function.Function;
  * @since 4.0.0
  */
 public interface Window<T> extends Queryable<T> {
+
     /**
      * Returns row number in the window, similar to SQL's {@code row_number()}
      *
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/WindowImpl.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/WindowImpl.java
index d99f039..867b865 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/WindowImpl.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/WindowImpl.java
@@ -18,14 +18,17 @@
  */
 package org.apache.groovy.ginq.provider.collection.runtime;
 
+import groovy.lang.Tuple;
 import groovy.lang.Tuple2;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static java.util.Collections.binarySearch;
 import static java.util.Comparator.comparing;
+import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from;
 
 /**
  * Represents window which stores elements used by window functions
@@ -37,30 +40,44 @@ import static java.util.Comparator.comparing;
 class WindowImpl<T, U extends Comparable<? super U>> extends QueryableCollection<T> implements Window<T> {
     private static final long serialVersionUID = -3458969297047398621L;
     private final Tuple2<T, Long> currentRecord;
+    private final Function<? super T, ? extends U> keyExtractor;
+    private final long size;
     private final int index;
-    private final WindowDefinition<T, U> windowDefinition;
     private final U value;
-    private final Function<? super T, ? extends U> keyExtractor;
-
-    WindowImpl(Tuple2<T, Long> currentRecord, Queryable<Tuple2<T, Long>> partition, WindowDefinition<T, U> windowDefinition) {
-        super(partition.orderBy(composeOrders(windowDefinition)));
-        this.currentRecord = currentRecord;
-        this.windowDefinition = windowDefinition;
 
+    static <T, U extends Comparable<? super U>> Window<T> newInstance(Tuple2<T, Long> currentRecord, Queryable<Tuple2<T, Long>> partition, WindowDefinition<T, U> windowDefinition) {
+        Function<? super T, ? extends U> keyExtractor;
         final List<Order<? super T, ? extends U>> orderList = windowDefinition.orderBy();
         if (null != orderList && 1 == orderList.size()) {
-            this.keyExtractor = orderList.get(0).getKeyExtractor();
-            this.value = keyExtractor.apply(currentRecord.getV1());
+            keyExtractor = orderList.get(0).getKeyExtractor();
         } else {
-            this.keyExtractor = null;
-            this.value = null;
+            keyExtractor = null;
         }
 
+        List<Tuple2<T, Long>> listWithIndex = partition.orderBy(composeOrders(windowDefinition)).toList();
+
         int tmpIndex = null == orderList || orderList.isEmpty()
-                        ? binarySearch(listWithIndex, currentRecord, comparing(Tuple2::getV2))
-                        : binarySearch(listWithIndex, currentRecord, makeComparator(composeOrders(orderList)).thenComparing(Tuple2::getV2));
+                ? binarySearch(listWithIndex, currentRecord, comparing(Tuple2::getV2))
+                : binarySearch(listWithIndex, currentRecord, makeComparator(composeOrders(orderList)).thenComparing(Tuple2::getV2));
+        int index = tmpIndex >= 0 ? tmpIndex : -tmpIndex - 1;
+
+        long size = partition.size();
+        Tuple2<Long, Long> indexTuple = getValidFirstAndLastIndex(windowDefinition, index, size);
+        List<T> list = null == indexTuple ? Collections.emptyList()
+                                  : from(listWithIndex.stream().map(Tuple2::getV1).collect(Collectors.toList()))
+                                      .limit(indexTuple.getV1(), indexTuple.getV2() - indexTuple.getV1() + 1)
+                                      .toList();
+
+        return new WindowImpl<>(currentRecord, index, list, keyExtractor);
+    }
 
-        this.index = tmpIndex >= 0 ? tmpIndex : -tmpIndex - 1;
+    private WindowImpl(Tuple2<T, Long> currentRecord, int index, List<T> list, Function<? super T, ? extends U> keyExtractor) {
+        super(list);
+        this.currentRecord = currentRecord;
+        this.keyExtractor = keyExtractor;
+        this.index = index;
+        this.value = null == keyExtractor ? null : keyExtractor.apply(currentRecord.getV1());
+        this.size = list.size();
     }
 
     @Override
@@ -73,7 +90,7 @@ class WindowImpl<T, U extends Comparable<? super U>> extends QueryableCollection
         V field;
         if (0 == lead) {
             field = extractor.apply(currentRecord.getV1());
-        } else if (0 <= index + lead && index + lead < this.size()) {
+        } else if (0 <= index + lead && index + lead < size) {
             field = extractor.apply(this.toList().get(index + (int) lead));
         } else {
             field = def;
@@ -88,31 +105,24 @@ class WindowImpl<T, U extends Comparable<? super U>> extends QueryableCollection
 
     @Override
     public <V> V firstValue(Function<? super T, ? extends V> extractor) {
-        long lastIndex = getLastIndex();
-        if (lastIndex < 0) {
-            return null;
-        }
-        long firstIndex = getFirstIndex();
-        if (firstIndex >= this.size()) {
+        List<T> list = this.toList();
+
+        if (list.isEmpty()) {
             return null;
         }
-        int resultIndex = (int) Math.max(0, firstIndex);
-        return extractor.apply(this.toList().get(resultIndex));
+
+        return extractor.apply(list.get(0));
     }
 
     @Override
     public <V> V lastValue(Function<? super T, ? extends V> extractor) {
-        long firstIndex = getFirstIndex();
-        long size = this.size();
-        if (firstIndex >= size) {
-            return null;
-        }
-        long lastIndex = getLastIndex();
-        if (lastIndex < 0) {
+        List<T> list = this.toList();
+
+        if (list.isEmpty()) {
             return null;
         }
-        int resultIndex = (int) Math.min(size - 1, lastIndex);
-        return extractor.apply(this.toList().get(resultIndex));
+
+        return extractor.apply(list.get(list.size() - 1));
     }
 
     @Override
@@ -147,29 +157,25 @@ class WindowImpl<T, U extends Comparable<? super U>> extends QueryableCollection
         return result;
     }
 
-    private long getFirstIndex() {
+    private static <T, U extends Comparable<? super U>> long getFirstIndex(WindowDefinition<T, U> windowDefinition, int index) {
         RowBound rowBound = windowDefinition.rows();
-        long firstRowIndex;
         final Long lower = rowBound.getLower();
-        if (null == lower || Long.MIN_VALUE == lower) {
-            firstRowIndex = 0;
-        } else {
-            firstRowIndex = index + lower;
-        }
-        return firstRowIndex;
+        return null == lower || Long.MIN_VALUE == lower ? 0 : index + lower;
     }
 
-    private long getLastIndex() {
+    private static <T, U extends Comparable<? super U>> long getLastIndex(WindowDefinition<T, U> windowDefinition, int index, long size) {
         RowBound rowBound = windowDefinition.rows();
-        long lastRowIndex;
-        long size = this.size();
         final Long upper = rowBound.getUpper();
-        if (null == upper || Long.MAX_VALUE == upper) {
-            lastRowIndex = size - 1;
-        } else {
-            lastRowIndex = index + upper;
+        return null == upper || Long.MAX_VALUE == upper ? size - 1 : index + upper;
+    }
+
+    private static <T, U extends Comparable<? super U>> Tuple2<Long, Long> getValidFirstAndLastIndex(WindowDefinition<T, U> windowDefinition, int index, long size) {
+        long firstIndex = getFirstIndex(windowDefinition, index);
+        long lastIndex = getLastIndex(windowDefinition, index, size);
+        if ((firstIndex < 0 && lastIndex < 0) || (firstIndex >= size && lastIndex >= size)) {
+            return null;
         }
-        return lastRowIndex;
+        return Tuple.tuple(Math.max(firstIndex, 0), Math.min(lastIndex, size - 1));
     }
 
     private static <T, U extends Comparable<? super U>> List<Order<Tuple2<T, Long>, U>> composeOrders(List<Queryable.Order<? super T, ? extends U>> orderList) {
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index ba2fe62..09a6c0d 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -5234,6 +5234,21 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - window - 50"() {
+        assertGinqScript '''
+            assert [[1, 3, 2, 1, 2, 1.5, 1.5], [2, 6, 3, 1, 3, 2, 2], [3, 5, 3, 2, 2, 2.5, 2.5]] == GQ {
+                from n in [1, 2, 3]
+                select n, (sum(n) over(orderby n rows -1, 1)), 
+                          (max(n) over(orderby n rows -1, 1)), 
+                          (min(n) over(orderby n rows -1, 1)),
+                          (count(n) over(orderby n rows -1, 1)),
+                          (avg(n) over(orderby n rows -1, 1)),
+                          (median(n) over(orderby n rows -1, 1))
+            }.toList()
+        '''
+    }
+
     private static void assertGinqScript(String script) {
         String deoptimizedScript = script.replaceAll(/\bGQ\s*[{]/, 'GQ(optimize:false) {')
         List<String> scriptList = [deoptimizedScript, script]