You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@solr.apache.org by ho...@apache.org on 2024/03/18 20:27:41 UTC
(solr) branch main updated: SOLR-17164: Add 2 arg variant of vectorSimilarity() function
This is an automated email from the ASF dual-hosted git repository.
hossman pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/solr.git
The following commit(s) were added to refs/heads/main by this push:
new 768c5af847a SOLR-17164: Add 2 arg variant of vectorSimilarity() function
768c5af847a is described below
commit 768c5af847a77a4a48b302cb929d1e0e24486e9a
Author: Chris Hostetter <ho...@apache.org>
AuthorDate: Mon Mar 18 13:27:24 2024 -0700
SOLR-17164: Add 2 arg variant of vectorSimilarity() function
---
solr/CHANGES.txt | 2 +
.../org/apache/solr/search/ValueSourceParser.java | 40 +----
.../solr/search/VectorSimilaritySourceParser.java | 183 ++++++++++++++++++++
.../org/apache/solr/search/QueryEqualityTest.java | 66 +++++++-
.../search/VectorSimilaritySourceParserTest.java | 187 +++++++++++++++++++++
.../function/TestDenseVectorFunctionQuery.java | 165 ++++++++++++++++++
.../query-guide/pages/function-queries.adoc | 32 +++-
7 files changed, 622 insertions(+), 53 deletions(-)
diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index 70571a9c2ed..9dc9fb0aa47 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -120,6 +120,8 @@ Improvements
* SOLR-17172: Add QueryLimits termination to the existing heavy SearchComponent-s. This allows query limits (e.g. timeAllowed,
cpuAllowed) to terminate expensive operations within components if limits are exceeded. (Andrzej Bialecki)
+* SOLR-17164: Add 2 arg variant of vectorSimilarity() function (Sanjay Dutt, hossman)
+
Optimizations
---------------------
* SOLR-17144: Close searcherExecutor thread per core after 1 minute (Pierre Salagnac, Christine Poerschke)
diff --git a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
index f3054efa1ea..a6bd40f8f49 100644
--- a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
+++ b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
@@ -26,15 +26,12 @@ import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
-import org.apache.lucene.index.VectorEncoding;
-import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.docvalues.BoolDocValues;
import org.apache.lucene.queries.function.docvalues.DoubleDocValues;
import org.apache.lucene.queries.function.docvalues.LongDocValues;
-import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConstNumberSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DefFunction;
@@ -42,7 +39,6 @@ import org.apache.lucene.queries.function.valuesource.DivFloatFunction;
import org.apache.lucene.queries.function.valuesource.DocFreqValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.DualFloatFunction;
-import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.IDFValueSource;
import org.apache.lucene.queries.function.valuesource.IfFunction;
import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource;
@@ -344,41 +340,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin {
}
});
alias("sum", "add");
- addParser(
- "vectorSimilarity",
- new ValueSourceParser() {
- @Override
- public ValueSource parse(FunctionQParser fp) throws SyntaxError {
-
- VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg());
- VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg());
-
- int vectorEncodingFlag =
- vectorEncoding.equals(VectorEncoding.BYTE)
- ? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
- : 0;
- ValueSource v1 =
- fp.parseValueSource(
- FunctionQParser.FLAG_DEFAULT
- | FunctionQParser.FLAG_CONSUME_DELIMITER
- | vectorEncodingFlag);
- ValueSource v2 =
- fp.parseValueSource(
- FunctionQParser.FLAG_DEFAULT
- | FunctionQParser.FLAG_CONSUME_DELIMITER
- | vectorEncodingFlag);
-
- switch (vectorEncoding) {
- case FLOAT32:
- return new FloatVectorSimilarityFunction(functionName, v1, v2);
- case BYTE:
- return new ByteVectorSimilarityFunction(functionName, v1, v2);
- default:
- throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
- }
- }
- });
-
+ addParser("vectorSimilarity", new VectorSimilaritySourceParser());
addParser(
"product",
new ValueSourceParser() {
diff --git a/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java b/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java
new file mode 100644
index 00000000000..aed934cdc8e
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.search;
+
+import static org.apache.solr.common.SolrException.ErrorCode;
+import static org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST;
+
+import java.util.Arrays;
+import java.util.Locale;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.queries.function.ValueSource;
+import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
+import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.schema.DenseVectorField;
+import org.apache.solr.schema.FieldType;
+import org.apache.solr.schema.SchemaField;
+
+/**
+ * This class provides implementation for two variants for parsing function query vectorSimilarity
+ * which is used to calculate the similarity between two vectors.
+ */
+public class VectorSimilaritySourceParser extends ValueSourceParser {
+ @Override
+ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
+
+ final String arg1Str = fp.parseArg();
+ if (arg1Str == null || !fp.hasMoreArguments())
+ throw new SolrException(
+ BAD_REQUEST, "Invalid number of arguments. Please provide either two or four arguments.");
+
+ final String arg2Str = peekIsConstVector(fp) ? null : fp.parseArg();
+ if (fp.hasMoreArguments() && arg2Str != null) {
+ return handle4ArgsVariant(fp, arg1Str, arg2Str);
+ }
+ return handle2ArgsVariant(fp, arg1Str, arg2Str);
+ }
+
+ /**
+ * returns true if and only if the next argument is a constant vector, taking into consideration
+ * that the next (literal) argument may be a param reference
+ */
+ private boolean peekIsConstVector(final FunctionQParser fp) throws SyntaxError {
+ final char rawPeek = fp.sp.peek();
+ if ('[' == rawPeek) {
+ return true;
+ }
+ if ('$' == rawPeek) {
+ final int savedPos = fp.sp.pos;
+ try {
+ final String rawParam = fp.parseArg();
+ return ((null != rawParam) && ('[' == (new StrParser(rawParam)).peek()));
+ } finally {
+ fp.sp.pos = savedPos;
+ }
+ }
+ return false;
+ }
+
+ private static int buildVectorEncodingFlag(final VectorEncoding vectorEncoding) {
+ return FunctionQParser.FLAG_DEFAULT
+ | FunctionQParser.FLAG_CONSUME_DELIMITER
+ | (vectorEncoding.equals(VectorEncoding.BYTE)
+ ? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
+ : 0);
+ }
+
+ /** Expects to find args #3 and #4 (two vector ValueSources) still in the function parser */
+ private ValueSource handle4ArgsVariant(FunctionQParser fp, String vecEncStr, String vecSimFuncStr)
+ throws SyntaxError {
+ final var vectorEncoding = enumValueOrBadRequest(VectorEncoding.class, vecEncStr);
+ final var vectorSimilarityFunction =
+ enumValueOrBadRequest(VectorSimilarityFunction.class, vecSimFuncStr);
+ final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
+ final ValueSource v1 = fp.parseValueSource(vectorEncodingFlag);
+ final ValueSource v2 = fp.parseValueSource(vectorEncodingFlag);
+ return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
+ }
+
+ /**
+ * If <code>field2Name</code> is null, then expects to find a constant vector as the only
+ * remaining arg in the function parser.
+ */
+ private ValueSource handle2ArgsVariant(FunctionQParser fp, String field1Name, String field2Name)
+ throws SyntaxError {
+
+ final SchemaField field1 = fp.req.getSchema().getField(field1Name);
+ final DenseVectorField field1Type = requireVectorType(field1);
+
+ final var vectorEncoding = field1Type.getVectorEncoding();
+ final var vectorSimilarityFunction = field1Type.getSimilarityFunction();
+
+ final ValueSource v1 = field1Type.getValueSource(field1, fp);
+ final ValueSource v2;
+
+ if (null == field2Name) {
+ final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
+ v2 = fp.parseValueSource(vectorEncodingFlag);
+
+ } else {
+ final SchemaField field2 = fp.req.getSchema().getField(field2Name);
+ final DenseVectorField field2Type = requireVectorType(field2);
+ if (vectorEncoding != field2Type.getVectorEncoding()
+ || vectorSimilarityFunction != field2Type.getSimilarityFunction()) {
+ throw new SolrException(
+ BAD_REQUEST,
+ String.format(
+ Locale.ROOT,
+ "Invalid arguments: vector field %s and vector field %s must have the same vectorEncoding and similarityFunction",
+ field1.getName(),
+ field2.getName()));
+ }
+ v2 = field2Type.getValueSource(field2, fp);
+ }
+ return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
+ }
+
+ private ValueSource createSimilarityFunction(
+ VectorSimilarityFunction functionName,
+ VectorEncoding vectorEncoding,
+ ValueSource v1,
+ ValueSource v2)
+ throws SyntaxError {
+ switch (vectorEncoding) {
+ case FLOAT32:
+ return new FloatVectorSimilarityFunction(functionName, v1, v2);
+ case BYTE:
+ return new ByteVectorSimilarityFunction(functionName, v1, v2);
+ default:
+ throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
+ }
+ }
+
+ private DenseVectorField requireVectorType(final SchemaField field) throws SyntaxError {
+ final FieldType fieldType = field.getType();
+ if (fieldType instanceof DenseVectorField) {
+ return (DenseVectorField) field.getType();
+ }
+ throw new SolrException(
+ BAD_REQUEST,
+ String.format(
+ Locale.ROOT,
+ "Type mismatch: Expected [%s], but found a different field type for field: [%s]",
+ DenseVectorField.class.getSimpleName(),
+ field.getName()));
+ }
+
+ /**
+ * Helper method that returns the correct Enum instance for the <code>arg</code> String, or throws
+ * a {@link ErrorCode#BAD_REQUEST} with specifics on the "Invalid argument"
+ */
+ private static <T extends Enum<T>> T enumValueOrBadRequest(
+ final Class<T> enumClass, final String arg) throws SolrException {
+ assert null != enumClass;
+ try {
+ return Enum.valueOf(enumClass, arg);
+ } catch (IllegalArgumentException | NullPointerException e) {
+ throw new SolrException(
+ BAD_REQUEST,
+ String.format(
+ Locale.ROOT,
+ "Invalid argument: %s is not a valid %s. Expected one of %s",
+ arg,
+ enumClass.getSimpleName(),
+ Arrays.toString(enumClass.getEnumConstants())));
+ }
+ }
+}
diff --git a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
index 8eb1c3da71f..653c0935879 100644
--- a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
+++ b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
@@ -911,13 +911,67 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
}
public void testFuncKnnVector() throws Exception {
- assertFuncEquals(
- "vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
- "vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])");
+ try (SolrQueryRequest req =
+ req(
+ "v1", "[1,2,3]",
+ "v2", " [1,2,3] ",
+ "v3", " [1, 2, 3] ")) {
+ assertFuncEquals(
+ req,
+ "vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
+ "vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])",
+ "vectorSimilarity(FLOAT32, COSINE,$v1, [4, 5, 6])",
+ "vectorSimilarity(FLOAT32, COSINE, $v2 , [4, 5, 6])",
+ "vectorSimilarity(FLOAT32, COSINE, $v3 , [4, 5, 6])");
+ }
- assertFuncEquals(
- "vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
- "vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])");
+ try (SolrQueryRequest req =
+ req(
+ "f1", "bar_i",
+ "f2", " bar_i ",
+ "f3", " field(bar_i) ")) {
+ assertFuncEquals(
+ req,
+ "vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN,$f1, [4, 5, 6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN, $f1, [4, 5, 6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN, $f2, [4, 5, 6])",
+ "vectorSimilarity(BYTE, EUCLIDEAN, $f3, [4, 5, 6])");
+ }
+
+ try (SolrQueryRequest req =
+ req(
+ "f", "vector",
+ "v1", "[1,2,3,4]",
+ "v2", " [1, 2, 3, 4]")) {
+ assertFuncEquals(
+ req,
+ "vectorSimilarity(FLOAT32,COSINE,vector,[1,2,3,4])",
+ "vectorSimilarity(FLOAT32,COSINE,vector,$v1)",
+ "vectorSimilarity(FLOAT32,COSINE,vector, $v1)",
+ "vectorSimilarity(FLOAT32,COSINE,vector,$v2)",
+ "vectorSimilarity(FLOAT32,COSINE,vector, $v2)",
+ "vectorSimilarity(vector,[1,2,3,4])",
+ "vectorSimilarity( vector,[1,2,3,4])",
+ "vectorSimilarity( $f,[1,2,3,4])",
+ "vectorSimilarity(vector,$v1)",
+ "vectorSimilarity(vector, $v1)",
+ "vectorSimilarity( $f, $v1)",
+ "vectorSimilarity(vector,$v2)",
+ "vectorSimilarity(vector, $v2)");
+ }
+
+ // contrived, but helps us test the param resolution
+ // for both field names in the 2arg usecase
+ try (SolrQueryRequest req = req("f", "vector")) {
+ assertFuncEquals(
+ req,
+ "vectorSimilarity($f, $f)",
+ "vectorSimilarity($f, vector)",
+ "vectorSimilarity(vector, $f)",
+ "vectorSimilarity(vector, vector)");
+ }
}
public void testFuncQuery() throws Exception {
diff --git a/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java b/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java
new file mode 100644
index 00000000000..943bc350056
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.search;
+
+import static org.apache.solr.SolrTestCaseJ4.assumeWorkingMockito;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.queries.function.ValueSource;
+import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
+import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
+import org.apache.solr.SolrTestCase;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.schema.BinaryField;
+import org.apache.solr.schema.DenseVectorField;
+import org.apache.solr.schema.IndexSchema;
+import org.apache.solr.schema.IntPointField;
+import org.apache.solr.schema.SchemaField;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/** Test for {@link VectorSimilaritySourceParser} */
+public class VectorSimilaritySourceParserTest extends SolrTestCase {
+ private static final VectorSimilaritySourceParser vecSimilarity =
+ new VectorSimilaritySourceParser();
+ private SolrQueryRequest request;
+ private SolrParams localParams;
+ private SolrParams params;
+ private IndexSchema indexSchema;
+
+ @BeforeClass
+ public static void beforeClass() {
+ assumeWorkingMockito();
+ }
+
+ @Before
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ resetMocks();
+ }
+
+ @Test
+ public void testReportErrorPassingZeroArg() throws SyntaxError {
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity()"));
+ assertEquals(
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ error.getMessage());
+ }
+
+ @Test
+ public void testReportErrorPassingOneArg() throws SyntaxError {
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1)"));
+ assertEquals(
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ error.getMessage());
+ }
+
+ @Test
+ public void testReportErrorIfSecArgsEmpty() throws Exception {
+ SchemaField field1 = new SchemaField("field1", new DenseVectorField(5));
+ when(indexSchema.getField("field1")).thenReturn(field1);
+
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1,)"));
+ assertEquals(
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ error.getMessage());
+ }
+
+ @Test
+ public void testReportErrorIfFirstArgNotVector() throws SyntaxError {
+ SchemaField field1 = new SchemaField("field1", new IntPointField());
+ when(indexSchema.getField("field1")).thenReturn(field1);
+
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)"));
+ assertEquals(
+ "Type mismatch: Expected [DenseVectorField], but found a different field type for field: [field1]",
+ error.getMessage());
+ }
+
+ @Test
+ public void testReportErrorIfSecArgNotVector() throws SyntaxError {
+ DenseVectorField fieldType = new DenseVectorField(5);
+ SchemaField field1 = new SchemaField("field1", fieldType);
+ SchemaField field2 = new SchemaField("field2", new BinaryField());
+ when(indexSchema.getField("field1")).thenReturn(field1);
+ when(indexSchema.getField("field2")).thenReturn(field2);
+
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)"));
+ assertEquals(
+ "Type mismatch: Expected [DenseVectorField], but found a different field type for field: [field2]",
+ error.getMessage());
+ }
+
+ @Test
+ public void testReportErrorIfFieldMissmatch() throws SyntaxError {
+ DenseVectorField vectorField1 =
+ new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.BYTE);
+ SchemaField field1 = new SchemaField("field1", vectorField1);
+ DenseVectorField vectorField2 =
+ new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.FLOAT32);
+ SchemaField field2 = new SchemaField("field2", vectorField2);
+ DenseVectorField vectorField3 =
+ new DenseVectorField(5, VectorSimilarityFunction.DOT_PRODUCT, VectorEncoding.FLOAT32);
+ SchemaField field3 = new SchemaField("field3", vectorField3);
+
+ when(indexSchema.getField("field1")).thenReturn(field1);
+ when(indexSchema.getField("field2")).thenReturn(field2);
+ when(indexSchema.getField("field3")).thenReturn(field3);
+
+ SolrException error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)"));
+ assertEquals(
+ "Invalid arguments: vector field field1 and vector field field2 must have the same vectorEncoding and similarityFunction",
+ error.getMessage());
+
+ error =
+ assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field2, field3)"));
+ assertEquals(
+ "Invalid arguments: vector field field2 and vector field field3 must have the same vectorEncoding and similarityFunction",
+ error.getMessage());
+ }
+
+ @Test
+ public void test2ArgsByteVectorField() throws SyntaxError {
+ DenseVectorField vectorField =
+ new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.BYTE);
+ SchemaField field1 = new SchemaField("field1", vectorField);
+ SchemaField field2 = new SchemaField("field2", vectorField);
+ when(indexSchema.getField("field1")).thenReturn(field1);
+ when(indexSchema.getField("field2")).thenReturn(field2);
+
+ ValueSource valueSource = parseWithMocks("vectorSimilarity(field1, field2)");
+ assertTrue(valueSource instanceof ByteVectorSimilarityFunction);
+ }
+
+ @Test
+ public void test2ArgsFloatVectorAndConst() throws Exception {
+ DenseVectorField vectorField =
+ new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.FLOAT32);
+ SchemaField field1 = new SchemaField("field1", vectorField);
+ when(indexSchema.getField("field1")).thenReturn(field1);
+
+ ValueSource valueSource = parseWithMocks("vectorSimilarity(field1, [1, 2, 3, 4, 5])");
+ assertTrue(valueSource instanceof FloatVectorSimilarityFunction);
+ }
+
+ private void resetMocks() {
+ request = mock(SolrQueryRequest.class);
+ localParams = mock(SolrParams.class);
+ params = mock(SolrParams.class);
+ indexSchema = mock(IndexSchema.class);
+ when(request.getSchema()).thenReturn(indexSchema);
+ }
+
+ protected ValueSource parseWithMocks(final String input) throws SyntaxError {
+ final String funcPrefix = "vectorSimilarity(";
+ assert input.startsWith(funcPrefix);
+ final FunctionQParser fqp =
+ new FunctionQParser(input.substring(funcPrefix.length()), localParams, params, request);
+ return vecSimilarity.parse(fqp);
+ }
+}
diff --git a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java
index c6573ff693c..503edd739db 100644
--- a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java
+++ b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java
@@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.SolrException;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.params.CommonParams;
import org.junit.After;
@@ -200,4 +201,168 @@ public class TestDenseVectorFunctionQuery extends SolrTestCaseJ4 {
"//result/doc[3]/float[@name='score'][.='0.7002023']",
"//result/doc[4]/float[@name='score'][.='0.7002023']");
}
+
+ @Test
+ public void testReportsErrorInvalidNumberOfArgs() {
+ assertQEx(
+ "vectorSimilarity test number of arguments failed!",
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ req(CommonParams.Q, "{!func} vectorSimilarity()", "fq", "id:(1 2 3)", "fl", "id, score"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity test number of arguments failed!",
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector)",
+ "fq",
+ "id:(1 2 3)",
+ "fl",
+ "id, score"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity test number of arguments failed!",
+ "Invalid number of arguments. Please provide either two or four arguments.",
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector,)",
+ "fq",
+ "id:(1 2 3)",
+ "fl",
+ "id, score"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ }
+
+ @Test
+ public void testReportsErrorInvalidArgs() {
+ assertQEx(
+ "vectorSimilarity 2arg: first arg non-vector field",
+ "undefined field: \"bogus\"",
+ req(CommonParams.Q, "{!func} vectorSimilarity(bogus, vector_byte_encoding)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 2arg: second arg non-vector field",
+ "undefined field: \"bogus\"",
+ req(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding, bogus)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 3+ args: 1st arg not valid encoding",
+ "Invalid argument: BOGUS is not a valid VectorEncoding. Expected one of [",
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BOGUS, DOT_PRODUCT, vector_byte_encoding, vector_byte_encoding)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 3+ args: 2nd arg not valid encoding",
+ "Invalid argument: BOGUS is not a valid VectorSimilarityFunction. Expected one of [",
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BYTE, BOGUS, vector_byte_encoding, vector_byte_encoding)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 3 args: first two are valid for 2 arg syntax",
+ "SyntaxError: Expected ')'",
+ req(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding,[1,2,3,3],BOGUS)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 3 args: first two are valid for 4 arg syntax, w/valid 3rd arg field",
+ "SyntaxError: Expected identifier",
+ req(CommonParams.Q, "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, vector_byte_encoding)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 3 args: first two are valid for 4 arg syntax, w/valid 3rd arg const vector",
+ "SyntaxError: Expected identifier",
+ req(CommonParams.Q, "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, [1,2,3,3])"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ assertQEx(
+ "vectorSimilarity 5 args: valid 4 arg syntax with extra cruft",
+ "SyntaxError: Expected ')'",
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, vector_byte_encoding, vector_byte_encoding, BOGUS)"),
+ SolrException.ErrorCode.BAD_REQUEST);
+ }
+
+ @Test
+ public void test2ArgsByteFieldAndConstVector() throws Exception {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector_byte_encoding, [1,2,3,3])",
+ "fq",
+ "id:(1 2)",
+ "fl",
+ "id, score",
+ "rows",
+ "1"),
+ "//result[@numFound='" + 2 + "']",
+ "//result/doc[1]/str[@name='id'][.=1]");
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector_byte_encoding, [3,3,2,1])",
+ "fq",
+ "id:(1 2)",
+ "fl",
+ "id, score",
+ "rows",
+ "1"),
+ "//result[@numFound='" + 2 + "']",
+ "//result/doc[1]/str[@name='id'][.=2]");
+ }
+
+ @Test
+ public void test2ArgsFloatFieldAndConstVector() throws Exception {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector, [1,2,3,3])",
+ "fq",
+ "id:(1 2 3)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 3 + "']",
+ "//result/doc[1]/str[@name='id'][.=2]",
+ "//result/doc[2]/str[@name='id'][.=3]",
+ "//result/doc[3]/str[@name='id'][.=1]");
+ }
+
+ @Test
+ public void test2ArgsFloatVectorField() throws Exception {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector, vector2)",
+ "fq",
+ "id:(1 2 3 4)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 4 + "']",
+ "//result/doc[1]/str[@name='id'][.=2]",
+ "//result/doc[2]/str[@name='id'][.=1]");
+ }
+
+ @Test
+ public void test2ArgsIfEitherFieldMissingValueDocScoreZero() {
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector, vector2)",
+ "fq",
+ "id:(3)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 1 + "']",
+ "//result/doc[1]/float[@name='score'][.=0.0]");
+ assertQ(
+ req(
+ CommonParams.Q,
+ "{!func} vectorSimilarity(vector, vector2)",
+ "fq",
+ "id:(4)",
+ "fl",
+ "id, score"),
+ "//result[@numFound='" + 1 + "']",
+ "//result/doc[1]/float[@name='score'][.=0.0]");
+ }
}
diff --git a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
index db839dc873e..48f9345f1cd 100644
--- a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
+++ b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc
@@ -153,19 +153,35 @@ There must be an even number of ValueSource instances passed in and the method a
* `dist(1,x,y,z,e,f,g)`: Manhattan distance between (x,y,z) and (e,f,g) where each letter is a field name.
=== vectorSimilarity Function
-Returns the similarity between two Knn vectors in an n-dimensional space.
-Takes in input the vector element encoding, the similarity measure plus two ValueSource instances and calculates the similarity between the two vectors.
+Returns the similarity between two Knn vectors in an n-dimensional space. There are two variants of this function.
-* The encodings supported are: `BYTE`, `FLOAT32`.
-* The similarities supported are: `EUCLIDEAN`, `COSINE`, `DOT_PRODUCT`
+==== vectorSimilarity(vector1, vector2)
+
+This function accepts two vectors as input: The first argument must be the name of a `DenseVectorField`. The second argument can be either the name of a second `DenseVectorField` or a constant vector.
+
+If two field names are specified, they must be configured with the same `vectorDimensions`, `vectorEncoding`, and `similarityFunction`. If a constant vector is specified, then it will be parsed using the `vectorEncoding` configured on the field specified by the first argument and must have the same dimensions.
+
+*Syntax Examples*
-Each ValueSource must be a knn vector (field or constant).
+* `vectorSimilarity(vectorField1, vectorField2)`: calculates the configured similarity between vector fields `vectorField1` and `vectorField2` for each document.
+* `vectorSimilarity(vectorField1, [1,2,3,4])`: calculates the configured similarity between vector field `vectorField1` and `[1, 2, 3, 4]` for each document.
+
+[NOTE]
+Only field names that follow xref:indexing-guide:fields.adoc#field-properties[recommended field naming conventions] are guaranteed to work with this syntax. Atypical field names requiring `field("...")` syntax when used in Function Queries must use the more complex 4 argument variant syntax of the `vectorSimilarity(...)` function described below.
+
+==== vectorSimilarity(ENCODING, SIMILARITY_FUNCTION, vector1, vector2)
+
+Takes in input the vector element encoding, the similarity measure plus two ValueSource instances (either a `DenseVectorField` or a constant vector) and calculates the similarity between the two vectors.
+
+* The encodings supported are: `BYTE`, `FLOAT32`
+** This is used to parse any constant vector arguments
+* The similarities supported are: `EUCLIDEAN`, `COSINE`, `DOT_PRODUCT`
*Syntax Examples*
-* `vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])`: calculates the cosine similarity between [1, 2, 3] and [4, 5, 6] for each document.
-* `vectorSimilarity(FLOAT32, DOT_PRODUCT, vectorField1, vectorField2)`: calculates the dot product similarity between the vector in 'vectorField1' and in 'vectorField2' for each document.
-* `vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vectorField)`: calculates the euclidean similarity between the vector in 'vectorField' and the constant vector [1, 5, 4, 3] for each document.
+* `vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])`: calculates the cosine similarity between `[1, 2, 3]` and `[4, 5, 6]` for each document.
+* `vectorSimilarity(FLOAT32, DOT_PRODUCT, vectorField1, vectorField2)`: calculates the dot product similarity between the vector in `vectorField1` and in `vectorField2` for each document.
+* `vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vectorField)`: calculates the euclidean similarity between the vector in `vectorField` and the constant vector `[1, 5, 4, 3]` for each document.
=== docfreq(field,val) Function
Returns the number of documents that contain the term in the field.