This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 9f183fc3e3 GH-36512: [C++][FlightRPC] Add async GetFlightInfo client
call (#36517)
9f183fc3e3 is described below
commit 9f183fc3e38abfd2edc492767280e6e917997e2c
Author: David Li <[email protected]>
AuthorDate: Wed Aug 9 11:01:44 2023 -0400
GH-36512: [C++][FlightRPC] Add async GetFlightInfo client call (#36517)
### Rationale for this change
Async is a long-requested feature.
### What changes are included in this PR?
Just the C++ implementation of async GetFlightInfo for the client.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Yes, new APIs.
* Closes: #36512
Authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
---
cpp/src/arrow/flight/CMakeLists.txt | 5 +
cpp/src/arrow/flight/api.h | 1 +
cpp/src/arrow/flight/client.cc | 58 +++++
cpp/src/arrow/flight/client.h | 28 +++
cpp/src/arrow/flight/flight_internals_test.cc | 6 +-
cpp/src/arrow/flight/flight_test.cc | 25 +-
cpp/src/arrow/flight/serialization_internal.cc | 24 +-
cpp/src/arrow/flight/serialization_internal.h | 2 +-
cpp/src/arrow/flight/test_definitions.cc | 267 ++++++++++++++++++++-
cpp/src/arrow/flight/test_definitions.h | 26 ++
cpp/src/arrow/flight/transport.cc | 16 ++
cpp/src/arrow/flight/transport.h | 50 ++--
cpp/src/arrow/flight/transport/grpc/grpc_client.cc | 194 ++++++++++++++-
.../arrow/flight/transport/grpc/util_internal.cc | 107 ++++++++-
.../arrow/flight/transport/grpc/util_internal.h | 8 +
cpp/src/arrow/flight/type_fwd.h | 5 +
cpp/src/arrow/flight/types.cc | 92 ++++++-
cpp/src/arrow/flight/types.h | 82 ++++++-
cpp/src/arrow/flight/types_async.h | 80 ++++++
19 files changed, 1024 insertions(+), 52 deletions(-)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt
b/cpp/src/arrow/flight/CMakeLists.txt
index 7383a7eec9..6e76181533 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -119,6 +119,11 @@ else()
add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental)
endif()
+# Was in a different namespace, or simply not supported, prior to this
+if(ARROW_GRPC_VERSION VERSION_GREATER_EQUAL "1.40")
+ add_definitions(-DGRPC_ENABLE_ASYNC)
+endif()
+
# </KLUDGE> Restore the CXXFLAGS that were modified above
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")
diff --git a/cpp/src/arrow/flight/api.h b/cpp/src/arrow/flight/api.h
index 61c475dc20..ed31b5c8fa 100644
--- a/cpp/src/arrow/flight/api.h
+++ b/cpp/src/arrow/flight/api.h
@@ -27,3 +27,4 @@
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/flight/types.h"
+#include "arrow/flight/types_async.h"
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index ec5377b7c1..eb62ec65ff 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -32,6 +32,7 @@
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/util/future.h"
#include "arrow/util/logging.h"
#include "arrow/flight/client_auth.h"
@@ -39,11 +40,48 @@
#include "arrow/flight/transport.h"
#include "arrow/flight/transport/grpc/grpc_client.h"
#include "arrow/flight/types.h"
+#include "arrow/flight/types_async.h"
namespace arrow {
namespace flight {
+namespace {
+template <typename T>
+class UnaryUnaryAsyncListener : public AsyncListener<T> {
+ public:
+ UnaryUnaryAsyncListener() : future_(arrow::Future<T>::Make()) {}
+
+ void OnNext(T result) override {
+ DCHECK(!result_.ok());
+ result_ = std::move(result);
+ }
+
+ void OnFinish(Status status) override {
+ if (status.ok()) {
+ DCHECK(result_.ok());
+ } else {
+ // Default-initialized result is not ok
+ DCHECK(!result_.ok());
+ result_ = std::move(status);
+ }
+ future_.MarkFinished(std::move(result_));
+ }
+
+ static std::pair<std::shared_ptr<AsyncListener<T>>, arrow::Future<T>> Make()
{
+ auto self = std::make_shared<UnaryUnaryAsyncListener<T>>();
+ // Keep the listener alive by stashing it in the future
+ self->future_.AddCallback([self](const arrow::Result<T>&) {});
+ auto future = self->future_;
+ return std::make_pair(std::move(self), std::move(future));
+ }
+
+ private:
+ arrow::Result<T> result_;
+ arrow::Future<T> future_;
+};
+} // namespace
+
const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail";
FlightCallOptions::FlightCallOptions()
@@ -584,6 +622,24 @@ arrow::Result<std::unique_ptr<FlightInfo>>
FlightClient::GetFlightInfo(
return info;
}
+void FlightClient::GetFlightInfoAsync(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener) {
+ if (auto status = CheckOpen(); !status.ok()) {
+ listener->OnFinish(std::move(status));
+ return;
+ }
+ transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
+}
+
+arrow::Future<FlightInfo> FlightClient::GetFlightInfoAsync(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor) {
+ RETURN_NOT_OK(CheckOpen());
+ auto [listener, future] = UnaryUnaryAsyncListener<FlightInfo>::Make();
+ transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
+ return future;
+}
+
arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
@@ -658,6 +714,8 @@ Status FlightClient::Close() {
return Status::OK();
}
+bool FlightClient::supports_async() const { return
transport_->supports_async(); }
+
Status FlightClient::CheckOpen() const {
if (closed_) {
return Status::Invalid("FlightClient is closed");
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 7204b469a6..cc1c35aaeb 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -271,6 +271,31 @@ class ARROW_FLIGHT_EXPORT FlightClient {
return GetFlightInfo({}, descriptor);
}
+ /// \brief Asynchronous GetFlightInfo.
+ /// \param[in] options Per-RPC options
+ /// \param[in] descriptor the dataset request
+ /// \param[in] listener Callbacks for response and RPC completion
+ ///
+ /// This API is EXPERIMENTAL.
+ void GetFlightInfoAsync(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener);
+ void GetFlightInfoAsync(const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener)
{
+ return GetFlightInfoAsync({}, descriptor, std::move(listener));
+ }
+
+ /// \brief Asynchronous GetFlightInfo returning a Future.
+ /// \param[in] options Per-RPC options
+ /// \param[in] descriptor the dataset request
+ ///
+ /// This API is EXPERIMENTAL.
+ arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightCallOptions&
options,
+ const FlightDescriptor&
descriptor);
+ arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightDescriptor&
descriptor) {
+ return GetFlightInfoAsync({}, descriptor);
+ }
+
/// \brief Request schema for a single flight, which may be an existing
/// dataset or a command to be executed
/// \param[in] options Per-RPC options
@@ -355,6 +380,9 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \since 8.0.0
Status Close();
+ /// \brief Whether this client supports asynchronous methods.
+ bool supports_async() const;
+
private:
FlightClient();
Status CheckOpen() const;
diff --git a/cpp/src/arrow/flight/flight_internals_test.cc
b/cpp/src/arrow/flight/flight_internals_test.cc
index e56bab6db2..72a25018e8 100644
--- a/cpp/src/arrow/flight/flight_internals_test.cc
+++ b/cpp/src/arrow/flight/flight_internals_test.cc
@@ -76,9 +76,7 @@ void TestRoundtrip(const std::vector<FlightType>& values,
ASSERT_OK(internal::ToProto(values[i], &pb_value));
if constexpr (std::is_same_v<FlightType, FlightInfo>) {
- FlightInfo::Data data;
- ASSERT_OK(internal::FromProto(pb_value, &data));
- FlightInfo value(std::move(data));
+ ASSERT_OK_AND_ASSIGN(FlightInfo value, internal::FromProto(pb_value));
EXPECT_EQ(values[i], value);
} else if constexpr (std::is_same_v<FlightType, SchemaResult>) {
std::string data;
@@ -742,5 +740,7 @@ TEST(TransportErrorHandling, ReconstructStatus) {
ASSERT_EQ(detail->extra_info(), "Binary error details");
}
+// TODO: test TransportStatusDetail
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc
b/cpp/src/arrow/flight/flight_test.cc
index 1e7ea9bb00..c36c9eee71 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -40,6 +40,7 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/util/base64.h"
+#include "arrow/util/future.h"
#include "arrow/util/logging.h"
#ifdef GRPCPP_GRPCPP_H
@@ -91,9 +92,16 @@ const char kAuthHeader[] = "authorization";
//------------------------------------------------------------
// Common transport tests
+#ifdef GRPC_ENABLE_ASYNC
+constexpr bool kGrpcSupportsAsync = true;
+#else
+constexpr bool kGrpcSupportsAsync = false;
+#endif
+
class GrpcConnectivityTest : public ConnectivityTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -102,6 +110,7 @@ ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest);
class GrpcDataTest : public DataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -110,6 +119,7 @@ ARROW_FLIGHT_TEST_DATA(GrpcDataTest);
class GrpcDoPutTest : public DoPutTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -118,6 +128,7 @@ ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest);
class GrpcAppMetadataTest : public AppMetadataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -126,6 +137,7 @@ ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest);
class GrpcIpcOptionsTest : public IpcOptionsTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -134,6 +146,7 @@ ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest);
class GrpcCudaDataTest : public CudaDataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
@@ -142,11 +155,21 @@ ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
class GrpcErrorHandlingTest : public ErrorHandlingTest, public ::testing::Test
{
protected:
std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest);
+class GrpcAsyncClientTest : public AsyncClientTest, public ::testing::Test {
+ protected:
+ std::string transport() const override { return "grpc"; }
+ bool supports_async() const override { return kGrpcSupportsAsync; }
+ void SetUp() override { SetUpTest(); }
+ void TearDown() override { TearDownTest(); }
+};
+ARROW_FLIGHT_TEST_ASYNC_CLIENT(GrpcAsyncClientTest);
+
//------------------------------------------------------------
// Ad-hoc gRPC-specific tests
@@ -443,7 +466,7 @@ class TestTls : public ::testing::Test {
Location location_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
- bool server_is_initialized_;
+ bool server_is_initialized_ = false;
};
// A server middleware that rejects all calls.
diff --git a/cpp/src/arrow/flight/serialization_internal.cc
b/cpp/src/arrow/flight/serialization_internal.cc
index b0859e1d91..5d09a1a045 100644
--- a/cpp/src/arrow/flight/serialization_internal.cc
+++ b/cpp/src/arrow/flight/serialization_internal.cc
@@ -230,20 +230,21 @@ Status ToProto(const FlightDescriptor& descriptor,
pb::FlightDescriptor* pb_desc
// FlightInfo
-Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) {
- RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor));
+arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info) {
+ FlightInfo::Data info;
+ RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info.descriptor));
- info->schema = pb_info.schema();
+ info.schema = pb_info.schema();
- info->endpoints.resize(pb_info.endpoint_size());
+ info.endpoints.resize(pb_info.endpoint_size());
for (int i = 0; i < pb_info.endpoint_size(); ++i) {
- RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i]));
+ RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info.endpoints[i]));
}
- info->total_records = pb_info.total_records();
- info->total_bytes = pb_info.total_bytes();
- info->ordered = pb_info.ordered();
- return Status::OK();
+ info.total_records = pb_info.total_records();
+ info.total_bytes = pb_info.total_bytes();
+ info.ordered = pb_info.ordered();
+ return FlightInfo(std::move(info));
}
Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) {
@@ -291,9 +292,8 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo*
pb_info) {
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request) {
- FlightInfo::Data data;
- RETURN_NOT_OK(FromProto(pb_request.info(), &data));
- request->info = std::make_unique<FlightInfo>(std::move(data));
+ ARROW_ASSIGN_OR_RAISE(FlightInfo info, FromProto(pb_request.info()));
+ request->info = std::make_unique<FlightInfo>(std::move(info));
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/serialization_internal.h
b/cpp/src/arrow/flight/serialization_internal.h
index b0a3491ac2..30eb0b3181 100644
--- a/cpp/src/arrow/flight/serialization_internal.h
+++ b/cpp/src/arrow/flight/serialization_internal.h
@@ -59,7 +59,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descr,
FlightDescriptor* descr);
Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint*
endpoint);
Status FromProto(const pb::RenewFlightEndpointRequest& pb_request,
RenewFlightEndpointRequest* request);
-Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info);
+arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info);
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request);
Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
diff --git a/cpp/src/arrow/flight/test_definitions.cc
b/cpp/src/arrow/flight/test_definitions.cc
index 4e13738004..55be3244fb 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -20,6 +20,7 @@
#include <chrono>
#include <memory>
#include <mutex>
+#include <unordered_map>
#include "arrow/array/array_base.h"
#include "arrow/array/array_dict.h"
@@ -27,7 +28,11 @@
#include "arrow/flight/api.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/test_util.h"
+#include "arrow/flight/types.h"
+#include "arrow/flight/types_async.h"
+#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/checked_cast.h"
@@ -123,6 +128,28 @@ void ConnectivityTest::TestBrokenConnection() {
//------------------------------------------------------------
// Tests of data plane methods
+namespace {
+class GetFlightInfoListener : public AsyncListener<FlightInfo> {
+ public:
+ void OnNext(FlightInfo message) override {
+ info = std::move(message);
+ counter++;
+ }
+ void OnFinish(Status status) override {
+ ASSERT_FALSE(future.is_finished());
+ if (status.ok()) {
+ future.MarkFinished(std::move(info));
+ } else {
+ future.MarkFinished(std::move(status));
+ }
+ }
+
+ FlightInfo info = FlightInfo(FlightInfo::Data{});
+ int counter = 0;
+ arrow::Future<FlightInfo> future = arrow::Future<FlightInfo>::Make();
+};
+} // namespace
+
void DataTest::SetUpTest() {
server_ = ExampleTestServer();
@@ -150,6 +177,14 @@ void DataTest::CheckDoGet(
ASSERT_OK_AND_ASSIGN(auto info, client_->GetFlightInfo(descr));
check_endpoints(info->endpoints());
+ if (supports_async()) {
+ auto listener = std::make_shared<GetFlightInfoListener>();
+ client_->GetFlightInfoAsync(descr, listener);
+ ASSERT_FINISHES_OK(listener->future);
+ ASSERT_EQ(1, listener->counter);
+ check_endpoints(listener->future.MoveResult()->endpoints());
+ }
+
ipc::DictionaryMemo dict_memo;
ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&dict_memo));
AssertSchemaEqual(*expected_schema, *schema);
@@ -671,11 +706,11 @@ void DoPutTest::SetUpTest() {
void DoPutTest::TearDownTest() {
ASSERT_OK(client_->Close());
ASSERT_OK(server_->Shutdown());
- reinterpret_cast<DoPutTestServer*>(server_.get())->batches_.clear();
+ checked_cast<DoPutTestServer*>(server_.get())->batches_.clear();
}
void DoPutTest::CheckBatches(const FlightDescriptor& expected_descriptor,
const RecordBatchVector& expected_batches) {
- auto* do_put_server = (DoPutTestServer*)server_.get();
+ auto* do_put_server = static_cast<DoPutTestServer*>(server_.get());
ASSERT_EQ(do_put_server->descriptor_, expected_descriptor);
ASSERT_EQ(do_put_server->batches_.size(), expected_batches.size());
for (size_t i = 0; i < expected_batches.size(); ++i) {
@@ -1410,6 +1445,26 @@ static const std::vector<StatusCode> kStatusCodes = {
StatusCode::AlreadyExists,
};
+// For each Arrow status code, what Flight code do we get?
+static const std::unordered_map<StatusCode, TransportStatusCode>
kTransportStatusCodes = {
+ {StatusCode::OutOfMemory, TransportStatusCode::kUnknown},
+ {StatusCode::KeyError, TransportStatusCode::kNotFound},
+ {StatusCode::TypeError, TransportStatusCode::kUnknown},
+ {StatusCode::Invalid, TransportStatusCode::kInvalidArgument},
+ {StatusCode::IOError, TransportStatusCode::kUnknown},
+ {StatusCode::CapacityError, TransportStatusCode::kUnknown},
+ {StatusCode::IndexError, TransportStatusCode::kUnknown},
+ {StatusCode::Cancelled, TransportStatusCode::kCancelled},
+ {StatusCode::UnknownError, TransportStatusCode::kUnknown},
+ {StatusCode::NotImplemented, TransportStatusCode::kUnimplemented},
+ {StatusCode::SerializationError, TransportStatusCode::kUnknown},
+ {StatusCode::RError, TransportStatusCode::kUnknown},
+ {StatusCode::CodeGenError, TransportStatusCode::kUnknown},
+ {StatusCode::ExpressionValidationError, TransportStatusCode::kUnknown},
+ {StatusCode::ExecutionError, TransportStatusCode::kUnknown},
+ {StatusCode::AlreadyExists, TransportStatusCode::kAlreadyExists},
+};
+
static const std::vector<FlightStatusCode> kFlightStatusCodes = {
FlightStatusCode::Internal, FlightStatusCode::TimedOut,
FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
@@ -1517,6 +1572,15 @@ class MetadataRecordingClientMiddlewareFactory : public
ClientMiddlewareFactory
mutable std::mutex mutex_;
std::vector<std::pair<std::string, std::string>> headers_;
};
+
+class TransportStatusListener : public AsyncListener<FlightInfo> {
+ public:
+ void OnNext(FlightInfo /*message*/) override {}
+ void OnFinish(Status status) override {
future.MarkFinished(std::move(status)); }
+
+ arrow::Future<> future = arrow::Future<>::Make();
+};
+
} // namespace
struct ErrorHandlingTest::Impl {
@@ -1544,6 +1608,98 @@ std::vector<std::pair<std::string, std::string>>
ErrorHandlingTest::GetHeaders()
return impl_->metadata->GetHeaders();
}
+void ErrorHandlingTest::TestAsyncGetFlightInfo() {
+ if (!supports_async()) {
+ GTEST_SKIP() << "Transport does not support async";
+ }
+ // Server-side still does all the junk around trying to translate Arrow
+ // status codes, so this test is a little indirect
+
+ for (const auto code : kStatusCodes) {
+ ARROW_SCOPED_TRACE("C++ status code: ", static_cast<int>(code), ": ",
+ Status::CodeAsString(code));
+
+ // Just the status code
+ {
+ auto descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message"});
+ auto listener = std::make_shared<TransportStatusListener>();
+
+ client_->GetFlightInfoAsync(descr, listener);
+ EXPECT_FINISHES(listener->future);
+ auto detail = TransportStatusDetail::Unwrap(listener->future.status());
+ ASSERT_TRUE(detail.has_value());
+
+ EXPECT_EQ(detail->get().code(), kTransportStatusCodes.at(code));
+ // Exact equality - should have no extra junk in the message
+ EXPECT_EQ(detail->get().message(), "Expected message");
+ }
+
+ // Custom status detail
+ {
+ auto descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message", ""});
+ auto listener = std::make_shared<TransportStatusListener>();
+
+ client_->GetFlightInfoAsync(descr, listener);
+ EXPECT_FINISHES(listener->future);
+ auto detail = TransportStatusDetail::Unwrap(listener->future.status());
+ ASSERT_TRUE(detail.has_value());
+
+ EXPECT_EQ(detail->get().code(), kTransportStatusCodes.at(code));
+ // The server-side arrow::Status-to-TransportStatus conversion puts the
+ // detail into the main error message.
+ EXPECT_EQ(detail->get().message(),
+ "Expected message. Detail: Custom status detail");
+
+ std::string_view arrow_code, arrow_message;
+ for (const auto& [key, value] : detail->get().details()) {
+ if (key == "x-arrow-status") {
+ arrow_code = value;
+ } else if (key == "x-arrow-status-message-bin") {
+ arrow_message = value;
+ }
+ }
+ EXPECT_EQ(arrow_code, std::to_string(static_cast<int>(code)));
+ EXPECT_EQ(arrow_message, "Expected message");
+ }
+
+ // Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ ARROW_SCOPED_TRACE("Flight status code: ",
static_cast<int>(flight_code));
+ auto descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message",
+ std::to_string(static_cast<int>(flight_code)), "Expected detail
message"});
+ auto listener = std::make_shared<TransportStatusListener>();
+
+ client_->GetFlightInfoAsync(descr, listener);
+ EXPECT_FINISHES(listener->future);
+ auto detail = TransportStatusDetail::Unwrap(listener->future.status());
+ ASSERT_TRUE(detail.has_value());
+
+ // The server-side arrow::Status-to-TransportStatus conversion puts the
+ // detail into the main error message.
+ EXPECT_THAT(detail->get().message(),
+ ::testing::HasSubstr("Expected message. Detail:"));
+
+ std::string_view arrow_code, arrow_message, binary_detail;
+ for (const auto& [key, value] : detail->get().details()) {
+ if (key == "x-arrow-status") {
+ arrow_code = value;
+ } else if (key == "x-arrow-status-message-bin") {
+ arrow_message = value;
+ } else if (key == "grpc-status-details-bin") {
+ binary_detail = value;
+ }
+ }
+
+ EXPECT_EQ(arrow_code, std::to_string(static_cast<int>(code)));
+ EXPECT_EQ(arrow_message, "Expected message");
+ EXPECT_EQ(binary_detail, "Expected detail message");
+ }
+ }
+}
+
void ErrorHandlingTest::TestGetFlightInfo() {
std::unique_ptr<FlightInfo> info;
for (const auto code : kStatusCodes) {
@@ -1656,5 +1812,112 @@ void ErrorHandlingTest::TestDoExchange() {
reader_thread.join();
}
+//------------------------------------------------------------
+// Test async clients
+
+void AsyncClientTest::SetUpTest() {
+ if (!supports_async()) {
+ GTEST_SKIP() << "async is not supported";
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
+
+ server_ = ExampleTestServer();
+ FlightServerOptions server_options(location);
+ ASSERT_OK(server_->Init(server_options));
+
+ std::string uri = location.scheme() + "://127.0.0.1:" +
std::to_string(server_->port());
+ ASSERT_OK_AND_ASSIGN(auto real_location, Location::Parse(uri));
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(client_, FlightClient::Connect(real_location,
client_options));
+
+ ASSERT_TRUE(client_->supports_async());
+}
+void AsyncClientTest::TearDownTest() {
+ if (supports_async()) {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
+}
+
+void AsyncClientTest::TestGetFlightInfo() {
+ class Listener : public AsyncListener<FlightInfo> {
+ public:
+ void OnNext(FlightInfo info) override {
+ info_ = std::move(info);
+ counter_++;
+ }
+
+ void OnFinish(Status status) override {
+ ASSERT_FALSE(future_.is_finished());
+ if (status.ok()) {
+ future_.MarkFinished(std::move(info_));
+ } else {
+ future_.MarkFinished(std::move(status));
+ }
+ }
+
+ int counter_ = 0;
+ FlightInfo info_ = FlightInfo(FlightInfo::Data());
+ arrow::Future<FlightInfo> future_ = arrow::Future<FlightInfo>::Make();
+ };
+
+ auto descr = FlightDescriptor::Command("status-outofmemory");
+ auto listener = std::make_shared<Listener>();
+ client_->GetFlightInfoAsync(descr, listener);
+
+ ASSERT_FINISHES_AND_RAISES(UnknownError, listener->future_);
+ ASSERT_THAT(listener->future_.status().ToString(),
::testing::HasSubstr("Sentinel"));
+ ASSERT_EQ(0, listener->counter_);
+}
+
+void AsyncClientTest::TestGetFlightInfoFuture() {
+ auto descr = FlightDescriptor::Command("status-outofmemory");
+ auto future = client_->GetFlightInfoAsync(descr);
+ ASSERT_FINISHES_AND_RAISES(UnknownError, future);
+ ASSERT_THAT(future.status().ToString(), ::testing::HasSubstr("Sentinel"));
+
+ descr = FlightDescriptor::Command("my_command");
+ future = client_->GetFlightInfoAsync(descr);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto info, future);
+ // See test_util.cc:ExampleFlightInfo
+ ASSERT_EQ(descr, info.descriptor());
+ ASSERT_EQ(1000, info.total_records());
+ ASSERT_EQ(100000, info.total_bytes());
+}
+
+void AsyncClientTest::TestListenerLifetime() {
+ arrow::Future<FlightInfo> future = arrow::Future<FlightInfo>::Make();
+
+ class Listener : public AsyncListener<FlightInfo> {
+ public:
+ void OnNext(FlightInfo info) override { info_ = std::move(info); }
+
+ void OnFinish(Status status) override {
+ if (status.ok()) {
+ future_.MarkFinished(std::move(info_));
+ } else {
+ future_.MarkFinished(std::move(status));
+ }
+ }
+
+ FlightInfo info_ = FlightInfo(FlightInfo::Data());
+ arrow::Future<FlightInfo> future_;
+ };
+
+ // Bad client code: don't retain a reference to the listener, which owns the
+ // RPC state. We should still be able to get the result without crashing.
(The
+ // RPC state is disposed of in the background via the 'garbage bin' in the
+ // gRPC client implementation.)
+ {
+ auto descr = FlightDescriptor::Command("my_command");
+ auto listener = std::make_shared<Listener>();
+ listener->future_ = future;
+ client_->GetFlightInfoAsync(descr, std::move(listener));
+ }
+
+ ASSERT_FINISHES_OK(future);
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_definitions.h
b/cpp/src/arrow/flight/test_definitions.h
index c73bc264b4..1e0e8c209a 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -40,6 +40,7 @@ namespace flight {
class ARROW_FLIGHT_EXPORT FlightTest {
protected:
virtual std::string transport() const = 0;
+ virtual bool supports_async() const { return false; }
virtual void SetUpTest() {}
virtual void TearDownTest() {}
};
@@ -266,6 +267,7 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public
FlightTest {
// Test methods
void TestGetFlightInfo();
void TestGetFlightInfoMetadata();
+ void TestAsyncGetFlightInfo();
void TestDoPut();
void TestDoExchange();
@@ -282,10 +284,34 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public
FlightTest {
#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE)
\
static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value,
\
ARROW_STRINGIFY(FIXTURE) " must inherit from
ErrorHandlingTest"); \
+ TEST_F(FIXTURE, TestAsyncGetFlightInfo) { TestAsyncGetFlightInfo(); }
\
TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
\
TEST_F(FIXTURE, TestGetFlightInfoMetadata) { TestGetFlightInfoMetadata(); }
\
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); }
\
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
+/// \brief Tests of the async client.
+class ARROW_FLIGHT_EXPORT AsyncClientTest : public FlightTest {
+ public:
+ void SetUpTest() override;
+ void TearDownTest() override;
+
+ // Test methods
+ void TestGetFlightInfo();
+ void TestGetFlightInfoFuture();
+ void TestListenerLifetime();
+
+ private:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+#define ARROW_FLIGHT_TEST_ASYNC_CLIENT(FIXTURE)
\
+ static_assert(std::is_base_of<AsyncClientTest, FIXTURE>::value,
\
+ ARROW_STRINGIFY(FIXTURE) " must inherit from
AsyncClientTest"); \
+ TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
\
+ TEST_F(FIXTURE, TestGetFlightInfoFuture) { TestGetFlightInfoFuture(); }
\
+ TEST_F(FIXTURE, TestListenerLifetime) { TestListenerLifetime(); }
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport.cc
b/cpp/src/arrow/flight/transport.cc
index a0281ffd61..88228f2503 100644
--- a/cpp/src/arrow/flight/transport.cc
+++ b/cpp/src/arrow/flight/transport.cc
@@ -24,6 +24,7 @@
#include "arrow/flight/client_auth.h"
#include "arrow/flight/transport_server.h"
#include "arrow/flight/types.h"
+#include "arrow/flight/types_async.h"
#include "arrow/ipc/message.h"
#include "arrow/result.h"
#include "arrow/status.h"
@@ -74,6 +75,11 @@ Status ClientTransport::GetFlightInfo(const
FlightCallOptions& options,
std::unique_ptr<FlightInfo>* info) {
return Status::NotImplemented("GetFlightInfo for this transport");
}
+void ClientTransport::GetFlightInfoAsync(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener) {
+ listener->OnFinish(Status::NotImplemented("Async GetFlightInfo for this
transport"));
+}
arrow::Result<std::unique_ptr<SchemaResult>> ClientTransport::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
return Status::NotImplemented("GetSchema for this transport");
@@ -95,6 +101,16 @@ Status ClientTransport::DoExchange(const FlightCallOptions&
options,
std::unique_ptr<ClientDataStream>* stream) {
return Status::NotImplemented("DoExchange for this transport");
}
+void ClientTransport::SetAsyncRpc(AsyncListenerBase* listener,
+ std::unique_ptr<AsyncRpc>&& rpc) {
+ listener->rpc_state_ = std::move(rpc);
+}
+AsyncRpc* ClientTransport::GetAsyncRpc(AsyncListenerBase* listener) {
+ return listener->rpc_state_.get();
+}
+std::unique_ptr<AsyncRpc> ClientTransport::ReleaseAsyncRpc(AsyncListenerBase*
listener) {
+ return std::move(listener->rpc_state_);
+}
class TransportRegistry::Impl final {
public:
diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h
index 6406734e6e..69605d2112 100644
--- a/cpp/src/arrow/flight/transport.h
+++ b/cpp/src/arrow/flight/transport.h
@@ -64,7 +64,9 @@
#include <vector>
#include "arrow/flight/type_fwd.h"
+#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
+#include "arrow/ipc/options.h"
#include "arrow/type_fwd.h"
namespace arrow {
@@ -182,6 +184,9 @@ class ARROW_FLIGHT_EXPORT ClientTransport {
virtual Status GetFlightInfo(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info);
+ virtual void GetFlightInfoAsync(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>>
listener);
virtual arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor);
virtual Status ListFlights(const FlightCallOptions& options, const Criteria&
criteria,
@@ -192,6 +197,12 @@ class ARROW_FLIGHT_EXPORT ClientTransport {
std::unique_ptr<ClientDataStream>* stream);
virtual Status DoExchange(const FlightCallOptions& options,
std::unique_ptr<ClientDataStream>* stream);
+
+ virtual bool supports_async() const { return false; }
+
+ static void SetAsyncRpc(AsyncListenerBase* listener,
std::unique_ptr<AsyncRpc>&& rpc);
+ static AsyncRpc* GetAsyncRpc(AsyncListenerBase* listener);
+ static std::unique_ptr<AsyncRpc> ReleaseAsyncRpc(AsyncListenerBase*
listener);
};
/// A registry of transport implementations.
@@ -223,24 +234,33 @@ ARROW_FLIGHT_EXPORT
TransportRegistry* GetDefaultTransportRegistry();
//------------------------------------------------------------
-// Error propagation helpers
+// Async APIs
-/// \brief Abstract status code as per the Flight specification.
-enum class TransportStatusCode {
- kOk = 0,
- kUnknown = 1,
- kInternal = 2,
- kInvalidArgument = 3,
- kTimedOut = 4,
- kNotFound = 5,
- kAlreadyExists = 6,
- kCancelled = 7,
- kUnauthenticated = 8,
- kUnauthorized = 9,
- kUnimplemented = 10,
- kUnavailable = 11,
+/// \brief Transport-specific state for an async RPC.
+///
+/// Transport implementations may subclass this to store their own
+/// state, and stash an instance in a user-supplied AsyncListener via
+/// ClientTransport::GetAsyncRpc and ClientTransport::SetAsyncRpc.
+///
+/// This API is EXPERIMENTAL.
+class ARROW_FLIGHT_EXPORT AsyncRpc {
+ public:
+ virtual ~AsyncRpc() = default;
+ /// \brief Request cancellation of the RPC.
+ virtual void TryCancel() {}
+
+ /// Only needed for DoPut/DoExchange
+ virtual void Begin(const FlightDescriptor& descriptor,
std::shared_ptr<Schema> schema) {
+ }
+ /// Only needed for DoPut/DoExchange
+ virtual void Write(arrow::flight::FlightStreamChunk chunk) {}
+ /// Only needed for DoPut/DoExchange
+ virtual void DoneWriting() {}
};
+//------------------------------------------------------------
+// Error propagation helpers
+
/// \brief Abstract error status.
///
/// Transport implementations may use side channels (e.g. HTTP
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
index 9b40015f9f..7108f35549 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
@@ -17,15 +17,19 @@
#include "arrow/flight/transport/grpc/grpc_client.h"
+#include <condition_variable>
+#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
+#include <thread>
#include <unordered_map>
#include <utility>
#include <grpcpp/grpcpp.h>
+#include <grpcpp/support/client_callback.h>
#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
#include <grpcpp/security/tls_credentials_options.h>
#endif
@@ -51,6 +55,7 @@
#include "arrow/flight/transport/grpc/serialization_internal.h"
#include "arrow/flight/transport/grpc/util_internal.h"
#include "arrow/flight/types.h"
+#include "arrow/flight/types_async.h"
namespace arrow {
@@ -549,6 +554,127 @@ class GrpcResultStream : public ResultStream {
std::unique_ptr<::grpc::ClientReader<pb::Result>> stream_;
};
+#ifdef GRPC_ENABLE_ASYNC
+/// Force destruction to wait for RPC completion.
+class FinishedFlag {
+ public:
+ ~FinishedFlag() { Wait(); }
+
+ void Finish() {
+ std::lock_guard<std::mutex> guard(mutex_);
+ finished_ = true;
+ cv_.notify_all();
+ }
+ void Wait() const {
+ std::unique_lock<std::mutex> guard(mutex_);
+ cv_.wait(guard, [&]() { return finished_; });
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ mutable std::condition_variable cv_;
+ bool finished_{false};
+};
+
+// XXX: it appears that if we destruct gRPC resources (like a
+// ClientContext) from a gRPC callback, we will be running on a gRPC
+// thread and we may attempt to join ourselves (because gRPC
+// apparently refcounts threads). Avoid that by transferring gRPC
+// resources to a dedicated thread for destruction.
+class GrpcGarbageBin {
+ public:
+ GrpcGarbageBin() {
+ grpc_destructor_thread_ = std::thread([&]() {
+ while (true) {
+ std::unique_lock<std::mutex> guard(grpc_destructor_mutex_);
+ grpc_destructor_cv_.wait(guard,
+ [&]() { return !running_ ||
!garbage_bin_.empty(); });
+
+ garbage_bin_.clear();
+
+ if (!running_) return;
+ }
+ });
+ }
+
+ void Dispose(std::unique_ptr<internal::AsyncRpc> trash) {
+ std::unique_lock<std::mutex> guard(grpc_destructor_mutex_);
+ if (!running_) return;
+ garbage_bin_.push_back(std::move(trash));
+ grpc_destructor_cv_.notify_all();
+ }
+
+ void Stop() {
+ {
+ std::unique_lock<std::mutex> guard(grpc_destructor_mutex_);
+ running_ = false;
+ grpc_destructor_cv_.notify_all();
+ }
+ grpc_destructor_thread_.join();
+ }
+
+ private:
+ bool running_ = true;
+ std::thread grpc_destructor_thread_;
+ std::mutex grpc_destructor_mutex_;
+ std::condition_variable grpc_destructor_cv_;
+ std::deque<std::unique_ptr<internal::AsyncRpc>> garbage_bin_;
+};
+
+template <typename Result, typename Request, typename Response>
+class UnaryUnaryAsyncCall : public ::grpc::ClientUnaryReactor, public
internal::AsyncRpc {
+ public:
+ ClientRpc rpc;
+ std::shared_ptr<AsyncListener<Result>> listener;
+ std::shared_ptr<GrpcGarbageBin> garbage_bin_;
+
+ Request pb_request;
+ Response pb_response;
+ Status client_status;
+
+ // Destruct last
+ FinishedFlag finished;
+
+ explicit UnaryUnaryAsyncCall(const FlightCallOptions& options,
+ std::shared_ptr<AsyncListener<Result>> listener,
+ std::shared_ptr<GrpcGarbageBin> garbage_bin)
+ : rpc(options),
+ listener(std::move(listener)),
+ garbage_bin_(std::move(garbage_bin)) {}
+
+ void TryCancel() override { rpc.context.TryCancel(); }
+
+ void OnDone(const ::grpc::Status& status) override {
+ if (status.ok()) {
+ auto result = internal::FromProto(pb_response);
+ client_status = result.status();
+ if (client_status.ok()) {
+ listener->OnNext(std::move(result).MoveValueUnsafe());
+ }
+ }
+ Finish(status);
+ }
+
+ void Finish(const ::grpc::Status& status) {
+ auto listener = std::move(this->listener);
+ listener->OnFinish(
+ CombinedTransportStatus(status, std::move(client_status),
&rpc.context));
+ // SetAsyncRpc may trigger destruction, so Finish() first
+ finished.Finish();
+ // Instead of potentially destructing gRPC resources here,
+ // transfer it to a dedicated background thread
+ garbage_bin_->Dispose(
+ flight::internal::ClientTransport::ReleaseAsyncRpc(listener.get()));
+ }
+};
+
+#define LISTENER_NOT_OK(LISTENER, EXPR) \
+ if (auto arrow_status = (EXPR); !arrow_status.ok()) { \
+ (LISTENER)->OnFinish(std::move(arrow_status)); \
+ return; \
+ }
+#endif
+
class GrpcClientImpl : public internal::ClientTransport {
public:
static arrow::Result<std::unique_ptr<internal::ClientTransport>> Make() {
@@ -702,14 +828,30 @@ class GrpcClientImpl : public internal::ClientTransport {
stub_ = pb::FlightService::NewStub(
::grpc::experimental::CreateCustomChannelWithInterceptors(
grpc_uri.str(), creds, args, std::move(interceptors)));
+
+#ifdef GRPC_ENABLE_ASYNC
+ garbage_bin_ = std::make_shared<GrpcGarbageBin>();
+#endif
+
return Status::OK();
}
Status Close() override {
- // TODO(ARROW-15473): if we track ongoing RPCs, we can cancel them first
- // gRPC does not offer a real Close(). We could reset() the gRPC
- // client but that can cause gRPC to hang in shutdown
- // (ARROW-15793).
+#ifdef GRPC_ENABLE_ASYNC
+ // TODO(https://github.com/apache/arrow/issues/30949): if there are async
+ // RPCs running when the client is stopped, then when they go to use the
+ // garbage bin, they'll instead synchronously dispose of resources from
+ // the callback thread, and will likely crash. We could instead cancel
+ // them first and wait for completion before stopping the thread, but
+ // tracking all of the RPCs may be unacceptable overhead for clients that
+ // are making many small concurrent RPC calls, so it remains to be seen
+ // whether there's a pressing need for this.
+ garbage_bin_->Stop();
+#endif
+ // TODO(https://github.com/apache/arrow/issues/30949): if we track ongoing
+ // RPCs, we can cancel them first gRPC does not offer a real Close(). We
+ // could reset() the gRPC client but that can cause gRPC to hang in
+ // shutdown (https://github.com/apache/arrow/issues/31235).
return Status::OK();
}
@@ -745,8 +887,7 @@ class GrpcClientImpl : public internal::ClientTransport {
pb::FlightInfo pb_info;
while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) {
- FlightInfo::Data info_data;
- RETURN_NOT_OK(internal::FromProto(pb_info, &info_data));
+ ARROW_ASSIGN_OR_RAISE(FlightInfo info_data,
internal::FromProto(pb_info));
flights.emplace_back(std::move(info_data));
}
if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
@@ -796,9 +937,8 @@ class GrpcClientImpl : public internal::ClientTransport {
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response),
&rpc.context);
RETURN_NOT_OK(s);
- FlightInfo::Data info_data;
- RETURN_NOT_OK(internal::FromProto(pb_response, &info_data));
- info->reset(new FlightInfo(std::move(info_data)));
+ ARROW_ASSIGN_OR_RAISE(auto info_data, internal::FromProto(pb_response));
+ *info = std::make_unique<FlightInfo>(std::move(info_data));
return Status::OK();
}
@@ -855,6 +995,36 @@ class GrpcClientImpl : public internal::ClientTransport {
return Status::OK();
}
+#ifdef GRPC_ENABLE_ASYNC
+ void GetFlightInfoAsync(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener)
override {
+ using AsyncCall =
+ UnaryUnaryAsyncCall<FlightInfo, pb::FlightDescriptor, pb::FlightInfo>;
+ auto call = std::make_unique<AsyncCall>(options, listener, garbage_bin_);
+ LISTENER_NOT_OK(listener, internal::ToProto(descriptor,
&call->pb_request));
+ LISTENER_NOT_OK(listener, call->rpc.SetToken(auth_handler_.get()));
+
+ stub_->experimental_async()->GetFlightInfo(&call->rpc.context,
&call->pb_request,
+ &call->pb_response, call.get());
+ ClientTransport::SetAsyncRpc(listener.get(), std::move(call));
+ arrow::internal::checked_cast<AsyncCall*>(
+ ClientTransport::GetAsyncRpc(listener.get()))
+ ->StartCall();
+ }
+
+ bool supports_async() const override { return true; }
+#else
+ void GetFlightInfoAsync(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::shared_ptr<AsyncListener<FlightInfo>> listener)
override {
+ listener->OnFinish(
+ Status::NotImplemented("gRPC 1.40 or newer is required to use async"));
+ }
+
+ bool supports_async() const override { return false; }
+#endif
+
private:
Status AuthenticateInternal(ClientRpc& rpc) {
std::shared_ptr<
@@ -894,6 +1064,10 @@ class GrpcClientImpl : public internal::ClientTransport {
::GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig>
noop_auth_check_;
#endif
+
+#ifdef GRPC_ENABLE_ASYNC
+ std::shared_ptr<GrpcGarbageBin> garbage_bin_;
+#endif
};
std::once_flag kGrpcClientTransportInitialized;
} // namespace
@@ -907,6 +1081,8 @@ void InitializeFlightGrpcClient() {
});
}
+#undef LISTENER_NOT_OK
+
} // namespace grpc
} // namespace transport
} // namespace flight
diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc
b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
index f431fc30ec..88ec15bc66 100644
--- a/cpp/src/arrow/flight/transport/grpc/util_internal.cc
+++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
@@ -20,6 +20,7 @@
#include <cstdlib>
#include <map>
#include <memory>
+#include <optional>
#include <string>
#include <grpcpp/grpcpp.h>
@@ -28,6 +29,7 @@
#include "arrow/flight/types.h"
#include "arrow/status.h"
#include "arrow/util/string.h"
+#include "arrow/util/string_builder.h"
namespace arrow {
@@ -37,6 +39,8 @@ namespace flight {
namespace transport {
namespace grpc {
+using internal::TransportStatus;
+
const char* kGrpcAuthHeader = "auth-token-bin";
const char* kGrpcStatusCodeHeader = "x-arrow-status";
const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
@@ -82,11 +86,106 @@ static bool FromGrpcContext(const ::grpc::ClientContext&
ctx,
return true;
}
+static TransportStatus TransportStatusFromGrpc(const ::grpc::Status&
grpc_status) {
+ switch (grpc_status.error_code()) {
+ case ::grpc::StatusCode::OK:
+ return TransportStatus{TransportStatusCode::kOk, ""};
+ case ::grpc::StatusCode::CANCELLED:
+ return TransportStatus{TransportStatusCode::kCancelled,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::UNKNOWN:
+ return TransportStatus{TransportStatusCode::kUnknown,
grpc_status.error_message()};
+ case ::grpc::StatusCode::INVALID_ARGUMENT:
+ return TransportStatus{TransportStatusCode::kInvalidArgument,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::DEADLINE_EXCEEDED:
+ return TransportStatus{TransportStatusCode::kTimedOut,
grpc_status.error_message()};
+ case ::grpc::StatusCode::NOT_FOUND:
+ return TransportStatus{TransportStatusCode::kNotFound,
grpc_status.error_message()};
+ case ::grpc::StatusCode::ALREADY_EXISTS:
+ return TransportStatus{TransportStatusCode::kAlreadyExists,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::PERMISSION_DENIED:
+ return TransportStatus{TransportStatusCode::kUnauthorized,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::RESOURCE_EXHAUSTED:
+ return TransportStatus{TransportStatusCode::kUnavailable,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::FAILED_PRECONDITION:
+ return TransportStatus{TransportStatusCode::kUnavailable,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::ABORTED:
+ return TransportStatus{TransportStatusCode::kUnavailable,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::OUT_OF_RANGE:
+ return TransportStatus{TransportStatusCode::kInvalidArgument,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::UNIMPLEMENTED:
+ return TransportStatus{TransportStatusCode::kUnimplemented,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::INTERNAL:
+ return TransportStatus{TransportStatusCode::kInternal,
grpc_status.error_message()};
+ case ::grpc::StatusCode::UNAVAILABLE:
+ return TransportStatus{TransportStatusCode::kUnavailable,
+ grpc_status.error_message()};
+ case ::grpc::StatusCode::DATA_LOSS:
+ return TransportStatus{TransportStatusCode::kInternal,
grpc_status.error_message()};
+ case ::grpc::StatusCode::UNAUTHENTICATED:
+ return TransportStatus{TransportStatusCode::kUnauthenticated,
+ grpc_status.error_message()};
+ default:
+ return TransportStatus{TransportStatusCode::kUnknown,
+ util::StringBuilder("(",
grpc_status.error_code(), ")",
+ grpc_status.error_message())};
+ }
+}
+
+Status CombinedTransportStatus(const ::grpc::Status& grpc_status,
+ arrow::Status arrow_status,
::grpc::ClientContext* ctx) {
+ if (grpc_status.ok() && arrow_status.ok()) {
+ return Status::OK();
+ } else if (grpc_status.ok() && !arrow_status.ok()) {
+ return arrow_status;
+ }
+
+ // Can't share with FromGrpcCode because that function sometimes constructs
an Arrow
+ // Status directly
+ const TransportStatus base_status = TransportStatusFromGrpc(grpc_status);
+
+ std::vector<std::pair<std::string, std::string>> details;
+ if (!grpc_status.ok() && ctx) {
+ // Attach rich error details
+ const std::multimap<::grpc::string_ref, ::grpc::string_ref>& trailers =
+ ctx->GetServerTrailingMetadata();
+
+ for (const auto key : {
+ // gRPC error details
+ kBinaryErrorDetailsKey,
+ // Sync C++ servers send information about the Arrow status
+ kGrpcStatusCodeHeader,
+ kGrpcStatusMessageHeader,
+ kGrpcStatusDetailHeader,
+ }) {
+ for (auto [it, end] = trailers.equal_range(key); it != end; it++) {
+ details.emplace_back(key, std::string(it->second.data(),
it->second.size()));
+ }
+ }
+ }
+
+ if (arrow_status.ok()) {
+ arrow_status = base_status.ToStatus();
+ }
+
+ if (!details.empty()) {
+ return arrow_status.WithDetail(std::make_shared<TransportStatusDetail>(
+ base_status.code, std::move(base_status.message), std::move(details)));
+ }
+ return arrow_status;
+}
+
/// Convert a gRPC status to an Arrow status, ignoring any
/// implementation-defined headers that encode further detail.
static Status FromGrpcCode(const ::grpc::Status& grpc_status) {
- using internal::TransportStatus;
- using internal::TransportStatusCode;
switch (grpc_status.error_code()) {
case ::grpc::StatusCode::OK:
return Status::OK();
@@ -169,8 +268,6 @@ Status FromGrpcStatus(const ::grpc::Status& grpc_status,
::grpc::ClientContext*
/// Convert an Arrow status to a gRPC status.
static ::grpc::Status ToRawGrpcStatus(const Status& arrow_status) {
- using internal::TransportStatus;
- using internal::TransportStatusCode;
if (arrow_status.ok()) return ::grpc::Status::OK;
TransportStatus transport_status = TransportStatus::FromStatus(arrow_status);
@@ -215,7 +312,7 @@ static ::grpc::Status ToRawGrpcStatus(const Status&
arrow_status) {
grpc_code = ::grpc::StatusCode::UNKNOWN;
break;
}
- return ::grpc::Status(grpc_code, std::move(transport_status.message));
+ return {grpc_code, std::move(transport_status.message)};
}
/// Convert an Arrow status to a gRPC status, and add extra headers to
diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.h
b/cpp/src/arrow/flight/transport/grpc/util_internal.h
index a267e55654..5687c7a872 100644
--- a/cpp/src/arrow/flight/transport/grpc/util_internal.h
+++ b/cpp/src/arrow/flight/transport/grpc/util_internal.h
@@ -18,6 +18,7 @@
#pragma once
#include "arrow/flight/transport/grpc/protocol_grpc_internal.h"
+#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
#include "arrow/util/macros.h"
@@ -71,6 +72,13 @@ extern const char* kGrpcStatusDetailHeader;
ARROW_FLIGHT_EXPORT
extern const char* kBinaryErrorDetailsKey;
+/// \brief Combine a gRPC status, possible client-side Arrow status,
+/// and a gRPC ClientContext into a transport status.
+ARROW_FLIGHT_EXPORT
+Status CombinedTransportStatus(const ::grpc::Status& grpc_status,
+ arrow::Status arrow_status,
+ ::grpc::ClientContext* ctx = nullptr);
+
/// Convert a gRPC status to an Arrow status. Optionally, provide a
/// ClientContext to recover the exact Arrow status if it was passed
/// over the wire.
diff --git a/cpp/src/arrow/flight/type_fwd.h b/cpp/src/arrow/flight/type_fwd.h
index c82c4e6d8f..ac2effbc91 100644
--- a/cpp/src/arrow/flight/type_fwd.h
+++ b/cpp/src/arrow/flight/type_fwd.h
@@ -24,6 +24,10 @@ class Uri;
namespace flight {
struct Action;
struct ActionType;
+template <typename T>
+class AsyncListener;
+class AsyncListenerBase;
+class AsyncRpc;
struct BasicAuth;
class ClientAuthHandler;
class ClientMiddleware;
@@ -51,6 +55,7 @@ class ServerMiddleware;
class ServerMiddlewareFactory;
struct Ticket;
namespace internal {
+class AsyncRpc;
class ClientTransport;
struct FlightData;
class ServerTransport;
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 7c72595ed6..b7cd55325b 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -24,12 +24,16 @@
#include "arrow/buffer.h"
#include "arrow/flight/serialization_internal.h"
+#include "arrow/flight/types_async.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/reader.h"
#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/util/base64.h"
#include "arrow/util/formatting.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string.h"
#include "arrow/util/string_builder.h"
#include "arrow/util/uri.h"
@@ -299,9 +303,8 @@ arrow::Result<std::unique_ptr<FlightInfo>>
FlightInfo::Deserialize(
if (!pb_info.ParseFromZeroCopyStream(&input)) {
return Status::Invalid("Not a valid FlightInfo");
}
- FlightInfo::Data data;
- RETURN_NOT_OK(internal::FromProto(pb_info, &data));
- return std::make_unique<FlightInfo>(std::move(data));
+ ARROW_ASSIGN_OR_RAISE(FlightInfo info, internal::FromProto(pb_info));
+ return std::make_unique<FlightInfo>(std::move(info));
}
std::string FlightInfo::ToString() const {
@@ -873,5 +876,88 @@ arrow::Result<std::string> BasicAuth::SerializeToString()
const {
return out;
}
+//------------------------------------------------------------
+// Error propagation helpers
+
+std::string ToString(TransportStatusCode code) {
+ switch (code) {
+ case TransportStatusCode::kOk:
+ return "kOk";
+ case TransportStatusCode::kUnknown:
+ return "kUnknown";
+ case TransportStatusCode::kInternal:
+ return "kInternal";
+ case TransportStatusCode::kInvalidArgument:
+ return "kInvalidArgument";
+ case TransportStatusCode::kTimedOut:
+ return "kTimedOut";
+ case TransportStatusCode::kNotFound:
+ return "kNotFound";
+ case TransportStatusCode::kAlreadyExists:
+ return "kAlreadyExists";
+ case TransportStatusCode::kCancelled:
+ return "kCancelled";
+ case TransportStatusCode::kUnauthenticated:
+ return "kUnauthenticated";
+ case TransportStatusCode::kUnauthorized:
+ return "kUnauthorized";
+ case TransportStatusCode::kUnimplemented:
+ return "kUnimplemented";
+ case TransportStatusCode::kUnavailable:
+ return "kUnavailable";
+ }
+ return "(unknown code)";
+}
+
+std::string TransportStatusDetail::ToString() const {
+ std::string repr = "TransportStatusDetail{";
+ repr += arrow::flight::ToString(code());
+ repr += ", message=\"";
+ repr += message();
+ repr += "\", details={";
+
+ bool first = true;
+ for (const auto& [key, value] : details()) {
+ if (!first) {
+ repr += ", ";
+ }
+ first = false;
+
+ repr += "{\"";
+ repr += key;
+ repr += "\", ";
+ if (arrow::internal::EndsWith(key, "-bin")) {
+ repr += arrow::util::base64_encode(value);
+ } else {
+ repr += "\"";
+ repr += value;
+ repr += "\"";
+ }
+ repr += "}";
+ }
+
+ repr += "}}";
+ return repr;
+}
+
+std::optional<std::reference_wrapper<const TransportStatusDetail>>
+TransportStatusDetail::Unwrap(const Status& status) {
+ std::shared_ptr<StatusDetail> detail = status.detail();
+ if (!detail) return std::nullopt;
+ if (detail->type_id() != kTypeId) return std::nullopt;
+ return std::cref(arrow::internal::checked_cast<const
TransportStatusDetail&>(*detail));
+}
+
+//------------------------------------------------------------
+// Async types
+
+AsyncListenerBase::AsyncListenerBase() = default;
+AsyncListenerBase::~AsyncListenerBase() = default;
+void AsyncListenerBase::TryCancel() {
+ if (rpc_state_) {
+ rpc_state_->TryCancel();
+ }
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index ca86c27e86..c5d72d5167 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -24,15 +24,18 @@
#include <cstdint>
#include <map>
#include <memory>
+#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
+#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h"
#include "arrow/ipc/options.h"
#include "arrow/ipc/writer.h"
#include "arrow/result.h"
+#include "arrow/status.h"
namespace arrow {
@@ -71,7 +74,8 @@ namespace flight {
/// > is from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59.999999999Z.
using Timestamp = std::chrono::system_clock::time_point;
-/// \brief A Flight-specific status code.
+/// \brief A Flight-specific status code. Used to encode some
+/// additional status codes into an Arrow Status.
enum class FlightStatusCode : int8_t {
/// An implementation error has occurred.
Internal,
@@ -774,5 +778,81 @@ class ARROW_FLIGHT_EXPORT SimpleResultStream : public
ResultStream {
size_t position_;
};
+/// \defgroup flight-error Error Handling
+/// Types for handling errors from RPCs. Flight uses a set of status
+/// codes standardized across Flight implementations, so these types
+/// let applications work directly with those codes instead of having
+/// to translate to and from Arrow Status.
+/// @{
+
+/// \brief Abstract status code for an RPC as per the Flight
+/// specification.
+enum class TransportStatusCode {
+ /// \brief No error.
+ kOk = 0,
+ /// \brief An unknown error occurred.
+ kUnknown = 1,
+ /// \brief An error occurred in the transport implementation, or an
+ /// error internal to the service implementation occurred.
+ kInternal = 2,
+ /// \brief An argument is invalid.
+ kInvalidArgument = 3,
+ /// \brief The request timed out.
+ kTimedOut = 4,
+ /// \brief An argument is not necessarily invalid, but references
+ /// some resource that does not exist. Prefer over
+ /// kInvalidArgument where applicable.
+ kNotFound = 5,
+ /// \brief The request attempted to create some resource that does
+ /// not exist.
+ kAlreadyExists = 6,
+ /// \brief The request was explicitly cancelled.
+ kCancelled = 7,
+ /// \brief The client is not authenticated.
+ kUnauthenticated = 8,
+ /// \brief The client is not authorized to perform this request.
+ kUnauthorized = 9,
+ /// \brief The request is not implemented
+ kUnimplemented = 10,
+ /// \brief There is a network connectivity error, or some resource
+ /// is otherwise unavailable. Most likely a temporary condition.
+ kUnavailable = 11,
+};
+
+/// \brief Convert a code to a string.
+std::string ToString(TransportStatusCode code);
+
+/// \brief An error from an RPC call, using Flight error codes directly
+/// instead of trying to translate to Arrow Status.
+///
+/// Currently, only attached to the Status passed to AsyncListener::OnFinish.
+///
+/// This API is EXPERIMENTAL.
+class ARROW_FLIGHT_EXPORT TransportStatusDetail : public StatusDetail {
+ public:
+ constexpr static const char* kTypeId = "flight::TransportStatusDetail";
+ explicit TransportStatusDetail(TransportStatusCode code, std::string message,
+ std::vector<std::pair<std::string,
std::string>> details)
+ : code_(code), message_(std::move(message)),
details_(std::move(details)) {}
+ const char* type_id() const override { return kTypeId; }
+ std::string ToString() const override;
+
+ static std::optional<std::reference_wrapper<const TransportStatusDetail>>
Unwrap(
+ const Status& status);
+
+ TransportStatusCode code() const { return code_; }
+ std::string_view message() const { return message_; }
+ const std::vector<std::pair<std::string, std::string>>& details() const {
+ return details_;
+ }
+
+ private:
+ TransportStatusCode code_;
+ std::string message_;
+ std::vector<std::pair<std::string, std::string>> details_;
+};
+
+/// @}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/types_async.h
b/cpp/src/arrow/flight/types_async.h
new file mode 100644
index 0000000000..a241e64fb4
--- /dev/null
+++ b/cpp/src/arrow/flight/types_async.h
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/flight/type_fwd.h"
+#include "arrow/flight/types.h"
+#include "arrow/ipc/options.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow::flight {
+
+/// \defgroup flight-async Async Flight Types
+/// Common types used for asynchronous Flight APIs.
+/// @{
+
+/// \brief Non-templated state for an async RPC.
+///
+/// This API is EXPERIMENTAL.
+class ARROW_FLIGHT_EXPORT AsyncListenerBase {
+ public:
+ AsyncListenerBase();
+ virtual ~AsyncListenerBase();
+
+ /// \brief Request cancellation of the RPC.
+ ///
+ /// The RPC is not cancelled until AsyncListener::OnFinish is called.
+ void TryCancel();
+
+ private:
+ friend class arrow::flight::internal::ClientTransport;
+
+ /// Transport-specific state for this RPC. Transport
+ /// implementations may store and retrieve state here via
+ /// ClientTransport::SetAsyncRpc and ClientTransport::GetAsyncRpc.
+ std::unique_ptr<internal::AsyncRpc> rpc_state_;
+};
+
+/// \brief Callbacks for results from async RPCs.
+///
+/// A single listener may not be used for multiple concurrent RPC
+/// calls. The application MUST hold the listener alive until
+/// OnFinish() is called and has finished.
+///
+/// This API is EXPERIMENTAL.
+template <typename T>
+class ARROW_FLIGHT_EXPORT AsyncListener : public AsyncListenerBase {
+ public:
+ /// \brief Get the next server result.
+ ///
+ /// This will never be called concurrently with itself or OnFinish.
+ virtual void OnNext(T message) = 0;
+ /// \brief Get the final status.
+ ///
+ /// This will never be called concurrently with itself or OnNext. If the
+ /// error comes from the remote server, then a TransportStatusDetail will be
+ /// attached. Otherwise, the error is generated by the client-side
+ /// transport and will not have a TransportStatusDetail.
+ virtual void OnFinish(Status status) = 0;
+};
+
+/// @}
+
+} // namespace arrow::flight