adamdebreceni commented on code in PR #1457:
URL: https://github.com/apache/nifi-minifi-cpp/pull/1457#discussion_r1067003083
##########
extensions/standard-processors/processors/PutTCP.cpp:
##########
@@ -160,339 +177,145 @@ void PutTCP::onSchedule(core::ProcessContext* const
context, core::ProcessSessio
}
namespace {
+template<class SocketType>
+asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&,
asio::steady_timer::duration) {
+ co_return std::error_code();
+}
+
+template<>
+asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket,
asio::steady_timer::duration timeout_duration) {
+ co_return co_await
asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client,
use_nothrow_awaitable), timeout_duration); // NOLINT
+}
+
template<class SocketType>
class ConnectionHandler : public ConnectionHandlerBase {
public:
ConnectionHandler(detail::ConnectionId connection_id,
std::chrono::milliseconds timeout,
std::shared_ptr<core::logging::Logger> logger,
std::optional<size_t> max_size_of_socket_send_buffer,
- std::shared_ptr<controllers::SSLContextService>
ssl_context_service)
+ std::optional<asio::ssl::context>& ssl_context)
: connection_id_(std::move(connection_id)),
- timeout_(timeout),
+ timeout_duration_(timeout),
logger_(std::move(logger)),
max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer),
- ssl_context_service_(std::move(ssl_context_service)) {
+ ssl_context_(ssl_context) {
}
~ConnectionHandler() override = default;
- nonstd::expected<void, std::error_code> sendData(const
std::shared_ptr<io::InputStream>& flow_file_content_stream, const
std::vector<std::byte>& delimiter) override;
+ asio::awaitable<std::error_code> sendStreamWithDelimiter(const
std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>&
delimiter, asio::io_context& io_context_) override;
private:
- nonstd::expected<std::shared_ptr<SocketType>, std::error_code> getSocket();
-
[[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const
override {
- return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() -
dur);
+ return last_used_ && *last_used_ >= (steady_clock::now() - dur);
}
void reset() override {
last_used_.reset();
socket_.reset();
- io_context_.reset();
- last_error_.clear();
- deadline_.expires_at(asio::steady_timer::time_point::max());
}
- void checkDeadline(std::error_code error_code, SocketType* socket);
- void startConnect(tcp::resolver::results_type::iterator endpoint_iter, const
std::shared_ptr<SocketType>& socket);
-
- void handleConnect(std::error_code error,
- tcp::resolver::results_type::iterator endpoint_iter,
- const std::shared_ptr<SocketType>& socket);
- void handleConnectionSuccess(const tcp::resolver::results_type::iterator&
endpoint_iter,
- const std::shared_ptr<SocketType>& socket);
- void handleHandshake(std::error_code error,
- const tcp::resolver::results_type::iterator&
endpoint_iter,
- const std::shared_ptr<SocketType>& socket);
-
- void handleWrite(std::error_code error,
- std::size_t bytes_written,
- const std::shared_ptr<io::InputStream>&
flow_file_content_stream,
- const std::vector<std::byte>& delimiter,
- const std::shared_ptr<SocketType>& socket);
-
- void handleDelimiterWrite(std::error_code error, std::size_t bytes_written,
const std::shared_ptr<SocketType>& socket);
+ [[nodiscard]] bool hasBeenUsed() const override { return
last_used_.has_value(); }
+ [[nodiscard]] asio::awaitable<std::error_code>
setupUsableSocket(asio::io_context& io_context);
+ [[nodiscard]] bool hasUsableSocket() const { return socket_ &&
socket_->lowest_layer().is_open(); }
- nonstd::expected<std::shared_ptr<SocketType>, std::error_code>
establishConnection(const tcp::resolver::results_type& resolved_query);
+ asio::awaitable<std::error_code> establishNewConnection(const
tcp::resolver::results_type& resolved_query, asio::io_context& io_context_);
+ asio::awaitable<std::error_code> send(const
std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>&
delimiter);
- [[nodiscard]] bool hasBeenUsed() const override { return
last_used_.has_value(); }
+ SocketType createNewSocket(asio::io_context& io_context_);
detail::ConnectionId connection_id_;
- std::optional<std::chrono::steady_clock::time_point> last_used_;
- asio::io_context io_context_;
- std::error_code last_error_;
- asio::steady_timer deadline_{io_context_};
- std::chrono::milliseconds timeout_;
- std::shared_ptr<SocketType> socket_;
+ std::optional<SocketType> socket_;
+
+ std::optional<steady_clock::time_point> last_used_;
+ std::chrono::milliseconds timeout_duration_;
std::shared_ptr<core::logging::Logger> logger_;
std::optional<size_t> max_size_of_socket_send_buffer_;
- std::shared_ptr<controllers::SSLContextService> ssl_context_service_;
-
- nonstd::expected<tcp::resolver::results_type, std::error_code>
resolveHostname();
- nonstd::expected<void, std::error_code> sendDataToSocket(const
std::shared_ptr<SocketType>& socket,
- const
std::shared_ptr<io::InputStream>& flow_file_content_stream,
- const
std::vector<std::byte>& delimiter);
+ std::optional<asio::ssl::context>& ssl_context_;
};
-template<class SocketType>
-nonstd::expected<void, std::error_code>
ConnectionHandler<SocketType>::sendData(const std::shared_ptr<io::InputStream>&
flow_file_content_stream, const std::vector<std::byte>& delimiter) {
- return getSocket() | utils::flatMap([&](const std::shared_ptr<SocketType>&
socket) { return sendDataToSocket(socket, flow_file_content_stream, delimiter);
});;
-}
-
-template<class SocketType>
-nonstd::expected<std::shared_ptr<SocketType>, std::error_code>
ConnectionHandler<SocketType>::getSocket() {
- if (socket_ && socket_->lowest_layer().is_open())
- return socket_;
- auto new_socket = resolveHostname() | utils::flatMap([&](const auto&
resolved_query) { return establishConnection(resolved_query); });
- if (!new_socket)
- return nonstd::make_unexpected(new_socket.error());
- socket_ = std::move(*new_socket);
- return socket_;
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::checkDeadline(std::error_code error_code,
SocketType* socket) {
- if (error_code != asio::error::operation_aborted) {
- deadline_.expires_at(asio::steady_timer::time_point::max());
- last_error_ = asio::error::timed_out;
- deadline_.async_wait([&](std::error_code error_code) {
checkDeadline(error_code, socket); });
- socket->lowest_layer().close();
- }
-}
-
-template<class SocketType>
-void
ConnectionHandler<SocketType>::startConnect(tcp::resolver::results_type::iterator
endpoint_iter, const std::shared_ptr<SocketType>& socket) {
- if (endpoint_iter == tcp::resolver::results_type::iterator()) {
- logger_->log_trace("No more endpoints to try");
- deadline_.cancel();
- return;
- }
-
- last_error_.clear();
- deadline_.expires_after(timeout_);
- deadline_.async_wait([&](std::error_code error_code) -> void {
- checkDeadline(error_code, socket.get());
- });
- socket->lowest_layer().async_connect(endpoint_iter->endpoint(),
- [&socket, endpoint_iter, this](std::error_code err) {
- handleConnect(err, endpoint_iter, socket);
- });
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleConnect(std::error_code error,
-
tcp::resolver::results_type::iterator endpoint_iter,
- const
std::shared_ptr<SocketType>& socket) {
- bool connection_failed_before_deadline = error.operator bool();
- bool connection_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
- if (connection_failed_due_to_deadline) {
- core::logging::LOG_TRACE(logger_) << "Connecting to " <<
endpoint_iter->endpoint() << " timed out";
- socket->lowest_layer().close();
- return startConnect(++endpoint_iter, socket);
- }
-
- if (connection_failed_before_deadline) {
- core::logging::LOG_TRACE(logger_) << "Connecting to " <<
endpoint_iter->endpoint() << " failed due to " << error.message();
- last_error_ = error;
- socket->lowest_layer().close();
- return startConnect(++endpoint_iter, socket);
- }
-
- if (max_size_of_socket_send_buffer_)
-
socket->lowest_layer().set_option(TcpSocket::send_buffer_size(*max_size_of_socket_send_buffer_));
-
- handleConnectionSuccess(endpoint_iter, socket);
-}
-
-template<class SocketType>
-void ConnectionHandler<SocketType>::handleHandshake(std::error_code,
- const
tcp::resolver::results_type::iterator&,
- const
std::shared_ptr<SocketType>&) {
- throw std::invalid_argument("Handshake called without SSL");
-}
-
template<>
-void ConnectionHandler<SslSocket>::handleHandshake(std::error_code error,
- const
tcp::resolver::results_type::iterator& endpoint_iter,
- const
std::shared_ptr<SslSocket>& socket) {
- if (!error) {
- core::logging::LOG_TRACE(logger_) << "Successful handshake with " <<
endpoint_iter->endpoint();
- deadline_.cancel();
- return;
- }
- core::logging::LOG_TRACE(logger_) << "Handshake with " <<
endpoint_iter->endpoint() << " failed due to " << error.message();
- last_error_ = error;
- socket->lowest_layer().close();
- startConnect(std::next(endpoint_iter), socket);
+TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context&
io_context_) {
+ gsl_Expects(!ssl_context_);
+ return TcpSocket{io_context_};
}
template<>
-void ConnectionHandler<TcpSocket>::handleConnectionSuccess(const
tcp::resolver::results_type::iterator& endpoint_iter,
- const
std::shared_ptr<TcpSocket>& socket) {
- core::logging::LOG_TRACE(logger_) << "Connected to " <<
endpoint_iter->endpoint();
- socket->lowest_layer().non_blocking(true);
- deadline_.cancel();
-}
-
-template<>
-void ConnectionHandler<SslSocket>::handleConnectionSuccess(const
tcp::resolver::results_type::iterator& endpoint_iter,
- const
std::shared_ptr<SslSocket>& socket) {
- core::logging::LOG_TRACE(logger_) << "Connected to " <<
endpoint_iter->endpoint();
- socket->async_handshake(asio::ssl::stream_base::client, [this, &socket,
endpoint_iter](const std::error_code handshake_error) {
- handleHandshake(handshake_error, endpoint_iter, socket);
- });
+SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context&
io_context_) {
+ gsl_Expects(ssl_context_);
+ return {io_context_, *ssl_context_};
}
template<class SocketType>
-void ConnectionHandler<SocketType>::handleWrite(std::error_code error,
- std::size_t bytes_written,
- const
std::shared_ptr<io::InputStream>& flow_file_content_stream,
- const std::vector<std::byte>&
delimiter,
- const
std::shared_ptr<SocketType>& socket) {
- bool write_failed_before_deadline = error.operator bool();
- bool write_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
- if (write_failed_due_to_deadline) {
- logger_->log_trace("Writing flowfile to socket timed out");
- socket->lowest_layer().close();
- deadline_.cancel();
- return;
- }
-
- if (write_failed_before_deadline) {
- last_error_ = error;
- logger_->log_trace("Writing flowfile to socket failed due to %s",
error.message());
- socket->lowest_layer().close();
- deadline_.cancel();
- return;
- }
-
- logger_->log_trace("Writing flowfile(%zu bytes) to socket succeeded",
bytes_written);
- if (flow_file_content_stream->size() == flow_file_content_stream->tell()) {
- asio::async_write(*socket, asio::buffer(delimiter), [&](std::error_code
error, std::size_t bytes_written) {
- handleDelimiterWrite(error, bytes_written, socket);
- });
- } else {
- std::vector<std::byte> data_chunk;
- data_chunk.resize(chunk_size);
- gsl::span<std::byte> buffer{data_chunk};
- size_t num_read = flow_file_content_stream->read(buffer);
- asio::async_write(*socket, asio::buffer(data_chunk, num_read), [&](const
std::error_code err, std::size_t bytes_written) {
- handleWrite(err, bytes_written, flow_file_content_stream, delimiter,
socket);
- });
+asio::awaitable<std::error_code>
ConnectionHandler<SocketType>::establishNewConnection(const
tcp::resolver::results_type& resolved_query, asio::io_context& io_context) {
+ auto socket = createNewSocket(io_context);
+ std::error_code last_error;
+ for (const auto& endpoint : resolved_query) {
+ auto [connection_error] = co_await
asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint,
use_nothrow_awaitable), timeout_duration_);
+ if (connection_error) {
+ core::logging::LOG_DEBUG(logger_) << "Connecting to " <<
endpoint.endpoint() << " failed due to " << connection_error.message();
+ last_error = connection_error;
+ continue;
+ }
+ auto [handshake_error] = co_await handshake(socket, timeout_duration_);
+ if (handshake_error) {
+ core::logging::LOG_DEBUG(logger_) << "Handshake with " <<
endpoint.endpoint() << " failed due to " << handshake_error.message();
+ last_error = handshake_error;
+ continue;
+ }
+ if (max_size_of_socket_send_buffer_)
+
socket.lowest_layer().set_option(TcpSocket::send_buffer_size(*max_size_of_socket_send_buffer_));
+ socket_.emplace(std::move(socket));
+ co_return std::error_code();
}
+ co_return last_error;
}
template<class SocketType>
-void ConnectionHandler<SocketType>::handleDelimiterWrite(std::error_code
error, std::size_t bytes_written, const std::shared_ptr<SocketType>& socket) {
- bool write_failed_before_deadline = error.operator bool();
- bool write_failed_due_to_deadline = !socket->lowest_layer().is_open();
-
- if (write_failed_due_to_deadline) {
- logger_->log_trace("Writing delimiter to socket timed out");
- socket->lowest_layer().close();
- deadline_.cancel();
- return;
- }
-
- if (write_failed_before_deadline) {
- last_error_ = error;
- logger_->log_trace("Writing delimiter to socket failed due to %s",
error.message());
- socket->lowest_layer().close();
- deadline_.cancel();
- return;
- }
-
- logger_->log_trace("Writing delimiter(%zu bytes) to socket succeeded",
bytes_written);
- deadline_.cancel();
-}
-
-
-template<>
-nonstd::expected<std::shared_ptr<TcpSocket>, std::error_code>
ConnectionHandler<TcpSocket>::establishConnection(const
tcp::resolver::results_type& resolved_query) {
- auto socket = std::make_shared<TcpSocket>(io_context_);
- startConnect(resolved_query.begin(), socket);
- deadline_.expires_after(timeout_);
- deadline_.async_wait([&](std::error_code error_code) -> void {
- checkDeadline(error_code, socket.get());
- });
- io_context_.run();
- if (last_error_)
- return nonstd::make_unexpected(last_error_);
- return socket;
-}
-
-asio::ssl::context getSslContext(const auto& ssl_context_service) {
- gsl_Expects(ssl_context_service);
- asio::ssl::context ssl_context(asio::ssl::context::sslv23);
- ssl_context.load_verify_file(ssl_context_service->getCACertificate());
- ssl_context.set_verify_mode(asio::ssl::verify_peer);
- if (auto cert_file = ssl_context_service->getCertificateFile();
!cert_file.empty())
- ssl_context.use_certificate_file(cert_file, asio::ssl::context::pem);
- if (auto private_key_file = ssl_context_service->getPrivateKeyFile();
!private_key_file.empty())
- ssl_context.use_private_key_file(private_key_file,
asio::ssl::context::pem);
- ssl_context.set_password_callback([password =
ssl_context_service->getPassphrase()](std::size_t&,
asio::ssl::context_base::password_purpose&) { return password; });
- return ssl_context;
+[[nodiscard]] asio::awaitable<std::error_code>
ConnectionHandler<SocketType>::setupUsableSocket(asio::io_context& io_context) {
+ if (hasUsableSocket())
+ co_return std::error_code();
+ tcp::resolver resolver(io_context);
+ auto [resolve_error, resolve_result] = co_await
asyncOperationWithTimeout(resolver.async_resolve(connection_id_.getHostname(),
connection_id_.getPort(), use_nothrow_awaitable), timeout_duration_);
+ if (resolve_error)
+ co_return resolve_error;
+ co_return co_await establishNewConnection(resolve_result, io_context);
}
-template<>
-nonstd::expected<std::shared_ptr<SslSocket>, std::error_code>
ConnectionHandler<SslSocket>::establishConnection(const
tcp::resolver::results_type& resolved_query) {
- auto ssl_context = getSslContext(ssl_context_service_);
- auto socket = std::make_shared<SslSocket>(io_context_, ssl_context);
- startConnect(resolved_query.begin(), socket);
- deadline_.async_wait([&](std::error_code error_code) -> void {
- checkDeadline(error_code, socket.get());
- });
- io_context_.run();
- if (last_error_)
- return nonstd::make_unexpected(last_error_);
- return socket;
+template<class SocketType>
+asio::awaitable<std::error_code>
ConnectionHandler<SocketType>::sendStreamWithDelimiter(const
std::shared_ptr<io::InputStream>& stream_to_send,
+
const std::vector<std::byte>& delimiter,
+
asio::io_context& io_context) {
+ if (auto connection_error = co_await setupUsableSocket(io_context)) //
NOLINT
+ co_return connection_error;
+ co_return co_await send(stream_to_send, delimiter);
}
template<class SocketType>
-nonstd::expected<void, std::error_code>
ConnectionHandler<SocketType>::sendDataToSocket(const
std::shared_ptr<SocketType>& socket,
-
const std::shared_ptr<io::InputStream>& flow_file_content_stream,
-
const std::vector<std::byte>& delimiter) {
- if (!socket || !socket->lowest_layer().is_open())
- return nonstd::make_unexpected(asio::error::not_socket);
-
- deadline_.expires_after(timeout_);
- deadline_.async_wait([&](std::error_code error_code) -> void {
- checkDeadline(error_code, socket.get());
- });
- io_context_.restart();
+asio::awaitable<std::error_code> ConnectionHandler<SocketType>::send(const
std::shared_ptr<io::InputStream>& stream_to_send,
+ const
std::vector<std::byte>& delimiter) {
+ gsl_Expects(hasUsableSocket());
std::vector<std::byte> data_chunk;
data_chunk.resize(chunk_size);
-
gsl::span<std::byte> buffer{data_chunk};
- size_t num_read = flow_file_content_stream->read(buffer);
- logger_->log_trace("read %zu bytes from flowfile", num_read);
- asio::async_write(*socket, asio::buffer(data_chunk, num_read), [&](const
std::error_code err, std::size_t bytes_written) {
- handleWrite(err, bytes_written, flow_file_content_stream, delimiter,
socket);
- });
- deadline_.async_wait([&](std::error_code error_code) -> void {
- checkDeadline(error_code, socket.get());
- });
- io_context_.run();
- if (last_error_)
- return nonstd::make_unexpected(last_error_);
- last_used_ = std::chrono::steady_clock::now();
- return {};
-}
+ while (stream_to_send->tell() < stream_to_send->size()) {
+ size_t num_read = stream_to_send->read(buffer);
Review Comment:
`io::isError` seems to be used instead of direct comparison with the error
value
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]