This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 37cb59240b GH-36952: [C++][FlightRPC][Python] Add methods to send
headers (#36956)
37cb59240b is described below
commit 37cb59240b1fa4c5b8e596afdaebf9435c415cec
Author: David Li <[email protected]>
AuthorDate: Mon Jul 31 16:33:28 2023 -0400
GH-36952: [C++][FlightRPC][Python] Add methods to send headers (#36956)
### Rationale for this change
Sending headers/trailers is required for services, but you couldn't do this
before.
### What changes are included in this PR?
Add new methods to directly send headers/trailers.
### Are these changes tested?
Yes
### Are there any user-facing changes?
Yes (new APIs)
* Closes: #36952
Authored-by: David Li <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
cpp/src/arrow/flight/client_middleware.h | 5 ++
cpp/src/arrow/flight/server.h | 9 +++
cpp/src/arrow/flight/test_definitions.cc | 87 ++++++++++++++++++++--
cpp/src/arrow/flight/test_definitions.h | 9 ++-
cpp/src/arrow/flight/transport/grpc/grpc_client.cc | 18 +----
cpp/src/arrow/flight/transport/grpc/grpc_server.cc | 9 +++
.../transport/ucx/flight_transport_ucx_test.cc | 2 +
cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 3 +
python/pyarrow/_flight.pyx | 8 ++
python/pyarrow/includes/libarrow_flight.pxd | 2 +
python/pyarrow/tests/test_flight.py | 44 ++++++++++-
11 files changed, 174 insertions(+), 22 deletions(-)
diff --git a/cpp/src/arrow/flight/client_middleware.h
b/cpp/src/arrow/flight/client_middleware.h
index 5b67e784b9..8e3126553a 100644
--- a/cpp/src/arrow/flight/client_middleware.h
+++ b/cpp/src/arrow/flight/client_middleware.h
@@ -42,6 +42,11 @@ class ARROW_FLIGHT_EXPORT ClientMiddleware {
virtual void SendingHeaders(AddCallHeaders* outgoing_headers) = 0;
/// \brief A callback when headers are received from the server.
+ ///
+ /// This may be called more than once, since servers send both
+ /// headers and trailers. Some implementations (e.g. gRPC-Java, and
+ /// hence Arrow Flight in Java) may consolidate headers into
+ /// trailers if the RPC errored.
virtual void ReceivedHeaders(const CallHeaders& incoming_headers) = 0;
/// \brief A callback after the call has completed.
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 76f1a31706..049c6cee3f 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -122,6 +122,15 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
virtual const std::string& peer_identity() const = 0;
/// \brief The peer address (not validated)
virtual const std::string& peer() const = 0;
+ /// \brief Add a response header. This is only valid before the server
+ /// starts sending the response; generally this isn't an issue unless you
+ /// are implementing FlightDataStream, ResultStream, or similar interfaces
+ /// yourself, or during a DoExchange or DoPut.
+ virtual void AddHeader(const std::string& key, const std::string& value)
const = 0;
+ /// \brief Add a response trailer. This is only valid before the server
+ /// sends the final status; generally this isn't an issue unless your RPC
+ /// handler launches a thread or similar.
+ virtual void AddTrailer(const std::string& key, const std::string& value)
const = 0;
/// \brief Look up a middleware by key. Do not maintain a reference
/// to the object beyond the request body.
/// \return The middleware, or nullptr if not found.
diff --git a/cpp/src/arrow/flight/test_definitions.cc
b/cpp/src/arrow/flight/test_definitions.cc
index 507c5ef404..4e13738004 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -18,17 +18,22 @@
#include "arrow/flight/test_definitions.h"
#include <chrono>
+#include <memory>
+#include <mutex>
#include "arrow/array/array_base.h"
#include "arrow/array/array_dict.h"
#include "arrow/array/util.h"
#include "arrow/flight/api.h"
+#include "arrow/flight/client_middleware.h"
#include "arrow/flight/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/config.h"
#include "arrow/util/logging.h"
+#include "gmock/gmock.h"
#if defined(ARROW_CUDA)
#include "arrow/gpu/cuda_api.h"
@@ -1438,20 +1443,26 @@ class ErrorHandlingTestServer : public FlightServerBase
{
public:
Status GetFlightInfo(const ServerCallContext& context, const
FlightDescriptor& request,
std::unique_ptr<FlightInfo>* info) override {
- if (request.path.size() >= 2) {
+ if (request.path.size() == 1 && request.path[0] == "metadata") {
+ context.AddHeader("x-header", "header-value");
+ context.AddHeader("x-header-bin", "header\x01value");
+ context.AddTrailer("x-trailer", "trailer-value");
+ context.AddTrailer("x-trailer-bin", "trailer\x01value");
+ return Status::Invalid("Expected");
+ } else if (request.path.size() >= 2) {
const int raw_code = std::atoi(request.path[0].c_str());
ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code));
if (request.path.size() == 2) {
- return Status(code, request.path[1]);
+ return {code, request.path[1]};
} else if (request.path.size() == 3) {
- return Status(code, request.path[1],
std::make_shared<TestStatusDetail>());
+ return {code, request.path[1], std::make_shared<TestStatusDetail>()};
} else {
const int raw_code = std::atoi(request.path[2].c_str());
ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code,
TryConvertFlightStatusCode(raw_code));
- return Status(code, request.path[1],
- std::make_shared<FlightStatusDetail>(flight_code,
request.path[3]));
+ return {code, request.path[1],
+ std::make_shared<FlightStatusDetail>(flight_code,
request.path[3])};
}
}
return Status::NotImplemented("NYI");
@@ -1469,20 +1480,70 @@ class ErrorHandlingTestServer : public FlightServerBase
{
return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized",
"extra info");
}
};
+
+class MetadataRecordingClientMiddleware : public ClientMiddleware {
+ public:
+ explicit MetadataRecordingClientMiddleware(
+ std::mutex& mutex, std::vector<std::pair<std::string, std::string>>&
headers)
+ : mutex_(mutex), headers_(headers) {}
+ void SendingHeaders(AddCallHeaders*) override {}
+ void ReceivedHeaders(const CallHeaders& incoming_headers) override {
+ std::lock_guard<std::mutex> guard(mutex_);
+ for (const auto& [key, value] : incoming_headers) {
+ headers_.emplace_back(key, value);
+ }
+ }
+ void CallCompleted(const Status&) override {}
+
+ private:
+ std::mutex& mutex_;
+ std::vector<std::pair<std::string, std::string>>& headers_;
+};
+
+class MetadataRecordingClientMiddlewareFactory : public
ClientMiddlewareFactory {
+ public:
+ void StartCall(const CallInfo&,
+ std::unique_ptr<ClientMiddleware>* middleware) override {
+ *middleware = std::make_unique<MetadataRecordingClientMiddleware>(mutex_,
headers_);
+ }
+
+ std::vector<std::pair<std::string, std::string>> GetHeaders() const {
+ std::lock_guard<std::mutex> guard(mutex_);
+ // Take copy
+ return headers_;
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ std::vector<std::pair<std::string, std::string>> headers_;
+};
} // namespace
+struct ErrorHandlingTest::Impl {
+ std::shared_ptr<MetadataRecordingClientMiddlewareFactory> metadata =
+ std::make_shared<MetadataRecordingClientMiddlewareFactory>();
+};
+
void ErrorHandlingTest::SetUpTest() {
+ impl_ = std::make_shared<Impl>();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
ASSERT_OK(MakeServer<ErrorHandlingTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
- [](FlightClientOptions* options) { return Status::OK(); }));
+ [&](FlightClientOptions* options) {
+ options->middleware.emplace_back(impl_->metadata);
+ return Status::OK();
+ }));
}
void ErrorHandlingTest::TearDownTest() {
ASSERT_OK(client_->Close());
ASSERT_OK(server_->Shutdown());
}
+std::vector<std::pair<std::string, std::string>>
ErrorHandlingTest::GetHeaders() {
+ return impl_->metadata->GetHeaders();
+}
+
void ErrorHandlingTest::TestGetFlightInfo() {
std::unique_ptr<FlightInfo> info;
for (const auto code : kStatusCodes) {
@@ -1518,6 +1579,20 @@ void ErrorHandlingTest::TestGetFlightInfo() {
}
}
+void ErrorHandlingTest::TestGetFlightInfoMetadata() {
+ auto descr = FlightDescriptor::Path({"metadata"});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Expected"),
+ client_->GetFlightInfo(descr));
+ // This is janky because we don't/can't expose grpc::CallContext.
+ // See https://github.com/apache/arrow/issues/34607
+ ASSERT_THAT(GetHeaders(), ::testing::IsSupersetOf({
+ std::make_pair("x-header", "header-value"),
+ std::make_pair("x-header-bin",
"header\x01value"),
+ std::make_pair("x-trailer", "trailer-value"),
+ std::make_pair("x-trailer-bin",
"trailer\x01value"),
+ }));
+}
+
void CheckErrorDetail(const Status& status) {
auto detail = FlightStatusDetail::UnwrapStatus(status);
ASSERT_NE(detail, nullptr) << status.ToString();
diff --git a/cpp/src/arrow/flight/test_definitions.h
b/cpp/src/arrow/flight/test_definitions.h
index 7a7f905f3e..c73bc264b4 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -265,10 +265,16 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public
FlightTest {
// Test methods
void TestGetFlightInfo();
+ void TestGetFlightInfoMetadata();
void TestDoPut();
void TestDoExchange();
- private:
+ protected:
+ struct Impl;
+
+ std::vector<std::pair<std::string, std::string>> GetHeaders();
+
+ std::shared_ptr<Impl> impl_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
};
@@ -277,6 +283,7 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public
FlightTest {
static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value,
\
ARROW_STRINGIFY(FIXTURE) " must inherit from
ErrorHandlingTest"); \
TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
\
+ TEST_F(FIXTURE, TestGetFlightInfoMetadata) { TestGetFlightInfoMetadata(); }
\
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); }
\
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
index 89f0886383..9b40015f9f 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
@@ -107,9 +107,9 @@ class GrpcClientInterceptorAdapter : public
::grpc::experimental::Interceptor {
public:
explicit GrpcClientInterceptorAdapter(
std::vector<std::unique_ptr<ClientMiddleware>> middleware)
- : middleware_(std::move(middleware)), received_headers_(false) {}
+ : middleware_(std::move(middleware)) {}
- void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) {
+ void Intercept(::grpc::experimental::InterceptorBatchMethods* methods)
override {
using InterceptionHookPoints =
::grpc::experimental::InterceptionHookPoints;
if (methods->QueryInterceptionHookPoint(
InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
@@ -142,10 +142,6 @@ class GrpcClientInterceptorAdapter : public
::grpc::experimental::Interceptor {
private:
void ReceivedHeaders(
const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
- if (received_headers_) {
- return;
- }
- received_headers_ = true;
CallHeaders headers;
for (const auto& entry : metadata) {
headers.insert({std::string_view(entry.first.data(),
entry.first.length()),
@@ -157,20 +153,14 @@ class GrpcClientInterceptorAdapter : public
::grpc::experimental::Interceptor {
}
std::vector<std::unique_ptr<ClientMiddleware>> middleware_;
- // When communicating with a gRPC-Java server, the server may not
- // send back headers if the call fails right away. Instead, the
- // headers will be consolidated into the trailers. We don't want to
- // call the client middleware callback twice, so instead track
- // whether we saw headers - if not, then we need to check trailers.
- bool received_headers_;
};
class GrpcClientInterceptorAdapterFactory
: public ::grpc::experimental::ClientInterceptorFactoryInterface {
public:
- GrpcClientInterceptorAdapterFactory(
+ explicit GrpcClientInterceptorAdapterFactory(
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware)
- : middleware_(middleware) {}
+ : middleware_(std::move(middleware)) {}
::grpc::experimental::Interceptor* CreateClientInterceptor(
::grpc::experimental::ClientRpcInfo* info) override {
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index 2c7a1d5e99..50d4ffe002 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -111,6 +111,7 @@ class GrpcServerAuthSender : public ServerAuthSender {
};
class GrpcServerCallContext : public ServerCallContext {
+ public:
explicit GrpcServerCallContext(::grpc::ServerContext* context)
: context_(context), peer_(context_->peer()) {
for (const auto& entry : context->client_metadata()) {
@@ -143,6 +144,14 @@ class GrpcServerCallContext : public ServerCallContext {
return ToGrpcStatus(status, context_);
}
+ void AddHeader(const std::string& key, const std::string& value) const
override {
+ context_->AddInitialMetadata(key, value);
+ }
+
+ void AddTrailer(const std::string& key, const std::string& value) const
override {
+ context_->AddTrailingMetadata(key, value);
+ }
+
ServerMiddleware* GetMiddleware(const std::string& key) const override {
const auto& instance = middleware_map_.find(key);
if (instance == middleware_map_.end()) {
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
index 3ac02bf718..c3481d834f 100644
--- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -103,6 +103,8 @@ class UcxErrorHandlingTest : public ErrorHandlingTest,
public ::testing::Test {
std::string transport() const override { return "ucx"; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
+
+ void TestGetFlightInfoMetadata() { GTEST_SKIP() << "Middleware not
implemented"; }
};
ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest);
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
index 4a573d7429..8bbac34705 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -72,6 +72,9 @@ class UcxServerCallContext : public flight::ServerCallContext
{
public:
const std::string& peer_identity() const override { return peer_; }
const std::string& peer() const override { return peer_; }
+ // Not supported
+ void AddHeader(const std::string& key, const std::string& value) const
override {}
+ void AddTrailer(const std::string& key, const std::string& value) const
override {}
ServerMiddleware* GetMiddleware(const std::string& key) const override {
return nullptr;
}
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c9f5526754..0572ed77b4 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -1756,6 +1756,14 @@ cdef class ServerCallContext(_Weakrefable):
"""Check if the current RPC call has been canceled by the client."""
return self.context.is_cancelled()
+ def add_header(self, key, value):
+ """Add a response header."""
+ self.context.AddHeader(tobytes(key), tobytes(value))
+
+ def add_trailer(self, key, value):
+ """Add a response trailer."""
+ self.context.AddTrailer(tobytes(key), tobytes(value))
+
def get_middleware(self, key):
"""
Get a middleware instance by key.
diff --git a/python/pyarrow/includes/libarrow_flight.pxd
b/python/pyarrow/includes/libarrow_flight.pxd
index 34ba809438..624904ed77 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -257,6 +257,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow"
nogil:
c_string& peer_identity()
c_string& peer()
c_bool is_cancelled()
+ void AddHeader(const c_string& key, const c_string& value)
+ void AddTrailer(const c_string& key, const c_string& value)
CServerMiddleware* GetMiddleware(const c_string& key)
cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
diff --git a/python/pyarrow/tests/test_flight.py
b/python/pyarrow/tests/test_flight.py
index 930523b9f5..6c1c582dce 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -833,7 +833,7 @@ class MultiHeaderClientMiddleware(ClientMiddleware):
def received_headers(self, headers):
# Let the test code know what the last set of headers we
# received were.
- self.factory.last_headers = headers
+ self.factory.last_headers.update(headers)
class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
@@ -2323,3 +2323,45 @@ def test_do_put_does_not_crash_when_schema_is_none():
with pytest.raises(TypeError, match=msg):
client.do_put(flight.FlightDescriptor.for_command('foo'),
schema=None)
+
+
+def test_headers_trailers():
+ """Ensure that server-sent headers/trailers make it through."""
+
+ class HeadersTrailersFlightServer(FlightServerBase):
+ def get_flight_info(self, context, descriptor):
+ context.add_header("x-header", "header-value")
+ context.add_header("x-header-bin", "header\x01value")
+ context.add_trailer("x-trailer", "trailer-value")
+ context.add_trailer("x-trailer-bin", "trailer\x01value")
+ return flight.FlightInfo(
+ pa.schema([]),
+ descriptor,
+ [],
+ -1, -1
+ )
+
+ class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory):
+ def __init__(self):
+ self.headers = []
+
+ def start_call(self, info):
+ return HeadersTrailersMiddleware(self)
+
+ class HeadersTrailersMiddleware(ClientMiddleware):
+ def __init__(self, factory):
+ self.factory = factory
+
+ def received_headers(self, headers):
+ for key, values in headers.items():
+ for value in values:
+ self.factory.headers.append((key, value))
+
+ factory = HeadersTrailersMiddlewareFactory()
+ with HeadersTrailersFlightServer() as server, \
+ FlightClient(("localhost", server.port), middleware=[factory]) as
client:
+ client.get_flight_info(flight.FlightDescriptor.for_path(""))
+ assert ("x-header", "header-value") in factory.headers
+ assert ("x-header-bin", b"header\x01value") in factory.headers
+ assert ("x-trailer", "trailer-value") in factory.headers
+ assert ("x-trailer-bin", b"trailer\x01value") in factory.headers