You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ya...@apache.org on 2022/03/02 09:27:04 UTC

[incubator-doris] branch master updated: [fix] fix a bug of encryption function with iv may return wrong result (#8277)

This is an automated email from the ASF dual-hosted git repository.

yangzhg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 246ac4e  [fix] fix a bug of encryption function with iv  may return wrong result (#8277)
246ac4e is described below

commit 246ac4e37aa4da6836b7850cb990f02d1c3725a3
Author: Zhengguo Yang <ya...@gmail.com>
AuthorDate: Wed Mar 2 17:26:44 2022 +0800

    [fix] fix a bug of encryption function with iv  may return wrong result (#8277)
---
 be/src/exprs/encryption_functions.cpp | 93 +++++++++++++----------------------
 be/src/exprs/minmax_predicate.h       |  4 +-
 be/src/udf/udf.cpp                    |  3 ++
 be/src/udf/udf.h                      |  3 ++
 be/src/util/encryption_util.cpp       | 37 +++++++++++---
 be/src/util/encryption_util.h         |  4 +-
 be/test/util/encryption_util_test.cpp | 91 +++++++++++++++++++++++++++++++++-
 7 files changed, 164 insertions(+), 71 deletions(-)

diff --git a/be/src/exprs/encryption_functions.cpp b/be/src/exprs/encryption_functions.cpp
index 5d919a3..19ec1a7 100644
--- a/be/src/exprs/encryption_functions.cpp
+++ b/be/src/exprs/encryption_functions.cpp
@@ -55,32 +55,23 @@ StringVal encrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k
     if (src.len == 0 || src.is_null) {
         return StringVal::null();
     }
+    /*
+     * Buffer for ciphertext. Ensure the buffer is long enough for the
+     * ciphertext which may be longer than the plaintext, depending on the
+     * algorithm and mode.
+     */
+
     int cipher_len = src.len + 16;
-    std::unique_ptr<char[]> p;
-    p.reset(new char[cipher_len]);
-    int ret_code = 0;
-    if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
-        mode != SM4_128_ECB) {
-        if (iv.len == 0 || iv.is_null) {
-            return StringVal::null();
-        }
-        int iv_len = 32; // max  key length 256 / 8
-        std::unique_ptr<char[]> init_vec;
-        init_vec.reset(new char[iv_len]);
-        std::memset(init_vec.get(), 0, iv.len + 1);
-        memcpy(init_vec.get(), iv.ptr, iv.len);
-        ret_code = EncryptionUtil::encrypt(
-                mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
-                (unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
-    } else {
-        ret_code = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len,
-                                           (unsigned char*)key.ptr, key.len, nullptr, true,
-                                           (unsigned char*)p.get());
-    }
-    if (ret_code < 0) {
+    std::unique_ptr<char[]> cipher_text;
+    cipher_text.reset(new char[cipher_len]);
+    int cipher_text_len = 0;
+    cipher_text_len = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len,
+                                              (unsigned char*)key.ptr, key.len, (char*)iv.ptr, true,
+                                              (unsigned char*)cipher_text.get());
+    if (cipher_text_len < 0) {
         return StringVal::null();
     }
-    return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+    return AnyValUtil::from_buffer_temp(ctx, cipher_text.get(), cipher_text_len);
 }
 
 StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& key,
@@ -89,31 +80,16 @@ StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k
         return StringVal::null();
     }
     int cipher_len = src.len;
-    std::unique_ptr<char[]> p;
-    p.reset(new char[cipher_len]);
-    int ret_code = 0;
-    if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
-        mode != SM4_128_ECB) {
-        if (iv.len == 0 || iv.is_null) {
-            return StringVal::null();
-        }
-        int iv_len = 32; // max  key length 256 / 8
-        std::unique_ptr<char[]> init_vec;
-        init_vec.reset(new char[iv_len]);
-        std::memset(init_vec.get(), 0, iv.len + 1);
-        memcpy(init_vec.get(), iv.ptr, iv.len);
-        ret_code = EncryptionUtil::decrypt(
-                mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
-                (unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
-    } else {
-        ret_code = EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len,
-                                           (unsigned char*)key.ptr, key.len, nullptr, true,
-                                           (unsigned char*)p.get());
-    }
-    if (ret_code < 0) {
+    std::unique_ptr<char[]> plain_text;
+    plain_text.reset(new char[cipher_len]);
+    int plain_text_len = 0;
+    plain_text_len =
+            EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr,
+                                    key.len, (char*)iv.ptr, true, (unsigned char*)plain_text.get());
+    if (plain_text_len < 0) {
         return StringVal::null();
     }
-    return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+    return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len);
 }
 
 StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src,
