You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ab...@apache.org on 2023/01/23 12:00:16 UTC
[lucene] branch branch_9x updated: Introduce getters for KnnVectorQuery(#12029)
This is an automated email from the ASF dual-hosted git repository.
abenedetti pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new a85cd370c95 Introduce getters for KnnVectorQuery(#12029)
a85cd370c95 is described below
commit a85cd370c95e3269e4c000a1cbca5eec17cfdff7
Author: Alessandro Benedetti <a....@sease.io>
AuthorDate: Mon Jan 23 12:35:08 2023 +0100
Introduce getters for KnnVectorQuery(#12029)
---
lucene/CHANGES.txt | 2 ++
.../lucene/search/AbstractKnnVectorQuery.java | 22 ++++++++++++++++++
.../apache/lucene/search/KnnByteVectorQuery.java | 7 ++++++
.../org/apache/lucene/search/KnnVectorQuery.java | 8 +++++++
.../lucene/search/BaseKnnVectorQueryTestCase.java | 27 ++++++++++++++++++++++
.../lucene/search/TestKnnByteVectorQuery.java | 10 ++++++++
.../apache/lucene/search/TestKnnVectorQuery.java | 8 +++++++
7 files changed, 84 insertions(+)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index bc9c3e94bf8..11bf03a3a68 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -232,6 +232,8 @@ Other
* LUCENE-10546: Update Faceting user guide. (Egor Potemkin)
+* GITHUB#12099: Introduce support in KnnVectorQuery for getters. (Alessandro Benedetti)
+
Build
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index c5060c5c694..5d635890a31 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -238,6 +238,28 @@ abstract class AbstractKnnVectorQuery extends Query {
return Objects.hash(field, k, filter);
}
+ /**
+ * @return the KnnVectorField where the KnnVector search happens.
+ */
+ public String getField() {
+ return field;
+ }
+
+ /**
+ * @return the max number of results the KnnVector search returns.
+ */
+ public int getK() {
+ return k;
+ }
+
+ /**
+ * @return the filter that is executed before the KnnVector search happens. Only the results
+ * accepted by this filter are returned by the KnnVector search.
+ */
+ public Query getFilter() {
+ return filter;
+ }
+
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index 87cbfde97e4..d44597ae81e 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -113,4 +113,11 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
public int hashCode() {
return Objects.hash(super.hashCode(), target);
}
+
+ /**
+ * @return the target query vector of the search. Each vector element is a byte.
+ */
+ public BytesRef getTargetCopy() {
+ return BytesRef.deepCopyOf(target);
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index 5ed250be0b3..914947d3854 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -23,6 +23,7 @@ import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
/**
@@ -107,4 +108,11 @@ public class KnnVectorQuery extends AbstractKnnVectorQuery {
result = 31 * result + Arrays.hashCode(target);
return result;
}
+
+ /**
+ * @return the target query vector of the search. Each vector element is a float.
+ */
+ public float[] getTargetCopy() {
+ return ArrayUtil.copyOfSubArray(target, 0, target.length);
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index 2363f207757..04629eb80fc 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -95,6 +95,33 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0}, 10));
}
+ public void testGetField() {
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
+ Query filter1 = new TermQuery(new Term("id", "id1"));
+ AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 10, filter1);
+
+ assertEquals("f1", q1.getField());
+ assertEquals("f2", q2.getField());
+ }
+
+ public void testGetK() {
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 6);
+ Query filter1 = new TermQuery(new Term("id", "id1"));
+ AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 7, filter1);
+
+ assertEquals(6, q1.getK());
+ assertEquals(7, q2.getK());
+ }
+
+ public void testGetFilter() {
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 6);
+ Query filter1 = new TermQuery(new Term("id", "id1"));
+ AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 7, filter1);
+
+ assertNull(q1.getFilter());
+ assertEquals(filter1, q2.getFilter());
+ }
+
/**
* Tests if a AbstractKnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no
* documents to match.
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
index f3cbfb01e16..dbd127c50a9 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -73,6 +73,16 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
}
+ public void testGetTarget() {
+ byte[] queryVectorBytes = floatToBytes(new float[] {0, 1});
+ BytesRef targetQueryVector = new BytesRef(queryVectorBytes);
+ KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", targetQueryVector, 10);
+
+ assertEquals(targetQueryVector, q1.getTargetCopy());
+ assertFalse(targetQueryVector == q1.getTargetCopy());
+ assertFalse(targetQueryVector.bytes == q1.getTargetCopy().bytes);
+ }
+
@Override
VectorEncoding vectorEncoding() {
return VectorEncoding.BYTE;
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index c50b6b864d9..8cdf95e1ce6 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -66,6 +66,14 @@ public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
assertEquals("KnnVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
}
+ public void testGetTarget() {
+ float[] queryVector = new float[] {0, 1};
+ KnnVectorQuery q1 = new KnnVectorQuery("f1", queryVector, 10);
+
+ assertArrayEquals(queryVector, q1.getTargetCopy(), 0);
+ assertNotEquals(queryVector, q1.getTargetCopy());
+ }
+
@Override
VectorEncoding vectorEncoding() {
return VectorEncoding.FLOAT32;