This is an automated email from the ASF dual-hosted git repository. kenhuuu pushed a commit to branch master-http-final in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 98b0141e502aa19ba5edc80c275d7bac20ea6aa1 Author: Yang Xia <[email protected]> AuthorDate: Wed Aug 28 14:01:59 2024 -0700 implement pluggable auth in python with basic and sigv4 as reference (#2731) --- CHANGELOG.asciidoc | 1 + .../gremlin_python/driver/aiohttp/transport.py | 19 ++++--- .../src/main/python/gremlin_python/driver/auth.py | 55 ++++++++++++++++++ .../main/python/gremlin_python/driver/client.py | 18 ++---- .../driver/driver_remote_connection.py | 15 ++--- .../main/python/gremlin_python/driver/protocol.py | 21 ++----- .../main/python/gremlin_python/driver/request.py | 8 +-- gremlin-python/src/main/python/tests/conftest.py | 17 +++--- .../src/main/python/tests/driver/test_auth.py | 66 ++++++++++++++++++++++ .../src/main/python/tests/driver/test_client.py | 5 +- 10 files changed, 163 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index 79daf15dbd..7a79c9d848 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -59,6 +59,7 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima * `EmbeddedRemoteConnection` will use `Gremlinlang`, not `JavaTranslator`. * Java `Client` will no longer support submitting traversals. `DriverRemoteConnection` should be used instead. * Removed usage of `Bytecode` from `gremlin-python`. +* Added `auth` module in `gremlin-python` for pluggable authentication. * Fixed `GremlinLangScriptEngine` handling for some strategies. * Modified the `split()` step to split a string into a list of its characters if the given separator is an empty string. diff --git a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py index 4e5eca79d0..b2827e9085 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py @@ -51,6 +51,7 @@ class AiohttpHTTPTransport(AbstractBaseTransport): self._client_session = None self._http_req_resp = None self._enable_ssl = False + self._url = None # Set all inner variables to parameters passed in. self._aiohttp_kwargs = kwargs @@ -70,17 +71,17 @@ class AiohttpHTTPTransport(AbstractBaseTransport): self.close() def connect(self, url, headers=None): + self._url = url # Inner function to perform async connect. async def async_connect(): - # Start client session and use it to send all HTTP requests. Base url is the endpoint, headers are set here - # Base url can only parse basic url with no path, see https://github.com/aio-libs/aiohttp/issues/6647 + # Start client session and use it to send all HTTP requests. Headers can be set here. if self._enable_ssl: # ssl context is established through tcp connector tcp_conn = aiohttp.TCPConnector(ssl_context=self._ssl_context) self._client_session = aiohttp.ClientSession(connector=tcp_conn, - base_url=url, headers=headers, loop=self._loop) + headers=headers, loop=self._loop) else: - self._client_session = aiohttp.ClientSession(base_url=url, headers=headers, loop=self._loop) + self._client_session = aiohttp.ClientSession(headers=headers, loop=self._loop) # Execute the async connect synchronously. self._loop.run_until_complete(async_connect()) @@ -88,13 +89,13 @@ class AiohttpHTTPTransport(AbstractBaseTransport): def write(self, message): # Inner function to perform async write. async def async_write(): - basic_auth = None - # basic password authentication for https connections + # To pass url into message for request authentication processing + message.update({'url': self._url}) if message['auth']: - basic_auth = aiohttp.BasicAuth(message['auth']['username'], message['auth']['password']) + message['auth'](message) + async with async_timeout.timeout(self._write_timeout): - self._http_req_resp = await self._client_session.post(url="/gremlin", - auth=basic_auth, + self._http_req_resp = await self._client_session.post(url=self._url, data=message['payload'], headers=message['headers'], **self._aiohttp_kwargs) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py b/gremlin-python/src/main/python/gremlin_python/driver/auth.py new file mode 100644 index 0000000000..dd7a1610ef --- /dev/null +++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +def basic(username, password): + from aiohttp import BasicAuth as aiohttpBasicAuth + + def apply(request): + return request['headers'].update({'authorization': aiohttpBasicAuth(username, password).encode()}) + + return apply + + +def sigv4(region, service): + import os + from boto3 import Session + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + + def apply(request): + access_key = os.environ.get('AWS_ACCESS_KEY_ID', '') + secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', '') + session_token = os.environ.get('AWS_SESSION_TOKEN', '') + + session = Session( + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + aws_session_token=session_token, + region_name=region + ) + + sigv4_request = AWSRequest(method="POST", url=request['url'], data=request['payload']) + SigV4Auth(session.get_credentials(), service, region).add_auth(sigv4_request) + request['headers'].update(sigv4_request.headers) + request['payload'] = sigv4_request.data + return request + + return apply + diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index acd9ab5590..420610708e 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -19,12 +19,9 @@ import logging import warnings import queue -import re from concurrent.futures import ThreadPoolExecutor from gremlin_python.driver import connection, protocol, request, serializer -from gremlin_python.process import traversal -from gremlin_python.driver.request import TokensV4 log = logging.getLogger("gremlinpython") @@ -44,13 +41,10 @@ class Client: def __init__(self, url, traversal_source, protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - message_serializer=None, username="", password="", headers=None, + message_serializer=None, auth=None, headers=None, enable_user_agent_on_connect=True, **transport_kwargs): log.info("Creating Client with url '%s'", url) - # check via url that we are using http protocol - self._use_http = re.search('^http', url) - self._closed = False self._url = url self._headers = headers @@ -62,8 +56,7 @@ class Client: message_serializer = serializer.GraphBinarySerializersV4() self._message_serializer = message_serializer - self._username = username - self._password = password + self._auth = auth if transport_factory is None: try: @@ -82,8 +75,7 @@ class Client: def protocol_factory(): return protocol.GremlinServerHTTPProtocol( self._message_serializer, - username=self._username, - password=self._password) + auth=self._auth) self._protocol_factory = protocol_factory if pool_size is None: @@ -163,11 +155,11 @@ class Client: if isinstance(message, str): log.debug("fields='%s', gremlin='%s'", str(fields), str(message)) - message = request.RequestMessageV4(fields=fields, gremlin=message) + message = request.RequestMessage(fields=fields, gremlin=message) conn = self._pool.get(True) if request_options: - message.fields.update({token: request_options[token] for token in TokensV4 + message.fields.update({token: request_options[token] for token in request.Tokens if token in request_options and token != 'bindings'}) if 'bindings' in request_options: if 'bindings' in message.fields: diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index 24ae16becb..a48dc0087b 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -21,9 +21,8 @@ from concurrent.futures import Future import warnings from gremlin_python.driver import client, serializer -from gremlin_python.driver.remote_connection import ( - RemoteConnection, RemoteTraversal) -from gremlin_python.driver.request import TokensV4 +from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal +from gremlin_python.driver.request import Tokens log = logging.getLogger("gremlinpython") @@ -34,7 +33,7 @@ class DriverRemoteConnection(RemoteConnection): def __init__(self, url, traversal_source="g", protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - username="", password="", + auth=None, message_serializer=None, headers=None, enable_user_agent_on_connect=True, **transport_kwargs): log.info("Creating DriverRemoteConnection with url '%s'", str(url)) @@ -44,8 +43,7 @@ class DriverRemoteConnection(RemoteConnection): self.__transport_factory = transport_factory self.__pool_size = pool_size self.__max_workers = max_workers - self.__username = username - self.__password = password + self.__auth = auth self.__message_serializer = message_serializer self.__headers = headers self.__enable_user_agent_on_connect = enable_user_agent_on_connect @@ -59,8 +57,7 @@ class DriverRemoteConnection(RemoteConnection): pool_size=pool_size, max_workers=max_workers, message_serializer=message_serializer, - username=username, - password=password, + auth=auth, headers=headers, enable_user_agent_on_connect=enable_user_agent_on_connect, **transport_kwargs) @@ -121,7 +118,7 @@ class DriverRemoteConnection(RemoteConnection): def extract_request_options(gremlin_lang): request_options = {} for os in gremlin_lang.options_strategies: - request_options.update({token: os.configuration[token] for token in TokensV4 + request_options.update({token: os.configuration[token] for token in Tokens if token in os.configuration}) if gremlin_lang.parameters is not None and len(gremlin_lang.parameters) > 0: request_options["params"] = gremlin_lang.parameters diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py index a35c004a2a..b286c9dec9 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. # -import json import logging import abc @@ -54,12 +53,9 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta): class GremlinServerHTTPProtocol(AbstractBaseProtocol): - def __init__(self, - message_serializer, - username='', password=''): + def __init__(self, message_serializer, auth=None): + self._auth = auth self._message_serializer = message_serializer - self._username = username - self._password = password self._response_msg = {'status': {'code': 0, 'message': '', 'exception': ''}, @@ -71,18 +67,13 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol): super(GremlinServerHTTPProtocol, self).connection_made(transport) def write(self, request_message): - - basic_auth = {} - if self._username and self._password: - basic_auth['username'] = self._username - basic_auth['password'] = self._password - content_type = str(self._message_serializer.version, encoding='utf-8') + message = { - 'headers': {'CONTENT-TYPE': content_type, - 'ACCEPT': content_type}, + 'headers': {'content-type': content_type, + 'accept': content_type}, 'payload': self._message_serializer.serialize_message(request_message), - 'auth': basic_auth + 'auth': self._auth } self._transport.write(message) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/request.py b/gremlin-python/src/main/python/gremlin_python/driver/request.py index 9993802f06..5e04f0bb85 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/request.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/request.py @@ -20,8 +20,8 @@ import collections __author__ = 'David M. Brown ([email protected])' -RequestMessageV4 = collections.namedtuple( - 'RequestMessageV4', ['fields', 'gremlin']) +RequestMessage = collections.namedtuple( + 'RequestMessage', ['fields', 'gremlin']) -TokensV4 = ['batchSize', 'bindings', 'g', 'gremlin', 'language', - 'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent'] +Tokens = ['batchSize', 'bindings', 'g', 'gremlin', 'language', + 'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent'] diff --git a/gremlin-python/src/main/python/tests/conftest.py b/gremlin-python/src/main/python/tests/conftest.py index f362a1f515..9e353c3d17 100644 --- a/gremlin-python/src/main/python/tests/conftest.py +++ b/gremlin-python/src/main/python/tests/conftest.py @@ -24,18 +24,17 @@ import pytest import logging import queue - from gremlin_python.driver.client import Client from gremlin_python.driver.connection import Connection from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.protocol import GremlinServerHTTPProtocol from gremlin_python.driver.serializer import GraphBinarySerializersV4 from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport - +from gremlin_python.driver.auth import basic, sigv4 """HTTP server testing variables""" -gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/') -gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/') +gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/gremlin') +gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/gremlin') anonymous_url = gremlin_server_url.format(45940) basic_url = gremlin_basic_auth_url.format(45941) @@ -44,7 +43,6 @@ verbose_logging = False logging.basicConfig(format='%(asctime)s [%(levelname)8s] [%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s', level=logging.DEBUG if verbose_logging else logging.INFO) - """ Tests below are for the HTTP server with GraphBinaryV4 """ @@ -52,7 +50,7 @@ Tests below are for the HTTP server with GraphBinaryV4 def connection(request): protocol = GremlinServerHTTPProtocol( GraphBinarySerializersV4(), - username='stephen', password='password') + auth=basic('stephen', 'password')) executor = concurrent.futures.ThreadPoolExecutor(5) pool = queue.Queue() try: @@ -91,7 +89,8 @@ def authenticated_client(request): # turn off certificate verification for testing purposes only ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_opts.verify_mode = ssl.CERT_NONE - client = Client(basic_url, 'gmodern', username='stephen', password='password', + client = Client(basic_url, 'gmodern', + auth=basic('stephen', 'password'), transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts)) else: raise ValueError("Invalid authentication option - " + request.param) @@ -144,7 +143,7 @@ def remote_connection_crew(request): return remote_conn -# TODO: revisit once auth is updated +# TODO: revisit once testing for multiple types of auth is enabled @pytest.fixture(params=['basic']) def remote_connection_authenticated(request): try: @@ -153,7 +152,7 @@ def remote_connection_authenticated(request): ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_opts.verify_mode = ssl.CERT_NONE remote_conn = DriverRemoteConnection(basic_url, 'gmodern', - username='stephen', password='password', + auth=basic('stephen', 'password'), transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts)) else: raise ValueError("Invalid authentication option - " + request.param) diff --git a/gremlin-python/src/main/python/tests/driver/test_auth.py b/gremlin-python/src/main/python/tests/driver/test_auth.py new file mode 100644 index 0000000000..584cf1ec68 --- /dev/null +++ b/gremlin-python/src/main/python/tests/driver/test_auth.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +from aiohttp import BasicAuth as aiohttpBasicAuth +from gremlin_python.driver.auth import basic, sigv4 + + +def create_mock_request(): + return {'headers': + {'content-type': 'application/vnd.graphbinary-v4.0', + 'accept': 'application/vnd.graphbinary-v4.0'}, + 'payload': b'', + 'url': 'https://test_url:8182/gremlin'} + + +class TestAuth(object): + + def test_basic_auth_request(self): + mock_request = create_mock_request() + assert 'authorization' not in mock_request['headers'] + basic('username', 'password')(mock_request) + assert 'authorization' in mock_request['headers'] + assert aiohttpBasicAuth('username', 'password').encode() == mock_request['headers']['authorization'] + + def test_sigv4_auth_request(self): + mock_request = create_mock_request() + assert 'Authorization' not in mock_request['headers'] + assert 'X-Amz-Date' not in mock_request['headers'] + os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY' + sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) + assert mock_request['headers']['X-Amz-Date'] is not None + assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] + assert 'Signature=' in mock_request['headers']['Authorization'] + + def test_sigv4_auth_request_session_token(self): + mock_request = create_mock_request() + assert 'Authorization' not in mock_request['headers'] + assert 'X-Amz-Date' not in mock_request['headers'] + assert 'X-Amz-Security-Token' not in mock_request['headers'] + os.environ['AWS_SESSION_TOKEN'] = 'MOCK_TOKEN' + sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) + assert mock_request['headers']['X-Amz-Date'] is not None + assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=') + assert mock_request['headers']['X-Amz-Security-Token'] == 'MOCK_TOKEN' + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] + assert 'Signature=' in mock_request['headers']['Authorization'] + + diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py b/gremlin-python/src/main/python/tests/driver/test_client.py index 5cc7130f7d..fbaaed6d34 100644 --- a/gremlin-python/src/main/python/tests/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/driver/test_client.py @@ -25,7 +25,7 @@ import pytest from gremlin_python.driver.client import Client from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.protocol import GremlinServerError -from gremlin_python.driver.request import RequestMessageV4 +from gremlin_python.driver.request import RequestMessage from gremlin_python.process.graph_traversal import __, GraphTraversalSource from gremlin_python.process.traversal import TraversalStrategies, Parameter from gremlin_python.process.strategies import OptionsStrategy @@ -41,8 +41,7 @@ test_no_auth_url = gremlin_server_url.format(45940) def create_basic_request_message(traversal, source='gmodern'): - msg = RequestMessageV4(fields={'g': source}, gremlin=traversal.gremlin_lang.get_gremlin()) - return msg + return RequestMessage(fields={'g': source}, gremlin=traversal.gremlin_lang.get_gremlin()) def test_connection(connection):