@@ -197,15 +173,15 @@ StringVal EncryptionFunctions::from_base64(FunctionContext* ctx, const StringVal
         return StringVal::null();
     }
 
-    int cipher_len = src.len;
-    std::unique_ptr<char[]> p;
-    p.reset(new char[cipher_len]);
+    int encoded_len = src.len;
+    std::unique_ptr<char[]> plain_text;
+    plain_text.reset(new char[encoded_len]);
 
-    int ret_code = base64_decode((const char*)src.ptr, src.len, p.get());
-    if (ret_code < 0) {
+    int plain_text_len = base64_decode((const char*)src.ptr, src.len, plain_text.get());
+    if (plain_text_len < 0) {
         return StringVal::null();
     }
-    return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+    return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len);
 }
 
 StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal& src) {
@@ -213,15 +189,16 @@ StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal&
         return StringVal::null();
     }
 
-    int cipher_len = (size_t)(4.0 * ceil((double)src.len / 3.0));
-    std::unique_ptr<char[]> p;
-    p.reset(new char[cipher_len]);
+    int encoded_len = (size_t)(4.0 * ceil((double)src.len / 3.0));
+    std::unique_ptr<char[]> encoded_text;
+    encoded_text.reset(new char[encoded_len]);
 
-    int ret_code = base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)p.get());
-    if (ret_code < 0) {
+    int encoded_text_len =
+            base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)encoded_text.get());
+    if (encoded_text_len < 0) {
         return StringVal::null();
     }
-    return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
+    return AnyValUtil::from_buffer_temp(ctx, encoded_text.get(), encoded_text_len);
 }
 
 StringVal EncryptionFunctions::md5sum(FunctionContext* ctx, int num_args, const StringVal* args) {
diff --git a/be/src/exprs/minmax_predicate.h b/be/src/exprs/minmax_predicate.h
index 3a9ff5b..2c8140d 100644
--- a/be/src/exprs/minmax_predicate.h
+++ b/be/src/exprs/minmax_predicate.h
@@ -25,7 +25,6 @@ namespace doris {
 // only used in Runtime Filter
 class MinMaxFuncBase {
 public:
-    virtual ~MinMaxFuncBase() = default;
     virtual void insert(const void* data) = 0;
     virtual bool find(void* data) = 0;
     virtual bool is_empty() = 0;
@@ -35,6 +34,7 @@ public:
     virtual Status assign(void* min_data, void* max_data) = 0;
     // merge from other minmax_func
     virtual Status merge(MinMaxFuncBase* minmax_func, ObjectPool* pool) = 0;
+    virtual ~MinMaxFuncBase() = default;
 };
 
 template <class T>
@@ -114,4 +114,4 @@ private:
     bool _empty = true;
 };
 
-} // namespace doris
\ No newline at end of file
+} // namespace doris
diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp
index eae0bf1..612aabb 100644
--- a/be/src/udf/udf.cpp
+++ b/be/src/udf/udf.cpp
@@ -562,4 +562,7 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const {
         return nullptr;
     }
 }
+std::ostream& operator<<(std::ostream& os, const StringVal& string_val) {
+    return os << string_val.to_string();
+}
 } // namespace doris_udf
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index c262c8c..481f19a 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -21,6 +21,7 @@
 #include <string.h>
 
 #include <cstdint>
