This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 69d595a  ARROW-4558: [C++][Flight] Implement gRPC customizations 
without UB
69d595a is described below

commit 69d595ae4c61902b3f2778e536fca6675350c88c
Author: Wes McKinney <[email protected]>
AuthorDate: Wed Feb 13 14:17:56 2019 -0600

    ARROW-4558: [C++][Flight] Implement gRPC customizations without UB
    
    I admit this feels gross, but it's less gross than what was there before. I 
can do some more clean up but wanted to get feedback before spending any more 
time on it
    
    So, the problem partially lies with the gRPC C++ library. The obvious 
thing, and first thing I tried, was to specialize 
`SerializationTraits<protocol::FlightData>` and do casts between `FlightData` 
and `protocol::FlightData` (the proto) at the last possible moment. 
Unfortunately, this seems to not be possible because of this:
    
    
https://github.com/grpc/grpc/blob/master/include/grpcpp/impl/codegen/proto_utils.h#L100
    
    So I had to override that Googly hack and go to some shenanigans (see 
protocol.h/protocol.cc) to make sure the same templates are always visible both 
in `Flight.grpc.pb.cc` as well as our client.cc/server.cc
    
    Author: Wes McKinney <[email protected]>
    
    Closes #3633 from wesm/flight-cpp-avoid-ub and squashes the following 
commits:
    
    ed6eb800e <Wes McKinney> Further refinements, make protocol.h an internal 
header. Comments per feedback
    b3609d4f0 <Wes McKinney> Add comments about the purpose of protocol.cc
    ac405b326 <Wes McKinney> Ensure .proto file is compiled before anything else
    23fe416b4 <Wes McKinney> Implement gRPC customizations another way without 
calling reinterpret_cast on the client and server C++ types
---
 cpp/src/arrow/flight/CMakeLists.txt                |  11 +-
 cpp/src/arrow/flight/client.cc                     |  21 +-
 cpp/src/arrow/flight/customize_protobuf.h          | 100 +++++++++
 cpp/src/arrow/flight/internal.cc                   |   2 +
 cpp/src/arrow/flight/internal.h                    |   3 +-
 ...ialization-internal.cc => protocol-internal.cc} |  29 +--
 ...rialization-internal.cc => protocol-internal.h} |  24 +--
 cpp/src/arrow/flight/serialization-internal.cc     | 205 +++++++++++++++++++
 cpp/src/arrow/flight/serialization-internal.h      | 224 +--------------------
 cpp/src/arrow/flight/server.cc                     |  26 +--
 10 files changed, 354 insertions(+), 291 deletions(-)

