You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ab...@apache.org on 2022/06/29 09:37:47 UTC
[lucene] 02/02: LUCENE-10593: VectorSimilarityFunction reverse removal (#926)
This is an automated email from the ASF dual-hosted git repository.
abenedetti pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
commit b3b7098cd9636c5ad2516055f768dd29b795a05d
Author: Alessandro Benedetti <a....@sease.io>
AuthorDate: Tue Jun 28 15:33:11 2022 +0200
LUCENE-10593: VectorSimilarityFunction reverse removal (#926)
* Vector Similarity Function reverse property removed
* NeighborQueue tie-breaking fixed (node id + node score encoding)
* NeighborQueue readability refactor
* BoundChecker removal (now it's only in backward-codecs)
---
lucene/CHANGES.txt | 2 +
.../lucene90/Lucene90BoundsChecker.java} | 27 ++++++++---
.../lucene90/Lucene90HnswGraphBuilder.java | 20 ++++----
.../lucene90/Lucene90HnswVectorsReader.java | 4 +-
.../lucene90/Lucene90OnHeapHnswGraph.java | 19 ++++----
.../lucene91/Lucene91BoundsChecker.java} | 27 ++++++++---
.../lucene91/Lucene91HnswGraphBuilder.java | 25 +++++-----
.../lucene91/Lucene91HnswVectorsReader.java | 4 +-
.../lucene92/Lucene92HnswVectorsReader.java | 4 +-
.../simpletext/SimpleTextKnnVectorsReader.java | 2 +-
.../codecs/lucene93/Lucene93HnswVectorsReader.java | 2 +-
.../lucene/index/VectorSimilarityFunction.java | 52 +++------------------
.../org/apache/lucene/search/KnnVectorQuery.java | 2 +-
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 24 +++++-----
.../apache/lucene/util/hnsw/HnswGraphSearcher.java | 22 ++++-----
.../org/apache/lucene/util/hnsw/NeighborQueue.java | 54 ++++++++++++++++++----
.../apache/lucene/util/hnsw/OnHeapHnswGraph.java | 10 ++--
.../apache/lucene/util/hnsw/KnnGraphTester.java | 5 +-
.../org/apache/lucene/util/hnsw/TestHnswGraph.java | 26 +----------
19 files changed, 161 insertions(+), 170 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 41fad11b781..7c0c35e6c6d 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -103,6 +103,8 @@ Optimizations
* LUCENE-10606: For KnnVectorQuery, optimize case where filter is backed by BitSetIterator (Kaival Parikh)
+* LUCENE-10593: Vector similarity function and NeighborQueue reverse removal. (Alessandro Benedetti)
+
Bug Fixes
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90BoundsChecker.java
similarity index 77%
copy from lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
copy to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90BoundsChecker.java
index 9cde17db421..e181dc0ec57 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90BoundsChecker.java
@@ -14,17 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-package org.apache.lucene.util.hnsw;
+package org.apache.lucene.backward_codecs.lucene90;
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
* with a new value.
*/
-public abstract class BoundsChecker {
+public abstract class Lucene90BoundsChecker {
float bound;
+ /** Default Constructor */
+ public Lucene90BoundsChecker() {}
+
/** Update the bound if sample is better */
public abstract void update(float sample);
@@ -33,10 +35,21 @@ public abstract class BoundsChecker {
bound = sample;
}
- /** @return whether the sample exceeds (is worse than) the bound */
+ /**
+ * Check the sample
+ *
+ * @param sample a score
+ * @return whether the sample exceeds (is worse than) the bound
+ */
public abstract boolean check(float sample);
- public static BoundsChecker create(boolean reversed) {
+ /**
+ * Create a min or max bound checker
+ *
+ * @param reversed true for the min and false for the max
+ * @return the bound checker
+ */
+ public static Lucene90BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
@@ -48,7 +61,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
* with a new value.
*/
- public static class Max extends BoundsChecker {
+ public static class Max extends Lucene90BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
@@ -70,7 +83,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
* with a new value.
*/
- public static class Min extends BoundsChecker {
+ public static class Min extends Lucene90BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index bf07cd44f8c..d4c09455c7a 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -25,7 +25,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
-import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
@@ -51,7 +50,7 @@ public final class Lucene90HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
- private final BoundsChecker bound;
+ private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
@@ -91,7 +90,7 @@ public final class Lucene90HnswGraphBuilder {
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
- bound = BoundsChecker.create(similarityFunction.reversed);
+ bound = Lucene90BoundsChecker.create(false);
random = new SplittableRandom(seed);
scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@@ -234,9 +233,9 @@ public final class Lucene90HnswGraphBuilder {
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
- float diversityCheck =
+ float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node()[i]));
- if (bound.check(diversityCheck) == false) {
+ if (bound.check(neighborSimilarity) == false) {
return false;
}
}
@@ -267,13 +266,14 @@ public final class Lucene90HnswGraphBuilder {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it
- int nbrNode = neighbors.node()[i];
+ int neighborId = neighbors.node()[i];
bound.set(neighbors.score()[i]);
- float[] nbrVector = vectorValues.vectorValue(nbrNode);
+ float[] neighborVector = vectorValues.vectorValue(neighborId);
for (int j = maxConn; j > i; j--) {
- float diversityCheck =
- similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node()[j]));
- if (bound.check(diversityCheck) == false) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ neighborVector, buildVectors.vectorValue(neighbors.node()[j]));
+ if (bound.check(neighborSimilarity) == false) {
// node j is too similar to node i given its score relative to the base node
// replace it with the new node, which is at [maxConn]
return i;
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
index 531140478c6..db3377bbde0 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -266,9 +266,9 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
- float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
+ float minSimilarity = results.topScore();
results.pop();
- scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity);
}
TotalHits.Relation relation =
results.incomplete()
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
index 6457b8071e9..aeffedcc287 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
@@ -27,7 +27,6 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
-import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
@@ -85,9 +84,9 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
int size = graphValues.size();
// MIN heap, holding the top results
- NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
+ NeighborQueue results = new NeighborQueue(numSeed, false);
// MAX heap, from which to pull the candidate nodes
- NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
+ NeighborQueue candidates = new NeighborQueue(numSeed, true);
int numVisited = 0;
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
@@ -114,13 +113,13 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
// Set the bound to the worst current result and below reject any newly-generated candidates
// failing
// to exceed this bound
- BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
+ Lucene90BoundsChecker bound = Lucene90BoundsChecker.create(false);
bound.set(results.topScore());
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
- float topCandidateScore = candidates.topScore();
+ float topCandidateSimilarity = candidates.topScore();
if (results.size() >= topK) {
- if (bound.check(topCandidateScore)) {
+ if (bound.check(topCandidateSimilarity)) {
break;
}
}
@@ -138,11 +137,11 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
break;
}
- float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
- if (results.size() < numSeed || bound.check(score) == false) {
- candidates.add(friendOrd, score);
+ float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
+ if (results.size() < numSeed || bound.check(friendSimilarity) == false) {
+ candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
- results.insertWithOverflow(friendOrd, score);
+ results.insertWithOverflow(friendOrd, friendSimilarity);
bound.set(results.topScore());
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91BoundsChecker.java
similarity index 77%
rename from lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
rename to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91BoundsChecker.java
index 9cde17db421..eb854f69eb3 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91BoundsChecker.java
@@ -14,17 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-package org.apache.lucene.util.hnsw;
+package org.apache.lucene.backward_codecs.lucene91;
/**
* A helper class for an hnsw graph that serves as a comparator of the currently set bound value
* with a new value.
*/
-public abstract class BoundsChecker {
+public abstract class Lucene91BoundsChecker {
float bound;
+ /** Default Constructor */
+ public Lucene91BoundsChecker() {}
+
/** Update the bound if sample is better */
public abstract void update(float sample);
@@ -33,10 +35,21 @@ public abstract class BoundsChecker {
bound = sample;
}
- /** @return whether the sample exceeds (is worse than) the bound */
+ /**
+ * Check the sample
+ *
+ * @param sample a score
+ * @return whether the sample exceeds (is worse than) the bound
+ */
public abstract boolean check(float sample);
- public static BoundsChecker create(boolean reversed) {
+ /**
+ * Create a min or max bound checker
+ *
+ * @param reversed true for the min and false for the max
+ * @return the bound checker
+ */
+ public static Lucene91BoundsChecker create(boolean reversed) {
if (reversed) {
return new Min();
} else {
@@ -48,7 +61,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set maximum value
* with a new value.
*/
- public static class Max extends BoundsChecker {
+ public static class Max extends Lucene91BoundsChecker {
Max() {
bound = Float.NEGATIVE_INFINITY;
}
@@ -70,7 +83,7 @@ public abstract class BoundsChecker {
* A helper class for an hnsw graph that serves as a comparator of the currently set minimum value
* with a new value.
*/
- public static class Min extends BoundsChecker {
+ public static class Min extends Lucene91BoundsChecker {
Min() {
bound = Float.POSITIVE_INFINITY;
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index 002497d2d2a..54ff3240b32 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -28,7 +28,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
-import org.apache.lucene.util.hnsw.BoundsChecker;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
@@ -55,7 +54,7 @@ public final class Lucene91HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
- private final BoundsChecker bound;
+ private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
final Lucene91OnHeapHnswGraph hnsw;
@@ -104,9 +103,9 @@ public final class Lucene91HnswGraphBuilder {
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
- new NeighborQueue(beamWidth, similarityFunction.reversed == false),
+ new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
- bound = BoundsChecker.create(similarityFunction.reversed);
+ bound = Lucene91BoundsChecker.create(false);
scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@@ -231,8 +230,8 @@ public final class Lucene91HnswGraphBuilder {
// extract all the Neighbors from the queue into an array; these will now be
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
- float score = candidates.topScore();
- scratch.add(candidates.pop(), score);
+ float similarity = candidates.topScore();
+ scratch.add(candidates.pop(), similarity);
}
}
@@ -253,9 +252,9 @@ public final class Lucene91HnswGraphBuilder {
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
- float diversityCheck =
+ float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
- if (bound.check(diversityCheck) == false) {
+ if (bound.check(neighborSimilarity) == false) {
return false;
}
}
@@ -286,13 +285,13 @@ public final class Lucene91HnswGraphBuilder {
for (int i = neighbors.size() - 1; i >= 0; i--) {
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
// them, drop it
- int nbrNode = neighbors.node[i];
+ int neighborId = neighbors.node[i];
bound.set(neighbors.score[i]);
- float[] nbrVector = vectorValues.vectorValue(nbrNode);
+ float[] neighborVector = vectorValues.vectorValue(neighborId);
for (int j = maxConn; j > i; j--) {
- float diversityCheck =
- similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
- if (bound.check(diversityCheck) == false) {
+ float neighborSimilarity =
+ similarityFunction.compare(neighborVector, buildVectors.vectorValue(neighbors.node[j]));
+ if (bound.check(neighborSimilarity) == false) {
// node j is too similar to node i given its score relative to the base node
// replace it with the new node, which is at [maxConn]
return i;
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
index 45f5fd5308e..42a8115e1b5 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
@@ -253,9 +253,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
- float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
+ float minSimilarity = results.topScore();
results.pop();
- scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), score);
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity);
}
TotalHits.Relation relation =
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
index 4e9a6a2dd3b..1b6534f9876 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -246,9 +246,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
- float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
+ float minSimilarity = results.topScore();
results.pop();
- scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity);
}
TotalHits.Relation relation =
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index 081503a51f8..adfab333e11 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -170,7 +170,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
float[] vector = values.vectorValue();
- float score = vectorSimilarity.convertToScore(vectorSimilarity.compare(vector, target));
+ float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene93/Lucene93HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene93/Lucene93HnswVectorsReader.java
index e439b7a52f0..977911efc2c 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene93/Lucene93HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene93/Lucene93HnswVectorsReader.java
@@ -246,7 +246,7 @@ public final class Lucene93HnswVectorsReader extends KnnVectorsReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
- float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
+ float score = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
index a237e3d8609..0aae4b72847 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
@@ -26,15 +26,10 @@ import static org.apache.lucene.util.VectorUtil.*;
public enum VectorSimilarityFunction {
/** Euclidean distance */
- EUCLIDEAN(true) {
+ EUCLIDEAN {
@Override
public float compare(float[] v1, float[] v2) {
- return squareDistance(v1, v2);
- }
-
- @Override
- public float convertToScore(float similarity) {
- return 1 / (1 + similarity);
+ return 1 / (1 + squareDistance(v1, v2));
}
},
@@ -47,12 +42,7 @@ public enum VectorSimilarityFunction {
DOT_PRODUCT {
@Override
public float compare(float[] v1, float[] v2) {
- return dotProduct(v1, v2);
- }
-
- @Override
- public float convertToScore(float similarity) {
- return (1 + similarity) / 2;
+ return (1 + dotProduct(v1, v2)) / 2;
}
},
@@ -60,50 +50,22 @@ public enum VectorSimilarityFunction {
* Cosine similarity. NOTE: the preferred way to perform cosine similarity is to normalize all
* vectors to unit length, and instead use {@link VectorSimilarityFunction#DOT_PRODUCT}. You
* should only use this function if you need to preserve the original vectors and cannot normalize
- * them in advance.
+ * them in advance. The similarity score is normalised to assure it is positive.
*/
COSINE {
@Override
public float compare(float[] v1, float[] v2) {
- return cosine(v1, v2);
- }
-
- @Override
- public float convertToScore(float similarity) {
- return (1 + similarity) / 2;
+ return (1 + cosine(v1, v2)) / 2;
}
};
/**
- * If true, the scores associated with vector comparisons are nonnegative and in reverse order;
- * that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
- * represent more similar vectors, and scores may be negative or positive.
- */
- public final boolean reversed;
-
- VectorSimilarityFunction(boolean reversed) {
- this.reversed = reversed;
- }
-
- VectorSimilarityFunction() {
- reversed = false;
- }
-
- /**
- * Calculates a similarity score between the two vectors with a specified function.
+ * Calculates a similarity score between the two vectors with a specified function. Higher
+ * similarity scores correspond to closer vectors.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(float[] v1, float[] v2);
-
- /**
- * Converts similarity scores used (may be negative, reversed, etc) into document scores, which
- * must be positive, with higher scores representing better matches.
- *
- * @param similarity the raw internal score as returned by {@link #compare(float[], float[])}.
- * @return normalizedSimilarity
- */
- public abstract float convertToScore(float similarity);
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index bb9e80c1d63..6e68de193da 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -197,7 +197,7 @@ public class KnnVectorQuery extends Query {
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
- float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
+ float score = similarityFunction.compare(vector, target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index b611d082c96..c86b9321d7a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -51,7 +51,6 @@ public final class HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
- private final BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
final OnHeapHnswGraph hnsw;
@@ -96,15 +95,14 @@ public final class HnswGraphBuilder {
this.ml = 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
- this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed);
+ this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
- new NeighborQueue(beamWidth, similarityFunction.reversed == false),
+ new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
- bound = BoundsChecker.create(similarityFunction.reversed);
// in scratch we store candidates in reverse order: worse candidates are first
- scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed);
+ scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
}
/**
@@ -225,8 +223,8 @@ public final class HnswGraphBuilder {
// extract all the Neighbors from the queue into an array; these will now be
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
- float score = candidates.topScore();
- scratch.add(candidates.pop(), score);
+ float maxSimilarity = candidates.topScore();
+ scratch.add(candidates.pop(), maxSimilarity);
}
}
@@ -245,11 +243,10 @@ public final class HnswGraphBuilder {
NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
throws IOException {
- bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
- float diversityCheck =
+ float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
- if (bound.check(diversityCheck) == false) {
+ if (neighborSimilarity >= score) {
return false;
}
}
@@ -261,16 +258,17 @@ public final class HnswGraphBuilder {
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
+ float minAcceptedSimilarity;
for (int i = neighbors.size() - 1; i > 0; i--) {
int cNode = neighbors.node[i];
float[] cVector = vectorValues.vectorValue(cNode);
- bound.set(neighbors.score[i]);
+ minAcceptedSimilarity = neighbors.score[i];
// check the candidate against its better-scoring neighbors
for (int j = i - 1; j >= 0; j--) {
- float diversityCheck =
+ float neighborSimilarity =
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
// node i is too similar to node j given its score relative to the base node
- if (bound.check(diversityCheck) == false) {
+ if (neighborSimilarity >= minAcceptedSimilarity) {
return i;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index ba88995bd3b..59735f6be9d 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -80,7 +80,7 @@ public final class HnswGraphSearcher {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
similarityFunction,
- new NeighborQueue(topK, similarityFunction.reversed == false),
+ new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
@@ -139,7 +139,7 @@ public final class HnswGraphSearcher {
int visitedLimit)
throws IOException {
int size = graph.size();
- NeighborQueue results = new NeighborQueue(topK, similarityFunction.reversed);
+ NeighborQueue results = new NeighborQueue(topK, false);
clearScratchState();
int numVisited = 0;
@@ -160,14 +160,14 @@ public final class HnswGraphSearcher {
// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
- BoundsChecker bound = BoundsChecker.create(similarityFunction.reversed);
+ float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (results.size() >= topK) {
- bound.set(results.topScore());
+ minAcceptedSimilarity = results.topScore();
}
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
- float topCandidateScore = candidates.topScore();
- if (bound.check(topCandidateScore)) {
+ float topCandidateSimilarity = candidates.topScore();
+ if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
}
@@ -184,13 +184,13 @@ public final class HnswGraphSearcher {
results.markIncomplete();
break;
}
- float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
+ float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
numVisited++;
- if (bound.check(score) == false) {
- candidates.add(friendOrd, score);
+ if (friendSimilarity >= minAcceptedSimilarity) {
+ candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
- if (results.insertWithOverflow(friendOrd, score) && results.size() >= topK) {
- bound.set(results.topScore());
+ if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
+ minAcceptedSimilarity = results.topScore();
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index a2c7253261b..50f0587bb5c 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -30,13 +30,13 @@ import org.apache.lucene.util.NumericUtils;
public class NeighborQueue {
private enum Order {
- NATURAL {
+ MIN_HEAP {
@Override
long apply(long v) {
return v;
}
},
- REVERSED {
+ MAX_HEAP {
@Override
long apply(long v) {
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
@@ -56,9 +56,9 @@ public class NeighborQueue {
// Whether the search stopped early because it reached the visited nodes limit
private boolean incomplete;
- public NeighborQueue(int initialSize, boolean reversed) {
+ public NeighborQueue(int initialSize, boolean maxHeap) {
this.heap = new LongHeap(initialSize);
- this.order = reversed ? Order.REVERSED : Order.NATURAL;
+ this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
}
/** @return the number of elements in the heap */
@@ -89,32 +89,66 @@ public class NeighborQueue {
return heap.insertWithOverflow(encode(newNode, newScore));
}
+ /**
+ * Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
+ * that when two scores are equals, the smaller node ID must win.
+ *
+ * <p>The most significant 32 bits represent the float score, encoded as a sortable int.
+ *
+ * <p>The less significant 32 bits represent the node ID.
+ *
+ * <p>The bits representing the node ID are complemented to guarantee the win for the smaller node
+ * Id.
+ *
+ * <p>The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that
+ * has
+ *
+ * <p>The most significant 32 bits to 0
+ *
+ * <p>The less significant 32 bits represent the node ID.
+ *
+ * @param node the node ID
+ * @param score the node score
+ * @return the encoded score, node ID
+ */
private long encode(int node, float score) {
- return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | node);
+ return order.apply(
+ (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
+ }
+
+ private float decodeScore(long heapValue) {
+ return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
+ }
+
+ private int decodeNodeId(long heapValue) {
+ return (int) ~(order.apply(heapValue));
}
/** Removes the top element and returns its node id. */
public int pop() {
- return (int) order.apply(heap.pop());
+ return decodeNodeId(heap.pop());
}
public int[] nodes() {
int size = size();
int[] nodes = new int[size];
for (int i = 0; i < size; i++) {
- nodes[i] = (int) order.apply(heap.get(i + 1));
+ nodes[i] = decodeNodeId(heap.get(i + 1));
}
return nodes;
}
/** Returns the top element's node id. */
public int topNode() {
- return (int) order.apply(heap.top());
+ return decodeNodeId(heap.top());
}
- /** Returns the top element's node score. */
+ /**
+ * Returns the top element's node score. For the min heap this is the minimum score. For the max
+ * heap this is the maximum score.
+ */
public float topScore() {
- return NumericUtils.sortableIntToFloat((int) (order.apply(heap.top()) >> 32));
+ return decodeScore(heap.top());
}
public void clear() {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 1dc0845ccd5..e03a8e89e79 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil;
*/
public final class OnHeapHnswGraph extends HnswGraph {
- private final boolean similarityReversed;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@@ -52,8 +51,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
private int upto;
private NeighborArray cur;
- OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) {
- this.similarityReversed = similarityReversed;
+ OnHeapHnswGraph(int M, int levelOfFirstNode) {
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
@@ -63,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
this.nsize0 = (M * 2 + 1);
for (int l = 0; l < numLevels; l++) {
graph.add(new ArrayList<>());
- graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false));
+ graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true));
}
this.nodesByLevel = new ArrayList<>(numLevels);
@@ -123,9 +121,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
}
}
}
- graph
- .get(level)
- .add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false));
+ graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true));
}
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index bb659f0bded..615e15be047 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -276,8 +276,7 @@ public class KnnGraphTester {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i);
- NeighborArray sorted =
- new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
+ NeighborArray sorted = new NeighborArray(neighbors.size(), true);
for (int j = 0; j < neighbors.size(); j++) {
int node = neighbors.node[j];
float score = neighbors.score[j];
@@ -555,7 +554,7 @@ public class KnnGraphTester {
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
- NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
+ NeighborQueue queue = new NeighborQueue(topK, false);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = similarityFunction.compare(query, vector);
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 8e357b91898..93be8e8aa8b 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -309,30 +309,6 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue(nn.visitedCount() <= visitedLimit);
}
- public void testBoundsCheckerMax() {
- BoundsChecker max = BoundsChecker.create(false);
- float f = random().nextFloat() - 0.5f;
- // any float > -MAX_VALUE is in bounds
- assertFalse(max.check(f));
- // f is now the bound (minus some delta)
- max.update(f);
- assertFalse(max.check(f)); // f is not out of bounds
- assertFalse(max.check(f + 1)); // anything greater than f is in bounds
- assertTrue(max.check(f - 1e-5f)); // delta is zero initially
- }
-
- public void testBoundsCheckerMin() {
- BoundsChecker min = BoundsChecker.create(true);
- float f = random().nextFloat() - 0.5f;
- // any float < MAX_VALUE is in bounds
- assertFalse(min.check(f));
- // f is now the bound (minus some delta)
- min.update(f);
- assertFalse(min.check(f)); // f is not out of bounds
- assertFalse(min.check(f - 1)); // anything less than f is in bounds
- assertTrue(min.check(f + 1e-5f)); // delta is zero initially
- }
-
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
expectThrows(
@@ -441,7 +417,7 @@ public class TestHnswGraph extends LuceneTestCase {
while (actual.size() > topK) {
actual.pop();
}
- NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
+ NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));