You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@nifi.apache.org by GitBox <gi...@apache.org> on 2022/12/14 11:58:11 UTC

[GitHub] [nifi-minifi-cpp] adamdebreceni commented on a diff in pull request #1457: MINIFICPP-1979 Use Coroutines with asio

adamdebreceni commented on code in PR #1457:
URL: https://github.com/apache/nifi-minifi-cpp/pull/1457#discussion_r1048381787


##########
extensions/standard-processors/processors/PutTCP.cpp:
##########
@@ -160,339 +177,145 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio
 }
 
 namespace {
+template<class SocketType>
+asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, asio::steady_timer::duration) {
+  co_return std::error_code();
+}
+
+template<>
+asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) {
+  co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration);  // NOLINT
+}
+
 template<class SocketType>
 class ConnectionHandler : public ConnectionHandlerBase {
  public:
   ConnectionHandler(detail::ConnectionId connection_id,
                     std::chrono::milliseconds timeout,
                     std::shared_ptr<core::logging::Logger> logger,
                     std::optional<size_t> max_size_of_socket_send_buffer,
-                    std::shared_ptr<controllers::SSLContextService> ssl_context_service)
+                    std::optional<asio::ssl::context>& ssl_context)
       : connection_id_(std::move(connection_id)),
-        timeout_(timeout),
+        timeout_duration_(timeout),
         logger_(std::move(logger)),
         max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer),
-        ssl_context_service_(std::move(ssl_context_service)) {
+        ssl_context_(ssl_context) {
   }
 
   ~ConnectionHandler() override = default;
 
-  nonstd::expected<void, std::error_code> sendData(const std::shared_ptr<io::InputStream>& flow_file_content_stream, const std::vector<std::byte>& delimiter) override;
+  asio::awaitable<std::error_code> sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter, asio::io_context& io_context_) override;
 
  private:
-  nonstd::expected<std::shared_ptr<SocketType>, std::error_code> getSocket();
-
   [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const override {
-    return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() - dur);
+    return last_used_ && *last_used_ >= (steady_clock::now() - dur);
   }
 
   void reset() override {
     last_used_.reset();
     socket_.reset();
-    io_context_.reset();
-    last_error_.clear();
-    deadline_.expires_at(asio::steady_timer::time_point::max());
   }
 
-  void checkDeadline(std::error_code error_code, SocketType* socket);
-  void startConnect(tcp::resolver::results_type::iterator endpoint_iter, const std::shared_ptr<SocketType>& socket);
-
-  void handleConnect(std::error_code error,
-                     tcp::resolver::results_type::iterator endpoint_iter,
-                     const std::shared_ptr<SocketType>& socket);
-  void handleConnectionSuccess(const tcp::resolver::results_type::iterator& endpoint_iter,
-                               const std::shared_ptr<SocketType>& socket);
-  void handleHandshake(std::error_code error,
-                       const tcp::resolver::results_type::iterator& endpoint_iter,
-                       const std::shared_ptr<SocketType>& socket);
-
-  void handleWrite(std::error_code error,
-                   std::size_t bytes_written,
-                   const std::shared_ptr<io::InputStream>& flow_file_content_stream,
-                   const std::vector<std::byte>& delimiter,
-                   const std::shared_ptr<SocketType>& socket);
-
-  void handleDelimiterWrite(std::error_code error, std::size_t bytes_written, const std::shared_ptr<SocketType>& socket);
+  [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); }
+  [[nodiscard]] asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context);
+  [[nodiscard]] bool hasUsableSocket() const {  return socket_ && socket_->lowest_layer().is_open(); }
 
-  nonstd::expected<std::shared_ptr<SocketType>, std::error_code> establishConnection(const tcp::resolver::results_type& resolved_query);
+  asio::awaitable<std::error_code> establishNewConnection(const tcp::resolver::results_type& resolved_query, asio::io_context& io_context_);
+  asio::awaitable<std::error_code> send(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter);
 
-  [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); }
+  SocketType createNewSocket(asio::io_context& io_context_);
 
   detail::ConnectionId connection_id_;
