You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/01/25 08:10:09 UTC

[lucene] branch branch_9x updated: LUCENE-10384: Simplify LongHeap. (#615)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new b443015  LUCENE-10384: Simplify LongHeap. (#615)
b443015 is described below

commit b443015f01403c0ca8b602742f7d39a5acfa3652
Author: Adrien Grand <jp...@gmail.com>
AuthorDate: Tue Jan 25 09:04:52 2022 +0100

    LUCENE-10384: Simplify LongHeap. (#615)
    
    The min/max ordering logic moves to NeighborQueue.
---
 .../apache/lucene/codecs/lucene90/PForUtil.java    |  2 +-
 .../src/java/org/apache/lucene/util/LongHeap.java  | 63 ++++---------------
 .../org/apache/lucene/util/hnsw/NeighborQueue.java | 35 ++++++++---
 .../test/org/apache/lucene/util/TestLongHeap.java  | 73 +++++++---------------
 4 files changed, 61 insertions(+), 112 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/PForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/PForUtil.java
index c8f470d..eb735c8 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/PForUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/PForUtil.java
@@ -60,7 +60,7 @@ final class PForUtil {
   /** Encode 128 integers from {@code longs} into {@code out}. */
   void encode(long[] longs, DataOutput out) throws IOException {
     // Determine the top MAX_EXCEPTIONS + 1 values
-    final LongHeap top = LongHeap.create(LongHeap.Order.MIN, MAX_EXCEPTIONS + 1);
+    final LongHeap top = new LongHeap(MAX_EXCEPTIONS + 1);
     for (int i = 0; i <= MAX_EXCEPTIONS; ++i) {
       top.push(longs[i]);
     }
diff --git a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
index 68b6b43..b1f64fa 100644
--- a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
+++ b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
@@ -17,27 +17,16 @@
 package org.apache.lucene.util;
 
 /**
- * A heap that stores longs; a primitive priority queue that like all priority queues maintains a
- * partial ordering of its elements such that the least element can always be found in constant
+ * A min heap that stores longs; a primitive priority queue that like all priority queues maintains
+ * a partial ordering of its elements such that the least element can always be found in constant
  * time. Put()'s and pop()'s require log(size). This heap provides unbounded growth via {@link
  * #push(long)}, and bounded-size insertion based on its nominal maxSize via {@link
- * #insertWithOverflow(long)}. The heap may be either a min heap, in which case the least element is
- * the smallest integer, or a max heap, when it is the largest, depending on the Order parameter.
+ * #insertWithOverflow(long)}. The heap is a min heap, meaning that the top element is the lowest
+ * value of the heap.
  *
  * @lucene.internal
  */
-public abstract class LongHeap {
-
-  /**
-   * Used to specify the ordering of the heap. A min-heap provides access to the minimum element in
-   * constant time, and when bounded, retains the maximum <code>maxSize</code> elements. A max-heap
-   * conversely provides access to the maximum element in constant time, and when bounded retains
-   * the minimum <code>maxSize</code> elements.
-   */
-  public enum Order {
-    MIN,
-    MAX
-  }
+public final class LongHeap {
 
   private final int maxSize;
 
@@ -50,7 +39,7 @@ public abstract class LongHeap {
    * @param maxSize the maximum size of the heap, or if negative, the initial size of an unbounded
    *     heap
    */
-  LongHeap(int maxSize) {
+  public LongHeap(int maxSize) {
     final int heapSize;
     if (maxSize < 1 || maxSize >= ArrayUtil.MAX_ARRAY_LENGTH) {
       // Throw exception to prevent confusing OOME:
@@ -63,33 +52,6 @@ public abstract class LongHeap {
     this.heap = new long[heapSize];
   }
 
-  public static LongHeap create(Order order, int maxSize) {
-    // TODO: override push() for unbounded queue
-    if (order == Order.MIN) {
-      return new LongHeap(maxSize) {
-        @Override
-        public boolean lessThan(long a, long b) {
-          return a < b;
-        }
-      };
-    } else {
-      return new LongHeap(maxSize) {
-        @Override
-        public boolean lessThan(long a, long b) {
-          return a > b;
-        }
-      };
-    }
-  }
-
-  /**
-   * Determines the ordering of objects in this priority queue. Subclasses must define this one
-   * method.
-   *
-   * @return <code>true</code> iff parameter <code>a</code> is less than parameter <code>b</code>.
-   */
-  public abstract boolean lessThan(long a, long b);
-
   /**
    * Adds a value in log(size) time. Grows unbounded as needed to accommodate new values.
    *
@@ -114,7 +76,7 @@ public abstract class LongHeap {
    */
   public boolean insertWithOverflow(long value) {
     if (size >= maxSize) {
-      if (lessThan(value, heap[1])) {
+      if (value < heap[1]) {
         return false;
       }
       updateTop(value);
@@ -190,7 +152,7 @@ public abstract class LongHeap {
     int i = origPos;
     long value = heap[i]; // save bottom value
     int j = i >>> 1;
-    while (j > 0 && lessThan(value, heap[j])) {
+    while (j > 0 && value < heap[j]) {
       heap[i] = heap[j]; // shift parents down
       i = j;
       j = j >>> 1;
@@ -202,15 +164,15 @@ public abstract class LongHeap {
     long value = heap[i]; // save top value
     int j = i << 1; // find smaller child
     int k = j + 1;
-    if (k <= size && lessThan(heap[k], heap[j])) {
+    if (k <= size && heap[k] < heap[j]) {
       j = k;
     }
-    while (j <= size && lessThan(heap[j], value)) {
+    while (j <= size && heap[j] < value) {
       heap[i] = heap[j]; // shift up child
       i = j;
       j = i << 1;
       k = j + 1;
-      if (k <= size && lessThan(heap[k], heap[j])) {
+      if (k <= size && heap[k] < heap[j]) {
         j = k;
       }
     }
@@ -236,7 +198,8 @@ public abstract class LongHeap {
    *
    * @lucene.internal
    */
-  protected final long[] getHeapArray() {
+  // pkg-private for testing
+  final long[] getHeapArray() {
     return heap;
   }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index 4102dff..83a8b75 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -29,17 +29,34 @@ import org.apache.lucene.util.NumericUtils;
  */
 public class NeighborQueue {
 
+  private static enum Order {
+    NATURAL {
+      @Override
+      long apply(long v) {
+        return v;
+      }
+    },
+    REVERSED {
+      @Override
+      long apply(long v) {
+        // This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
+        // needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa.
+        return -1 - v;
+      }
+    };
+
+    abstract long apply(long v);
+  }
+
   private final LongHeap heap;
+  private final Order order;
 
   // Used to track the number of neighbors visited during a single graph traversal
   private int visitedCount;
 
   NeighborQueue(int initialSize, boolean reversed) {
-    if (reversed) {
-      heap = LongHeap.create(LongHeap.Order.MAX, initialSize);
-    } else {
-      heap = LongHeap.create(LongHeap.Order.MIN, initialSize);
-    }
+    this.heap = new LongHeap(initialSize);
+    this.order = reversed ? Order.REVERSED : Order.NATURAL;
   }
 
   /** @return the number of elements in the heap */
@@ -71,12 +88,12 @@ public class NeighborQueue {
   }
 
   private long encode(int node, float score) {
-    return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
+    return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
   }
 
   /** Removes the top element and returns its node id. */
   public int pop() {
-    return (int) heap.pop();
+    return (int) order.apply(heap.pop());
   }
 
   int[] nodes() {
@@ -90,12 +107,12 @@ public class NeighborQueue {
 
   /** Returns the top element's node id. */
   public int topNode() {
-    return (int) heap.top();
+    return (int) order.apply(heap.top());
   }
 
   /** Returns the top element's node score. */
   public float topScore() {
-    return NumericUtils.sortableIntToFloat((int) (heap.top() >> 32));
+    return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
   }
 
   public int visitedCount() {
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
index fa1a449..f1eef26 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
@@ -16,9 +16,6 @@
  */
 package org.apache.lucene.util;
 
-import static org.apache.lucene.util.LongHeap.Order.MAX;
-import static org.apache.lucene.util.LongHeap.Order.MIN;
-
 import java.util.ArrayList;
 import java.util.Random;
 import org.apache.lucene.tests.util.LuceneTestCase;
@@ -26,26 +23,11 @@ import org.apache.lucene.tests.util.TestUtil;
 
 public class TestLongHeap extends LuceneTestCase {
 
-  private static class AssertingLongHeap extends LongHeap {
-    AssertingLongHeap(int count) {
-      super(count);
-    }
-
-    @Override
-    public boolean lessThan(long a, long b) {
-      return (a < b);
-    }
-
-    final void checkValidity() {
-      long[] heapArray = getHeapArray();
-      for (int i = 1; i <= size(); i++) {
-        int parent = i >>> 1;
-        if (parent > 1) {
-          if (lessThan(heapArray[parent], heapArray[i]) == false) {
-            assertEquals(heapArray[parent], heapArray[i]);
-          }
-        }
-      }
+  private static void checkValidity(LongHeap heap) {
+    long[] heapArray = heap.getHeapArray();
+    for (int i = 2; i <= heap.size(); i++) {
+      int parent = i >>> 1;
+      assert heapArray[parent] <= heapArray[i];
     }
   }
 
@@ -54,7 +36,7 @@ public class TestLongHeap extends LuceneTestCase {
   }
 
   public static void testPQ(int count, Random gen) {
-    LongHeap pq = LongHeap.create(MIN, count);
+    LongHeap pq = new LongHeap(count);
     long sum = 0, sum2 = 0;
 
     for (int i = 0; i < count; i++) {
@@ -75,7 +57,7 @@ public class TestLongHeap extends LuceneTestCase {
   }
 
   public void testClear() {
-    LongHeap pq = LongHeap.create(MIN, 3);
+    LongHeap pq = new LongHeap(3);
     pq.push(2);
     pq.push(3);
     pq.push(1);
@@ -85,7 +67,7 @@ public class TestLongHeap extends LuceneTestCase {
   }
 
   public void testExceedBounds() {
-    LongHeap pq = LongHeap.create(MIN, 1);
+    LongHeap pq = new LongHeap(1);
     pq.push(2);
     pq.push(0);
     // expectThrows(ArrayIndexOutOfBoundsException.class, () -> pq.push(0));
@@ -94,7 +76,7 @@ public class TestLongHeap extends LuceneTestCase {
   }
 
   public void testFixedSize() {
-    LongHeap pq = LongHeap.create(MIN, 3);
+    LongHeap pq = new LongHeap(3);
     pq.insertWithOverflow(2);
     pq.insertWithOverflow(3);
     pq.insertWithOverflow(1);
@@ -105,20 +87,8 @@ public class TestLongHeap extends LuceneTestCase {
     assertEquals(3, pq.top());
   }
 
-  public void testFixedSizeMax() {
-    LongHeap pq = LongHeap.create(MAX, 3);
-    pq.insertWithOverflow(2);
-    pq.insertWithOverflow(3);
-    pq.insertWithOverflow(1);
-    pq.insertWithOverflow(5);
-    pq.insertWithOverflow(7);
-    pq.insertWithOverflow(1);
-    assertEquals(3, pq.size());
-    assertEquals(2, pq.top());
-  }
-
   public void testDuplicateValues() {
-    LongHeap pq = LongHeap.create(MIN, 3);
+    LongHeap pq = new LongHeap(3);
     pq.push(2);
     pq.push(3);
     pq.push(1);
@@ -131,7 +101,7 @@ public class TestLongHeap extends LuceneTestCase {
   public void testInsertions() {
     Random random = random();
     int numDocsInPQ = TestUtil.nextInt(random, 1, 100);
-    AssertingLongHeap pq = new AssertingLongHeap(numDocsInPQ);
+    LongHeap pq = new LongHeap(numDocsInPQ);
     Long lastLeast = null;
 
     // Basic insertion of new content
@@ -140,7 +110,7 @@ public class TestLongHeap extends LuceneTestCase {
       long newEntry = Math.abs(random.nextLong());
       sds.add(newEntry);
       pq.insertWithOverflow(newEntry);
-      pq.checkValidity();
+      checkValidity(pq);
       long newLeast = pq.top();
       if ((lastLeast != null) && (newLeast != newEntry) && (newLeast != lastLeast)) {
         // If there has been a change of least entry and it wasn't our new
@@ -153,17 +123,16 @@ public class TestLongHeap extends LuceneTestCase {
   }
 
   public void testInvalid() {
-    expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, -1));
-    expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, 0));
-    expectThrows(
-        IllegalArgumentException.class, () -> LongHeap.create(MAX, ArrayUtil.MAX_ARRAY_LENGTH));
+    expectThrows(IllegalArgumentException.class, () -> new LongHeap(-1));
+    expectThrows(IllegalArgumentException.class, () -> new LongHeap(0));
+    expectThrows(IllegalArgumentException.class, () -> new LongHeap(ArrayUtil.MAX_ARRAY_LENGTH));
   }
 
   public void testUnbounded() {
     int initialSize = random().nextInt(10) + 1;
-    LongHeap pq = LongHeap.create(MAX, initialSize);
+    LongHeap pq = new LongHeap(initialSize);
     int num = random().nextInt(100) + 1;
-    long minValue = Long.MAX_VALUE;
+    long maxValue = Long.MIN_VALUE;
     int count = 0;
     for (int i = 0; i < num; i++) {
       long value = random().nextLong();
@@ -178,19 +147,19 @@ public class TestLongHeap extends LuceneTestCase {
           }
         }
       }
-      minValue = Math.min(minValue, value);
+      maxValue = Math.max(maxValue, value);
     }
     assertEquals(count, pq.size());
-    long last = Long.MAX_VALUE;
+    long last = Long.MIN_VALUE;
     while (pq.size() > 0) {
       long top = pq.top();
       long next = pq.pop();
       assertEquals(top, next);
       --count;
-      assertTrue(next <= last);
+      assertTrue(next >= last);
       last = next;
     }
     assertEquals(0, count);
-    assertEquals(minValue, last);
+    assertEquals(maxValue, last);
   }
 }