lidavidm commented on a change in pull request #8959:
URL: https://github.com/apache/arrow/pull/8959#discussion_r548194584



##########
File path: python/pyarrow/_flight.pyx
##########
@@ -118,12 +118,18 @@ cdef class FlightCallOptions(_Weakrefable):
         write_options : pyarrow.ipc.IpcWriteOptions, optional
             IPC write options. The default options can be controlled
             by environment variables (see pyarrow.ipc).
-
+        headers : vector[pair[c_string, c_string]], optional

Review comment:
       This type hint should use Python conventions (e.g. `List[Tuple[str, 
str]]`)

##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable):
                 self.client.get().Authenticate(deref(c_options),
                                                move(handler)))
 
+    def authenticateBasicToken(self, username, password,

Review comment:
       Let's follow Python naming conventions - this should use snake_case.

##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -506,6 +505,162 @@ def get_token(self):
         return self.token
 
 
+class NoopAuthHandler(ServerAuthHandler):
+    """A no-op auth handler."""
+
+    def authenticate(self, outgoing, incoming):
+        """Do nothing."""
+
+    def is_valid(self, token):
+        """
+        Returning an empty string.
+        Returning None causes Type error.
+        """
+        return ""
+
+
+def case_insensitive_header_lookup(headers, lookup_key):

Review comment:
       This is specifically to extract an authentication header based n the 
exception it raises, let's make sure the name reflects that.

##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1871,7 +1909,6 @@ cdef CStatus _server_authenticate(void* self, 
CServerAuthSender* outgoing,
         reader.poison()

Review comment:
       Not a typo.

##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -996,6 +1152,100 @@ def test_token_auth_invalid():
             client.authenticate(TokenClientAuthHandler('test', 'wrong'))
 
 
+header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
+no_op_auth_handler = NoopAuthHandler()
+
+
+def test_authenticate_basic_token():
+    """Test authenticateBasicToken with bearer token and auth headers."""
+    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+        "auth": HeaderAuthServerMiddlewareFactory()
+    }) as server:
+        client = FlightClient(('localhost', server.port))
+        token_pair = client.authenticateBasicToken(b'test', b'password')
+        assert token_pair[0] == b'authorization'
+        assert token_pair[1] == b'Bearer ' + b'token1234'
+
+
+def test_authenticate_basic_token_invalid_password():
+    """Test authenticateBasicToken with an invalid password."""
+    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+        "auth": HeaderAuthServerMiddlewareFactory()
+    }) as server:
+        client = FlightClient(('localhost', server.port))
+        with pytest.raises(flight.FlightUnauthenticatedError):
+            client.authenticateBasicToken(b'test', b'badpassword')
+
+
+def test_authenticate_basic_token_and_action():
+    """Test authenticateBasicToken and doAction after authentication."""
+    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+        "auth": HeaderAuthServerMiddlewareFactory()
+    }) as server:
+        client = FlightClient(('localhost', server.port))
+        token_pair = client.authenticateBasicToken(b'test', b'password')
+        assert token_pair[0] == b'authorization'
+        assert token_pair[1] == b'Bearer ' + b'token1234'
+        options = flight.FlightCallOptions(headers=[token_pair])
+        result = list(client.do_action(
+            action=flight.Action('test-action', b''), options=options))
+        assert result[0].body.to_pybytes() == b'token1234'
+
+
+def test_authenticate_basic_token_with_client_middleware():
+    """Test authenticateBasicToken with client middleware
+       to intercept authorization header returned by the
+       HTTP header auth enabled server.
+    """
+    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+        "auth": HeaderAuthServerMiddlewareFactory()
+    }) as server:
+        client_auth_middleware = ClientHeaderAuthMiddlewareFactory()
+        client = FlightClient(
+            ('localhost', server.port),
+            middleware=[client_auth_middleware]
+        )
+        encoded_credentials = base64.b64encode(b'test:password')
+        options = flight.FlightCallOptions(headers=[
+            (b'authorization', b'Basic ' + encoded_credentials)
+        ])
+        result = list(client.do_action(
+            action=flight.Action('test-action', b''), options=options))
+        assert result[0].body.to_pybytes() == b'token1234'
+        assert client_auth_middleware.call_credential[0] == b'authorization'
+        assert client_auth_middleware.call_credential[1] == \
+            b'Bearer ' + b'token1234'
+        result2 = list(client.do_action(
+            action=flight.Action('test-action', b''), options=options))
+        assert result2[0].body.to_pybytes() == b'token1234'
+        assert client_auth_middleware.call_credential[0] == b'authorization'
+        assert client_auth_middleware.call_credential[1] == \
+            b'Bearer ' + b'token1234'
+
+
+def test_arbitrary_headers_in_flight_call_options():
+    """Test passing multiple arbitrary headers to the middleware."""
+    with ArbitraryHeadersFlightServer(
+            auth_handler=no_op_auth_handler,
+            middleware={
+                "auth": HeaderAuthServerMiddlewareFactory(),
+                "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
+            }) as server:
+        client = FlightClient(('localhost', server.port))
+        token_pair = client.authenticateBasicToken(b'test', b'password')
+        assert token_pair[0] == b'authorization'
+        assert token_pair[1] == b'Bearer ' + b'token1234'

