You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kvrocks.apache.org by hu...@apache.org on 2023/06/25 11:54:08 UTC

[kvrocks] branch unstable updated: Add Redis-compatible cursors for `SCAN` commands (#1489)

This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 431277c7 Add Redis-compatible cursors for `SCAN` commands (#1489)
431277c7 is described below

commit 431277c786c32c049e464914b7e6a98a76d48867
Author: 纪华裕 <80...@qq.com>
AuthorDate: Sun Jun 25 19:54:03 2023 +0800

    Add Redis-compatible cursors for `SCAN` commands (#1489)
    
    We have added the following steps to the processing of the `SCAN`, `ZSCAN`, `SSCAN`, `HSCAN` commands:
    - Before processing the command, convert the `numeric cursor` sent by the client back to the `keyname cursor`.
    - After processing the command, convert the `keyname cursor` back to a `numeric cursor` and return it to the client, And store the conversion dictionary in the `Server#cursor_dict_`.
    
    Since those steps are non-intrusive to the internal implementation of the `SCAN` command, we can ensure that the behavior of commands such as `ZSCAN`, `SSCAN`, and `HSCAN` is consistent with that of the `SCAN` command.
    In the following, I will only use the `SCAN` command as an example to describe the new design.
    
    ### Cursor design
    
    We call the new cursor the `numeric cursor` and the old one the `keyname cursor`.
    The numeric cursor is composed of 3 parts:
    - 1-16 bit is counter,and the first bit always 1
    - 17-32 bit is timestamp
    - 33-64 bit is hash of keyname
    
    The `counter` is a 16-bit unsigned integer that is incremented by 1 every time. When the `counter` overflows, it returns to 0 because it is an unsigned number. Since our `cursor_dict_size` is a power of 2, the `counter` is continuous mod `cursor_dict_size`.
    `timestap` is a 16-bit timestamp in seconds, which can store up to 9 hours.
    `hash` is a 32-bit hash value of the `keyname`.
    
    ### Cursor dictionary(`Server#cursor_dict_`)
    
    The cursor dictionary is an array with a length of 16384(1024*16), which is determined at compile time, and occupies about 640KB of memory. Including the length of the referenced keyname strings, its size is about 1-2M.
    
    During converting the `keyname cursor` back to a `numeric cursor`, a new cursor is generated based on the above rules, and the index for storing the dictionary is determined based on the counter value (index = counter % DICT_SIZE).
    
    During converting the `numeric cursor` back to a `keyname cursor`, we get the counter from the cursor and calculate the index of the cursor in the `cursor_dict_` based on the counter. We only need to compare the cursor value of the item at that index with the input cursor value to determine if they are the same.
    
    ### Other information about the cursor
    This design guarantees the validity of the latest 16384(1024*16) cursors, while cursors that are older or not generated in our system are considered invalid cursors. For invalid cursors, we treat them as a 0 cursor, which means we will start iterating over the collection from the beginning.
    
    Our cursor is globally visible, and we store index information in the cursor. As long as the cursor remains valid, using the same cursor in different connections will produce the same results.
    
    We prevent other users from guessing the data traversed by adjacent cursors by adding the hash value of the `keyname` to the cursor. If a user tries to obtain adjacent cursor information by traversing the hash, the cursor will become an invalid cursor before the traversal is complete because the size of the 32-bit space is much larger than the length of the `cursor_dict_`.
    
    We add a timestamp to the cursor to ensure that the same cursor does not appear within a short period of time before and after a restart.
    
    Other behaviors are consistent with the original SCAN implementation.
    
    ### Configuration file
    Added `redis-cursor-compatible` configuration item.
    If enabled, the cursor will be an unsigned 64-bit integer.
    If disabled, the cursor will be a string.
    
    ### Test file
    Added tests for `redis-cli --bigkey` and `redis-cli --memkeys` commands. We only need to ensure that these commands run correctly, because their correctness is guaranteed by `redis-cli` on the premise that we ensure the correctness of the `scan` command.
    Modified the scan command test to test for the cases where `redis-cursor-compatible` is set to `yes` or `no`.
    
    
    ### Other changes:
    Fixed a bug where the cursor did not return 0 when `SCAN` commands return less than the number of elements.
---
 kvrocks.conf                             |  6 ++++
 src/commands/cmd_hash.cc                 |  5 +--
 src/commands/cmd_server.cc               | 15 +++++----
 src/commands/cmd_set.cc                  |  5 +--
 src/commands/cmd_zset.cc                 |  5 +--
 src/commands/scan_base.h                 | 12 ++++---
 src/config/config.cc                     |  1 +
 src/config/config.h                      |  1 +
 src/server/server.cc                     | 55 ++++++++++++++++++++++++++++++++
 src/server/server.h                      | 45 ++++++++++++++++++++++++++
 src/storage/redis_db.cc                  |  3 +-
 tests/gocase/integration/cli/cli_test.go | 17 ++++++++++
 tests/gocase/unit/scan/scan_test.go      | 15 ++++++++-
 13 files changed, 165 insertions(+), 20 deletions(-)

diff --git a/kvrocks.conf b/kvrocks.conf
index 71bbe2ff..106f0d85 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -296,6 +296,12 @@ max-backup-keep-hours 24
 # Default: 16
 max-bitmap-to-string-mb 16
 
+# Whether to enable SCAN-like cursor compatible with Redis. 
+# If enabled, the cursor will be unsigned 64-bit integers.
+# If disabled, the cursor will be a string.
+# Default: no
+redis-cursor-compatible no
+
 ################################## TLS ###################################
 
 # By default, TLS/SSL is disabled, i.e. `tls-port` is set to 0.
diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc
index b4ce5707..c6e2ddf6 100644
--- a/src/commands/cmd_hash.cc
+++ b/src/commands/cmd_hash.cc
@@ -366,12 +366,13 @@ class CommandHScan : public CommandSubkeyScanBase {
     redis::Hash hash_db(svr->storage, conn->GetNamespace());
     std::vector<std::string> fields;
     std::vector<std::string> values;
-    auto s = hash_db.Scan(key_, cursor_, limit_, prefix_, &fields, &values);
+    auto key_name = svr->GetKeyNameFromCursor(cursor_, CursorType::kTypeHash);
+    auto s = hash_db.Scan(key_, key_name, limit_, prefix_, &fields, &values);
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    *output = GenerateOutput(fields, values);
+    *output = GenerateOutput(svr, fields, values, CursorType::kTypeHash);
     return Status::OK();
   }
 };
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index cc0b2f22..18c56f86 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -760,11 +760,11 @@ class CommandScan : public CommandScanBase {
     return Commander::Parse(args);
   }
 
-  static std::string GenerateOutput(const std::vector<std::string> &keys, std::string end_cursor) {
+  static std::string GenerateOutput(Server *svr, const std::vector<std::string> &keys, const std::string &end_cursor) {
     std::vector<std::string> list;
     if (!end_cursor.empty()) {
-      end_cursor = kCursorPrefix + end_cursor;
-      list.emplace_back(redis::BulkString(end_cursor));
+      list.emplace_back(
+          redis::BulkString(svr->GenerateCursorFromKeyName(end_cursor, CursorType::kTypeBase, kCursorPrefix)));
     } else {
       list.emplace_back(redis::BulkString("0"));
     }
@@ -776,14 +776,15 @@ class CommandScan : public CommandScanBase {
 
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
     redis::Database redis_db(svr->storage, conn->GetNamespace());
+    auto key_name = svr->GetKeyNameFromCursor(cursor_, CursorType::kTypeBase);
+
     std::vector<std::string> keys;
-    std::string end_cursor;
-    auto s = redis_db.Scan(cursor_, limit_, prefix_, &keys, &end_cursor);
+    std::string end_key;
+    auto s = redis_db.Scan(key_name, limit_, prefix_, &keys, &end_key);
     if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
-
-    *output = GenerateOutput(keys, end_cursor);
+    *output = GenerateOutput(svr, keys, end_key);
     return Status::OK();
   }
 };
diff --git a/src/commands/cmd_set.cc b/src/commands/cmd_set.cc
index c066e372..586cad4c 100644
--- a/src/commands/cmd_set.cc
+++ b/src/commands/cmd_set.cc
@@ -426,12 +426,13 @@ class CommandSScan : public CommandSubkeyScanBase {
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
     redis::Set set_db(svr->storage, conn->GetNamespace());
     std::vector<std::string> members;
-    auto s = set_db.Scan(key_, cursor_, limit_, prefix_, &members);
+    auto key_name = svr->GetKeyNameFromCursor(cursor_, CursorType::kTypeSet);
+    auto s = set_db.Scan(key_, key_name, limit_, prefix_, &members);
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    *output = CommandScanBase::GenerateOutput(members);
+    *output = CommandScanBase::GenerateOutput(svr, members, CursorType::kTypeSet);
     return Status::OK();
   }
 };
diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index 995f9b3f..5ad2cad7 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -1319,7 +1319,8 @@ class CommandZScan : public CommandSubkeyScanBase {
     redis::ZSet zset_db(svr->storage, conn->GetNamespace());
     std::vector<std::string> members;
     std::vector<double> scores;
-    auto s = zset_db.Scan(key_, cursor_, limit_, prefix_, &members, &scores);
+    auto key_name = svr->GetKeyNameFromCursor(cursor_, CursorType::kTypeZSet);
+    auto s = zset_db.Scan(key_, key_name, limit_, prefix_, &members, &scores);
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
@@ -1329,7 +1330,7 @@ class CommandZScan : public CommandSubkeyScanBase {
     for (const auto &score : scores) {
       score_strings.emplace_back(util::Float2String(score));
     }
-    *output = GenerateOutput(members, score_strings);
+    *output = GenerateOutput(svr, members, score_strings, CursorType::kTypeZSet);
     return Status::OK();
   }
 };
diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h
index 85077845..069a54e9 100644
--- a/src/commands/scan_base.h
+++ b/src/commands/scan_base.h
@@ -23,6 +23,7 @@
 #include "commander.h"
 #include "error_constants.h"
 #include "parse_util.h"
+#include "server/server.h"
 
 namespace redis {
 
@@ -63,10 +64,11 @@ class CommandScanBase : public Commander {
     }
   }
 
-  std::string GenerateOutput(const std::vector<std::string> &keys) const {
+  std::string GenerateOutput(Server *svr, const std::vector<std::string> &keys, CursorType cursor_type) const {
     std::vector<std::string> list;
     if (keys.size() == static_cast<size_t>(limit_)) {
-      list.emplace_back(redis::BulkString(keys.back()));
+      auto end_cursor = svr->GenerateCursorFromKeyName(keys.back(), cursor_type);
+      list.emplace_back(redis::BulkString(end_cursor));
     } else {
       list.emplace_back(redis::BulkString("0"));
     }
@@ -109,11 +111,13 @@ class CommandSubkeyScanBase : public CommandScanBase {
     return Commander::Parse(args);
   }
 
-  std::string GenerateOutput(const std::vector<std::string> &fields, const std::vector<std::string> &values) {
+  std::string GenerateOutput(Server *svr, const std::vector<std::string> &fields,
+                             const std::vector<std::string> &values, CursorType cursor_type) {
     std::vector<std::string> list;
     auto items_count = fields.size();
     if (items_count == static_cast<size_t>(limit_)) {
-      list.emplace_back(redis::BulkString(fields.back()));
+      auto end_cursor = svr->GenerateCursorFromKeyName(fields.back(), cursor_type);
+      list.emplace_back(redis::BulkString(end_cursor));
     } else {
       list.emplace_back(redis::BulkString("0"));
     }
diff --git a/src/config/config.cc b/src/config/config.cc
index ad94a0cc..64cce3c4 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -165,6 +165,7 @@ Config::Config() {
       {"unixsocketperm", true, new OctalField(&unixsocketperm, 0777, 1, INT_MAX)},
       {"log-retention-days", false, new IntField(&log_retention_days, -1, -1, INT_MAX)},
       {"persist-cluster-nodes-enabled", false, new YesNoField(&persist_cluster_nodes_enabled, true)},
+      {"redis-cursor-compatible", false, new YesNoField(&redis_cursor_compatible, false)},
 
       /* rocksdb options */
       {"rocksdb.compression", false,
diff --git a/src/config/config.h b/src/config/config.h
index 48b09be2..3a06300b 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -144,6 +144,7 @@ struct Config {
   int pipeline_size;
   int sequence_gap;
 
+  bool redis_cursor_compatible = false;
   int log_retention_days;
   // profiling
   int profiling_sample_ratio = 0;
diff --git a/src/server/server.cc b/src/server/server.cc
index 56b809c4..a5604496 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -28,6 +28,8 @@
 #include <sys/utsname.h>
 
 #include <atomic>
+#include <cstdint>
+#include <functional>
 #include <iomanip>
 #include <jsoncons/json.hpp>
 #include <memory>
@@ -60,6 +62,9 @@ Server::Server(engine::Storage *storage, Config *config)
     stats.commands_stats[iter.first].latency = 0;
   }
 
+  // init cursor_dict_
+  cursor_dict_ = std::make_unique<CursorDictType>();
+
 #ifdef ENABLE_OPENSSL
   // init ssl context
   if (config->tls_port) {
@@ -1753,3 +1758,53 @@ std::list<std::pair<std::string, uint32_t>> Server::GetSlaveHostAndPort() {
   slave_threads_mu_.unlock();
   return result;
 }
+
+// The numeric cursor consists of a 16-bit counter, a 16-bit time stamp, a 29-bit hash,and a 3-bit cursor type. The
+// hash is used to prevent information leakage. The time_stamp is used to prevent the generation of the same cursor in
+// the extremely short period before and after a restart.
+NumberCursor::NumberCursor(CursorType cursor_type, uint16_t counter, const std::string &key_name) {
+  auto hash = static_cast<uint32_t>(std::hash<std::string>{}(key_name));
+  auto time_stamp = static_cast<uint16_t>(util::GetTimeStamp());
+  // make hash top 3-bit zero
+  constexpr uint64_t hash_mask = 0x1FFFFFFFFFFFFFFF;
+  cursor_ = static_cast<uint64_t>(counter) | static_cast<uint64_t>(time_stamp) << 16 |
+            (static_cast<uint64_t>(hash) << 32 & hash_mask) | static_cast<uint64_t>(cursor_type) << 61;
+}
+
+bool NumberCursor::IsMatch(const CursorDictElement &element, CursorType cursor_type) const {
+  return cursor_ == element.cursor.cursor_ && cursor_type == getCursorType();
+}
+
+std::string Server::GenerateCursorFromKeyName(const std::string &key_name, CursorType cursor_type, const char *prefix) {
+  if (!config_->redis_cursor_compatible) {
+    // add prefix for SCAN
+    return prefix + key_name;
+  }
+  auto counter = cursor_counter_.fetch_add(1);
+  auto number_cursor = NumberCursor(cursor_type, counter, key_name);
+  cursor_dict_->at(number_cursor.GetIndex()) = {number_cursor, key_name};
+  return number_cursor.ToString();
+}
+
+std::string Server::GetKeyNameFromCursor(const std::string &cursor, CursorType cursor_type) {
+  // When cursor is 0, cursor string is empty
+  if (cursor.empty() || !config_->redis_cursor_compatible) {
+    return cursor;
+  }
+
+  auto s = ParseInt<uint64_t>(cursor, 10);
+  // When Cursor 0 or not a Integer return empty string.
+  // Although the parameter 'cursor' is not expected to be 0, we still added a check for 0 to increase the robustness of
+  // the code.
+  if (!s.IsOK() || *s == 0) {
+    return {};
+  }
+  auto number_cursor = NumberCursor(*s);
+  // Because the index information is fully stored in the cursor, we can directly obtain the index from the cursor.
+  auto item = cursor_dict_->at(number_cursor.GetIndex());
+  if (number_cursor.IsMatch(item, cursor_type)) {
+    return item.key_name;
+  }
+
+  return {};
+}
diff --git a/src/server/server.h b/src/server/server.h
index 408537ca..7ac6128a 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -22,6 +22,10 @@
 
 #include <inttypes.h>
 
+#include <array>
+#include <atomic>
+#include <cstddef>
+#include <cstdint>
 #include <list>
 #include <map>
 #include <memory>
@@ -84,6 +88,39 @@ struct ChannelSubscribeNum {
   size_t subscribe_num;
 };
 
+// CURSOR_DICT_SIZE must be 2^n where n <= 16
+constexpr const size_t CURSOR_DICT_SIZE = 1024 * 16;
+static_assert((CURSOR_DICT_SIZE & (CURSOR_DICT_SIZE - 1)) == 0, "CURSOR_DICT_SIZE must be 2^n");
+static_assert(CURSOR_DICT_SIZE <= (1 << 16), "CURSOR_DICT_SIZE must be less than or equal to 2^16");
+
+enum class CursorType : uint8_t {
+  kTypeNone = 0,  // none
+  kTypeBase = 1,  // cursor for SCAN
+  kTypeHash = 2,  // cursor for HSCAN
+  kTypeSet = 3,   // cursor for SSCAN
+  kTypeZSet = 4,  // cursor for ZSCAN
+};
+struct CursorDictElement;
+
+class NumberCursor {
+ public:
+  NumberCursor() = default;
+  explicit NumberCursor(CursorType cursor_type, uint16_t counter, const std::string &key_name);
+  explicit NumberCursor(uint64_t number_cursor) : cursor_(number_cursor) {}
+  size_t GetIndex() const { return cursor_ % CURSOR_DICT_SIZE; }
+  bool IsMatch(const CursorDictElement &element, CursorType cursor_type) const;
+  std::string ToString() const { return std::to_string(cursor_); }
+
+ private:
+  CursorType getCursorType() const { return static_cast<CursorType>(cursor_ >> 61); }
+  uint64_t cursor_;
+};
+
+struct CursorDictElement {
+  NumberCursor cursor;
+  std::string key_name;
+};
+
 enum SlowLog {
   kSlowLogMaxArgc = 32,
   kSlowLogMaxString = 128,
@@ -196,6 +233,9 @@ class Server {
   void GetLatestKeyNumStats(const std::string &ns, KeyNumStats *stats);
   time_t GetLastScanTime(const std::string &ns);
 
+  std::string GenerateCursorFromKeyName(const std::string &key_name, CursorType cursor_type, const char *prefix = "");
+  std::string GetKeyNameFromCursor(const std::string &cursor, CursorType cursor_type);
+
   int DecrClientNum();
   int IncrClientNum();
   int IncrMonitorClientNum();
@@ -319,4 +359,9 @@ class Server {
   std::atomic<size_t> watched_key_size_ = 0;
   std::map<std::string, std::set<redis::Connection *>> watched_key_map_;
   std::shared_mutex watched_key_mutex_;
+
+  // SCAN ring buffer
+  std::atomic<uint16_t> cursor_counter_ = {0};
+  using CursorDictType = std::array<CursorDictElement, CURSOR_DICT_SIZE>;
+  std::unique_ptr<CursorDictType> cursor_dict_;
 };
diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc
index 23da37df..1d2bb229 100644
--- a/src/storage/redis_db.cc
+++ b/src/storage/redis_db.cc
@@ -277,9 +277,8 @@ rocksdb::Status Database::Scan(const std::string &cursor, uint64_t limit, const
       keys->emplace_back(user_key);
       cnt++;
     }
-
     if (!storage_->IsSlotIdEncoded() || prefix.empty()) {
-      if (!keys->empty()) {
+      if (!keys->empty() && cnt >= limit) {
         end_cursor->append(user_key);
       }
       break;
diff --git a/tests/gocase/integration/cli/cli_test.go b/tests/gocase/integration/cli/cli_test.go
index 4edbde28..675b06a9 100644
--- a/tests/gocase/integration/cli/cli_test.go
+++ b/tests/gocase/integration/cli/cli_test.go
@@ -371,4 +371,21 @@ func TestRedisCli(t *testing.T) {
 		require.Equal(t, "1000", rdb.Get(ctx, "test-counter").Val())
 		require.Regexp(t, "(?s).*All data transferred.*errors: 0.*replies: 2102.*", r)
 	})
+
+	// We need flush all, put this test at the end
+	t.Run("Test redis-cursor-compatible mode", func(t *testing.T) {
+		rdb.FlushAll(ctx)
+		util.Populate(t, rdb, "", 1000, 10)
+		require.NoError(t, rdb.ConfigSet(ctx, "redis-cursor-compatible", "yes").Err())
+
+		t.Run("Use reids-cli --bigkeys", func(t *testing.T) {
+			r := runCli(t, srv, nil, "--bigkeys").Success()
+			require.Contains(t, r, "Sampled 1000 keys in the keyspace")
+		})
+
+		t.Run("Use reids-cli --memkeys", func(t *testing.T) {
+			r := runCli(t, srv, nil, "--memkeys").Success()
+			require.Contains(t, r, "Sampled 1000 keys in the keyspace")
+		})
+	})
 }
diff --git a/tests/gocase/unit/scan/scan_test.go b/tests/gocase/unit/scan/scan_test.go
index fdf18316..ad00d19c 100644
--- a/tests/gocase/unit/scan/scan_test.go
+++ b/tests/gocase/unit/scan/scan_test.go
@@ -52,13 +52,26 @@ func TestScanEmptyKey(t *testing.T) {
 	require.Equal(t, []string{"", "fab", "fiz", "foobar"}, keys)
 }
 
-func TestScan(t *testing.T) {
+func TestScanWithNumberCursor(t *testing.T) {
 	srv := util.StartServer(t, map[string]string{})
 	defer srv.Close()
+	ctx := context.Background()
+	rdb := srv.NewClient()
+	defer func() { require.NoError(t, rdb.Close()) }()
+	require.NoError(t, rdb.ConfigSet(ctx, "redis-cursor-compatible", "yes").Err())
+	ScanTest(t, rdb, ctx)
+}
 
+func TestScanWithStringCursor(t *testing.T) {
+	srv := util.StartServer(t, map[string]string{})
+	defer srv.Close()
 	ctx := context.Background()
 	rdb := srv.NewClient()
 	defer func() { require.NoError(t, rdb.Close()) }()
+	ScanTest(t, rdb, ctx)
+}
+
+func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) {
 
 	t.Run("SCAN Basic", func(t *testing.T) {
 		require.NoError(t, rdb.FlushDB(ctx).Err())