You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by sz...@apache.org on 2021/12/01 13:27:07 UTC
[nifi-minifi-cpp] 03/03: MINIFICPP-1692 TLSSocket: Break infinite loop when no more data can be read
This is an automated email from the ASF dual-hosted git repository.
szaszm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 0553df446aefb395b34c1e4e2b5ced31f07a4e04
Author: Adam Debreceni <ad...@apache.org>
AuthorDate: Wed Dec 1 14:23:32 2021 +0100
MINIFICPP-1692 TLSSocket: Break infinite loop when no more data can be read
Closes #1218
Signed-off-by: Marton Szasz <sz...@apache.org>
---
cmake/BuildTests.cmake | 13 +++
.../TLSClientSocketSupportedProtocolsTest.cpp | 118 ++++---------------
libminifi/src/io/tls/TLSSocket.cpp | 4 +-
libminifi/test/SimpleSSLTestServer.h | 130 +++++++++++++++++++++
libminifi/test/unit/tls/TLSStreamTests.cpp | 82 +++++++++++++
5 files changed, 247 insertions(+), 100 deletions(-)
diff --git a/cmake/BuildTests.cmake b/cmake/BuildTests.cmake
index ecb5c70..6d5402e 100644
--- a/cmake/BuildTests.cmake
+++ b/cmake/BuildTests.cmake
@@ -101,6 +101,7 @@ target_include_directories(${CATCH_MAIN_LIB} SYSTEM BEFORE PRIVATE "${CMAKE_SOUR
SET(TEST_RESOURCES ${TEST_DIR}/resources)
GETSOURCEFILES(UNIT_TESTS "${TEST_DIR}/unit/")
+GETSOURCEFILES(TLS_UNIT_TESTS "${TEST_DIR}/unit/tls/")
GETSOURCEFILES(NANOFI_UNIT_TESTS "${NANOFI_TEST_DIR}")
GETSOURCEFILES(INTEGRATION_TESTS "${TEST_DIR}/integration/")
@@ -115,6 +116,18 @@ FOREACH(testfile ${UNIT_TESTS})
ENDFOREACH()
message("-- Finished building ${UNIT_TEST_COUNT} unit test file(s)...")
+if (NOT OPENSSL_OFF)
+ SET(UNIT_TEST_COUNT 0)
+ FOREACH(testfile ${TLS_UNIT_TESTS})
+ get_filename_component(testfilename "${testfile}" NAME_WE)
+ add_executable("${testfilename}" "${TEST_DIR}/unit/tls/${testfile}")
+ createTests("${testfilename}")
+ MATH(EXPR UNIT_TEST_COUNT "${UNIT_TEST_COUNT}+1")
+ add_test(NAME "${testfilename}" COMMAND "${testfilename}" "${TEST_RESOURCES}/" WORKING_DIRECTORY ${TEST_DIR})
+ ENDFOREACH()
+ message("-- Finished building ${UNIT_TEST_COUNT} TLS unit test file(s)...")
+endif()
+
if(NOT WIN32 AND ENABLE_NANOFI)
SET(UNIT_TEST_COUNT 0)
FOREACH(testfile ${NANOFI_UNIT_TESTS})
diff --git a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
index 928bc56..6cc643c 100644
--- a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
+++ b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
@@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <chrono>
#include <thread>
+#include <filesystem>
#undef NDEBUG
#include <cassert>
#include <utility>
@@ -26,114 +27,34 @@
#include <string>
#include "properties/Configure.h"
#include "io/tls/TLSSocket.h"
+#include "SimpleSSLTestServer.h"
namespace minifi = org::apache::nifi::minifi;
-#ifdef WIN32
-#pragma comment(lib, "Ws2_32.lib")
-using SocketDescriptor = SOCKET;
-#else
-using SocketDescriptor = int;
-static constexpr SocketDescriptor INVALID_SOCKET = -1;
-#endif /* WIN32 */
-
-
-class SimpleSSLTestServer {
- public:
- SimpleSSLTestServer(const SSL_METHOD* method, const std::string& port, const std::string& path)
- : port_(port), had_connection_(false) {
- ctx_ = SSL_CTX_new(method);
- configureContext(path);
- socket_descriptor_ = createSocket(std::stoi(port_));
- }
-
- ~SimpleSSLTestServer() {
- SSL_shutdown(ssl_);
- SSL_free(ssl_);
- SSL_CTX_free(ctx_);
- }
-
- void waitForConnection() {
- server_read_thread_ = std::thread([this]() -> void {
- SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr);
- if (client != INVALID_SOCKET) {
- ssl_ = SSL_new(ctx_);
- SSL_set_fd(ssl_, client);
- had_connection_ = (SSL_accept(ssl_) == 1);
- }
- });
- }
-
- void shutdownServer() {
-#ifdef WIN32
- shutdown(socket_descriptor_, SD_BOTH);
- closesocket(socket_descriptor_);
-#else
- shutdown(socket_descriptor_, SHUT_RDWR);
- close(socket_descriptor_);
-#endif
- server_read_thread_.join();
- }
-
- bool hadConnection() const {
- return had_connection_;
- }
-
- private:
- SSL_CTX *ctx_ = nullptr;
- SSL* ssl_ = nullptr;
- std::string port_;
- SocketDescriptor socket_descriptor_;
- bool had_connection_;
- std::thread server_read_thread_;
-
- void configureContext(const std::string& path) {
- SSL_CTX_set_ecdh_auto(ctx_, 1);
- /* Set the key and cert */
- assert(SSL_CTX_use_certificate_file(ctx_, (path + "cn.crt.pem").c_str(), SSL_FILETYPE_PEM) == 1);
- assert(SSL_CTX_use_PrivateKey_file(ctx_, (path + "cn.ckey.pem").c_str(), SSL_FILETYPE_PEM) == 1);
- }
-
- static SocketDescriptor createSocket(int port) {
- struct sockaddr_in addr;
-
- addr.sin_family = AF_INET;
- addr.sin_port = htons(port);
- addr.sin_addr.s_addr = htonl(INADDR_ANY);
-
- SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0);
- assert(socket_descriptor >= 0);
- assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0);
- assert(listen(socket_descriptor, 1) >= 0);
-
- return socket_descriptor;
- }
-};
-
class SimpleSSLTestServerTLSv1 : public SimpleSSLTestServer {
public:
- SimpleSSLTestServerTLSv1(const std::string& port, const std::string& path)
- : SimpleSSLTestServer(TLSv1_server_method(), port, path) {
+ SimpleSSLTestServerTLSv1(int port, const std::filesystem::path& key_dir)
+ : SimpleSSLTestServer(TLSv1_server_method(), port, key_dir) {
}
};
class SimpleSSLTestServerTLSv1_1 : public SimpleSSLTestServer {
public:
- SimpleSSLTestServerTLSv1_1(const std::string& port, const std::string& path)
- : SimpleSSLTestServer(TLSv1_1_server_method(), port, path) {
+ SimpleSSLTestServerTLSv1_1(int port, const std::filesystem::path& key_dir)
+ : SimpleSSLTestServer(TLSv1_1_server_method(), port, key_dir) {
}
};
class SimpleSSLTestServerTLSv1_2 : public SimpleSSLTestServer {
public:
- SimpleSSLTestServerTLSv1_2(const std::string& port, const std::string& path)
- : SimpleSSLTestServer(TLSv1_2_server_method(), port, path) {
+ SimpleSSLTestServerTLSv1_2(int port, const std::filesystem::path& key_dir)
+ : SimpleSSLTestServer(TLSv1_2_server_method(), port, key_dir) {
}
};
class TLSClientSocketSupportedProtocolsTest {
public:
- explicit TLSClientSocketSupportedProtocolsTest(const std::string& key_dir)
+ explicit TLSClientSocketSupportedProtocolsTest(const std::filesystem::path& key_dir)
: key_dir_(key_dir), configuration_(std::make_shared<minifi::Configure>()) {
}
@@ -147,14 +68,13 @@ class TLSClientSocketSupportedProtocolsTest {
protected:
void configureSecurity() {
host_ = minifi::io::Socket::getMyHostName();
- port_ = "38777";
if (!key_dir_.empty()) {
configuration_->set(minifi::Configure::nifi_remote_input_secure, "true");
- configuration_->set(minifi::Configure::nifi_security_client_certificate, key_dir_ + "cn.crt.pem");
- configuration_->set(minifi::Configure::nifi_security_client_private_key, key_dir_ + "cn.ckey.pem");
- configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, key_dir_ + "cn.pass");
- configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, key_dir_ + "nifi-cert.pem");
- configuration_->set(minifi::Configure::nifi_default_directory, key_dir_);
+ configuration_->set(minifi::Configure::nifi_security_client_certificate, (key_dir_ / "cn.crt.pem").string());
+ configuration_->set(minifi::Configure::nifi_security_client_private_key, (key_dir_ / "cn.ckey.pem").string());
+ configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir_ / "cn.pass").string());
+ configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir_ / "nifi-cert.pem").string());
+ configuration_->set(minifi::Configure::nifi_default_directory, key_dir_.string());
}
}
@@ -166,11 +86,14 @@ class TLSClientSocketSupportedProtocolsTest {
template <class TLSTestSever>
void verifyTLSProtocolCompatibility(const bool should_be_compatible) {
- TLSTestSever server(port_, key_dir_);
+ // bind to random port
+ TLSTestSever server(0, key_dir_);
server.waitForConnection();
+ int port = server.getPort();
+
const auto socket_context = std::make_shared<minifi::io::TLSContext>(configuration_);
- client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, std::stoi(port_), 0);
+ client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, port, 0);
const bool client_initialized_successfully = (client_socket_->initialize() == 0);
assert(client_initialized_successfully == should_be_compatible);
server.shutdownServer();
@@ -180,8 +103,7 @@ class TLSClientSocketSupportedProtocolsTest {
protected:
std::unique_ptr<minifi::io::TLSSocket> client_socket_;
std::string host_;
- std::string port_;
- std::string key_dir_;
+ std::filesystem::path key_dir_;
std::shared_ptr<minifi::Configure> configuration_;
};
diff --git a/libminifi/src/io/tls/TLSSocket.cpp b/libminifi/src/io/tls/TLSSocket.cpp
index 5d76e8c..af8772a 100644
--- a/libminifi/src/io/tls/TLSSocket.cpp
+++ b/libminifi/src/io/tls/TLSSocket.cpp
@@ -434,9 +434,9 @@ size_t TLSSocket::read(uint8_t *buf, size_t buflen) {
const auto ssl_read_size = gsl::narrow<int>(std::min(buflen, gsl::narrow<size_t>(std::numeric_limits<int>::max())));
status = SSL_read(fd_ssl, buf, ssl_read_size);
sslStatus = SSL_get_error(fd_ssl, status);
- } while (status < 0 && sslStatus == SSL_ERROR_WANT_READ);
+ } while (status <= 0 && sslStatus == SSL_ERROR_WANT_READ);
- if (status < 0)
+ if (status <= 0)
break;
buflen -= gsl::narrow<size_t>(status);
diff --git a/libminifi/test/SimpleSSLTestServer.h b/libminifi/test/SimpleSSLTestServer.h
new file mode 100644
index 0000000..b6cecf5
--- /dev/null
+++ b/libminifi/test/SimpleSSLTestServer.h
@@ -0,0 +1,130 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+#include <filesystem>
+#include <string>
+#include "io/tls/TLSSocket.h"
+
+#ifdef WIN32
+#include <winsock2.h>
+#include <ws2tcpip.h>
+#pragma comment(lib, "Ws2_32.lib")
+using SocketDescriptor = SOCKET;
+#else
+using SocketDescriptor = int;
+static constexpr SocketDescriptor INVALID_SOCKET = -1;
+#endif /* WIN32 */
+
+namespace minifi = org::apache::nifi::minifi;
+
+class SimpleSSLTestServer {
+ struct SocketInitializer {
+ SocketInitializer() {
+#ifdef WIN32
+ static WSADATA s_wsaData;
+ const int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData);
+ if (0 != iWinSockInitResult) {
+ throw std::runtime_error("Cannot initialize socket");
+ }
+#endif
+ }
+ };
+
+ public:
+ SimpleSSLTestServer(const SSL_METHOD* method, int port, const std::filesystem::path& key_dir)
+ : port_(port), had_connection_(false) {
+ static SocketInitializer socket_initializer{};
+ minifi::io::OpenSSLInitializer::getInstance();
+ ctx_ = SSL_CTX_new(method);
+ configureContext(key_dir);
+ socket_descriptor_ = createSocket(port_);
+ }
+
+ ~SimpleSSLTestServer() {
+ SSL_shutdown(ssl_);
+ SSL_free(ssl_);
+ SSL_CTX_free(ctx_);
+ }
+
+ void waitForConnection() {
+ server_read_thread_ = std::thread([this]() -> void {
+ SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr);
+ if (client != INVALID_SOCKET) {
+ ssl_ = SSL_new(ctx_);
+ SSL_set_fd(ssl_, client);
+ had_connection_ = (SSL_accept(ssl_) == 1);
+ }
+ });
+ }
+
+ void shutdownServer() {
+#ifdef WIN32
+ shutdown(socket_descriptor_, SD_BOTH);
+ closesocket(socket_descriptor_);
+#else
+ shutdown(socket_descriptor_, SHUT_RDWR);
+ close(socket_descriptor_);
+#endif
+ server_read_thread_.join();
+ }
+
+ bool hadConnection() const {
+ return had_connection_;
+ }
+
+ int getPort() const {
+ struct sockaddr_in addr;
+ socklen_t addr_len = sizeof(addr);
+ assert(getsockname(socket_descriptor_, (struct sockaddr*)&addr, &addr_len) == 0);
+ return ntohs(addr.sin_port);
+ }
+
+ private:
+ SSL_CTX *ctx_ = nullptr;
+ SSL* ssl_ = nullptr;
+ int port_;
+ SocketDescriptor socket_descriptor_;
+ bool had_connection_;
+ std::thread server_read_thread_;
+
+ void configureContext(const std::filesystem::path& key_dir) {
+ SSL_CTX_set_ecdh_auto(ctx_, 1);
+ /* Set the key and cert */
+ assert(SSL_CTX_use_certificate_file(ctx_, (key_dir / "cn.crt.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
+ assert(SSL_CTX_use_PrivateKey_file(ctx_, (key_dir / "cn.ckey.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
+ }
+
+ static SocketDescriptor createSocket(int port) {
+ struct sockaddr_in addr;
+
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ addr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+ SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0);
+ assert(socket_descriptor >= 0);
+ assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0);
+ assert(listen(socket_descriptor, 1) >= 0);
+
+ return socket_descriptor;
+ }
+};
diff --git a/libminifi/test/unit/tls/TLSStreamTests.cpp b/libminifi/test/unit/tls/TLSStreamTests.cpp
new file mode 100644
index 0000000..9fc5939
--- /dev/null
+++ b/libminifi/test/unit/tls/TLSStreamTests.cpp
@@ -0,0 +1,82 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#undef LOAD_EXTENSIONS
+#undef NDEBUG
+
+#include <cassert>
+
+#include "io/tls/TLSServerSocket.h"
+#include "io/tls/TLSSocket.h"
+#include "../../TestBase.h"
+#include "../../SimpleSSLTestServer.h"
+#include "../utils/IntegrationTestUtils.h"
+
+using namespace std::chrono_literals;
+
+static std::shared_ptr<minifi::io::TLSContext> createContext(const std::filesystem::path& key_dir) {
+ auto configuration = std::make_shared<minifi::Configure>();
+ configuration->set(minifi::Configure::nifi_remote_input_secure, "true");
+ configuration->set(minifi::Configure::nifi_security_client_certificate, (key_dir / "cn.crt.pem").string());
+ configuration->set(minifi::Configure::nifi_security_client_private_key, (key_dir / "cn.ckey.pem").string());
+ configuration->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir / "cn.pass").string());
+ configuration->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir / "nifi-cert.pem").string());
+ configuration->set(minifi::Configure::nifi_default_directory, key_dir.string());
+
+ return std::make_shared<minifi::io::TLSContext>(configuration);
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ throw std::logic_error("Specify the key directory");
+ }
+ std::filesystem::path key_dir(argv[1]);
+
+ LogTestController::getInstance().setTrace<minifi::io::Socket>();
+ LogTestController::getInstance().setTrace<minifi::io::TLSSocket>();
+ LogTestController::getInstance().setTrace<minifi::io::TLSServerSocket>();
+ LogTestController::getInstance().setTrace<minifi::io::TLSContext>();
+
+ auto server = std::make_unique<SimpleSSLTestServer>(TLSv1_2_server_method(), 0, key_dir);
+ int port = server->getPort();
+ server->waitForConnection();
+
+ std::string host = minifi::io::Socket::getMyHostName();
+
+ auto client_ctx = createContext(key_dir);
+ assert(client_ctx->initialize(false) == 0);
+
+ minifi::io::TLSSocket client_socket(client_ctx, host, port);
+ assert(client_socket.initialize() == 0);
+
+ std::atomic_bool read_complete{false};
+
+ std::thread read_thread{[&] {
+ std::vector<uint8_t> buffer;
+ auto read_count = client_socket.read(buffer, 10);
+ assert(read_count == 0);
+ read_complete = true;
+ }};
+
+ server->shutdownServer();
+ server.reset();
+
+ assert(utils::verifyEventHappenedInPollTime(1s, [&] {return read_complete.load();}));
+
+ read_thread.join();
+}