diff --git a/cpp/src/arrow/flight/CMakeLists.txt 
b/cpp/src/arrow/flight/CMakeLists.txt
index a32a5fa..9183e26 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -58,11 +58,16 @@ add_custom_command(
 
 set_source_files_properties(${FLIGHT_GENERATED_PROTO_FILES} PROPERTIES 
GENERATED TRUE)
 
+add_custom_target(flight_grpc_gen ALL DEPENDS ${FLIGHT_GENERATED_PROTO_FILES})
+
+# Note, we do not compile the generated Protobuf sources directly, instead
+# compiling then via protocol-internal.cc which contains some gRPC template
+# overrides to enable Flight-specific optimizations. See comments in
+# protobuf-internal.cc
 set(ARROW_FLIGHT_SRCS
     client.cc
-    Flight.pb.cc
-    Flight.grpc.pb.cc
     internal.cc
+    protocol-internal.cc
     serialization-internal.cc
     server.cc
     types.cc)
@@ -70,6 +75,8 @@ set(ARROW_FLIGHT_SRCS
 add_arrow_lib(arrow_flight
               SOURCES
               ${ARROW_FLIGHT_SRCS}
+              DEPENDENCIES
+              flight_grpc_gen
               SHARED_LINK_LIBS
               arrow_shared
               ${ARROW_FLIGHT_STATIC_LINK_LIBS}
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index fd13f79..8520777 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -16,6 +16,7 @@
 // under the License.
 
 #include "arrow/flight/client.h"
+#include "arrow/flight/protocol-internal.h"
 
 #include <memory>
 #include <sstream>
@@ -33,8 +34,6 @@
 #include "arrow/type.h"
 #include "arrow/util/logging.h"
 
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
 #include "arrow/flight/internal.h"
 #include "arrow/flight/serialization-internal.h"
 
@@ -74,13 +73,8 @@ class FlightStreamReader : public RecordBatchReader {
       return Status::OK();
     }
 
-    // For customizing read path for better memory/serialization efficiency
-    // XXX this cast is undefined behavior
-    auto custom_reader = 
reinterpret_cast<grpc::ClientReader<FlightData>*>(stream_.get());
-
-    // Explicitly specify the override to invoke - otherwise compiler
-    // may invoke through vtable (not updated by reinterpret_cast)
-    if (custom_reader->grpc::ClientReader<FlightData>::Read(&data)) {
+    // Pretend to be pb::FlightData and intercept in SerializationTraits
+    if (stream_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
       std::unique_ptr<ipc::Message> message;
 
       // Validate IPC message
@@ -127,12 +121,9 @@ class FlightPutWriter::FlightPutWriterImpl : public 
ipc::RecordBatchWriter {
   Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) 
override {
     IpcPayload payload;
     RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, 
&payload));
-    // XXX this cast is undefined behavior
-    auto custom_writer = 
reinterpret_cast<grpc::ClientWriter<IpcPayload>*>(writer_.get());
-    // Explicitly specify the override to invoke - otherwise compiler
-    // may invoke through vtable (not updated by reinterpret_cast)
-    if (!custom_writer->grpc::ClientWriter<IpcPayload>::Write(payload,
-                                                              
grpc::WriteOptions())) {
+
+    if (!writer_->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+                        grpc::WriteOptions())) {
       std::stringstream ss;
       ss << "Could not write record batch to stream: "
          << rpc_->context.debug_error_string();
diff --git a/cpp/src/arrow/flight/customize_protobuf.h 
b/cpp/src/arrow/flight/customize_protobuf.h
new file mode 100644
index 0000000..fd2e086
--- /dev/null
+++ b/cpp/src/arrow/flight/customize_protobuf.h
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+
+#include "grpcpp/impl/codegen/config_protobuf.h"
+
+// It is necessary to undefined this macro so that the protobuf
+// SerializationTraits specialization is not declared in proto_utils.h. We've
+// copied that specialization below and modified it to exclude
+// protocol::FlightData from the default implementation so we can specialize
+// for our faster serialization-deserialization path
+#undef GRPC_OPEN_SOURCE_PROTO
+
+#include "grpcpp/impl/codegen/proto_utils.h"
+
+namespace arrow {
+namespace ipc {
+namespace internal {
+
+struct IpcPayload;
+
+}  // namespace internal
+}  // namespace ipc
+
+namespace flight {
+
+struct FlightData;
+
+namespace protocol {
+
+class FlightData;
+
+}  // namespace protocol
+}  // namespace flight
+}  // namespace arrow
+
+namespace grpc {
+
+using arrow::flight::FlightData;
+using arrow::ipc::internal::IpcPayload;
+
+class ByteBuffer;
+class Status;
+
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer);
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
+
+// This class provides a protobuf serializer. It translates between protobuf
+// objects and grpc_byte_buffers. More information about SerializationTraits 
can
+// be found in include/grpcpp/impl/codegen/serialization_traits.h.
+template <class T>
+class SerializationTraits<
+    T, typename std::enable_if<
+           std::is_base_of<grpc::protobuf::Message, T>::value &&
+           !std::is_same<arrow::flight::protocol::FlightData, 
T>::value>::type> {
+ public:
+  static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
+                          bool* own_buffer) {
+    return GenericSerialize<ProtoBufferWriter, T>(msg, bb, own_buffer);
+  }
+
+  static Status Deserialize(ByteBuffer* buffer, grpc::protobuf::Message* msg) {
+    return GenericDeserialize<ProtoBufferReader, T>(buffer, msg);
+  }
+};
+
+template <class T>
+class SerializationTraits<T, typename std::enable_if<std::is_same<
+                                 arrow::flight::protocol::FlightData, 
T>::value>::type> {
+ public:
+  static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
+                          bool* own_buffer) {
+    return FlightDataSerialize(*reinterpret_cast<const IpcPayload*>(&msg), bb,
+                               own_buffer);
+  }
+
+  static Status Deserialize(ByteBuffer* buffer, grpc::protobuf::Message* msg) {
+    return FlightDataDeserialize(buffer, reinterpret_cast<FlightData*>(msg));
+  }
+};
+
+}  // namespace grpc
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index a614450..21268e8 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -17,6 +17,8 @@
 
 #include "arrow/flight/internal.h"
 
+#include "arrow/flight/customize_protobuf.h"
+
 #include <memory>
 #include <string>
 #include <utility>
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 7f9bda1..15c3d71 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -26,8 +26,7 @@
 #include "arrow/ipc/writer.h"
 #include "arrow/util/macros.h"
 
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
+#include "arrow/flight/protocol-internal.h"
 #include "arrow/flight/types.h"
 
 namespace arrow {
diff --git a/cpp/src/arrow/flight/serialization-internal.cc 
b/cpp/src/arrow/flight/protocol-internal.cc
similarity index 55%
copy from cpp/src/arrow/flight/serialization-internal.cc
copy to cpp/src/arrow/flight/protocol-internal.cc
index 194a7b5..116bb2b 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/protocol-internal.cc
@@ -13,25 +13,14 @@
 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
-// under the License.
 
-#include "arrow/flight/serialization-internal.h"
+#include "arrow/flight/protocol-internal.h"
 
-namespace arrow {
-namespace flight {
-namespace internal {
-
-bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
-                       CodedInputStream* input, 
std::shared_ptr<arrow::Buffer>* out) {
-  uint32_t length;
-  if (!input->ReadVarint32(&length)) {
-    return false;
-  }
-  *out = arrow::SliceBuffer(source_data, input->CurrentPosition(),
-                            static_cast<int64_t>(length));
-  return input->Skip(static_cast<int>(length));
-}
-
-}  // namespace internal
-}  // namespace flight
-}  // namespace arrow
+// NOTE(wesm): Including .cc files in another .cc file would ordinarily be a
+// no-no. We have customized the serialization path for FlightData, which is
+// currently only possible through some pre-processor commands that need to be
+// included before either of these files is compiled. Because we don't want to
+// edit the generated C++ files, we include them here and do our gRPC
+// customizations in protocol-internal.h
+#include "arrow/flight/Flight.grpc.pb.cc"  // NOLINT
+#include "arrow/flight/Flight.pb.cc"       // NOLINT
diff --git a/cpp/src/arrow/flight/serialization-internal.cc 
b/cpp/src/arrow/flight/protocol-internal.h
similarity index 55%
copy from cpp/src/arrow/flight/serialization-internal.cc
copy to cpp/src/arrow/flight/protocol-internal.h
index 194a7b5..d3ba77f 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/protocol-internal.h
@@ -13,25 +13,11 @@
 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
-// under the License.
 
-#include "arrow/flight/serialization-internal.h"
+#pragma once
 
-namespace arrow {
-namespace flight {
-namespace internal {
+// Need to include this first to get our gRPC customizations
+#include "arrow/flight/customize_protobuf.h"
 
-bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
-                       CodedInputStream* input, 
std::shared_ptr<arrow::Buffer>* out) {
-  uint32_t length;
-  if (!input->ReadVarint32(&length)) {
-    return false;
-  }
-  *out = arrow::SliceBuffer(source_data, input->CurrentPosition(),
-                            static_cast<int64_t>(length));
-  return input->Skip(static_cast<int>(length));
-}
-
-}  // namespace internal
-}  // namespace flight
-}  // namespace arrow
+#include "arrow/flight/Flight.grpc.pb.h"
+#include "arrow/flight/Flight.pb.h"
diff --git a/cpp/src/arrow/flight/serialization-internal.cc 
b/cpp/src/arrow/flight/serialization-internal.cc
index 194a7b5..67b2155 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -32,6 +32,211 @@ bool ReadBytesZeroCopy(const 
std::shared_ptr<arrow::Buffer>& source_data,
   return input->Skip(static_cast<int>(length));
 }
 
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+
+// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow
+// consumers with zero-copy
+class GrpcBuffer : public arrow::MutableBuffer {
+ public:
+  GrpcBuffer(grpc_slice slice, bool incref)
+      : MutableBuffer(GRPC_SLICE_START_PTR(slice),
+                      static_cast<int64_t>(GRPC_SLICE_LENGTH(slice))),
+        slice_(incref ? grpc_slice_ref(slice) : slice) {}
+
+  ~GrpcBuffer() override {
+    // Decref slice
+    grpc_slice_unref(slice_);
+  }
+
+  static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf,
+                            std::shared_ptr<arrow::Buffer>* out) {
+    // These types are guaranteed by static assertions in gRPC to have the same
+    // in-memory representation
+
+    auto buffer = *reinterpret_cast<grpc_byte_buffer**>(cpp_buf);
+
+    // This part below is based on the Flatbuffers gRPC SerializationTraits in
+    // flatbuffers/grpc.h
+
+    // Check if this is a single uncompressed slice.
+    if ((buffer->type == GRPC_BB_RAW) &&
+        (buffer->data.raw.compression == GRPC_COMPRESS_NONE) &&
+        (buffer->data.raw.slice_buffer.count == 1)) {
+      // If it is, then we can reference the `grpc_slice` directly.
+      grpc_slice slice = buffer->data.raw.slice_buffer.slices[0];
+
+      // Increment reference count so this memory remains valid
+      *out = std::make_shared<GrpcBuffer>(slice, true);
+    } else {
+      // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read
+      // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives
+      // us back a new slice with the refcount already incremented.
+      grpc_byte_buffer_reader reader;
+      if (!grpc_byte_buffer_reader_init(&reader, buffer)) {
+        return arrow::Status::IOError("Internal gRPC error reading from 
ByteBuffer");
+      }
+      grpc_slice slice = grpc_byte_buffer_reader_readall(&reader);
+      grpc_byte_buffer_reader_destroy(&reader);
+
+      // Steal the slice reference
+      *out = std::make_shared<GrpcBuffer>(slice, false);
+    }
+
+    return arrow::Status::OK();
+  }
+
+ private:
+  grpc_slice slice_;
+};
+
 }  // namespace internal
 }  // namespace flight
 }  // namespace arrow
