tifflhl commented on a change in pull request #8959:
URL: https://github.com/apache/arrow/pull/8959#discussion_r548374081
##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -506,6 +505,162 @@ def get_token(self):
return self.token
+class NoopAuthHandler(ServerAuthHandler):
+ """A no-op auth handler."""
+
+ def authenticate(self, outgoing, incoming):
+ """Do nothing."""
+
+ def is_valid(self, token):
+ """
+ Returning an empty string.
+ Returning None causes Type error.
+ """
+ return ""
+
+
+def case_insensitive_header_lookup(headers, lookup_key):
+ """Lookup the value of given key in the given headers.
+ The key lookup is case insensitive.
+ """
+ for key in headers:
+ if key.lower() == lookup_key.lower():
+ return headers.get(key)
+
+ raise flight.FlightUnauthenticatedError(
+ 'No authorization header found.')
+
+
+class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory):
+ """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware."""
+
+ def __init__(self):
+ self.call_credential = []
+
+ def start_call(self, info):
+ return ClientHeaderAuthMiddleware(self)
+
+ def set_call_credential(self, call_credential):
+ self.call_credential = call_credential
+
+
+class ClientHeaderAuthMiddleware(ClientMiddleware):
+ """
+ ClientMiddleware that extracts the authorization header
+ from the server.
+
+ This is an example of a ClientMiddleware that can extract
+ the bearer token authorization header from a HTTP header
+ authentication enabled server.
+
+ Parameters
+ ----------
+ factory : ClientHeaderAuthMiddlewareFactory
+ This factory is used to set call credentials if an
+ authorization header is found in the headers from the server.
+ """
+
+ def __init__(self, factory):
+ self.factory = factory
+
+ def received_headers(self, headers):
+ auth_header = case_insensitive_header_lookup(headers, 'Authorization')
Review comment:
The two are interchangeable.
##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -506,6 +505,162 @@ def get_token(self):
return self.token
+class NoopAuthHandler(ServerAuthHandler):
+ """A no-op auth handler."""
+
+ def authenticate(self, outgoing, incoming):
+ """Do nothing."""
+
+ def is_valid(self, token):
+ """
+ Returning an empty string.
+ Returning None causes Type error.
+ """
+ return ""
+
+
+def case_insensitive_header_lookup(headers, lookup_key):
+ """Lookup the value of given key in the given headers.
+ The key lookup is case insensitive.
+ """
+ for key in headers:
+ if key.lower() == lookup_key.lower():
+ return headers.get(key)
+
+ raise flight.FlightUnauthenticatedError(
+ 'No authorization header found.')
+
+
+class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory):
+ """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware."""
+
+ def __init__(self):
+ self.call_credential = []
+
+ def start_call(self, info):
+ return ClientHeaderAuthMiddleware(self)
+
+ def set_call_credential(self, call_credential):
+ self.call_credential = call_credential
+
+
+class ClientHeaderAuthMiddleware(ClientMiddleware):
+ """
+ ClientMiddleware that extracts the authorization header
+ from the server.
+
+ This is an example of a ClientMiddleware that can extract
+ the bearer token authorization header from a HTTP header
+ authentication enabled server.
+
+ Parameters
+ ----------
+ factory : ClientHeaderAuthMiddlewareFactory
+ This factory is used to set call credentials if an
+ authorization header is found in the headers from the server.
+ """
+
+ def __init__(self, factory):
+ self.factory = factory
+
+ def received_headers(self, headers):
+ auth_header = case_insensitive_header_lookup(headers, 'Authorization')
+ self.factory.set_call_credential([
+ b'authorization',
+ auth_header[0].encode("utf-8")])
+
+
+class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Validates incoming username and password."""
+
+ def start_call(self, info, headers):
+ auth_header = case_insensitive_header_lookup(
+ headers,
+ 'Authorization'
+ )
+ values = auth_header[0].split(' ')
+ token = ''
+
+ if values[0] == 'Basic':
+ decoded = base64.b64decode(values[1])
+ pair = decoded.decode("utf-8").split(':')
+ if not (pair[0] == 'test' and pair[1] == 'password'):
+ raise flight.FlightUnauthenticatedError('Invalid credentials')
Review comment:
Addressed.
##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable):
self.client.get().Authenticate(deref(c_options),
move(handler)))
+ def authenticateBasicToken(self, username, password,
Review comment:
Addressed.
##########
File path: python/pyarrow/_flight.pyx
##########
@@ -118,12 +118,18 @@ cdef class FlightCallOptions(_Weakrefable):
write_options : pyarrow.ipc.IpcWriteOptions, optional
IPC write options. The default options can be controlled
by environment variables (see pyarrow.ipc).
-
+ headers : vector[pair[c_string, c_string]], optional
Review comment:
Addressed.
##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable):
self.client.get().Authenticate(deref(c_options),
move(handler)))
+ def authenticateBasicToken(self, username, password,
+ options: FlightCallOptions = None):
+ """Authenticate to the server with HTTP basic authentication.
+
+ Parameters
+ ----------
+ username : string
+ Username to authenticate with
+ password : string
+ Password to authenticate with
+ options : FlightCallOptions
+ Options for this call
+
+ Returns
+ -------
+ pair : pair[string, string]
Review comment:
Addressed.
##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -506,6 +505,162 @@ def get_token(self):
return self.token
+class NoopAuthHandler(ServerAuthHandler):
+ """A no-op auth handler."""
+
+ def authenticate(self, outgoing, incoming):
+ """Do nothing."""
+
+ def is_valid(self, token):
+ """
+ Returning an empty string.
+ Returning None causes Type error.
+ """
+ return ""
+
+
+def case_insensitive_header_lookup(headers, lookup_key):
Review comment:
I modified the function to not raise an exception when the header is not
found. The method is used in a test not related to authentication as well.
##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -996,6 +1152,100 @@ def test_token_auth_invalid():
client.authenticate(TokenClientAuthHandler('test', 'wrong'))
+header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
+no_op_auth_handler = NoopAuthHandler()
+
+
+def test_authenticate_basic_token():
+ """Test authenticateBasicToken with bearer token and auth headers."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticateBasicToken(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer ' + b'token1234'
+
+
+def test_authenticate_basic_token_invalid_password():
+ """Test authenticateBasicToken with an invalid password."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(flight.FlightUnauthenticatedError):
+ client.authenticateBasicToken(b'test', b'badpassword')
+
+
+def test_authenticate_basic_token_and_action():
+ """Test authenticateBasicToken and doAction after authentication."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticateBasicToken(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer ' + b'token1234'
+ options = flight.FlightCallOptions(headers=[token_pair])
+ result = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result[0].body.to_pybytes() == b'token1234'
+
+
+def test_authenticate_basic_token_with_client_middleware():
+ """Test authenticateBasicToken with client middleware
+ to intercept authorization header returned by the
+ HTTP header auth enabled server.
+ """
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client_auth_middleware = ClientHeaderAuthMiddlewareFactory()
+ client = FlightClient(
+ ('localhost', server.port),
+ middleware=[client_auth_middleware]
+ )
+ encoded_credentials = base64.b64encode(b'test:password')
+ options = flight.FlightCallOptions(headers=[
+ (b'authorization', b'Basic ' + encoded_credentials)
+ ])
+ result = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result[0].body.to_pybytes() == b'token1234'
+ assert client_auth_middleware.call_credential[0] == b'authorization'
+ assert client_auth_middleware.call_credential[1] == \
+ b'Bearer ' + b'token1234'
+ result2 = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result2[0].body.to_pybytes() == b'token1234'
+ assert client_auth_middleware.call_credential[0] == b'authorization'
+ assert client_auth_middleware.call_credential[1] == \
+ b'Bearer ' + b'token1234'
+
+
+def test_arbitrary_headers_in_flight_call_options():
+ """Test passing multiple arbitrary headers to the middleware."""
+ with ArbitraryHeadersFlightServer(
+ auth_handler=no_op_auth_handler,
+ middleware={
+ "auth": HeaderAuthServerMiddlewareFactory(),
+ "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticateBasicToken(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer ' + b'token1234'
Review comment:
Addressed.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]