You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ab...@apache.org on 2023/01/23 12:00:16 UTC

[lucene] branch branch_9x updated: Introduce getters for KnnVectorQuery(#12029)

This is an automated email from the ASF dual-hosted git repository.

abenedetti pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new a85cd370c95 Introduce getters for KnnVectorQuery(#12029)
a85cd370c95 is described below

commit a85cd370c95e3269e4c000a1cbca5eec17cfdff7
Author: Alessandro Benedetti <a....@sease.io>
AuthorDate: Mon Jan 23 12:35:08 2023 +0100

    Introduce getters for KnnVectorQuery(#12029)
---
 lucene/CHANGES.txt                                 |  2 ++
 .../lucene/search/AbstractKnnVectorQuery.java      | 22 ++++++++++++++++++
 .../apache/lucene/search/KnnByteVectorQuery.java   |  7 ++++++
 .../org/apache/lucene/search/KnnVectorQuery.java   |  8 +++++++
 .../lucene/search/BaseKnnVectorQueryTestCase.java  | 27 ++++++++++++++++++++++
 .../lucene/search/TestKnnByteVectorQuery.java      | 10 ++++++++
 .../apache/lucene/search/TestKnnVectorQuery.java   |  8 +++++++
 7 files changed, 84 insertions(+)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index bc9c3e94bf8..11bf03a3a68 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -232,6 +232,8 @@ Other
 
 * LUCENE-10546: Update Faceting user guide. (Egor Potemkin)
 
+* GITHUB#12099: Introduce support in KnnVectorQuery for getters. (Alessandro Benedetti)
+
 Build
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index c5060c5c694..5d635890a31 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -238,6 +238,28 @@ abstract class AbstractKnnVectorQuery extends Query {
     return Objects.hash(field, k, filter);
   }
 
+  /**
+   * @return the KnnVectorField where the KnnVector search happens.
+   */
+  public String getField() {
+    return field;
+  }
+
+  /**
+   * @return the max number of results the KnnVector search returns.
+   */
+  public int getK() {
+    return k;
+  }
+
+  /**
+   * @return the filter that is executed before the KnnVector search happens. Only the results
+   *     accepted by this filter are returned by the KnnVector search.
+   */
+  public Query getFilter() {
+    return filter;
+  }
+
   /** Caches the results of a KnnVector search: a list of docs and their scores */
   static class DocAndScoreQuery extends Query {
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index 87cbfde97e4..d44597ae81e 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -113,4 +113,11 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
   public int hashCode() {
     return Objects.hash(super.hashCode(), target);
   }
+
+  /**
+   * @return the target query vector of the search. Each vector element is a byte.
+   */
+  public BytesRef getTargetCopy() {
+    return BytesRef.deepCopyOf(target);
+  }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index 5ed250be0b3..914947d3854 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -23,6 +23,7 @@ import org.apache.lucene.document.KnnVectorField;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.Bits;
 
 /**
@@ -107,4 +108,11 @@ public class KnnVectorQuery extends AbstractKnnVectorQuery {
     result = 31 * result + Arrays.hashCode(target);
     return result;
   }
+
+  /**
+   * @return the target query vector of the search. Each vector element is a float.
+   */
+  public float[] getTargetCopy() {
+    return ArrayUtil.copyOfSubArray(target, 0, target.length);
+  }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index 2363f207757..04629eb80fc 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -95,6 +95,33 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
     assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0}, 10));
   }
 
+  public void testGetField() {
+    AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
+    Query filter1 = new TermQuery(new Term("id", "id1"));
+    AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 10, filter1);
+
+    assertEquals("f1", q1.getField());
+    assertEquals("f2", q2.getField());
+  }
+
+  public void testGetK() {
+    AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 6);
+    Query filter1 = new TermQuery(new Term("id", "id1"));
+    AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 7, filter1);
+
+    assertEquals(6, q1.getK());
+    assertEquals(7, q2.getK());
+  }
+
+  public void testGetFilter() {
+    AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 6);
+    Query filter1 = new TermQuery(new Term("id", "id1"));
+    AbstractKnnVectorQuery q2 = getKnnVectorQuery("f2", new float[] {0, 1}, 7, filter1);
+
+    assertNull(q1.getFilter());
+    assertEquals(filter1, q2.getFilter());
+  }
+
   /**
    * Tests if a AbstractKnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no
    * documents to match.
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
index f3cbfb01e16..dbd127c50a9 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -73,6 +73,16 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
     assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
   }
 
+  public void testGetTarget() {
+    byte[] queryVectorBytes = floatToBytes(new float[] {0, 1});
+    BytesRef targetQueryVector = new BytesRef(queryVectorBytes);
+    KnnByteVectorQuery q1 = new KnnByteVectorQuery("f1", targetQueryVector, 10);
+
+    assertEquals(targetQueryVector, q1.getTargetCopy());
+    assertFalse(targetQueryVector == q1.getTargetCopy());
+    assertFalse(targetQueryVector.bytes == q1.getTargetCopy().bytes);
+  }
+
   @Override
   VectorEncoding vectorEncoding() {
     return VectorEncoding.BYTE;
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index c50b6b864d9..8cdf95e1ce6 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -66,6 +66,14 @@ public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
     assertEquals("KnnVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
   }
 
+  public void testGetTarget() {
+    float[] queryVector = new float[] {0, 1};
+    KnnVectorQuery q1 = new KnnVectorQuery("f1", queryVector, 10);
+
+    assertArrayEquals(queryVector, q1.getTargetCopy(), 0);
+    assertNotEquals(queryVector, q1.getTargetCopy());
+  }
+
   @Override
   VectorEncoding vectorEncoding() {
     return VectorEncoding.FLOAT32;