+
+namespace grpc {
+
+using arrow::flight::FlightData;
+using arrow::flight::internal::GrpcBuffer;
+using arrow::flight::internal::ReadBytesZeroCopy;
+
+using google::protobuf::internal::WireFormatLite;
+using google::protobuf::io::ArrayOutputStream;
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer) {
+  size_t total_size = 0;
+
+  DCHECK_LT(msg.metadata->size(), kInt32Max);
+  const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
+
+  // 1 byte for metadata tag
+  total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
+
+  int64_t body_size = 0;
+  for (const auto& buffer : msg.body_buffers) {
+    // Buffer may be null when the row length is zero, or when all
+    // entries are invalid.
+    if (!buffer) continue;
+
+    body_size += buffer->size();
+
+    const int64_t remainder = buffer->size() % 8;
+    if (remainder) {
+      body_size += 8 - remainder;
+    }
+  }
+
+  // 2 bytes for body tag
+  // Only written when there are body buffers
+  if (msg.body_length > 0) {
+    total_size += 2 + 
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
+  }
+
+  // TODO(wesm): messages over 2GB unlikely to be yet supported
+  if (total_size > kInt32Max) {
+    return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+                        "Cannot send record batches exceeding 2GB yet");
+  }
+
+  // Allocate slice, assign to output buffer
+  grpc::Slice slice(total_size);
+
+  // XXX(wesm): for debugging
+  // std::cout << "Writing record batch with total size " << total_size << 
std::endl;
+
+  ArrayOutputStream writer(const_cast<uint8_t*>(slice.begin()),
+                           static_cast<int>(slice.size()));
+  CodedOutputStream pb_stream(&writer);
+
+  // Write header
+  WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
+                           WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
+  pb_stream.WriteVarint32(metadata_size);
+  pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
+                                 static_cast<int>(msg.metadata->size()));
+
+  // Don't write tag if there are no body buffers
+  if (msg.body_length > 0) {
+    // Write body
+    WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
+                             WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
+    pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));
+
+    constexpr uint8_t kPaddingBytes[8] = {0};
+
+    for (const auto& buffer : msg.body_buffers) {
+      // Buffer may be null when the row length is zero, or when all
+      // entries are invalid.
+      if (!buffer) continue;
+
+      pb_stream.WriteRawMaybeAliased(buffer->data(), 
static_cast<int>(buffer->size()));
+
+      // Write padding if not multiple of 8
+      const int remainder = static_cast<int>(buffer->size() % 8);
+      if (remainder) {
+        pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
+      }
+    }
+  }
+
+  DCHECK_EQ(static_cast<int>(total_size), pb_stream.ByteCount());
+
+  // Hand off the slice to the returned ByteBuffer
+  grpc::ByteBuffer tmp(&slice, 1);
+  out->Swap(&tmp);
+  *own_buffer = true;
+  return grpc::Status::OK;
+}
+
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
+  if (!buffer) {
+    return Status(StatusCode::INTERNAL, "No payload");
+  }
+
+  std::shared_ptr<arrow::Buffer> wrapped_buffer;
+  GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer));
+
+  auto buffer_length = static_cast<int>(wrapped_buffer->size());
+  CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length);
+
+  // TODO(wesm): The 2-parameter version of this function is deprecated
+  pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */);
+
+  // This is the bytes remaining when using CodedInputStream like this
+  while (pb_stream.BytesUntilTotalBytesLimit()) {
+    const uint32_t tag = pb_stream.ReadTag();
+    const int field_number = WireFormatLite::GetTagFieldNumber(tag);
+    switch (field_number) {
+      case pb::FlightData::kFlightDescriptorFieldNumber: {
+        pb::FlightDescriptor pb_descriptor;
+        if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
+          return Status(StatusCode::INTERNAL, "Unable to parse 
FlightDescriptor");
+        }
+      } break;
+      case pb::FlightData::kDataHeaderFieldNumber: {
+        if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
+          return Status(StatusCode::INTERNAL, "Unable to read FlightData 
metadata");
+        }
+      } break;
+      case pb::FlightData::kDataBodyFieldNumber: {
+        if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
+          return Status(StatusCode::INTERNAL, "Unable to read FlightData 
body");
+        }
+      } break;
+      default:
+        DCHECK(false) << "cannot happen";
+    }
+  }
+  buffer->Clear();
+
+  // TODO(wesm): Where and when should we verify that the FlightData is not
+  // malformed or missing components?
+
+  return Status::OK;
+}
+
+}  // namespace grpc
diff --git a/cpp/src/arrow/flight/serialization-internal.h 
b/cpp/src/arrow/flight/serialization-internal.h
index 06cdcdf..19c8592 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -20,6 +20,9 @@
 
 #pragma once
 
