You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@solr.apache.org by ab...@apache.org on 2023/07/09 19:57:12 UTC
[solr] branch branch_9_3 updated: SOLR-16675: dense vector function queries (#1750)
This is an automated email from the ASF dual-hosted git repository.
abenedetti pushed a commit to branch branch_9_3
in repository https://gitbox.apache.org/repos/asf/solr.git
The following commit(s) were added to refs/heads/branch_9_3 by this push:
new 8b320c4144e SOLR-16675: dense vector function queries (#1750)
8b320c4144e is described below
commit 8b320c4144e077a2566dad264e2bff8cd40daf1c
Author: Elia Porciani <e....@sease.io>
AuthorDate: Sun Jul 9 21:49:22 2023 +0200
SOLR-16675: dense vector function queries (#1750)
---------
Co-authored-by: Alessandro Benedetti <a....@sease.io>
---
solr/CHANGES.txt | 2 +
.../org/apache/solr/schema/DenseVectorField.java | 13 +-
.../org/apache/solr/search/FunctionQParser.java | 77 ++++++++
.../src/java/org/apache/solr/search/StrParser.java | 17 ++
.../org/apache/solr/search/ValueSourceParser.java | 38 ++++
.../apache/solr/schema/DenseVectorFieldTest.java | 16 --
.../org/apache/solr/search/QueryEqualityTest.java | 10 +
.../function/TestDenseVectorFunctionQuery.java | 203 +++++++++++++++++++++
.../function/TestDenseVectorValueSourceParser.java | 85 +++++++++
.../query-guide/pages/function-queries.adoc | 15 ++
10 files changed, 458 insertions(+), 18 deletions(-)
diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index e273ff5a50d..f0dad57d6ba 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -46,6 +46,8 @@ New Features
* SOLR-16717: {!join} can join collections with multiple shards on both sides. (Mikhail Khludnev)
+* SOLR-16675: Added function queries for dense vector similarity. (Elia Porciani, Alessandro Benedetti)
+
Improvements
---------------------
diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
index 4ab30fd72a8..5d2013cf204 100644
--- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
+++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
@@ -35,6 +35,8 @@ import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.ValueSource;
+import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
+import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
@@ -343,9 +345,16 @@ public class DenseVectorField extends FloatPointField {
@Override
public ValueSource getValueSource(SchemaField field, QParser parser) {
+
+ switch (vectorEncoding) {
+ case FLOAT32:
+ return new FloatKnnVectorFieldSource(field.getName());
+ case BYTE:
+ return new ByteKnnVectorFieldSource(field.getName());
+ }
+
throw new SolrException(
- SolrException.ErrorCode.BAD_REQUEST,
- "Function queries are not supported for Dense Vector fields.");
+ SolrException.ErrorCode.BAD_REQUEST, "Vector encoding not supported for function queries.");
}
public Query getKnnVectorQuery(
diff --git a/solr/core/src/java/org/apache/solr/search/FunctionQParser.java b/solr/core/src/java/org/apache/solr/search/FunctionQParser.java
index bce14e20011..41d86f6eae6 100644
--- a/solr/core/src/java/org/apache/solr/search/FunctionQParser.java
+++ b/solr/core/src/java/org/apache/solr/search/FunctionQParser.java
@@ -18,8 +18,11 @@ package org.apache.solr.search;
import java.util.ArrayList;
import java.util.List;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.ValueSource;
+import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
+import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.LiteralValueSource;
@@ -40,6 +43,9 @@ public class FunctionQParser extends QParser {
// When a field name is encountered, use the placeholder FieldNameValueSource instead of resolving
// to a real ValueSource
public static final int FLAG_USE_FIELDNAME_SOURCE = 0x04;
+
+ // When the flag is set, vector parsing use byte encoding, otherwise float encoding is used
+ public static final int FLAG_PARSE_VECTOR_BYTE_ENCODING = 0x08;
public static final int FLAG_DEFAULT = FLAG_CONSUME_DELIMITER;
/**
@@ -243,6 +249,49 @@ public class FunctionQParser extends QParser {
return val;
}
+ public List<Number> parseVector(VectorEncoding encoding) throws SyntaxError {
+ ArrayList<Number> values = new ArrayList<>();
+ char initChar = sp.val.charAt(sp.pos);
+ if (initChar != '[') {
+ throw new SyntaxError("Missing parenthesis at the beginning of vector ");
+ }
+ sp.pos += 1;
+ boolean valueExpected = true;
+ while (sp.pos < sp.end) {
+ char ch = sp.val.charAt(sp.pos);
+ if (Character.isWhitespace(ch)) {
+ sp.pos++;
+ } else if ((ch >= '0' && ch <= '9') || ch == '.' || ch == '+' || ch == '-') {
+ switch (encoding) {
+ case BYTE:
+ values.add(sp.getByte());
+ break;
+ case FLOAT32:
+ values.add(sp.getFloat());
+ break;
+ default:
+ throw new SyntaxError("Unexpected vector encoding: " + encoding);
+ }
+ valueExpected = false;
+ } else if (ch == ',') {
+ if (valueExpected) {
+ throw new SyntaxError("Unexpected vector encoding: " + encoding);
+ }
+ sp.pos++;
+ valueExpected = true;
+ } else if (ch == ']' && !valueExpected) {
+ break;
+ } else {
+ throw new SyntaxError("Unexpected " + ch + " at position " + sp.pos);
+ }
+ }
+ if (sp.pos >= sp.end) {
+ throw new SyntaxError("Missing parenthesis at the end of vector");
+ }
+ sp.pos++;
+ return values;
+ }
+
/**
* Parse a list of ValueSource. Must be the final set of arguments to a ValueSource.
*
@@ -363,6 +412,8 @@ public class FunctionQParser extends QParser {
}
} else if (ch == '"' || ch == '\'') {
valueSource = new LiteralValueSource(sp.getQuotedString());
+ } else if (ch == '[') {
+ valueSource = parseConstVector(flags);
} else if (ch == '$') {
sp.pos++;
String param = sp.getId();
@@ -457,6 +508,32 @@ public class FunctionQParser extends QParser {
return valueSource;
}
+ public ValueSource parseConstVector(int flags) throws SyntaxError {
+
+ VectorEncoding encoding =
+ (flags & FLAG_PARSE_VECTOR_BYTE_ENCODING) != 0
+ ? VectorEncoding.BYTE
+ : VectorEncoding.FLOAT32;
+ var vector = parseVector(encoding);
+
+ switch (encoding) {
+ case BYTE:
+ byte[] byteVector = new byte[vector.size()];
+ for (int i = 0; i < vector.size(); ++i) {
+ byteVector[i] = vector.get(i).byteValue();
+ }
+ return new ConstKnnByteVectorValueSource(byteVector);
+ case FLOAT32:
+ float[] floatVector = new float[vector.size()];
+ for (int i = 0; i < vector.size(); ++i) {
+ floatVector[i] = vector.get(i).floatValue();
+ }
+ return new ConstKnnFloatValueSource(floatVector);
+ }
+
+ throw new SyntaxError("wrong vector encoding:" + encoding);
+ }
+
/**
* @lucene.experimental
*/
diff --git a/solr/core/src/java/org/apache/solr/search/StrParser.java b/solr/core/src/java/org/apache/solr/search/StrParser.java
index 0b806448f27..07edf7a20b9 100644
--- a/solr/core/src/java/org/apache/solr/search/StrParser.java
+++ b/solr/core/src/java/org/apache/solr/search/StrParser.java
@@ -164,6 +164,23 @@ public class StrParser {
return Integer.parseInt(new String(arr, 0, i));
}
+ public byte getByte() {
+ eatws();
+ char[] arr = new char[end - pos];
+ int i;
+ for (i = 0; i < arr.length; i++) {
+ char ch = val.charAt(pos);
+ if ((ch >= '0' && ch <= '9') || ch == '+' || ch == '-') {
+ pos++;
+ arr[i] = ch;
+ } else {
+ break;
+ }
+ }
+
+ return Byte.parseByte(new String(arr, 0, i));
+ }
+
public String getId() throws SyntaxError {
return getId("Expected identifier");
}
diff --git a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
index c58b33836d4..06466865393 100644
--- a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
+++ b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
@@ -26,12 +26,15 @@ import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.docvalues.BoolDocValues;
import org.apache.lucene.queries.function.docvalues.DoubleDocValues;
import org.apache.lucene.queries.function.docvalues.LongDocValues;
+import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConstNumberSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DefFunction;
@@ -39,6 +42,7 @@ import org.apache.lucene.queries.function.valuesource.DivFloatFunction;
import org.apache.lucene.queries.function.valuesource.DocFreqValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.DualFloatFunction;
+import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.IDFValueSource;
import org.apache.lucene.queries.function.valuesource.IfFunction;
import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource;
@@ -340,6 +344,40 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin {
}
});
alias("sum", "add");
+ addParser(
+ "vectorSimilarity",
+ new ValueSourceParser() {
+ @Override
+ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
+
+ VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg());
+ VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg());
+
+ int vectorEncodingFlag =
+ vectorEncoding.equals(VectorEncoding.BYTE)
+ ? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
+ : 0;
+ ValueSource v1 =
+ fp.parseValueSource(
+ FunctionQParser.FLAG_DEFAULT
+ | FunctionQParser.FLAG_CONSUME_DELIMITER
+ | vectorEncodingFlag);
+ ValueSource v2 =
+ fp.parseValueSource(
+ FunctionQParser.FLAG_DEFAULT
+ | FunctionQParser.FLAG_CONSUME_DELIMITER
+ | vectorEncodingFlag);
+
+ switch (vectorEncoding) {
+ case FLOAT32:
+ return new FloatVectorSimilarityFunction(functionName, v1, v2);
+ case BYTE:
+ return new ByteVectorSimilarityFunction(functionName, v1, v2);
+ default:
+ throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
+ }
+ }
+ });
addParser(
"product",
diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java
index ed501a65f7c..1830a8f7ffc 100644
--- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java
+++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java
@@ -603,22 +603,6 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase {
}
}
- /** Not Supported */
- @Test
- public void query_functionQueryUsage_shouldThrowException() throws Exception {
- try {
- initCore("solrconfig-basic.xml", "schema-densevector.xml");
-
- assertQEx(
- "Running Function queries on a dense vector field should raise an Exception",
- "Function queries are not supported for Dense Vector fields.",
- req("q", "*:*", "fl", "id,field(vector)"),
- SolrException.ErrorCode.BAD_REQUEST);
- } finally {
- deleteCore();
- }
- }
-
@Test
public void denseVectorField_shouldBePresentAfterAtomicUpdate() throws Exception {
try {
diff --git a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
index c159ca94628..adaf0fb8302 100644
--- a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
+++ b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
@@ -908,6 +908,16 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
assertFuncEquals("vector(foo_i,sum(4,bar_i))", "vector(foo_i, sum(4,bar_i))");
}
+ public void testFuncKnnVector() throws Exception {
+ assertFuncEquals(
+ "vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
+ "vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])");
+
+ assertFuncEquals(
+ "vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])");
+ }
+
public void testFuncQuery() throws Exception {
SolrQueryRequest req = req("myQ", "asdf");
try {
diff --git a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java
new file mode 100644
index 00000000000..c6573ff693c
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.search.function;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.SolrInputDocument;
+import org.apache.solr.common.params.CommonParams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDenseVectorFunctionQuery extends SolrTestCaseJ4 {
+ String IDField = "id";
+ String vectorField = "vector";
+ String vectorField2 = "vector2";
+ String byteVectorField = "vector_byte_encoding";
+
+ @Before
+ public void prepareIndex() throws Exception {
+ /* vectorDimension="4" similarityFunction="cosine" */
+ initCore("solrconfig-basic.xml", "schema-densevector.xml");
+
+ List<SolrInputDocument> docsToIndex = this.prepareDocs();
+ for (SolrInputDocument doc : docsToIndex) {
+ assertU(adoc(doc));
+ }
+
+ assertU(commit());
+ }
+
+ @After
+ public void cleanUp() {
+ clearIndex();
+ deleteCore();
+ }
+
+ private List<SolrInputDocument> prepareDocs() {
+ int docsCount = 6;
+ List<SolrInputDocument> docs = new ArrayList<>(docsCount);
+ for (int i = 1; i < docsCount + 1; i++) {
+ SolrInputDocument doc = new SolrInputDocument();
+ doc.addField(IDField, i);
+ docs.add(doc);
+ }
+
+ docs.get(0).addField(vectorField, Arrays.asList(1f, 2f, 3f, 4f));
+ docs.get(1).addField(vectorField, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f));
+ docs.get(2).addField(vectorField, Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f));
+
+ docs.get(0).addField(vectorField2, Arrays.asList(5f, 4f, 1f, 2f));
+ docs.get(1).addField(vectorField2, Arrays.asList(2f, 2f, 1f, 4f));
+ docs.get(3).addField(vectorField, Arrays.asList(1.4f, 2.4f, 3.4f, 4.4f));
+
+ docs.get(0).addField(byteVectorField, Arrays.asList(1, 2, 3, 4));
+ docs.get(1).addField(byteVectorField, Arrays.asList(4, 2, 3, 1));
+
+ return docs;
+ }
+
+ @Test
+ public void floatConstantVectors_shouldReturnFloatSimilarity() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])",
+ "fq",
+ "id:(1 2 3)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 3 + "']",
+ "//result/doc[1]/float[@name='score'][.='0.9873159']",
+ "//result/doc[2]/float[@name='score'][.='0.9873159']",
+ "//result/doc[3]/float[@name='score'][.='0.9873159']");
+ }
+
+ @Test
+ public void byteConstantVectors_shouldReturnFloatSimilarity() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BYTE, COSINE, [1,2,3], [4,5,6])",
+ "fq",
+ "id:(1 2 3)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 3 + "']",
+ "//result/doc[1]/float[@name='score'][.='0.9873159']",
+ "//result/doc[2]/float[@name='score'][.='0.9873159']",
+ "//result/doc[3]/float[@name='score'][.='0.9873159']");
+ }
+
+ @Test
+ public void floatFieldVectors_shouldReturnFloatSimilarity() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(FLOAT32, DOT_PRODUCT, vector, vector2)",
+ "fq",
+ "id:(1 2)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 2 + "']",
+ "//result/doc[1]/float[@name='score'][.='15.25']",
+ "//result/doc[2]/float[@name='score'][.='12.5']");
+ }
+
+ @Test
+ public void byteFieldVectors_shouldReturnFloatSimilarity() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BYTE, EUCLIDEAN, vector_byte_encoding, vector_byte_encoding)",
+ "fq",
+ "id:(1 2)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 2 + "']",
+ "//result/doc[1]/float[@name='score'][.='1.0']",
+ "//result/doc[2]/float[@name='score'][.='1.0']");
+ }
+
+ @Test
+ public void resultOfVectorFunction_canBeUsedAsFloatFunctionInput() {
+
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} sub(1.5, vectorSimilarity(FLOAT32, EUCLIDEAN, [1,5,4,3], vector))",
+ "fq",
+ "id:(1 2)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 2 + "']",
+ "//result/doc[1]/float[@name='score'][.='1.4166666']",
+ "//result/doc[2]/float[@name='score'][.='1.4']");
+ }
+
+ @Test
+ public void byteFieldVectors_missingFieldValue_shouldReturnSimilarityZero() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vector_byte_encoding)",
+ "fq",
+ "id:3",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 1 + "']",
+ "//result/doc[1]/float[@name='score'][.='0.0']");
+ }
+
+ @Test
+ public void floatFieldVectors_missingFieldValue_shouldReturnSimilarityZero() {
+
+ // document 3 does not contain value for vector2
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(FLOAT32, DOT_PRODUCT, [1,5,4,3], vector2)",
+ "fq",
+ "id:(3)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 1 + "']",
+ "//result/doc[1]/float[@name='score'][.='0.0']");
+ }
+
+ @Test
+ public void vectorQueryInRerankQParser_ShouldRescoreOnlyFirstKResults() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "id:(1 2 3 4)",
+ "rq",
+ "{!rerank reRankQuery=$rqq reRankDocs=2 reRankWeight=1}",
+ "rqq",
+ "{!func} vectorSimilarity(FLOAT32, EUCLIDEAN, [1,5,4,3], vector)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 4 + "']",
+ "//result/doc[1]/float[@name='score'][.='0.8002023']",
+ "//result/doc[2]/float[@name='score'][.='0.7835356']",
+ "//result/doc[3]/float[@name='score'][.='0.7002023']",
+ "//result/doc[4]/float[@name='score'][.='0.7002023']");
+ }
+}
diff --git a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorValueSourceParser.java b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorValueSourceParser.java
new file mode 100644
index 00000000000..1f68efa531c
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorValueSourceParser.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.search.function;
+
+import java.io.IOException;
+import java.util.List;
+import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
+import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
+import org.apache.solr.search.FunctionQParser;
+import org.apache.solr.search.SyntaxError;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestDenseVectorValueSourceParser {
+
+ @Test
+ public void floatVectorParsing_shouldReturnConstKnnFloatValueSource()
+ throws SyntaxError, IOException {
+ FunctionQParser qp = new FunctionQParser("[1, 2,3,4]", null, null, null);
+ var valueSource = qp.parseConstVector(0);
+ Assert.assertEquals(ConstKnnFloatValueSource.class, valueSource.getClass());
+ var floatVectorValueSource = (ConstKnnFloatValueSource) valueSource;
+
+ float[] expected = {1.f, 2.f, 3.f, 4.f};
+ Assert.assertArrayEquals(
+ expected, floatVectorValueSource.getValues(null, null).floatVectorVal(0), 0.1f);
+ }
+
+ @Test
+ public void byteVectorParsing_shouldReturnConstKnnByterValueSource()
+ throws SyntaxError, IOException {
+ FunctionQParser qp = new FunctionQParser("[1, 2,3, 4]", null, null, null);
+ var valueSource = qp.parseConstVector(FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING);
+ Assert.assertEquals(ConstKnnByteVectorValueSource.class, valueSource.getClass());
+
+ var byteVectorValueSource = (ConstKnnByteVectorValueSource) valueSource;
+
+ byte[] expected = {1, 2, 3, 4};
+ Assert.assertArrayEquals(
+ expected, byteVectorValueSource.getValues(null, null).byteVectorVal(0));
+ }
+
+ @Test
+ public void byteVectorParsing_ValuesOutsideByteBoundaries_shouldRaiseAnException() {
+ FunctionQParser qp = new FunctionQParser("[1,2,3,170]", null, null, null);
+ Assert.assertThrows(
+ NumberFormatException.class,
+ () -> qp.parseConstVector(FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING));
+ }
+
+ @Test
+ public void byteVectorParsing_NonIntegerValues_shouldRaiseAnException() {
+ FunctionQParser qp = new FunctionQParser("[1,2,3.2,4]", null, null, null);
+ Assert.assertThrows(
+ NumberFormatException.class,
+ () -> qp.parseConstVector(FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING));
+ }
+
+ @Test
+ public void byteVectorParsing_WrongSyntaxForVector_shouldRaiseAnException() {
+
+ var testCases = List.of("<1,2,3.2,4>", "[1,,2,3,4]", "[1,2,3,4,5", "[1,2,3,4,,]", "1,2,3,4]");
+
+ for (String testCase : testCases) {
+ FunctionQParser qp = new FunctionQParser(testCase, null, null, null);
+ Assert.assertThrows(
+ SyntaxError.class,
+ () -> qp.parseConstVector(FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING));
+ }
+ }
+}
diff --git a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
index 8a81bddf8de..db839dc873e 100644
--- a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
+++ b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
@@ -152,6 +152,21 @@ There must be an even number of ValueSource instances passed in and the method a
* `dist(2, x,y,z,0,0,0):` Euclidean distance between (0,0,0) and (x,y,z) for each document.
* `dist(1,x,y,z,e,f,g)`: Manhattan distance between (x,y,z) and (e,f,g) where each letter is a field name.
+=== vectorSimilarity Function
+Returns the similarity between two Knn vectors in an n-dimensional space.
+Takes in input the vector element encoding, the similarity measure plus two ValueSource instances and calculates the similarity between the two vectors.
+
+* The encodings supported are: `BYTE`, `FLOAT32`.
+* The similarities supported are: `EUCLIDEAN`, `COSINE`, `DOT_PRODUCT`
+
+Each ValueSource must be a knn vector (field or constant).
+
+*Syntax Examples*
+
+* `vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])`: calculates the cosine similarity between [1, 2, 3] and [4, 5, 6] for each document.
+* `vectorSimilarity(FLOAT32, DOT_PRODUCT, vectorField1, vectorField2)`: calculates the dot product similarity between the vector in 'vectorField1' and in 'vectorField2' for each document.
+* `vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vectorField)`: calculates the euclidean similarity between the vector in 'vectorField' and the constant vector [1, 5, 4, 3] for each document.
+
=== docfreq(field,val) Function
Returns the number of documents that contain the term in the field.
This is a constant (the same value for all documents in the index).