This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 37cb59240b GH-36952: [C++][FlightRPC][Python] Add methods to send 
headers (#36956)
37cb59240b is described below

commit 37cb59240b1fa4c5b8e596afdaebf9435c415cec
Author: David Li <[email protected]>
AuthorDate: Mon Jul 31 16:33:28 2023 -0400

    GH-36952: [C++][FlightRPC][Python] Add methods to send headers (#36956)
    
    
    
    ### Rationale for this change
    
    Sending headers/trailers is required for services, but you couldn't do this 
before.
    
    ### What changes are included in this PR?
    
    Add new methods to directly send headers/trailers.
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    Yes (new APIs)
    
    * Closes: #36952
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/arrow/flight/client_middleware.h           |  5 ++
 cpp/src/arrow/flight/server.h                      |  9 +++
 cpp/src/arrow/flight/test_definitions.cc           | 87 ++++++++++++++++++++--
 cpp/src/arrow/flight/test_definitions.h            |  9 ++-
 cpp/src/arrow/flight/transport/grpc/grpc_client.cc | 18 +----
 cpp/src/arrow/flight/transport/grpc/grpc_server.cc |  9 +++
 .../transport/ucx/flight_transport_ucx_test.cc     |  2 +
 cpp/src/arrow/flight/transport/ucx/ucx_server.cc   |  3 +
 python/pyarrow/_flight.pyx                         |  8 ++
 python/pyarrow/includes/libarrow_flight.pxd        |  2 +
 python/pyarrow/tests/test_flight.py                | 44 ++++++++++-
 11 files changed, 174 insertions(+), 22 deletions(-)

diff --git a/cpp/src/arrow/flight/client_middleware.h 
b/cpp/src/arrow/flight/client_middleware.h
index 5b67e784b9..8e3126553a 100644
--- a/cpp/src/arrow/flight/client_middleware.h
+++ b/cpp/src/arrow/flight/client_middleware.h
@@ -42,6 +42,11 @@ class ARROW_FLIGHT_EXPORT ClientMiddleware {
   virtual void SendingHeaders(AddCallHeaders* outgoing_headers) = 0;
 
   /// \brief A callback when headers are received from the server.
+  ///
+  /// This may be called more than once, since servers send both
+  /// headers and trailers.  Some implementations (e.g. gRPC-Java, and
+  /// hence Arrow Flight in Java) may consolidate headers into
+  /// trailers if the RPC errored.
   virtual void ReceivedHeaders(const CallHeaders& incoming_headers) = 0;
 
   /// \brief A callback after the call has completed.
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 76f1a31706..049c6cee3f 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -122,6 +122,15 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
   virtual const std::string& peer_identity() const = 0;
   /// \brief The peer address (not validated)
   virtual const std::string& peer() const = 0;
+  /// \brief Add a response header.  This is only valid before the server
+  /// starts sending the response; generally this isn't an issue unless you
+  /// are implementing FlightDataStream, ResultStream, or similar interfaces
+  /// yourself, or during a DoExchange or DoPut.
+  virtual void AddHeader(const std::string& key, const std::string& value) 
const = 0;
+  /// \brief Add a response trailer.  This is only valid before the server
+  /// sends the final status; generally this isn't an issue unless your RPC
+  /// handler launches a thread or similar.
+  virtual void AddTrailer(const std::string& key, const std::string& value) 
const = 0;
   /// \brief Look up a middleware by key. Do not maintain a reference
   /// to the object beyond the request body.
   /// \return The middleware, or nullptr if not found.
diff --git a/cpp/src/arrow/flight/test_definitions.cc 
b/cpp/src/arrow/flight/test_definitions.cc
index 507c5ef404..4e13738004 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -18,17 +18,22 @@
 #include "arrow/flight/test_definitions.h"
 
 #include <chrono>
+#include <memory>
+#include <mutex>
 
 #include "arrow/array/array_base.h"
 #include "arrow/array/array_dict.h"
 #include "arrow/array/util.h"
 #include "arrow/flight/api.h"
+#include "arrow/flight/client_middleware.h"
 #include "arrow/flight/test_util.h"
 #include "arrow/table.h"
 #include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/config.h"
 #include "arrow/util/logging.h"
+#include "gmock/gmock.h"
 
 #if defined(ARROW_CUDA)
 #include "arrow/gpu/cuda_api.h"
@@ -1438,20 +1443,26 @@ class ErrorHandlingTestServer : public FlightServerBase 
{
  public:
   Status GetFlightInfo(const ServerCallContext& context, const 
FlightDescriptor& request,
                        std::unique_ptr<FlightInfo>* info) override {
-    if (request.path.size() >= 2) {
+    if (request.path.size() == 1 && request.path[0] == "metadata") {
+      context.AddHeader("x-header", "header-value");
+      context.AddHeader("x-header-bin", "header\x01value");
+      context.AddTrailer("x-trailer", "trailer-value");
+      context.AddTrailer("x-trailer-bin", "trailer\x01value");
+      return Status::Invalid("Expected");
+    } else if (request.path.size() >= 2) {
       const int raw_code = std::atoi(request.path[0].c_str());
       ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code));
 
       if (request.path.size() == 2) {
-        return Status(code, request.path[1]);
+        return {code, request.path[1]};
       } else if (request.path.size() == 3) {
-        return Status(code, request.path[1], 
std::make_shared<TestStatusDetail>());
+        return {code, request.path[1], std::make_shared<TestStatusDetail>()};
       } else {
         const int raw_code = std::atoi(request.path[2].c_str());
         ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code,
                               TryConvertFlightStatusCode(raw_code));
-        return Status(code, request.path[1],
-                      std::make_shared<FlightStatusDetail>(flight_code, 
request.path[3]));
+        return {code, request.path[1],
+                std::make_shared<FlightStatusDetail>(flight_code, 
request.path[3])};
       }
     }
     return Status::NotImplemented("NYI");
@@ -1469,20 +1480,70 @@ class ErrorHandlingTestServer : public FlightServerBase 
{
     return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", 
"extra info");
   }
 };
+
+class MetadataRecordingClientMiddleware : public ClientMiddleware {
+ public:
+  explicit MetadataRecordingClientMiddleware(
+      std::mutex& mutex, std::vector<std::pair<std::string, std::string>>& 
headers)
+      : mutex_(mutex), headers_(headers) {}
+  void SendingHeaders(AddCallHeaders*) override {}
+  void ReceivedHeaders(const CallHeaders& incoming_headers) override {
+    std::lock_guard<std::mutex> guard(mutex_);
+    for (const auto& [key, value] : incoming_headers) {
+      headers_.emplace_back(key, value);
+    }
+  }
+  void CallCompleted(const Status&) override {}
+
+ private:
+  std::mutex& mutex_;
+  std::vector<std::pair<std::string, std::string>>& headers_;
+};
+
+class MetadataRecordingClientMiddlewareFactory : public 
ClientMiddlewareFactory {
+ public:
+  void StartCall(const CallInfo&,
+                 std::unique_ptr<ClientMiddleware>* middleware) override {
+    *middleware = std::make_unique<MetadataRecordingClientMiddleware>(mutex_, 
headers_);
+  }
+
+  std::vector<std::pair<std::string, std::string>> GetHeaders() const {
+    std::lock_guard<std::mutex> guard(mutex_);
+    // Take copy
+    return headers_;
+  }
+
+ private:
+  mutable std::mutex mutex_;
+  std::vector<std::pair<std::string, std::string>> headers_;
+};
 }  // namespace
 
