This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 69d595a ARROW-4558: [C++][Flight] Implement gRPC customizations
without UB
69d595a is described below
commit 69d595ae4c61902b3f2778e536fca6675350c88c
Author: Wes McKinney <[email protected]>
AuthorDate: Wed Feb 13 14:17:56 2019 -0600
ARROW-4558: [C++][Flight] Implement gRPC customizations without UB
I admit this feels gross, but it's less gross than what was there before. I
can do some more clean up but wanted to get feedback before spending any more
time on it
So, the problem partially lies with the gRPC C++ library. The obvious
thing, and first thing I tried, was to specialize
`SerializationTraits<protocol::FlightData>` and do casts between `FlightData`
and `protocol::FlightData` (the proto) at the last possible moment.
Unfortunately, this seems to not be possible because of this:
https://github.com/grpc/grpc/blob/master/include/grpcpp/impl/codegen/proto_utils.h#L100
So I had to override that Googly hack and go to some shenanigans (see
protocol.h/protocol.cc) to make sure the same templates are always visible both
in `Flight.grpc.pb.cc` as well as our client.cc/server.cc
Author: Wes McKinney <[email protected]>
Closes #3633 from wesm/flight-cpp-avoid-ub and squashes the following
commits:
ed6eb800e <Wes McKinney> Further refinements, make protocol.h an internal
header. Comments per feedback
b3609d4f0 <Wes McKinney> Add comments about the purpose of protocol.cc
ac405b326 <Wes McKinney> Ensure .proto file is compiled before anything else
23fe416b4 <Wes McKinney> Implement gRPC customizations another way without
calling reinterpret_cast on the client and server C++ types
---
cpp/src/arrow/flight/CMakeLists.txt | 11 +-
cpp/src/arrow/flight/client.cc | 21 +-
cpp/src/arrow/flight/customize_protobuf.h | 100 +++++++++
cpp/src/arrow/flight/internal.cc | 2 +
cpp/src/arrow/flight/internal.h | 3 +-
...ialization-internal.cc => protocol-internal.cc} | 29 +--
...rialization-internal.cc => protocol-internal.h} | 24 +--
cpp/src/arrow/flight/serialization-internal.cc | 205 +++++++++++++++++++
cpp/src/arrow/flight/serialization-internal.h | 224 +--------------------
cpp/src/arrow/flight/server.cc | 26 +--
10 files changed, 354 insertions(+), 291 deletions(-)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt
b/cpp/src/arrow/flight/CMakeLists.txt
index a32a5fa..9183e26 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -58,11 +58,16 @@ add_custom_command(
set_source_files_properties(${FLIGHT_GENERATED_PROTO_FILES} PROPERTIES
GENERATED TRUE)
+add_custom_target(flight_grpc_gen ALL DEPENDS ${FLIGHT_GENERATED_PROTO_FILES})
+
+# Note, we do not compile the generated Protobuf sources directly, instead
+# compiling then via protocol-internal.cc which contains some gRPC template
+# overrides to enable Flight-specific optimizations. See comments in
+# protobuf-internal.cc
set(ARROW_FLIGHT_SRCS
client.cc
- Flight.pb.cc
- Flight.grpc.pb.cc
internal.cc
+ protocol-internal.cc
serialization-internal.cc
server.cc
types.cc)
@@ -70,6 +75,8 @@ set(ARROW_FLIGHT_SRCS
add_arrow_lib(arrow_flight
SOURCES
${ARROW_FLIGHT_SRCS}
+ DEPENDENCIES
+ flight_grpc_gen
SHARED_LINK_LIBS
arrow_shared
${ARROW_FLIGHT_STATIC_LINK_LIBS}
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index fd13f79..8520777 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/flight/client.h"
+#include "arrow/flight/protocol-internal.h"
#include <memory>
#include <sstream>
@@ -33,8 +34,6 @@
#include "arrow/type.h"
#include "arrow/util/logging.h"
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/serialization-internal.h"
@@ -74,13 +73,8 @@ class FlightStreamReader : public RecordBatchReader {
return Status::OK();
}
- // For customizing read path for better memory/serialization efficiency
- // XXX this cast is undefined behavior
- auto custom_reader =
reinterpret_cast<grpc::ClientReader<FlightData>*>(stream_.get());
-
- // Explicitly specify the override to invoke - otherwise compiler
- // may invoke through vtable (not updated by reinterpret_cast)
- if (custom_reader->grpc::ClientReader<FlightData>::Read(&data)) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (stream_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
std::unique_ptr<ipc::Message> message;
// Validate IPC message
@@ -127,12 +121,9 @@ class FlightPutWriter::FlightPutWriterImpl : public
ipc::RecordBatchWriter {
Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false)
override {
IpcPayload payload;
RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_,
&payload));
- // XXX this cast is undefined behavior
- auto custom_writer =
reinterpret_cast<grpc::ClientWriter<IpcPayload>*>(writer_.get());
- // Explicitly specify the override to invoke - otherwise compiler
- // may invoke through vtable (not updated by reinterpret_cast)
- if (!custom_writer->grpc::ClientWriter<IpcPayload>::Write(payload,
-
grpc::WriteOptions())) {
+
+ if (!writer_->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
std::stringstream ss;
ss << "Could not write record batch to stream: "
<< rpc_->context.debug_error_string();
diff --git a/cpp/src/arrow/flight/customize_protobuf.h
b/cpp/src/arrow/flight/customize_protobuf.h
new file mode 100644
index 0000000..fd2e086
--- /dev/null
+++ b/cpp/src/arrow/flight/customize_protobuf.h
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+
+#include "grpcpp/impl/codegen/config_protobuf.h"
+
+// It is necessary to undefined this macro so that the protobuf
+// SerializationTraits specialization is not declared in proto_utils.h. We've
+// copied that specialization below and modified it to exclude
+// protocol::FlightData from the default implementation so we can specialize
+// for our faster serialization-deserialization path
+#undef GRPC_OPEN_SOURCE_PROTO
+
+#include "grpcpp/impl/codegen/proto_utils.h"
+
+namespace arrow {
+namespace ipc {
+namespace internal {
+
+struct IpcPayload;
+
+} // namespace internal
+} // namespace ipc
+
+namespace flight {
+
+struct FlightData;
+
+namespace protocol {
+
+class FlightData;
+
+} // namespace protocol
+} // namespace flight
+} // namespace arrow
+
+namespace grpc {
+
+using arrow::flight::FlightData;
+using arrow::ipc::internal::IpcPayload;
+
+class ByteBuffer;
+class Status;
+
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer);
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
+
+// This class provides a protobuf serializer. It translates between protobuf
+// objects and grpc_byte_buffers. More information about SerializationTraits
can
+// be found in include/grpcpp/impl/codegen/serialization_traits.h.
+template <class T>
+class SerializationTraits<
+ T, typename std::enable_if<
+ std::is_base_of<grpc::protobuf::Message, T>::value &&
+ !std::is_same<arrow::flight::protocol::FlightData,
T>::value>::type> {
+ public:
+ static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
+ bool* own_buffer) {
+ return GenericSerialize<ProtoBufferWriter, T>(msg, bb, own_buffer);
+ }
+
+ static Status Deserialize(ByteBuffer* buffer, grpc::protobuf::Message* msg) {
+ return GenericDeserialize<ProtoBufferReader, T>(buffer, msg);
+ }
+};
+
+template <class T>
+class SerializationTraits<T, typename std::enable_if<std::is_same<
+ arrow::flight::protocol::FlightData,
T>::value>::type> {
+ public:
+ static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
+ bool* own_buffer) {
+ return FlightDataSerialize(*reinterpret_cast<const IpcPayload*>(&msg), bb,
+ own_buffer);
+ }
+
+ static Status Deserialize(ByteBuffer* buffer, grpc::protobuf::Message* msg) {
+ return FlightDataDeserialize(buffer, reinterpret_cast<FlightData*>(msg));
+ }
+};
+
+} // namespace grpc
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index a614450..21268e8 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -17,6 +17,8 @@
#include "arrow/flight/internal.h"
+#include "arrow/flight/customize_protobuf.h"
+
#include <memory>
#include <string>
#include <utility>
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 7f9bda1..15c3d71 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -26,8 +26,7 @@
#include "arrow/ipc/writer.h"
#include "arrow/util/macros.h"
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
+#include "arrow/flight/protocol-internal.h"
#include "arrow/flight/types.h"
namespace arrow {
diff --git a/cpp/src/arrow/flight/serialization-internal.cc
b/cpp/src/arrow/flight/protocol-internal.cc
similarity index 55%
copy from cpp/src/arrow/flight/serialization-internal.cc
copy to cpp/src/arrow/flight/protocol-internal.cc
index 194a7b5..116bb2b 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/protocol-internal.cc
@@ -13,25 +13,14 @@
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
-// under the License.
-#include "arrow/flight/serialization-internal.h"
+#include "arrow/flight/protocol-internal.h"
-namespace arrow {
-namespace flight {
-namespace internal {
-
-bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
- CodedInputStream* input,
std::shared_ptr<arrow::Buffer>* out) {
- uint32_t length;
- if (!input->ReadVarint32(&length)) {
- return false;
- }
- *out = arrow::SliceBuffer(source_data, input->CurrentPosition(),
- static_cast<int64_t>(length));
- return input->Skip(static_cast<int>(length));
-}
-
-} // namespace internal
-} // namespace flight
-} // namespace arrow
+// NOTE(wesm): Including .cc files in another .cc file would ordinarily be a
+// no-no. We have customized the serialization path for FlightData, which is
+// currently only possible through some pre-processor commands that need to be
+// included before either of these files is compiled. Because we don't want to
+// edit the generated C++ files, we include them here and do our gRPC
+// customizations in protocol-internal.h
+#include "arrow/flight/Flight.grpc.pb.cc" // NOLINT
+#include "arrow/flight/Flight.pb.cc" // NOLINT
diff --git a/cpp/src/arrow/flight/serialization-internal.cc
b/cpp/src/arrow/flight/protocol-internal.h
similarity index 55%
copy from cpp/src/arrow/flight/serialization-internal.cc
copy to cpp/src/arrow/flight/protocol-internal.h
index 194a7b5..d3ba77f 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/protocol-internal.h
@@ -13,25 +13,11 @@
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
-// under the License.
-#include "arrow/flight/serialization-internal.h"
+#pragma once
-namespace arrow {
-namespace flight {
-namespace internal {
+// Need to include this first to get our gRPC customizations
+#include "arrow/flight/customize_protobuf.h"
-bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
- CodedInputStream* input,
std::shared_ptr<arrow::Buffer>* out) {
- uint32_t length;
- if (!input->ReadVarint32(&length)) {
- return false;
- }
- *out = arrow::SliceBuffer(source_data, input->CurrentPosition(),
- static_cast<int64_t>(length));
- return input->Skip(static_cast<int>(length));
-}
-
-} // namespace internal
-} // namespace flight
-} // namespace arrow
+#include "arrow/flight/Flight.grpc.pb.h"
+#include "arrow/flight/Flight.pb.h"
diff --git a/cpp/src/arrow/flight/serialization-internal.cc
b/cpp/src/arrow/flight/serialization-internal.cc
index 194a7b5..67b2155 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -32,6 +32,211 @@ bool ReadBytesZeroCopy(const
std::shared_ptr<arrow::Buffer>& source_data,
return input->Skip(static_cast<int>(length));
}
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+
+// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow
+// consumers with zero-copy
+class GrpcBuffer : public arrow::MutableBuffer {
+ public:
+ GrpcBuffer(grpc_slice slice, bool incref)
+ : MutableBuffer(GRPC_SLICE_START_PTR(slice),
+ static_cast<int64_t>(GRPC_SLICE_LENGTH(slice))),
+ slice_(incref ? grpc_slice_ref(slice) : slice) {}
+
+ ~GrpcBuffer() override {
+ // Decref slice
+ grpc_slice_unref(slice_);
+ }
+
+ static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf,
+ std::shared_ptr<arrow::Buffer>* out) {
+ // These types are guaranteed by static assertions in gRPC to have the same
+ // in-memory representation
+
+ auto buffer = *reinterpret_cast<grpc_byte_buffer**>(cpp_buf);
+
+ // This part below is based on the Flatbuffers gRPC SerializationTraits in
+ // flatbuffers/grpc.h
+
+ // Check if this is a single uncompressed slice.
+ if ((buffer->type == GRPC_BB_RAW) &&
+ (buffer->data.raw.compression == GRPC_COMPRESS_NONE) &&
+ (buffer->data.raw.slice_buffer.count == 1)) {
+ // If it is, then we can reference the `grpc_slice` directly.
+ grpc_slice slice = buffer->data.raw.slice_buffer.slices[0];
+
+ // Increment reference count so this memory remains valid
+ *out = std::make_shared<GrpcBuffer>(slice, true);
+ } else {
+ // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read
+ // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives
+ // us back a new slice with the refcount already incremented.
+ grpc_byte_buffer_reader reader;
+ if (!grpc_byte_buffer_reader_init(&reader, buffer)) {
+ return arrow::Status::IOError("Internal gRPC error reading from
ByteBuffer");
+ }
+ grpc_slice slice = grpc_byte_buffer_reader_readall(&reader);
+ grpc_byte_buffer_reader_destroy(&reader);
+
+ // Steal the slice reference
+ *out = std::make_shared<GrpcBuffer>(slice, false);
+ }
+
+ return arrow::Status::OK();
+ }
+
+ private:
+ grpc_slice slice_;
+};
+
} // namespace internal
} // namespace flight
} // namespace arrow
+
+namespace grpc {
+
+using arrow::flight::FlightData;
+using arrow::flight::internal::GrpcBuffer;
+using arrow::flight::internal::ReadBytesZeroCopy;
+
+using google::protobuf::internal::WireFormatLite;
+using google::protobuf::io::ArrayOutputStream;
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer) {
+ size_t total_size = 0;
+
+ DCHECK_LT(msg.metadata->size(), kInt32Max);
+ const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
+
+ // 1 byte for metadata tag
+ total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
+
+ int64_t body_size = 0;
+ for (const auto& buffer : msg.body_buffers) {
+ // Buffer may be null when the row length is zero, or when all
+ // entries are invalid.
+ if (!buffer) continue;
+
+ body_size += buffer->size();
+
+ const int64_t remainder = buffer->size() % 8;
+ if (remainder) {
+ body_size += 8 - remainder;
+ }
+ }
+
+ // 2 bytes for body tag
+ // Only written when there are body buffers
+ if (msg.body_length > 0) {
+ total_size += 2 +
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
+ }
+
+ // TODO(wesm): messages over 2GB unlikely to be yet supported
+ if (total_size > kInt32Max) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Cannot send record batches exceeding 2GB yet");
+ }
+
+ // Allocate slice, assign to output buffer
+ grpc::Slice slice(total_size);
+
+ // XXX(wesm): for debugging
+ // std::cout << "Writing record batch with total size " << total_size <<
std::endl;
+
+ ArrayOutputStream writer(const_cast<uint8_t*>(slice.begin()),
+ static_cast<int>(slice.size()));
+ CodedOutputStream pb_stream(&writer);
+
+ // Write header
+ WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
+ pb_stream.WriteVarint32(metadata_size);
+ pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
+ static_cast<int>(msg.metadata->size()));
+
+ // Don't write tag if there are no body buffers
+ if (msg.body_length > 0) {
+ // Write body
+ WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
+ pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));
+
+ constexpr uint8_t kPaddingBytes[8] = {0};
+
+ for (const auto& buffer : msg.body_buffers) {
+ // Buffer may be null when the row length is zero, or when all
+ // entries are invalid.
+ if (!buffer) continue;
+
+ pb_stream.WriteRawMaybeAliased(buffer->data(),
static_cast<int>(buffer->size()));
+
+ // Write padding if not multiple of 8
+ const int remainder = static_cast<int>(buffer->size() % 8);
+ if (remainder) {
+ pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
+ }
+ }
+ }
+
+ DCHECK_EQ(static_cast<int>(total_size), pb_stream.ByteCount());
+
+ // Hand off the slice to the returned ByteBuffer
+ grpc::ByteBuffer tmp(&slice, 1);
+ out->Swap(&tmp);
+ *own_buffer = true;
+ return grpc::Status::OK;
+}
+
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
+ if (!buffer) {
+ return Status(StatusCode::INTERNAL, "No payload");
+ }
+
+ std::shared_ptr<arrow::Buffer> wrapped_buffer;
+ GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer));
+
+ auto buffer_length = static_cast<int>(wrapped_buffer->size());
+ CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length);
+
+ // TODO(wesm): The 2-parameter version of this function is deprecated
+ pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */);
+
+ // This is the bytes remaining when using CodedInputStream like this
+ while (pb_stream.BytesUntilTotalBytesLimit()) {
+ const uint32_t tag = pb_stream.ReadTag();
+ const int field_number = WireFormatLite::GetTagFieldNumber(tag);
+ switch (field_number) {
+ case pb::FlightData::kFlightDescriptorFieldNumber: {
+ pb::FlightDescriptor pb_descriptor;
+ if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
+ return Status(StatusCode::INTERNAL, "Unable to parse
FlightDescriptor");
+ }
+ } break;
+ case pb::FlightData::kDataHeaderFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
+ return Status(StatusCode::INTERNAL, "Unable to read FlightData
metadata");
+ }
+ } break;
+ case pb::FlightData::kDataBodyFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
+ return Status(StatusCode::INTERNAL, "Unable to read FlightData
body");
+ }
+ } break;
+ default:
+ DCHECK(false) << "cannot happen";
+ }
+ }
+ buffer->Clear();
+
+ // TODO(wesm): Where and when should we verify that the FlightData is not
+ // malformed or missing components?
+
+ return Status::OK;
+}
+
+} // namespace grpc
diff --git a/cpp/src/arrow/flight/serialization-internal.h
b/cpp/src/arrow/flight/serialization-internal.h
index 06cdcdf..19c8592 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -20,6 +20,9 @@
#pragma once
+// Enable gRPC customizations
+#include "arrow/flight/protocol-internal.h"
+
#include <limits>
#include <memory>
@@ -29,14 +32,13 @@
#include "google/protobuf/wire_format_lite.h"
#include "grpc/byte_buffer_reader.h"
#include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/util/logging.h"
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/types.h"
@@ -70,71 +72,13 @@ using google::protobuf::io::CodedOutputStream;
bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
CodedInputStream* input,
std::shared_ptr<arrow::Buffer>* out);
-// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow
-// consumers with zero-copy
-class GrpcBuffer : public arrow::MutableBuffer {
- public:
- GrpcBuffer(grpc_slice slice, bool incref)
- : MutableBuffer(GRPC_SLICE_START_PTR(slice),
- static_cast<int64_t>(GRPC_SLICE_LENGTH(slice))),
- slice_(incref ? grpc_slice_ref(slice) : slice) {}
-
- ~GrpcBuffer() override {
- // Decref slice
- grpc_slice_unref(slice_);
- }
-
- static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf,
- std::shared_ptr<arrow::Buffer>* out) {
- // These types are guaranteed by static assertions in gRPC to have the same
- // in-memory representation
-
- auto buffer = *reinterpret_cast<grpc_byte_buffer**>(cpp_buf);
-
- // This part below is based on the Flatbuffers gRPC SerializationTraits in
- // flatbuffers/grpc.h
-
- // Check if this is a single uncompressed slice.
- if ((buffer->type == GRPC_BB_RAW) &&
- (buffer->data.raw.compression == GRPC_COMPRESS_NONE) &&
- (buffer->data.raw.slice_buffer.count == 1)) {
- // If it is, then we can reference the `grpc_slice` directly.
- grpc_slice slice = buffer->data.raw.slice_buffer.slices[0];
-
- // Increment reference count so this memory remains valid
- *out = std::make_shared<GrpcBuffer>(slice, true);
- } else {
- // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read
- // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives
- // us back a new slice with the refcount already incremented.
- grpc_byte_buffer_reader reader;
- if (!grpc_byte_buffer_reader_init(&reader, buffer)) {
- return arrow::Status::IOError("Internal gRPC error reading from
ByteBuffer");
- }
- grpc_slice slice = grpc_byte_buffer_reader_readall(&reader);
- grpc_byte_buffer_reader_destroy(&reader);
-
- // Steal the slice reference
- *out = std::make_shared<GrpcBuffer>(slice, false);
- }
-
- return arrow::Status::OK();
- }
-
- private:
- grpc_slice slice_;
-};
-
} // namespace internal
-
} // namespace flight
} // namespace arrow
namespace grpc {
using arrow::flight::FlightData;
-using arrow::flight::internal::GrpcBuffer;
-using arrow::flight::internal::ReadBytesZeroCopy;
using google::protobuf::internal::WireFormatLite;
using google::protobuf::io::ArrayOutputStream;
@@ -158,163 +102,11 @@ inline arrow::Status FailSerialization(arrow::Status
status) {
return status;
}
-// Read internal::FlightData from grpc::ByteBuffer containing FlightData
-// protobuf without copying
-template <>
-class SerializationTraits<FlightData> {
- public:
- static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool*
own_buffer) {
- return FailSerialization(Status(
- StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not
implemented"));
- }
-
- static Status Deserialize(ByteBuffer* buffer, FlightData* out) {
- if (!buffer) {
- return FailSerialization(Status(StatusCode::INTERNAL, "No payload"));
- }
-
- std::shared_ptr<arrow::Buffer> wrapped_buffer;
- GRPC_RETURN_NOT_OK(FailSerialization(GrpcBuffer::Wrap(buffer,
&wrapped_buffer)));
-
- auto buffer_length = static_cast<int>(wrapped_buffer->size());
- CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length);
-
- // TODO(wesm): The 2-parameter version of this function is deprecated
- pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */);
-
- // This is the bytes remaining when using CodedInputStream like this
- while (pb_stream.BytesUntilTotalBytesLimit()) {
- const uint32_t tag = pb_stream.ReadTag();
- const int field_number = WireFormatLite::GetTagFieldNumber(tag);
- switch (field_number) {
- case pb::FlightData::kFlightDescriptorFieldNumber: {
- pb::FlightDescriptor pb_descriptor;
- if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
- return FailSerialization(
- Status(StatusCode::INTERNAL, "Unable to parse
FlightDescriptor"));
- }
- } break;
- case pb::FlightData::kDataHeaderFieldNumber: {
- if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
- return FailSerialization(
- Status(StatusCode::INTERNAL, "Unable to read FlightData
metadata"));
- }
- } break;
- case pb::FlightData::kDataBodyFieldNumber: {
- if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
- return FailSerialization(
- Status(StatusCode::INTERNAL, "Unable to read FlightData
body"));
- }
- } break;
- default:
- DCHECK(false) << "cannot happen";
- }
- }
- buffer->Clear();
-
- // TODO(wesm): Where and when should we verify that the FlightData is not
- // malformed or missing components?
-
- return Status::OK;
- }
-};
-
// Write FlightData to a grpc::ByteBuffer without extra copying
-template <>
-class SerializationTraits<IpcPayload> {
- public:
- static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) {
- return FailSerialization(grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
- "IpcPayload deserialization not
implemented"));
- }
-
- static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out,
- bool* own_buffer) {
- size_t total_size = 0;
-
- DCHECK_LT(msg.metadata->size(), kInt32Max);
- const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
-
- // 1 byte for metadata tag
- total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
-
- int64_t body_size = 0;
- for (const auto& buffer : msg.body_buffers) {
- // Buffer may be null when the row length is zero, or when all
- // entries are invalid.
- if (!buffer) continue;
-
- body_size += buffer->size();
-
- const int64_t remainder = buffer->size() % 8;
- if (remainder) {
- body_size += 8 - remainder;
- }
- }
-
- // 2 bytes for body tag
- // Only written when there are body buffers
- if (msg.body_length > 0) {
- total_size +=
- 2 +
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
- }
-
- // TODO(wesm): messages over 2GB unlikely to be yet supported
- if (total_size > kInt32Max) {
- return FailSerialization(
- grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
- "Cannot send record batches exceeding 2GB yet"));
- }
-
- // Allocate slice, assign to output buffer
- grpc::Slice slice(total_size);
-
- // XXX(wesm): for debugging
- // std::cout << "Writing record batch with total size " << total_size <<
std::endl;
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer);
- ArrayOutputStream writer(const_cast<uint8_t*>(slice.begin()),
- static_cast<int>(slice.size()));
- CodedOutputStream pb_stream(&writer);
-
- // Write header
- WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
- WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
- pb_stream.WriteVarint32(metadata_size);
- pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
- static_cast<int>(msg.metadata->size()));
-
- // Don't write tag if there are no body buffers
- if (msg.body_length > 0) {
- // Write body
- WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
- WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
- pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));
-
- constexpr uint8_t kPaddingBytes[8] = {0};
-
- for (const auto& buffer : msg.body_buffers) {
- // Buffer may be null when the row length is zero, or when all
- // entries are invalid.
- if (!buffer) continue;
-
- pb_stream.WriteRawMaybeAliased(buffer->data(),
static_cast<int>(buffer->size()));
-
- // Write padding if not multiple of 8
- const int remainder = static_cast<int>(buffer->size() % 8);
- if (remainder) {
- pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
- }
- }
- }
-
- DCHECK_EQ(static_cast<int>(total_size), pb_stream.ByteCount());
-
- // Hand off the slice to the returned ByteBuffer
- grpc::ByteBuffer tmp(&slice, 1);
- out->Swap(&tmp);
- *own_buffer = true;
- return grpc::Status::OK;
- }
-};
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
} // namespace grpc
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index ac5b535..2fef93d 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/flight/server.h"
+#include "arrow/flight/protocol-internal.h"
#include <cstdint>
#include <memory>
@@ -29,8 +30,6 @@
#include "arrow/status.h"
#include "arrow/util/logging.h"
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/serialization-internal.h"
#include "arrow/flight/types.h"
@@ -73,13 +72,9 @@ class FlightMessageReaderImpl : public FlightMessageReader {
return Status::OK();
}
- // XXX this cast is undefined behavior
- auto custom_reader =
reinterpret_cast<grpc::ServerReader<FlightData>*>(reader_);
-
FlightData data;
- // Explicitly specify the override to invoke - otherwise compiler
- // may invoke through vtable (not updated by reinterpret_cast)
- if (custom_reader->grpc::ServerReader<FlightData>::Read(&data)) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (reader_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
std::unique_ptr<ipc::Message> message;
// Validate IPC message
@@ -184,26 +179,23 @@ class FlightServiceImpl : public FlightService::Service {
std::unique_ptr<FlightDataStream> data_stream;
GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream));
- // Requires ServerWriter customization in grpc_customizations.h
- // XXX this cast is undefined behavior
- auto custom_writer = reinterpret_cast<ServerWriter<IpcPayload>*>(writer);
-
// Write the schema as the first message in the stream
IpcPayload schema_payload;
MemoryPool* pool = default_memory_pool();
ipc::DictionaryMemo dictionary_memo;
GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
*data_stream->schema(), pool, &dictionary_memo, &schema_payload));
- // Explicitly specify the override to invoke - otherwise compiler
- // may invoke through vtable (not updated by reinterpret_cast)
- custom_writer->grpc::ServerWriter<IpcPayload>::Write(schema_payload,
- grpc::WriteOptions());
+
+ // Pretend to be pb::FlightData, we cast back to IpcPayload in
SerializationTraits
+ writer->Write(*reinterpret_cast<const pb::FlightData*>(&schema_payload),
+ grpc::WriteOptions());
while (true) {
IpcPayload payload;
GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
if (payload.metadata == nullptr ||
- !custom_writer->Write(payload, grpc::WriteOptions())) {
+ !writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
// No more messages to write, or connection terminated for some other
// reason
break;