You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by sz...@apache.org on 2021/12/01 13:27:07 UTC

[nifi-minifi-cpp] 03/03: MINIFICPP-1692 TLSSocket: Break infinite loop when no more data can be read

This is an automated email from the ASF dual-hosted git repository.

szaszm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git

commit 0553df446aefb395b34c1e4e2b5ced31f07a4e04
Author: Adam Debreceni <ad...@apache.org>
AuthorDate: Wed Dec 1 14:23:32 2021 +0100

    MINIFICPP-1692 TLSSocket: Break infinite loop when no more data can be read
    
    Closes #1218
    Signed-off-by: Marton Szasz <sz...@apache.org>
---
 cmake/BuildTests.cmake                             |  13 +++
 .../TLSClientSocketSupportedProtocolsTest.cpp      | 118 ++++---------------
 libminifi/src/io/tls/TLSSocket.cpp                 |   4 +-
 libminifi/test/SimpleSSLTestServer.h               | 130 +++++++++++++++++++++
 libminifi/test/unit/tls/TLSStreamTests.cpp         |  82 +++++++++++++
 5 files changed, 247 insertions(+), 100 deletions(-)

diff --git a/cmake/BuildTests.cmake b/cmake/BuildTests.cmake
index ecb5c70..6d5402e 100644
--- a/cmake/BuildTests.cmake
+++ b/cmake/BuildTests.cmake
@@ -101,6 +101,7 @@ target_include_directories(${CATCH_MAIN_LIB} SYSTEM BEFORE PRIVATE "${CMAKE_SOUR
 SET(TEST_RESOURCES ${TEST_DIR}/resources)
 
 GETSOURCEFILES(UNIT_TESTS "${TEST_DIR}/unit/")
+GETSOURCEFILES(TLS_UNIT_TESTS "${TEST_DIR}/unit/tls/")
 GETSOURCEFILES(NANOFI_UNIT_TESTS "${NANOFI_TEST_DIR}")
 GETSOURCEFILES(INTEGRATION_TESTS "${TEST_DIR}/integration/")
 
@@ -115,6 +116,18 @@ FOREACH(testfile ${UNIT_TESTS})
 ENDFOREACH()
 message("-- Finished building ${UNIT_TEST_COUNT} unit test file(s)...")
 
+if (NOT OPENSSL_OFF)
+  SET(UNIT_TEST_COUNT 0)
+  FOREACH(testfile ${TLS_UNIT_TESTS})
+    get_filename_component(testfilename "${testfile}" NAME_WE)
+    add_executable("${testfilename}" "${TEST_DIR}/unit/tls/${testfile}")
+    createTests("${testfilename}")
+    MATH(EXPR UNIT_TEST_COUNT "${UNIT_TEST_COUNT}+1")
+    add_test(NAME "${testfilename}" COMMAND "${testfilename}" "${TEST_RESOURCES}/" WORKING_DIRECTORY ${TEST_DIR})
+  ENDFOREACH()
+  message("-- Finished building ${UNIT_TEST_COUNT} TLS unit test file(s)...")
+endif()
+
 if(NOT WIN32 AND ENABLE_NANOFI)
   SET(UNIT_TEST_COUNT 0)
   FOREACH(testfile ${NANOFI_UNIT_TESTS})
diff --git a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
index 928bc56..6cc643c 100644
--- a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
+++ b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp
@@ -19,6 +19,7 @@
 #include <sys/stat.h>
 #include <chrono>
 #include <thread>
+#include <filesystem>
 #undef NDEBUG
 #include <cassert>
 #include <utility>
@@ -26,114 +27,34 @@
 #include <string>
 #include "properties/Configure.h"
 #include "io/tls/TLSSocket.h"
+#include "SimpleSSLTestServer.h"
 
 namespace minifi = org::apache::nifi::minifi;
 
-#ifdef WIN32
-#pragma comment(lib, "Ws2_32.lib")
-using SocketDescriptor = SOCKET;
-#else
-using SocketDescriptor = int;
-static constexpr SocketDescriptor INVALID_SOCKET = -1;
-#endif /* WIN32 */
-
-
-class SimpleSSLTestServer  {
- public:
-  SimpleSSLTestServer(const SSL_METHOD* method, const std::string& port, const std::string& path)
-      : port_(port), had_connection_(false) {
-    ctx_ = SSL_CTX_new(method);
-    configureContext(path);
-    socket_descriptor_ = createSocket(std::stoi(port_));
-  }
-
-  ~SimpleSSLTestServer() {
-      SSL_shutdown(ssl_);
-      SSL_free(ssl_);
-      SSL_CTX_free(ctx_);
-  }
-
-  void waitForConnection() {
-    server_read_thread_ = std::thread([this]() -> void {
-        SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr);
-        if (client != INVALID_SOCKET) {
-            ssl_ = SSL_new(ctx_);
-            SSL_set_fd(ssl_, client);
-            had_connection_ = (SSL_accept(ssl_) == 1);
-        }
-    });
-  }
-
-  void shutdownServer() {
-#ifdef WIN32
-    shutdown(socket_descriptor_, SD_BOTH);
-    closesocket(socket_descriptor_);
-#else
-    shutdown(socket_descriptor_, SHUT_RDWR);
-    close(socket_descriptor_);
-#endif
-    server_read_thread_.join();
-  }
-
-  bool hadConnection() const {
-    return had_connection_;
-  }
-
- private:
-  SSL_CTX *ctx_ = nullptr;
-  SSL* ssl_ = nullptr;
-  std::string port_;
-  SocketDescriptor socket_descriptor_;
-  bool had_connection_;
-  std::thread server_read_thread_;
-
-  void configureContext(const std::string& path) {
-    SSL_CTX_set_ecdh_auto(ctx_, 1);
-    /* Set the key and cert */
-    assert(SSL_CTX_use_certificate_file(ctx_, (path + "cn.crt.pem").c_str(), SSL_FILETYPE_PEM) == 1);
-    assert(SSL_CTX_use_PrivateKey_file(ctx_, (path + "cn.ckey.pem").c_str(), SSL_FILETYPE_PEM) == 1);
-  }
-
-  static SocketDescriptor createSocket(int port) {
-    struct sockaddr_in addr;
-
-    addr.sin_family = AF_INET;
-    addr.sin_port = htons(port);
-    addr.sin_addr.s_addr = htonl(INADDR_ANY);
-
-    SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0);
-    assert(socket_descriptor >= 0);
-    assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0);
-    assert(listen(socket_descriptor, 1) >= 0);
-
-    return socket_descriptor;
-  }
-};
-
 class SimpleSSLTestServerTLSv1  : public SimpleSSLTestServer {
  public:
-  SimpleSSLTestServerTLSv1(const std::string& port, const std::string& path)
-      : SimpleSSLTestServer(TLSv1_server_method(), port, path) {
+  SimpleSSLTestServerTLSv1(int port, const std::filesystem::path& key_dir)
+      : SimpleSSLTestServer(TLSv1_server_method(), port, key_dir) {
   }
 };
 
 class SimpleSSLTestServerTLSv1_1  : public SimpleSSLTestServer {
  public:
-  SimpleSSLTestServerTLSv1_1(const std::string& port, const std::string& path)
-      : SimpleSSLTestServer(TLSv1_1_server_method(), port, path) {
+  SimpleSSLTestServerTLSv1_1(int port, const std::filesystem::path& key_dir)
+      : SimpleSSLTestServer(TLSv1_1_server_method(), port, key_dir) {
   }
 };
 
 class SimpleSSLTestServerTLSv1_2  : public SimpleSSLTestServer {
  public:
-  SimpleSSLTestServerTLSv1_2(const std::string& port, const std::string& path)
-      : SimpleSSLTestServer(TLSv1_2_server_method(), port, path) {
+  SimpleSSLTestServerTLSv1_2(int port, const std::filesystem::path& key_dir)
+      : SimpleSSLTestServer(TLSv1_2_server_method(), port, key_dir) {
   }
 };
 
 class TLSClientSocketSupportedProtocolsTest {
  public:
-  explicit TLSClientSocketSupportedProtocolsTest(const std::string& key_dir)
+  explicit TLSClientSocketSupportedProtocolsTest(const std::filesystem::path& key_dir)
       : key_dir_(key_dir), configuration_(std::make_shared<minifi::Configure>()) {
   }
 
@@ -147,14 +68,13 @@ class TLSClientSocketSupportedProtocolsTest {
  protected:
   void configureSecurity() {
     host_ = minifi::io::Socket::getMyHostName();
-    port_ = "38777";
     if (!key_dir_.empty()) {
       configuration_->set(minifi::Configure::nifi_remote_input_secure, "true");
-      configuration_->set(minifi::Configure::nifi_security_client_certificate, key_dir_ + "cn.crt.pem");
-      configuration_->set(minifi::Configure::nifi_security_client_private_key, key_dir_ + "cn.ckey.pem");
-      configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, key_dir_ + "cn.pass");
-      configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, key_dir_ + "nifi-cert.pem");
-      configuration_->set(minifi::Configure::nifi_default_directory, key_dir_);
+      configuration_->set(minifi::Configure::nifi_security_client_certificate, (key_dir_ / "cn.crt.pem").string());
+      configuration_->set(minifi::Configure::nifi_security_client_private_key, (key_dir_ / "cn.ckey.pem").string());
+      configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir_ / "cn.pass").string());
+      configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir_ / "nifi-cert.pem").string());
+      configuration_->set(minifi::Configure::nifi_default_directory, key_dir_.string());
     }
   }
 
