This is an automated email from the ASF dual-hosted git repository.
felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 69bce8f0cd GH-43677: [C++][FlightRPC] Move the FlightTestServer to its
own .cc and .h files (#43678)
69bce8f0cd is described below
commit 69bce8f0cd02297ecc31caef22db67e654c16e28
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Tue Aug 13 21:27:36 2024 -0300
GH-43677: [C++][FlightRPC] Move the FlightTestServer to its own .cc and .h
files (#43678)
### Rationale for this change
One way of learning about a codebase is reading the tests. As it is now,
it's hard to see the minimal `FlightServerBase` sub-class in
`flight/test_util.cc`, so I moved it to its own file.
### What changes are included in this PR?
- Renaming `FlightTestServer` to `TestFlightServer`
- Moving the class to `test_flight_server.{h,cc}`
- Bonus: Moving the server and client auth handlers to
`test_auth_handlers.{h,cc}`
### Are these changes tested?
By existing tests.
### Are there any user-facing changes?
`ExampleTestServer` is removed from the testing library in favor of
`FlightTestServer::Make`.
* GitHub Issue: #43677
Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
cpp/src/arrow/flight/CMakeLists.txt | 2 +
cpp/src/arrow/flight/flight_test.cc | 8 +-
.../flight/integration_tests/test_integration.cc | 1 +
cpp/src/arrow/flight/test_auth_handlers.cc | 141 ++++++
cpp/src/arrow/flight/test_auth_handlers.h | 89 ++++
cpp/src/arrow/flight/test_definitions.cc | 15 +-
cpp/src/arrow/flight/test_flight_server.cc | 417 ++++++++++++++++++
cpp/src/arrow/flight/test_flight_server.h | 92 ++++
cpp/src/arrow/flight/test_server.cc | 3 +-
cpp/src/arrow/flight/test_util.cc | 486 +--------------------
cpp/src/arrow/flight/test_util.h | 65 ---
11 files changed, 759 insertions(+), 560 deletions(-)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt
b/cpp/src/arrow/flight/CMakeLists.txt
index 43ac48b876..98f93705f6 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -262,7 +262,9 @@ if(ARROW_TESTING)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
+ test_auth_handlers.cc
test_definitions.cc
+ test_flight_server.cc
test_util.cc
DEPENDENCIES
flight_grpc_gen
diff --git a/cpp/src/arrow/flight/flight_test.cc
b/cpp/src/arrow/flight/flight_test.cc
index 101bb06b21..3d52bc3f5a 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -52,7 +52,9 @@
// Include before test_util.h (boost), contains Windows fixes
#include "arrow/flight/platform.h"
#include "arrow/flight/serialization_internal.h"
+#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_definitions.h"
+#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
// OTel includes must come after any gRPC includes, and
// client_header_internal.h includes gRPC. See:
@@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) {
// CI environments don't have an IPv6 interface configured
TEST(TestFlight, DISABLED_IpV6Port) {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0));
FlightServerOptions options(location);
@@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) {
}
TEST(TestFlight, ServerCallContextIncomingHeaders) {
- auto server = ExampleTestServer();
+ auto server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
@@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) {
class TestFlightClient : public ::testing::Test {
public:
void SetUp() {
- server_ = ExampleTestServer();
+ server_ = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc
b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index 665c1f1ba0..da6fcf81eb 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -36,6 +36,7 @@
#include "arrow/flight/sql/server.h"
#include "arrow/flight/sql/server_session_middleware.h"
#include "arrow/flight/sql/types.h"
+#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/dictionary.h"
diff --git a/cpp/src/arrow/flight/test_auth_handlers.cc
b/cpp/src/arrow/flight/test_auth_handlers.cc
new file mode 100644
index 0000000000..856ccf0f2b
--- /dev/null
+++ b/cpp/src/arrow/flight/test_auth_handlers.cc
@@ -0,0 +1,141 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/test_auth_handlers.h"
+#include "arrow/flight/types.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow::flight {
+
+// TestServerAuthHandler
+
+TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
+ const std::string& password)
+ : username_(username), password_(password) {}
+
+TestServerAuthHandler::~TestServerAuthHandler() {}
+
+Status TestServerAuthHandler::Authenticate(const ServerCallContext& context,
+ ServerAuthSender* outgoing,
+ ServerAuthReader* incoming) {
+ std::string token;
+ RETURN_NOT_OK(incoming->Read(&token));
+ if (token != password_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ RETURN_NOT_OK(outgoing->Write(username_));
+ return Status::OK();
+}
+
+Status TestServerAuthHandler::IsValid(const ServerCallContext& context,
+ const std::string& token,
+ std::string* peer_identity) {
+ if (token != password_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ *peer_identity = username_;
+ return Status::OK();
+}
+
+// TestServerBasicAuthHandler
+
+TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string&
username,
+ const std::string&
password) {
+ basic_auth_.username = username;
+ basic_auth_.password = password;
+}
+
+TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}
+
+Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext&
context,
+ ServerAuthSender* outgoing,
+ ServerAuthReader* incoming) {
+ std::string token;
+ RETURN_NOT_OK(incoming->Read(&token));
+ ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth,
BasicAuth::Deserialize(token));
+ if (incoming_auth.username != basic_auth_.username ||
+ incoming_auth.password != basic_auth_.password) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
+ return Status::OK();
+}
+
+Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context,
+ const std::string& token,
+ std::string* peer_identity) {
+ if (token != basic_auth_.username) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ *peer_identity = basic_auth_.username;
+ return Status::OK();
+}
+
+// TestClientAuthHandler
+
+TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
+ const std::string& password)
+ : username_(username), password_(password) {}
+
+TestClientAuthHandler::~TestClientAuthHandler() {}
+
+Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
+ ClientAuthReader* incoming) {
+ RETURN_NOT_OK(outgoing->Write(password_));
+ std::string username;
+ RETURN_NOT_OK(incoming->Read(&username));
+ if (username != username_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ return Status::OK();
+}
+
+Status TestClientAuthHandler::GetToken(std::string* token) {
+ *token = password_;
+ return Status::OK();
+}
+
+// TestClientBasicAuthHandler
+
+TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string&
username,
+ const std::string&
password) {
+ basic_auth_.username = username;
+ basic_auth_.password = password;
+}
+
+TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}
+
+Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
+ ClientAuthReader* incoming) {
+ ARROW_ASSIGN_OR_RAISE(std::string pb_result,
basic_auth_.SerializeToString());
+ RETURN_NOT_OK(outgoing->Write(pb_result));
+ RETURN_NOT_OK(incoming->Read(&token_));
+ return Status::OK();
+}
+
+Status TestClientBasicAuthHandler::GetToken(std::string* token) {
+ *token = token_;
+ return Status::OK();
+}
+
+} // namespace arrow::flight
diff --git a/cpp/src/arrow/flight/test_auth_handlers.h
b/cpp/src/arrow/flight/test_auth_handlers.h
new file mode 100644
index 0000000000..74f48798f3
--- /dev/null
+++ b/cpp/src/arrow/flight/test_auth_handlers.h
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/types.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+// A pair of authentication handlers that check for a predefined password
+// and set the peer identity to a predefined username.
+
+namespace arrow::flight {
+
+class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler {
+ public:
+ explicit TestServerAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestServerAuthHandler() override;
+ Status Authenticate(const ServerCallContext& context, ServerAuthSender*
outgoing,
+ ServerAuthReader* incoming) override;
+ Status IsValid(const ServerCallContext& context, const std::string& token,
+ std::string* peer_identity) override;
+
+ private:
+ std::string username_;
+ std::string password_;
+};
+
+class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public
ServerAuthHandler {
+ public:
+ explicit TestServerBasicAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestServerBasicAuthHandler() override;
+ Status Authenticate(const ServerCallContext& context, ServerAuthSender*
outgoing,
+ ServerAuthReader* incoming) override;
+ Status IsValid(const ServerCallContext& context, const std::string& token,
+ std::string* peer_identity) override;
+
+ private:
+ BasicAuth basic_auth_;
+};
+
+class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
+ public:
+ explicit TestClientAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestClientAuthHandler() override;
+ Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming)
override;
+ Status GetToken(std::string* token) override;
+
+ private:
+ std::string username_;
+ std::string password_;
+};
+
+class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public
ClientAuthHandler {
+ public:
+ explicit TestClientBasicAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestClientBasicAuthHandler() override;
+ Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming)
override;
+ Status GetToken(std::string* token) override;
+
+ private:
+ BasicAuth basic_auth_;
+ std::string token_;
+};
+
+} // namespace arrow::flight
diff --git a/cpp/src/arrow/flight/test_definitions.cc
b/cpp/src/arrow/flight/test_definitions.cc
index c43b693d84..273d394c28 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -27,6 +27,7 @@
#include "arrow/array/util.h"
#include "arrow/flight/api.h"
#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"
@@ -53,7 +54,7 @@ using arrow::internal::checked_cast;
// Tests of initialization/shutdown
void ConnectivityTest::TestGetPort() {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
FlightServerOptions options(location);
@@ -61,7 +62,7 @@ void ConnectivityTest::TestGetPort() {
ASSERT_GT(server->port(), 0);
}
void ConnectivityTest::TestBuilderHook() {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
FlightServerOptions options(location);
@@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() {
constexpr int kIterations = 10;
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
@@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() {
}
}
void ConnectivityTest::TestShutdownWithDeadline() {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
FlightServerOptions options(location);
@@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() {
ASSERT_OK(server->Wait());
}
void ConnectivityTest::TestBrokenConnection() {
- std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+ std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
@@ -151,7 +152,7 @@ class GetFlightInfoListener : public
AsyncListener<FlightInfo> {
} // namespace
void DataTest::SetUpTest() {
- server_ = ExampleTestServer();
+ server_ = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
FlightServerOptions options(location);
@@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() {
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(),
"127.0.0.1", 0));
- server_ = ExampleTestServer();
+ server_ = TestFlightServer::Make();
FlightServerOptions server_options(location);
ASSERT_OK(server_->Init(server_options));
diff --git a/cpp/src/arrow/flight/test_flight_server.cc
b/cpp/src/arrow/flight/test_flight_server.cc
new file mode 100644
index 0000000000..0ea95ebd15
--- /dev/null
+++ b/cpp/src/arrow/flight/test_flight_server.cc
@@ -0,0 +1,417 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/flight/test_flight_server.h"
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/type_fwd.h"
+#include "arrow/status.h"
+
+namespace arrow::flight {
+namespace {
+
+class ErrorRecordBatchReader : public RecordBatchReader {
+ public:
+ ErrorRecordBatchReader() : schema_(arrow::schema({})) {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ Status Close() override {
+ // This should be propagated over DoGet to the client
+ return Status::IOError("Expected error");
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+};
+
+Status GetBatchForFlight(const Ticket& ticket,
std::shared_ptr<RecordBatchReader>* out) {
+ if (ticket.ticket == "ticket-ints-1") {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-floats-1") {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleFloatBatches(&batches));
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-dicts-1") {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleDictBatches(&batches));
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-large-batch-1") {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleLargeBatches(&batches));
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
+ return Status::OK();
+ } else {
+ return Status::NotImplemented("no stream implemented for ticket: " +
ticket.ticket);
+ }
+}
+
+} // namespace
+
+std::unique_ptr<FlightServerBase> TestFlightServer::Make() {
+ return std::make_unique<TestFlightServer>();
+}
+
+Status TestFlightServer::ListFlights(const ServerCallContext& context,
+ const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings)
{
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+ if (criteria && criteria->expression != "") {
+ // For test purposes, if we get criteria, return no results
+ flights.clear();
+ }
+ *listings = std::make_unique<SimpleFlightListing>(flights);
+ return Status::OK();
+}
+
+Status TestFlightServer::GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* out) {
+ // Test that Arrow-C++ status codes make it through the transport
+ if (request.type == FlightDescriptor::DescriptorType::CMD &&
+ request.cmd == "status-outofmemory") {
+ return Status::OutOfMemory("Sentinel");
+ }
+
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *out = std::make_unique<FlightInfo>(info);
+ return Status::OK();
+ }
+ }
+ return Status::Invalid("Flight not found: ", request.ToString());
+}
+
+Status TestFlightServer::DoGet(const ServerCallContext& context, const Ticket&
request,
+ std::unique_ptr<FlightDataStream>* data_stream)
{
+ // Test for ARROW-5095
+ if (request.ticket == "ARROW-5095-fail") {
+ return Status::UnknownError("Server-side error");
+ }
+ if (request.ticket == "ARROW-5095-success") {
+ return Status::OK();
+ }
+ if (request.ticket == "ARROW-13253-DoGet-Batch") {
+ // Make batch > 2GiB in size
+ ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch}));
+ *data_stream = std::make_unique<RecordBatchStream>(std::move(reader));
+ return Status::OK();
+ }
+ if (request.ticket == "ticket-stream-error") {
+ auto reader = std::make_shared<ErrorRecordBatchReader>();
+ *data_stream = std::make_unique<RecordBatchStream>(std::move(reader));
+ return Status::OK();
+ }
+
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
+
+ *data_stream = std::make_unique<RecordBatchStream>(batch_reader);
+ return Status::OK();
+}
+
+Status TestFlightServer::DoPut(const ServerCallContext&,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) {
+ return reader->ToRecordBatches().status();
+}
+
+Status TestFlightServer::DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMessageWriter>
writer) {
+ // Test various scenarios for a DoExchange
+ if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) {
+ return Status::Invalid("Must provide a command descriptor");
+ }
+
+ const std::string& cmd = reader->descriptor().cmd;
+ if (cmd == "error") {
+ // Immediately return an error to the client.
+ return Status::NotImplemented("Expected error");
+ } else if (cmd == "get") {
+ return RunExchangeGet(std::move(reader), std::move(writer));
+ } else if (cmd == "put") {
+ return RunExchangePut(std::move(reader), std::move(writer));
+ } else if (cmd == "counter") {
+ return RunExchangeCounter(std::move(reader), std::move(writer));
+ } else if (cmd == "total") {
+ return RunExchangeTotal(std::move(reader), std::move(writer));
+ } else if (cmd == "echo") {
+ return RunExchangeEcho(std::move(reader), std::move(writer));
+ } else if (cmd == "large_batch") {
+ return RunExchangeLargeBatch(std::move(reader), std::move(writer));
+ } else if (cmd == "TestUndrained") {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ return Status::OK();
+ } else {
+ return Status::NotImplemented("Scenario not implemented: ", cmd);
+ }
+}
+
+// A simple example - act like DoGet.
+Status TestFlightServer::RunExchangeGet(std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMessageWriter>
writer) {
+ RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ return Status::OK();
+}
+
+// A simple example - act like DoPut
+Status TestFlightServer::RunExchangePut(std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMessageWriter>
writer) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ if (!schema->Equals(ExampleIntSchema(), false)) {
+ return Status::Invalid("Schema is not as expected");
+ }
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ FlightStreamChunk chunk;
+ for (const auto& batch : batches) {
+ ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
+ if (!chunk.data) {
+ return Status::Invalid("Expected another batch");
+ }
+ if (!batch->Equals(*chunk.data)) {
+ return Status::Invalid("Batch does not match");
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
+ if (chunk.data || chunk.app_metadata) {
+ return Status::Invalid("Too many batches");
+ }
+
+ RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done")));
+ return Status::OK();
+}
+
+// Read some number of record batches from the client, send a
+// metadata message back with the count, then echo the batches back.
+Status
TestFlightServer::RunExchangeCounter(std::unique_ptr<FlightMessageReader>
reader,
+
std::unique_ptr<FlightMessageWriter> writer) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ FlightStreamChunk chunk;
+ int chunks = 0;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (chunk.data) {
+ batches.push_back(chunk.data);
+ chunks++;
+ }
+ }
+
+ // Echo back the number of record batches read.
+ std::shared_ptr<Buffer> buf = Buffer::FromString(std::to_string(chunks));
+ RETURN_NOT_OK(writer->WriteMetadata(buf));
+ // Echo the record batches themselves.
+ if (chunks > 0) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ RETURN_NOT_OK(writer->Begin(schema));
+
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ }
+
+ return Status::OK();
+}
+
+// Read int64 batches from the client, each time sending back a
+// batch with a running sum of columns.
+Status TestFlightServer::RunExchangeTotal(std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMessageWriter>
writer) {
+ FlightStreamChunk chunk{};
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ // Ensure the schema contains only int64 columns
+ for (const auto& field : schema->fields()) {
+ if (field->type()->id() != Type::type::INT64) {
+ return Status::Invalid("Field is not INT64: ", field->name());
+ }
+ }
+ std::vector<int64_t> sums(schema->num_fields());
+ std::vector<std::shared_ptr<Array>> columns(schema->num_fields());
+ RETURN_NOT_OK(writer->Begin(schema));
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (chunk.data) {
+ if (!chunk.data->schema()->Equals(schema, false)) {
+ // A compliant client implementation would make this impossible
+ return Status::Invalid("Schemas are incompatible");
+ }
+
+ // Update the running totals
+ auto builder = std::make_shared<Int64Builder>();
+ int col_index = 0;
+ for (const auto& column : chunk.data->columns()) {
+ auto arr = std::dynamic_pointer_cast<Int64Array>(column);
+ if (!arr) {
+ return MakeFlightError(FlightStatusCode::Internal, "Could not cast
array");
+ }
+ for (int row = 0; row < column->length(); row++) {
+ if (!arr->IsNull(row)) {
+ sums[col_index] += arr->Value(row);
+ }
+ }
+
+ builder->Reset();
+ RETURN_NOT_OK(builder->Append(sums[col_index]));
+ RETURN_NOT_OK(builder->Finish(&columns[col_index]));
+
+ col_index++;
+ }
+
+ // Echo the totals to the client
+ auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns);
+ RETURN_NOT_OK(writer->WriteRecordBatch(*response));
+ }
+ }
+ return Status::OK();
+}
+
+// Echo the client's messages back.
+Status TestFlightServer::RunExchangeEcho(std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMessageWriter>
writer) {
+ FlightStreamChunk chunk;
+ bool begun = false;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (!begun && chunk.data) {
+ begun = true;
+ RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
+ }
+ if (chunk.data && chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data,
chunk.app_metadata));
+ } else if (chunk.data) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
+ } else if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
+ }
+ }
+ return Status::OK();
+}
+
+// Regression test for ARROW-13253
+Status TestFlightServer::RunExchangeLargeBatch(
+ std::unique_ptr<FlightMessageReader>, std::unique_ptr<FlightMessageWriter>
writer) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
+ RETURN_NOT_OK(writer->Begin(batch->schema()));
+ return writer->WriteRecordBatch(*batch);
+}
+
+Status TestFlightServer::RunAction1(const Action& action,
+ std::unique_ptr<ResultStream>* out) {
+ std::vector<Result> results;
+ for (int i = 0; i < 3; ++i) {
+ Result result;
+ std::string value = action.body->ToString() + "-part" + std::to_string(i);
+ result.body = Buffer::FromString(std::move(value));
+ results.push_back(result);
+ }
+ *out = std::make_unique<SimpleResultStream>(std::move(results));
+ return Status::OK();
+}
+
+Status TestFlightServer::RunAction2(std::unique_ptr<ResultStream>* out) {
+ // Empty
+ *out = std::make_unique<SimpleResultStream>(std::vector<Result>{});
+ return Status::OK();
+}
+
+Status TestFlightServer::ListIncomingHeaders(const ServerCallContext& context,
+ const Action& action,
+ std::unique_ptr<ResultStream>*
out) {
+ std::vector<Result> results;
+ std::string_view prefix(*action.body);
+ for (const auto& header : context.incoming_headers()) {
+ if (header.first.substr(0, prefix.size()) != prefix) {
+ continue;
+ }
+ Result result;
+ result.body =
+ Buffer::FromString(std::string(header.first) + ": " +
std::string(header.second));
+ results.push_back(result);
+ }
+ *out = std::make_unique<SimpleResultStream>(std::move(results));
+ return Status::OK();
+}
+
+Status TestFlightServer::DoAction(const ServerCallContext& context, const
Action& action,
+ std::unique_ptr<ResultStream>* out) {
+ if (action.type == "action1") {
+ return RunAction1(action, out);
+ } else if (action.type == "action2") {
+ return RunAction2(out);
+ } else if (action.type == "list-incoming-headers") {
+ return ListIncomingHeaders(context, action, out);
+ } else {
+ return Status::NotImplemented(action.type);
+ }
+}
+
+Status TestFlightServer::ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* out) {
+ std::vector<ActionType> actions = ExampleActionTypes();
+ *out = std::move(actions);
+ return Status::OK();
+}
+
+Status TestFlightServer::GetSchema(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema) {
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *schema = std::make_unique<SchemaResult>(info.serialized_schema());
+ return Status::OK();
+ }
+ }
+ return Status::Invalid("Flight not found: ", request.ToString());
+}
+
+} // namespace arrow::flight
diff --git a/cpp/src/arrow/flight/test_flight_server.h
b/cpp/src/arrow/flight/test_flight_server.h
new file mode 100644
index 0000000000..794dd834c0
--- /dev/null
+++ b/cpp/src/arrow/flight/test_flight_server.h
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/flight/server.h"
+#include "arrow/flight/type_fwd.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow::flight {
+
+class ARROW_FLIGHT_EXPORT TestFlightServer : public FlightServerBase {
+ public:
+ static std::unique_ptr<FlightServerBase> Make();
+
+ Status ListFlights(const ServerCallContext& context, const Criteria*
criteria,
+ std::unique_ptr<FlightListing>* listings) override;
+
+ Status GetFlightInfo(const ServerCallContext& context, const
FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* out) override;
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override;
+
+ Status DoPut(const ServerCallContext&, std::unique_ptr<FlightMessageReader>
reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override;
+
+ Status DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) override;
+
+ // A simple example - act like DoGet.
+ Status RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ // A simple example - act like DoPut
+ Status RunExchangePut(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ // Read some number of record batches from the client, send a
+ // metadata message back with the count, then echo the batches back.
+ Status RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ // Read int64 batches from the client, each time sending back a
+ // batch with a running sum of columns.
+ Status RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ // Echo the client's messages back.
+ Status RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ // Regression test for ARROW-13253
+ Status RunExchangeLargeBatch(std::unique_ptr<FlightMessageReader>,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out);
+
+ Status RunAction2(std::unique_ptr<ResultStream>* out);
+
+ Status ListIncomingHeaders(const ServerCallContext& context, const Action&
action,
+ std::unique_ptr<ResultStream>* out);
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* out) override;
+
+ Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* out) override;
+
+ Status GetSchema(const ServerCallContext& context, const FlightDescriptor&
request,
+ std::unique_ptr<SchemaResult>* schema) override;
+};
+
+} // namespace arrow::flight
diff --git a/cpp/src/arrow/flight/test_server.cc
b/cpp/src/arrow/flight/test_server.cc
index 18bf2b4135..ba84b8f532 100644
--- a/cpp/src/arrow/flight/test_server.cc
+++ b/cpp/src/arrow/flight/test_server.cc
@@ -26,6 +26,7 @@
#include <gflags/gflags.h>
#include "arrow/flight/server.h"
+#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/util/logging.h"
@@ -38,7 +39,7 @@ std::unique_ptr<arrow::flight::FlightServerBase> g_server;
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
- g_server = arrow::flight::ExampleTestServer();
+ g_server = arrow::flight::TestFlightServer::Make();
arrow::flight::Location location;
if (FLAGS_unix.empty()) {
diff --git a/cpp/src/arrow/flight/test_util.cc
b/cpp/src/arrow/flight/test_util.cc
index 8b4245e74e..127827ff38 100644
--- a/cpp/src/arrow/flight/test_util.cc
+++ b/cpp/src/arrow/flight/test_util.cc
@@ -49,8 +49,7 @@
#include "arrow/flight/api.h"
#include "arrow/flight/serialization_internal.h"
-namespace arrow {
-namespace flight {
+namespace arrow::flight {
namespace bp = boost::process;
namespace fs = boost::filesystem;
@@ -90,25 +89,6 @@ Status ResolveCurrentExecutable(fs::path* out) {
}
}
-class ErrorRecordBatchReader : public RecordBatchReader {
- public:
- ErrorRecordBatchReader() : schema_(arrow::schema({})) {}
-
- std::shared_ptr<Schema> schema() const override { return schema_; }
-
- Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
- *out = nullptr;
- return Status::OK();
- }
-
- Status Close() override {
- // This should be propagated over DoGet to the client
- return Status::IOError("Expected error");
- }
-
- private:
- std::shared_ptr<Schema> schema_;
-};
} // namespace
void TestServer::Start(const std::vector<std::string>& extra_args) {
@@ -171,364 +151,6 @@ int TestServer::port() const { return port_; }
const std::string& TestServer::unix_sock() const { return unix_sock_; }
-Status GetBatchForFlight(const Ticket& ticket,
std::shared_ptr<RecordBatchReader>* out) {
- if (ticket.ticket == "ticket-ints-1") {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
- return Status::OK();
- } else if (ticket.ticket == "ticket-floats-1") {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleFloatBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
- return Status::OK();
- } else if (ticket.ticket == "ticket-dicts-1") {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleDictBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
- return Status::OK();
- } else if (ticket.ticket == "ticket-large-batch-1") {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleLargeBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
- return Status::OK();
- } else {
- return Status::NotImplemented("no stream implemented for ticket: " +
ticket.ticket);
- }
-}
-
-class FlightTestServer : public FlightServerBase {
- Status ListFlights(const ServerCallContext& context, const Criteria*
criteria,
- std::unique_ptr<FlightListing>* listings) override {
- std::vector<FlightInfo> flights = ExampleFlightInfo();
- if (criteria && criteria->expression != "") {
- // For test purposes, if we get criteria, return no results
- flights.clear();
- }
- *listings = std::make_unique<SimpleFlightListing>(flights);
- return Status::OK();
- }
-
- Status GetFlightInfo(const ServerCallContext& context, const
FlightDescriptor& request,
- std::unique_ptr<FlightInfo>* out) override {
- // Test that Arrow-C++ status codes make it through the transport
- if (request.type == FlightDescriptor::DescriptorType::CMD &&
- request.cmd == "status-outofmemory") {
- return Status::OutOfMemory("Sentinel");
- }
-
- std::vector<FlightInfo> flights = ExampleFlightInfo();
-
- for (const auto& info : flights) {
- if (info.descriptor().Equals(request)) {
- *out = std::make_unique<FlightInfo>(info);
- return Status::OK();
- }
- }
- return Status::Invalid("Flight not found: ", request.ToString());
- }
-
- Status DoGet(const ServerCallContext& context, const Ticket& request,
- std::unique_ptr<FlightDataStream>* data_stream) override {
- // Test for ARROW-5095
- if (request.ticket == "ARROW-5095-fail") {
- return Status::UnknownError("Server-side error");
- }
- if (request.ticket == "ARROW-5095-success") {
- return Status::OK();
- }
- if (request.ticket == "ARROW-13253-DoGet-Batch") {
- // Make batch > 2GiB in size
- ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
- ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch}));
- *data_stream = std::make_unique<RecordBatchStream>(std::move(reader));
- return Status::OK();
- }
- if (request.ticket == "ticket-stream-error") {
- auto reader = std::make_shared<ErrorRecordBatchReader>();
- *data_stream = std::make_unique<RecordBatchStream>(std::move(reader));
- return Status::OK();
- }
-
- std::shared_ptr<RecordBatchReader> batch_reader;
- RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
-
- *data_stream = std::make_unique<RecordBatchStream>(batch_reader);
- return Status::OK();
- }
-
- Status DoPut(const ServerCallContext&, std::unique_ptr<FlightMessageReader>
reader,
- std::unique_ptr<FlightMetadataWriter> writer) override {
- return reader->ToRecordBatches().status();
- }
-
- Status DoExchange(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) override {
- // Test various scenarios for a DoExchange
- if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) {
- return Status::Invalid("Must provide a command descriptor");
- }
-
- const std::string& cmd = reader->descriptor().cmd;
- if (cmd == "error") {
- // Immediately return an error to the client.
- return Status::NotImplemented("Expected error");
- } else if (cmd == "get") {
- return RunExchangeGet(std::move(reader), std::move(writer));
- } else if (cmd == "put") {
- return RunExchangePut(std::move(reader), std::move(writer));
- } else if (cmd == "counter") {
- return RunExchangeCounter(std::move(reader), std::move(writer));
- } else if (cmd == "total") {
- return RunExchangeTotal(std::move(reader), std::move(writer));
- } else if (cmd == "echo") {
- return RunExchangeEcho(std::move(reader), std::move(writer));
- } else if (cmd == "large_batch") {
- return RunExchangeLargeBatch(std::move(reader), std::move(writer));
- } else if (cmd == "TestUndrained") {
- ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
- return Status::OK();
- } else {
- return Status::NotImplemented("Scenario not implemented: ", cmd);
- }
- }
-
- // A simple example - act like DoGet.
- Status RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) {
- RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- for (const auto& batch : batches) {
- RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
- }
- return Status::OK();
- }
-
- // A simple example - act like DoPut
- Status RunExchangePut(std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) {
- ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
- if (!schema->Equals(ExampleIntSchema(), false)) {
- return Status::Invalid("Schema is not as expected");
- }
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- FlightStreamChunk chunk;
- for (const auto& batch : batches) {
- ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
- if (!chunk.data) {
- return Status::Invalid("Expected another batch");
- }
- if (!batch->Equals(*chunk.data)) {
- return Status::Invalid("Batch does not match");
- }
- }
- ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
- if (chunk.data || chunk.app_metadata) {
- return Status::Invalid("Too many batches");
- }
-
- RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done")));
- return Status::OK();
- }
-
- // Read some number of record batches from the client, send a
- // metadata message back with the count, then echo the batches back.
- Status RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) {
- std::vector<std::shared_ptr<RecordBatch>> batches;
- FlightStreamChunk chunk;
- int chunks = 0;
- while (true) {
- ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
- if (!chunk.data && !chunk.app_metadata) {
- break;
- }
- if (chunk.data) {
- batches.push_back(chunk.data);
- chunks++;
- }
- }
-
- // Echo back the number of record batches read.
- std::shared_ptr<Buffer> buf = Buffer::FromString(std::to_string(chunks));
- RETURN_NOT_OK(writer->WriteMetadata(buf));
- // Echo the record batches themselves.
- if (chunks > 0) {
- ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
- RETURN_NOT_OK(writer->Begin(schema));
-
- for (const auto& batch : batches) {
- RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
- }
- }
-
- return Status::OK();
- }
-
- // Read int64 batches from the client, each time sending back a
- // batch with a running sum of columns.
- Status RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) {
- FlightStreamChunk chunk{};
- ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
- // Ensure the schema contains only int64 columns
- for (const auto& field : schema->fields()) {
- if (field->type()->id() != Type::type::INT64) {
- return Status::Invalid("Field is not INT64: ", field->name());
- }
- }
- std::vector<int64_t> sums(schema->num_fields());
- std::vector<std::shared_ptr<Array>> columns(schema->num_fields());
- RETURN_NOT_OK(writer->Begin(schema));
- while (true) {
- ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
- if (!chunk.data && !chunk.app_metadata) {
- break;
- }
- if (chunk.data) {
- if (!chunk.data->schema()->Equals(schema, false)) {
- // A compliant client implementation would make this impossible
- return Status::Invalid("Schemas are incompatible");
- }
-
- // Update the running totals
- auto builder = std::make_shared<Int64Builder>();
- int col_index = 0;
- for (const auto& column : chunk.data->columns()) {
- auto arr = std::dynamic_pointer_cast<Int64Array>(column);
- if (!arr) {
- return MakeFlightError(FlightStatusCode::Internal, "Could not cast
array");
- }
- for (int row = 0; row < column->length(); row++) {
- if (!arr->IsNull(row)) {
- sums[col_index] += arr->Value(row);
- }
- }
-
- builder->Reset();
- RETURN_NOT_OK(builder->Append(sums[col_index]));
- RETURN_NOT_OK(builder->Finish(&columns[col_index]));
-
- col_index++;
- }
-
- // Echo the totals to the client
- auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns);
- RETURN_NOT_OK(writer->WriteRecordBatch(*response));
- }
- }
- return Status::OK();
- }
-
- // Echo the client's messages back.
- Status RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMessageWriter> writer) {
- FlightStreamChunk chunk;
- bool begun = false;
- while (true) {
- ARROW_ASSIGN_OR_RAISE(chunk, reader->Next());
- if (!chunk.data && !chunk.app_metadata) {
- break;
- }
- if (!begun && chunk.data) {
- begun = true;
- RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
- }
- if (chunk.data && chunk.app_metadata) {
- RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data,
chunk.app_metadata));
- } else if (chunk.data) {
- RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
- } else if (chunk.app_metadata) {
- RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
- }
- }
- return Status::OK();
- }
-
- // Regression test for ARROW-13253
- Status RunExchangeLargeBatch(std::unique_ptr<FlightMessageReader>,
- std::unique_ptr<FlightMessageWriter> writer) {
- ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
- RETURN_NOT_OK(writer->Begin(batch->schema()));
- return writer->WriteRecordBatch(*batch);
- }
-
- Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
- std::vector<Result> results;
- for (int i = 0; i < 3; ++i) {
- Result result;
- std::string value = action.body->ToString() + "-part" +
std::to_string(i);
- result.body = Buffer::FromString(std::move(value));
- results.push_back(result);
- }
- *out = std::make_unique<SimpleResultStream>(std::move(results));
- return Status::OK();
- }
-
- Status RunAction2(std::unique_ptr<ResultStream>* out) {
- // Empty
- *out = std::make_unique<SimpleResultStream>(std::vector<Result>{});
- return Status::OK();
- }
-
- Status ListIncomingHeaders(const ServerCallContext& context, const Action&
action,
- std::unique_ptr<ResultStream>* out) {
- std::vector<Result> results;
- std::string_view prefix(*action.body);
- for (const auto& header : context.incoming_headers()) {
- if (header.first.substr(0, prefix.size()) != prefix) {
- continue;
- }
- Result result;
- result.body = Buffer::FromString(std::string(header.first) + ": " +
- std::string(header.second));
- results.push_back(result);
- }
- *out = std::make_unique<SimpleResultStream>(std::move(results));
- return Status::OK();
- }
-
- Status DoAction(const ServerCallContext& context, const Action& action,
- std::unique_ptr<ResultStream>* out) override {
- if (action.type == "action1") {
- return RunAction1(action, out);
- } else if (action.type == "action2") {
- return RunAction2(out);
- } else if (action.type == "list-incoming-headers") {
- return ListIncomingHeaders(context, action, out);
- } else {
- return Status::NotImplemented(action.type);
- }
- }
-
- Status ListActions(const ServerCallContext& context,
- std::vector<ActionType>* out) override {
- std::vector<ActionType> actions = ExampleActionTypes();
- *out = std::move(actions);
- return Status::OK();
- }
-
- Status GetSchema(const ServerCallContext& context, const FlightDescriptor&
request,
- std::unique_ptr<SchemaResult>* schema) override {
- std::vector<FlightInfo> flights = ExampleFlightInfo();
-
- for (const auto& info : flights) {
- if (info.descriptor().Equals(request)) {
- *schema = std::make_unique<SchemaResult>(info.serialized_schema());
- return Status::OK();
- }
- }
- return Status::Invalid("Flight not found: ", request.ToString());
- }
-};
-
-std::unique_ptr<FlightServerBase> ExampleTestServer() {
- return std::make_unique<FlightTestServer>();
-}
-
FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor&
descriptor,
const std::vector<FlightEndpoint>& endpoints,
int64_t total_records, int64_t total_bytes, bool
ordered,
@@ -701,109 +323,6 @@ std::vector<ActionType> ExampleActionTypes() {
return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}};
}
-TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
- const std::string& password)
- : username_(username), password_(password) {}
-
-TestServerAuthHandler::~TestServerAuthHandler() {}
-
-Status TestServerAuthHandler::Authenticate(const ServerCallContext& context,
- ServerAuthSender* outgoing,
- ServerAuthReader* incoming) {
- std::string token;
- RETURN_NOT_OK(incoming->Read(&token));
- if (token != password_) {
- return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
- }
- RETURN_NOT_OK(outgoing->Write(username_));
- return Status::OK();
-}
-
-Status TestServerAuthHandler::IsValid(const ServerCallContext& context,
- const std::string& token,
- std::string* peer_identity) {
- if (token != password_) {
- return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
- }
- *peer_identity = username_;
- return Status::OK();
-}
-
-TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string&
username,
- const std::string&
password) {
- basic_auth_.username = username;
- basic_auth_.password = password;
-}
-
-TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}
-
-Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext&
context,
- ServerAuthSender* outgoing,
- ServerAuthReader* incoming) {
- std::string token;
- RETURN_NOT_OK(incoming->Read(&token));
- ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth,
BasicAuth::Deserialize(token));
- if (incoming_auth.username != basic_auth_.username ||
- incoming_auth.password != basic_auth_.password) {
- return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
- }
- RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
- return Status::OK();
-}
-
-Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context,
- const std::string& token,
- std::string* peer_identity) {
- if (token != basic_auth_.username) {
- return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
- }
- *peer_identity = basic_auth_.username;
- return Status::OK();
-}
-
-TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
- const std::string& password)
- : username_(username), password_(password) {}
-
-TestClientAuthHandler::~TestClientAuthHandler() {}
-
-Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
- ClientAuthReader* incoming) {
- RETURN_NOT_OK(outgoing->Write(password_));
- std::string username;
- RETURN_NOT_OK(incoming->Read(&username));
- if (username != username_) {
- return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
- }
- return Status::OK();
-}
-
-Status TestClientAuthHandler::GetToken(std::string* token) {
- *token = password_;
- return Status::OK();
-}
-
-TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string&
username,
- const std::string&
password) {
- basic_auth_.username = username;
- basic_auth_.password = password;
-}
-
-TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}
-
-Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
- ClientAuthReader* incoming) {
- ARROW_ASSIGN_OR_RAISE(std::string pb_result,
basic_auth_.SerializeToString());
- RETURN_NOT_OK(outgoing->Write(pb_result));
- RETURN_NOT_OK(incoming->Read(&token_));
- return Status::OK();
-}
-
-Status TestClientBasicAuthHandler::GetToken(std::string* token) {
- *token = token_;
- return Status::OK();
-}
-
Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
std::string root;
RETURN_NOT_OK(GetTestResourceRoot(&root));
@@ -860,5 +379,4 @@ Status ExampleTlsCertificateRoot(CertKeyPair* out) {
}
}
-} // namespace flight
-} // namespace arrow
+} // namespace arrow::flight
diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h
index c0b42d9b90..15ba6145ec 100644
--- a/cpp/src/arrow/flight/test_util.h
+++ b/cpp/src/arrow/flight/test_util.h
@@ -32,9 +32,7 @@
#include "arrow/testing/util.h"
#include "arrow/flight/client.h"
-#include "arrow/flight/client_auth.h"
#include "arrow/flight/server.h"
-#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
@@ -95,10 +93,6 @@ class ARROW_FLIGHT_EXPORT TestServer {
std::shared_ptr<::boost::process::child> server_process_;
};
-/// \brief Create a simple Flight server for testing
-ARROW_FLIGHT_EXPORT
-std::unique_ptr<FlightServerBase> ExampleTestServer();
-
// Helper to initialize a server and matching client with callbacks to
// populate options.
template <typename T, typename... Args>
@@ -195,65 +189,6 @@ FlightInfo MakeFlightInfo(const Schema& schema, const
FlightDescriptor& descript
int64_t total_records, int64_t total_bytes, bool
ordered,
std::string app_metadata);
-// ----------------------------------------------------------------------
-// A pair of authentication handlers that check for a predefined password
-// and set the peer identity to a predefined username.
-
-class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler {
- public:
- explicit TestServerAuthHandler(const std::string& username,
- const std::string& password);
- ~TestServerAuthHandler() override;
- Status Authenticate(const ServerCallContext& context, ServerAuthSender*
outgoing,
- ServerAuthReader* incoming) override;
- Status IsValid(const ServerCallContext& context, const std::string& token,
- std::string* peer_identity) override;
-
- private:
- std::string username_;
- std::string password_;
-};
-
-class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public
ServerAuthHandler {
- public:
- explicit TestServerBasicAuthHandler(const std::string& username,
- const std::string& password);
- ~TestServerBasicAuthHandler() override;
- Status Authenticate(const ServerCallContext& context, ServerAuthSender*
outgoing,
- ServerAuthReader* incoming) override;
- Status IsValid(const ServerCallContext& context, const std::string& token,
- std::string* peer_identity) override;
-
- private:
- BasicAuth basic_auth_;
-};
-
-class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
- public:
- explicit TestClientAuthHandler(const std::string& username,
- const std::string& password);
- ~TestClientAuthHandler() override;
- Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming)
override;
- Status GetToken(std::string* token) override;
-
- private:
- std::string username_;
- std::string password_;
-};
-
-class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public
ClientAuthHandler {
- public:
- explicit TestClientBasicAuthHandler(const std::string& username,
- const std::string& password);
- ~TestClientBasicAuthHandler() override;
- Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming)
override;
- Status GetToken(std::string* token) override;
-
- private:
- BasicAuth basic_auth_;
- std::string token_;
-};
-
ARROW_FLIGHT_EXPORT
Status ExampleTlsCertificates(std::vector<CertKeyPair>* out);