Review comment:
       nit: why concat here? (and above)

##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable):
                 self.client.get().Authenticate(deref(c_options),
                                                move(handler)))
 
+    def authenticateBasicToken(self, username, password,

Review comment:
       self is never documented - it's implicitly passed by Python (it's 
equivalent to `this` in Java et al)

##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable):
                 self.client.get().Authenticate(deref(c_options),
                                                move(handler)))
 
+    def authenticateBasicToken(self, username, password,
+                               options: FlightCallOptions = None):
+        """Authenticate to the server with HTTP basic authentication.
+
+        Parameters
+        ----------
+        username : string
+            Username to authenticate with
+        password : string
+            Password to authenticate with
+        options  : FlightCallOptions
+            Options for this call
+
+        Returns
+        -------
+        pair : pair[string, string]

Review comment:
       Same for the type hints here - use `str` and `Tuple`

##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -506,6 +505,95 @@ def get_token(self):
         return self.token
 
 
+class NoopAuthHandler(ServerAuthHandler):
+    """A no-op auth handler."""
+
+    def authenticate(self, outgoing, incoming):
+        """Do nothing."""
+
+    def is_valid(self, token):
+        """Do nothing."""
+        return ""
+
+
+class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory):
+    """Validates incoming username and password."""
+
+    def start_call(self, info, headers):
+        auth_header = headers.get('authorization')
+        values = auth_header[0].split(' ')
+        token = ''
+
+        if values[0] == 'Basic':
+            decoded = base64.b64decode(values[1])
+            pair = decoded.decode("utf-8").split(':')
+            if not (pair[0] == 'test' and pair[1] == 'password'):
+                raise flight.FlightUnauthenticatedError('Invalid credentials')
+            token = 'token1234'
+        elif values[0] == 'Bearer':
+            token = values[1]
+            if not token == 'token1234':
+                raise flight.FlightUnauthenticatedError('Invalid credentials')
+        else:
+            raise flight.FlightUnauthenticatedError('Invalid credentials')
+
+        return HeaderAuthServerMiddleware(token)
+
+
+class HeaderAuthServerMiddleware(ServerMiddleware):
+    """A ServerMiddleware that transports incoming username and passowrd."""
+
+    def __init__(self, token):
+        self.token = token
+
+    def sending_headers(self):
+        return {'authorization': 'Bearer ' + self.token}

Review comment:
       Headers are supposed to be treated case-insensitively so even though 
other languages may use Authorization, it will all get folded to the same case

##########
File path: python/pyarrow/includes/libarrow_flight.pxd
##########
@@ -307,6 +308,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" 
nogil:
         CStatus Authenticate(CFlightCallOptions& options,
                              unique_ptr[CClientAuthHandler] auth_handler)
 
+        CResult[pair[c_string, c_string]] AuthenticateBasicToken(
+            CFlightCallOptions& options,
+            const c_string& username,

Review comment:
       The mixture here is because this is a set of Cython type definitions for 
C++ code, which uses the C++ conventions (e.g. CResult), but the overall 
project is in Python, which uses Python conventions (e.g. c_string)

##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1871,7 +1909,6 @@ cdef CStatus _server_authenticate(void* self, 
CServerAuthSender* outgoing,
         reader.poison()

Review comment:
       This is to ensure a server doesn't use the Python reader beyond the 
lifetime of the C++ reader it wraps. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to