@@ -166,11 +86,14 @@ class TLSClientSocketSupportedProtocolsTest {
 
   template <class TLSTestSever>
   void verifyTLSProtocolCompatibility(const bool should_be_compatible) {
-    TLSTestSever server(port_, key_dir_);
+    // bind to random port
+    TLSTestSever server(0, key_dir_);
     server.waitForConnection();
 
+    int port = server.getPort();
+
     const auto socket_context = std::make_shared<minifi::io::TLSContext>(configuration_);
-    client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, std::stoi(port_), 0);
+    client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, port, 0);
     const bool client_initialized_successfully = (client_socket_->initialize() == 0);
     assert(client_initialized_successfully == should_be_compatible);
     server.shutdownServer();
@@ -180,8 +103,7 @@ class TLSClientSocketSupportedProtocolsTest {
  protected:
     std::unique_ptr<minifi::io::TLSSocket> client_socket_;
     std::string host_;
-    std::string port_;
-    std::string key_dir_;
+    std::filesystem::path key_dir_;
     std::shared_ptr<minifi::Configure> configuration_;
 };
 
diff --git a/libminifi/src/io/tls/TLSSocket.cpp b/libminifi/src/io/tls/TLSSocket.cpp
index 5d76e8c..af8772a 100644
--- a/libminifi/src/io/tls/TLSSocket.cpp
+++ b/libminifi/src/io/tls/TLSSocket.cpp
@@ -434,9 +434,9 @@ size_t TLSSocket::read(uint8_t *buf, size_t buflen) {
       const auto ssl_read_size = gsl::narrow<int>(std::min(buflen, gsl::narrow<size_t>(std::numeric_limits<int>::max())));
       status = SSL_read(fd_ssl, buf, ssl_read_size);
       sslStatus = SSL_get_error(fd_ssl, status);
-    } while (status < 0 && sslStatus == SSL_ERROR_WANT_READ);
+    } while (status <= 0 && sslStatus == SSL_ERROR_WANT_READ);
 
