You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@lucene.apache.org by GitBox <gi...@apache.org> on 2021/04/14 19:20:41 UTC

[GitHub] [lucene] msokolov commented on a change in pull request #83: LUCENE-9798 : Fix looping bug and made Full Knn calculation parallelizable

msokolov commented on a change in pull request #83:
URL: https://github.com/apache/lucene/pull/83#discussion_r613192059



##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);
+        float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
+        this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
+        this.currDocIndex++;
+      }
+    }
+  }
+
+  /**
+   * computes the exact KNN match for each query vector in queryPath for all the document vectors in
+   * docPath
+   *
+   * @param docPath : path to the file containing the float 32 document vectors in bytes with
+   *     little-endian byte order
+   * @param queryPath : path to the file containing the containing 32-bit floating point vectors in
+   *     little-endian byte order
+   * @param numThreads : create numThreads to parallelize work
+   * @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
+   *     an array containing the indexes of the topK most similar document vectors to the ith query
+   *     vector, and is sorted by similarity, with the most similar vector first. Similarity is
+   *     defined by the searchStrategy used to construct this FullKnn.
+   * @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
+   *     IOException : In case of IO exception while reading files.
+   */
+  public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
+    assert numThreads > 0;
+    final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
+    final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
+
+    if (!quiet) {
+      System.out.println(
+          "computing true nearest neighbors of "
+              + numQueries
+              + " target vectors using "
+              + numThreads
+              + " threads.");
+    }
+
+    try (FileChannel docInput = FileChannel.open(docPath);
+        FileChannel queryInput = FileChannel.open(queryPath)) {
+      return doFullKnn(
+          numDocs,
+          numQueries,
+          numThreads,
+          new FileChannelBufferProvider(docInput),
+          new FileChannelBufferProvider(queryInput));
+    }
+  }
+
+  int[][] doFullKnn(
+      int numDocs,
+      int numQueries,
+      int numThreads,
+      BufferProvider docInput,
+      BufferProvider queryInput)
+      throws IOException {
+    if (numDocs < topK) {
+      throw new IllegalArgumentException(
+          String.format(
+              Locale.ROOT,
+              "topK (%d) cannot be greater than number of docs in docPath (%d)",
+              topK,
+              numDocs));
+    }
+
+    final ExecutorService executorService =
+        Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
+    int[][] result = new int[numQueries][];
+
+    FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
+    float[] query = new float[dim];
+    List<KnnJob> jobList = new ArrayList<>(numThreads);
+    for (int i = 0; i < numQueries; ) {
+
+      for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
+        queries.get(query);
+        jobList.add(
+            new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
+      }
+
+      long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
+      int docsLeft = numDocs;
+      int currDocIndex = 0;
+      int offset = 0;
+      while (docsLeft > 0) {
+        long totalBytes = (long) docsLeft * dim * Float.BYTES;
+        int blockSize = (int) Math.min(totalBytes, maxBufferSize);
+
+        FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
+        offset += blockSize;
+
+        final List<CompletableFuture<Void>> completableFutures =
+            jobList.stream()
+                .peek(job -> job.docVectors = docVectors.duplicate())

Review comment:
       what is peek doing here? I'm not familiar with this idiom, sorry. Would this be clearer as a for-loop?

##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);
+        float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
+        this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
+        this.currDocIndex++;
+      }
+    }
+  }
+
+  /**
+   * computes the exact KNN match for each query vector in queryPath for all the document vectors in
+   * docPath
+   *
+   * @param docPath : path to the file containing the float 32 document vectors in bytes with
+   *     little-endian byte order
+   * @param queryPath : path to the file containing the containing 32-bit floating point vectors in
+   *     little-endian byte order
+   * @param numThreads : create numThreads to parallelize work
+   * @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
+   *     an array containing the indexes of the topK most similar document vectors to the ith query
+   *     vector, and is sorted by similarity, with the most similar vector first. Similarity is
+   *     defined by the searchStrategy used to construct this FullKnn.
+   * @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
+   *     IOException : In case of IO exception while reading files.
+   */
+  public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
+    assert numThreads > 0;
+    final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
+    final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
+
+    if (!quiet) {
+      System.out.println(
+          "computing true nearest neighbors of "
+              + numQueries
+              + " target vectors using "
+              + numThreads
+              + " threads.");
+    }
+
+    try (FileChannel docInput = FileChannel.open(docPath);
+        FileChannel queryInput = FileChannel.open(queryPath)) {
+      return doFullKnn(
+          numDocs,
+          numQueries,
+          numThreads,
+          new FileChannelBufferProvider(docInput),
+          new FileChannelBufferProvider(queryInput));
+    }
+  }
+
+  int[][] doFullKnn(
+      int numDocs,
+      int numQueries,
+      int numThreads,
+      BufferProvider docInput,
+      BufferProvider queryInput)
+      throws IOException {
+    if (numDocs < topK) {
+      throw new IllegalArgumentException(
+          String.format(
+              Locale.ROOT,
+              "topK (%d) cannot be greater than number of docs in docPath (%d)",
+              topK,
+              numDocs));
+    }
+
+    final ExecutorService executorService =
+        Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
+    int[][] result = new int[numQueries][];
+
+    FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
+    float[] query = new float[dim];
+    List<KnnJob> jobList = new ArrayList<>(numThreads);
+    for (int i = 0; i < numQueries; ) {
+
+      for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
+        queries.get(query);
+        jobList.add(
+            new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
+      }
+
+      long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
+      int docsLeft = numDocs;
+      int currDocIndex = 0;
+      int offset = 0;
+      while (docsLeft > 0) {
+        long totalBytes = (long) docsLeft * dim * Float.BYTES;
+        int blockSize = (int) Math.min(totalBytes, maxBufferSize);
+
+        FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
+        offset += blockSize;
+
+        final List<CompletableFuture<Void>> completableFutures =
+            jobList.stream()
+                .peek(job -> job.docVectors = docVectors.duplicate())
+                .peek(job -> job.currDocIndex = currDocIndex)
+                .map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
+                .collect(Collectors.toList());
+
+        CompletableFuture.allOf(
+                completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
+            .join();
+        docsLeft -= (blockSize / (dim * Float.BYTES));
+      }
+
+      jobList.forEach(
+          job -> {
+            result[job.queryIndex] = new int[topK];
+            for (int k = topK - 1; k >= 0; k--) {
+              result[job.queryIndex][k] = popNodeId(job.queue);
+              // System.out.print(" " + n);
+            }
+            if (!quiet && (job.queryIndex + 1) % 10 == 0) {
+              System.out.print(" " + (job.queryIndex + 1));
+              System.out.flush();
+            }
+          });
+
+      jobList.clear();
+    }
+
+    executorService.shutdown();
+    try {
+      if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
+        executorService.shutdownNow();
+      }
+    } catch (InterruptedException e) {
+      executorService.shutdownNow();
+      throw new RuntimeException(
+          "Exception occured while waiting for executor service to finish.", e);
+    }
+
+    return result;
+  }
+
+  /**
+   * pops the queue and returns the last 4 bytes of long where nodeId is stored.
+   *
+   * @param queue queue from which to pop the node id
+   * @return the node id
+   */
+  private int popNodeId(LongHeap queue) {
+    return (int) queue.pop();
+  }
+
+  /**
+   * encodes the score and nodeId into a single long with score in first 4 bytes, to make it
+   * sortable by score.
+   *
+   * @param node node id
+   * @param score float score of the node wrt incoming node/query
+   * @return a score sortable long that can be used in heap
+   */
+  private static long encodeNodeIdAndScore(int node, float score) {
+    return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
+  }
+
+  interface BufferProvider {

Review comment:
       Why do we need an interface here? There is no other implementation

##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);

Review comment:
       we don't typically use `this.` except in constructors where the name is shadowed by a parameter. IMO it just clutters up the code; do you mind removing?

##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);
+        float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
+        this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
+        this.currDocIndex++;
+      }
+    }
+  }
+
+  /**
+   * computes the exact KNN match for each query vector in queryPath for all the document vectors in
+   * docPath
+   *
+   * @param docPath : path to the file containing the float 32 document vectors in bytes with
+   *     little-endian byte order
+   * @param queryPath : path to the file containing the containing 32-bit floating point vectors in
+   *     little-endian byte order
+   * @param numThreads : create numThreads to parallelize work
+   * @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
+   *     an array containing the indexes of the topK most similar document vectors to the ith query
+   *     vector, and is sorted by similarity, with the most similar vector first. Similarity is
+   *     defined by the searchStrategy used to construct this FullKnn.
+   * @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
+   *     IOException : In case of IO exception while reading files.
+   */
+  public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
+    assert numThreads > 0;
+    final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
+    final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
+
+    if (!quiet) {
+      System.out.println(
+          "computing true nearest neighbors of "
+              + numQueries
+              + " target vectors using "
+              + numThreads
+              + " threads.");
+    }
+
+    try (FileChannel docInput = FileChannel.open(docPath);
+        FileChannel queryInput = FileChannel.open(queryPath)) {
+      return doFullKnn(
+          numDocs,
+          numQueries,
+          numThreads,
+          new FileChannelBufferProvider(docInput),
+          new FileChannelBufferProvider(queryInput));
+    }
+  }
+
+  int[][] doFullKnn(
+      int numDocs,
+      int numQueries,
+      int numThreads,
+      BufferProvider docInput,
+      BufferProvider queryInput)
+      throws IOException {
+    if (numDocs < topK) {
+      throw new IllegalArgumentException(
+          String.format(
+              Locale.ROOT,
+              "topK (%d) cannot be greater than number of docs in docPath (%d)",
+              topK,
+              numDocs));
+    }
+
+    final ExecutorService executorService =
+        Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
+    int[][] result = new int[numQueries][];
+
+    FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
+    float[] query = new float[dim];
+    List<KnnJob> jobList = new ArrayList<>(numThreads);
+    for (int i = 0; i < numQueries; ) {
+
+      for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
+        queries.get(query);
+        jobList.add(
+            new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
+      }
+
+      long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
+      int docsLeft = numDocs;
+      int currDocIndex = 0;
+      int offset = 0;
+      while (docsLeft > 0) {
+        long totalBytes = (long) docsLeft * dim * Float.BYTES;
+        int blockSize = (int) Math.min(totalBytes, maxBufferSize);
+
+        FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
+        offset += blockSize;
+
+        final List<CompletableFuture<Void>> completableFutures =
+            jobList.stream()
+                .peek(job -> job.docVectors = docVectors.duplicate())
+                .peek(job -> job.currDocIndex = currDocIndex)
+                .map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
+                .collect(Collectors.toList());
+
+        CompletableFuture.allOf(
+                completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
+            .join();
+        docsLeft -= (blockSize / (dim * Float.BYTES));
+      }
+
+      jobList.forEach(

Review comment:
       Why use a stream when a for loop would do? I guess it's a stylistic choice, but this code base tends to use the traditional iterator statements

##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);
+        float d = this.searchStrategy.compare(this.queryVector, this.currDocVector);
+        this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, d));
+        this.currDocIndex++;
+      }
+    }
+  }
+
+  /**
+   * computes the exact KNN match for each query vector in queryPath for all the document vectors in
+   * docPath
+   *
+   * @param docPath : path to the file containing the float 32 document vectors in bytes with
+   *     little-endian byte order
+   * @param queryPath : path to the file containing the containing 32-bit floating point vectors in
+   *     little-endian byte order
+   * @param numThreads : create numThreads to parallelize work
+   * @return : returns an int 2D array ( int matches[][]) of size 'numIters x topK'. matches[i] is
+   *     an array containing the indexes of the topK most similar document vectors to the ith query
+   *     vector, and is sorted by similarity, with the most similar vector first. Similarity is
+   *     defined by the searchStrategy used to construct this FullKnn.
+   * @throws IllegalArgumentException : if topK is greater than number of documents in docPath file
+   *     IOException : In case of IO exception while reading files.
+   */
+  public int[][] computeNN(Path docPath, Path queryPath, int numThreads) throws IOException {
+    assert numThreads > 0;
+    final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
+    final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
+
+    if (!quiet) {
+      System.out.println(
+          "computing true nearest neighbors of "
+              + numQueries
+              + " target vectors using "
+              + numThreads
+              + " threads.");
+    }
+
+    try (FileChannel docInput = FileChannel.open(docPath);
+        FileChannel queryInput = FileChannel.open(queryPath)) {
+      return doFullKnn(
+          numDocs,
+          numQueries,
+          numThreads,
+          new FileChannelBufferProvider(docInput),
+          new FileChannelBufferProvider(queryInput));
+    }
+  }
+
+  int[][] doFullKnn(
+      int numDocs,
+      int numQueries,
+      int numThreads,
+      BufferProvider docInput,
+      BufferProvider queryInput)
+      throws IOException {
+    if (numDocs < topK) {
+      throw new IllegalArgumentException(
+          String.format(
+              Locale.ROOT,
+              "topK (%d) cannot be greater than number of docs in docPath (%d)",
+              topK,
+              numDocs));
+    }
+
+    final ExecutorService executorService =
+        Executors.newFixedThreadPool(numThreads, new NamedThreadFactory("FullKnnExecutor"));
+    int[][] result = new int[numQueries][];
+
+    FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * Float.BYTES).asFloatBuffer();
+    float[] query = new float[dim];
+    List<KnnJob> jobList = new ArrayList<>(numThreads);
+    for (int i = 0; i < numQueries; ) {
+
+      for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
+        queries.get(query);
+        jobList.add(
+            new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), topK, searchStrategy));
+      }
+
+      long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
+      int docsLeft = numDocs;
+      int currDocIndex = 0;
+      int offset = 0;
+      while (docsLeft > 0) {
+        long totalBytes = (long) docsLeft * dim * Float.BYTES;
+        int blockSize = (int) Math.min(totalBytes, maxBufferSize);
+
+        FloatBuffer docVectors = docInput.getBuffer(offset, blockSize).asFloatBuffer();
+        offset += blockSize;
+
+        final List<CompletableFuture<Void>> completableFutures =
+            jobList.stream()
+                .peek(job -> job.docVectors = docVectors.duplicate())
+                .peek(job -> job.currDocIndex = currDocIndex)
+                .map(job -> CompletableFuture.runAsync(() -> job.execute(), executorService))
+                .collect(Collectors.toList());
+
+        CompletableFuture.allOf(
+                completableFutures.toArray(new CompletableFuture<?>[completableFutures.size()]))
+            .join();
+        docsLeft -= (blockSize / (dim * Float.BYTES));
+      }
+
+      jobList.forEach(
+          job -> {
+            result[job.queryIndex] = new int[topK];
+            for (int k = topK - 1; k >= 0; k--) {
+              result[job.queryIndex][k] = popNodeId(job.queue);
+              // System.out.print(" " + n);
+            }
+            if (!quiet && (job.queryIndex + 1) % 10 == 0) {
+              System.out.print(" " + (job.queryIndex + 1));
+              System.out.flush();
+            }
+          });
+
+      jobList.clear();
+    }
+
+    executorService.shutdown();
+    try {
+      if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
+        executorService.shutdownNow();

Review comment:
       You could simplify by moving this to a finally block rather than detecting both cases where it needs to be called. I think it will be safe to call even if the service was successfully shut down, or if not, could protect with call to isShutdown.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org