trevor211 commented on code in PR #1836: URL: https://github.com/apache/incubator-brpc/pull/1836#discussion_r977579941
########## src/brpc/rdma/rdma_endpoint.cpp: ########## @@ -0,0 +1,1467 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_RDMA + +#include <gflags/gflags.h> +#include "butil/fd_utility.h" +#include "butil/logging.h" // CHECK, LOG +#include "butil/sys_byteorder.h" // HostToNet,NetToHost +#include "bthread/bthread.h" +#include "brpc/errno.pb.h" +#include "brpc/event_dispatcher.h" +#include "brpc/input_messenger.h" +#include "brpc/socket.h" +#include "brpc/reloadable_flags.h" +#include "brpc/rdma/block_pool.h" +#include "brpc/rdma/rdma_helper.h" +#include "brpc/rdma/rdma_endpoint.h" + + +namespace brpc { +namespace rdma { + +extern ibv_cq* (*IbvCreateCq)(ibv_context*, int, void*, ibv_comp_channel*, int); +extern int (*IbvDestroyCq)(ibv_cq*); +extern ibv_comp_channel* (*IbvCreateCompChannel)(ibv_context*); +extern int (*IbvDestroyCompChannel)(ibv_comp_channel*); +extern int (*IbvGetCqEvent)(ibv_comp_channel*, ibv_cq**, void**); +extern void (*IbvAckCqEvents)(ibv_cq*, unsigned int); +extern ibv_qp* (*IbvCreateQp)(ibv_pd*, ibv_qp_init_attr*); +extern int (*IbvModifyQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask); +extern int (*IbvQueryQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask, ibv_qp_init_attr*); +extern int (*IbvDestroyQp)(ibv_qp*); +extern bool g_skip_rdma_init; + +DEFINE_int32(rdma_sq_size, 128, "SQ size for RDMA"); +DEFINE_int32(rdma_rq_size, 128, "RQ size for RDMA"); +DEFINE_bool(rdma_recv_zerocopy, true, "Enable zerocopy for receive side"); +DEFINE_int32(rdma_zerocopy_min_size, 512, "The minimal size for receive zerocopy"); +DEFINE_string(rdma_recv_block_type, "default", "Default size type for recv WR: " + "default(8KB - 32B)/large(64KB - 32B)/huge(2MB - 32B)"); +DEFINE_int32(rdma_cqe_poll_once, 32, "The maximum of cqe number polled once."); +DEFINE_int32(rdma_prepared_qp_size, 128, "SQ and RQ size for prepared QP."); +DEFINE_int32(rdma_prepared_qp_cnt, 1024, "Initial count of prepared QP."); +DEFINE_bool(rdma_trace_verbose, false, "Print log message verbosely"); +BRPC_VALIDATE_GFLAG(rdma_trace_verbose, brpc::PassValidate); + +static const size_t IOBUF_BLOCK_HEADER_LEN = 32; // implementation-dependent +static const size_t IOBUF_BLOCK_DEFAULT_PAYLOAD = + butil::IOBuf::DEFAULT_BLOCK_SIZE - IOBUF_BLOCK_HEADER_LEN; + +// DO NOT change this value unless you know the safe value!!! +// This is the number of reserved WRs in SQ/RQ for pure ACK. +static const size_t RESERVED_WR_NUM = 3; + +// magic string RDMA (4B) +// message length (2B) +// hello version (2B) +// impl version (2B): 0 means should use tcp +// block size (2B) +// sq size (2B) +// rq size (2B) +// GID (16B) +// QP number (4B) +static const char* MAGIC_STR = "RDMA"; +static const size_t MAGIC_STR_LEN = 4; +static const size_t HELLO_MSG_LEN_MIN = 38; +static const size_t HELLO_MSG_LEN_MAX = 4096; +static const size_t ACK_MSG_LEN = 4; +static uint16_t g_rdma_hello_msg_len = 38; // In Byte +static uint16_t g_rdma_hello_version = 1; +static uint16_t g_rdma_impl_version = 1; +static uint16_t g_rdma_recv_block_size = 0; + +static const uint32_t MAX_INLINE_DATA = 64; +static const uint8_t MAX_HOP_LIMIT = 16; +static const uint8_t TIMEOUT = 14; +static const uint8_t RETRY_CNT = 7; +static const uint16_t MIN_QP_SIZE = 16; +static const uint16_t MIN_BLOCK_SIZE = 1024; +static const uint32_t ACK_MSG_RDMA_OK = 0x1; + +static butil::Mutex* g_rdma_resource_mutex = NULL; +static RdmaResource* g_rdma_resource_list = NULL; + +struct HelloMessage { + void Serialize(void* data) const; + void Deserialize(void* data); + + uint16_t msg_len; + uint16_t hello_ver; + uint16_t impl_ver; + uint16_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +void HelloMessage::Serialize(void* data) const { + uint16_t* current_pos = (uint16_t*)data; + *(current_pos++) = butil::HostToNet16(msg_len); + *(current_pos++) = butil::HostToNet16(hello_ver); + *(current_pos++) = butil::HostToNet16(impl_ver); + *(current_pos++) = butil::HostToNet16(block_size); + *(current_pos++) = butil::HostToNet16(sq_size); + *(current_pos++) = butil::HostToNet16(rq_size); + *(current_pos++) = butil::HostToNet16(lid); + memcpy(current_pos, gid.raw, 16); + uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); + *qp_num_pos = butil::HostToNet32(qp_num); +} + +void HelloMessage::Deserialize(void* data) { + uint16_t* current_pos = (uint16_t*)data; + msg_len = butil::NetToHost16(*current_pos++); + hello_ver = butil::NetToHost16(*current_pos++); + impl_ver = butil::NetToHost16(*current_pos++); + block_size = butil::NetToHost16(*current_pos++); + sq_size = butil::NetToHost16(*current_pos++); + rq_size = butil::NetToHost16(*current_pos++); + lid = butil::NetToHost16(*current_pos++); + memcpy(gid.raw, current_pos, 16); + qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); +} + +RdmaResource::RdmaResource() + : qp(NULL) + , cq(NULL) + , comp_channel(NULL) + , next(NULL) { } + +RdmaResource::~RdmaResource() { + if (qp) { + IbvDestroyQp(qp); + qp = NULL; + } + if (cq) { + IbvDestroyCq(cq); + cq = NULL; + } + if (comp_channel) { + IbvDestroyCompChannel(comp_channel); + comp_channel = NULL; + } +} + +RdmaEndpoint::RdmaEndpoint(Socket* s) + : _socket(s) + , _state(UNINIT) + , _resource(NULL) + , _cq_events(0) + , _cq_sid(INVALID_SOCKET_ID) + , _sq_size(FLAGS_rdma_sq_size) + , _rq_size(FLAGS_rdma_rq_size) + , _sbuf() + , _rbuf() + , _rbuf_data() + , _remote_recv_block_size(0) + , _accumulated_ack(0) + , _unsolicited(0) + , _unsolicited_bytes(0) + , _sq_current(0) + , _sq_unsignaled(0) + , _sq_sent(0) + , _rq_received(0) + , _local_window_capacity(0) + , _remote_window_capacity(0) + , _window_size(0) + , _new_rq_wrs(0) +{ + if (_sq_size < MIN_QP_SIZE) { + _sq_size = MIN_QP_SIZE; + } + if (_rq_size < MIN_QP_SIZE) { + _rq_size = MIN_QP_SIZE; + } + _read_butex = bthread::butex_create_checked<butil::atomic<int> >(); +} + +RdmaEndpoint::~RdmaEndpoint() { + Reset(); + bthread::butex_destroy(_read_butex); +} + +void RdmaEndpoint::Reset() { + DeallocateResources(); + + _cq_events = 0; + _cq_sid = INVALID_SOCKET_ID; + _state = UNINIT; + _sbuf.clear(); + _rbuf.clear(); + _rbuf_data.clear(); + _accumulated_ack = 0; + _unsolicited = 0; + _sq_current = 0; + _sq_unsignaled = 0; + _local_window_capacity = 0; + _remote_window_capacity = 0; + _window_size.store(0, butil::memory_order_relaxed); + _new_rq_wrs = 0; + _sq_sent = 0; + _rq_received = 0; +} + +void RdmaConnect::StartConnect(const Socket* socket, + void (*done)(int err, void* data), + void* data) { + CHECK(socket->_rdma_ep != NULL); + SocketUniquePtr s; + if (Socket::Address(socket->id(), &s) != 0) { + return; + } + if (!IsRdmaAvailable()) { + socket->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP; + s->_rdma_state = Socket::RDMA_OFF; + done(0, data); + return; + } + _done = done; + _data = data; + bthread_t tid; + if (bthread_start_background(&tid, &BTHREAD_ATTR_NORMAL, + RdmaEndpoint::ProcessHandshakeAtClient, socket->_rdma_ep) < 0) { + LOG(FATAL) << "Fail to start handshake bthread"; + } else { + s.release(); + } +} + +void RdmaConnect::StopConnect(Socket* socket) { } + +void RdmaConnect::Run() { + _done(errno, _data); +} + +static void TryReadOnTcpDuringRdmaEst(Socket* s) { + int progress = Socket::PROGRESS_INIT; + while (true) { + uint8_t tmp; + ssize_t nr = read(s->fd(), &tmp, 1); + if (nr < 0) { + if (errno != EAGAIN) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read from " << s; + s->SetFailed(saved_errno, "Fail to read from %s: %s", + s->description().c_str(), berror(saved_errno)); + return; + } + if (!s->MoreReadEvents(&progress)) { + break; + } + } else if (nr == 0) { + s->SetEOF(); + return; + } else { + LOG(WARNING) << "Read unexpected data from " << s; + s->SetFailed(EPROTO, "Read unexpected data from %s", + s->description().c_str()); + return; + } + } +} + +void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { + RdmaEndpoint* ep = m->_rdma_ep; + CHECK(ep != NULL); + + int progress = Socket::PROGRESS_INIT; + while (true) { + if (ep->_state == UNINIT) { + if (!m->CreatedByConnect()) { + if (!IsRdmaAvailable()) { + ep->_state = FALLBACK_TCP; + m->_rdma_state = Socket::RDMA_OFF; + continue; + } + bthread_t tid; + ep->_state = S_HELLO_WAIT; + SocketUniquePtr s; + m->ReAddress(&s); + if (bthread_start_background(&tid, &BTHREAD_ATTR_NORMAL, + ProcessHandshakeAtServer, ep) < 0) { + ep->_state = UNINIT; + LOG(FATAL) << "Fail to start handshake bthread"; + } else { + s.release(); + } + } else { + // The connection may be closed or reset before the client + // starts handshake. This will be handled by client handshake. + // Ignore the exception here. + } + } else if (ep->_state < ESTABLISHED) { // during handshake + ep->_read_butex->fetch_add(1, butil::memory_order_release); + bthread::butex_wake(ep->_read_butex); + } else if (ep->_state == FALLBACK_TCP){ // handshake finishes + InputMessenger::OnNewMessages(m); + return; + } else if (ep->_state == ESTABLISHED) { + TryReadOnTcpDuringRdmaEst(ep->_socket); + return; + } + if (!m->MoreReadEvents(&progress)) { + break; + } + } +} + +bool HelloNegotiationValid(HelloMessage& msg) { + if (msg.hello_ver == g_rdma_hello_version && + msg.impl_ver == g_rdma_impl_version && + msg.block_size >= MIN_BLOCK_SIZE && + msg.sq_size >= MIN_QP_SIZE && + msg.rq_size >= MIN_QP_SIZE) { + // This can be modified for future compatibility + return true; + } + return false; +} + +static const int WAIT_TIMEOUT_MS = 50; + +int RdmaEndpoint::ReadFromFd(void* data, size_t len) { + CHECK(data != NULL); + int nr = 0; + size_t received = 0; + do { + const int expected_val = _read_butex->load(butil::memory_order_acquire); + const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); + nr = read(_socket->fd(), (uint8_t*)data + received, len - received); + if (nr < 0) { + if (errno == EAGAIN) { + if (bthread::butex_wait(_read_butex, expected_val, &duetime) < 0) { + if (errno != EWOULDBLOCK && errno != ETIMEDOUT) { + return -1; + } + } + } else { + return -1; + } + } else if (nr == 0) { // Got EOF + errno = EEOF; + return -1; + } else { + received += nr; + } + } while (received < len); + return 0; +} + +int RdmaEndpoint::WriteToFd(void* data, size_t len) { + CHECK(data != NULL); + int nw = 0; + size_t written = 0; + do { + const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); + nw = write(_socket->fd(), (uint8_t*)data + written, len - written); + if (nw < 0) { + if (errno == EAGAIN) { + if (_socket->WaitEpollOut(_socket->fd(), true, &duetime) < 0) { + if (errno != ETIMEDOUT) { + return -1; + } + } + } else { + return -1; + } + } else { + written += nw; + } + } while (written < len); + return 0; +} + +inline void RdmaEndpoint::TryReadOnTcp() { + if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { + if (_state == FALLBACK_TCP) { + InputMessenger::OnNewMessages(_socket); + } else if (_state == ESTABLISHED) { + TryReadOnTcpDuringRdmaEst(_socket); + } + } +} + +void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { + RdmaEndpoint* ep = static_cast<RdmaEndpoint*>(arg); + SocketUniquePtr s(ep->_socket); + RdmaConnect::RunGuard rg((RdmaConnect*)s->_app_connect.get()); + + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Start handshake on " << s->_local_side; + + void* data = malloc(g_rdma_hello_msg_len); + if (!data) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + // First initialize CQ and QP resources + ep->_state = C_ALLOC_QPCQ; + if (ep->AllocateResources() < 0) { + LOG(WARNING) << "Fallback to tcp:" << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + ep->_state = FALLBACK_TCP; + return NULL; + } + + // Send hello message to server + ep->_state = C_HELLO_SEND; + HelloMessage local_msg; + local_msg.msg_len = g_rdma_hello_msg_len; + local_msg.hello_ver = g_rdma_hello_version; + local_msg.impl_ver = g_rdma_impl_version; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = ep->_sq_size; + local_msg.rq_size = ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(ep->_resource)) { + local_msg.qp_num = ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + // Check magic str + ep->_state = C_HELLO_WAIT; + if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to get hello message from server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { + LOG(WARNING) << "Read unexpected data during handshake:" << s->description(); + s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + + // Read hello message from server + if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to get Hello Message from server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + HelloMessage remote_msg; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { + LOG(WARNING) << "Fail to parse Hello Message length from server:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // TODO: Read Hello Message customized data + // Just for future use, should not happen now + } + + if (!HelloNegotiationValid(remote_msg)) { + LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" + << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + } else { + ep->_remote_recv_block_size = remote_msg.block_size; + ep->_local_window_capacity = + std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; + ep->_remote_window_capacity = + std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM, + ep->_window_size.store(ep->_local_window_capacity, butil::memory_order_relaxed); + + ep->_state = C_BRINGUP_QP; + if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { + LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + } else { + s->_rdma_state = Socket::RDMA_ON; + } + } + + // Send ACK message to server + ep->_state = C_ACK_SEND; + uint32_t flags = 0; + if (s->_rdma_state != Socket::RDMA_OFF) { + flags |= ACK_MSG_RDMA_OK; + } + *(uint32_t*)data = butil::HostToNet32(flags); + if (ep->WriteToFd(data, ACK_MSG_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send Ack Message to server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + if (s->_rdma_state == Socket::RDMA_ON) { + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Handshake ends (use rdma) on " << s->description(); + } else { + ep->_state = FALLBACK_TCP; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Handshake ends (use tcp) on " << s->description(); + } + + errno = 0; + + return NULL; +} + +void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { + RdmaEndpoint* ep = static_cast<RdmaEndpoint*>(arg); + SocketUniquePtr s(ep->_socket); + + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Start handshake on " << s->description(); + + void* data = malloc(g_rdma_hello_msg_len); + if (!data) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to recv hello message from client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + ep->_state = S_HELLO_WAIT; + if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description() << " " << s->_remote_side; + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "It seems that the " + << "client does not use RDMA, fallback to TCP:" + << s->description(); + // we need to copy data read back to _socket->_read_buf + s->_read_buf.append(data, MAGIC_STR_LEN); + ep->_state = FALLBACK_TCP; + s->_rdma_state = Socket::RDMA_OFF; + ep->TryReadOnTcp(); + return NULL; + } + + if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + HelloMessage remote_msg; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { + LOG(WARNING) << "Fail to parse Hello Message length from client:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // TODO: Read Hello Message customized header + // Just for future use, should not happen now + } + + if (!HelloNegotiationValid(remote_msg)) { + LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" + << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + } else { + ep->_remote_recv_block_size = remote_msg.block_size; + ep->_local_window_capacity = + std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; + ep->_remote_window_capacity = + std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM, + ep->_window_size.store(ep->_local_window_capacity, butil::memory_order_relaxed); + + ep->_state = S_ALLOC_QPCQ; + if (ep->AllocateResources() < 0) { + LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" + << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + } else { + ep->_state = S_BRINGUP_QP; + if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { + LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" + << s->description(); + s->_rdma_state = Socket::RDMA_OFF; + } + } + } + + // Send hello message to client + ep->_state = S_HELLO_SEND; + HelloMessage local_msg; + local_msg.msg_len = g_rdma_hello_msg_len; + if (s->_rdma_state == Socket::RDMA_OFF) { + local_msg.impl_ver = 0; + local_msg.hello_ver = 0; + } else { + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = ep->_sq_size; + local_msg.rq_size = ep->_rq_size; + local_msg.hello_ver = g_rdma_hello_version; + local_msg.impl_ver = g_rdma_impl_version; + if (BAIDU_LIKELY(ep->_resource)) { + local_msg.qp_num = ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + } + memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + // Recv ACK Message + ep->_state = S_ACK_WAIT; + if (ep->ReadFromFd(data, ACK_MSG_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read ack message from client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + // Check RDMA enable flag + uint32_t flags = butil::NetToHost32(*(uint32_t*)data); + if (flags & ACK_MSG_RDMA_OK) { + if (s->_rdma_state == Socket::RDMA_OFF) { + LOG(WARNING) << "Fail to parse Hello Message length from client:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } else { + s->_rdma_state = Socket::RDMA_ON; + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Handshake ends (use rdma) on " << s->description(); + } + } else { + s->_rdma_state = Socket::RDMA_OFF; + ep->_state = FALLBACK_TCP; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Handshake ends (use tcp) on " << s->description(); + } + + ep->TryReadOnTcp(); + + return NULL; +} + +bool RdmaEndpoint::IsWritable() const { + if (BAIDU_UNLIKELY(g_skip_rdma_init)) { + // Just for UT + return false; + } + + return _window_size.load(butil::memory_order_relaxed) > 0; +} + +// RdmaIOBuf inherits from IOBuf to provide a new function. +// The reason is that we need to use some protected member function of IOBuf. +class RdmaIOBuf : public butil::IOBuf { +friend class RdmaEndpoint; +private: + // Cut the current IOBuf to ibv_sge list and `to' for at most first max_sge + // blocks or first max_len bytes. + // Return: the bytes included in the sglist, or -1 if failed + ssize_t cut_into_sglist_and_iobuf(ibv_sge* sglist, size_t* sge_index, + butil::IOBuf* to, size_t max_sge, size_t max_len) { + size_t len = 0; + while (*sge_index < max_sge) { + if (len == max_len || _ref_num() == 0) { + break; + } + butil::IOBuf::BlockRef const& r = _ref_at(0); + CHECK(r.length > 0); + const void* start = fetch1(); + uint32_t lkey = GetLKey((char*)start - r.offset); + if (lkey == 0) { + LOG(WARNING) << "Memory not registered for rdma. " + << "Is this iobuf allocated before calling " + << "GlobalRdmaInitializeOrDie? Or just forget to " + << "call RegisterMemoryForRdma for your own buffer?"; + errno = ERDMAMEM; + return -1; + } + size_t i = *sge_index; + if (len + r.length > max_len) { + // Split the block to comply with size for receiving + sglist[i].length = max_len - len; + len = max_len; + } else { + sglist[i].length = r.length; + len += r.length; + } + sglist[i].addr = (uint64_t)start; + sglist[i].lkey = lkey; + cutn(to, sglist[i].length); + (*sge_index)++; + } + return len; + } +}; + +// Note this function is coupled with the implementation of IOBuf +ssize_t RdmaEndpoint::CutFromIOBufList(butil::IOBuf** from, size_t ndata) { + if (BAIDU_UNLIKELY(g_skip_rdma_init)) { + // Just for UT + errno = EAGAIN; + return -1; + } + + CHECK(from != NULL); + CHECK(ndata > 0); + + size_t total_len = 0; + size_t current = 0; + uint32_t window = 0; + ibv_send_wr wr; + int max_sge = GetRdmaMaxSge(); + ibv_sge sglist[max_sge]; + while (current < ndata) { + window = _window_size.load(butil::memory_order_relaxed); + if (window == 0) { + if (total_len > 0) { + break; + } else { + errno = EAGAIN; + return -1; + } + } + butil::IOBuf* to = &_sbuf[_sq_current]; + size_t this_len = 0; + + memset(&wr, 0, sizeof(wr)); + wr.sg_list = sglist; + wr.opcode = IBV_WR_SEND_WITH_IMM; + + RdmaIOBuf* data = (RdmaIOBuf*)from[current]; + size_t sge_index = 0; + while (sge_index < (uint32_t)max_sge && + this_len < _remote_recv_block_size) { + if (data->size() == 0) { + // The current IOBuf is empty, find next one + ++current; + if (current == ndata) { + break; + } + data = (RdmaIOBuf*)from[current]; + continue; + } + + ssize_t len = data->cut_into_sglist_and_iobuf( + sglist, &sge_index, to, max_sge, + _remote_recv_block_size - this_len); + if (len < 0) { + return -1; + } + CHECK(len > 0); + this_len += len; + total_len += len; + } + if (this_len == 0) { + continue; + } + + wr.num_sge = sge_index; + + uint32_t imm = _new_rq_wrs.exchange(0, butil::memory_order_relaxed); + wr.imm_data = butil::HostToNet32(imm); + // Avoid too much recv completion event to reduce the cpu overhead + bool solicited = false; + if (window == 1 || current + 1 >= ndata) { + // Only last message in the write queue or last message in the + // current window will be flagged as solicited. + solicited = true; + } else { + if (_unsolicited > _local_window_capacity / 4) { + // Make sure the recv side can be signaled to return ack + solicited = true; + } else if (_accumulated_ack > _remote_window_capacity / 4) { + // Make sure the recv side can be signaled to handle ack + solicited = true; + } else if (_unsolicited_bytes > 1048576) { + // Make sure the recv side can be signaled when it receives enough data + solicited = true; + } else { + ++_unsolicited; + _unsolicited_bytes += this_len; + _accumulated_ack += imm; + } + } + if (solicited) { + wr.send_flags |= IBV_SEND_SOLICITED; + _unsolicited = 0; + _unsolicited_bytes = 0; + _accumulated_ack = 0; + } + + // Avoid too much send completion event to reduce the CPU overhead + ++_sq_unsignaled; + if (_sq_unsignaled >= _local_window_capacity / 4) { + // Refer to: + // http::www.rdmamojo.com/2014/06/30/working-unsignaled-completions/ + wr.send_flags |= IBV_SEND_SIGNALED; + _sq_unsignaled = 0; + } + + ibv_send_wr* bad = NULL; + if (ibv_post_send(_resource->qp, &wr, &bad) < 0) { + // We use other way to guarantee the Send Queue is not full. + // So we just consider this error as an unrecoverable error. + PLOG(WARNING) << "Fail to ibv_post_send"; + return -1; + } + + ++_sq_current; + if (_sq_current == _sq_size - RESERVED_WR_NUM) { + _sq_current = 0; + } + + // Update _window_size. Note that _window_size will never be negative. + // Because there is at most one thread can enter this function for each + // Socket, and the other thread of HandleCompletion can only add this + // counter. + _window_size.fetch_sub(1, butil::memory_order_relaxed); + } + + return total_len; +} + +int RdmaEndpoint::SendAck(int num) { + if (_new_rq_wrs.fetch_add(num, butil::memory_order_relaxed) > _remote_window_capacity / 2) { + return SendImm(_new_rq_wrs.exchange(0, butil::memory_order_relaxed)); + } + return 0; +} + +int RdmaEndpoint::SendImm(uint32_t imm) { + if (imm == 0) { + return 0; + } + + ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.imm_data = butil::HostToNet32(imm); + wr.send_flags |= IBV_SEND_SOLICITED; + wr.send_flags |= IBV_SEND_SIGNALED; + + ibv_send_wr* bad = NULL; + if (ibv_post_send(_resource->qp, &wr, &bad) < 0) { + // We use other way to guarantee the Send Queue is not full. + // So we just consider this error as an unrecoverable error. + PLOG(WARNING) << "Fail to ibv_post_send"; + return -1; + } + return 0; +} + +ssize_t RdmaEndpoint::HandleCompletion(ibv_wc& wc) { + bool zerocopy = FLAGS_rdma_recv_zerocopy; + switch (wc.opcode) { + case IBV_WC_SEND: { // send completion + // Do nothing + break; + } + case IBV_WC_RECV: { // recv completion + // Please note that only the first wc.byte_len bytes is valid + if (wc.byte_len > 0) { + if (wc.byte_len < (uint32_t)FLAGS_rdma_zerocopy_min_size) { + zerocopy = false; + } + CHECK(_state != FALLBACK_TCP); + if (zerocopy) { + butil::IOBuf tmp; + _rbuf[_rq_received].cutn(&tmp, wc.byte_len); + _socket->_read_buf.append(tmp); + } else { + // Copy data when the receive data is really small + _socket->_read_buf.append(_rbuf_data[_rq_received], wc.byte_len); + } + } + if (wc.imm_data > 0) { + // Clear sbuf here because we ignore event wakeup for send completions + uint32_t acks = butil::NetToHost32(wc.imm_data); + uint32_t num = acks; + while (num > 0) { + _sbuf[_sq_sent++].clear(); + if (_sq_sent == _sq_size - RESERVED_WR_NUM) { + _sq_sent = 0; + } + --num; + } + butil::subtle::MemoryBarrier(); + + // Update window + uint32_t wnd_thresh = _local_window_capacity / 8; + if (_window_size.fetch_add(acks, butil::memory_order_relaxed) >= wnd_thresh + || acks >= wnd_thresh) { + // Do not wake up writing thread right after _window_size > 0. + // Otherwise the writing thread may switch to background too quickly. + _socket->WakeAsEpollOut(); + } + } + // We must re-post recv WR + if (PostRecv(1, zerocopy) < 0) { + return -1; + } + if (wc.byte_len > 0) { + SendAck(1); + } + return wc.byte_len; + } + default: + CHECK(false) << "This should not happen"; Review Comment: ```suggestion // Do not CHECK abort since wc.opcode could be IBV_WC_DRIVER2(136) which is // a bug already fixed by rdma-core. // FYI: https://github.com/linux-rdma/rdma-core/commit/4c905646de3e75bdccada4abe9f0d273d76eaf50 LOG(WARNING) << "This should not happen, wc.opcode = " << wc.opcode; ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