-  std::optional<std::chrono::steady_clock::time_point> last_used_;
-  asio::io_context io_context_;
-  std::error_code last_error_;
-  asio::steady_timer deadline_{io_context_};
-  std::chrono::milliseconds timeout_;
-  std::shared_ptr<SocketType> socket_;
+  std::optional<SocketType> socket_;
+
+  std::optional<steady_clock::time_point> last_used_;
+  std::chrono::milliseconds timeout_duration_;
 
   std::shared_ptr<core::logging::Logger> logger_;
   std::optional<size_t> max_size_of_socket_send_buffer_;
 
-  std::shared_ptr<controllers::SSLContextService> ssl_context_service_;
-
-  nonstd::expected<tcp::resolver::results_type, std::error_code> resolveHostname();
-  nonstd::expected<void, std::error_code> sendDataToSocket(const std::shared_ptr<SocketType>& socket,
-                                                           const std::shared_ptr<io::InputStream>& flow_file_content_stream,
-                                                           const std::vector<std::byte>& delimiter);
+  std::optional<asio::ssl::context>& ssl_context_;
 };
 
-template<class SocketType>
-nonstd::expected<void, std::error_code> ConnectionHandler<SocketType>::sendData(const std::shared_ptr<io::InputStream>& flow_file_content_stream, const std::vector<std::byte>& delimiter) {
-  return getSocket() | utils::flatMap([&](const std::shared_ptr<SocketType>& socket) { return sendDataToSocket(socket, flow_file_content_stream, delimiter); });;
-}
-
-template<class SocketType>
-nonstd::expected<std::shared_ptr<SocketType>, std::error_code> ConnectionHandler<SocketType>::getSocket() {
-  if (socket_ && socket_->lowest_layer().is_open())
-    return socket_;
-  auto new_socket = resolveHostname() | utils::flatMap([&](const auto& resolved_query) { return establishConnection(resolved_query); });
-  if (!new_socket)
-    return nonstd::make_unexpected(new_socket.error());
-  socket_ = std::move(*new_socket);
-  return socket_;
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::checkDeadline(std::error_code error_code, SocketType* socket) {
-  if (error_code != asio::error::operation_aborted) {
-    deadline_.expires_at(asio::steady_timer::time_point::max());
-    last_error_ = asio::error::timed_out;
-    deadline_.async_wait([&](std::error_code error_code) { checkDeadline(error_code, socket); });
-    socket->lowest_layer().close();
-  }
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::startConnect(tcp::resolver::results_type::iterator endpoint_iter, const std::shared_ptr<SocketType>& socket) {
-  if (endpoint_iter == tcp::resolver::results_type::iterator()) {
-    logger_->log_trace("No more endpoints to try");
-    deadline_.cancel();
-    return;
-  }
-
-  last_error_.clear();
-  deadline_.expires_after(timeout_);
-  deadline_.async_wait([&](std::error_code error_code) -> void {
-    checkDeadline(error_code, socket.get());
-  });
-  socket->lowest_layer().async_connect(endpoint_iter->endpoint(),
-      [&socket, endpoint_iter, this](std::error_code err) {
-        handleConnect(err, endpoint_iter, socket);
-      });
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleConnect(std::error_code error,
-                                                  tcp::resolver::results_type::iterator endpoint_iter,
-                                                  const std::shared_ptr<SocketType>& socket) {
-  bool connection_failed_before_deadline = error.operator bool();
-  bool connection_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
-  if (connection_failed_due_to_deadline) {
-    core::logging::LOG_TRACE(logger_) << "Connecting to " << endpoint_iter->endpoint() << " timed out";
-    socket->lowest_layer().close();
-    return startConnect(++endpoint_iter, socket);
-  }
-
-  if (connection_failed_before_deadline) {
-    core::logging::LOG_TRACE(logger_) << "Connecting to " << endpoint_iter->endpoint() << " failed due to " << error.message();
-    last_error_ = error;
-    socket->lowest_layer().close();
-    return startConnect(++endpoint_iter, socket);
-  }
-
-  if (max_size_of_socket_send_buffer_)
-    socket->lowest_layer().set_option(TcpSocket::send_buffer_size(*max_size_of_socket_send_buffer_));
-
-  handleConnectionSuccess(endpoint_iter, socket);
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleHandshake(std::error_code,
-                                                    const tcp::resolver::results_type::iterator&,
-                                                    const std::shared_ptr<SocketType>&) {
-  throw std::invalid_argument("Handshake called without SSL");
-}
-
 template<>
-void ConnectionHandler<SslSocket>::handleHandshake(std::error_code error,
-                                                   const tcp::resolver::results_type::iterator& endpoint_iter,
-                                                   const std::shared_ptr<SslSocket>& socket) {
-  if (!error) {
-    core::logging::LOG_TRACE(logger_) << "Successful handshake with " << endpoint_iter->endpoint();
-    deadline_.cancel();
-    return;
-  }
-  core::logging::LOG_TRACE(logger_) << "Handshake with " << endpoint_iter->endpoint() << " failed due to " << error.message();
-  last_error_ = error;
-  socket->lowest_layer().close();
-  startConnect(std::next(endpoint_iter), socket);
+TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context& io_context_) {
+  gsl_Expects(!ssl_context_);
+  return TcpSocket{io_context_};
 }
 
 template<>
-void ConnectionHandler<TcpSocket>::handleConnectionSuccess(const tcp::resolver::results_type::iterator& endpoint_iter,
-                                                           const std::shared_ptr<TcpSocket>& socket) {
-  core::logging::LOG_TRACE(logger_) << "Connected to " << endpoint_iter->endpoint();
-  socket->lowest_layer().non_blocking(true);
-  deadline_.cancel();
-}
-
-template<>
-void ConnectionHandler<SslSocket>::handleConnectionSuccess(const tcp::resolver::results_type::iterator& endpoint_iter,
-                                                           const std::shared_ptr<SslSocket>& socket) {
-  core::logging::LOG_TRACE(logger_) << "Connected to " << endpoint_iter->endpoint();
-  socket->async_handshake(asio::ssl::stream_base::client, [this, &socket, endpoint_iter](const std::error_code handshake_error) {
-    handleHandshake(handshake_error, endpoint_iter, socket);
-  });
+SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context& io_context_) {
+  gsl_Expects(ssl_context_);
+  return {io_context_, *ssl_context_};
 }
 
 template<class SocketType>
-void ConnectionHandler<SocketType>::handleWrite(std::error_code error,
-                                                std::size_t bytes_written,
-                                                const std::shared_ptr<io::InputStream>& flow_file_content_stream,
-                                                const std::vector<std::byte>& delimiter,
-                                                const std::shared_ptr<SocketType>& socket) {
-  bool write_failed_before_deadline = error.operator bool();
-  bool write_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
-  if (write_failed_due_to_deadline) {
-    logger_->log_trace("Writing flowfile to socket timed out");
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  if (write_failed_before_deadline) {
-    last_error_ = error;
-    logger_->log_trace("Writing flowfile to socket failed due to %s", error.message());
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  logger_->log_trace("Writing flowfile(%zu bytes) to socket succeeded", bytes_written);
-  if (flow_file_content_stream->size() == flow_file_content_stream->tell()) {
-    asio::async_write(*socket, asio::buffer(delimiter), [&](std::error_code error, std::size_t bytes_written) {
-      handleDelimiterWrite(error, bytes_written, socket);
-    });
-  } else {
-    std::vector<std::byte> data_chunk;
-    data_chunk.resize(chunk_size);
-    gsl::span<std::byte> buffer{data_chunk};
-    size_t num_read = flow_file_content_stream->read(buffer);
-    asio::async_write(*socket, asio::buffer(data_chunk, num_read), [&](const std::error_code err, std::size_t bytes_written) {
-      handleWrite(err, bytes_written, flow_file_content_stream, delimiter, socket);
-    });
+asio::awaitable<std::error_code> ConnectionHandler<SocketType>::establishNewConnection(const tcp::resolver::results_type& resolved_query, asio::io_context& io_context) {
+  auto socket = createNewSocket(io_context);
+  std::error_code last_error;
+  for (const auto& endpoint : resolved_query) {
+    auto [connection_error] = co_await asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, use_nothrow_awaitable), timeout_duration_);
+    if (connection_error) {
+      core::logging::LOG_DEBUG(logger_) << "Connecting to " << endpoint.endpoint() << " failed due to " << connection_error.message();
+      last_error = connection_error;
+      continue;
+    }
+    auto [handshake_error] = co_await handshake(socket, timeout_duration_);
+    if (handshake_error) {
+      core::logging::LOG_DEBUG(logger_) << "Handshake with " << endpoint.endpoint() << " failed due to " << handshake_error.message();
+      last_error = handshake_error;
+      continue;
+    }
+    if (max_size_of_socket_send_buffer_)
+      socket.lowest_layer().set_option(TcpSocket::send_buffer_size(*max_size_of_socket_send_buffer_));
+    socket_.emplace(std::move(socket));
+    co_return std::error_code();
   }
+  co_return last_error;
 }
 
 template<class SocketType>
-void ConnectionHandler<SocketType>::handleDelimiterWrite(std::error_code error, std::size_t bytes_written, const std::shared_ptr<SocketType>& socket) {
-  bool write_failed_before_deadline = error.operator bool();
-  bool write_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
-  if (write_failed_due_to_deadline) {
-    logger_->log_trace("Writing delimiter to socket timed out");
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  if (write_failed_before_deadline) {
-    last_error_ = error;
-    logger_->log_trace("Writing delimiter to socket failed due to %s", error.message());
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  logger_->log_trace("Writing delimiter(%zu bytes) to socket succeeded", bytes_written);
-  deadline_.cancel();
-}
-
-
-template<>
-nonstd::expected<std::shared_ptr<TcpSocket>, std::error_code> ConnectionHandler<TcpSocket>::establishConnection(const tcp::resolver::results_type& resolved_query) {
-  auto socket = std::make_shared<TcpSocket>(io_context_);
-  startConnect(resolved_query.begin(), socket);
-  deadline_.expires_after(timeout_);
-  deadline_.async_wait([&](std::error_code error_code) -> void {
-    checkDeadline(error_code, socket.get());
-  });
-  io_context_.run();
-  if (last_error_)
-    return nonstd::make_unexpected(last_error_);
-  return socket;
-}
-
-asio::ssl::context getSslContext(const auto& ssl_context_service) {
-  gsl_Expects(ssl_context_service);
-  asio::ssl::context ssl_context(asio::ssl::context::sslv23);
-  ssl_context.load_verify_file(ssl_context_service->getCACertificate());
-  ssl_context.set_verify_mode(asio::ssl::verify_peer);
-  if (auto cert_file = ssl_context_service->getCertificateFile(); !cert_file.empty())
-    ssl_context.use_certificate_file(cert_file, asio::ssl::context::pem);
-  if (auto private_key_file = ssl_context_service->getPrivateKeyFile(); !private_key_file.empty())
-    ssl_context.use_private_key_file(private_key_file, asio::ssl::context::pem);
-  ssl_context.set_password_callback([password = ssl_context_service->getPassphrase()](std::size_t&, asio::ssl::context_base::password_purpose&) { return password; });
-  return ssl_context;
+[[nodiscard]] asio::awaitable<std::error_code> ConnectionHandler<SocketType>::setupUsableSocket(asio::io_context& io_context) {
+  if (hasUsableSocket())
+    co_return std::error_code();
+  tcp::resolver resolver(io_context);
+  auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout(resolver.async_resolve(connection_id_.getHostname(), connection_id_.getPort(), use_nothrow_awaitable), timeout_duration_);
+  if (resolve_error)
+    co_return resolve_error;
+  co_return co_await establishNewConnection(resolve_result, io_context);
 }
 
-template<>
-nonstd::expected<std::shared_ptr<SslSocket>, std::error_code> ConnectionHandler<SslSocket>::establishConnection(const tcp::resolver::results_type& resolved_query) {
-  auto ssl_context = getSslContext(ssl_context_service_);
-  auto socket = std::make_shared<SslSocket>(io_context_, ssl_context);
-  startConnect(resolved_query.begin(), socket);
-  deadline_.async_wait([&](std::error_code error_code) -> void {
-    checkDeadline(error_code, socket.get());
-  });
-  io_context_.run();
-  if (last_error_)
-    return nonstd::make_unexpected(last_error_);
-  return socket;
+template<class SocketType>
+asio::awaitable<std::error_code> ConnectionHandler<SocketType>::sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send,
+                                                                                        const std::vector<std::byte>& delimiter,
+                                                                                        asio::io_context& io_context) {
+  if (auto connection_error = co_await setupUsableSocket(io_context))  // NOLINT

Review Comment:
   what was the linter error here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@nifi.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org