You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by xi...@apache.org on 2024/02/05 20:35:03 UTC
(pinot) branch master updated: Fixing array literal usage for vector (#12365)
This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 480d32c034 Fixing array literal usage for vector (#12365)
480d32c034 is described below
commit 480d32c0341cab76b156548c03fde8bea6b2bf78
Author: Xiang Fu <xi...@gmail.com>
AuthorDate: Mon Feb 5 12:34:58 2024 -0800
Fixing array literal usage for vector (#12365)
---
.../pinot/common/function/FunctionRegistry.java | 2 +-
.../common/function/TransformFunctionType.java | 4 ++
.../request/context/RequestContextUtils.java | 31 ++++++++++++
.../apache/pinot/sql/parsers/CalciteSqlParser.java | 20 +++++---
.../rewriter/PredicateComparisonRewriter.java | 18 +++++--
.../pinot/integration/tests/custom/VectorTest.java | 57 ++++++++++++++++++++++
6 files changed, 120 insertions(+), 12 deletions(-)
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index deb1673d8b..00df9498dd 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -195,7 +195,7 @@ public class FunctionRegistry {
}
@ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true)
- public static double vectorSimilarity(float[] vector1, float[] vector2) {
+ public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}
}
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 7753260192..20bc26854c 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -273,6 +273,10 @@ public enum TransformFunctionType {
VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"),
+ VECTOR_SIMILARITY("vectorSimilarity", ReturnTypes.BOOLEAN_NOT_NULL,
+ OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC),
+ ordinal -> ordinal > 1 && ordinal < 4), "vector_similarity"),
+
ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor", "array_value_constructor"),
// Trigonometry
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
index 28f3037b25..958a20da68 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
@@ -470,6 +470,37 @@ public class RequestContextUtils {
}
private static float[] getVectorValue(Expression thriftExpression) {
+ if (thriftExpression.getType() == ExpressionType.LITERAL) {
+ Literal literalExpression = thriftExpression.getLiteral();
+ if (literalExpression.isSetIntArrayValue()) {
+ float[] vector = new float[literalExpression.getIntArrayValue().size()];
+ for (int i = 0; i < literalExpression.getIntArrayValue().size(); i++) {
+ vector[i] = literalExpression.getIntArrayValue().get(i).floatValue();
+ }
+ return vector;
+ }
+ if (literalExpression.isSetLongArrayValue()) {
+ float[] vector = new float[literalExpression.getLongArrayValue().size()];
+ for (int i = 0; i < literalExpression.getLongArrayValue().size(); i++) {
+ vector[i] = literalExpression.getLongArrayValue().get(i).floatValue();
+ }
+ return vector;
+ }
+ if (literalExpression.isSetFloatArrayValue()) {
+ float[] vector = new float[literalExpression.getFloatArrayValue().size()];
+ for (int i = 0; i < literalExpression.getFloatArrayValue().size(); i++) {
+ vector[i] = literalExpression.getFloatArrayValue().get(i);
+ }
+ return vector;
+ }
+ if (literalExpression.isSetDoubleArrayValue()) {
+ float[] vector = new float[literalExpression.getDoubleArrayValue().size()];
+ for (int i = 0; i < literalExpression.getDoubleArrayValue().size(); i++) {
+ vector[i] = literalExpression.getDoubleArrayValue().get(i).floatValue();
+ }
+ return vector;
+ }
+ }
if (thriftExpression.getType() != ExpressionType.FUNCTION) {
throw new BadQueryRequestException(
"Pinot does not support column or function on the right-hand side of the predicate");
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 63058fd8c3..3d216bd643 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -296,17 +296,25 @@ public class CalciteSqlParser {
+ "the signature is VECTOR_SIMILARITY(float[], float[], int).");
}
Expression vectorLiteral = filterExpression.getFunctionCall().getOperands().get(1);
- // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
- if (!vectorLiteral.isSetFunctionCall() || !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
- "arrayvalueconstructor")) {
- throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float array literal, "
- + "the signature is VECTOR_SIMILARITY(float[], float[], int).");
+ /*
+ * Array Literal could be either:
+ * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
+ * 2. a float/double array literals
+ * Also check in
+ * {@link org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter#updateFunctionExpression(Expression)}
+ */
+ if ((vectorLiteral.isSetFunctionCall() && !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
+ "arrayvalueconstructor"))
+ || (vectorLiteral.isSetLiteral() && !vectorLiteral.getLiteral().isSetFloatArrayValue()
+ && !vectorLiteral.getLiteral().isSetDoubleArrayValue())) {
+ throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float/double array "
+ + "literal, the signature is VECTOR_SIMILARITY(float[], float[], int)");
}
if (filterExpression.getFunctionCall().getOperands().size() == 3) {
Expression topK = filterExpression.getFunctionCall().getOperands().get(2);
if (!topK.isSetLiteral()) {
throw new IllegalStateException("The third argument of VECTOR_SIMILARITY must be an integer literal, "
- + "the signature is VECTOR_SIMILARITY(float[], float[], int).");
+ + "the signature is VECTOR_SIMILARITY(float[], float[], int)");
}
}
} else {
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
index 1917e37abc..c59b5126ec 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
@@ -132,11 +132,19 @@ public class PredicateComparisonRewriter implements QueryRewriter {
case VECTOR_SIMILARITY: {
Preconditions.checkArgument(operands.size() >= 2 && operands.size() <= 3,
"For %s predicate, the number of operands must be at either 2 or 3, got: %s", filterKind, expression);
- // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
- if (operands.get(1).getFunctionCall() == null || !operands.get(1).getFunctionCall().getOperator()
- .equalsIgnoreCase("arrayvalueconstructor")) {
+ /*
+ * Array Literal could be either:
+ * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
+ * 2. a float/double array literals
+ * Also check in {@link org.apache.pinot.sql.parsers.CalciteSqlParser#validateFilter(Expression)}}
+ */
+ if ((operands.get(1).getFunctionCall() != null && !operands.get(1).getFunctionCall().getOperator()
+ .equalsIgnoreCase("arrayvalueconstructor"))
+ || (operands.get(1).getLiteral() != null && !operands.get(1).getLiteral().isSetFloatArrayValue()
+ && !operands.get(1).getLiteral().isSetDoubleArrayValue())) {
throw new SqlCompilationException(
- String.format("For %s predicate, the second operand must be a float array literal, got: %s", filterKind,
+ String.format("For %s predicate, the second operand must be a float/double array literal, got: %s",
+ filterKind,
expression));
}
if (operands.size() == 3 && operands.get(2).getLiteral() == null) {
@@ -165,7 +173,7 @@ public class PredicateComparisonRewriter implements QueryRewriter {
/**
* Rewrite predicates to boolean expressions with EQUALS operator
* Example1: "select * from table where col1" converts to
- * "select * from table where col1 = true"
+ * "select * from table where col1 = true"
* Example2: "select * from table where startsWith(col1, 'str')" converts to
* "select * from table where startsWith(col1, 'str') = true"
*
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
index 245cd31d05..2490215081 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
@@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.List;
+import java.util.Map;
import java.util.stream.IntStream;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
@@ -30,8 +32,12 @@ import org.apache.avro.generic.GenericDatumWriter;
import org.apache.commons.lang3.RandomUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.scalar.VectorFunctions;
+import org.apache.pinot.spi.config.table.FieldConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
@@ -172,11 +178,62 @@ public class VectorTest extends CustomDataQueryClusterIntegrationTest {
assertEquals(l2Distance, 22.627416997969522);
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testVectorSimilarity(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ int topK = 5;
+ String oneVectorStringLiteral = "ARRAY[1.1"
+ + StringUtils.repeat(", 1.1", VECTOR_DIM_SIZE - 1)
+ + "]";
+ String query1 =
+ String.format("SELECT "
+ + "cosineDistance(%s, %s) AS dist "
+ + "FROM %s "
+ + "WHERE vectorSimilarity(%s, %s, %d) "
+ + "ORDER BY dist ASC "
+ + "LIMIT %d",
+ VECTOR_1, oneVectorStringLiteral, getTableName(), VECTOR_1, oneVectorStringLiteral, topK * 10, topK);
+ String query2 =
+ String.format("SELECT "
+ + "cosineDistance(%s, %s) as dist "
+ + "FROM %s "
+ + "ORDER BY dist ASC "
+ + "LIMIT %d",
+ VECTOR_1, oneVectorStringLiteral, getTableName(), topK);
+
+ JsonNode jsonNode1 = postQuery(query1);
+ JsonNode jsonNode2 = postQuery(query2);
+ for (int i = 0; i < topK; i++) {
+ double dist1 = jsonNode1.get("resultTable").get("rows").get(i).get(0).asDouble();
+ double dist2 = jsonNode2.get("resultTable").get("rows").get(i).get(0).asDouble();
+ assertEquals(dist1, dist2);
+ }
+ }
+
@Override
public String getTableName() {
return DEFAULT_TABLE_NAME;
}
+ @Override
+ public TableConfig createOfflineTableConfig() {
+ return new TableConfigBuilder(TableType.OFFLINE)
+ .setTableName(getTableName())
+ .setFieldConfigList(List.of(
+ new FieldConfig.Builder(VECTOR_1)
+ .withIndexTypes(List.of(FieldConfig.IndexType.VECTOR))
+ .withEncodingType(FieldConfig.EncodingType.RAW)
+ .withProperties(Map.of(
+ "vectorIndexType", "HNSW",
+ "vectorDimension", String.valueOf(VECTOR_DIM_SIZE),
+ "vectorDistanceFunction", "COSINE",
+ "version", "1"))
+ .build()
+ ))
+ .build();
+ }
+
@Override
public Schema createSchema() {
return new Schema.SchemaBuilder().setSchemaName(getTableName())
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org