You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ya...@apache.org on 2022/03/02 09:27:04 UTC
[incubator-doris] branch master updated: [fix] fix a bug of encryption function with iv may return wrong result (#8277)
This is an automated email from the ASF dual-hosted git repository.
yangzhg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 246ac4e [fix] fix a bug of encryption function with iv may return wrong result (#8277)
246ac4e is described below
commit 246ac4e37aa4da6836b7850cb990f02d1c3725a3
Author: Zhengguo Yang <ya...@gmail.com>
AuthorDate: Wed Mar 2 17:26:44 2022 +0800
[fix] fix a bug of encryption function with iv may return wrong result (#8277)
---
be/src/exprs/encryption_functions.cpp | 93 +++++++++++++----------------------
be/src/exprs/minmax_predicate.h | 4 +-
be/src/udf/udf.cpp | 3 ++
be/src/udf/udf.h | 3 ++
be/src/util/encryption_util.cpp | 37 +++++++++++---
be/src/util/encryption_util.h | 4 +-
be/test/util/encryption_util_test.cpp | 91 +++++++++++++++++++++++++++++++++-
7 files changed, 164 insertions(+), 71 deletions(-)
diff --git a/be/src/exprs/encryption_functions.cpp b/be/src/exprs/encryption_functions.cpp
index 5d919a3..19ec1a7 100644
--- a/be/src/exprs/encryption_functions.cpp
+++ b/be/src/exprs/encryption_functions.cpp
@@ -55,32 +55,23 @@ StringVal encrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k
if (src.len == 0 || src.is_null) {
return StringVal::null();
}
+ /*
+ * Buffer for ciphertext. Ensure the buffer is long enough for the
+ * ciphertext which may be longer than the plaintext, depending on the
+ * algorithm and mode.
+ */
+
int cipher_len = src.len + 16;
- std::unique_ptr<char[]> p;
- p.reset(new char[cipher_len]);
- int ret_code = 0;
- if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
- mode != SM4_128_ECB) {
- if (iv.len == 0 || iv.is_null) {
- return StringVal::null();
- }
- int iv_len = 32; // max key length 256 / 8
- std::unique_ptr<char[]> init_vec;
- init_vec.reset(new char[iv_len]);
- std::memset(init_vec.get(), 0, iv.len + 1);
- memcpy(init_vec.get(), iv.ptr, iv.len);
- ret_code = EncryptionUtil::encrypt(
- mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
- (unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
- } else {
- ret_code = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len,
- (unsigned char*)key.ptr, key.len, nullptr, true,
- (unsigned char*)p.get());
- }
- if (ret_code < 0) {
+ std::unique_ptr<char[]> cipher_text;
+ cipher_text.reset(new char[cipher_len]);
+ int cipher_text_len = 0;
+ cipher_text_len = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len,
+ (unsigned char*)key.ptr, key.len, (char*)iv.ptr, true,
+ (unsigned char*)cipher_text.get());
+ if (cipher_text_len < 0) {
return StringVal::null();
}
- return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+ return AnyValUtil::from_buffer_temp(ctx, cipher_text.get(), cipher_text_len);
}
StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& key,
@@ -89,31 +80,16 @@ StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k
return StringVal::null();
}
int cipher_len = src.len;
- std::unique_ptr<char[]> p;
- p.reset(new char[cipher_len]);
- int ret_code = 0;
- if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
- mode != SM4_128_ECB) {
- if (iv.len == 0 || iv.is_null) {
- return StringVal::null();
- }
- int iv_len = 32; // max key length 256 / 8
- std::unique_ptr<char[]> init_vec;
- init_vec.reset(new char[iv_len]);
- std::memset(init_vec.get(), 0, iv.len + 1);
- memcpy(init_vec.get(), iv.ptr, iv.len);
- ret_code = EncryptionUtil::decrypt(
- mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
- (unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
- } else {
- ret_code = EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len,
- (unsigned char*)key.ptr, key.len, nullptr, true,
- (unsigned char*)p.get());
- }
- if (ret_code < 0) {
+ std::unique_ptr<char[]> plain_text;
+ plain_text.reset(new char[cipher_len]);
+ int plain_text_len = 0;
+ plain_text_len =
+ EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr,
+ key.len, (char*)iv.ptr, true, (unsigned char*)plain_text.get());
+ if (plain_text_len < 0) {
return StringVal::null();
}
- return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+ return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len);
}
StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src,
@@ -197,15 +173,15 @@ StringVal EncryptionFunctions::from_base64(FunctionContext* ctx, const StringVal
return StringVal::null();
}
- int cipher_len = src.len;
- std::unique_ptr<char[]> p;
- p.reset(new char[cipher_len]);
+ int encoded_len = src.len;
+ std::unique_ptr<char[]> plain_text;
+ plain_text.reset(new char[encoded_len]);
- int ret_code = base64_decode((const char*)src.ptr, src.len, p.get());
- if (ret_code < 0) {
+ int plain_text_len = base64_decode((const char*)src.ptr, src.len, plain_text.get());
+ if (plain_text_len < 0) {
return StringVal::null();
}
- return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+ return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len);
}
StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal& src) {
@@ -213,15 +189,16 @@ StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal&
return StringVal::null();
}
- int cipher_len = (size_t)(4.0 * ceil((double)src.len / 3.0));
- std::unique_ptr<char[]> p;
- p.reset(new char[cipher_len]);
+ int encoded_len = (size_t)(4.0 * ceil((double)src.len / 3.0));
+ std::unique_ptr<char[]> encoded_text;
+ encoded_text.reset(new char[encoded_len]);
- int ret_code = base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)p.get());
- if (ret_code < 0) {
+ int encoded_text_len =
+ base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)encoded_text.get());
+ if (encoded_text_len < 0) {
return StringVal::null();
}
- return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+ return AnyValUtil::from_buffer_temp(ctx, encoded_text.get(), encoded_text_len);
}
StringVal EncryptionFunctions::md5sum(FunctionContext* ctx, int num_args, const StringVal* args) {
diff --git a/be/src/exprs/minmax_predicate.h b/be/src/exprs/minmax_predicate.h
index 3a9ff5b..2c8140d 100644
--- a/be/src/exprs/minmax_predicate.h
+++ b/be/src/exprs/minmax_predicate.h
@@ -25,7 +25,6 @@ namespace doris {
// only used in Runtime Filter
class MinMaxFuncBase {
public:
- virtual ~MinMaxFuncBase() = default;
virtual void insert(const void* data) = 0;
virtual bool find(void* data) = 0;
virtual bool is_empty() = 0;
@@ -35,6 +34,7 @@ public:
virtual Status assign(void* min_data, void* max_data) = 0;
// merge from other minmax_func
virtual Status merge(MinMaxFuncBase* minmax_func, ObjectPool* pool) = 0;
+ virtual ~MinMaxFuncBase() = default;
};
template <class T>
@@ -114,4 +114,4 @@ private:
bool _empty = true;
};
-} // namespace doris
\ No newline at end of file
+} // namespace doris
diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp
index eae0bf1..612aabb 100644
--- a/be/src/udf/udf.cpp
+++ b/be/src/udf/udf.cpp
@@ -562,4 +562,7 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const {
return nullptr;
}
}
+std::ostream& operator<<(std::ostream& os, const StringVal& string_val) {
+ return os << string_val.to_string();
+}
} // namespace doris_udf
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index c262c8c..481f19a 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -21,6 +21,7 @@
#include <string.h>
#include <cstdint>
+#include <iostream>
#include <vector>
// This is the only Doris header required to develop UDFs and UDAs. This header
@@ -656,7 +657,9 @@ struct StringVal : public AnyVal {
void append(FunctionContext* ctx, const uint8_t* buf, int64_t len);
void append(FunctionContext* ctx, const uint8_t* buf, int64_t len, const uint8_t* buf2,
int64_t buf2_len);
+ std::string to_string() const { return std::string((char*)ptr, len); }
};
+std::ostream& operator<<(std::ostream& os, const StringVal& string_val);
struct DecimalV2Val : public AnyVal {
__int128 val;
diff --git a/be/src/util/encryption_util.cpp b/be/src/util/encryption_util.cpp
index eb95cf5..b9396e9 100644
--- a/be/src/util/encryption_util.cpp
+++ b/be/src/util/encryption_util.cpp
@@ -21,7 +21,9 @@
#include <openssl/evp.h>
#include <openssl/ossl_typ.h>
#include <sys/types.h>
+
#include <cstring>
+#include <string>
namespace doris {
@@ -171,20 +173,29 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
- const unsigned char* iv, bool padding, unsigned char* encrypt) {
+ const char* iv_str, bool padding, unsigned char* encrypt) {
const EVP_CIPHER* cipher = get_evp_type(mode);
/* The encrypt key to be used for encryption */
unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
create_key(key, key_length, encrypt_key, mode);
- if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) {
+ int iv_length = EVP_CIPHER_iv_length(cipher);
+ if (cipher == nullptr || (iv_length > 0 && !iv_str)) {
return AES_BAD_DATA;
}
+ char* init_vec = nullptr;
+ std::string iv_default("DORISDORISDORIS_");
+
+ if (iv_str) {
+ init_vec = &iv_default[0];
+ memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH));
+ init_vec[iv_length] = '\0';
+ }
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
EVP_CIPHER_CTX_reset(cipher_ctx);
int length = 0;
- int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, iv, padding,
- encrypt, &length);
+ int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
+ reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length);
EVP_CIPHER_CTX_free(cipher_ctx);
if (ret == 0) {
ERR_clear_error();
@@ -219,21 +230,31 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
uint32_t encrypt_length, const unsigned char* key, uint32_t key_length,
- const unsigned char* iv, bool padding, unsigned char* decrypt_content) {
+ const char* iv_str, bool padding, unsigned char* decrypt_content) {
const EVP_CIPHER* cipher = get_evp_type(mode);
/* The encrypt key to be used for decryption */
unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
create_key(key, key_length, encrypt_key, mode);
- if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) {
+ int iv_length = EVP_CIPHER_iv_length(cipher);
+ if (cipher == nullptr || (iv_length > 0 && !iv_str)) {
return AES_BAD_DATA;
}
+ char* init_vec = nullptr;
+ std::string iv_default("DORISDORISDORIS_");
+
+ if (iv_str) {
+ init_vec = &iv_default[0];
+ memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH));
+ init_vec[iv_length] = '\0';
+ }
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
EVP_CIPHER_CTX_reset(cipher_ctx);
int length = 0;
- int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv, padding,
- decrypt_content, &length);
+ int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
+ reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content,
+ &length);
EVP_CIPHER_CTX_free(cipher_ctx);
if (ret > 0) {
return length;
diff --git a/be/src/util/encryption_util.h b/be/src/util/encryption_util.h
index e051d28..711a817 100644
--- a/be/src/util/encryption_util.h
+++ b/be/src/util/encryption_util.h
@@ -58,11 +58,11 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 };
class EncryptionUtil {
public:
static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length,
- const unsigned char* key, uint32_t key_length, const unsigned char* iv,
+ const unsigned char* key, uint32_t key_length, const char* iv_str,
bool padding, unsigned char* encrypt);
static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length,
- const unsigned char* key, uint32_t key_length, const unsigned char* iv,
+ const unsigned char* key, uint32_t key_length, const char* iv_str,
bool padding, unsigned char* decrypt_content);
};
diff --git a/be/test/util/encryption_util_test.cpp b/be/test/util/encryption_util_test.cpp
index 30c9752..2f30ade 100644
--- a/be/test/util/encryption_util_test.cpp
+++ b/be/test/util/encryption_util_test.cpp
@@ -117,7 +117,6 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) {
std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
- std::cout << encrypt_1.get();
std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
int ret_code = EncryptionUtil::decrypt(SM4_128_ECB, (unsigned char*)encrypt_1.get(), length_1,
(unsigned char*)_aes_key.c_str(), _aes_key.length(),
@@ -137,6 +136,96 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) {
ASSERT_EQ(source_2, decrypted_content_2);
}
+TEST_F(EncryptionUtilTest, aes_with_iv_test_by_case) {
+ std::string case_1 = "XbJgw1AxBNwZZPpvzPtWyg=="; // base64 for encrypted "hello, doris"
+ std::string source_1 = "hello, doris";
+ std::string case_2 = "gpKcO/iwgeRCIWBQdkpAkQ=="; // base64 for encrypted "doris test"
+ std::string source_2 = "doris test";
+ std::string iv = "doris";
+
+ std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
+ int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
+ std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
+ int ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_1.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_1(decrypted_1.get(), ret_code);
+ ASSERT_EQ(source_1, decrypted_content_1);
+ std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]);
+
+ ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_11.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_11(decrypted_11.get(), ret_code);
+ ASSERT_EQ(source_1, decrypted_content_11);
+
+ std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]);
+ int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get());
+ std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]);
+ ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_2.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_2(decrypted_2.get(), ret_code);
+ ASSERT_EQ(source_2, decrypted_content_2);
+
+ std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]);
+ ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_21.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_21(decrypted_21.get(), ret_code);
+ ASSERT_EQ(source_2, decrypted_content_21);
+}
+
+TEST_F(EncryptionUtilTest, sm4_with_iv_test_by_case) {
+ std::string case_1 = "9FFlX59+3EbIC7rqylMNwg=="; // base64 for encrypted "hello, doris"
+ std::string source_1 = "hello, doris";
+ std::string case_2 = "RIJVVUUmMT/4CVNYdxVvXA=="; // base64 for encrypted "doris test"
+ std::string source_2 = "doris test";
+ std::string iv = "doris";
+
+ std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
+ int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
+ std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
+ std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]);
+
+ int ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_1.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_1(decrypted_1.get(), ret_code);
+ ASSERT_EQ(source_1, decrypted_content_1);
+
+ std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]);
+ int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get());
+ std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]);
+ std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]);
+
+ ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_2.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_2(decrypted_2.get(), ret_code);
+ ASSERT_EQ(source_2, decrypted_content_2);
+
+ ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_11.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_11(decrypted_11.get(), ret_code);
+ ASSERT_EQ(source_1, decrypted_content_11);
+
+ ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+ (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+ iv.c_str(), true, (unsigned char*)decrypted_21.get());
+ ASSERT_TRUE(ret_code > 0);
+ std::string decrypted_content_21(decrypted_21.get(), ret_code);
+ ASSERT_EQ(source_2, decrypted_content_21);
+}
+
} // namespace doris
int main(int argc, char** argv) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org