You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by dw...@apache.org on 2022/04/29 19:41:24 UTC

[lucene] branch branch_9x updated: LUCENE-10539: Return a stream of completions from FSTCompletion. (#844)

This is an automated email from the ASF dual-hosted git repository.

dweiss pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new d0ce05888d6 LUCENE-10539: Return a stream of completions from FSTCompletion. (#844)
d0ce05888d6 is described below

commit d0ce05888d62eadf4e1172339338a389b37dad41
Author: Dawid Weiss <da...@carrotsearch.com>
AuthorDate: Fri Apr 29 21:35:35 2022 +0200

    LUCENE-10539: Return a stream of completions from FSTCompletion. (#844)
---
 lucene/CHANGES.txt                                 |   2 +
 .../lucene/search/suggest/fst/FSTCompletion.java   | 292 ++++++++++++---------
 .../search/suggest/fst/TestFSTCompletion.java      |  38 ++-
 3 files changed, 203 insertions(+), 129 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 5e53ac5cca9..f71afab5f65 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -17,6 +17,8 @@ API Changes
 New Features
 ---------------------
 
+* LUCENE-10539: Return a stream of completions from FSTCompletion. (Dawid Weiss)
+
 * LUCENE-10385: Implement Weight#count on IndexSortSortedNumericDocValuesRangeQuery
   to speed up computing the number of hits when possible. (Lu Xugang, Luca Cavanna, Adrien Grand)
 
diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/fst/FSTCompletion.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/fst/FSTCompletion.java
index 87009dedc53..12fd572846b 100644
--- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/fst/FSTCompletion.java
+++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/fst/FSTCompletion.java
@@ -17,8 +17,19 @@
 package org.apache.lucene.search.suggest.fst;
 
 import java.io.IOException;
-import java.util.*;
-import org.apache.lucene.util.*;
+import java.io.UncheckedIOException;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Spliterator;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.fst.FST;
 import org.apache.lucene.util.fst.FST.Arc;
 
@@ -53,7 +64,7 @@ public class FSTCompletion {
       return utf8.utf8ToString() + "/" + bucket;
     }
 
-    /** @see BytesRef#compareTo(BytesRef) */
+    /** Completions are equal when their {@link #utf8} images are equal (bucket is not compared). */
     @Override
     public int compareTo(Completion o) {
       return this.utf8.compareTo(o.utf8);
@@ -184,110 +195,174 @@ public class FSTCompletion {
       return EMPTY_RESULT;
     }
 
-    try {
-      BytesRef keyUtf8 = new BytesRef(key);
-      if (!higherWeightsFirst && rootArcs.length > 1) {
-        // We could emit a warning here (?). An optimal strategy for
-        // alphabetically sorted
-        // suggestions would be to add them with a constant weight -- this saves
-        // unnecessary
-        // traversals and sorting.
-        return lookupSortedAlphabetically(keyUtf8, num);
-      } else {
-        return lookupSortedByWeight(keyUtf8, num, false);
-      }
-    } catch (IOException e) {
-      // Should never happen, but anyway.
-      throw new RuntimeException(e);
+    if (!higherWeightsFirst && rootArcs.length > 1) {
+      // We could emit a warning here (?). An optimal strategy for
+      // alphabetically sorted
+      // suggestions would be to add them with a constant weight -- this saves
+      // unnecessary
+      // traversals and sorting.
+      return lookup(key).sorted().limit(num).collect(Collectors.toList());
+    } else {
+      return lookup(key).limit(num).collect(Collectors.toList());
     }
   }
 
   /**
-   * Lookup suggestions sorted alphabetically <b>if weights are not constant</b>. This is a
-   * workaround: in general, use constant weights for alphabetically sorted result.
+   * Lookup suggestions to <code>key</code> and return a stream of matching completions. The stream
+   * fetches completions dynamically - it can be filtered and limited to acquire the desired number
+   * of completions without collecting all of them.
+   *
+   * @param key The prefix to which suggestions should be sought.
+   * @return Returns the suggestions
    */
-  private List<Completion> lookupSortedAlphabetically(BytesRef key, int num) throws IOException {
-    // Greedily get num results from each weight branch.
-    List<Completion> res = lookupSortedByWeight(key, num, true);
-
-    // Sort and trim.
-    Collections.sort(res);
-    if (res.size() > num) {
-      res = res.subList(0, num);
+  public Stream<Completion> lookup(CharSequence key) {
+    if (key.length() == 0 || automaton == null) {
+      return Stream.empty();
+    }
+
+    try {
+      return lookupSortedByWeight(new BytesRef(key));
+    } catch (IOException e) {
+      throw new RuntimeException(e);
     }
-    return res;
   }
 
-  /**
-   * Lookup suggestions sorted by weight (descending order).
-   *
-   * @param collectAll If <code>true</code>, the routine terminates immediately when <code>num
-   *     </code> suggestions have been collected. If <code>false</code>, it will collect suggestions
-   *     from all weight arcs (needed for {@link #lookupSortedAlphabetically}.
-   */
-  private ArrayList<Completion> lookupSortedByWeight(BytesRef key, int num, boolean collectAll)
-      throws IOException {
-    // Don't overallocate the results buffers. This also serves the purpose of
-    // allowing the user of this class to request all matches using Integer.MAX_VALUE as
-    // the number of results.
-    final ArrayList<Completion> res = new ArrayList<>(Math.min(10, num));
-
-    final BytesRef output = BytesRef.deepCopyOf(key);
-    for (int i = 0; i < rootArcs.length; i++) {
-      final FST.Arc<Object> rootArc = rootArcs[i];
-      final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
-
-      // Descend into the automaton using the key as prefix.
-      if (descendWithPrefix(arc, key)) {
-        // A subgraph starting from the current node has the completions
-        // of the key prefix. The arc we're at is the last key's byte,
-        // so we will collect it too.
-        output.length = key.length - 1;
-        if (collect(res, num, rootArc.label(), output, arc) && !collectAll) {
-          // We have enough suggestions to return immediately. Keep on looking
-          // for an
-          // exact match, if requested.
-          if (exactFirst) {
-            if (!checkExistingAndReorder(res, key)) {
-              int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
-              if (exactMatchBucket != -1) {
-                // Insert as the first result and truncate at num.
-                while (res.size() >= num) {
-                  res.remove(res.size() - 1);
-                }
-                res.add(0, new Completion(key, exactMatchBucket));
-              }
-            }
-          }
+  /** Lookup suggestions sorted by weight (descending order). */
+  private Stream<Completion> lookupSortedByWeight(BytesRef key) throws IOException {
+
+    // Look for an exact match first.
+    Completion exactCompletion;
+    if (exactFirst) {
+      Completion c = null;
+      for (int i = 0; i < rootArcs.length; i++) {
+        int exactMatchBucket = getExactMatchStartingFromRootArc(i, key);
+        if (exactMatchBucket != -1) {
+          // root arcs are sorted by decreasing weight so any first exact match will always win.
+          c = new Completion(key, exactMatchBucket);
           break;
         }
       }
+      exactCompletion = c;
+    } else {
+      exactCompletion = null;
+    }
+
+    Stream<Completion> stream =
+        IntStream.range(0, rootArcs.length)
+            .boxed()
+            .flatMap(
+                i -> {
+                  try {
+                    final FST.Arc<Object> rootArc = rootArcs[i];
+                    final FST.Arc<Object> arc = new FST.Arc<>().copyFrom(rootArc);
+                    if (descendWithPrefix(arc, key)) {
+                      // A subgraph starting from the current node has the completions
+                      // of the key prefix. The arc we're at is the last key's byte,
+                      // so we will collect it too.
+                      final BytesRef output = BytesRef.deepCopyOf(key);
+                      output.length = key.length;
+                      return completionStream(output, rootArc.label(), arc);
+                    } else {
+                      return Stream.empty();
+                    }
+                  } catch (IOException e) {
+                    throw new UncheckedIOException(e);
+                  }
+                });
+
+    // if requested, return the exact completion first and omit it in any further completions.
+    if (exactFirst && exactCompletion != null) {
+      stream =
+          Stream.concat(
+              Stream.of(exactCompletion),
+              stream.filter(completion -> exactCompletion.compareTo(completion) != 0));
     }
-    return res;
+    return stream;
   }
 
-  /**
-   * Checks if the list of
-   * {@link org.apache.lucene.search.suggest.Lookup.LookupResult}s already has a
-   * <code>key</code>. If so, reorders that
-   * {@link org.apache.lucene.search.suggest.Lookup.LookupResult} to the first
-   * position.
-   *
-   * @return Returns <code>true<code> if and only if <code>list</code> contained
-   *         <code>key</code>.
-   */
-  private boolean checkExistingAndReorder(ArrayList<Completion> list, BytesRef key) {
-    // We assume list does not have duplicates (because of how the FST is created).
-    for (int i = list.size(); --i >= 0; ) {
-      if (key.equals(list.get(i).utf8)) {
-        // Key found. Unless already at i==0, remove it and push up front so
-        // that the ordering
-        // remains identical with the exception of the exact match.
-        list.add(0, list.remove(i));
-        return true;
+  /** Return a stream of all completions starting from the provided arc. */
+  private Stream<? extends Completion> completionStream(
+      BytesRef output, int bucket, Arc<Object> fromArc) throws IOException {
+
+    FST.BytesReader fstReader = automaton.getBytesReader();
+
+    class State {
+      Arc<Object> arc;
+      int outputLength;
+
+      State(Arc<Object> arc, int outputLength) throws IOException {
+        this.arc = automaton.readFirstTargetArc(arc, new Arc<>(), fstReader);
+        this.outputLength = outputLength;
       }
     }
-    return false;
+
+    ArrayDeque<State> states = new ArrayDeque<>();
+    states.addLast(new State(fromArc, output.length));
+
+    return StreamSupport.stream(
+        new Spliterator<>() {
+          @Override
+          public boolean tryAdvance(Consumer<? super Completion> action) {
+            try {
+              while (!states.isEmpty()) {
+                var state = states.peekLast();
+                output.length = state.outputLength;
+                var arc = state.arc;
+                var arcLabel = arc.label();
+
+                if (arcLabel == FST.END_LABEL) {
+                  Completion completion = new Completion(output, bucket);
+                  action.accept(completion);
+
+                  if (arc.isLast()) {
+                    states.removeLast();
+                  } else {
+                    automaton.readNextArc(arc, fstReader);
+                  }
+
+                  return true;
+                } else {
+                  assert output.offset == 0;
+                  if (output.length == output.bytes.length) {
+                    output.bytes = ArrayUtil.grow(output.bytes);
+                  }
+                  output.bytes[output.length++] = (byte) arcLabel;
+
+                  State newState = new State(arc, output.length);
+
+                  if (arc.isLast()) {
+                    states.removeLast();
+                  } else {
+                    automaton.readNextArc(arc, fstReader);
+                  }
+
+                  states.addLast(newState);
+                }
+              }
+
+              return false;
+            } catch (IOException e) {
+              throw new UncheckedIOException(e);
+            }
+          }
+
+          @Override
+          public Spliterator<Completion> trySplit() {
+            // Don't try to split.
+            return null;
+          }
+
+          @Override
+          public long estimateSize() {
+            return Long.MAX_VALUE;
+          }
+
+          @Override
+          public int characteristics() {
+            return Spliterator.NONNULL | Spliterator.ORDERED;
+          }
+        },
+        false);
   }
 
   /**
@@ -312,41 +387,6 @@ public class FSTCompletion {
     return true;
   }
 
-  /**
-   * Recursive collect lookup results from the automaton subgraph starting at <code>arc</code>.
-   *
-   * @param num Maximum number of results needed (early termination).
-   */
-  private boolean collect(
-      List<Completion> res, int num, int bucket, BytesRef output, Arc<Object> arc)
-      throws IOException {
-    if (output.length == output.bytes.length) {
-      output.bytes = ArrayUtil.grow(output.bytes);
-    }
-    assert output.offset == 0;
-    output.bytes[output.length++] = (byte) arc.label();
-    FST.BytesReader fstReader = automaton.getBytesReader();
-    automaton.readFirstTargetArc(arc, arc, fstReader);
-    while (true) {
-      if (arc.label() == FST.END_LABEL) {
-        res.add(new Completion(output, bucket));
-        if (res.size() >= num) return true;
-      } else {
-        int save = output.length;
-        if (collect(res, num, bucket, output, new Arc<>().copyFrom(arc))) {
-          return true;
-        }
-        output.length = save;
-      }
-
-      if (arc.isLast()) {
-        break;
-      }
-      automaton.readNextArc(arc, fstReader);
-    }
-    return false;
-  }
-
   /** Returns the bucket count (discretization thresholds). */
   public int getBucketCount() {
     return rootArcs.length;
diff --git a/lucene/suggest/src/test/org/apache/lucene/search/suggest/fst/TestFSTCompletion.java b/lucene/suggest/src/test/org/apache/lucene/search/suggest/fst/TestFSTCompletion.java
index 76e52afa5ea..6e8508feef8 100644
--- a/lucene/suggest/src/test/org/apache/lucene/search/suggest/fst/TestFSTCompletion.java
+++ b/lucene/suggest/src/test/org/apache/lucene/search/suggest/fst/TestFSTCompletion.java
@@ -17,14 +17,23 @@
 package org.apache.lucene.search.suggest.fst;
 
 import java.nio.charset.StandardCharsets;
-import java.util.*;
-import org.apache.lucene.search.suggest.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Random;
+import java.util.stream.Collectors;
+import org.apache.lucene.search.suggest.Input;
+import org.apache.lucene.search.suggest.InputArrayIterator;
 import org.apache.lucene.search.suggest.Lookup.LookupResult;
+import org.apache.lucene.search.suggest.SuggestRebuildTestUtil;
+import org.apache.lucene.search.suggest.TestLookupBenchmark;
 import org.apache.lucene.search.suggest.fst.FSTCompletion.Completion;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
-import org.apache.lucene.util.*;
+import org.apache.lucene.util.BytesRef;
 
 /** Unit tests for {@link FSTCompletion}. */
 public class TestFSTCompletion extends LuceneTestCase {
@@ -81,6 +90,20 @@ public class TestFSTCompletion extends LuceneTestCase {
     assertMatchEquals(completion.lookup(stringToCharSequence("one"), 2), "one/0.0", "oneness/1.0");
   }
 
+  public void testCompletionStream() throws Exception {
+    var completions =
+        completion
+            .lookup("fo")
+            .filter(completion -> !completion.utf8.utf8ToString().contains("fourteen"))
+            .sorted(
+                Comparator.comparing(
+                    completion -> completion.utf8.utf8ToString().toLowerCase(Locale.ROOT)))
+            .collect(Collectors.toList());
+
+    assertMatchEquals(
+        completions, "foundation/1", "four/0", "fourblah/1", "fourier/0", "fourty/1.0");
+  }
+
   public void testExactMatchReordering() throws Exception {
     // Check reordering of exact matches.
     assertMatchEquals(
@@ -130,8 +153,17 @@ public class TestFSTCompletion extends LuceneTestCase {
   }
 
   public void testFullMatchList() throws Exception {
+    // one/0.0 is returned first because it's an exact match.
     assertMatchEquals(
         completion.lookup(stringToCharSequence("one"), Integer.MAX_VALUE),
+        "one/0.0",
+        "oneness/1.0",
+        "onerous/1.0",
+        "onesimus/1.0");
+
+    // full sorted order by weight+alphabetical.
+    assertMatchEquals(
+        completion.lookup(stringToCharSequence("on"), Integer.MAX_VALUE),
         "oneness/1.0",
         "onerous/1.0",
         "onesimus/1.0",