+// Enable gRPC customizations
+#include "arrow/flight/protocol-internal.h"
+
 #include <limits>
 #include <memory>
 
@@ -29,14 +32,13 @@
 #include "google/protobuf/wire_format_lite.h"
 #include "grpc/byte_buffer_reader.h"
 #include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
 
 #include "arrow/ipc/writer.h"
 #include "arrow/record_batch.h"
 #include "arrow/status.h"
 #include "arrow/util/logging.h"
 
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
 #include "arrow/flight/internal.h"
 #include "arrow/flight/types.h"
 
@@ -70,71 +72,13 @@ using google::protobuf::io::CodedOutputStream;
 bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data,
                        CodedInputStream* input, 
std::shared_ptr<arrow::Buffer>* out);
 
-// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow
-// consumers with zero-copy
-class GrpcBuffer : public arrow::MutableBuffer {
- public:
-  GrpcBuffer(grpc_slice slice, bool incref)
-      : MutableBuffer(GRPC_SLICE_START_PTR(slice),
-                      static_cast<int64_t>(GRPC_SLICE_LENGTH(slice))),
-        slice_(incref ? grpc_slice_ref(slice) : slice) {}
-
-  ~GrpcBuffer() override {
-    // Decref slice
-    grpc_slice_unref(slice_);
-  }
-
-  static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf,
-                            std::shared_ptr<arrow::Buffer>* out) {
-    // These types are guaranteed by static assertions in gRPC to have the same
-    // in-memory representation
-
-    auto buffer = *reinterpret_cast<grpc_byte_buffer**>(cpp_buf);
-
-    // This part below is based on the Flatbuffers gRPC SerializationTraits in
-    // flatbuffers/grpc.h
-
-    // Check if this is a single uncompressed slice.
-    if ((buffer->type == GRPC_BB_RAW) &&
-        (buffer->data.raw.compression == GRPC_COMPRESS_NONE) &&
-        (buffer->data.raw.slice_buffer.count == 1)) {
-      // If it is, then we can reference the `grpc_slice` directly.
-      grpc_slice slice = buffer->data.raw.slice_buffer.slices[0];
-
-      // Increment reference count so this memory remains valid
-      *out = std::make_shared<GrpcBuffer>(slice, true);
-    } else {
-      // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read
-      // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives
-      // us back a new slice with the refcount already incremented.
-      grpc_byte_buffer_reader reader;
-      if (!grpc_byte_buffer_reader_init(&reader, buffer)) {
-        return arrow::Status::IOError("Internal gRPC error reading from 
ByteBuffer");
-      }
-      grpc_slice slice = grpc_byte_buffer_reader_readall(&reader);
-      grpc_byte_buffer_reader_destroy(&reader);
-
-      // Steal the slice reference
-      *out = std::make_shared<GrpcBuffer>(slice, false);
-    }
-
-    return arrow::Status::OK();
-  }
-
- private:
-  grpc_slice slice_;
-};
-
 }  // namespace internal
