This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new f0e14539dc GH-35442: [C++][FlightRPC] Pass ServerCallContext instead
of CallHeaders to ServerMiddlewareFactory::StartCall() (#35454)
f0e14539dc is described below
commit f0e14539dc546b53ed8614b128ed6df5e3b5feac
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sun May 7 05:08:48 2023 +0900
GH-35442: [C++][FlightRPC] Pass ServerCallContext instead of CallHeaders to
ServerMiddlewareFactory::StartCall() (#35454)
### Rationale for this change
Because it's also a RPC call like others such as `ListFlights()` and
`DoGet()`.
If we pass `ServerCallContext` instead of `CallHeaders`, implementers can
also get other information such as client address. For example,
https://github.com/apache/arrow-flight-sql-postgresql will use it by
`ServerCallContext::peer()`.
### What changes are included in this PR?
Passes `ServerCallContext` instead of `CallHeaders` but keeps a backward
compatibility.
Implementers can still use the old signature.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Yes. But this is still backward compatible.
* Closes: #35442
Authored-by: Sutou Kouhei <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
cpp/src/arrow/flight/CMakeLists.txt | 1 +
cpp/src/arrow/flight/flight_test.cc | 20 ++++++-------
.../flight/integration_tests/test_integration.cc | 4 +--
cpp/src/arrow/flight/server_auth.h | 3 +-
cpp/src/arrow/flight/server_middleware.cc | 35 ++++++++++++++++++++++
cpp/src/arrow/flight/server_middleware.h | 24 ++++++++++++++-
cpp/src/arrow/flight/server_tracing_middleware.cc | 13 ++++----
cpp/src/arrow/flight/transport/grpc/grpc_server.cc | 3 +-
8 files changed, 80 insertions(+), 23 deletions(-)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt
b/cpp/src/arrow/flight/CMakeLists.txt
index 2a88e5f8ec..917c0c3321 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -199,6 +199,7 @@ set(ARROW_FLIGHT_SRCS
serialization_internal.cc
server.cc
server_auth.cc
+ server_middleware.cc
server_tracing_middleware.cc
transport.cc
transport_server.cc
diff --git a/cpp/src/arrow/flight/flight_test.cc
b/cpp/src/arrow/flight/flight_test.cc
index 09e9d8c156..d56dc81e35 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -452,7 +452,7 @@ class TestTls : public ::testing::Test {
// A server middleware that rejects all calls.
class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are
rejected");
}
@@ -484,7 +484,7 @@ class CountingServerMiddlewareFactory : public
ServerMiddlewareFactory {
public:
CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
*middleware = std::make_shared<CountingServerMiddleware>(&successful_,
&failed_);
return Status::OK();
@@ -517,10 +517,10 @@ class TracingTestServerMiddlewareFactory : public
ServerMiddlewareFactory {
public:
TracingTestServerMiddlewareFactory() {}
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
iter_pair =
- incoming_headers.equal_range("x-tracing-span-id");
+ context.incoming_headers().equal_range("x-tracing-span-id");
if (iter_pair.first != iter_pair.second) {
const std::string_view& value = (*iter_pair.first).second;
*middleware =
std::make_shared<TracingTestServerMiddleware>(std::string(value));
@@ -578,10 +578,10 @@ class HeaderAuthServerMiddlewareFactory : public
ServerMiddlewareFactory {
public:
HeaderAuthServerMiddlewareFactory() {}
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
std::string username, password;
- ParseBasicHeader(incoming_headers, username, password);
+ ParseBasicHeader(context.incoming_headers(), username, password);
if ((username == kValidUsername) && (password == kValidPassword)) {
*middleware = std::make_shared<HeaderAuthServerMiddleware>();
} else if ((username == kInvalidUsername) && (password ==
kInvalidPassword)) {
@@ -619,13 +619,13 @@ class BearerAuthServerMiddlewareFactory : public
ServerMiddlewareFactory {
public:
BearerAuthServerMiddlewareFactory() : isValid_(false) {}
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
iter_pair =
- incoming_headers.equal_range(kAuthHeader);
+ context.incoming_headers().equal_range(kAuthHeader);
if (iter_pair.first != iter_pair.second) {
- *middleware =
- std::make_shared<BearerAuthServerMiddleware>(incoming_headers,
&isValid_);
+ *middleware = std::make_shared<BearerAuthServerMiddleware>(
+ context.incoming_headers(), &isValid_);
}
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc
b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index f6af142978..9a300d1bd2 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -143,10 +143,10 @@ class TestServerMiddleware : public ServerMiddleware {
class TestServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&
iter_pair =
- incoming_headers.equal_range("x-middleware");
+ context.incoming_headers().equal_range("x-middleware");
std::string received = "";
if (iter_pair.first != iter_pair.second) {
const std::string_view& value = (*iter_pair.first).second;
diff --git a/cpp/src/arrow/flight/server_auth.h
b/cpp/src/arrow/flight/server_auth.h
index 3d4787c0c7..93d3352ba2 100644
--- a/cpp/src/arrow/flight/server_auth.h
+++ b/cpp/src/arrow/flight/server_auth.h
@@ -21,6 +21,7 @@
#include <string>
+#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"
@@ -28,8 +29,6 @@ namespace arrow {
namespace flight {
-class ServerCallContext;
-
/// \brief A reader for messages from the client during an
/// authentication handshake.
class ARROW_FLIGHT_EXPORT ServerAuthReader {
diff --git a/cpp/src/arrow/flight/server_middleware.cc
b/cpp/src/arrow/flight/server_middleware.cc
new file mode 100644
index 0000000000..d7ace580dc
--- /dev/null
+++ b/cpp/src/arrow/flight/server_middleware.cc
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/server.h"
+
+namespace arrow {
+namespace flight {
+
+Status ServerMiddlewareFactory::StartCall(const CallInfo& info,
+ const ServerCallContext& context,
+ std::shared_ptr<ServerMiddleware>*
middleware) {
+ // TODO: We can make this pure virtual function when we remove
+ // the deprecated version.
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ return StartCall(info, context.incoming_headers(), middleware);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/server_middleware.h
b/cpp/src/arrow/flight/server_middleware.h
index 26431aff01..030f1a17c2 100644
--- a/cpp/src/arrow/flight/server_middleware.h
+++ b/cpp/src/arrow/flight/server_middleware.h
@@ -24,6 +24,7 @@
#include <string>
#include "arrow/flight/middleware.h"
+#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h" // IWYU pragma: keep
#include "arrow/status.h"
@@ -61,6 +62,22 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
public:
virtual ~ServerMiddlewareFactory() = default;
+ /// \brief A callback for the start of a new call.
+ ///
+ /// Return a non-OK status to reject the call with the given status.
+ ///
+ /// \param[in] info Information about the call.
+ /// \param[in] context The call context.
+ /// \param[out] middleware The middleware instance for this call. If
+ /// null, no middleware will be added to this call instance from
+ /// this factory.
+ /// \return Status A non-OK status will reject the call with the
+ /// given status. Middleware previously in the chain will have
+ /// their CallCompleted callback called. Other middleware
+ /// factories will not be called.
+ virtual Status StartCall(const CallInfo& info, const ServerCallContext&
context,
+ std::shared_ptr<ServerMiddleware>* middleware);
+
/// \brief A callback for the start of a new call.
///
/// Return a non-OK status to reject the call with the given status.
@@ -75,8 +92,13 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
/// given status. Middleware previously in the chain will have
/// their CallCompleted callback called. Other middleware
/// factories will not be called.
+ /// \deprecated Deprecated in 13.0.0. Implement the StartCall()
+ /// with ServerCallContext version instead.
+ ARROW_DEPRECATED("Deprecated in 13.0.0. Use ServerCallContext overload
instead.")
virtual Status StartCall(const CallInfo& info, const CallHeaders&
incoming_headers,
- std::shared_ptr<ServerMiddleware>* middleware) = 0;
+ std::shared_ptr<ServerMiddleware>* middleware) {
+ return Status::NotImplemented(typeid(this).name(), "::StartCall() isn't
implemented");
+ }
};
} // namespace flight
diff --git a/cpp/src/arrow/flight/server_tracing_middleware.cc
b/cpp/src/arrow/flight/server_tracing_middleware.cc
index 6587db1d1d..b5326d88a4 100644
--- a/cpp/src/arrow/flight/server_tracing_middleware.cc
+++ b/cpp/src/arrow/flight/server_tracing_middleware.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/flight/server_tracing_middleware.h"
+#include "arrow/flight/server.h"
#include <string>
#include <string_view>
@@ -122,19 +123,19 @@ class TracingServerMiddleware::Impl {
class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
virtual ~TracingServerMiddlewareFactory() = default;
- Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
constexpr char kServiceName[] = "arrow.flight.protocol.FlightService";
- FlightServerCarrier carrier(incoming_headers);
- auto context = otel::context::RuntimeContext::GetCurrent();
+ FlightServerCarrier carrier(context.incoming_headers());
+ auto otel_context = otel::context::RuntimeContext::GetCurrent();
auto propagator =
otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
- auto new_context = propagator->Extract(carrier, context);
+ auto new_otel_context = propagator->Extract(carrier, otel_context);
otel::trace::StartSpanOptions options;
options.kind = otel::trace::SpanKind::kServer;
- options.parent = otel::trace::GetSpan(new_context)->GetContext();
+ options.parent = otel::trace::GetSpan(new_otel_context)->GetContext();
auto* tracer = arrow::internal::tracing::GetTracer();
auto method_name = ToString(info.method);
@@ -167,7 +168,7 @@ class TracingServerMiddleware::Impl {
class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
virtual ~TracingServerMiddlewareFactory() = default;
- Status StartCall(const CallInfo&, const CallHeaders&,
+ Status StartCall(const CallInfo&, const ServerCallContext&,
std::shared_ptr<ServerMiddleware>* middleware) override {
std::unique_ptr<TracingServerMiddleware::Impl> impl(
new TracingServerMiddleware::Impl());
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index 09d702cd84..dcf9c3f8c9 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -323,8 +323,7 @@ class GrpcServiceHandler final : public
FlightService::Service {
GrpcAddServerHeaders outgoing_headers(context);
for (const auto& factory : middleware_) {
std::shared_ptr<ServerMiddleware> instance;
- Status result =
- factory.second->StartCall(info, flight_context.incoming_headers(),
&instance);
+ Status result = factory.second->StartCall(info, flight_context,
&instance);
if (!result.ok()) {
// Interceptor rejected call, end the request on all existing
// interceptors