lidavidm commented on a change in pull request #12442:
URL: https://github.com/apache/arrow/pull/12442#discussion_r841089057



##########
File path: cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
##########
@@ -0,0 +1,1164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <array>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Defines to test different implementation strategies
+// Enable the CONTIG path for CPU-only data
+// #define ARROW_FLIGHT_UCX_SEND_CONTIG
+// Enable ucp_mem_map in IOV path
+// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP
+
+constexpr char kHeaderMethod[] = ":method:";
+
+namespace {
+Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) {
+  if (ARROW_PREDICT_FALSE(in < 0)) {
+    return Status::Invalid("Length cannot be negative");
+  } else if (ARROW_PREDICT_FALSE(
+                 in > 
static_cast<int64_t>(std::numeric_limits<uint32_t>::max()))) {
+    return Status::Invalid("Length cannot exceed uint32_t");
+  }
+  UInt32ToBytesBe(static_cast<uint32_t>(in), out);
+  return Status::OK();
+}
+ucs_memory_type InferMemoryType(const Buffer& buffer) {
+  if (!buffer.is_cpu()) {
+    return UCS_MEMORY_TYPE_CUDA;
+  }
+  return UCS_MEMORY_TYPE_UNKNOWN;
+}
+void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size,
+                  ucs_memory_type memory_type, ucp_mem_h* memh_p) {
+  ucp_mem_map_params_t map_param;
+  std::memset(&map_param, 0, sizeof(map_param));
+  map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
+                         UCP_MEM_MAP_PARAM_FIELD_LENGTH |
+                         UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
+  map_param.address = const_cast<void*>(buffer);
+  map_param.length = size;
+  map_param.memory_type = memory_type;
+  auto ucs_status = ucp_mem_map(context, &map_param, memh_p);
+  if (ucs_status != UCS_OK) {
+    *memh_p = nullptr;
+    ARROW_LOG(WARNING) << "Could not map memory: "
+                       << FromUcsStatus("ucp_mem_map", ucs_status);
+  }
+}
+void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* 
memh_p) {
+  TryMapBuffer(context, reinterpret_cast<void*>(buffer.address()),
+               static_cast<size_t>(buffer.size()), InferMemoryType(buffer), 
memh_p);
+}
+void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) {
+  if (memh_p) {
+    auto ucs_status = ucp_mem_unmap(context, memh_p);
+    if (ucs_status != UCS_OK) {
+      ARROW_LOG(WARNING) << "Could not unmap memory: "
+                         << FromUcsStatus("ucp_mem_unmap", ucs_status);
+    }
+  }
+}
+
+/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA
+///   buffer).
+///
+/// Owns a reference to the associated worker to avoid undefined
+/// behavior.
+class UcxDataBuffer : public Buffer {
+ public:
+  explicit UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, size_t 
size)
+      : Buffer(const_cast<const uint8_t*>(reinterpret_cast<uint8_t*>(data)),
+               static_cast<int64_t>(size)),
+        worker_(std::move(worker)) {}
+
+  ~UcxDataBuffer() {
+    ucp_am_data_release(worker_->get(),
+                        const_cast<void*>(reinterpret_cast<const 
void*>(data())));
+  }
+
+ private:
+  std::shared_ptr<UcpWorker> worker_;
+};
+};  // namespace
+
+constexpr size_t FrameHeader::kFrameHeaderBytes;
+constexpr uint8_t FrameHeader::kFrameVersion;
+
+Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t 
body_size) {
+  header[0] = kFrameVersion;
+  header[1] = static_cast<uint8_t>(frame_type);
+  UInt32ToBytesBe(counter, header.data() + 4);
+  RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8));
+  return Status::OK();
+}
+
+arrow::Result<std::shared_ptr<Frame>> Frame::ParseHeader(const void* header,
+                                                         size_t header_length) 
{
+  if (header_length < FrameHeader::kFrameHeaderBytes) {
+    return Status::IOError("Header is too short, must be at least ",
+                           FrameHeader::kFrameHeaderBytes, " bytes, got ", 
header_length);
+  }
+
+  const uint8_t* frame_header = reinterpret_cast<const uint8_t*>(header);
+  if (frame_header[0] != FrameHeader::kFrameVersion) {
+    return Status::IOError("Expected frame version ",
+                           static_cast<int>(FrameHeader::kFrameVersion), " but 
got ",
+                           static_cast<int>(frame_header[0]));
+  } else if (frame_header[1] > static_cast<uint8_t>(FrameType::kMaxFrameType)) 
{
+    return Status::IOError("Unknown frame type ", 
static_cast<int>(frame_header[1]));
+  }
+
+  const FrameType frame_type = static_cast<FrameType>(frame_header[1]);
+  const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4);
+  const uint32_t frame_size = BytesToUInt32Be(frame_header + 8);
+
+  if (frame_type == FrameType::kDisconnect) {
+    return Status::Cancelled("Client initiated disconnect");
+  }
+
+  return std::make_shared<Frame>(frame_type, frame_size, frame_counter, 
nullptr);
+}
+
+arrow::Result<HeadersFrame> HeadersFrame::Parse(std::unique_ptr<Buffer> 
buffer) {
+  HeadersFrame result;
+  const uint8_t* payload = buffer->data();
+  const uint8_t* end = payload + buffer->size();
+  if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+    return Status::Invalid("Buffer underflow, expected number of headers");
+  }
+  const uint32_t num_headers = BytesToUInt32Be(payload);
+  payload += 4;
+  for (uint32_t i = 0; i < num_headers; i++) {
+    if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+      return Status::Invalid("Buffer underflow, expected length of key ", i + 
1);
+    }
+    const uint32_t key_length = BytesToUInt32Be(payload);
+    payload += 4;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+      return Status::Invalid("Buffer underflow, expected length of value ", i 
+ 1);
+    }
+    const uint32_t value_length = BytesToUInt32Be(payload);
+    payload += 4;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < key_length)) {
+      return Status::Invalid("Buffer underflow, expected key ", i + 1, " to 
have length ",
+                             key_length, ", but only ", (end - payload), " 
bytes remain");
+    }
+    const util::string_view key(reinterpret_cast<const char*>(payload), 
key_length);
+    payload += key_length;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < value_length)) {
+      return Status::Invalid("Buffer underflow, expected value ", i + 1,
+                             " to have length ", value_length, ", but only ",
+                             (end - payload), " bytes remain");
+    }
+    const util::string_view value(reinterpret_cast<const char*>(payload), 
value_length);
+    payload += value_length;
+    result.headers_.emplace_back(key, value);
+  }
+
+  result.buffer_ = std::move(buffer);
+  return result;
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+    const std::vector<std::pair<std::string, std::string>>& headers) {
+  int32_t total_length = 4 /* # of headers */;
+  for (const auto& header : headers) {
+    total_length += 4 /* key length */ + 4 /* value length */ +
+                    header.first.size() /* key */ + header.second.size();
+  }
+
+  ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length));
+  uint8_t* payload = buffer->mutable_data();
+
+  RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload));
+  payload += 4;
+  for (const auto& header : headers) {
+    RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload));
+    payload += 4;
+    RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload));
+    payload += 4;
+    std::memcpy(payload, header.first.data(), header.first.size());
+    payload += header.first.size();
+    std::memcpy(payload, header.second.data(), header.second.size());
+    payload += header.second.size();
+  }
+  return Parse(std::move(buffer));
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+    const Status& status,
+    const std::vector<std::pair<std::string, std::string>>& headers) {
+  auto all_headers = headers;
+  all_headers.emplace_back(kHeaderStatusCode,
+                           
std::to_string(static_cast<int32_t>(status.code())));
+  all_headers.emplace_back(kHeaderStatusMessage, status.message());
+  if (status.detail()) {
+    auto fsd = FlightStatusDetail::UnwrapStatus(status);
+    if (fsd) {
+      all_headers.emplace_back(kHeaderStatusDetailCode,
+                               
std::to_string(static_cast<int32_t>(fsd->code())));
+      all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
+    } else {
+      all_headers.emplace_back(kHeaderStatusDetail, 
status.detail()->ToString());
+    }
+  }
+  return Make(all_headers);
+}
+
+arrow::Result<util::string_view> HeadersFrame::Get(const std::string& key) {
+  for (const auto& pair : headers_) {
+    if (pair.first == key) return pair.second;
+  }
+  return Status::KeyError(key);
+}
+
+Status HeadersFrame::GetStatus(Status* out) {
+  util::string_view code_str, message_str;
+  auto status = Get(kHeaderStatusCode).Value(&code_str);
+  if (!status.ok()) {
+    return Status::KeyError("Server did not send status code header ", 
kHeaderStatusCode);
+  }
+
+  StatusCode status_code = StatusCode::OK;
+  auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
+  switch (code) {
+    case 0:
+      status_code = StatusCode::OK;
+      break;
+    case 1:
+      status_code = StatusCode::OutOfMemory;
+      break;
+    case 2:
+      status_code = StatusCode::KeyError;
+      break;
+    case 3:
+      status_code = StatusCode::TypeError;
+      break;
+    case 4:
+      status_code = StatusCode::Invalid;
+      break;
+    case 5:
+      status_code = StatusCode::IOError;
+      break;
+    case 6:
+      status_code = StatusCode::CapacityError;
+      break;
+    case 7:
+      status_code = StatusCode::IndexError;
+      break;
+    case 8:
+      status_code = StatusCode::Cancelled;
+      break;
+    case 9:
+      status_code = StatusCode::UnknownError;
+      break;
+    case 10:
+      status_code = StatusCode::NotImplemented;
+      break;
+    case 11:
+      status_code = StatusCode::SerializationError;
+      break;
+    case 13:
+      status_code = StatusCode::RError;
+      break;
+    case 40:
+      status_code = StatusCode::CodeGenError;
+      break;
+    case 41:
+      status_code = StatusCode::ExpressionValidationError;
+      break;
+    case 42:
+      status_code = StatusCode::ExecutionError;
+      break;
+    case 45:
+      status_code = StatusCode::AlreadyExists;
+      break;
+    default:
+      status_code = StatusCode::UnknownError;
+      break;
+  }
+  if (status_code == StatusCode::OK) {
+    *out = Status::OK();
+    return Status::OK();
+  }
+
+  status = Get(kHeaderStatusMessage).Value(&message_str);
+  if (!status.ok()) {
+    *out = Status(status_code, "Server did not send status message header", 
nullptr);
+    return Status::OK();
+  }
+
+  util::string_view detail_code_str, detail_str;
+  FlightStatusCode detail_code = FlightStatusCode::Internal;
+
+  if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) {
+    auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, 
/*base=*/10);
+    switch (detail_code_int) {
+      case 1:
+        detail_code = FlightStatusCode::TimedOut;
+        break;
+      case 2:
+        detail_code = FlightStatusCode::Cancelled;
+        break;
+      case 3:
+        detail_code = FlightStatusCode::Unauthenticated;
+        break;
+      case 4:
+        detail_code = FlightStatusCode::Unauthorized;
+        break;
+      case 5:
+        detail_code = FlightStatusCode::Unavailable;
+        break;
+      case 6:
+        detail_code = FlightStatusCode::Failed;
+        break;
+      case 0:
+      default:
+        detail_code = FlightStatusCode::Internal;
+        break;
+    }
+  }
+  ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str));
+
+  std::shared_ptr<StatusDetail> detail = nullptr;
+  if (!detail_str.empty()) {
+    detail = std::make_shared<FlightStatusDetail>(detail_code, 
std::string(detail_str));
+  }
+  *out = Status(status_code, std::string(message_str), std::move(detail));
+  return Status::OK();
+}
+
+namespace {
+static constexpr uint32_t kMissingFieldSentinel = 
std::numeric_limits<uint32_t>::max();
+static constexpr uint32_t kInt32Max =
+    static_cast<uint32_t>(std::numeric_limits<int32_t>::max());
+arrow::Result<uint32_t> PayloadHeaderFieldSize(const std::string& field,
+                                               const std::shared_ptr<Buffer>& 
data,
+                                               uint32_t* total_size) {
+  if (!data) return kMissingFieldSentinel;
+  if (data->size() > kInt32Max) {
+    return Status::Invalid(field, " must be less than 2 GiB, was: ", 
data->size());
+  }
+  *total_size += static_cast<uint32_t>(data->size());
+  // Check for underflow
+  if (*total_size < 0) return Status::Invalid("Payload header must fit in a 
uint32_t");
+  return static_cast<uint32_t>(data->size());
+}
+uint8_t* PackField(uint32_t size, const std::shared_ptr<Buffer>& data, 
uint8_t* out) {
+  UInt32ToBytesBe(size, out);
+  if (size != kMissingFieldSentinel) {
+    std::memcpy(out + 4, data->data(), size);
+    return out + 4 + size;
+  } else {
+    return out + 4;
+  }
+}
+}  // namespace
+
+arrow::Result<PayloadHeaderFrame> PayloadHeaderFrame::Make(const 
FlightPayload& payload,
+                                                           MemoryPool* 
memory_pool) {
+  // Assemble all non-data fields here. Presumably this is much less
+  // than data size so we will pay the copy.
+
+  // Structure per field: [4 byte length][data]. If a field is not
+  // present, UINT32_MAX is used as the sentinel (since 0-sized fields
+  // are acceptable)
+  uint32_t header_size = 12;
+  ARROW_ASSIGN_OR_RAISE(
+      const uint32_t descriptor_size,
+      PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size));
+  ARROW_ASSIGN_OR_RAISE(
+      const uint32_t app_metadata_size,
+      PayloadHeaderFieldSize("app_metadata", payload.app_metadata, 
&header_size));
+  ARROW_ASSIGN_OR_RAISE(
+      const uint32_t ipc_metadata_size,
+      PayloadHeaderFieldSize("ipc_message.metadata", 
payload.ipc_message.metadata,
+                             &header_size));
+
+  ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, 
memory_pool));
+  uint8_t* payload_header = header_buffer->mutable_data();
+
+  payload_header = PackField(descriptor_size, payload.descriptor, 
payload_header);
+  payload_header = PackField(app_metadata_size, payload.app_metadata, 
payload_header);
+  payload_header =
+      PackField(ipc_metadata_size, payload.ipc_message.metadata, 
payload_header);
+
+  return PayloadHeaderFrame(std::move(header_buffer));
+}
+Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) {
+  std::shared_ptr<Buffer> buffer = std::move(buffer_);
+
+  // Unpack the descriptor
+  uint32_t offset = 0;
+  uint32_t size = BytesToUInt32Be(buffer->data());
+  offset += 4;
+  if (size != kMissingFieldSentinel) {
+    if (static_cast<int64_t>(offset + size) > buffer->size()) {
+      return Status::Invalid("Buffer is too small: expected ", offset + size,
+                             " bytes but have ", buffer->size());
+    }
+    util::string_view desc(reinterpret_cast<const char*>(buffer->data() + 
offset), size);
+    data->descriptor.reset(new FlightDescriptor());
+    ARROW_ASSIGN_OR_RAISE(*data->descriptor, 
FlightDescriptor::Deserialize(desc));
+    offset += size;
+  } else {
+    data->descriptor = nullptr;
+  }
+
+  // Unpack app_metadata
+  size = BytesToUInt32Be(buffer->data() + offset);
+  offset += 4;
+  // While we properly handle zero-size vs nullptr metadata here, gRPC
+  // doesn't (Protobuf doesn't differentiate between the two)
+  if (size != kMissingFieldSentinel) {
+    if (static_cast<int64_t>(offset + size) > buffer->size()) {
+      return Status::Invalid("Buffer is too small: expected ", offset + size,
+                             " bytes but have ", buffer->size());
+    }
+    data->app_metadata = SliceBuffer(buffer, offset, size);
+    offset += size;
+  } else {
+    data->app_metadata = nullptr;
+  }
+
+  // Unpack the IPC header
+  size = BytesToUInt32Be(buffer->data() + offset);
+  offset += 4;
+  if (size != kMissingFieldSentinel) {
+    if (static_cast<int64_t>(offset + size) > buffer->size()) {
+      return Status::Invalid("Buffer is too small: expected ", offset + size,
+                             " bytes but have ", buffer->size());
+    }
+    data->metadata = SliceBuffer(std::move(buffer), offset, size);
+  } else {
+    data->metadata = nullptr;
+  }
+  data->body = nullptr;
+  return Status::OK();
+}
+
+// pImpl the driver since async methods require a stable address
+class UcpCallDriver::Impl {
+ public:
+#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG)
+  constexpr static bool kEnableContigSend = true;
+#else
+  constexpr static bool kEnableContigSend = false;
+#endif
+
+  Impl(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+      : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}),
+        worker_(std::move(worker)),
+        endpoint_(endpoint),
+        read_memory_pool_(default_memory_pool()),
+        write_memory_pool_(default_memory_pool()),
+        memory_manager_(CPUDevice::Instance()->default_memory_manager()),
+        name_("(unknown remote)"),
+        counter_(0) {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+    TryMapBuffer(worker_->context().get(), padding_bytes_.data(), 
padding_bytes_.size(),
+                 UCS_MEMORY_TYPE_HOST, &padding_memh_p_);
+#endif
+
+    ucp_ep_attr_t attrs;
+    std::memset(&attrs, 0, sizeof(attrs));
+    attrs.field_mask =
+        UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR;
+    if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) {
+      std::string local_addr, remote_addr;
+      ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr));
+      
ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr));
+      name_ = "local:" + local_addr + ";remote:" + remote_addr;
+    }
+  }
+
+  ~Impl() {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+    TryUnmapBuffer(worker_->context().get(), padding_memh_p_);
+#endif
+  }
+
+  arrow::Result<std::shared_ptr<Frame>> ReadNextFrame() {
+    auto fut = ReadFrameAsync();
+    while (!fut.is_finished()) MakeProgress();
+    RETURN_NOT_OK(fut.status());
+    return fut.MoveResult();
+  }
+
+  Future<std::shared_ptr<Frame>> ReadFrameAsync() {
+    RETURN_NOT_OK(CheckClosed());
+
+    std::unique_lock<std::mutex> guard(frame_mutex_);
+    if (ARROW_PREDICT_FALSE(!status_.ok())) return status_;
+
+    const uint32_t counter_value = next_counter_++;
+    auto it = frames_.find(counter_value);
+    if (it != frames_.end()) {

Review comment:
       `counter_` is used for reordering. Since we use the async/nonblocking 
calls in UCX, we won't necessarily receive messages in the same order they're 
sent. (I'm not sure of the exact guarantee here. I believe it's true that UCX 
will deliver the initial notifications in the same order, but then we have an 
async call to actually read the data, and I found that those didn't necessarily 
complete in order.) So `counter_` makes sure we don't transpose messages. 
   
   If the UCX callback runs first, then there will already be a future in the 
map (i.e. there's a delivered message that the client has not yet processed). 
If the client tries to read a message first, then there will be no future in 
the map (i.e. the client expects a message the server has not yet delivered). 
I'll add this in comments to explain what's going on.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to