-    if (status < 0)
+    if (status <= 0)
       break;
 
     buflen -= gsl::narrow<size_t>(status);
diff --git a/libminifi/test/SimpleSSLTestServer.h b/libminifi/test/SimpleSSLTestServer.h
new file mode 100644
index 0000000..b6cecf5
--- /dev/null
+++ b/libminifi/test/SimpleSSLTestServer.h
@@ -0,0 +1,130 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+#include <filesystem>
+#include <string>
+#include "io/tls/TLSSocket.h"
+
+#ifdef WIN32
+#include <winsock2.h>
+#include <ws2tcpip.h>
+#pragma comment(lib, "Ws2_32.lib")
+using SocketDescriptor = SOCKET;
+#else
+using SocketDescriptor = int;
+static constexpr SocketDescriptor INVALID_SOCKET = -1;
+#endif /* WIN32 */
+
+namespace minifi = org::apache::nifi::minifi;
+
+class SimpleSSLTestServer  {
+  struct SocketInitializer {
+    SocketInitializer() {
+#ifdef WIN32
+      static WSADATA s_wsaData;
+      const int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData);
+      if (0 != iWinSockInitResult) {
+        throw std::runtime_error("Cannot initialize socket");
+      }
+#endif
+    }
+  };
+
+ public:
+  SimpleSSLTestServer(const SSL_METHOD* method, int port, const std::filesystem::path& key_dir)
+      : port_(port), had_connection_(false) {
+    static SocketInitializer socket_initializer{};
+    minifi::io::OpenSSLInitializer::getInstance();
+    ctx_ = SSL_CTX_new(method);
+    configureContext(key_dir);
+    socket_descriptor_ = createSocket(port_);
+  }
+
+  ~SimpleSSLTestServer() {
+    SSL_shutdown(ssl_);
+    SSL_free(ssl_);
+    SSL_CTX_free(ctx_);
+  }
+
+  void waitForConnection() {
+    server_read_thread_ = std::thread([this]() -> void {
+      SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr);
+      if (client != INVALID_SOCKET) {
+        ssl_ = SSL_new(ctx_);
+        SSL_set_fd(ssl_, client);
+        had_connection_ = (SSL_accept(ssl_) == 1);
+      }
+    });
+  }
+
+  void shutdownServer() {
+#ifdef WIN32
+    shutdown(socket_descriptor_, SD_BOTH);
+    closesocket(socket_descriptor_);
+#else
+    shutdown(socket_descriptor_, SHUT_RDWR);
+    close(socket_descriptor_);
+#endif
+    server_read_thread_.join();
+  }
+
+  bool hadConnection() const {
+    return had_connection_;
+  }
+
+  int getPort() const {
+    struct sockaddr_in addr;
+    socklen_t addr_len = sizeof(addr);
+    assert(getsockname(socket_descriptor_, (struct sockaddr*)&addr, &addr_len) == 0);
+    return ntohs(addr.sin_port);
+  }
+
+ private:
+  SSL_CTX *ctx_ = nullptr;
+  SSL* ssl_ = nullptr;
+  int port_;
+  SocketDescriptor socket_descriptor_;
+  bool had_connection_;
+  std::thread server_read_thread_;
+
+  void configureContext(const std::filesystem::path& key_dir) {
+    SSL_CTX_set_ecdh_auto(ctx_, 1);
+    /* Set the key and cert */
+    assert(SSL_CTX_use_certificate_file(ctx_, (key_dir / "cn.crt.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
+    assert(SSL_CTX_use_PrivateKey_file(ctx_, (key_dir / "cn.ckey.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
+  }
+
+  static SocketDescriptor createSocket(int port) {
+    struct sockaddr_in addr;
+
+    addr.sin_family = AF_INET;
+    addr.sin_port = htons(port);
+    addr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+    SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0);
+    assert(socket_descriptor >= 0);
+    assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0);
+    assert(listen(socket_descriptor, 1) >= 0);
+
+    return socket_descriptor;
+  }
+};
diff --git a/libminifi/test/unit/tls/TLSStreamTests.cpp b/libminifi/test/unit/tls/TLSStreamTests.cpp
new file mode 100644
index 0000000..9fc5939
--- /dev/null
+++ b/libminifi/test/unit/tls/TLSStreamTests.cpp
@@ -0,0 +1,82 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#undef LOAD_EXTENSIONS
+#undef NDEBUG
+
+#include <cassert>
+
+#include "io/tls/TLSServerSocket.h"
+#include "io/tls/TLSSocket.h"
+#include "../../TestBase.h"
+#include "../../SimpleSSLTestServer.h"
+#include "../utils/IntegrationTestUtils.h"
+
+using namespace std::chrono_literals;
+
+static std::shared_ptr<minifi::io::TLSContext> createContext(const std::filesystem::path& key_dir) {
+  auto configuration = std::make_shared<minifi::Configure>();
+  configuration->set(minifi::Configure::nifi_remote_input_secure, "true");
+  configuration->set(minifi::Configure::nifi_security_client_certificate, (key_dir / "cn.crt.pem").string());
+  configuration->set(minifi::Configure::nifi_security_client_private_key, (key_dir / "cn.ckey.pem").string());
+  configuration->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir / "cn.pass").string());
+  configuration->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir / "nifi-cert.pem").string());
+  configuration->set(minifi::Configure::nifi_default_directory, key_dir.string());
+
+  return std::make_shared<minifi::io::TLSContext>(configuration);
+}
+
+int main(int argc, char** argv) {
+  if (argc < 2) {
+    throw std::logic_error("Specify the key directory");
+  }
+  std::filesystem::path key_dir(argv[1]);
+
+  LogTestController::getInstance().setTrace<minifi::io::Socket>();
+  LogTestController::getInstance().setTrace<minifi::io::TLSSocket>();
+  LogTestController::getInstance().setTrace<minifi::io::TLSServerSocket>();
+  LogTestController::getInstance().setTrace<minifi::io::TLSContext>();
+
+  auto server = std::make_unique<SimpleSSLTestServer>(TLSv1_2_server_method(), 0, key_dir);
+  int port = server->getPort();
+  server->waitForConnection();
+
+  std::string host = minifi::io::Socket::getMyHostName();
+
+  auto client_ctx = createContext(key_dir);
+  assert(client_ctx->initialize(false) == 0);
+
+  minifi::io::TLSSocket client_socket(client_ctx, host, port);
+  assert(client_socket.initialize() == 0);
+
+  std::atomic_bool read_complete{false};
+
+  std::thread read_thread{[&] {
+    std::vector<uint8_t> buffer;
+    auto read_count = client_socket.read(buffer, 10);
+    assert(read_count == 0);
+    read_complete = true;
+  }};
+
+  server->shutdownServer();
+  server.reset();
+
+  assert(utils::verifyEventHappenedInPollTime(1s, [&] {return read_complete.load();}));
+
+  read_thread.join();
+}