szaszm commented on code in PR #1457:
URL: https://github.com/apache/nifi-minifi-cpp/pull/1457#discussion_r1061463628


##########
extensions/standard-processors/processors/PutTCP.cpp:
##########
@@ -114,6 +114,21 @@ void PutTCP::initialize() {
 
 void PutTCP::notifyStop() {}
 
+namespace {
+asio::ssl::context getSslContext(const 
std::shared_ptr<controllers::SSLContextService>& ssl_context_service) {
+  gsl_Expects(ssl_context_service);
+  asio::ssl::context ssl_context(asio::ssl::context::sslv23);

Review Comment:
   Can we restrict this to only allow TLS 1.2 or later? TLS 1.3 or later would 
be even better, but all SSL versions are hopelessly insecure, and TLS 1.0 and 
1.1 are widely deprecated by now due to known attacks. Even TLS 1.2 is only 
considered secure with an appropriately restricted cipher suite.



##########
extensions/standard-processors/processors/PutTCP.cpp:
##########
@@ -160,339 +177,147 @@ void PutTCP::onSchedule(core::ProcessContext* const 
context, core::ProcessSessio
 }
 
 namespace {
+template<class SocketType>
+asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, 
asio::steady_timer::duration) {
+  co_return std::error_code();
+}
+
+template<>
+asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, 
asio::steady_timer::duration timeout_duration) {
+  co_return co_await 
asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, 
use_nothrow_awaitable), timeout_duration);  // NOLINT
+}
+
 template<class SocketType>
 class ConnectionHandler : public ConnectionHandlerBase {
  public:
   ConnectionHandler(detail::ConnectionId connection_id,
                     std::chrono::milliseconds timeout,
                     std::shared_ptr<core::logging::Logger> logger,
                     std::optional<size_t> max_size_of_socket_send_buffer,
-                    std::shared_ptr<controllers::SSLContextService> 
ssl_context_service)
+                    std::optional<asio::ssl::context>& ssl_context)
       : connection_id_(std::move(connection_id)),
-        timeout_(timeout),
+        timeout_duration_(timeout),
         logger_(std::move(logger)),
         max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer),
-        ssl_context_service_(std::move(ssl_context_service)) {
+        ssl_context_(ssl_context) {
   }
 
   ~ConnectionHandler() override = default;
 
-  nonstd::expected<void, std::error_code> sendData(const 
std::shared_ptr<io::InputStream>& flow_file_content_stream, const 
std::vector<std::byte>& delimiter) override;
+  asio::awaitable<std::error_code> sendStreamWithDelimiter(const 
std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& 
delimiter, asio::io_context& io_context_) override;
 
  private:
-  nonstd::expected<std::shared_ptr<SocketType>, std::error_code> getSocket();
-
   [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const 
override {
-    return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() - 
dur);
+    return last_used_ && *last_used_ >= (steady_clock::now() - dur);
   }
 
   void reset() override {
     last_used_.reset();
     socket_.reset();
-    io_context_.reset();
-    last_error_.clear();
-    deadline_.expires_at(asio::steady_timer::time_point::max());
   }
 
-  void checkDeadline(std::error_code error_code, SocketType* socket);
-  void startConnect(tcp::resolver::results_type::iterator endpoint_iter, const 
std::shared_ptr<SocketType>& socket);
-
-  void handleConnect(std::error_code error,
-                     tcp::resolver::results_type::iterator endpoint_iter,
-                     const std::shared_ptr<SocketType>& socket);
-  void handleConnectionSuccess(const tcp::resolver::results_type::iterator& 
endpoint_iter,
-                               const std::shared_ptr<SocketType>& socket);
-  void handleHandshake(std::error_code error,
-                       const tcp::resolver::results_type::iterator& 
endpoint_iter,
-                       const std::shared_ptr<SocketType>& socket);
-
-  void handleWrite(std::error_code error,
-                   std::size_t bytes_written,
-                   const std::shared_ptr<io::InputStream>& 
flow_file_content_stream,
-                   const std::vector<std::byte>& delimiter,
-                   const std::shared_ptr<SocketType>& socket);
-
-  void handleDelimiterWrite(std::error_code error, std::size_t bytes_written, 
const std::shared_ptr<SocketType>& socket);
+  [[nodiscard]] bool hasBeenUsed() const override { return 
last_used_.has_value(); }
+  [[nodiscard]] asio::awaitable<std::error_code> 
setupUsableSocket(asio::io_context& io_context);
+  [[nodiscard]] bool hasUsableSocket() const {  return socket_ && 
socket_->lowest_layer().is_open(); }
 
-  nonstd::expected<std::shared_ptr<SocketType>, std::error_code> 
establishConnection(const tcp::resolver::results_type& resolved_query);
+  asio::awaitable<std::error_code> establishNewConnection(const 
tcp::resolver::results_type& resolved_query, asio::io_context& io_context_);
+  asio::awaitable<std::error_code> send(const 
std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& 
delimiter);
 
-  [[nodiscard]] bool hasBeenUsed() const override { return 
last_used_.has_value(); }
+  SocketType createNewSocket(asio::io_context& io_context_);
 
   detail::ConnectionId connection_id_;
-  std::optional<std::chrono::steady_clock::time_point> last_used_;
-  asio::io_context io_context_;
-  std::error_code last_error_;
-  asio::steady_timer deadline_{io_context_};
-  std::chrono::milliseconds timeout_;
-  std::shared_ptr<SocketType> socket_;
+  std::optional<SocketType> socket_;
+
+  std::optional<steady_clock::time_point> last_used_;
+  std::chrono::milliseconds timeout_duration_;
 
   std::shared_ptr<core::logging::Logger> logger_;
   std::optional<size_t> max_size_of_socket_send_buffer_;
 
-  std::shared_ptr<controllers::SSLContextService> ssl_context_service_;
-
-  nonstd::expected<tcp::resolver::results_type, std::error_code> 
resolveHostname();
-  nonstd::expected<void, std::error_code> sendDataToSocket(const 
std::shared_ptr<SocketType>& socket,
-                                                           const 
std::shared_ptr<io::InputStream>& flow_file_content_stream,
-                                                           const 
std::vector<std::byte>& delimiter);
+  std::optional<asio::ssl::context>& ssl_context_;
 };
 