-
 }  // namespace flight
 }  // namespace arrow
 
 namespace grpc {
 
 using arrow::flight::FlightData;
-using arrow::flight::internal::GrpcBuffer;
-using arrow::flight::internal::ReadBytesZeroCopy;
 
 using google::protobuf::internal::WireFormatLite;
 using google::protobuf::io::ArrayOutputStream;
@@ -158,163 +102,11 @@ inline arrow::Status FailSerialization(arrow::Status 
status) {
   return status;
 }
 
-// Read internal::FlightData from grpc::ByteBuffer containing FlightData
-// protobuf without copying
-template <>
-class SerializationTraits<FlightData> {
- public:
-  static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* 
own_buffer) {
-    return FailSerialization(Status(
-        StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not 
implemented"));
-  }
-
-  static Status Deserialize(ByteBuffer* buffer, FlightData* out) {
-    if (!buffer) {
-      return FailSerialization(Status(StatusCode::INTERNAL, "No payload"));
-    }
-
-    std::shared_ptr<arrow::Buffer> wrapped_buffer;
-    GRPC_RETURN_NOT_OK(FailSerialization(GrpcBuffer::Wrap(buffer, 
&wrapped_buffer)));
-
-    auto buffer_length = static_cast<int>(wrapped_buffer->size());
-    CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length);
-
-    // TODO(wesm): The 2-parameter version of this function is deprecated
-    pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */);
-
-    // This is the bytes remaining when using CodedInputStream like this
-    while (pb_stream.BytesUntilTotalBytesLimit()) {
-      const uint32_t tag = pb_stream.ReadTag();
-      const int field_number = WireFormatLite::GetTagFieldNumber(tag);
-      switch (field_number) {
-        case pb::FlightData::kFlightDescriptorFieldNumber: {
-          pb::FlightDescriptor pb_descriptor;
-          if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
-            return FailSerialization(
-                Status(StatusCode::INTERNAL, "Unable to parse 
FlightDescriptor"));
-          }
-        } break;
-        case pb::FlightData::kDataHeaderFieldNumber: {
-          if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
-            return FailSerialization(
-                Status(StatusCode::INTERNAL, "Unable to read FlightData 
metadata"));
-          }
-        } break;
-        case pb::FlightData::kDataBodyFieldNumber: {
-          if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
-            return FailSerialization(
-                Status(StatusCode::INTERNAL, "Unable to read FlightData 
body"));
-          }
-        } break;
-        default:
-          DCHECK(false) << "cannot happen";
-      }
-    }
-    buffer->Clear();
-
-    // TODO(wesm): Where and when should we verify that the FlightData is not
-    // malformed or missing components?
-
-    return Status::OK;
-  }
-};
-
 // Write FlightData to a grpc::ByteBuffer without extra copying
-template <>
-class SerializationTraits<IpcPayload> {
- public:
-  static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) {
-    return FailSerialization(grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
-                                          "IpcPayload deserialization not 
implemented"));
-  }
-
-  static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out,
-                                bool* own_buffer) {
-    size_t total_size = 0;
-
-    DCHECK_LT(msg.metadata->size(), kInt32Max);
-    const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
-
-    // 1 byte for metadata tag
-    total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
-
-    int64_t body_size = 0;
-    for (const auto& buffer : msg.body_buffers) {
-      // Buffer may be null when the row length is zero, or when all
-      // entries are invalid.
-      if (!buffer) continue;
-
-      body_size += buffer->size();
-
-      const int64_t remainder = buffer->size() % 8;
-      if (remainder) {
-        body_size += 8 - remainder;
-      }
-    }
-
-    // 2 bytes for body tag
-    // Only written when there are body buffers
-    if (msg.body_length > 0) {
-      total_size +=
-          2 + 
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
-    }
-
-    // TODO(wesm): messages over 2GB unlikely to be yet supported
-    if (total_size > kInt32Max) {
-      return FailSerialization(
-          grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
-                       "Cannot send record batches exceeding 2GB yet"));
-    }
-
-    // Allocate slice, assign to output buffer
-    grpc::Slice slice(total_size);
-
-    // XXX(wesm): for debugging
-    // std::cout << "Writing record batch with total size " << total_size << 
std::endl;
+Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer);
 
-    ArrayOutputStream writer(const_cast<uint8_t*>(slice.begin()),
-                             static_cast<int>(slice.size()));
-    CodedOutputStream pb_stream(&writer);
-
-    // Write header
-    WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
-                             WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
-    pb_stream.WriteVarint32(metadata_size);
-    pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
-                                   static_cast<int>(msg.metadata->size()));
-
-    // Don't write tag if there are no body buffers
-    if (msg.body_length > 0) {
-      // Write body
-      WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
-                               WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
-      pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));
-
-      constexpr uint8_t kPaddingBytes[8] = {0};
-
-      for (const auto& buffer : msg.body_buffers) {
-        // Buffer may be null when the row length is zero, or when all
-        // entries are invalid.
-        if (!buffer) continue;
-
-        pb_stream.WriteRawMaybeAliased(buffer->data(), 
static_cast<int>(buffer->size()));
-
-        // Write padding if not multiple of 8
-        const int remainder = static_cast<int>(buffer->size() % 8);
-        if (remainder) {
-          pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
-        }
-      }
-    }
-
-    DCHECK_EQ(static_cast<int>(total_size), pb_stream.ByteCount());
-
-    // Hand off the slice to the returned ByteBuffer
-    grpc::ByteBuffer tmp(&slice, 1);
-    out->Swap(&tmp);
-    *own_buffer = true;
-    return grpc::Status::OK;
-  }
-};
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
 
 }  // namespace grpc
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index ac5b535..2fef93d 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -16,6 +16,7 @@
 // under the License.
 
 #include "arrow/flight/server.h"
