You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kvrocks.apache.org by hu...@apache.org on 2022/10/12 16:52:20 UTC

[incubator-kvrocks] branch unstable updated: Implement the GETEX command (#961)

This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/incubator-kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new ebb5629  Implement the GETEX command (#961)
ebb5629 is described below

commit ebb562970b5ab123b6bf85a22f6dde0816974187
Author: HaveAnOrangeCat <ma...@gmail.com>
AuthorDate: Thu Oct 13 00:52:13 2022 +0800

    Implement the GETEX command (#961)
    
    Co-authored-by: MaoChongxin <ch...@shopee.com>
    Co-authored-by: Twice <tw...@apache.org>
    Co-authored-by: Twice <tw...@gmail.com>
    Co-authored-by: git-hulk <hu...@gmail.com>
---
 src/redis_cmd.cc                               | 187 ++++++++++++++++---------
 src/redis_string.cc                            |  28 ++++
 src/redis_string.h                             |   1 +
 src/util.cc                                    |   6 +
 src/util.h                                     |   1 +
 tests/gocase/unit/command/command_test.go      |   7 -
 tests/gocase/unit/type/strings/strings_test.go |  50 +++++++
 7 files changed, 208 insertions(+), 72 deletions(-)

diff --git a/src/redis_cmd.cc b/src/redis_cmd.cc
index ff98349..ad797bf 100644
--- a/src/redis_cmd.cc
+++ b/src/redis_cmd.cc
@@ -31,6 +31,7 @@
 #include <vector>
 #include <thread>
 #include <utility>
+#include <unordered_map>
 
 #include "fd_util.h"
 #include "cluster.h"
@@ -99,6 +100,73 @@ AuthResult AuthenticateUser(Connection *conn, Config* config, const std::string&
   return AuthResult::OK;
 }
 
+Status ParseTTL(const std::vector<std::string> &args,
+                std::unordered_map<std::string, bool>* white_list,
+                int *result) {
+  int ttl = 0;
+  int64_t expire = 0;
+  bool last_arg = false;
+  for (size_t i = 0; i < args.size(); i++) {
+    last_arg = (i == args.size() - 1);
+    std::string opt = Util::ToLower(args[i]);
+    if (opt == "ex" && !ttl && !last_arg) {
+      auto parse_result = ParseInt<int>(args[++i], 10);
+      if (!parse_result) {
+          return Status(Status::RedisParseErr, errValueNotInteger);
+      }
+      ttl = *parse_result;
+      if (ttl <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
+    } else if (opt == "exat"  && !ttl && !expire && !last_arg) {
+      auto parse_result = ParseInt<int64_t>(args[++i], 10);
+      if (!parse_result) {
+          return Status(Status::RedisParseErr, errValueNotInteger);
+      }
+      expire = *parse_result;
+      if (expire <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
+    } else if (opt == "pxat"  && !ttl && !expire && !last_arg) {
+      auto parse_result = ParseInt<uint64_t>(args[++i], 10);
+      if (!parse_result) {
+          return Status(Status::RedisParseErr, errValueNotInteger);
+      }
+      uint64_t expire_ms = *parse_result;
+      if (expire_ms <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
+      if (expire_ms < 1000) {
+          expire = 1;
+      } else {
+          expire = static_cast<int64_t>(expire_ms / 1000);
+      }
+    } else if (opt == "px"  && !ttl && !last_arg) {
+      int64_t ttl_ms = 0;
+      auto parse_result = ParseInt<int64_t>(args[++i], 10);
+      if (!parse_result) {
+          return Status(Status::RedisParseErr, errValueNotInteger);
+      }
+      ttl_ms = *parse_result;
+      if (ttl_ms <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
+      if (ttl_ms > 0 && ttl_ms < 1000) {
+          ttl = 1;  // round up the pttl to second
+      } else {
+          ttl = static_cast<int>(ttl_ms / 1000);
+      }
+    } else {
+      auto iter = white_list->find(opt);
+      if (iter != white_list->end()) {
+          iter->second = true;
+      } else {
+          return Status(Status::NotOK, errInvalidSyntax);
+      }
+    }
+  }
+  if (!ttl && expire) {
+    int64_t now;
+    rocksdb::Env::Default()->GetCurrentTime(&now);
+    *result = expire - now;
+  } else {
+    *result = ttl;
+  }
+  return Status::OK();
+}
+
 class CommandAuth : public Commander {
  public:
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
@@ -320,6 +388,45 @@ class CommandGet : public Commander {
   }
 };
 
+class CommandGetEx : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    white_list_ = {{"persist", false}};
+    auto s = ParseTTL(std::vector<std::string>(args.begin() + 2, args.end()), &white_list_, &ttl_);
+    if (!s.IsOK()) {
+      return s;
+    }
+    if (white_list_["persist"] && args.size() > 3) {
+      return Status(Status::NotOK, errInvalidSyntax);
+    }
+    return Commander::Parse(args);
+  }
+  Status Execute(Server *svr, Connection *conn, std::string *output) override {
+    std::string value;
+    Redis::String string_db(svr->storage_, conn->GetNamespace());
+    rocksdb::Status s = string_db.GetEx(args_[1], &value, ttl_);
+
+    // The IsInvalidArgument error means the key type maybe a bitmap
+    // which we need to fall back to the bitmap's GetString according
+    // to the `max-bitmap-to-string-mb` configuration.
+    if (s.IsInvalidArgument()) {
+      Config *config = svr->GetConfig();
+      uint32_t max_btos_size = static_cast<uint32_t>(config->max_bitmap_to_string_mb) * MiB;
+      Redis::Bitmap bitmap_db(svr->storage_, conn->GetNamespace());
+      s = bitmap_db.GetString(args_[1], max_btos_size, &value);
+    }
+    if (!s.ok() && !s.IsNotFound()) {
+      return Status(Status::RedisExecErr, s.ToString());
+    }
+    *output = s.IsNotFound() ? Redis::NilString() : Redis::BulkString(value);
+    return Status::OK();
+  }
+
+ private:
+  int ttl_ = 0;
+  std::unordered_map<std::string, bool> white_list_;
+};
+
 class CommandStrlen: public Commander {
  public:
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
@@ -473,57 +580,12 @@ class CommandAppend: public Commander {
 class CommandSet : public Commander {
  public:
   Status Parse(const std::vector<std::string> &args) override {
-    bool last_arg;
-    for (size_t i = 3; i < args.size(); i++) {
-      last_arg = (i == args.size()-1);
-      std::string opt = Util::ToLower(args[i]);
-      if (opt == "nx" && !xx_) {
-        nx_ = true;
-      } else if (opt == "xx" && !nx_) {
-        xx_ = true;
-      } else if (opt == "ex" && !ttl_ && !last_arg) {
-        auto parse_result = ParseInt<int>(args_[++i], 10);
-        if (!parse_result) {
-          return Status(Status::RedisParseErr, errValueNotInteger);
-        }
-        ttl_ = *parse_result;
-        if (ttl_ <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
-      } else if (opt == "exat" && !ttl_ && !expire_ && !last_arg) {
-        auto parse_result = ParseInt<int64_t>(args_[++i], 10);
-        if (!parse_result) {
-          return Status(Status::RedisParseErr, errValueNotInteger);
-        }
-        expire_ = *parse_result;
-        if (expire_ <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
-      } else if (opt == "pxat" && !ttl_ && !expire_ && !last_arg) {
-        auto parse_result = ParseInt<uint64_t>(args[++i], 10);
-        if (!parse_result) {
-          return Status(Status::RedisParseErr, errValueNotInteger);
-        }
-        uint64_t expire_ms = *parse_result;
-        if (expire_ms <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
-        if (expire_ms < 1000) {
-          expire_ = 1;
-        } else {
-          expire_ = static_cast<int64_t>(expire_ms/1000);
-        }
-      } else if (opt == "px" && !ttl_ && !last_arg) {
-        int64_t ttl_ms = 0;
-        auto parse_result = ParseInt<int64_t>(args_[++i], 10);
-        if (!parse_result) {
-          return Status(Status::RedisParseErr, errValueNotInteger);
-        }
-        ttl_ms = *parse_result;
-        if (ttl_ms <= 0) return Status(Status::RedisParseErr, errInvalidExpireTime);
-        if (ttl_ms > 0 && ttl_ms < 1000) {
-          ttl_ = 1;  // round up the pttl to second
-        } else {
-          ttl_ = static_cast<int>(ttl_ms/1000);
-        }
-      } else {
-        return Status(Status::NotOK, errInvalidSyntax);
-      }
+    white_list_ = {{"nx", false}, {"xx", false}};
+    auto s = ParseTTL(std::vector<std::string>(args.begin() + 3, args.end()), &white_list_, &ttl_);
+    if (white_list_["nx"] && white_list_["xx"]) {
+      return Status(Status::NotOK, errInvalidSyntax);
     }
+    if (!s.IsOK()) { return s; }
     return Commander::Parse(args);
   }
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
@@ -531,28 +593,24 @@ class CommandSet : public Commander {
     Redis::String string_db(svr->storage_, conn->GetNamespace());
     rocksdb::Status s;
 
-    if (!ttl_ && expire_) {
-      int64_t now;
-      rocksdb::Env::Default()->GetCurrentTime(&now);
-      ttl_ = expire_ - now;
-      if (ttl_ <= 0) {
-        string_db.Del(args_[1]);
-        *output = Redis::SimpleString("OK");
-        return Status::OK();
-      }
+    if (ttl_ < 0) {
+      string_db.Del(args_[1]);
+      *output = Redis::SimpleString("OK");
+      return Status::OK();
     }
 
-    if (nx_) {
+    if (white_list_["nx"]) {
       s = string_db.SetNX(args_[1], args_[2], ttl_, &ret);
-    } else if (xx_) {
+    } else if (white_list_["xx"]) {
       s = string_db.SetXX(args_[1], args_[2], ttl_, &ret);
     } else {
       s = string_db.SetEX(args_[1], args_[2], ttl_);
     }
+
     if (!s.ok()) {
       return Status(Status::RedisExecErr, s.ToString());
     }
-    if ((nx_ || xx_) && !ret) {
+    if ((white_list_["nx"] || white_list_["xx"]) && !ret) {
       *output = Redis::NilString();
     } else {
       *output = Redis::SimpleString("OK");
@@ -561,10 +619,8 @@ class CommandSet : public Commander {
   }
 
  private:
-  bool xx_ = false;
-  bool nx_ = false;
   int ttl_ = 0;
-  int64_t expire_ = 0;
+  std::unordered_map<std::string, bool> white_list_;
 };
 
 class CommandSetEX : public Commander {
@@ -5916,6 +5972,7 @@ CommandAttributes redisCommandTable[] = {
     ADD_CMD("unlink", -2, "write", 1, -1, 1, CommandDel),
 
     ADD_CMD("get", 2, "read-only", 1, 1, 1, CommandGet),
+    ADD_CMD("getex", -2, "write", 1, 1, 1, CommandGetEx),
     ADD_CMD("strlen", 2, "read-only", 1, 1, 1, CommandStrlen),
     ADD_CMD("getset", 3, "write", 1, 1, 1, CommandGetSet),
     ADD_CMD("getrange", 4, "read-only", 1, 1, 1, CommandGetRange),
diff --git a/src/redis_string.cc b/src/redis_string.cc
index 5c88268..c982b86 100644
--- a/src/redis_string.cc
+++ b/src/redis_string.cc
@@ -150,6 +150,34 @@ rocksdb::Status String::Get(const std::string &user_key, std::string *value) {
   return getValue(ns_key, value);
 }
 
+rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, int ttl) {
+    uint32_t expire = 0;
+    if (ttl > 0) {
+        int64_t now;
+        rocksdb::Env::Default()->GetCurrentTime(&now);
+        expire = uint32_t(now) + ttl;
+    }
+    std::string ns_key;
+    AppendNamespacePrefix(user_key, &ns_key);
+
+    LockGuard guard(storage_->GetLockManager(), ns_key);
+    rocksdb::Status s = getValue(ns_key, value);
+    if (!s.ok() && s.IsNotFound()) return s;
+
+    std::string raw_data;
+    Metadata metadata(kRedisString, false);
+    metadata.expire = expire;
+    metadata.Encode(&raw_data);
+    raw_data.append(value->data(), value->size());
+    rocksdb::WriteBatch batch;
+    WriteBatchLogData log_data(kRedisString);
+    batch.PutLogData(log_data.Encode());
+    batch.Put(metadata_cf_handle_, ns_key, raw_data);
+    s = storage_->Write(storage_->DefaultWriteOptions(), &batch);
+    if (!s.ok()) return s;
+    return rocksdb::Status::OK();
+}
+
 rocksdb::Status String::GetSet(const std::string &user_key, const std::string &new_value, std::string *old_value) {
   std::string ns_key;
   AppendNamespacePrefix(user_key, &ns_key);
diff --git a/src/redis_string.h b/src/redis_string.h
index 37d66d2..7af77e2 100644
--- a/src/redis_string.h
+++ b/src/redis_string.h
@@ -40,6 +40,7 @@ class String : public Database {
   explicit String(Engine::Storage *storage, const std::string &ns) : Database(storage, ns) {}
   rocksdb::Status Append(const std::string &user_key, const std::string &value, int *ret);
   rocksdb::Status Get(const std::string &user_key, std::string *value);
+  rocksdb::Status GetEx(const std::string &user_key, std::string *value, int ttl);
   rocksdb::Status GetSet(const std::string &user_key, const std::string &new_value, std::string *old_value);
   rocksdb::Status GetDel(const std::string &user_key, std::string *value);
   rocksdb::Status Set(const std::string &user_key, const std::string &value);
diff --git a/src/util.cc b/src/util.cc
index c5f10a0..c8433d9 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -376,6 +376,12 @@ std::string ToLower(std::string in) {
   return in;
 }
 
+bool CaseInsensitiveCompare(const std::string& lhs, const std::string& rhs) {
+  return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char l, char r) {
+    return std::tolower(l) == std::tolower(r);
+  });
+}
+
 std::string Trim(std::string in, const std::string &chars) {
   if (in.empty()) return in;
 
diff --git a/src/util.h b/src/util.h
index 0f9a60d..78e64dc 100644
--- a/src/util.h
+++ b/src/util.h
@@ -55,6 +55,7 @@ Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min = INT6
 Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min = INT64_MIN, int64_t max = INT64_MAX);
 const std::string Float2String(double d);
 std::string ToLower(std::string in);
+bool CaseInsensitiveCompare(const std::string& lhs, const std::string& rhs);
 void BytesToHuman(char *buf, size_t size, uint64_t n);
 std::string Trim(std::string in, const std::string &chars);
 std::vector<std::string> Split(const std::string &in, const std::string &delim);
diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go
index f1be98a..a1eb556 100644
--- a/tests/gocase/unit/command/command_test.go
+++ b/tests/gocase/unit/command/command_test.go
@@ -35,13 +35,6 @@ func TestCommand(t *testing.T) {
 	rdb := srv.NewClient()
 	defer func() { require.NoError(t, rdb.Close()) }()
 
-	t.Run("Kvrocks supports 185 commands currently", func(t *testing.T) {
-		r := rdb.Do(ctx, "COMMAND", "COUNT")
-		v, err := r.Int()
-		require.NoError(t, err)
-		require.Equal(t, 185, v)
-	})
-
 	t.Run("acquire GET command info by COMMAND INFO", func(t *testing.T) {
 		r := rdb.Do(ctx, "COMMAND", "INFO", "GET")
 		vs, err := r.Slice()
diff --git a/tests/gocase/unit/type/strings/strings_test.go b/tests/gocase/unit/type/strings/strings_test.go
index cd231ff..67eca31 100644
--- a/tests/gocase/unit/type/strings/strings_test.go
+++ b/tests/gocase/unit/type/strings/strings_test.go
@@ -134,6 +134,56 @@ func TestString(t *testing.T) {
 		require.Equal(t, "20", rdb.Get(ctx, "x").Val())
 	})
 
+	t.Run("GETEX EX option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+		require.NoError(t, rdb.GetEx(ctx, "foo", 10*time.Second).Err())
+		util.BetweenValues(t, rdb.TTL(ctx, "foo").Val(), 5*time.Second, 10*time.Second)
+	})
+
+	t.Run("GETEX PX option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+		require.NoError(t, rdb.GetEx(ctx, "foo", 10000*time.Millisecond).Err())
+		util.BetweenValues(t, rdb.TTL(ctx, "foo").Val(), 5000*time.Millisecond, 10000*time.Millisecond)
+	})
+
+	t.Run("GETEX EXAT option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+		require.NoError(t, rdb.Do(ctx, "getex", "foo", "exat", time.Now().Add(10*time.Second).Unix()).Err())
+		util.BetweenValues(t, rdb.TTL(ctx, "foo").Val(), 5*time.Second, 10*time.Second)
+	})
+
+	t.Run("GETEX PXAT option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+		require.NoError(t, rdb.Do(ctx, "getex", "foo", "pxat", time.Now().Add(10000*time.Millisecond).UnixMilli()).Err())
+		util.BetweenValues(t, rdb.TTL(ctx, "foo").Val(), 5000*time.Millisecond, 10000*time.Millisecond)
+	})
+
+	t.Run("GETEX PERSIST option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 10*time.Second).Err())
+		util.BetweenValues(t, rdb.TTL(ctx, "foo").Val(), 5*time.Second, 10*time.Second)
+		require.NoError(t, rdb.Do(ctx, "getex", "foo", "persist").Err())
+		require.EqualValues(t, -1, rdb.TTL(ctx, "foo").Val())
+	})
+
+	t.Run("GETEX no option", func(t *testing.T) {
+		require.NoError(t, rdb.Del(ctx, "foo").Err())
+		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+		require.Equal(t, "bar", rdb.GetEx(ctx, "foo", 0).Val())
+	})
+
+	t.Run("GETEX syntax errors", func(t *testing.T) {
+		util.ErrorRegexp(t, rdb.Do(ctx, "getex", "foo", "non-existent-option").Err(), ".*syntax*.")
+	})
+
+	t.Run("GETEX no arguments", func(t *testing.T) {
+		util.ErrorRegexp(t, rdb.Do(ctx, "getex").Err(), ".*wrong number of arguments*.")
+	})
+
 	t.Run("GETDEL command", func(t *testing.T) {
 		require.NoError(t, rdb.Del(ctx, "foo").Err())
 		require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())