-template<class SocketType>
-nonstd::expected<void, std::error_code> 
ConnectionHandler<SocketType>::sendData(const std::shared_ptr<io::InputStream>& 
flow_file_content_stream, const std::vector<std::byte>& delimiter) {
-  return getSocket() | utils::flatMap([&](const std::shared_ptr<SocketType>& 
socket) { return sendDataToSocket(socket, flow_file_content_stream, delimiter); 
});;
-}
-
-template<class SocketType>
-nonstd::expected<std::shared_ptr<SocketType>, std::error_code> 
ConnectionHandler<SocketType>::getSocket() {
-  if (socket_ && socket_->lowest_layer().is_open())
-    return socket_;
-  auto new_socket = resolveHostname() | utils::flatMap([&](const auto& 
resolved_query) { return establishConnection(resolved_query); });
-  if (!new_socket)
-    return nonstd::make_unexpected(new_socket.error());
-  socket_ = std::move(*new_socket);
-  return socket_;
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::checkDeadline(std::error_code error_code, 
SocketType* socket) {
-  if (error_code != asio::error::operation_aborted) {
-    deadline_.expires_at(asio::steady_timer::time_point::max());
-    last_error_ = asio::error::timed_out;
-    deadline_.async_wait([&](std::error_code error_code) { 
checkDeadline(error_code, socket); });
-    socket->lowest_layer().close();
-  }
-}
-
-template<class SocketType>
-void 
ConnectionHandler<SocketType>::startConnect(tcp::resolver::results_type::iterator
 endpoint_iter, const std::shared_ptr<SocketType>& socket) {
-  if (endpoint_iter == tcp::resolver::results_type::iterator()) {
-    logger_->log_trace("No more endpoints to try");
-    deadline_.cancel();
-    return;
-  }
-
-  last_error_.clear();
-  deadline_.expires_after(timeout_);
-  deadline_.async_wait([&](std::error_code error_code) -> void {
-    checkDeadline(error_code, socket.get());
-  });
-  socket->lowest_layer().async_connect(endpoint_iter->endpoint(),
-      [&socket, endpoint_iter, this](std::error_code err) {
-        handleConnect(err, endpoint_iter, socket);
-      });
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleConnect(std::error_code error,
-                                                  
tcp::resolver::results_type::iterator endpoint_iter,
-                                                  const 
std::shared_ptr<SocketType>& socket) {
-  bool connection_failed_before_deadline = error.operator bool();
-  bool connection_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
-  if (connection_failed_due_to_deadline) {
-    core::logging::LOG_TRACE(logger_) << "Connecting to " << 
endpoint_iter->endpoint() << " timed out";
-    socket->lowest_layer().close();
-    return startConnect(++endpoint_iter, socket);
-  }
-
-  if (connection_failed_before_deadline) {
-    core::logging::LOG_TRACE(logger_) << "Connecting to " << 
endpoint_iter->endpoint() << " failed due to " << error.message();
-    last_error_ = error;
-    socket->lowest_layer().close();
-    return startConnect(++endpoint_iter, socket);
-  }
-
-  if (max_size_of_socket_send_buffer_)
-    
socket->lowest_layer().set_option(TcpSocket::send_buffer_size(*max_size_of_socket_send_buffer_));
-
-  handleConnectionSuccess(endpoint_iter, socket);
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleHandshake(std::error_code,
-                                                    const 
tcp::resolver::results_type::iterator&,
-                                                    const 
std::shared_ptr<SocketType>&) {
-  throw std::invalid_argument("Handshake called without SSL");
-}
-
 template<>
-void ConnectionHandler<SslSocket>::handleHandshake(std::error_code error,
-                                                   const 
tcp::resolver::results_type::iterator& endpoint_iter,
-                                                   const 
std::shared_ptr<SslSocket>& socket) {
-  if (!error) {
-    core::logging::LOG_TRACE(logger_) << "Successful handshake with " << 
endpoint_iter->endpoint();
-    deadline_.cancel();
-    return;
-  }
-  core::logging::LOG_TRACE(logger_) << "Handshake with " << 
endpoint_iter->endpoint() << " failed due to " << error.message();
-  last_error_ = error;
-  socket->lowest_layer().close();
-  startConnect(std::next(endpoint_iter), socket);
+TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context& 
io_context_) {
+  gsl_Expects(!ssl_context_);
+  return TcpSocket{io_context_};
 }
 
 template<>
-void ConnectionHandler<TcpSocket>::handleConnectionSuccess(const 
tcp::resolver::results_type::iterator& endpoint_iter,
-                                                           const 
std::shared_ptr<TcpSocket>& socket) {
-  core::logging::LOG_TRACE(logger_) << "Connected to " << 
endpoint_iter->endpoint();
-  socket->lowest_layer().non_blocking(true);
-  deadline_.cancel();
-}
-
-template<>
-void ConnectionHandler<SslSocket>::handleConnectionSuccess(const 
tcp::resolver::results_type::iterator& endpoint_iter,
-                                                           const 
std::shared_ptr<SslSocket>& socket) {
-  core::logging::LOG_TRACE(logger_) << "Connected to " << 
endpoint_iter->endpoint();
-  socket->async_handshake(asio::ssl::stream_base::client, [this, &socket, 
endpoint_iter](const std::error_code handshake_error) {
-    handleHandshake(handshake_error, endpoint_iter, socket);
-  });
+SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context& 
io_context_) {
+  gsl_Expects(ssl_context_);
+  return {io_context_, *ssl_context_};
 }
 
 template<class SocketType>
-void ConnectionHandler<SocketType>::handleWrite(std::error_code error,
-                                                std::size_t bytes_written,
-                                                const 
std::shared_ptr<io::InputStream>& flow_file_content_stream,
-                                                const std::vector<std::byte>& 
delimiter,
-                                                const 
std::shared_ptr<SocketType>& socket) {
-  bool write_failed_before_deadline = error.operator bool();
-  bool write_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
-  if (write_failed_due_to_deadline) {
-    logger_->log_trace("Writing flowfile to socket timed out");
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  if (write_failed_before_deadline) {
-    last_error_ = error;
-    logger_->log_trace("Writing flowfile to socket failed due to %s", 
error.message());
-    socket->lowest_layer().close();
-    deadline_.cancel();
-    return;
-  }
-
-  logger_->log_trace("Writing flowfile(%zu bytes) to socket succeeded", 
bytes_written);
-  if (flow_file_content_stream->size() == flow_file_content_stream->tell()) {
-    asio::async_write(*socket, asio::buffer(delimiter), [&](std::error_code 
error, std::size_t bytes_written) {
-      handleDelimiterWrite(error, bytes_written, socket);
-    });
-  } else {
-    std::vector<std::byte> data_chunk;
-    data_chunk.resize(chunk_size);
-    gsl::span<std::byte> buffer{data_chunk};
-    size_t num_read = flow_file_content_stream->read(buffer);
-    asio::async_write(*socket, asio::buffer(data_chunk, num_read), [&](const 
std::error_code err, std::size_t bytes_written) {
-      handleWrite(err, bytes_written, flow_file_content_stream, delimiter, 
socket);
-    });
+asio::awaitable<std::error_code> 
ConnectionHandler<SocketType>::establishNewConnection(const 
tcp::resolver::results_type& resolved_query, asio::io_context& io_context) {
+  auto socket = createNewSocket(io_context);
+  std::error_code last_error;
+  for (const auto& endpoint : resolved_query) {

Review Comment:
   `resolved_query` is not a descriptive name. I suggest renaming to 
`endpoints`.



##########
libminifi/src/utils/net/TcpServer.cpp:
##########
@@ -15,53 +15,76 @@
  * limitations under the License.
  */
 #include "utils/net/TcpServer.h"
+#include "utils/net/AsioCoro.h"
 
 namespace org::apache::nifi::minifi::utils::net {
 
-TcpSession::TcpSession(asio::io_context& io_context, 
utils::ConcurrentQueue<Message>& concurrent_queue, std::optional<size_t> 
max_queue_size, std::shared_ptr<core::logging::Logger> logger)
-  : concurrent_queue_(concurrent_queue),
-    max_queue_size_(max_queue_size),
-    socket_(io_context),
-    logger_(std::move(logger)) {
+asio::awaitable<void> TcpServer::listen() {
+  asio::ip::tcp::acceptor acceptor(io_context_, 
asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_));
+  if (port_ == 0)
+    port_ = acceptor.local_endpoint().port();
+  while (true) {
+    auto [accept_error, socket] = co_await 
acceptor.async_accept(use_nothrow_awaitable);
+    if (accept_error) {
+      logger_->log_error("Error during accepting new connection: %s", 
accept_error.message());
+      break;
+    }
+    if (ssl_data_)
+      co_spawn(io_context_, secureSession(std::move(socket)), asio::detached);
+    else
+      co_spawn(io_context_, insecureSession(std::move(socket)), 
asio::detached);
+  }
 }
 
-asio::ip::tcp::socket& TcpSession::getSocket() {
-  return socket_;
-}
+asio::awaitable<void> TcpServer::readLoop(auto& socket) {
+  std::string read_message;
+  while (true) {
+    auto [read_error, bytes_read] = co_await asio::async_read_until(socket, 
asio::dynamic_buffer(read_message), '\n', use_nothrow_awaitable);  // NOLINT
+    if (read_error || bytes_read == 0)
+      co_return;
 
-void TcpSession::start() {
-  asio::async_read_until(socket_,
-                         buffer_,
-                         '\n',
-                         [self = shared_from_this()](const auto& error_code, 
size_t) -> void {
-                           self->handleReadUntilNewLine(error_code);
-                         });
+    if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size())
+      concurrent_queue_.enqueue(Message(read_message.substr(0, bytes_read - 
1), IpProtocol::TCP, socket.lowest_layer().remote_endpoint().address(), 
socket.lowest_layer().local_endpoint().port()));
+    else
+      logger_->log_warn("Queue is full. TCP message ignored.");
+    read_message.erase(0, bytes_read);
+  }
 }
 
-void TcpSession::handleReadUntilNewLine(std::error_code error_code) {
-  if (error_code)
-    return;
-  std::istream is(&buffer_);
-  std::string message;
-  std::getline(is, message);
-  if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size())
-    concurrent_queue_.enqueue(Message(message, IpProtocol::TCP, 
socket_.remote_endpoint().address(), socket_.local_endpoint().port()));
-  else
-    logger_->log_warn("Queue is full. TCP message ignored.");
-  asio::async_read_until(socket_,
-                         buffer_,
-                         '\n',
-                         [self = shared_from_this()](const auto& error_code, 
size_t) -> void {
-                           self->handleReadUntilNewLine(error_code);
-                         });
+asio::awaitable<void> TcpServer::insecureSession(asio::ip::tcp::socket socket) 
{
+  co_return co_await readLoop(socket);  // NOLINT
 }
 
-TcpServer::TcpServer(std::optional<size_t> max_queue_size, uint16_t port, 
std::shared_ptr<core::logging::Logger> logger)
-    : SessionHandlingServer<TcpSession>(max_queue_size, port, 
std::move(logger)) {
+namespace {
+asio::ssl::context setupSslContext(SslServerOptions& ssl_data) {
+  asio::ssl::context ssl_context(asio::ssl::context::sslv23);
+  ssl_context.set_options(
+      asio::ssl::context::default_workarounds
+      | asio::ssl::context::no_sslv2
+      | asio::ssl::context::single_dh_use);

Review Comment:
   Please consider disabling more than just SSLv2. SSLv3 is also very insecure, 
but even TLS 1.0 and 1.1 are widely deprecated by now.



##########
libminifi/test/Utils.h:
##########
@@ -166,10 +170,10 @@ struct FlowFileQueueTestAccessor {
   FIELD_ACCESSOR(queue_);
 };
 
-bool sendMessagesViaSSL(const std::vector<std::string_view>& contents,
-                        const asio::ip::tcp::endpoint& remote_endpoint,
-                        const std::filesystem::path& ca_cert_path,
-                        const std::optional<minifi::utils::net::SslData>& 
ssl_data = std::nullopt) {
+std::error_code sendMessagesViaSSL(const std::vector<std::string_view>& 
contents,
+                                   const asio::ip::tcp::endpoint& 
remote_endpoint,
+                                   const std::filesystem::path& ca_cert_path,
+                                   const 
std::optional<minifi::utils::net::SslData>& ssl_data = std::nullopt) {

Review Comment:
   I'm not asking for a change, but this perfectly demonstrates why I don't 
like aligned continuations.



##########
libminifi/src/utils/net/UdpServer.cpp:
##########
@@ -15,32 +15,39 @@
  * limitations under the License.
  */
 #include "utils/net/UdpServer.h"
+#include "asio/use_awaitable.hpp"
+#include "asio/detached.hpp"
+#include "utils/net/AsioCoro.h"
 
 namespace org::apache::nifi::minifi::utils::net {
 
+constexpr size_t MAX_UDP_PACKET_SIZE = 65535;
+
 UdpServer::UdpServer(std::optional<size_t> max_queue_size,
                      uint16_t port,
                      std::shared_ptr<core::logging::Logger> logger)
-    : Server(max_queue_size, std::move(logger)),
-      socket_(io_context_, asio::ip::udp::endpoint(asio::ip::udp::v6(), port)) 
{
-  doReceive();
+    : Server(max_queue_size, port, std::move(logger)) {
 }
 
+asio::awaitable<void> UdpServer::listen() {

Review Comment:
   This is doing much more than just "listen" (which is something that doesn't 
even happen with UDP). I also don't see the point of adding the messages to an 
internal queue over just returning (or co_yield?) them.



##########
libminifi/test/Utils.h:
##########
@@ -183,33 +187,51 @@ bool sendMessagesViaSSL(const 
std::vector<std::string_view>& contents,
   asio::error_code err;
   socket.lowest_layer().connect(remote_endpoint, err);
   if (err) {
-    return false;
+    return err;
   }
   socket.handshake(asio::ssl::stream_base::client, err);
   if (err) {
-    return false;
+    return err;
   }
   for (auto& content : contents) {
     std::string tcp_message(content);
     tcp_message += '\n';
     asio::write(socket, asio::buffer(tcp_message, tcp_message.size()), err);
     if (err) {
-      return false;
+      return err;
     }
   }
-  return true;
+  return std::error_code();
 }
 
 #ifdef WIN32
 inline std::error_code hide_file(const std::filesystem::path& file_name) {
-    const bool success = SetFileAttributesA(file_name.string().c_str(), 
FILE_ATTRIBUTE_HIDDEN);
-    if (!success) {
-      // note: All possible documented error codes from GetLastError are in 
[0;15999] at the time of writing.
-      // The below casting is safe in [0;std::numeric_limits<int>::max()], int 
max is guaranteed to be at least 32767
-      return { static_cast<int>(GetLastError()), std::system_category() };
-    }
-    return {};
+  const bool success = SetFileAttributesA(file_name.string().c_str(), 
FILE_ATTRIBUTE_HIDDEN);
+  if (!success) {
+    // note: All possible documented error codes from GetLastError are in 
[0;15999] at the time of writing.
+    // The below casting is safe in [0;std::numeric_limits<int>::max()], int 
max is guaranteed to be at least 32767
+    return { static_cast<int>(GetLastError()), std::system_category() };
   }
+  return {};
+}
 #endif /* WIN32 */
 
+template<class T>
+uint16_t scheduleProcessorOnRandomPort(const std::shared_ptr<TestPlan>& 
test_plan, const std::shared_ptr<T>& processor) {

Review Comment:
   Could you add some static checks for the template argument?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to