This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch python-auth in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 8273106b893e1c68cbd9b967ab742369703e2b6e Author: Yang Xia <[email protected]> AuthorDate: Mon Aug 19 10:05:15 2024 -0700 implement pluggable auth in python with basic and sigv4 as reference --- .../gremlin_python/driver/aiohttp/transport.py | 19 ++--- .../src/main/python/gremlin_python/driver/auth.py | 81 ++++++++++++++++++++++ .../main/python/gremlin_python/driver/client.py | 18 ++--- .../driver/driver_remote_connection.py | 15 ++-- .../main/python/gremlin_python/driver/protocol.py | 21 ++---- .../main/python/gremlin_python/driver/request.py | 8 +-- gremlin-python/src/main/python/tests/conftest.py | 17 +++-- .../src/main/python/tests/driver/test_auth.py | 52 ++++++++++++++ .../src/main/python/tests/driver/test_client.py | 5 +- 9 files changed, 175 insertions(+), 61 deletions(-) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py index 2f4164f1ca..18f8dfdb90 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py @@ -46,6 +46,7 @@ class AiohttpHTTPTransport(AbstractBaseTransport): self._client_session = None self._http_req_resp = None self._enable_ssl = False + self._url = None # Set all inner variables to parameters passed in. self._aiohttp_kwargs = kwargs @@ -65,17 +66,17 @@ class AiohttpHTTPTransport(AbstractBaseTransport): self.close() def connect(self, url, headers=None): + self._url = url # Inner function to perform async connect. async def async_connect(): - # Start client session and use it to send all HTTP requests. Base url is the endpoint, headers are set here - # Base url can only parse basic url with no path, see https://github.com/aio-libs/aiohttp/issues/6647 + # Start client session and use it to send all HTTP requests. Headers can be set here. if self._enable_ssl: # ssl context is established through tcp connector tcp_conn = aiohttp.TCPConnector(ssl_context=self._ssl_context) self._client_session = aiohttp.ClientSession(connector=tcp_conn, - base_url=url, headers=headers, loop=self._loop) + headers=headers, loop=self._loop) else: - self._client_session = aiohttp.ClientSession(base_url=url, headers=headers, loop=self._loop) + self._client_session = aiohttp.ClientSession(headers=headers, loop=self._loop) # Execute the async connect synchronously. self._loop.run_until_complete(async_connect()) @@ -83,13 +84,13 @@ class AiohttpHTTPTransport(AbstractBaseTransport): def write(self, message): # Inner function to perform async write. async def async_write(): - basic_auth = None - # basic password authentication for https connections + # To pass url into message for request authentication processing + message.update({'url': self._url}) if message['auth']: - basic_auth = aiohttp.BasicAuth(message['auth']['username'], message['auth']['password']) + message['auth'].apply(message) + async with async_timeout.timeout(self._write_timeout): - self._http_req_resp = await self._client_session.post(url="/gremlin", - auth=basic_auth, + self._http_req_resp = await self._client_session.post(url=self._url, data=message['payload'], headers=message['headers'], **self._aiohttp_kwargs) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py b/gremlin-python/src/main/python/gremlin_python/driver/auth.py new file mode 100644 index 0000000000..678999e08d --- /dev/null +++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import abc + + +class Auth(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def apply(self, request): + """Applies the necessary authentication operations to the request and returns the modified request.""" + pass + + @staticmethod + def basic(username, password): + return BasicAuth(username, password) + + @staticmethod + def sigv4(region_name, aws_access_key_id='', aws_secret_access_key='', session_token='', service_name=''): + return SigV4Auth(region_name, aws_access_key_id, aws_secret_access_key, session_token, service_name) + + +class BasicAuth(Auth): + + def __init__(self, username, password): + self._username = username + self._password = password + + def apply(self, request): + from aiohttp import BasicAuth as aiohttpBasicAuth + + return request['headers'].update({'authorization': aiohttpBasicAuth(self._username, self._password).encode()}) + + +class SigV4Auth(Auth): + + def __init__(self, region_name, aws_access_key_id='', aws_secret_access_key='', session_token='', + service_name=''): + import os + + self._region_name = region_name + self._aws_access_key_id = aws_access_key_id if aws_access_key_id else os.environ.get('AWS_ACCESS_KEY_ID') + self._aws_secret_access_key = aws_secret_access_key if aws_secret_access_key \ + else os.environ.get('AWS_SECRET_ACCESS_KEY') + self._session_token = session_token if session_token else os.environ.get('AWS_SESSION_TOKEN') + self._service_name = service_name if service_name else "neptune-db" + + def apply(self, request): + from botocore.auth import SigV4Auth as botocoreSigV4Auth + from botocore.awsrequest import AWSRequest + from types import SimpleNamespace + + assert ((self._aws_access_key_id is not None and self._aws_secret_access_key is not None) + or self._session_token is not None), \ + ('No credentials or session token found, please ensure access key and secret key or session tokens ' + 'are provided or set as environment variables.') + + creds = SimpleNamespace( + access_key=self._aws_access_key_id, secret_key=self._aws_secret_access_key, token=self._session_token, + region=self._region_name, + ) + aws_request = AWSRequest(method="POST", url=request['url'], data=request['payload']) + botocoreSigV4Auth(creds, self._service_name, self._region_name).add_auth(aws_request) + request['headers'].update(aws_request.headers) + request['payload'] = aws_request.data + return request diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index 0e217acda4..16bd6f8270 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -19,12 +19,9 @@ import logging import warnings import queue -import re from concurrent.futures import ThreadPoolExecutor from gremlin_python.driver import connection, protocol, request, serializer -from gremlin_python.process import traversal -from gremlin_python.driver.request import TokensV4 log = logging.getLogger("gremlinpython") @@ -44,13 +41,10 @@ class Client: def __init__(self, url, traversal_source, protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - message_serializer=None, username="", password="", headers=None, + message_serializer=None, auth=None, headers=None, enable_user_agent_on_connect=True, **transport_kwargs): log.info("Creating Client with url '%s'", url) - # check via url that we are using http protocol - self._use_http = re.search('^http', url) - self._closed = False self._url = url self._headers = headers @@ -62,8 +56,7 @@ class Client: message_serializer = serializer.GraphBinarySerializersV4() self._message_serializer = message_serializer - self._username = username - self._password = password + self._auth = auth if transport_factory is None: try: @@ -82,8 +75,7 @@ class Client: def protocol_factory(): return protocol.GremlinServerHTTPProtocol( self._message_serializer, - username=self._username, - password=self._password) + auth=self._auth) self._protocol_factory = protocol_factory if pool_size is None: @@ -163,11 +155,11 @@ class Client: if isinstance(message, str): log.debug("fields='%s', gremlin='%s'", str(fields), str(message)) - message = request.RequestMessageV4(fields=fields, gremlin=message) + message = request.RequestMessage(fields=fields, gremlin=message) conn = self._pool.get(True) if request_options: - message.fields.update({token: request_options[token] for token in TokensV4 + message.fields.update({token: request_options[token] for token in request.Tokens if token in request_options and token != 'bindings'}) if 'bindings' in request_options: if 'bindings' in message.fields: diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index 24ae16becb..a48dc0087b 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -21,9 +21,8 @@ from concurrent.futures import Future import warnings from gremlin_python.driver import client, serializer -from gremlin_python.driver.remote_connection import ( - RemoteConnection, RemoteTraversal) -from gremlin_python.driver.request import TokensV4 +from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal +from gremlin_python.driver.request import Tokens log = logging.getLogger("gremlinpython") @@ -34,7 +33,7 @@ class DriverRemoteConnection(RemoteConnection): def __init__(self, url, traversal_source="g", protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - username="", password="", + auth=None, message_serializer=None, headers=None, enable_user_agent_on_connect=True, **transport_kwargs): log.info("Creating DriverRemoteConnection with url '%s'", str(url)) @@ -44,8 +43,7 @@ class DriverRemoteConnection(RemoteConnection): self.__transport_factory = transport_factory self.__pool_size = pool_size self.__max_workers = max_workers - self.__username = username - self.__password = password + self.__auth = auth self.__message_serializer = message_serializer self.__headers = headers self.__enable_user_agent_on_connect = enable_user_agent_on_connect @@ -59,8 +57,7 @@ class DriverRemoteConnection(RemoteConnection): pool_size=pool_size, max_workers=max_workers, message_serializer=message_serializer, - username=username, - password=password, + auth=auth, headers=headers, enable_user_agent_on_connect=enable_user_agent_on_connect, **transport_kwargs) @@ -121,7 +118,7 @@ class DriverRemoteConnection(RemoteConnection): def extract_request_options(gremlin_lang): request_options = {} for os in gremlin_lang.options_strategies: - request_options.update({token: os.configuration[token] for token in TokensV4 + request_options.update({token: os.configuration[token] for token in Tokens if token in os.configuration}) if gremlin_lang.parameters is not None and len(gremlin_lang.parameters) > 0: request_options["params"] = gremlin_lang.parameters diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py index a35c004a2a..69c2d921ed 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py @@ -19,6 +19,7 @@ import json import logging import abc +from gremlin_python.driver.auth import Auth log = logging.getLogger("gremlinpython") @@ -54,12 +55,9 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta): class GremlinServerHTTPProtocol(AbstractBaseProtocol): - def __init__(self, - message_serializer, - username='', password=''): + def __init__(self, message_serializer, auth=None): + self._auth = auth self._message_serializer = message_serializer - self._username = username - self._password = password self._response_msg = {'status': {'code': 0, 'message': '', 'exception': ''}, @@ -71,18 +69,13 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol): super(GremlinServerHTTPProtocol, self).connection_made(transport) def write(self, request_message): - - basic_auth = {} - if self._username and self._password: - basic_auth['username'] = self._username - basic_auth['password'] = self._password - content_type = str(self._message_serializer.version, encoding='utf-8') + message = { - 'headers': {'CONTENT-TYPE': content_type, - 'ACCEPT': content_type}, + 'headers': {'content-type': content_type, + 'accept': content_type}, 'payload': self._message_serializer.serialize_message(request_message), - 'auth': basic_auth + 'auth': self._auth } self._transport.write(message) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/request.py b/gremlin-python/src/main/python/gremlin_python/driver/request.py index 9993802f06..5e04f0bb85 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/request.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/request.py @@ -20,8 +20,8 @@ import collections __author__ = 'David M. Brown ([email protected])' -RequestMessageV4 = collections.namedtuple( - 'RequestMessageV4', ['fields', 'gremlin']) +RequestMessage = collections.namedtuple( + 'RequestMessage', ['fields', 'gremlin']) -TokensV4 = ['batchSize', 'bindings', 'g', 'gremlin', 'language', - 'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent'] +Tokens = ['batchSize', 'bindings', 'g', 'gremlin', 'language', + 'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent'] diff --git a/gremlin-python/src/main/python/tests/conftest.py b/gremlin-python/src/main/python/tests/conftest.py index f362a1f515..c2a82aa066 100644 --- a/gremlin-python/src/main/python/tests/conftest.py +++ b/gremlin-python/src/main/python/tests/conftest.py @@ -24,18 +24,17 @@ import pytest import logging import queue - from gremlin_python.driver.client import Client from gremlin_python.driver.connection import Connection from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.protocol import GremlinServerHTTPProtocol from gremlin_python.driver.serializer import GraphBinarySerializersV4 from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport - +from gremlin_python.driver.auth import Auth """HTTP server testing variables""" -gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/') -gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/') +gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/gremlin') +gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/gremlin') anonymous_url = gremlin_server_url.format(45940) basic_url = gremlin_basic_auth_url.format(45941) @@ -44,7 +43,6 @@ verbose_logging = False logging.basicConfig(format='%(asctime)s [%(levelname)8s] [%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s', level=logging.DEBUG if verbose_logging else logging.INFO) - """ Tests below are for the HTTP server with GraphBinaryV4 """ @@ -52,7 +50,7 @@ Tests below are for the HTTP server with GraphBinaryV4 def connection(request): protocol = GremlinServerHTTPProtocol( GraphBinarySerializersV4(), - username='stephen', password='password') + auth=Auth.basic('stephen', 'password')) executor = concurrent.futures.ThreadPoolExecutor(5) pool = queue.Queue() try: @@ -91,7 +89,8 @@ def authenticated_client(request): # turn off certificate verification for testing purposes only ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_opts.verify_mode = ssl.CERT_NONE - client = Client(basic_url, 'gmodern', username='stephen', password='password', + client = Client(basic_url, 'gmodern', + auth=Auth.basic('stephen', 'password'), transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts)) else: raise ValueError("Invalid authentication option - " + request.param) @@ -144,7 +143,7 @@ def remote_connection_crew(request): return remote_conn -# TODO: revisit once auth is updated +# TODO: revisit once testing for multiple types of auth is enabled @pytest.fixture(params=['basic']) def remote_connection_authenticated(request): try: @@ -153,7 +152,7 @@ def remote_connection_authenticated(request): ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_opts.verify_mode = ssl.CERT_NONE remote_conn = DriverRemoteConnection(basic_url, 'gmodern', - username='stephen', password='password', + auth=Auth.basic('stephen', 'password'), transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts)) else: raise ValueError("Invalid authentication option - " + request.param) diff --git a/gremlin-python/src/main/python/tests/driver/test_auth.py b/gremlin-python/src/main/python/tests/driver/test_auth.py new file mode 100644 index 0000000000..fbd516366f --- /dev/null +++ b/gremlin-python/src/main/python/tests/driver/test_auth.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from aiohttp import BasicAuth as aiohttpBasicAuth + +from src.main.python.gremlin_python.driver.auth import Auth + + +def create_mock_request(): + return {'headers': + {'content-type': 'application/vnd.graphbinary-v4.0', + 'accept': 'application/vnd.graphbinary-v4.0'}, + 'payload': b'', + 'url': 'https://test_url:8182/gremlin'} + + +class TestAuth(object): + + def test_basic_auth_request(self): + mock_request = create_mock_request() + assert 'authorization' not in mock_request['headers'] + Auth.basic('username', 'password').apply(mock_request) + assert 'authorization' in mock_request['headers'] + assert aiohttpBasicAuth('username', 'password').encode() == mock_request['headers']['authorization'] + + def test_sigv4_auth_request(self): + mock_request = create_mock_request() + assert 'Authorization' not in mock_request['headers'] + assert 'X-Amz-Date' not in mock_request['headers'] + Auth.sigv4('us-west-2', 'MOCK_ID', 'MOCK_KEY').apply(mock_request) + print(mock_request) + assert mock_request['headers']['X-Amz-Date'] is not None + assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') + assert 'us-west-2/neptune-db/aws4_request' in mock_request['headers']['Authorization'] + assert 'Signature=' in mock_request['headers']['Authorization'] + + diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py b/gremlin-python/src/main/python/tests/driver/test_client.py index ce180d7bf3..41b8a87a7a 100644 --- a/gremlin-python/src/main/python/tests/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/driver/test_client.py @@ -25,7 +25,7 @@ import pytest from gremlin_python.driver.client import Client from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.protocol import GremlinServerError -from gremlin_python.driver.request import RequestMessageV4 +from gremlin_python.driver.request import RequestMessage from gremlin_python.process.graph_traversal import __, GraphTraversalSource from gremlin_python.process.traversal import TraversalStrategies, Parameter from gremlin_python.process.strategies import OptionsStrategy @@ -41,8 +41,7 @@ test_no_auth_url = gremlin_server_url.format(45940) def create_basic_request_message(traversal, source='gmodern'): - msg = RequestMessageV4(fields={'g': source}, gremlin=traversal.gremlin_lang.get_gremlin()) - return msg + return RequestMessage(fields={'g': source}, gremlin=traversal.gremlin_lang.get_gremlin()) def test_connection(connection):
