szaszm commented on a change in pull request #713: WIP: MINIFICPP-1119 MINIFICPP-1154 unify win/posix sockets + fix bugs URL: https://github.com/apache/nifi-minifi-cpp/pull/713#discussion_r380343221
########## File path: libminifi/src/io/ClientSocket.cpp ########## @@ -0,0 +1,621 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "io/ClientSocket.h" +#ifndef WIN32 +#include <netinet/tcp.h> +#include <sys/types.h> +#include <netinet/in.h> +#include <ifaddrs.h> +#include <unistd.h> +#else +#include <WS2tcpip.h> +#pragma comment(lib, "Ws2_32.lib") +#endif /* !WIN32 */ + +#include <memory> +#include <utility> +#include <vector> +#include <cerrno> +#include <string> +#include <system_error> +#include <cinttypes> +#include <Exception.h> +#include <utils/Deleters.h> +#include "io/validation.h" +#include "core/logging/LoggerConfiguration.h" +#include "utils/GeneralUtils.h" + +namespace util = org::apache::nifi::minifi::utils; + +namespace { + +std::string get_last_err_str() { +#ifdef WIN32 + const auto error_code = WSAGetLastError(); +#else + const auto error_code = errno; +#endif /* WIN32 */ + return std::system_category().message(error_code); +} + +std::string get_last_getaddrinfo_err_str(int getaddrinfo_result) { +#ifdef WIN32 + (void)getaddrinfo_result; // against unused warnings on windows + return get_last_err_str(); +#else + return gai_strerror(getaddrinfo_result); +#endif /* WIN32 */ +} + +bool valid_sock_fd(org::apache::nifi::minifi::io::SocketDescriptor fd) { +#ifdef WIN32 + return fd != INVALID_SOCKET && fd >= 0; +#else + return fd >= 0; +#endif /* WIN32 */ +} + +std::string sockaddr_ntop(const sockaddr* const sa) { + std::string result; + if (sa->sa_family == AF_INET) { + sockaddr_in sa_in{}; + std::memcpy(reinterpret_cast<void*>(&sa_in), sa, sizeof(sockaddr_in)); + result.resize(INET_ADDRSTRLEN); + if (inet_ntop(AF_INET, &sa_in.sin_addr, &result[0], INET_ADDRSTRLEN) == nullptr) { + throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, get_last_err_str() }; + } + } else if (sa->sa_family == AF_INET6) { + sockaddr_in6 sa_in6{}; + std::memcpy(reinterpret_cast<void*>(&sa_in6), sa, sizeof(sockaddr_in6)); + result.resize(INET6_ADDRSTRLEN); + if (inet_ntop(AF_INET6, &sa_in6.sin6_addr, &result[0], INET6_ADDRSTRLEN) == nullptr) { + throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, get_last_err_str() }; + } + } else { + throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, "sockaddr_ntop: unknown address family" }; + } + result.resize(strlen(result.c_str())); // discard remaining null bytes at the end + return result; +} + +template<typename T, typename Pred, typename Adv> +auto find_if_custom_linked_list(T* const list, const Adv advance_func, const Pred predicate) -> + typename std::enable_if<std::is_convertible<decltype(advance_func(std::declval<T*>())), T*>::value && std::is_convertible<decltype(predicate(std::declval<T*>())), bool>::value, T*>::type +{ + for (T* it = list; it; it = advance_func(it)) { + if (predicate(it)) return it; + } + return nullptr; +} + +#ifndef WIN32 +std::error_code bind_to_local_network_interface(const minifi::io::SocketDescriptor fd, const minifi::io::NetworkInterface& interface) { + using ifaddrs_uniq_ptr = std::unique_ptr<ifaddrs, util::ifaddrs_deleter>; + const auto if_list_ptr = []() -> ifaddrs_uniq_ptr { + ifaddrs *list = nullptr; + const auto get_ifa_success = getifaddrs(&list) == 0; + assert(get_ifa_success || !list); + return ifaddrs_uniq_ptr{ list }; + }(); + if (!if_list_ptr) { return { errno, std::generic_category() }; } + + const auto advance_func = [](const ifaddrs *const p) { return p->ifa_next; }; + const auto predicate = [&interface](const ifaddrs *const item) { + return item->ifa_addr && item->ifa_name && (item->ifa_addr->sa_family == AF_INET || item->ifa_addr->sa_family == AF_INET6) + && item->ifa_name == interface.getInterface(); + }; + const auto *const itemFound = find_if_custom_linked_list(if_list_ptr.get(), advance_func, predicate); + if (itemFound == nullptr) { return std::make_error_code(std::errc::no_such_device_or_address); } + + const socklen_t addrlen = itemFound->ifa_addr->sa_family == AF_INET ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + if (bind(fd, itemFound->ifa_addr, addrlen) != 0) { return { errno, std::generic_category() }; } + return {}; +} +#endif /* !WIN32 */ + +std::error_code set_non_blocking(const minifi::io::SocketDescriptor fd) noexcept { +#ifndef WIN32 + if (fcntl(fd, F_SETFL, O_NONBLOCK) < 0) { + return { errno, std::generic_category() }; + } +#else + u_long iMode = 1; + if (ioctlsocket(fd, FIONBIO, &iMode) == SOCKET_ERROR) { + return { WSAGetLastError(), std::system_category() }; + } +#endif /* !WIN32 */ + return {}; +} +} // namespace + +namespace org { +namespace apache { +namespace nifi { +namespace minifi { +namespace io { + +Socket::Socket(const std::shared_ptr<SocketContext>& /*context*/, std::string hostname, const uint16_t port, const uint16_t listeners) + : requested_hostname_(std::move(hostname)), + port_(port), + listeners_(listeners), + logger_(logging::LoggerFactory<Socket>::getLogger()) { + FD_ZERO(&total_list_); + FD_ZERO(&read_fds_); + initialize_socket(); +} + +Socket::Socket(const std::shared_ptr<SocketContext>& context, std::string hostname, const uint16_t port) + : Socket(context, std::move(hostname), port, 0) { +} + +Socket::Socket(Socket &&other) noexcept + : requested_hostname_{ std::move(other.requested_hostname_) }, + canonical_hostname_{ std::move(other.canonical_hostname_) }, + port_{ other.port_ }, + is_loopback_only_{ other.is_loopback_only_ }, + local_network_interface_{ std::move(other.local_network_interface_) }, + socket_file_descriptor_{ other.socket_file_descriptor_ }, + total_list_{ other.total_list_ }, + read_fds_{ other.read_fds_ }, + socket_max_{ other.socket_max_.load() }, + total_written_{ other.total_written_.load() }, + total_read_{ other.total_read_.load() }, + listeners_{ other.listeners_ }, + nonBlocking_{ other.nonBlocking_ }, + logger_{ other.logger_ } +{ + other = Socket{ {}, {}, {} }; +} + +Socket& Socket::operator=(Socket &&other) noexcept { + if (&other == this) return *this; + requested_hostname_ = util::exchange(other.requested_hostname_, ""); + canonical_hostname_ = util::exchange(other.canonical_hostname_, ""); + port_ = util::exchange(other.port_, 0); + is_loopback_only_ = util::exchange(other.is_loopback_only_, false); + local_network_interface_ = util::exchange(other.local_network_interface_, {}); + socket_file_descriptor_ = util::exchange(other.socket_file_descriptor_, INVALID_SOCKET); + total_list_ = other.total_list_; + FD_ZERO(&other.total_list_); + read_fds_ = other.read_fds_; + FD_ZERO(&other.read_fds_); + socket_max_.exchange(other.socket_max_); + other.socket_max_.exchange(0); + total_written_.exchange(other.total_written_); + other.total_written_.exchange(0); + total_read_.exchange(other.total_read_); + other.total_read_.exchange(0); + listeners_ = util::exchange(other.listeners_, 0); + nonBlocking_ = util::exchange(other.nonBlocking_, false); + logger_ = other.logger_; + return *this; +} + +Socket::~Socket() { + Socket::closeStream(); +} + +void Socket::closeStream() { + if (valid_sock_fd(socket_file_descriptor_)) { + logging::LOG_DEBUG(logger_) << "Closing " << socket_file_descriptor_; +#ifdef WIN32 + closesocket(socket_file_descriptor_); +#else + close(socket_file_descriptor_); +#endif + socket_file_descriptor_ = INVALID_SOCKET; + } + if (total_written_ > 0) { + local_network_interface_.log_write(total_written_); + total_written_ = 0; + } + if (total_read_ > 0) { + local_network_interface_.log_read(total_read_); + total_read_ = 0; + } +} + +void Socket::setNonBlocking() { + if (listeners_ <= 0) { + nonBlocking_ = true; + } +} + +int8_t Socket::createConnection(const addrinfo* const destination_addresses) { + for (const auto *current_addr = destination_addresses; current_addr; current_addr = current_addr->ai_next) { + if (!valid_sock_fd(socket_file_descriptor_ = socket(current_addr->ai_family, current_addr->ai_socktype, current_addr->ai_protocol))) { + logger_->log_warn("socket: %s", get_last_err_str()); + continue; + } + setSocketOptions(socket_file_descriptor_); + + if (listeners_ > 0) { + // server socket + const auto bind_result = bind(socket_file_descriptor_, current_addr->ai_addr, current_addr->ai_addrlen); + if (bind_result == SOCKET_ERROR) { + logger_->log_warn("bind: %s", get_last_err_str()); + closeStream(); + continue; + } + + const auto listen_result = listen(socket_file_descriptor_, listeners_); + if (listen_result == SOCKET_ERROR) { + logger_->log_warn("listen: %s", get_last_err_str()); + closeStream(); + continue; + } + + logger_->log_info("Listening on %s:%" PRIu16 " with backlog %" PRIu16, sockaddr_ntop(current_addr->ai_addr), port_, listeners_); + } else { + // client socket +#ifndef WIN32 + if (!local_network_interface_.getInterface().empty()) { + const auto err = bind_to_local_network_interface(socket_file_descriptor_, local_network_interface_); + if (err) logger_->log_info("Bind to interface %s failed %s", local_network_interface_.getInterface(), err.message()); + else logger_->log_info("Bind to interface %s", local_network_interface_.getInterface()); + } +#endif /* !WIN32 */ + + const auto connect_result = connect(socket_file_descriptor_, current_addr->ai_addr, current_addr->ai_addrlen); + if (connect_result == SOCKET_ERROR) { + logger_->log_warn("Couldn't connect to %s:%" PRIu16 ": %s", sockaddr_ntop(current_addr->ai_addr), port_, get_last_err_str()); + closeStream(); + continue; + } + + logger_->log_info("Connected to %s:%" PRIu16, sockaddr_ntop(current_addr->ai_addr), port_); + } + + FD_SET(socket_file_descriptor_, &total_list_); + socket_max_ = socket_file_descriptor_; + return 0; + } + return -1; +} + +int8_t Socket::createConnection(const addrinfo *, ip4addr &addr) { + if (!valid_sock_fd(socket_file_descriptor_ = socket(AF_INET, SOCK_STREAM, 0))) { + logger_->log_error("error while connecting to server socket"); + return -1; + } + + setSocketOptions(socket_file_descriptor_); + + if (listeners_ > 0) { + // server socket + sockaddr_in sa{}; + memset(&sa, 0, sizeof(struct sockaddr_in)); + sa.sin_family = AF_INET; + sa.sin_port = htons(port_); + sa.sin_addr.s_addr = htonl(is_loopback_only_ ? INADDR_LOOPBACK : INADDR_ANY); + if (bind(socket_file_descriptor_, reinterpret_cast<const sockaddr*>(&sa), sizeof(struct sockaddr_in)) == SOCKET_ERROR) { + logger_->log_error("Could not bind to socket, reason %s", get_last_err_str()); + return -1; + } + + if (listen(socket_file_descriptor_, listeners_) == -1) { + return -1; + } + logger_->log_debug("Created connection with %d listeners", listeners_); + } else { + // client socket +#ifndef WIN32 + if (!local_network_interface_.getInterface().empty()) { + const auto err = bind_to_local_network_interface(socket_file_descriptor_, local_network_interface_); + if (err) logger_->log_info("Bind to interface %s failed %s", local_network_interface_.getInterface(), err.message()); + else logger_->log_info("Bind to interface %s", local_network_interface_.getInterface()); + } +#endif /* !WIN32 */ + sockaddr_in sa_loc{}; + memset(&sa_loc, 0x00, sizeof(sa_loc)); + sa_loc.sin_family = AF_INET; + sa_loc.sin_port = htons(port_); + // use any address if you are connecting to the local machine for testing + // otherwise we must use the requested hostname + if (IsNullOrEmpty(requested_hostname_) || requested_hostname_ == "localhost") { + sa_loc.sin_addr.s_addr = htonl(is_loopback_only_ ? INADDR_LOOPBACK : INADDR_ANY); + } else { +#ifdef WIN32 + sa_loc.sin_addr.s_addr = addr.s_addr; + } + if (connect(socket_file_descriptor_, reinterpret_cast<const sockaddr*>(&sa_loc), sizeof(sockaddr_in)) == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEADDRNOTAVAIL) { + logger_->log_error("invalid or unknown IP"); + } else if (err == WSAECONNREFUSED) { + logger_->log_error("Connection refused"); + } else { + logger_->log_error("Unknown error"); + } +#else + sa_loc.sin_addr.s_addr = addr; + } + if (connect(socket_file_descriptor_, reinterpret_cast<const sockaddr *>(&sa_loc), sizeof(sockaddr_in)) < 0) { +#endif /* WIN32 */ + closeStream(); + return -1; + } + } + + // add the listener to the total set + FD_SET(socket_file_descriptor_, &total_list_); + socket_max_ = socket_file_descriptor_; + logger_->log_debug("Created connection with file descriptor %d", socket_file_descriptor_); + return 0; +} + +int16_t Socket::initialize() { + addrinfo hints{}; + memset(&hints, 0, sizeof hints); // make sure the struct is empty + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_CANONNAME; + if (listeners_ > 0 && !is_loopback_only_) + hints.ai_flags = AI_PASSIVE; + hints.ai_protocol = 0; /* any protocol */ + + const char* const gai_node = [this]() -> const char* { + if (is_loopback_only_) return "localhost"; + if (!is_loopback_only_ && listeners_ > 0) return nullptr; // all non-localhost server sockets listen on wildcard address + if (!requested_hostname_.empty()) return requested_hostname_.c_str(); + return nullptr; + }(); + const auto gai_service = std::to_string(port_); + addrinfo* getaddrinfo_result = nullptr; + const int errcode = getaddrinfo(gai_node, gai_service.c_str(), &hints, &getaddrinfo_result); + const std::unique_ptr<addrinfo, util::addrinfo_deleter> addr_info{ getaddrinfo_result }; + getaddrinfo_result = nullptr; + if (errcode != 0) { + logger_->log_error("getaddrinfo: %s", get_last_getaddrinfo_err_str(errcode)); + return -1; + } + socket_file_descriptor_ = INVALID_SOCKET; + + // AI_CANONNAME always sets ai_canonname of the first addrinfo structure + canonical_hostname_ = !IsNullOrEmpty(addr_info->ai_canonname) ? addr_info->ai_canonname : requested_hostname_; Review comment: Behavior change: In the old code, we did a name lookup for the canonical name even for server sockets. This is no longer the case since now there is only one name lookup instead of two and we specify a wildcard node to `getaddrinfo` so that we get a result with wildcard addresses back. This means that the name resolution is no longer performed and we no longer set the `canonical_hostname_` member. As a stopgap, I'm assigning `requested_hostname_` to `canonical_hostname_`, but if this change in behavior is not acceptable, then we need to do a second name lookup in those cases anyway. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
