This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new f0e14539dc GH-35442: [C++][FlightRPC] Pass ServerCallContext instead 
of CallHeaders to ServerMiddlewareFactory::StartCall() (#35454)
f0e14539dc is described below

commit f0e14539dc546b53ed8614b128ed6df5e3b5feac
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sun May 7 05:08:48 2023 +0900

    GH-35442: [C++][FlightRPC] Pass ServerCallContext instead of CallHeaders to 
ServerMiddlewareFactory::StartCall() (#35454)
    
    ### Rationale for this change
    
    Because it's also a RPC call like others such as `ListFlights()` and 
`DoGet()`.
    
    If we pass `ServerCallContext` instead of `CallHeaders`, implementers can 
also get other information such as client address. For example, 
https://github.com/apache/arrow-flight-sql-postgresql will use it by 
`ServerCallContext::peer()`.
    
    ### What changes are included in this PR?
    
    Passes `ServerCallContext` instead of `CallHeaders` but keeps a backward 
compatibility.
    Implementers can still use the old signature.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Yes. But this is still backward compatible.
    * Closes: #35442
    
    Authored-by: Sutou Kouhei <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/arrow/flight/CMakeLists.txt                |  1 +
 cpp/src/arrow/flight/flight_test.cc                | 20 ++++++-------
 .../flight/integration_tests/test_integration.cc   |  4 +--
 cpp/src/arrow/flight/server_auth.h                 |  3 +-
 cpp/src/arrow/flight/server_middleware.cc          | 35 ++++++++++++++++++++++
 cpp/src/arrow/flight/server_middleware.h           | 24 ++++++++++++++-
 cpp/src/arrow/flight/server_tracing_middleware.cc  | 13 ++++----
 cpp/src/arrow/flight/transport/grpc/grpc_server.cc |  3 +-
 8 files changed, 80 insertions(+), 23 deletions(-)

diff --git a/cpp/src/arrow/flight/CMakeLists.txt 
b/cpp/src/arrow/flight/CMakeLists.txt
index 2a88e5f8ec..917c0c3321 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -199,6 +199,7 @@ set(ARROW_FLIGHT_SRCS
     serialization_internal.cc
     server.cc
     server_auth.cc
+    server_middleware.cc
     server_tracing_middleware.cc
     transport.cc
     transport_server.cc
diff --git a/cpp/src/arrow/flight/flight_test.cc 
b/cpp/src/arrow/flight/flight_test.cc
index 09e9d8c156..d56dc81e35 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -452,7 +452,7 @@ class TestTls : public ::testing::Test {
 
 // A server middleware that rejects all calls.
 class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are 
rejected");
   }
@@ -484,7 +484,7 @@ class CountingServerMiddlewareFactory : public 
ServerMiddlewareFactory {
  public:
   CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}
 
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     *middleware = std::make_shared<CountingServerMiddleware>(&successful_, 
&failed_);
     return Status::OK();
@@ -517,10 +517,10 @@ class TracingTestServerMiddlewareFactory : public 
ServerMiddlewareFactory {
  public:
   TracingTestServerMiddlewareFactory() {}
 
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& 
iter_pair =
-        incoming_headers.equal_range("x-tracing-span-id");
+        context.incoming_headers().equal_range("x-tracing-span-id");
     if (iter_pair.first != iter_pair.second) {
       const std::string_view& value = (*iter_pair.first).second;
       *middleware = 
std::make_shared<TracingTestServerMiddleware>(std::string(value));
@@ -578,10 +578,10 @@ class HeaderAuthServerMiddlewareFactory : public 
ServerMiddlewareFactory {
  public:
   HeaderAuthServerMiddlewareFactory() {}
 
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     std::string username, password;
-    ParseBasicHeader(incoming_headers, username, password);
+    ParseBasicHeader(context.incoming_headers(), username, password);
     if ((username == kValidUsername) && (password == kValidPassword)) {
       *middleware = std::make_shared<HeaderAuthServerMiddleware>();
     } else if ((username == kInvalidUsername) && (password == 
kInvalidPassword)) {
@@ -619,13 +619,13 @@ class BearerAuthServerMiddlewareFactory : public 
ServerMiddlewareFactory {
  public:
   BearerAuthServerMiddlewareFactory() : isValid_(false) {}
 
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& 
iter_pair =
-        incoming_headers.equal_range(kAuthHeader);
+        context.incoming_headers().equal_range(kAuthHeader);
     if (iter_pair.first != iter_pair.second) {
-      *middleware =
-          std::make_shared<BearerAuthServerMiddleware>(incoming_headers, 
&isValid_);
+      *middleware = std::make_shared<BearerAuthServerMiddleware>(
+          context.incoming_headers(), &isValid_);
     }
     return Status::OK();
   }
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc 
b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index f6af142978..9a300d1bd2 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -143,10 +143,10 @@ class TestServerMiddleware : public ServerMiddleware {
 
 class TestServerMiddlewareFactory : public ServerMiddlewareFactory {
  public:
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& 
iter_pair =
-        incoming_headers.equal_range("x-middleware");
+        context.incoming_headers().equal_range("x-middleware");
     std::string received = "";
     if (iter_pair.first != iter_pair.second) {
       const std::string_view& value = (*iter_pair.first).second;
diff --git a/cpp/src/arrow/flight/server_auth.h 
b/cpp/src/arrow/flight/server_auth.h
index 3d4787c0c7..93d3352ba2 100644
--- a/cpp/src/arrow/flight/server_auth.h
+++ b/cpp/src/arrow/flight/server_auth.h
@@ -21,6 +21,7 @@
 
 #include <string>
 
+#include "arrow/flight/type_fwd.h"
 #include "arrow/flight/visibility.h"
 #include "arrow/status.h"
 
@@ -28,8 +29,6 @@ namespace arrow {
 
 namespace flight {
 
-class ServerCallContext;
-
 /// \brief A reader for messages from the client during an
 /// authentication handshake.
 class ARROW_FLIGHT_EXPORT ServerAuthReader {
diff --git a/cpp/src/arrow/flight/server_middleware.cc 
b/cpp/src/arrow/flight/server_middleware.cc
new file mode 100644
index 0000000000..d7ace580dc
--- /dev/null
+++ b/cpp/src/arrow/flight/server_middleware.cc
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/server.h"
+
+namespace arrow {
+namespace flight {
+
+Status ServerMiddlewareFactory::StartCall(const CallInfo& info,
+                                          const ServerCallContext& context,
+                                          std::shared_ptr<ServerMiddleware>* 
middleware) {
+  // TODO: We can make this pure virtual function when we remove
+  // the deprecated version.
+  ARROW_SUPPRESS_DEPRECATION_WARNING
+  return StartCall(info, context.incoming_headers(), middleware);
+  ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+}  // namespace flight
+}  // namespace arrow
diff --git a/cpp/src/arrow/flight/server_middleware.h 
b/cpp/src/arrow/flight/server_middleware.h
index 26431aff01..030f1a17c2 100644
--- a/cpp/src/arrow/flight/server_middleware.h
+++ b/cpp/src/arrow/flight/server_middleware.h
@@ -24,6 +24,7 @@
 #include <string>
 
 #include "arrow/flight/middleware.h"
+#include "arrow/flight/type_fwd.h"
 #include "arrow/flight/visibility.h"  // IWYU pragma: keep
 #include "arrow/status.h"
 
@@ -61,6 +62,22 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
  public:
   virtual ~ServerMiddlewareFactory() = default;
 
+  /// \brief A callback for the start of a new call.
+  ///
+  /// Return a non-OK status to reject the call with the given status.
+  ///
+  /// \param[in] info Information about the call.
+  /// \param[in] context The call context.
+  /// \param[out] middleware The middleware instance for this call. If
+  ///     null, no middleware will be added to this call instance from
+  ///     this factory.
+  /// \return Status A non-OK status will reject the call with the
+  ///     given status. Middleware previously in the chain will have
+  ///     their CallCompleted callback called. Other middleware
+  ///     factories will not be called.
+  virtual Status StartCall(const CallInfo& info, const ServerCallContext& 
context,
+                           std::shared_ptr<ServerMiddleware>* middleware);
+
   /// \brief A callback for the start of a new call.
   ///
   /// Return a non-OK status to reject the call with the given status.
@@ -75,8 +92,13 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
   ///     given status. Middleware previously in the chain will have
   ///     their CallCompleted callback called. Other middleware
   ///     factories will not be called.
+  /// \deprecated Deprecated in 13.0.0. Implement the StartCall()
+  /// with ServerCallContext version instead.
+  ARROW_DEPRECATED("Deprecated in 13.0.0. Use ServerCallContext overload 
instead.")
   virtual Status StartCall(const CallInfo& info, const CallHeaders& 
incoming_headers,
-                           std::shared_ptr<ServerMiddleware>* middleware) = 0;
+                           std::shared_ptr<ServerMiddleware>* middleware) {
+    return Status::NotImplemented(typeid(this).name(), "::StartCall() isn't 
implemented");
+  }
 };
 
 }  // namespace flight
diff --git a/cpp/src/arrow/flight/server_tracing_middleware.cc 
b/cpp/src/arrow/flight/server_tracing_middleware.cc
index 6587db1d1d..b5326d88a4 100644
--- a/cpp/src/arrow/flight/server_tracing_middleware.cc
+++ b/cpp/src/arrow/flight/server_tracing_middleware.cc
@@ -16,6 +16,7 @@
 // under the License.
 
 #include "arrow/flight/server_tracing_middleware.h"
+#include "arrow/flight/server.h"
 
 #include <string>
 #include <string_view>
@@ -122,19 +123,19 @@ class TracingServerMiddleware::Impl {
 class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
  public:
   virtual ~TracingServerMiddlewareFactory() = default;
-  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+  Status StartCall(const CallInfo& info, const ServerCallContext& context,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     constexpr char kServiceName[] = "arrow.flight.protocol.FlightService";
 
-    FlightServerCarrier carrier(incoming_headers);
-    auto context = otel::context::RuntimeContext::GetCurrent();
+    FlightServerCarrier carrier(context.incoming_headers());
+    auto otel_context = otel::context::RuntimeContext::GetCurrent();
     auto propagator =
         
otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
-    auto new_context = propagator->Extract(carrier, context);
+    auto new_otel_context = propagator->Extract(carrier, otel_context);
 
     otel::trace::StartSpanOptions options;
     options.kind = otel::trace::SpanKind::kServer;
-    options.parent = otel::trace::GetSpan(new_context)->GetContext();
+    options.parent = otel::trace::GetSpan(new_otel_context)->GetContext();
 
     auto* tracer = arrow::internal::tracing::GetTracer();
     auto method_name = ToString(info.method);
@@ -167,7 +168,7 @@ class TracingServerMiddleware::Impl {
 class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
  public:
   virtual ~TracingServerMiddlewareFactory() = default;
-  Status StartCall(const CallInfo&, const CallHeaders&,
+  Status StartCall(const CallInfo&, const ServerCallContext&,
                    std::shared_ptr<ServerMiddleware>* middleware) override {
     std::unique_ptr<TracingServerMiddleware::Impl> impl(
         new TracingServerMiddleware::Impl());
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc 
b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index 09d702cd84..dcf9c3f8c9 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -323,8 +323,7 @@ class GrpcServiceHandler final : public 
FlightService::Service {
     GrpcAddServerHeaders outgoing_headers(context);
     for (const auto& factory : middleware_) {
       std::shared_ptr<ServerMiddleware> instance;
-      Status result =
-          factory.second->StartCall(info, flight_context.incoming_headers(), 
&instance);
+      Status result = factory.second->StartCall(info, flight_context, 
&instance);
       if (!result.ok()) {
         // Interceptor rejected call, end the request on all existing
         // interceptors

Reply via email to