You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by dw...@apache.org on 2021/03/10 09:53:05 UTC

[lucene] 01/01: LUCENE-9302: Grouping to use long to avoid overflows

This is an automated email from the ASF dual-hosted git repository.

dweiss pushed a commit to branch jira/lucene-9302
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit ddecbb2eb551a6c1d4fc1a6802cb210e9099f17d
Author: Ishan Chattopadhyaya <is...@apache.org>
AuthorDate: Mon Apr 6 22:38:29 2020 +0530

    LUCENE-9302: Grouping to use long to avoid overflows
---
 .../search/grouping/BlockGroupingCollector.java      |  4 ++--
 .../lucene/search/grouping/GroupingSearch.java       |  2 +-
 .../org/apache/lucene/search/grouping/TopGroups.java | 20 ++++++++++----------
 .../apache/lucene/search/grouping/TestGrouping.java  | 12 ++++++------
 4 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
index 23601ca..dad1101 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
@@ -353,8 +353,8 @@ public class BlockGroupingCollector extends SimpleCollector {
 
     return new TopGroups<>(new TopGroups<>(groupSort.getSort(),
                                        withinGroupSort.getSort(),
-                                       totalHitCount, totalGroupedHitCount, groups, maxScore),
-                         totalGroupCount);
+                                      (long) totalHitCount, (long) totalGroupedHitCount, groups, maxScore),
+                          (long) totalGroupCount);
   }
 
   @Override
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
index b88fb74..6ac5dc1 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
@@ -161,7 +161,7 @@ public class GroupingSearch {
     }
 
     if (allGroups) {
-      return new TopGroups(secondPassCollector.getTopGroups(groupDocsOffset), matchingGroups.size());
+      return new TopGroups(secondPassCollector.getTopGroups(groupDocsOffset), (long) matchingGroups.size());
     } else {
       return secondPassCollector.getTopGroups(groupDocsOffset);
     }
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
index b14e675..6bf711d 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
@@ -29,13 +29,13 @@ import org.apache.lucene.search.TotalHits.Relation;
  * @lucene.experimental */
 public class TopGroups<T> {
   /** Number of documents matching the search */
-  public final int totalHitCount;
+  public final long totalHitCount;
 
   /** Number of documents grouped into the topN groups */
-  public final int totalGroupedHitCount;
+  public final long totalGroupedHitCount;
 
   /** The total number of unique groups. If <code>null</code> this value is not computed. */
-  public final Integer totalGroupCount;
+  public final Long totalGroupCount;
 
   /** Group results in groupSort order */
   public final GroupDocs<T>[] groups;
@@ -50,7 +50,7 @@ public class TopGroups<T> {
    *  <code>Float.NaN</code> if scores were not computed. */
   public final float maxScore;
 
-  public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, int totalHitCount, int totalGroupedHitCount, GroupDocs<T>[] groups, float maxScore) {
+  public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, long totalHitCount, long totalGroupedHitCount, GroupDocs<T>[] groups, float maxScore) {
     this.groupSort = groupSort;
     this.withinGroupSort = withinGroupSort;
     this.totalHitCount = totalHitCount;
@@ -60,7 +60,7 @@ public class TopGroups<T> {
     this.maxScore = maxScore;
   }
 
-  public TopGroups(TopGroups<T> oldTopGroups, Integer totalGroupCount) {
+  public TopGroups(TopGroups<T> oldTopGroups, Long totalGroupCount) {
     this.groupSort = oldTopGroups.groupSort;
     this.withinGroupSort = oldTopGroups.withinGroupSort;
     this.totalHitCount = oldTopGroups.totalHitCount;
@@ -118,10 +118,10 @@ public class TopGroups<T> {
       return null;
     }
 
-    int totalHitCount = 0;
-    int totalGroupedHitCount = 0;
+    long totalHitCount = 0;
+    long totalGroupedHitCount = 0;
     // Optionally merge the totalGroupCount.
-    Integer totalGroupCount = null;
+    Long totalGroupCount = null;
 
     final int numGroups = shardGroups[0].groups.length;
     for(TopGroups<T> shard : shardGroups) {
@@ -132,7 +132,7 @@ public class TopGroups<T> {
       totalGroupedHitCount += shard.totalGroupedHitCount;
       if (shard.totalGroupCount != null) {
         if (totalGroupCount == null) {
-          totalGroupCount = 0;
+          totalGroupCount = 0L;
         }
 
         totalGroupCount += shard.totalGroupCount;
@@ -154,7 +154,7 @@ public class TopGroups<T> {
       final T groupValue = shardGroups[0].groups[groupIDX].groupValue;
       //System.out.println("  merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues));
       float maxScore = Float.NaN;
-      int totalHits = 0;
+      long totalHits = 0;
       double scoreSum = 0.0;
       for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) {
         //System.out.println("    shard=" + shardIDX);
diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
index f1ce508..73832c1 100644
--- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
+++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
@@ -452,7 +452,7 @@ public class TestGrouping extends LuceneTestCase {
     final List<BytesRef> sortedGroups = new ArrayList<>();
     final List<Comparable<?>[]> sortedGroupFields = new ArrayList<>();
 
-    int totalHitCount = 0;
+    long totalHitCount = 0;
     Set<BytesRef> knownGroups = new HashSet<>();
 
     //System.out.println("TEST: slowGrouping");
@@ -492,7 +492,7 @@ public class TestGrouping extends LuceneTestCase {
     final Comparator<GroupDoc> docSortComp = getComparator(docSort);
     @SuppressWarnings({"unchecked","rawtypes"})
     final GroupDocs<BytesRef>[] result = new GroupDocs[limit-groupOffset];
-    int totalGroupedHitCount = 0;
+    long totalGroupedHitCount = 0;
     for(int idx=groupOffset;idx < limit;idx++) {
       final BytesRef group = sortedGroups.get(idx);
       final List<GroupDoc> docs = groups.get(group);
@@ -523,7 +523,7 @@ public class TestGrouping extends LuceneTestCase {
     if (doAllGroups) {
       return new TopGroups<>(
         new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN),
-          knownGroups.size()
+          (long) knownGroups.size()
       );
     } else {
       return new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN);
@@ -960,7 +960,7 @@ public class TestGrouping extends LuceneTestCase {
           
           if (doAllGroups) {
             TopGroups<BytesRef> tempTopGroups = getTopGroups(c2, docOffset);
-            groupsResult = new TopGroups<>(tempTopGroups, allGroupsCollector.getGroupCount());
+            groupsResult = new TopGroups<>(tempTopGroups, (long) allGroupsCollector.getGroupCount());
           } else {
             groupsResult = getTopGroups(c2, docOffset);
           }
@@ -1046,8 +1046,8 @@ public class TestGrouping extends LuceneTestCase {
         final TopGroups<BytesRef> tempTopGroupsBlocks = (TopGroups<BytesRef>) c3.getTopGroups(docSort, groupOffset, docOffset, docOffset+docsPerGroup);
         final TopGroups<BytesRef> groupsResultBlocks;
         if (doAllGroups && tempTopGroupsBlocks != null) {
-          assertEquals((int) tempTopGroupsBlocks.totalGroupCount, allGroupsCollector2.getGroupCount());
-          groupsResultBlocks = new TopGroups<>(tempTopGroupsBlocks, allGroupsCollector2.getGroupCount());
+          assertEquals((long) tempTopGroupsBlocks.totalGroupCount, (long) allGroupsCollector2.getGroupCount());
+          groupsResultBlocks = new TopGroups<>(tempTopGroupsBlocks, (long) allGroupsCollector2.getGroupCount());
         } else {
           groupsResultBlocks = tempTopGroupsBlocks;
         }