You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by xi...@apache.org on 2024/02/05 20:35:03 UTC

(pinot) branch master updated: Fixing array literal usage for vector (#12365)

This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 480d32c034 Fixing array literal usage for vector (#12365)
480d32c034 is described below

commit 480d32c0341cab76b156548c03fde8bea6b2bf78
Author: Xiang Fu <xi...@gmail.com>
AuthorDate: Mon Feb 5 12:34:58 2024 -0800

    Fixing array literal usage for vector (#12365)
---
 .../pinot/common/function/FunctionRegistry.java    |  2 +-
 .../common/function/TransformFunctionType.java     |  4 ++
 .../request/context/RequestContextUtils.java       | 31 ++++++++++++
 .../apache/pinot/sql/parsers/CalciteSqlParser.java | 20 +++++---
 .../rewriter/PredicateComparisonRewriter.java      | 18 +++++--
 .../pinot/integration/tests/custom/VectorTest.java | 57 ++++++++++++++++++++++
 6 files changed, 120 insertions(+), 12 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index deb1673d8b..00df9498dd 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -195,7 +195,7 @@ public class FunctionRegistry {
     }
 
     @ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true)
-    public static double vectorSimilarity(float[] vector1, float[] vector2) {
+    public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) {
       throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
     }
   }
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 7753260192..20bc26854c 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -273,6 +273,10 @@ public enum TransformFunctionType {
   VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
       OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"),
 
+  VECTOR_SIMILARITY("vectorSimilarity", ReturnTypes.BOOLEAN_NOT_NULL,
+      OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC),
+          ordinal -> ordinal > 1 && ordinal < 4), "vector_similarity"),
+
   ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor", "array_value_constructor"),
 
   // Trigonometry
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
index 28f3037b25..958a20da68 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
@@ -470,6 +470,37 @@ public class RequestContextUtils {
   }
 
   private static float[] getVectorValue(Expression thriftExpression) {
+    if (thriftExpression.getType() == ExpressionType.LITERAL) {
+      Literal literalExpression = thriftExpression.getLiteral();
+      if (literalExpression.isSetIntArrayValue()) {
+        float[] vector = new float[literalExpression.getIntArrayValue().size()];
+        for (int i = 0; i < literalExpression.getIntArrayValue().size(); i++) {
+          vector[i] = literalExpression.getIntArrayValue().get(i).floatValue();
+        }
+        return vector;
+      }
+      if (literalExpression.isSetLongArrayValue()) {
+        float[] vector = new float[literalExpression.getLongArrayValue().size()];
+        for (int i = 0; i < literalExpression.getLongArrayValue().size(); i++) {
+          vector[i] = literalExpression.getLongArrayValue().get(i).floatValue();
+        }
+        return vector;
+      }
+      if (literalExpression.isSetFloatArrayValue()) {
+        float[] vector = new float[literalExpression.getFloatArrayValue().size()];
+        for (int i = 0; i < literalExpression.getFloatArrayValue().size(); i++) {
+          vector[i] = literalExpression.getFloatArrayValue().get(i);
+        }
+        return vector;
+      }
+      if (literalExpression.isSetDoubleArrayValue()) {
+        float[] vector = new float[literalExpression.getDoubleArrayValue().size()];
+        for (int i = 0; i < literalExpression.getDoubleArrayValue().size(); i++) {
+          vector[i] = literalExpression.getDoubleArrayValue().get(i).floatValue();
+        }
+        return vector;
+      }
+    }
     if (thriftExpression.getType() != ExpressionType.FUNCTION) {
       throw new BadQueryRequestException(
           "Pinot does not support column or function on the right-hand side of the predicate");
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 63058fd8c3..3d216bd643 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -296,17 +296,25 @@ public class CalciteSqlParser {
             + "the signature is VECTOR_SIMILARITY(float[], float[], int).");
       }
       Expression vectorLiteral = filterExpression.getFunctionCall().getOperands().get(1);
-      // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
-      if (!vectorLiteral.isSetFunctionCall() || !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
-          "arrayvalueconstructor")) {
-        throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float array literal, "
-            + "the signature is VECTOR_SIMILARITY(float[], float[], int).");
+      /*
+       * Array Literal could be either:
+       * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
+       * 2. a float/double array literals
+       * Also check in
+       * {@link org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter#updateFunctionExpression(Expression)}
+       */
+      if ((vectorLiteral.isSetFunctionCall() && !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
+          "arrayvalueconstructor"))
+          || (vectorLiteral.isSetLiteral() && !vectorLiteral.getLiteral().isSetFloatArrayValue()
+          && !vectorLiteral.getLiteral().isSetDoubleArrayValue())) {
+        throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float/double array "
+            + "literal, the signature is VECTOR_SIMILARITY(float[], float[], int)");
       }
       if (filterExpression.getFunctionCall().getOperands().size() == 3) {
         Expression topK = filterExpression.getFunctionCall().getOperands().get(2);
         if (!topK.isSetLiteral()) {
           throw new IllegalStateException("The third argument of VECTOR_SIMILARITY must be an integer literal, "
-              + "the signature is VECTOR_SIMILARITY(float[], float[], int).");
+              + "the signature is VECTOR_SIMILARITY(float[], float[], int)");
         }
       }
     } else {
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
index 1917e37abc..c59b5126ec 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java
@@ -132,11 +132,19 @@ public class PredicateComparisonRewriter implements QueryRewriter {
         case VECTOR_SIMILARITY: {
           Preconditions.checkArgument(operands.size() >= 2 && operands.size() <= 3,
               "For %s predicate, the number of operands must be at either 2 or 3, got: %s", filterKind, expression);
-          // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
-          if (operands.get(1).getFunctionCall() == null || !operands.get(1).getFunctionCall().getOperator()
-              .equalsIgnoreCase("arrayvalueconstructor")) {
+          /*
+           * Array Literal could be either:
+           * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
+           * 2. a float/double array literals
+           * Also check in {@link org.apache.pinot.sql.parsers.CalciteSqlParser#validateFilter(Expression)}}
+           */
+          if ((operands.get(1).getFunctionCall() != null && !operands.get(1).getFunctionCall().getOperator()
+              .equalsIgnoreCase("arrayvalueconstructor"))
+              || (operands.get(1).getLiteral() != null && !operands.get(1).getLiteral().isSetFloatArrayValue()
+                  && !operands.get(1).getLiteral().isSetDoubleArrayValue())) {
             throw new SqlCompilationException(
-                String.format("For %s predicate, the second operand must be a float array literal, got: %s", filterKind,
+                String.format("For %s predicate, the second operand must be a float/double array literal, got: %s",
+                    filterKind,
                     expression));
           }
           if (operands.size() == 3 && operands.get(2).getLiteral() == null) {
@@ -165,7 +173,7 @@ public class PredicateComparisonRewriter implements QueryRewriter {
   /**
    * Rewrite predicates to boolean expressions with EQUALS operator
    *     Example1: "select * from table where col1" converts to
-   *                "select * from table where col1 = true"
+   *               "select * from table where col1 = true"
    *     Example2: "select * from table where startsWith(col1, 'str')" converts to
    *               "select * from table where startsWith(col1, 'str') = true"
    *
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
index 245cd31d05..2490215081 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
@@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList;
 import java.io.File;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.List;
+import java.util.Map;
 import java.util.stream.IntStream;
 import org.apache.avro.file.DataFileWriter;
 import org.apache.avro.generic.GenericData;
@@ -30,8 +32,12 @@ import org.apache.avro.generic.GenericDatumWriter;
 import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.common.function.scalar.VectorFunctions;
+import org.apache.pinot.spi.config.table.FieldConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
@@ -172,11 +178,62 @@ public class VectorTest extends CustomDataQueryClusterIntegrationTest {
     assertEquals(l2Distance, 22.627416997969522);
   }
 
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testVectorSimilarity(boolean useMultiStageQueryEngine)
+      throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    int topK = 5;
+    String oneVectorStringLiteral = "ARRAY[1.1"
+        + StringUtils.repeat(", 1.1", VECTOR_DIM_SIZE - 1)
+        + "]";
+    String query1 =
+        String.format("SELECT "
+                + "cosineDistance(%s, %s) AS dist "
+                + "FROM %s "
+                + "WHERE vectorSimilarity(%s, %s, %d) "
+                + "ORDER BY dist ASC "
+                + "LIMIT %d",
+            VECTOR_1, oneVectorStringLiteral, getTableName(), VECTOR_1, oneVectorStringLiteral, topK * 10, topK);
+    String query2 =
+        String.format("SELECT "
+                + "cosineDistance(%s, %s) as dist "
+                + "FROM %s "
+                + "ORDER BY dist ASC "
+                + "LIMIT %d",
+            VECTOR_1, oneVectorStringLiteral, getTableName(), topK);
+
+    JsonNode jsonNode1 = postQuery(query1);
+    JsonNode jsonNode2 = postQuery(query2);
+    for (int i = 0; i < topK; i++) {
+      double dist1 = jsonNode1.get("resultTable").get("rows").get(i).get(0).asDouble();
+      double dist2 = jsonNode2.get("resultTable").get("rows").get(i).get(0).asDouble();
+      assertEquals(dist1, dist2);
+    }
+  }
+
   @Override
   public String getTableName() {
     return DEFAULT_TABLE_NAME;
   }
 
+  @Override
+  public TableConfig createOfflineTableConfig() {
+    return new TableConfigBuilder(TableType.OFFLINE)
+        .setTableName(getTableName())
+        .setFieldConfigList(List.of(
+            new FieldConfig.Builder(VECTOR_1)
+                .withIndexTypes(List.of(FieldConfig.IndexType.VECTOR))
+                .withEncodingType(FieldConfig.EncodingType.RAW)
+                .withProperties(Map.of(
+                    "vectorIndexType", "HNSW",
+                    "vectorDimension", String.valueOf(VECTOR_DIM_SIZE),
+                    "vectorDistanceFunction", "COSINE",
+                    "version", "1"))
+                .build()
+        ))
+        .build();
+  }
+
   @Override
   public Schema createSchema() {
     return new Schema.SchemaBuilder().setSchemaName(getTableName())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org