+#include <iostream>
 #include <vector>
 
 // This is the only Doris header required to develop UDFs and UDAs. This header
@@ -656,7 +657,9 @@ struct StringVal : public AnyVal {
     void append(FunctionContext* ctx, const uint8_t* buf, int64_t len);
     void append(FunctionContext* ctx, const uint8_t* buf, int64_t len, const uint8_t* buf2,
                 int64_t buf2_len);
+    std::string to_string() const { return std::string((char*)ptr, len); }
 };
+std::ostream& operator<<(std::ostream& os, const StringVal& string_val);
 
 struct DecimalV2Val : public AnyVal {
     __int128 val;
diff --git a/be/src/util/encryption_util.cpp b/be/src/util/encryption_util.cpp
index eb95cf5..b9396e9 100644
--- a/be/src/util/encryption_util.cpp
+++ b/be/src/util/encryption_util.cpp
@@ -21,7 +21,9 @@
 #include <openssl/evp.h>
 #include <openssl/ossl_typ.h>
 #include <sys/types.h>
+
 #include <cstring>
+#include <string>
 
 namespace doris {
 
@@ -171,20 +173,29 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
 
 int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
                             uint32_t source_length, const unsigned char* key, uint32_t key_length,
-                            const unsigned char* iv, bool padding, unsigned char* encrypt) {
+                            const char* iv_str, bool padding, unsigned char* encrypt) {
     const EVP_CIPHER* cipher = get_evp_type(mode);
     /* The encrypt key to be used for encryption */
     unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
     create_key(key, key_length, encrypt_key, mode);
 
-    if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) {
+    int iv_length = EVP_CIPHER_iv_length(cipher);
+    if (cipher == nullptr || (iv_length > 0 && !iv_str)) {
         return AES_BAD_DATA;
     }
+    char* init_vec = nullptr;
+    std::string iv_default("DORISDORISDORIS_");
+
+    if (iv_str) {
+        init_vec = &iv_default[0];
+        memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH));
+        init_vec[iv_length] = '\0';
+    }
     EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
     EVP_CIPHER_CTX_reset(cipher_ctx);
     int length = 0;
-    int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, iv, padding,
-                         encrypt, &length);
+    int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
+                         reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length);
     EVP_CIPHER_CTX_free(cipher_ctx);
     if (ret == 0) {
         ERR_clear_error();
@@ -219,21 +230,31 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
 
 int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
                             uint32_t encrypt_length, const unsigned char* key, uint32_t key_length,
-                            const unsigned char* iv, bool padding, unsigned char* decrypt_content) {
+                            const char* iv_str, bool padding, unsigned char* decrypt_content) {
     const EVP_CIPHER* cipher = get_evp_type(mode);
 
     /* The encrypt key to be used for decryption */
     unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
     create_key(key, key_length, encrypt_key, mode);
 
-    if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) {
+    int iv_length = EVP_CIPHER_iv_length(cipher);
+    if (cipher == nullptr || (iv_length > 0 && !iv_str)) {
         return AES_BAD_DATA;
     }
+    char* init_vec = nullptr;
+    std::string iv_default("DORISDORISDORIS_");
+
+    if (iv_str) {
+        init_vec = &iv_default[0];
+        memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH));
+        init_vec[iv_length] = '\0';
+    }
     EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
     EVP_CIPHER_CTX_reset(cipher_ctx);
     int length = 0;
-    int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv, padding,
-                         decrypt_content, &length);
+    int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
+                         reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content,
+                         &length);
     EVP_CIPHER_CTX_free(cipher_ctx);
     if (ret > 0) {
         return length;
diff --git a/be/src/util/encryption_util.h b/be/src/util/encryption_util.h
index e051d28..711a817 100644
--- a/be/src/util/encryption_util.h
+++ b/be/src/util/encryption_util.h
@@ -58,11 +58,11 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 };
 class EncryptionUtil {
 public:
     static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length,
-                       const unsigned char* key, uint32_t key_length, const unsigned char* iv,
+                       const unsigned char* key, uint32_t key_length, const char* iv_str,
                        bool padding, unsigned char* encrypt);
 
     static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length,
