You are viewing a plain text version of this content. The canonical link for it is here.
Posted to pr@cassandra.apache.org by "adelapena (via GitHub)" <gi...@apache.org> on 2023/09/14 12:21:16 UTC

[GitHub] [cassandra] adelapena commented on a diff in pull request #2673: CASSANDRA-18715 Add support for a vector search index in SAI

adelapena commented on code in PR #2673:
URL: https://github.com/apache/cassandra/pull/2673#discussion_r1325818624


##########
src/java/org/apache/cassandra/config/CassandraRelevantProperties.java:
##########
@@ -420,12 +420,16 @@ public enum CassandraRelevantProperties
 
     // SAI specific properties
 
+    /** Whether to allow the user to specify custom options to the hnsw index */
+    SAI_HNSW_ALLOW_CUSTOM_PARAMETERS("cassandra.sai.hnsw.allow_custom_parameters", "false"),

Review Comment:
   This property could be with the other vector search properties, a few lines below.



##########
src/java/org/apache/cassandra/cql3/Operator.java:
##########
@@ -258,6 +258,20 @@ public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteB
         {
             throw new UnsupportedOperationException();
         }
+    },
+    ANN(15)
+    {
+        @Override
+        public String toString()
+        {
+            return "ANN";
+        }
+
+        @Override
+        public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand)
+        {
+            return true;

Review Comment:
   This probably deserves a comment explaining why it is no-op.



##########
src/java/org/apache/cassandra/cql3/Ordering.java:
##########
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3;
+
+import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction;
+import org.apache.cassandra.cql3.restrictions.SingleRestriction;
+import org.apache.cassandra.schema.ColumnMetadata;
+import org.apache.cassandra.schema.TableMetadata;
+
+/**
+ * A single element of an ORDER BY clause.
+ * <code>ORDER BY ordering1 [, ordering2 [, ...]] </code>
+ * <p>
+ * An ordering comprises an expression that produces the values to compare against each other
+ * and a sorting direction (ASC, DESC).
+ */
+public class Ordering
+{
+    public final Expression expression;
+    public final Direction direction;
+
+    public Ordering(Expression expression, Direction direction)
+    {
+        this.expression = expression;
+        this.direction = direction;
+    }
+
+    public interface Expression
+    {
+        default boolean hasNonClusteredOrdering()
+        {
+            return false;
+        }
+
+        default SingleRestriction toRestriction()
+        {
+            throw new UnsupportedOperationException();
+        }
+
+        ColumnMetadata getColumn();
+    }
+
+    /**
+     * Represents a single column in
+     * <code>ORDER BY column</code>

Review Comment:
   Nit: could be single-line
   ```suggestion
        * Represents a single column in {@code ORDER BY column}.
   ```



##########
src/java/org/apache/cassandra/service/reads/DataResolver.java:
##########
@@ -135,6 +135,9 @@ private boolean needsReplicaFilteringProtection()
         if (command.rowFilter().isEmpty())
             return false;
 
+        if (command.isTopK())
+            return false;

Review Comment:
   This method is called `needsReplicaFilteringProtection`, but it actually is a mix of whether the query needs and supports replica filtering protection. Maybe we could rename it to `useReplicaFilteringProtection`.
   
   I understand that top-k queries with CL_ONE do need replica filtering protection, but we aren't able to provide it, is that right?
   
   If a top-k query with CL>ONE gets a row from one replica and nothing from the other, we cannot know if that row has been changed in the other node. That can lead to the resurrection of deleted/stale rows. Also, read repair might send the stale row to the silent replica. Am I missing something in top-k queries that prevents this situation? I haven't read through all the patch yet so I might be missing something.
   
   I guess a possible approach would be just forbidding top-k queries with CL>ONE. 
   
   I understand that ANN in particular, differently to general top-k queries, only has to produce approximate results. So I guess we could use RFP for ANN queries and, if the silent replica gives us something that is not very far away from the index result, just let it pass.
   
   A third approach could be simply failing the queries if RFP actually finds a silent replica. In theory, this shouldn't be a very common situation. Or, instead of failing, it could trigger RR and query the silent replica again. This later approach, if viable, would be quite involved so we could do it in a separate patch.
   
   
   



##########
test/distributed/org/apache/cassandra/distributed/test/sai/VectorDistributedTest.java:
##########
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.sai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Multimap;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+import org.apache.cassandra.db.Keyspace;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
+import static org.apache.cassandra.distributed.api.Feature.NETWORK;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+public class VectorDistributedTest extends TestBaseImpl
+{
+
+    @Rule
+    public SAITester.FailureWatcher failureRule = new SAITester.FailureWatcher();
+
+    private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}";
+    private static final String CREATE_TABLE = "CREATE TABLE %%s (pk int primary key, val vector<float, %d>)";
+    private static final String CREATE_TABLE_TWO_VECTORS = "CREATE TABLE %%s (pk int primary key, val1 vector<float, %d>, val2 vector<float, %d>)";
+    private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %%s(%s) USING 'StorageAttachedIndex'";
+
+    private static final VectorSimilarityFunction function = IndexWriterConfig.DEFAULT_SIMILARITY_FUNCTION;
+
+    private static final String INVALID_LIMIT_MESSAGE = "Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than 1000";
+
+    private static final double MIN_RECALL = 0.8;
+
+    private static final int NUM_REPLICAS = 3;
+    private static final int RF = 2;
+
+    private static final AtomicInteger seq = new AtomicInteger();
+    private static String table;
+
+    private static Cluster cluster;
+
+    private static int dimensionCount;
+
+    @BeforeClass
+    public static void setupCluster() throws Exception
+    {
+        cluster = Cluster.build(NUM_REPLICAS)
+                         .withTokenCount(1) // VSTODO in-jvm-test in CC branch doesn't support multiple tokens
+                         .withDataDirCount(1) // VSTODO vector memtable flush doesn't support multiple directories yet
+                         .withConfig(config -> config.with(GOSSIP).with(NETWORK))
+                         .start();
+
+        cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF)));
+    }
+
+    @AfterClass
+    public static void closeCluster()
+    {
+        if (cluster != null)
+            cluster.close();
+    }
+
+    @Before
+    public void before()
+    {
+        table = "table_" + seq.getAndIncrement();
+        dimensionCount = SAITester.getRandom().nextIntBetween(100, 2048);
+    }
+
+    @After
+    public void after()
+    {
+        cluster.schemaChange(formatQuery("DROP TABLE IF EXISTS %s"));
+    }
+
+    @Test
+    public void testVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecall = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithoutLimit(randomVector(), vectorCount))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        int pageSize = SAITester.getRandom().nextIntBetween(40, 70);
+        limit = SAITester.getRandom().nextIntBetween(20, 50);
+        result = searchWithPageAndLimit(queryVector, pageSize, limit);
+
+        resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecallWithPaging = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecallWithPaging).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithPageWithoutLimit(randomVector(), 10))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        // query on-disk index
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        queryVector = randomVector();
+        result = searchWithLimit(queryVector, limit);
+        double sstableRecall = getRecall(vectors, queryVector, getVectors(result));
+        assertThat(sstableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testMultiSSTablesVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+        // disable compaction
+        String tableName = table;
+        cluster.forEach(n -> n.runOnInstance(() -> {
+            Keyspace keyspace = Keyspace.open(KEYSPACE);
+            keyspace.getColumnFamilyStore(tableName).disableAutoCompaction();
+        }));
+
+        int vectorCountPerSSTable = SAITester.getRandom().nextIntBetween(200, 500);
+        int sstableCount = SAITester.getRandom().nextIntBetween(3, 5);
+        List<float[]> allVectors = new ArrayList<>(sstableCount * vectorCountPerSSTable);
+
+        int pk = 0;
+        for (int i = 0; i < sstableCount; i++)
+        {
+            List<float[]> vectors = generateVectors(vectorCountPerSSTable);
+            for (float[] vector : vectors)
+                execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+            allVectors.addAll(vectors);
+            cluster.forEach(n -> n.flush(KEYSPACE));
+        }
+
+        // query multiple sstable indexes in multiple node
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(50, 100), allVectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        // expect recall to be at least 0.8
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double recall = getRecall(allVectors, queryVector, getVectors(result));
+        assertThat(recall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testPartitionRestrictedVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+    }
+
+    @Test
+    public void rangeRestrictedTest() throws Throwable
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = IntStream.range(0, vectorCount).mapToObj(s -> randomVector()).collect(Collectors.toList());
+
+        int pk = 0;
+        Multimap<Long, float[]> vectorsByToken = ArrayListMultimap.create();
+        for (float[] vector : vectors)
+        {
+            vectorsByToken.put(Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(pk)).getLongValue(), vector);
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ',' + vectorString(vector) + " )");
+        }
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index with existing key:
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+    }
+
+    private List<float[]> searchWithRange(float[] queryVector, long minToken, long maxToken, int expectedSize) throws Throwable
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE token(pk) <= " + maxToken + " AND token(pk) >= " + minToken + " ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT 1000");
+        assertThat(result).hasNumberOfRows(expectedSize);
+        return getVectors(result);
+    }
+
+    private Object[][] searchWithLimit(float[] queryVector, int limit)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit);
+        assertThat(result).hasNumberOfRows(limit);
+        return result;
+    }
+
+    private Object[][] searchWithoutLimit(float[] queryVector, int results)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(results);
+        return result;
+    }
+
+
+    private Object[][] searchWithPageWithoutLimit(float[] queryVector, int pageSize)
+    {
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector), pageSize);
+    }
+
+    private Object[][] searchWithPageAndLimit(float[] queryVector, int pageSize, int limit)
+    {
+        // we don't know how many will be returned in case of paging, because coordinator resumes from last-returned-row's partiton
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit, pageSize);
+    }
+
+    private void searchByKeyWithoutLimit(int key, float[] queryVector, List<float[]> vectors)
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE pk = " + key + " AND val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(1);
+        float[] output = getVectors(result).get(0);
+        assertThat(output).isEqualTo(vectors.get(key));
+    }

Review Comment:
   Unused method.



##########
test/distributed/org/apache/cassandra/distributed/test/sai/VectorDistributedTest.java:
##########
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.sai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Multimap;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+import org.apache.cassandra.db.Keyspace;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
+import static org.apache.cassandra.distributed.api.Feature.NETWORK;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+public class VectorDistributedTest extends TestBaseImpl
+{
+
+    @Rule
+    public SAITester.FailureWatcher failureRule = new SAITester.FailureWatcher();
+
+    private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}";
+    private static final String CREATE_TABLE = "CREATE TABLE %%s (pk int primary key, val vector<float, %d>)";
+    private static final String CREATE_TABLE_TWO_VECTORS = "CREATE TABLE %%s (pk int primary key, val1 vector<float, %d>, val2 vector<float, %d>)";

Review Comment:
   Nit: unused



##########
src/antlr/Lexer.g:
##########
@@ -224,7 +224,8 @@ K_MASKED:      M A S K E D;
 K_UNMASK:      U N M A S K;
 K_SELECT_MASKED: S E L E C T '_' M A S K E D;
 
-K_VECTOR:       V E C T O R;
+K_VECTOR:      V E C T O R;
+K_ANN_OF:      A N N WS+ O F;

Review Comment:
   Is there any particular reason for adding `ANN OF` as a single keyword, instead of just adding `ANN` and reusing the existing `OF` keyword?



##########
src/java/org/apache/cassandra/cql3/Ordering.java:
##########
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3;
+
+import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction;
+import org.apache.cassandra.cql3.restrictions.SingleRestriction;
+import org.apache.cassandra.schema.ColumnMetadata;
+import org.apache.cassandra.schema.TableMetadata;
+
+/**
+ * A single element of an ORDER BY clause.
+ * <code>ORDER BY ordering1 [, ordering2 [, ...]] </code>
+ * <p>
+ * An ordering comprises an expression that produces the values to compare against each other
+ * and a sorting direction (ASC, DESC).
+ */
+public class Ordering
+{
+    public final Expression expression;
+    public final Direction direction;
+
+    public Ordering(Expression expression, Direction direction)
+    {
+        this.expression = expression;
+        this.direction = direction;
+    }
+
+    public interface Expression
+    {
+        default boolean hasNonClusteredOrdering()
+        {
+            return false;
+        }
+
+        default SingleRestriction toRestriction()
+        {
+            throw new UnsupportedOperationException();
+        }
+
+        ColumnMetadata getColumn();
+    }
+
+    /**
+     * Represents a single column in
+     * <code>ORDER BY column</code>
+     */
+    public static class SingleColumn implements Expression
+    {
+        public final ColumnMetadata column;
+
+        public SingleColumn(ColumnMetadata column)
+        {
+            this.column = column;
+        }
+
+        @Override
+        public ColumnMetadata getColumn()
+        {
+            return column;
+        }
+    }
+
+    /**
+     * An expression used in Approximate Nearest Neighbor ordering.
+     * <code>ORDER BY column ANN OF value</code>
+     */
+    public static class Ann implements Expression
+    {
+        final ColumnMetadata column;
+        final Term vectorValue;
+
+        public Ann(ColumnMetadata column, Term vectorValue)
+        {
+            this.column = column;
+            this.vectorValue = vectorValue;
+        }
+
+        @Override
+        public boolean hasNonClusteredOrdering()
+        {
+            return true;
+        }
+
+        @Override
+        public SingleRestriction toRestriction()
+        {
+            return new SingleColumnRestriction.AnnRestriction(column, vectorValue);
+        }
+
+        @Override
+        public ColumnMetadata getColumn()
+        {
+            return column;
+        }
+    }
+
+    public enum Direction
+    {ASC, DESC}
+
+
+    /**
+     * Represents the AST of a single element in the ORDER BY clause.

Review Comment:
   What is AST?



##########
test/distributed/org/apache/cassandra/distributed/test/sai/VectorDistributedTest.java:
##########
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.sai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Multimap;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+import org.apache.cassandra.db.Keyspace;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
+import static org.apache.cassandra.distributed.api.Feature.NETWORK;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+public class VectorDistributedTest extends TestBaseImpl
+{
+
+    @Rule
+    public SAITester.FailureWatcher failureRule = new SAITester.FailureWatcher();
+
+    private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}";
+    private static final String CREATE_TABLE = "CREATE TABLE %%s (pk int primary key, val vector<float, %d>)";
+    private static final String CREATE_TABLE_TWO_VECTORS = "CREATE TABLE %%s (pk int primary key, val1 vector<float, %d>, val2 vector<float, %d>)";
+    private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %%s(%s) USING 'StorageAttachedIndex'";
+
+    private static final VectorSimilarityFunction function = IndexWriterConfig.DEFAULT_SIMILARITY_FUNCTION;
+
+    private static final String INVALID_LIMIT_MESSAGE = "Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than 1000";
+
+    private static final double MIN_RECALL = 0.8;
+
+    private static final int NUM_REPLICAS = 3;
+    private static final int RF = 2;
+
+    private static final AtomicInteger seq = new AtomicInteger();
+    private static String table;
+
+    private static Cluster cluster;
+
+    private static int dimensionCount;
+
+    @BeforeClass
+    public static void setupCluster() throws Exception
+    {
+        cluster = Cluster.build(NUM_REPLICAS)
+                         .withTokenCount(1) // VSTODO in-jvm-test in CC branch doesn't support multiple tokens
+                         .withDataDirCount(1) // VSTODO vector memtable flush doesn't support multiple directories yet
+                         .withConfig(config -> config.with(GOSSIP).with(NETWORK))
+                         .start();
+
+        cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF)));
+    }
+
+    @AfterClass
+    public static void closeCluster()
+    {
+        if (cluster != null)
+            cluster.close();
+    }
+
+    @Before
+    public void before()
+    {
+        table = "table_" + seq.getAndIncrement();
+        dimensionCount = SAITester.getRandom().nextIntBetween(100, 2048);
+    }
+
+    @After
+    public void after()
+    {
+        cluster.schemaChange(formatQuery("DROP TABLE IF EXISTS %s"));
+    }
+
+    @Test
+    public void testVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecall = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithoutLimit(randomVector(), vectorCount))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        int pageSize = SAITester.getRandom().nextIntBetween(40, 70);
+        limit = SAITester.getRandom().nextIntBetween(20, 50);
+        result = searchWithPageAndLimit(queryVector, pageSize, limit);
+
+        resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecallWithPaging = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecallWithPaging).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithPageWithoutLimit(randomVector(), 10))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        // query on-disk index
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        queryVector = randomVector();
+        result = searchWithLimit(queryVector, limit);
+        double sstableRecall = getRecall(vectors, queryVector, getVectors(result));
+        assertThat(sstableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testMultiSSTablesVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+        // disable compaction
+        String tableName = table;
+        cluster.forEach(n -> n.runOnInstance(() -> {
+            Keyspace keyspace = Keyspace.open(KEYSPACE);
+            keyspace.getColumnFamilyStore(tableName).disableAutoCompaction();
+        }));
+
+        int vectorCountPerSSTable = SAITester.getRandom().nextIntBetween(200, 500);
+        int sstableCount = SAITester.getRandom().nextIntBetween(3, 5);
+        List<float[]> allVectors = new ArrayList<>(sstableCount * vectorCountPerSSTable);
+
+        int pk = 0;
+        for (int i = 0; i < sstableCount; i++)
+        {
+            List<float[]> vectors = generateVectors(vectorCountPerSSTable);
+            for (float[] vector : vectors)
+                execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+            allVectors.addAll(vectors);
+            cluster.forEach(n -> n.flush(KEYSPACE));
+        }
+
+        // query multiple sstable indexes in multiple node
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(50, 100), allVectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        // expect recall to be at least 0.8
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double recall = getRecall(allVectors, queryVector, getVectors(result));
+        assertThat(recall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testPartitionRestrictedVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+    }
+
+    @Test
+    public void rangeRestrictedTest() throws Throwable
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = IntStream.range(0, vectorCount).mapToObj(s -> randomVector()).collect(Collectors.toList());
+
+        int pk = 0;
+        Multimap<Long, float[]> vectorsByToken = ArrayListMultimap.create();
+        for (float[] vector : vectors)
+        {
+            vectorsByToken.put(Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(pk)).getLongValue(), vector);
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ',' + vectorString(vector) + " )");
+        }
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index with existing key:
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+    }
+
+    private List<float[]> searchWithRange(float[] queryVector, long minToken, long maxToken, int expectedSize) throws Throwable
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE token(pk) <= " + maxToken + " AND token(pk) >= " + minToken + " ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT 1000");
+        assertThat(result).hasNumberOfRows(expectedSize);
+        return getVectors(result);
+    }
+
+    private Object[][] searchWithLimit(float[] queryVector, int limit)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit);
+        assertThat(result).hasNumberOfRows(limit);
+        return result;
+    }
+
+    private Object[][] searchWithoutLimit(float[] queryVector, int results)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(results);
+        return result;
+    }
+
+
+    private Object[][] searchWithPageWithoutLimit(float[] queryVector, int pageSize)
+    {
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector), pageSize);
+    }
+
+    private Object[][] searchWithPageAndLimit(float[] queryVector, int pageSize, int limit)
+    {
+        // we don't know how many will be returned in case of paging, because coordinator resumes from last-returned-row's partiton
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit, pageSize);
+    }
+
+    private void searchByKeyWithoutLimit(int key, float[] queryVector, List<float[]> vectors)
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE pk = " + key + " AND val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(1);
+        float[] output = getVectors(result).get(0);
+        assertThat(output).isEqualTo(vectors.get(key));
+    }
+
+    private void searchByKeyWithLimit(int key, float[] queryVector, int limit, List<float[]> vectors)
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE pk = " + key + " ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit);
+        assertThat(result).hasNumberOfRows(1);
+        float[] output = getVectors(result).get(0);
+        assertThat(output).isEqualTo(vectors.get(key));
+    }
+
+    private void assertDescendingScore(float[] queryVector, List<float[]> resultVectors)
+    {
+        float prevScore = -1;
+        for (float[] current : resultVectors)
+        {
+            float score = function.compare(current, queryVector);
+            if (prevScore >= 0)
+                assertThat(score).isLessThanOrEqualTo(prevScore);
+
+            prevScore = score;
+        }
+    }
+
+    private double getRecall(List<float[]> vectors, float[] query, List<float[]> result)
+    {
+        List<float[]> sortedVectors = new ArrayList<>(vectors);
+        sortedVectors.sort((a, b) -> Double.compare(function.compare(b, query), function.compare(a, query)));
+
+        assertThat(sortedVectors).containsAll(result);
+
+        List<float[]> nearestNeighbors = sortedVectors.subList(0, result.size());
+
+        int matches = 0;
+        for (float[] in : nearestNeighbors)
+        {
+            for (float[] out : result)
+            {
+                if (Arrays.compare(in, out) ==0)
+                {
+                    matches++;
+                    break;
+                }
+            }
+        }
+
+        return matches * 1.0 / result.size();
+    }
+
+    private List<float[]> generateVectors(int vectorCount)
+    {
+        return IntStream.range(0, vectorCount).mapToObj(s -> randomVector()).collect(Collectors.toList());
+    }
+
+    private List<float[]> getVectors(Object[][] result)
+    {
+        List<float[]> vectors = new ArrayList<>();
+
+        // verify results are part of inserted vectors
+        for (Object[] obj: result)
+        {
+            List<Float> list = (List<Float>) obj[0];
+            float[] array = new float[list.size()];
+            for (int index = 0; index < list.size(); index++)
+                array[index] = list.get(index);
+            vectors.add(array);
+        }
+
+        return vectors;
+    }
+
+    private String vectorString(float[] vector)
+    {
+        return Arrays.toString(vector);
+    }
+
+    private String randomVectorString()
+    {
+        return vectorString(randomVector());
+    }

Review Comment:
   Unused method



##########
test/distributed/org/apache/cassandra/distributed/test/sai/VectorDistributedTest.java:
##########
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.sai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Multimap;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+import org.apache.cassandra.db.Keyspace;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
+import static org.apache.cassandra.distributed.api.Feature.NETWORK;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+public class VectorDistributedTest extends TestBaseImpl
+{
+
+    @Rule
+    public SAITester.FailureWatcher failureRule = new SAITester.FailureWatcher();
+
+    private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}";
+    private static final String CREATE_TABLE = "CREATE TABLE %%s (pk int primary key, val vector<float, %d>)";
+    private static final String CREATE_TABLE_TWO_VECTORS = "CREATE TABLE %%s (pk int primary key, val1 vector<float, %d>, val2 vector<float, %d>)";
+    private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %%s(%s) USING 'StorageAttachedIndex'";
+
+    private static final VectorSimilarityFunction function = IndexWriterConfig.DEFAULT_SIMILARITY_FUNCTION;
+
+    private static final String INVALID_LIMIT_MESSAGE = "Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than 1000";
+
+    private static final double MIN_RECALL = 0.8;
+
+    private static final int NUM_REPLICAS = 3;
+    private static final int RF = 2;
+
+    private static final AtomicInteger seq = new AtomicInteger();
+    private static String table;
+
+    private static Cluster cluster;
+
+    private static int dimensionCount;
+
+    @BeforeClass
+    public static void setupCluster() throws Exception
+    {
+        cluster = Cluster.build(NUM_REPLICAS)
+                         .withTokenCount(1) // VSTODO in-jvm-test in CC branch doesn't support multiple tokens
+                         .withDataDirCount(1) // VSTODO vector memtable flush doesn't support multiple directories yet
+                         .withConfig(config -> config.with(GOSSIP).with(NETWORK))
+                         .start();
+
+        cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF)));
+    }
+
+    @AfterClass
+    public static void closeCluster()
+    {
+        if (cluster != null)
+            cluster.close();
+    }
+
+    @Before
+    public void before()
+    {
+        table = "table_" + seq.getAndIncrement();
+        dimensionCount = SAITester.getRandom().nextIntBetween(100, 2048);
+    }
+
+    @After
+    public void after()
+    {
+        cluster.schemaChange(formatQuery("DROP TABLE IF EXISTS %s"));
+    }
+
+    @Test
+    public void testVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecall = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithoutLimit(randomVector(), vectorCount))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        int pageSize = SAITester.getRandom().nextIntBetween(40, 70);
+        limit = SAITester.getRandom().nextIntBetween(20, 50);
+        result = searchWithPageAndLimit(queryVector, pageSize, limit);
+
+        resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double memtableRecallWithPaging = getRecall(vectors, queryVector, resultVectors);
+        assertThat(memtableRecallWithPaging).isGreaterThanOrEqualTo(MIN_RECALL);
+
+        assertThatThrownBy(() -> searchWithPageWithoutLimit(randomVector(), 10))
+        .hasMessageContaining(INVALID_LIMIT_MESSAGE);
+
+        // query on-disk index
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        limit = Math.min(SAITester.getRandom().nextIntBetween(10, 50), vectors.size());
+        queryVector = randomVector();
+        result = searchWithLimit(queryVector, limit);
+        double sstableRecall = getRecall(vectors, queryVector, getVectors(result));
+        assertThat(sstableRecall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testMultiSSTablesVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+        // disable compaction
+        String tableName = table;
+        cluster.forEach(n -> n.runOnInstance(() -> {
+            Keyspace keyspace = Keyspace.open(KEYSPACE);
+            keyspace.getColumnFamilyStore(tableName).disableAutoCompaction();
+        }));
+
+        int vectorCountPerSSTable = SAITester.getRandom().nextIntBetween(200, 500);
+        int sstableCount = SAITester.getRandom().nextIntBetween(3, 5);
+        List<float[]> allVectors = new ArrayList<>(sstableCount * vectorCountPerSSTable);
+
+        int pk = 0;
+        for (int i = 0; i < sstableCount; i++)
+        {
+            List<float[]> vectors = generateVectors(vectorCountPerSSTable);
+            for (float[] vector : vectors)
+                execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+            allVectors.addAll(vectors);
+            cluster.forEach(n -> n.flush(KEYSPACE));
+        }
+
+        // query multiple sstable indexes in multiple node
+        int limit = Math.min(SAITester.getRandom().nextIntBetween(50, 100), allVectors.size());
+        float[] queryVector = randomVector();
+        Object[][] result = searchWithLimit(queryVector, limit);
+
+        // expect recall to be at least 0.8
+        List<float[]> resultVectors = getVectors(result);
+        assertDescendingScore(queryVector, resultVectors);
+        double recall = getRecall(allVectors, queryVector, getVectors(result));
+        assertThat(recall).isGreaterThanOrEqualTo(MIN_RECALL);
+    }
+
+    @Test
+    public void testPartitionRestrictedVectorSearch()
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = generateVectors(vectorCount);
+
+        int pk = 0;
+        for (float[] vector : vectors)
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ", " + vectorString(vector) + " )");
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key = SAITester.getRandom().nextIntBetween(0, vectorCount - 1);
+            float[] queryVector = randomVector();
+            searchByKeyWithLimit(key, queryVector, 1, vectors);
+        }
+    }
+
+    @Test
+    public void rangeRestrictedTest() throws Throwable
+    {
+        cluster.schemaChange(formatQuery(String.format(CREATE_TABLE, dimensionCount)));
+        cluster.schemaChange(formatQuery(String.format(CREATE_INDEX, "val")));
+        SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+
+        int vectorCount = SAITester.getRandom().nextIntBetween(500, 1000);
+        List<float[]> vectors = IntStream.range(0, vectorCount).mapToObj(s -> randomVector()).collect(Collectors.toList());
+
+        int pk = 0;
+        Multimap<Long, float[]> vectorsByToken = ArrayListMultimap.create();
+        for (float[] vector : vectors)
+        {
+            vectorsByToken.put(Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(pk)).getLongValue(), vector);
+            execute("INSERT INTO %s (pk, val) VALUES (" + (pk++) + ',' + vectorString(vector) + " )");
+        }
+
+        // query memtable index
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+
+        cluster.forEach(n -> n.flush(KEYSPACE));
+
+        // query on-disk index with existing key:
+        for (int executionCount = 0; executionCount < 50; executionCount++)
+        {
+            int key1 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token1 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key1)).getLongValue();
+            int key2 = SAITester.getRandom().nextIntBetween(1, vectorCount * 2);
+            long token2 = Murmur3Partitioner.instance.getToken(Int32Type.instance.decompose(key2)).getLongValue();
+
+            long minToken = Math.min(token1, token2);
+            long maxToken = Math.max(token1, token2);
+            List<float[]> expected = vectorsByToken.entries().stream()
+                                                   .filter(e -> e.getKey() >= minToken && e.getKey() <= maxToken)
+                                                   .map(Map.Entry::getValue)
+                                                   .collect(Collectors.toList());
+
+            float[] queryVector = randomVector();
+            List<float[]> resultVectors = searchWithRange(queryVector, minToken, maxToken, expected.size());
+            if (expected.isEmpty())
+                assertThat(resultVectors).isEmpty();
+            else
+            {
+                double recall = getRecall(resultVectors, queryVector, expected);
+                assertThat(recall).isGreaterThanOrEqualTo(0.8);
+            }
+        }
+    }
+
+    private List<float[]> searchWithRange(float[] queryVector, long minToken, long maxToken, int expectedSize) throws Throwable
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE token(pk) <= " + maxToken + " AND token(pk) >= " + minToken + " ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT 1000");
+        assertThat(result).hasNumberOfRows(expectedSize);
+        return getVectors(result);
+    }
+
+    private Object[][] searchWithLimit(float[] queryVector, int limit)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit);
+        assertThat(result).hasNumberOfRows(limit);
+        return result;
+    }
+
+    private Object[][] searchWithoutLimit(float[] queryVector, int results)
+    {
+        Object[][] result = execute("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(results);
+        return result;
+    }
+
+
+    private Object[][] searchWithPageWithoutLimit(float[] queryVector, int pageSize)
+    {
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector), pageSize);
+    }
+
+    private Object[][] searchWithPageAndLimit(float[] queryVector, int pageSize, int limit)
+    {
+        // we don't know how many will be returned in case of paging, because coordinator resumes from last-returned-row's partiton
+        return executeWithPaging("SELECT val FROM %s ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit, pageSize);
+    }
+
+    private void searchByKeyWithoutLimit(int key, float[] queryVector, List<float[]> vectors)
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE pk = " + key + " AND val ann of " + Arrays.toString(queryVector));
+        assertThat(result).hasNumberOfRows(1);
+        float[] output = getVectors(result).get(0);
+        assertThat(output).isEqualTo(vectors.get(key));
+    }
+
+    private void searchByKeyWithLimit(int key, float[] queryVector, int limit, List<float[]> vectors)
+    {
+        Object[][] result = execute("SELECT val FROM %s WHERE pk = " + key + " ORDER BY val ann of " + Arrays.toString(queryVector) + " LIMIT " + limit);
+        assertThat(result).hasNumberOfRows(1);
+        float[] output = getVectors(result).get(0);
+        assertThat(output).isEqualTo(vectors.get(key));
+    }
+
+    private void assertDescendingScore(float[] queryVector, List<float[]> resultVectors)
+    {
+        float prevScore = -1;
+        for (float[] current : resultVectors)
+        {
+            float score = function.compare(current, queryVector);
+            if (prevScore >= 0)
+                assertThat(score).isLessThanOrEqualTo(prevScore);
+
+            prevScore = score;
+        }
+    }
+
+    private double getRecall(List<float[]> vectors, float[] query, List<float[]> result)
+    {
+        List<float[]> sortedVectors = new ArrayList<>(vectors);
+        sortedVectors.sort((a, b) -> Double.compare(function.compare(b, query), function.compare(a, query)));
+
+        assertThat(sortedVectors).containsAll(result);
+
+        List<float[]> nearestNeighbors = sortedVectors.subList(0, result.size());
+
+        int matches = 0;
+        for (float[] in : nearestNeighbors)
+        {
+            for (float[] out : result)
+            {
+                if (Arrays.compare(in, out) ==0)
+                {
+                    matches++;
+                    break;
+                }
+            }
+        }
+
+        return matches * 1.0 / result.size();
+    }
+
+    private List<float[]> generateVectors(int vectorCount)
+    {
+        return IntStream.range(0, vectorCount).mapToObj(s -> randomVector()).collect(Collectors.toList());
+    }
+
+    private List<float[]> getVectors(Object[][] result)
+    {
+        List<float[]> vectors = new ArrayList<>();
+
+        // verify results are part of inserted vectors
+        for (Object[] obj: result)
+        {
+            List<Float> list = (List<Float>) obj[0];
+            float[] array = new float[list.size()];
+            for (int index = 0; index < list.size(); index++)
+                array[index] = list.get(index);
+            vectors.add(array);
+        }
+
+        return vectors;
+    }
+
+    private String vectorString(float[] vector)
+    {
+        return Arrays.toString(vector);
+    }
+
+    private String randomVectorString()
+    {
+        return vectorString(randomVector());
+    }
+
+    private float[] randomVector()
+    {
+        float[] rawVector = new float[dimensionCount];
+        for (int i = 0; i < dimensionCount; i++)
+        {
+            rawVector[i] = SAITester.getRandom().nextFloat();
+        }
+        return rawVector;
+    }
+
+    private static Object[][] execute(String query)
+    {
+        return execute(query, ConsistencyLevel.QUORUM);
+    }
+
+    private static Object[][] executeAll(String query)
+    {
+        return execute(query, ConsistencyLevel.ALL);
+    }

Review Comment:
   Unused method



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: pr-unsubscribe@cassandra.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: pr-unsubscribe@cassandra.apache.org
For additional commands, e-mail: pr-help@cassandra.apache.org