zeroshade commented on code in PR #28: URL: https://github.com/apache/arrow-experiments/pull/28#discussion_r1574884794
########## dissociated-ipc/cudf-flight-poc.cc: ########## @@ -0,0 +1,799 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <future> +#include <iostream> +#include <queue> +#include <thread> + +#include <gflags/gflags.h> +#include <cudf/interop.hpp> +#include <cudf/io/parquet.hpp> + +#include <arrow/c/abi.h> +#include <arrow/c/bridge.h> +#include <arrow/device.h> +#include <arrow/flight/client.h> +#include <arrow/flight/server.h> +#include <arrow/gpu/cuda_api.h> +#include <arrow/ipc/api.h> +#include <arrow/util/endian.h> +#include <arrow/util/logging.h> +#include <arrow/util/uri.h> + +#include "ucx_client.h" +#include "ucx_server.h" + +namespace flight = arrow::flight; +namespace ipc = arrow::ipc; + +// Define some constants for the `want_data` tags +static constexpr ucp_tag_t kWantDataTag = 0x00000DEADBA0BAB0; +static constexpr ucp_tag_t kWantCtrlTag = 0xFFFFFDEADBA0BAB0; +// define a mask to check the tag +static constexpr ucp_tag_t kWantCtrlMask = 0xFFFFF00000000000; + +enum class MetadataMsgType : uint8_t { + EOS = 0, + METADATA = 1, +}; + +cudf::column_metadata column_info_to_metadata(const cudf::io::column_name_info& info) { + cudf::column_metadata result; + result.name = info.name; + std::transform(info.children.begin(), info.children.end(), + std::back_inserter(result.children_meta), column_info_to_metadata); + return result; +} + +std::vector<cudf::column_metadata> table_metadata_to_column( + const cudf::io::table_metadata& tbl_meta) { + std::vector<cudf::column_metadata> result; + + std::transform(tbl_meta.schema_info.begin(), tbl_meta.schema_info.end(), + std::back_inserter(result), column_info_to_metadata); + return result; +} + +// a UCX server which serves cuda record batches via the dissociated ipc protocol +class CudaUcxServer : public UcxServer { + public: + CudaUcxServer() { + // create a buffer holding 8 bytes on the GPU to use for padding buffers + cuda_padding_bytes_ = rmm::device_buffer(8, rmm::cuda_stream_view{}); + cuMemsetD8(reinterpret_cast<uintptr_t>(cuda_padding_bytes_.data()), 0, 8); + } + virtual ~CudaUcxServer() { + if (listening_.load()) { + ARROW_UNUSED(Shutdown()); + } + } + + arrow::Status initialize() { + // load the parquet data directly onto the GPU as a libcudf table + auto source = cudf::io::source_info("./data/taxi-data/train.parquet"); + auto options = cudf::io::parquet_reader_options::builder(source); + cudf::io::chunked_parquet_reader rdr(1 * 1024 * 1024, options); + + // get arrow::RecordBatches for each chunk of the parquet data while + // leaving the data on the GPU + arrow::RecordBatchVector batches; + auto chunk = rdr.read_chunk(); + auto schema = cudf::to_arrow_schema(chunk.tbl->view(), + table_metadata_to_column(chunk.metadata)); + auto device_out = cudf::to_arrow_device(std::move(*chunk.tbl)); + ARROW_ASSIGN_OR_RAISE(auto data, + arrow::ImportDeviceRecordBatch(device_out.get(), schema.get())); + + batches.push_back(std::move(data)); + + while (rdr.has_next()) { + chunk = rdr.read_chunk(); + device_out = cudf::to_arrow_device(std::move(*chunk.tbl)); + ARROW_ASSIGN_OR_RAISE( + data, arrow::ImportDeviceRecordBatch(device_out.get(), schema.get())); + batches.push_back(std::move(data)); + } + + data_sets_.emplace("train.parquet", std::move(batches)); + + // initialize the server and let it choose its own port + ARROW_RETURN_NOT_OK(Init("127.0.0.1", 0)); + + ARROW_ASSIGN_OR_RAISE(ctrl_location_, + flight::Location::Parse(location_.ToString() + "?want_data=" + + std::to_string(kWantCtrlTag))); + ARROW_ASSIGN_OR_RAISE(data_location_, + flight::Location::Parse(location_.ToString() + "?want_data=" + + std::to_string(kWantDataTag))); + return arrow::Status::OK(); + } + + inline flight::Location ctrl_location() const { return ctrl_location_; } + inline flight::Location data_location() const { return data_location_; } + + protected: + arrow::Status setup_handlers(UcxServer::ClientWorker* worker) override { + return arrow::Status::OK(); + } + + arrow::Status do_work(UcxServer::ClientWorker* worker) override { + // probe for a message with the want_data tag synchronously, + // so this will block until it receives a message with this tag + ARROW_ASSIGN_OR_RAISE( + auto tag_info, worker->conn_->ProbeForTagSync(kWantDataTag, ~kWantCtrlMask, 1)); + + std::string msg; + msg.resize(tag_info.first.length); + ARROW_RETURN_NOT_OK( + worker->conn_->RecvTagData(tag_info.second, reinterpret_cast<void*>(msg.data()), + msg.size(), nullptr, nullptr, UCS_MEMORY_TYPE_HOST)); + + ARROW_LOG(DEBUG) << "server received WantData: " << msg; + + // simulate two separate servers, one metadata server and one body data server + if (tag_info.first.sender_tag & kWantCtrlMask) { + return send_metadata_stream(worker, msg); + } + + return send_data_stream(worker, msg); + } + + private: + arrow::Status send_metadata_stream(UcxServer::ClientWorker* worker, + const std::string& ident) { + auto it = data_sets_.find(ident); + if (it == data_sets_.end()) { + return arrow::Status::Invalid("data set not found:", ident); + } + + ipc::IpcWriteOptions ipc_options; + ipc::DictionaryFieldMapper mapper; + const auto& record_list = it->second; + auto schema = record_list[0]->schema(); + ARROW_RETURN_NOT_OK(mapper.AddSchemaFields(*schema)); + + // for each record in the stream, collect the IPC metadata to send + uint32_t sequence_num = 0; + // schema payload is first + ipc::IpcPayload payload; + ARROW_RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options, mapper, &payload)); + ARROW_RETURN_NOT_OK(write_ipc_metadata(worker->conn_.get(), payload, sequence_num++)); + + // then any dictionaries + ARROW_ASSIGN_OR_RAISE(const auto dictionaries, + ipc::CollectDictionaries(*record_list[0], mapper)); + for (const auto& pair : dictionaries) { + ARROW_RETURN_NOT_OK( + ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options, &payload)); + ARROW_RETURN_NOT_OK( + write_ipc_metadata(worker->conn_.get(), payload, sequence_num++)); + } + + // finally the record batch metadata messages + for (const auto& batch : record_list) { + ARROW_RETURN_NOT_OK(ipc::GetRecordBatchPayload(*batch, ipc_options, &payload)); + ARROW_RETURN_NOT_OK( + write_ipc_metadata(worker->conn_.get(), payload, sequence_num++)); + } + + // finally, we send the End-Of-Stream message + std::array<uint8_t, 5> eos_bytes{static_cast<uint8_t>(MetadataMsgType::EOS), 0, 0, 0, + 0}; + utils::Uint32ToBytesLE(sequence_num, eos_bytes.data() + 1); + + ARROW_RETURN_NOT_OK(worker->conn_->Flush()); + return worker->conn_->SendAM(0, eos_bytes.data(), eos_bytes.size()); Review Comment: We don't, the disconnect / handling in the server will call flush before it disconnects after we finish sending the stream. I'm just calling `Flush` before sending the End Of Stream message for convenience purposes to ensure all the rest of the messages have been sent before sending the End of Stream -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