+#include "arrow/flight/protocol-internal.h"
 
 #include <cstdint>
 #include <memory>
@@ -29,8 +30,6 @@
 #include "arrow/status.h"
 #include "arrow/util/logging.h"
 
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
 #include "arrow/flight/internal.h"
 #include "arrow/flight/serialization-internal.h"
 #include "arrow/flight/types.h"
@@ -73,13 +72,9 @@ class FlightMessageReaderImpl : public FlightMessageReader {
       return Status::OK();
     }
 
-    // XXX this cast is undefined behavior
-    auto custom_reader = 
reinterpret_cast<grpc::ServerReader<FlightData>*>(reader_);
-
     FlightData data;
-    // Explicitly specify the override to invoke - otherwise compiler
-    // may invoke through vtable (not updated by reinterpret_cast)
-    if (custom_reader->grpc::ServerReader<FlightData>::Read(&data)) {
+    // Pretend to be pb::FlightData and intercept in SerializationTraits
+    if (reader_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
       std::unique_ptr<ipc::Message> message;
 
       // Validate IPC message
@@ -184,26 +179,23 @@ class FlightServiceImpl : public FlightService::Service {
     std::unique_ptr<FlightDataStream> data_stream;
     GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream));
 
-    // Requires ServerWriter customization in grpc_customizations.h
-    // XXX this cast is undefined behavior
-    auto custom_writer = reinterpret_cast<ServerWriter<IpcPayload>*>(writer);
-
     // Write the schema as the first message in the stream
     IpcPayload schema_payload;
     MemoryPool* pool = default_memory_pool();
     ipc::DictionaryMemo dictionary_memo;
     GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
         *data_stream->schema(), pool, &dictionary_memo, &schema_payload));
-    // Explicitly specify the override to invoke - otherwise compiler
-    // may invoke through vtable (not updated by reinterpret_cast)
-    custom_writer->grpc::ServerWriter<IpcPayload>::Write(schema_payload,
-                                                         grpc::WriteOptions());
+
+    // Pretend to be pb::FlightData, we cast back to IpcPayload in 
SerializationTraits
+    writer->Write(*reinterpret_cast<const pb::FlightData*>(&schema_payload),
+                  grpc::WriteOptions());
 
     while (true) {
       IpcPayload payload;
       GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
       if (payload.metadata == nullptr ||
-          !custom_writer->Write(payload, grpc::WriteOptions())) {
+          !writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+                         grpc::WriteOptions())) {
         // No more messages to write, or connection terminated for some other
         // reason
         break;

Reply via email to