You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kvrocks.apache.org by hu...@apache.org on 2022/09/18 10:39:29 UTC
[incubator-kvrocks] branch unstable updated: Implement the command `hello` (#881)
This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/incubator-kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 231f750 Implement the command `hello` (#881)
231f750 is described below
commit 231f7504db60506427ca091133debce2c987fe36
Author: mwish <ma...@gmail.com>
AuthorDate: Sun Sep 18 18:39:23 2022 +0800
Implement the command `hello` (#881)
---
src/redis_cmd.cc | 118 ++++++++++++++++++++++++++----
src/redis_connection.cc | 37 ++++------
src/redis_reply.cc | 33 +++++----
src/redis_reply.h | 6 +-
src/util.cc | 4 +-
tests/gocase/unit/command/command_test.go | 4 +-
tests/gocase/unit/hello/hello_test.go | 117 +++++++++++++++++++++++++++++
7 files changed, 258 insertions(+), 61 deletions(-)
diff --git a/src/redis_cmd.cc b/src/redis_cmd.cc
index 16b00e9..8943dd0 100644
--- a/src/redis_cmd.cc
+++ b/src/redis_cmd.cc
@@ -56,6 +56,7 @@
#include "scripting.h"
#include "slot_import.h"
#include "slot_migrate.h"
+#include "parse_util.h"
namespace Redis {
@@ -73,29 +74,47 @@ const char *errUnbalacedStreamList =
const char *errTimeoutIsNegative = "timeout is negative";
const char *errLimitOptionNotAllowed = "syntax error, LIMIT cannot be used without the special ~ option";
+enum class AuthResult {
+ OK,
+ INVALID_PASSWORD,
+ NO_REQUIRE_PASS,
+};
+
+AuthResult AuthenticateUser(Connection *conn, Config* config, const std::string& user_password) {
+ auto iter = config->tokens.find(user_password);
+ if (iter != config->tokens.end()) {
+ conn->SetNamespace(iter->second);
+ conn->BecomeUser();
+ return AuthResult::OK;
+ }
+ const auto& requirepass = config->requirepass;
+ if (!requirepass.empty() && user_password != requirepass) {
+ return AuthResult::INVALID_PASSWORD;
+ }
+ conn->SetNamespace(kDefaultNamespace);
+ conn->BecomeAdmin();
+ if (requirepass.empty()) {
+ return AuthResult::NO_REQUIRE_PASS;
+ }
+ return AuthResult::OK;
+}
+
class CommandAuth : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
Config *config = svr->GetConfig();
- auto user_password = args_[1];
- auto iter = config->tokens.find(user_password);
- if (iter != config->tokens.end()) {
- conn->SetNamespace(iter->second);
- conn->BecomeUser();
+ auto& user_password = args_[1];
+ AuthResult result = AuthenticateUser(conn, config, user_password);
+ switch (result) {
+ case AuthResult::OK:
*output = Redis::SimpleString("OK");
- return Status::OK();
- }
- const auto requirepass = config->requirepass;
- if (!requirepass.empty() && user_password != requirepass) {
+ break;
+ case AuthResult::INVALID_PASSWORD:
*output = Redis::Error("ERR invalid password");
- return Status::OK();
- }
- conn->SetNamespace(kDefaultNamespace);
- conn->BecomeAdmin();
- if (requirepass.empty()) {
+ break;
+ case AuthResult::NO_REQUIRE_PASS:
*output = Redis::Error("ERR Client sent AUTH, but no password is set");
- } else {
- *output = Redis::SimpleString("OK");
+ break;
}
return Status::OK();
}
@@ -4132,6 +4151,72 @@ class CommandEcho : public Commander {
}
};
+/* HELLO [<protocol-version> [AUTH <password>] [SETNAME <name>] ] */
+class CommandHello final : public Commander {
+ public:
+ Status Execute(Server *svr, Connection *conn, std::string *output) override {
+ size_t next_arg = 1;
+ if (args_.size() >= 2) {
+ int64_t protocol;
+ auto parseResult = ParseInt<int64_t>(args_[next_arg], /* base= */ 10);
+ ++next_arg;
+ if (!parseResult.IsOK()) {
+ return Status(Status::NotOK, "Protocol version is not an integer or out of range");
+ }
+ protocol = parseResult.GetValue();
+
+ // In redis, it will check protocol < 2 or protocol > 3,
+ // kvrocks only supports REPL2 by now, but for supporting some
+ // `hello 3`, it will not report error when using 3.
+ if (protocol < 2 || protocol > 3) {
+ return Status(Status::NotOK, "-NOPROTO unsupported protocol version");
+ }
+ }
+
+ // Handling AUTH and SETNAME
+ for (; next_arg < args_.size(); ++next_arg) {
+ size_t moreargs = args_.size() - next_arg - 1;
+ const std::string& opt = args_[next_arg];
+ if (opt == "AUTH" && moreargs != 0) {
+ const auto& user_password = args_[next_arg + 1];
+ auto authResult = AuthenticateUser(conn, svr->GetConfig(), user_password);
+ switch (authResult) {
+ case AuthResult::INVALID_PASSWORD:
+ return Status(Status::NotOK, "invalid password");
+ case AuthResult::NO_REQUIRE_PASS:
+ return Status(Status::NotOK, "Client sent AUTH, but no password is set");
+ case AuthResult::OK:
+ break;
+ }
+ next_arg += 1;
+ } else if (opt == "SETNAME" && moreargs != 0) {
+ const std::string& name = args_[next_arg + 1];
+ conn->SetName(name);
+ next_arg += 1;
+ } else {
+ *output = Redis::Error("Syntax error in HELLO option " + opt);
+ return Status::OK();
+ }
+ }
+
+ std::vector<std::string> output_list;
+ output_list.push_back(Redis::BulkString("server"));
+ output_list.push_back(Redis::BulkString("redis"));
+ output_list.push_back(Redis::BulkString("proto"));
+ output_list.push_back(Redis::Integer(2));
+
+ output_list.push_back(Redis::BulkString("mode"));
+ // Note: sentinel is not supported in kvrocks.
+ if (svr->GetConfig()->cluster_enabled) {
+ output_list.push_back(Redis::BulkString("cluster"));
+ } else {
+ output_list.push_back(Redis::BulkString("standalone"));
+ }
+ *output = Redis::Array(output_list);
+ return Status::OK();
+ }
+};
+
class CommandScanBase : public Commander {
public:
Status ParseMatchAndCountParam(const std::string &type, std::string value) {
@@ -5680,6 +5765,7 @@ CommandAttributes redisCommandTable[] = {
ADD_CMD("debug", -2, "read-only exclusive", 0, 0, 0, CommandDebug),
ADD_CMD("command", -1, "read-only", 0, 0, 0, CommandCommand),
ADD_CMD("echo", 2, "read-only", 0, 0, 0, CommandEcho),
+ ADD_CMD("hello", -1, "read-only ok-loading", 0, 0, 0, CommandHello),
ADD_CMD("ttl", 2, "read-only", 1, 1, 1, CommandTTL),
ADD_CMD("pttl", 2, "read-only", 1, 1, 1, CommandPTTL),
diff --git a/src/redis_connection.cc b/src/redis_connection.cc
index a3932e4..db37b28 100644
--- a/src/redis_connection.cc
+++ b/src/redis_connection.cc
@@ -18,17 +18,17 @@
*
*/
-#include <rocksdb/perf_context.h>
-#include <rocksdb/iostats_context.h>
#include <glog/logging.h>
+#include <rocksdb/iostats_context.h>
+#include <rocksdb/perf_context.h>
#ifdef ENABLE_OPENSSL
#include <event2/bufferevent_ssl.h>
#endif
#include "redis_connection.h"
-#include "worker.h"
#include "server.h"
#include "tls_util.h"
+#include "worker.h"
namespace Redis {
@@ -74,9 +74,7 @@ void Connection::Close() {
owner_->FreeConnection(this);
}
-void Connection::Detach() {
- owner_->DetachConnection(this);
-}
+void Connection::Detach() { owner_->DetachConnection(this); }
void Connection::OnRead(struct bufferevent *bev, void *ctx) {
DLOG(INFO) << "[connection] on read: " << bufferevent_getfd(bev);
@@ -143,23 +141,21 @@ void Connection::SendFile(int fd) {
void Connection::SetAddr(std::string ip, int port) {
ip_ = std::move(ip);
port_ = port;
- addr_ = ip_ +":"+ std::to_string(port_);
+ addr_ = ip_ + ":" + std::to_string(port_);
}
uint64_t Connection::GetAge() {
time_t now;
time(&now);
- return static_cast<uint64_t>(now-create_time_);
+ return static_cast<uint64_t>(now - create_time_);
}
-void Connection::SetLastInteraction() {
- time(&last_interaction_);
-}
+void Connection::SetLastInteraction() { time(&last_interaction_); }
uint64_t Connection::GetIdleTime() {
time_t now;
time(&now);
- return static_cast<uint64_t>(now-last_interaction_);
+ return static_cast<uint64_t>(now - last_interaction_);
}
// Currently, master connection is not handled in connection
@@ -185,17 +181,11 @@ std::string Connection::GetFlags() {
return flags;
}
-void Connection::EnableFlag(Flag flag) {
- flags_ |= flag;
-}
+void Connection::EnableFlag(Flag flag) { flags_ |= flag; }
-void Connection::DisableFlag(Flag flag) {
- flags_ &= (~flag);
-}
+void Connection::DisableFlag(Flag flag) { flags_ &= (~flag); }
-bool Connection::IsFlagEnabled(Flag flag) {
- return (flags_ & flag) > 0;
-}
+bool Connection::IsFlagEnabled(Flag flag) { return (flags_ & flag) > 0; }
void Connection::SubscribeChannel(const std::string &channel) {
for (const auto &chan : subscribe_channels_) {
@@ -333,7 +323,8 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
}
if (GetNamespace().empty()) {
- if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth") {
+ if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth" &&
+ Util::ToLower(cmd_tokens.front()) != "hello") {
Reply(Redis::Error("NOAUTH Authentication required."));
continue;
}
@@ -402,7 +393,7 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
if (!s.IsOK()) {
if (IsFlagEnabled(Connection::kMultiExec)) multi_error_ = true;
Reply(Redis::Error(s.Msg()));
- continue;;
+ continue;
}
}
diff --git a/src/redis_reply.cc b/src/redis_reply.cc
index 14a32ca..f9dcc1e 100644
--- a/src/redis_reply.cc
+++ b/src/redis_reply.cc
@@ -45,36 +45,39 @@ std::string MultiLen(int64_t len) {
return "*"+std::to_string(len)+"\r\n";
}
-std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string) {
+std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string) {
+ std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (values[i].empty() && output_nil_for_empty_string) {
- values[i] = NilString();
+ result += NilString();
} else {
- values[i] = BulkString(values[i]);
+ result += BulkString(values[i]);
}
}
- return Array(values);
+ return result;
}
-std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses) {
+std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses) {
+ std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (i < statuses.size() && !statuses[i].ok()) {
- values[i] = NilString();
+ result += NilString();
} else {
- values[i] = BulkString(values[i]);
+ result += BulkString(values[i]);
}
}
- return Array(values);
+ return result;
}
-std::string Array(std::vector<std::string> list) {
- std::string::size_type n = std::accumulate(
- list.begin(), list.end(), std::string::size_type(0),
- [] ( std::string::size_type n, const std::string &s ) { return ( n += s.size() ); });
+
+std::string Array(const std::vector<std::string>& list) {
+ size_t n = std::accumulate(
+ list.begin(), list.end(), 0, [] (size_t n, const std::string &s) { return n + s.size(); });
std::string result = "*" + std::to_string(list.size()) + CRLF;
- result.reserve(n);
- return std::accumulate(list.begin(), list.end(), result,
- [](std::string &dest, std::string const &item) -> std::string& {dest += item; return dest;});
+ std::string::size_type final_size = result.size() + n;
+ result.reserve(final_size);
+ for (const auto& i : list) result += i;
+ return result;
}
std::string Command2RESP(const std::vector<std::string> &cmd_args) {
diff --git a/src/redis_reply.h b/src/redis_reply.h
index ff4820c..e049ef1 100644
--- a/src/redis_reply.h
+++ b/src/redis_reply.h
@@ -35,8 +35,8 @@ std::string Integer(int64_t data);
std::string BulkString(const std::string &data);
std::string NilString();
std::string MultiLen(int64_t len);
-std::string Array(std::vector<std::string> list);
-std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string = true);
-std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses);
+std::string Array(const std::vector<std::string>& list);
+std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string = true);
+std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses);
std::string Command2RESP(const std::vector<std::string> &cmd_args);
} // namespace Redis
diff --git a/src/util.cc b/src/util.cc
index 0c9a1ab..31944f0 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -344,7 +344,7 @@ Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min, int64
try {
*n = static_cast<int64_t>(std::stoll(str));
if (max > min && (*n < min || *n > max)) {
- return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
+ return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
@@ -356,7 +356,7 @@ Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t
try {
*n = static_cast<int64_t>(std::stoll(str, nullptr, 8));
if (max > min && (*n < min || *n > max)) {
- return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
+ return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go
index a2a424f..46c4e4c 100644
--- a/tests/gocase/unit/command/command_test.go
+++ b/tests/gocase/unit/command/command_test.go
@@ -35,11 +35,11 @@ func TestCommand(t *testing.T) {
rdb := srv.NewClient()
defer func() { require.NoError(t, rdb.Close()) }()
- t.Run("Kvrocks supports 180 commands currently", func(t *testing.T) {
+ t.Run("Kvrocks supports 181 commands currently", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "COUNT")
v, err := r.Int()
require.NoError(t, err)
- require.Equal(t, 180, v)
+ require.Equal(t, 181, v)
})
t.Run("acquire GET command info by COMMAND INFO", func(t *testing.T) {
diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go
new file mode 100644
index 0000000..4f2adbf
--- /dev/null
+++ b/tests/gocase/unit/hello/hello_test.go
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package auth
+
+import (
+ "context"
+ "testing"
+
+ "github.com/apache/incubator-kvrocks/tests/gocase/util"
+ "github.com/stretchr/testify/require"
+)
+
+func TestHello(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("hello with wrong protocol", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "1")
+ require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
+ })
+
+ t.Run("hello with protocol 2", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "2")
+ rList := r.Val().([]interface{})
+ require.EqualValues(t, rList[2], "proto")
+ require.EqualValues(t, rList[3], 2)
+ })
+
+ t.Run("hello with protocol 3", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "3")
+ rList := r.Val().([]interface{})
+ require.EqualValues(t, rList[2], "proto")
+ require.EqualValues(t, rList[3], 2)
+ })
+
+ t.Run("hello with wrong protocol", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "5")
+ require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
+ })
+
+ t.Run("hello with non protocol", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "AUTH")
+ require.ErrorContains(t, r.Err(), "Protocol version is not an integer or out of range")
+ })
+
+ t.Run("hello with non protocol", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "2", "SETNAME", "kvrocks")
+ rList := r.Val().([]interface{})
+ require.EqualValues(t, rList[2], "proto")
+ require.EqualValues(t, rList[3], 2)
+
+ r = rdb.Do(ctx, "CLIENT", "GETNAME")
+ require.EqualValues(t, r.Val(), "kvrocks")
+ })
+}
+
+func TestHelloWithAuth(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{
+ "requirepass": "foobar",
+ })
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("AUTH fails when a wrong password is given", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "3", "AUTH", "wrong!")
+ require.ErrorContains(t, r.Err(), "invalid password")
+ })
+
+ t.Run("Arbitrary command gives an error when AUTH is required", func(t *testing.T) {
+ r := rdb.Set(ctx, "foo", "bar", 0)
+ require.ErrorContains(t, r.Err(), "NOAUTH Authentication required.")
+ })
+
+ t.Run("AUTH succeeds when the right password is given", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "3", "AUTH", "foobar")
+ t.Log(r)
+ })
+
+ t.Run("Once AUTH succeeded we can actually send commands to the server", func(t *testing.T) {
+ require.Equal(t, "OK", rdb.Set(ctx, "foo", 100, 0).Val())
+ require.EqualValues(t, 101, rdb.Incr(ctx, "foo").Val())
+ })
+
+ t.Run("hello with non protocol", func(t *testing.T) {
+ r := rdb.Do(ctx, "HELLO", "2", "AUTH", "foobar", "SETNAME", "kvrocks")
+ rList := r.Val().([]interface{})
+ require.EqualValues(t, rList[2], "proto")
+ require.EqualValues(t, rList[3], 2)
+
+ r = rdb.Do(ctx, "CLIENT", "GETNAME")
+ require.EqualValues(t, r.Val(), "kvrocks")
+ })
+}