You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kvrocks.apache.org by hu...@apache.org on 2022/09/18 10:39:29 UTC

[incubator-kvrocks] branch unstable updated: Implement the command `hello` (#881)

This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/incubator-kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 231f750  Implement the command `hello` (#881)
231f750 is described below

commit 231f7504db60506427ca091133debce2c987fe36
Author: mwish <ma...@gmail.com>
AuthorDate: Sun Sep 18 18:39:23 2022 +0800

    Implement the command `hello` (#881)
---
 src/redis_cmd.cc                          | 118 ++++++++++++++++++++++++++----
 src/redis_connection.cc                   |  37 ++++------
 src/redis_reply.cc                        |  33 +++++----
 src/redis_reply.h                         |   6 +-
 src/util.cc                               |   4 +-
 tests/gocase/unit/command/command_test.go |   4 +-
 tests/gocase/unit/hello/hello_test.go     | 117 +++++++++++++++++++++++++++++
 7 files changed, 258 insertions(+), 61 deletions(-)

diff --git a/src/redis_cmd.cc b/src/redis_cmd.cc
index 16b00e9..8943dd0 100644
--- a/src/redis_cmd.cc
+++ b/src/redis_cmd.cc
@@ -56,6 +56,7 @@
 #include "scripting.h"
 #include "slot_import.h"
 #include "slot_migrate.h"
+#include "parse_util.h"
 
 namespace Redis {
 
@@ -73,29 +74,47 @@ const char *errUnbalacedStreamList =
 const char *errTimeoutIsNegative = "timeout is negative";
 const char *errLimitOptionNotAllowed = "syntax error, LIMIT cannot be used without the special ~ option";
 
+enum class AuthResult {
+  OK,
+  INVALID_PASSWORD,
+  NO_REQUIRE_PASS,
+};
+
+AuthResult AuthenticateUser(Connection *conn, Config* config, const std::string& user_password) {
+  auto iter = config->tokens.find(user_password);
+  if (iter != config->tokens.end()) {
+    conn->SetNamespace(iter->second);
+    conn->BecomeUser();
+    return AuthResult::OK;
+  }
+  const auto& requirepass = config->requirepass;
+  if (!requirepass.empty() && user_password != requirepass) {
+    return AuthResult::INVALID_PASSWORD;
+  }
+  conn->SetNamespace(kDefaultNamespace);
+  conn->BecomeAdmin();
+  if (requirepass.empty()) {
+    return AuthResult::NO_REQUIRE_PASS;
+  }
+  return AuthResult::OK;
+}
+
 class CommandAuth : public Commander {
  public:
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
     Config *config = svr->GetConfig();
-    auto user_password = args_[1];
-    auto iter = config->tokens.find(user_password);
-    if (iter != config->tokens.end()) {
-      conn->SetNamespace(iter->second);
-      conn->BecomeUser();
+    auto& user_password = args_[1];
+    AuthResult result = AuthenticateUser(conn, config, user_password);
+    switch (result) {
+    case AuthResult::OK:
       *output = Redis::SimpleString("OK");
-      return Status::OK();
-    }
-    const auto requirepass = config->requirepass;
-    if (!requirepass.empty() && user_password != requirepass) {
+      break;
+    case AuthResult::INVALID_PASSWORD:
       *output = Redis::Error("ERR invalid password");
-      return Status::OK();
-    }
-    conn->SetNamespace(kDefaultNamespace);
-    conn->BecomeAdmin();
-    if (requirepass.empty()) {
+      break;
+    case AuthResult::NO_REQUIRE_PASS:
       *output = Redis::Error("ERR Client sent AUTH, but no password is set");
-    } else {
-      *output = Redis::SimpleString("OK");
+      break;
     }
     return Status::OK();
   }
@@ -4132,6 +4151,72 @@ class CommandEcho : public Commander {
   }
 };
 
+/* HELLO [<protocol-version> [AUTH <password>] [SETNAME <name>] ] */
+class CommandHello final : public Commander {
+ public:
+  Status Execute(Server *svr, Connection *conn, std::string *output) override {
+    size_t next_arg = 1;
+    if (args_.size() >= 2) {
+      int64_t protocol;
+      auto parseResult = ParseInt<int64_t>(args_[next_arg], /* base= */ 10);
+      ++next_arg;
+      if (!parseResult.IsOK()) {
+        return Status(Status::NotOK, "Protocol version is not an integer or out of range");
+      }
+      protocol = parseResult.GetValue();
+
+      // In redis, it will check protocol < 2 or protocol > 3,
+      // kvrocks only supports REPL2 by now, but for supporting some
+      // `hello 3`, it will not report error when using 3.
+      if (protocol < 2 || protocol > 3) {
+        return Status(Status::NotOK, "-NOPROTO unsupported protocol version");
+      }
+    }
+
+    // Handling AUTH and SETNAME
+    for (; next_arg < args_.size(); ++next_arg) {
+      size_t moreargs = args_.size() - next_arg - 1;
+      const std::string& opt = args_[next_arg];
+      if (opt == "AUTH" && moreargs != 0) {
+        const auto& user_password = args_[next_arg + 1];
+        auto authResult = AuthenticateUser(conn, svr->GetConfig(), user_password);
+        switch (authResult) {
+        case AuthResult::INVALID_PASSWORD:
+          return Status(Status::NotOK, "invalid password");
+        case AuthResult::NO_REQUIRE_PASS:
+          return Status(Status::NotOK, "Client sent AUTH, but no password is set");
+        case AuthResult::OK:
+          break;
+        }
+        next_arg += 1;
+      } else if (opt == "SETNAME" && moreargs != 0) {
+        const std::string& name = args_[next_arg + 1];
+        conn->SetName(name);
+        next_arg += 1;
+      } else {
+        *output = Redis::Error("Syntax error in HELLO option " + opt);
+        return Status::OK();
+      }
+    }
+
+    std::vector<std::string> output_list;
+    output_list.push_back(Redis::BulkString("server"));
+    output_list.push_back(Redis::BulkString("redis"));
+    output_list.push_back(Redis::BulkString("proto"));
+    output_list.push_back(Redis::Integer(2));
+
+    output_list.push_back(Redis::BulkString("mode"));
+    // Note: sentinel is not supported in kvrocks.
+    if (svr->GetConfig()->cluster_enabled) {
+      output_list.push_back(Redis::BulkString("cluster"));
+    } else {
+      output_list.push_back(Redis::BulkString("standalone"));
+    }
+    *output = Redis::Array(output_list);
+    return Status::OK();
+  }
+};
+
 class CommandScanBase : public Commander {
  public:
   Status ParseMatchAndCountParam(const std::string &type, std::string value) {
@@ -5680,6 +5765,7 @@ CommandAttributes redisCommandTable[] = {
     ADD_CMD("debug", -2, "read-only exclusive", 0, 0, 0, CommandDebug),
     ADD_CMD("command", -1, "read-only", 0, 0, 0, CommandCommand),
     ADD_CMD("echo", 2, "read-only", 0, 0, 0, CommandEcho),
+    ADD_CMD("hello", -1,  "read-only ok-loading", 0, 0, 0, CommandHello),
 
     ADD_CMD("ttl", 2, "read-only", 1, 1, 1, CommandTTL),
     ADD_CMD("pttl", 2, "read-only", 1, 1, 1, CommandPTTL),
diff --git a/src/redis_connection.cc b/src/redis_connection.cc
index a3932e4..db37b28 100644
--- a/src/redis_connection.cc
+++ b/src/redis_connection.cc
@@ -18,17 +18,17 @@
  *
  */
 
-#include <rocksdb/perf_context.h>
-#include <rocksdb/iostats_context.h>
 #include <glog/logging.h>
+#include <rocksdb/iostats_context.h>
+#include <rocksdb/perf_context.h>
 #ifdef ENABLE_OPENSSL
 #include <event2/bufferevent_ssl.h>
 #endif
 
 #include "redis_connection.h"
-#include "worker.h"
 #include "server.h"
 #include "tls_util.h"
+#include "worker.h"
 
 namespace Redis {
 
@@ -74,9 +74,7 @@ void Connection::Close() {
   owner_->FreeConnection(this);
 }
 
-void Connection::Detach() {
-  owner_->DetachConnection(this);
-}
+void Connection::Detach() { owner_->DetachConnection(this); }
 
 void Connection::OnRead(struct bufferevent *bev, void *ctx) {
   DLOG(INFO) << "[connection] on read: " << bufferevent_getfd(bev);
@@ -143,23 +141,21 @@ void Connection::SendFile(int fd) {
 void Connection::SetAddr(std::string ip, int port) {
   ip_ = std::move(ip);
   port_ = port;
-  addr_ = ip_ +":"+ std::to_string(port_);
+  addr_ = ip_ + ":" + std::to_string(port_);
 }
 
 uint64_t Connection::GetAge() {
   time_t now;
   time(&now);
-  return static_cast<uint64_t>(now-create_time_);
+  return static_cast<uint64_t>(now - create_time_);
 }
 
-void Connection::SetLastInteraction() {
-  time(&last_interaction_);
-}
+void Connection::SetLastInteraction() { time(&last_interaction_); }
 
 uint64_t Connection::GetIdleTime() {
   time_t now;
   time(&now);
-  return static_cast<uint64_t>(now-last_interaction_);
+  return static_cast<uint64_t>(now - last_interaction_);
 }
 
 // Currently, master connection is not handled in connection
@@ -185,17 +181,11 @@ std::string Connection::GetFlags() {
   return flags;
 }
 
-void Connection::EnableFlag(Flag flag) {
-  flags_ |= flag;
-}
+void Connection::EnableFlag(Flag flag) { flags_ |= flag; }
 
-void Connection::DisableFlag(Flag flag) {
-  flags_ &= (~flag);
-}
+void Connection::DisableFlag(Flag flag) { flags_ &= (~flag); }
 
-bool Connection::IsFlagEnabled(Flag flag) {
-  return (flags_ & flag) > 0;
-}
+bool Connection::IsFlagEnabled(Flag flag) { return (flags_ & flag) > 0; }
 
 void Connection::SubscribeChannel(const std::string &channel) {
   for (const auto &chan : subscribe_channels_) {
@@ -333,7 +323,8 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
     }
 
     if (GetNamespace().empty()) {
-      if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth") {
+      if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth" &&
+          Util::ToLower(cmd_tokens.front()) != "hello") {
         Reply(Redis::Error("NOAUTH Authentication required."));
         continue;
       }
@@ -402,7 +393,7 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
       if (!s.IsOK()) {
         if (IsFlagEnabled(Connection::kMultiExec)) multi_error_ = true;
         Reply(Redis::Error(s.Msg()));
-        continue;;
+        continue;
       }
     }
 
diff --git a/src/redis_reply.cc b/src/redis_reply.cc
index 14a32ca..f9dcc1e 100644
--- a/src/redis_reply.cc
+++ b/src/redis_reply.cc
@@ -45,36 +45,39 @@ std::string MultiLen(int64_t len) {
   return "*"+std::to_string(len)+"\r\n";
 }
 
-std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string) {
+std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string) {
+  std::string result = "*" + std::to_string(values.size()) + CRLF;
   for (size_t i = 0; i < values.size(); i++) {
     if (values[i].empty() && output_nil_for_empty_string) {
-      values[i] = NilString();
+      result += NilString();
     }  else {
-      values[i] = BulkString(values[i]);
+      result += BulkString(values[i]);
     }
   }
-  return Array(values);
+  return result;
 }
 
 
-std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses) {
+std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses) {
+  std::string result = "*" + std::to_string(values.size()) + CRLF;
   for (size_t i = 0; i < values.size(); i++) {
     if (i < statuses.size() && !statuses[i].ok()) {
-      values[i] = NilString();
+      result += NilString();
     } else {
-      values[i] = BulkString(values[i]);
+      result += BulkString(values[i]);
     }
   }
-  return Array(values);
+  return result;
 }
-std::string Array(std::vector<std::string> list) {
-  std::string::size_type n = std::accumulate(
-    list.begin(), list.end(), std::string::size_type(0),
-    [] ( std::string::size_type n, const std::string &s ) { return ( n += s.size() ); });
+
+std::string Array(const std::vector<std::string>& list) {
+  size_t n = std::accumulate(
+      list.begin(), list.end(), 0, [] (size_t n, const std::string &s) { return n + s.size(); });
   std::string result = "*" + std::to_string(list.size()) + CRLF;
-  result.reserve(n);
-  return std::accumulate(list.begin(), list.end(), result,
-    [](std::string &dest, std::string const &item) -> std::string& {dest += item; return dest;});
+  std::string::size_type final_size = result.size() + n;
+  result.reserve(final_size);
+  for (const auto& i : list) result += i;
+  return result;
 }
 
 std::string Command2RESP(const std::vector<std::string> &cmd_args) {
diff --git a/src/redis_reply.h b/src/redis_reply.h
index ff4820c..e049ef1 100644
--- a/src/redis_reply.h
+++ b/src/redis_reply.h
@@ -35,8 +35,8 @@ std::string Integer(int64_t data);
 std::string BulkString(const std::string &data);
 std::string NilString();
 std::string MultiLen(int64_t len);
-std::string Array(std::vector<std::string> list);
-std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string = true);
-std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses);
+std::string Array(const std::vector<std::string>& list);
+std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string = true);
+std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses);
 std::string Command2RESP(const std::vector<std::string> &cmd_args);
 }  // namespace Redis
diff --git a/src/util.cc b/src/util.cc
index 0c9a1ab..31944f0 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -344,7 +344,7 @@ Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min, int64
   try {
     *n = static_cast<int64_t>(std::stoll(str));
     if (max > min && (*n < min || *n > max)) {
-      return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
+      return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
     }
   } catch (std::exception &e) {
     return Status(Status::NotOK, "value is not an integer or out of range");
@@ -356,7 +356,7 @@ Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t
   try {
     *n = static_cast<int64_t>(std::stoll(str, nullptr, 8));
     if (max > min && (*n < min || *n > max)) {
-      return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
+      return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
     }
   } catch (std::exception &e) {
     return Status(Status::NotOK, "value is not an integer or out of range");
diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go
index a2a424f..46c4e4c 100644
--- a/tests/gocase/unit/command/command_test.go
+++ b/tests/gocase/unit/command/command_test.go
@@ -35,11 +35,11 @@ func TestCommand(t *testing.T) {
 	rdb := srv.NewClient()
 	defer func() { require.NoError(t, rdb.Close()) }()
 
-	t.Run("Kvrocks supports 180 commands currently", func(t *testing.T) {
+	t.Run("Kvrocks supports 181 commands currently", func(t *testing.T) {
 		r := rdb.Do(ctx, "COMMAND", "COUNT")
 		v, err := r.Int()
 		require.NoError(t, err)
-		require.Equal(t, 180, v)
+		require.Equal(t, 181, v)
 	})
 
 	t.Run("acquire GET command info by COMMAND INFO", func(t *testing.T) {
diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go
new file mode 100644
index 0000000..4f2adbf
--- /dev/null
+++ b/tests/gocase/unit/hello/hello_test.go
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package auth
+
+import (
+	"context"
+	"testing"
+
+	"github.com/apache/incubator-kvrocks/tests/gocase/util"
+	"github.com/stretchr/testify/require"
+)
+
+func TestHello(t *testing.T) {
+	srv := util.StartServer(t, map[string]string{})
+	defer srv.Close()
+
+	ctx := context.Background()
+	rdb := srv.NewClient()
+	defer func() { require.NoError(t, rdb.Close()) }()
+
+	t.Run("hello with wrong protocol", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "1")
+		require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
+	})
+
+	t.Run("hello with protocol 2", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "2")
+		rList := r.Val().([]interface{})
+		require.EqualValues(t, rList[2], "proto")
+		require.EqualValues(t, rList[3], 2)
+	})
+
+	t.Run("hello with protocol 3", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "3")
+		rList := r.Val().([]interface{})
+		require.EqualValues(t, rList[2], "proto")
+		require.EqualValues(t, rList[3], 2)
+	})
+
+	t.Run("hello with wrong protocol", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "5")
+		require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
+	})
+
+	t.Run("hello with non protocol", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "AUTH")
+		require.ErrorContains(t, r.Err(), "Protocol version is not an integer or out of range")
+	})
+
+	t.Run("hello with non protocol", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "2", "SETNAME", "kvrocks")
+		rList := r.Val().([]interface{})
+		require.EqualValues(t, rList[2], "proto")
+		require.EqualValues(t, rList[3], 2)
+
+		r = rdb.Do(ctx, "CLIENT", "GETNAME")
+		require.EqualValues(t, r.Val(), "kvrocks")
+	})
+}
+
+func TestHelloWithAuth(t *testing.T) {
+	srv := util.StartServer(t, map[string]string{
+		"requirepass": "foobar",
+	})
+	defer srv.Close()
+
+	ctx := context.Background()
+	rdb := srv.NewClient()
+	defer func() { require.NoError(t, rdb.Close()) }()
+
+	t.Run("AUTH fails when a wrong password is given", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "3", "AUTH", "wrong!")
+		require.ErrorContains(t, r.Err(), "invalid password")
+	})
+
+	t.Run("Arbitrary command gives an error when AUTH is required", func(t *testing.T) {
+		r := rdb.Set(ctx, "foo", "bar", 0)
+		require.ErrorContains(t, r.Err(), "NOAUTH Authentication required.")
+	})
+
+	t.Run("AUTH succeeds when the right password is given", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "3", "AUTH", "foobar")
+		t.Log(r)
+	})
+
+	t.Run("Once AUTH succeeded we can actually send commands to the server", func(t *testing.T) {
+		require.Equal(t, "OK", rdb.Set(ctx, "foo", 100, 0).Val())
+		require.EqualValues(t, 101, rdb.Incr(ctx, "foo").Val())
+	})
+
+	t.Run("hello with non protocol", func(t *testing.T) {
+		r := rdb.Do(ctx, "HELLO", "2", "AUTH", "foobar", "SETNAME", "kvrocks")
+		rList := r.Val().([]interface{})
+		require.EqualValues(t, rList[2], "proto")
+		require.EqualValues(t, rList[3], 2)
+
+		r = rdb.Do(ctx, "CLIENT", "GETNAME")
+		require.EqualValues(t, r.Val(), "kvrocks")
+	})
+}