-                       const unsigned char* key, uint32_t key_length, const unsigned char* iv,
+                       const unsigned char* key, uint32_t key_length, const char* iv_str,
                        bool padding, unsigned char* decrypt_content);
 };
 
diff --git a/be/test/util/encryption_util_test.cpp b/be/test/util/encryption_util_test.cpp
index 30c9752..2f30ade 100644
--- a/be/test/util/encryption_util_test.cpp
+++ b/be/test/util/encryption_util_test.cpp
@@ -117,7 +117,6 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) {
 
     std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
     int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
-    std::cout << encrypt_1.get();
     std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
     int ret_code = EncryptionUtil::decrypt(SM4_128_ECB, (unsigned char*)encrypt_1.get(), length_1,
                                            (unsigned char*)_aes_key.c_str(), _aes_key.length(),
@@ -137,6 +136,96 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) {
     ASSERT_EQ(source_2, decrypted_content_2);
 }
 
+TEST_F(EncryptionUtilTest, aes_with_iv_test_by_case) {
+    std::string case_1 = "XbJgw1AxBNwZZPpvzPtWyg=="; // base64 for encrypted "hello, doris"
+    std::string source_1 = "hello, doris";
+    std::string case_2 = "gpKcO/iwgeRCIWBQdkpAkQ=="; // base64 for encrypted "doris test"
+    std::string source_2 = "doris test";
+    std::string iv = "doris";
+
+    std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
+    int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
+    std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
+    int ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+                                           (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                           iv.c_str(), true, (unsigned char*)decrypted_1.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_1(decrypted_1.get(), ret_code);
+    ASSERT_EQ(source_1, decrypted_content_1);
+    std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]);
+
+    ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_11.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_11(decrypted_11.get(), ret_code);
+    ASSERT_EQ(source_1, decrypted_content_11);
+
+    std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]);
+    int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get());
+    std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]);
+    ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_2.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_2(decrypted_2.get(), ret_code);
+    ASSERT_EQ(source_2, decrypted_content_2);
+
+    std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]);
+    ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_21.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_21(decrypted_21.get(), ret_code);
+    ASSERT_EQ(source_2, decrypted_content_21);
+}
+
+TEST_F(EncryptionUtilTest, sm4_with_iv_test_by_case) {
+    std::string case_1 = "9FFlX59+3EbIC7rqylMNwg=="; // base64 for encrypted "hello, doris"
+    std::string source_1 = "hello, doris";
+    std::string case_2 = "RIJVVUUmMT/4CVNYdxVvXA=="; // base64 for encrypted "doris test"
+    std::string source_2 = "doris test";
+    std::string iv = "doris";
+
+    std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]);
+    int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get());
+    std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]);
+    std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]);
+
+    int ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+                                           (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                           iv.c_str(), true, (unsigned char*)decrypted_1.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_1(decrypted_1.get(), ret_code);
+    ASSERT_EQ(source_1, decrypted_content_1);
+
+    std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]);
+    int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get());
+    std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]);
+    std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]);
+
+    ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_2.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_2(decrypted_2.get(), ret_code);
+    ASSERT_EQ(source_2, decrypted_content_2);
+
+    ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_11.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_11(decrypted_11.get(), ret_code);
+    ASSERT_EQ(source_1, decrypted_content_11);
+
+    ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2,
+                                       (unsigned char*)_aes_key.c_str(), _aes_key.length(),
+                                       iv.c_str(), true, (unsigned char*)decrypted_21.get());
+    ASSERT_TRUE(ret_code > 0);
+    std::string decrypted_content_21(decrypted_21.get(), ret_code);
+    ASSERT_EQ(source_2, decrypted_content_21);
+}
+
 } // namespace doris
 
 int main(int argc, char** argv) {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org