+struct ErrorHandlingTest::Impl {
+  std::shared_ptr<MetadataRecordingClientMiddlewareFactory> metadata =
+      std::make_shared<MetadataRecordingClientMiddlewareFactory>();
+};
+
 void ErrorHandlingTest::SetUpTest() {
+  impl_ = std::make_shared<Impl>();
   ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), 
"127.0.0.1", 0));
   ASSERT_OK(MakeServer<ErrorHandlingTestServer>(
       location, &server_, &client_,
       [](FlightServerOptions* options) { return Status::OK(); },
-      [](FlightClientOptions* options) { return Status::OK(); }));
+      [&](FlightClientOptions* options) {
+        options->middleware.emplace_back(impl_->metadata);
+        return Status::OK();
+      }));
 }
 void ErrorHandlingTest::TearDownTest() {
   ASSERT_OK(client_->Close());
   ASSERT_OK(server_->Shutdown());
 }
 
+std::vector<std::pair<std::string, std::string>> 
ErrorHandlingTest::GetHeaders() {
+  return impl_->metadata->GetHeaders();
+}
+
 void ErrorHandlingTest::TestGetFlightInfo() {
   std::unique_ptr<FlightInfo> info;
   for (const auto code : kStatusCodes) {
@@ -1518,6 +1579,20 @@ void ErrorHandlingTest::TestGetFlightInfo() {
   }
 }
 
+void ErrorHandlingTest::TestGetFlightInfoMetadata() {
+  auto descr = FlightDescriptor::Path({"metadata"});
+  EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Expected"),
+                                  client_->GetFlightInfo(descr));
+  // This is janky because we don't/can't expose grpc::CallContext.
+  // See https://github.com/apache/arrow/issues/34607
+  ASSERT_THAT(GetHeaders(), ::testing::IsSupersetOf({
+                                std::make_pair("x-header", "header-value"),
+                                std::make_pair("x-header-bin", 
"header\x01value"),
+                                std::make_pair("x-trailer", "trailer-value"),
+                                std::make_pair("x-trailer-bin", 
"trailer\x01value"),
+                            }));
+}
+
 void CheckErrorDetail(const Status& status) {
   auto detail = FlightStatusDetail::UnwrapStatus(status);
   ASSERT_NE(detail, nullptr) << status.ToString();
diff --git a/cpp/src/arrow/flight/test_definitions.h 
b/cpp/src/arrow/flight/test_definitions.h
index 7a7f905f3e..c73bc264b4 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -265,10 +265,16 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public 
FlightTest {
 
   // Test methods
   void TestGetFlightInfo();
+  void TestGetFlightInfoMetadata();
   void TestDoPut();
   void TestDoExchange();
 
- private:
+ protected:
+  struct Impl;
+
+  std::vector<std::pair<std::string, std::string>> GetHeaders();
+
+  std::shared_ptr<Impl> impl_;
   std::unique_ptr<FlightClient> client_;
   std::unique_ptr<FlightServerBase> server_;
 };
@@ -277,6 +283,7 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public 
FlightTest {
   static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value,            
   \
                 ARROW_STRINGIFY(FIXTURE) " must inherit from 
ErrorHandlingTest"); \
   TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }                  
   \
+  TEST_F(FIXTURE, TestGetFlightInfoMetadata) { TestGetFlightInfoMetadata(); }  
   \
   TEST_F(FIXTURE, TestDoPut) { TestDoPut(); }                                  
   \
   TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
 
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc 
b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
index 89f0886383..9b40015f9f 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
@@ -107,9 +107,9 @@ class GrpcClientInterceptorAdapter : public 
::grpc::experimental::Interceptor {
  public:
   explicit GrpcClientInterceptorAdapter(
       std::vector<std::unique_ptr<ClientMiddleware>> middleware)
-      : middleware_(std::move(middleware)), received_headers_(false) {}
+      : middleware_(std::move(middleware)) {}
 
-  void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) {
+  void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) 
override {
     using InterceptionHookPoints = 
::grpc::experimental::InterceptionHookPoints;
     if (methods->QueryInterceptionHookPoint(
             InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
@@ -142,10 +142,6 @@ class GrpcClientInterceptorAdapter : public 
::grpc::experimental::Interceptor {
  private:
   void ReceivedHeaders(
       const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
-    if (received_headers_) {
-      return;
-    }
-    received_headers_ = true;
     CallHeaders headers;
     for (const auto& entry : metadata) {
       headers.insert({std::string_view(entry.first.data(), 
entry.first.length()),
@@ -157,20 +153,14 @@ class GrpcClientInterceptorAdapter : public 
::grpc::experimental::Interceptor {
   }
 
   std::vector<std::unique_ptr<ClientMiddleware>> middleware_;
-  // When communicating with a gRPC-Java server, the server may not
-  // send back headers if the call fails right away. Instead, the
-  // headers will be consolidated into the trailers. We don't want to
-  // call the client middleware callback twice, so instead track
-  // whether we saw headers - if not, then we need to check trailers.
-  bool received_headers_;
 };
 
 class GrpcClientInterceptorAdapterFactory
     : public ::grpc::experimental::ClientInterceptorFactoryInterface {
  public:
-  GrpcClientInterceptorAdapterFactory(
+  explicit GrpcClientInterceptorAdapterFactory(
       std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware)
-      : middleware_(middleware) {}
+      : middleware_(std::move(middleware)) {}
 
   ::grpc::experimental::Interceptor* CreateClientInterceptor(
       ::grpc::experimental::ClientRpcInfo* info) override {
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc 
b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index 2c7a1d5e99..50d4ffe002 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -111,6 +111,7 @@ class GrpcServerAuthSender : public ServerAuthSender {
 };
 
 class GrpcServerCallContext : public ServerCallContext {
+ public:
   explicit GrpcServerCallContext(::grpc::ServerContext* context)
       : context_(context), peer_(context_->peer()) {
     for (const auto& entry : context->client_metadata()) {
@@ -143,6 +144,14 @@ class GrpcServerCallContext : public ServerCallContext {
     return ToGrpcStatus(status, context_);
   }
 
+  void AddHeader(const std::string& key, const std::string& value) const 
override {
+    context_->AddInitialMetadata(key, value);
+  }
+
+  void AddTrailer(const std::string& key, const std::string& value) const 
override {
+    context_->AddTrailingMetadata(key, value);
+  }
+
   ServerMiddleware* GetMiddleware(const std::string& key) const override {
     const auto& instance = middleware_map_.find(key);
     if (instance == middleware_map_.end()) {
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc 
b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
index 3ac02bf718..c3481d834f 100644
--- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -103,6 +103,8 @@ class UcxErrorHandlingTest : public ErrorHandlingTest, 
public ::testing::Test {
   std::string transport() const override { return "ucx"; }
   void SetUp() override { SetUpTest(); }
   void TearDown() override { TearDownTest(); }
+
+  void TestGetFlightInfoMetadata() { GTEST_SKIP() << "Middleware not 
implemented"; }
 };
 ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest);
 
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc 
b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
index 4a573d7429..8bbac34705 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -72,6 +72,9 @@ class UcxServerCallContext : public flight::ServerCallContext 
{
  public:
   const std::string& peer_identity() const override { return peer_; }
   const std::string& peer() const override { return peer_; }
+  // Not supported
+  void AddHeader(const std::string& key, const std::string& value) const 
override {}
+  void AddTrailer(const std::string& key, const std::string& value) const 
override {}
   ServerMiddleware* GetMiddleware(const std::string& key) const override {
     return nullptr;
   }
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c9f5526754..0572ed77b4 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -1756,6 +1756,14 @@ cdef class ServerCallContext(_Weakrefable):
         """Check if the current RPC call has been canceled by the client."""
         return self.context.is_cancelled()
 
+    def add_header(self, key, value):
+        """Add a response header."""
+        self.context.AddHeader(tobytes(key), tobytes(value))
+
+    def add_trailer(self, key, value):
+        """Add a response trailer."""
+        self.context.AddTrailer(tobytes(key), tobytes(value))
+
     def get_middleware(self, key):
         """
         Get a middleware instance by key.
diff --git a/python/pyarrow/includes/libarrow_flight.pxd 
b/python/pyarrow/includes/libarrow_flight.pxd
index 34ba809438..624904ed77 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -257,6 +257,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" 
nogil:
         c_string& peer_identity()
         c_string& peer()
         c_bool is_cancelled()
+        void AddHeader(const c_string& key, const c_string& value)
+        void AddTrailer(const c_string& key, const c_string& value)
         CServerMiddleware* GetMiddleware(const c_string& key)
 
     cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
diff --git a/python/pyarrow/tests/test_flight.py 
b/python/pyarrow/tests/test_flight.py
index 930523b9f5..6c1c582dce 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -833,7 +833,7 @@ class MultiHeaderClientMiddleware(ClientMiddleware):
     def received_headers(self, headers):
         # Let the test code know what the last set of headers we
         # received were.
-        self.factory.last_headers = headers
+        self.factory.last_headers.update(headers)
 
 
 class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
@@ -2323,3 +2323,45 @@ def test_do_put_does_not_crash_when_schema_is_none():
     with pytest.raises(TypeError, match=msg):
         client.do_put(flight.FlightDescriptor.for_command('foo'),
                       schema=None)
+
+
+def test_headers_trailers():
+    """Ensure that server-sent headers/trailers make it through."""
+
+    class HeadersTrailersFlightServer(FlightServerBase):
+        def get_flight_info(self, context, descriptor):
+            context.add_header("x-header", "header-value")
+            context.add_header("x-header-bin", "header\x01value")
+            context.add_trailer("x-trailer", "trailer-value")
+            context.add_trailer("x-trailer-bin", "trailer\x01value")
+            return flight.FlightInfo(
+                pa.schema([]),
+                descriptor,
+                [],
+                -1, -1
+            )
+
+    class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory):
+        def __init__(self):
+            self.headers = []
+
+        def start_call(self, info):
+            return HeadersTrailersMiddleware(self)
+
+    class HeadersTrailersMiddleware(ClientMiddleware):
+        def __init__(self, factory):
+            self.factory = factory
+
+        def received_headers(self, headers):
+            for key, values in headers.items():
+                for value in values:
+                    self.factory.headers.append((key, value))
+
+    factory = HeadersTrailersMiddlewareFactory()
+    with HeadersTrailersFlightServer() as server, \
+            FlightClient(("localhost", server.port), middleware=[factory]) as 
client:
+        client.get_flight_info(flight.FlightDescriptor.for_path(""))
+        assert ("x-header", "header-value") in factory.headers
+        assert ("x-header-bin", b"header\x01value") in factory.headers
+        assert ("x-trailer", "trailer-value") in factory.headers
+        assert ("x-trailer-bin", b"trailer\x01value") in factory.headers

Reply via email to