This is an automated email from the ASF dual-hosted git repository.

xiazcy pushed a commit to branch python-auth
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 8273106b893e1c68cbd9b967ab742369703e2b6e
Author: Yang Xia <[email protected]>
AuthorDate: Mon Aug 19 10:05:15 2024 -0700

    implement pluggable auth in python with basic and sigv4 as reference
---
 .../gremlin_python/driver/aiohttp/transport.py     | 19 ++---
 .../src/main/python/gremlin_python/driver/auth.py  | 81 ++++++++++++++++++++++
 .../main/python/gremlin_python/driver/client.py    | 18 ++---
 .../driver/driver_remote_connection.py             | 15 ++--
 .../main/python/gremlin_python/driver/protocol.py  | 21 ++----
 .../main/python/gremlin_python/driver/request.py   |  8 +--
 gremlin-python/src/main/python/tests/conftest.py   | 17 +++--
 .../src/main/python/tests/driver/test_auth.py      | 52 ++++++++++++++
 .../src/main/python/tests/driver/test_client.py    |  5 +-
 9 files changed, 175 insertions(+), 61 deletions(-)

diff --git 
a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py 
b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
index 2f4164f1ca..18f8dfdb90 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
@@ -46,6 +46,7 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
         self._client_session = None
         self._http_req_resp = None
         self._enable_ssl = False
+        self._url = None
 
         # Set all inner variables to parameters passed in.
         self._aiohttp_kwargs = kwargs
@@ -65,17 +66,17 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
         self.close()
 
     def connect(self, url, headers=None):
+        self._url = url
         # Inner function to perform async connect.
         async def async_connect():
-            # Start client session and use it to send all HTTP requests. Base 
url is the endpoint, headers are set here
-            # Base url can only parse basic url with no path, see 
https://github.com/aio-libs/aiohttp/issues/6647
+            # Start client session and use it to send all HTTP requests. 
Headers can be set here.
             if self._enable_ssl:
                 # ssl context is established through tcp connector
                 tcp_conn = aiohttp.TCPConnector(ssl_context=self._ssl_context)
                 self._client_session = 
aiohttp.ClientSession(connector=tcp_conn,
-                                                             base_url=url, 
headers=headers, loop=self._loop)
+                                                             headers=headers, 
loop=self._loop)
             else:
-                self._client_session = aiohttp.ClientSession(base_url=url, 
headers=headers, loop=self._loop)
+                self._client_session = aiohttp.ClientSession(headers=headers, 
loop=self._loop)
 
         # Execute the async connect synchronously.
         self._loop.run_until_complete(async_connect())
@@ -83,13 +84,13 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
     def write(self, message):
         # Inner function to perform async write.
         async def async_write():
-            basic_auth = None
-            # basic password authentication for https connections
+            # To pass url into message for request authentication processing
+            message.update({'url': self._url})
             if message['auth']:
-                basic_auth = aiohttp.BasicAuth(message['auth']['username'], 
message['auth']['password'])
+                message['auth'].apply(message)
+
             async with async_timeout.timeout(self._write_timeout):
-                self._http_req_resp = await 
self._client_session.post(url="/gremlin",
-                                                                      
auth=basic_auth,
+                self._http_req_resp = await 
self._client_session.post(url=self._url,
                                                                       
data=message['payload'],
                                                                       
headers=message['headers'],
                                                                       
**self._aiohttp_kwargs)
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py 
b/gremlin-python/src/main/python/gremlin_python/driver/auth.py
new file mode 100644
index 0000000000..678999e08d
--- /dev/null
+++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import abc
+
+
+class Auth(metaclass=abc.ABCMeta):
+
+    @abc.abstractmethod
+    def apply(self, request):
+        """Applies the necessary authentication operations to the request and 
returns the modified request."""
+        pass
+
+    @staticmethod
+    def basic(username, password):
+        return BasicAuth(username, password)
+
+    @staticmethod
+    def sigv4(region_name, aws_access_key_id='', aws_secret_access_key='', 
session_token='', service_name=''):
+        return SigV4Auth(region_name, aws_access_key_id, 
aws_secret_access_key, session_token, service_name)
+
+
+class BasicAuth(Auth):
+
+    def __init__(self, username, password):
+        self._username = username
+        self._password = password
+
+    def apply(self, request):
+        from aiohttp import BasicAuth as aiohttpBasicAuth
+
+        return request['headers'].update({'authorization': 
aiohttpBasicAuth(self._username, self._password).encode()})
+
+
+class SigV4Auth(Auth):
+
+    def __init__(self, region_name, aws_access_key_id='', 
aws_secret_access_key='', session_token='',
+                 service_name=''):
+        import os
+
+        self._region_name = region_name
+        self._aws_access_key_id = aws_access_key_id if aws_access_key_id else 
os.environ.get('AWS_ACCESS_KEY_ID')
+        self._aws_secret_access_key = aws_secret_access_key if 
aws_secret_access_key \
+            else os.environ.get('AWS_SECRET_ACCESS_KEY')
+        self._session_token = session_token if session_token else 
os.environ.get('AWS_SESSION_TOKEN')
+        self._service_name = service_name if service_name else "neptune-db"
+
+    def apply(self, request):
+        from botocore.auth import SigV4Auth as botocoreSigV4Auth
+        from botocore.awsrequest import AWSRequest
+        from types import SimpleNamespace
+
+        assert ((self._aws_access_key_id is not None and 
self._aws_secret_access_key is not None)
+                or self._session_token is not None), \
+            ('No credentials or session token found, please ensure access key 
and secret key or session tokens '
+             'are provided or set as environment variables.')
+
+        creds = SimpleNamespace(
+            access_key=self._aws_access_key_id, 
secret_key=self._aws_secret_access_key, token=self._session_token,
+            region=self._region_name,
+        )
+        aws_request = AWSRequest(method="POST", url=request['url'], 
data=request['payload'])
+        botocoreSigV4Auth(creds, self._service_name, 
self._region_name).add_auth(aws_request)
+        request['headers'].update(aws_request.headers)
+        request['payload'] = aws_request.data
+        return request
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py 
b/gremlin-python/src/main/python/gremlin_python/driver/client.py
index 0e217acda4..16bd6f8270 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/client.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py
@@ -19,12 +19,9 @@
 import logging
 import warnings
 import queue
-import re
 from concurrent.futures import ThreadPoolExecutor
 
 from gremlin_python.driver import connection, protocol, request, serializer
-from gremlin_python.process import traversal
-from gremlin_python.driver.request import TokensV4
 
 log = logging.getLogger("gremlinpython")
 
@@ -44,13 +41,10 @@ class Client:
 
     def __init__(self, url, traversal_source, protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 message_serializer=None, username="", password="", 
headers=None,
+                 message_serializer=None, auth=None, headers=None,
                  enable_user_agent_on_connect=True, **transport_kwargs):
         log.info("Creating Client with url '%s'", url)
 
-        # check via url that we are using http protocol
-        self._use_http = re.search('^http', url)
-
         self._closed = False
         self._url = url
         self._headers = headers
@@ -62,8 +56,7 @@ class Client:
             message_serializer = serializer.GraphBinarySerializersV4()
 
         self._message_serializer = message_serializer
-        self._username = username
-        self._password = password
+        self._auth = auth
 
         if transport_factory is None:
             try:
@@ -82,8 +75,7 @@ class Client:
             def protocol_factory():
                 return protocol.GremlinServerHTTPProtocol(
                     self._message_serializer,
-                    username=self._username,
-                    password=self._password)
+                    auth=self._auth)
         self._protocol_factory = protocol_factory
 
         if pool_size is None:
@@ -163,11 +155,11 @@ class Client:
 
         if isinstance(message, str):
             log.debug("fields='%s', gremlin='%s'", str(fields), str(message))
-            message = request.RequestMessageV4(fields=fields, gremlin=message)
+            message = request.RequestMessage(fields=fields, gremlin=message)
 
         conn = self._pool.get(True)
         if request_options:
-            message.fields.update({token: request_options[token] for token in 
TokensV4
+            message.fields.update({token: request_options[token] for token in 
request.Tokens
                                    if token in request_options and token != 
'bindings'})
             if 'bindings' in request_options:
                 if 'bindings' in message.fields:
diff --git 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
index 24ae16becb..a48dc0087b 100644
--- 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
+++ 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
@@ -21,9 +21,8 @@ from concurrent.futures import Future
 import warnings
 
 from gremlin_python.driver import client, serializer
-from gremlin_python.driver.remote_connection import (
-    RemoteConnection, RemoteTraversal)
-from gremlin_python.driver.request import TokensV4
+from gremlin_python.driver.remote_connection import RemoteConnection, 
RemoteTraversal
+from gremlin_python.driver.request import Tokens
 
 log = logging.getLogger("gremlinpython")
 
@@ -34,7 +33,7 @@ class DriverRemoteConnection(RemoteConnection):
 
     def __init__(self, url, traversal_source="g", protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 username="", password="",
+                 auth=None,
                  message_serializer=None, headers=None,
                  enable_user_agent_on_connect=True, **transport_kwargs):
         log.info("Creating DriverRemoteConnection with url '%s'", str(url))
@@ -44,8 +43,7 @@ class DriverRemoteConnection(RemoteConnection):
         self.__transport_factory = transport_factory
         self.__pool_size = pool_size
         self.__max_workers = max_workers
-        self.__username = username
-        self.__password = password
+        self.__auth = auth
         self.__message_serializer = message_serializer
         self.__headers = headers
         self.__enable_user_agent_on_connect = enable_user_agent_on_connect
@@ -59,8 +57,7 @@ class DriverRemoteConnection(RemoteConnection):
                                      pool_size=pool_size,
                                      max_workers=max_workers,
                                      message_serializer=message_serializer,
-                                     username=username,
-                                     password=password,
+                                     auth=auth,
                                      headers=headers,
                                      
enable_user_agent_on_connect=enable_user_agent_on_connect,
                                      **transport_kwargs)
@@ -121,7 +118,7 @@ class DriverRemoteConnection(RemoteConnection):
     def extract_request_options(gremlin_lang):
         request_options = {}
         for os in gremlin_lang.options_strategies:
-            request_options.update({token: os.configuration[token] for token 
in TokensV4
+            request_options.update({token: os.configuration[token] for token 
in Tokens
                                     if token in os.configuration})
         if gremlin_lang.parameters is not None and 
len(gremlin_lang.parameters) > 0:
             request_options["params"] = gremlin_lang.parameters
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py 
b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
index a35c004a2a..69c2d921ed 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
@@ -19,6 +19,7 @@
 import json
 import logging
 import abc
+from gremlin_python.driver.auth import Auth
 
 log = logging.getLogger("gremlinpython")
 
@@ -54,12 +55,9 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta):
 
 class GremlinServerHTTPProtocol(AbstractBaseProtocol):
 
-    def __init__(self,
-                 message_serializer,
-                 username='', password=''):
+    def __init__(self, message_serializer, auth=None):
+        self._auth = auth
         self._message_serializer = message_serializer
-        self._username = username
-        self._password = password
         self._response_msg = {'status': {'code': 0,
                                          'message': '',
                                          'exception': ''},
@@ -71,18 +69,13 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol):
         super(GremlinServerHTTPProtocol, self).connection_made(transport)
 
     def write(self, request_message):
-
-        basic_auth = {}
-        if self._username and self._password:
-            basic_auth['username'] = self._username
-            basic_auth['password'] = self._password
-
         content_type = str(self._message_serializer.version, encoding='utf-8')
+
         message = {
-            'headers': {'CONTENT-TYPE': content_type,
-                        'ACCEPT': content_type},
+            'headers': {'content-type': content_type,
+                        'accept': content_type},
             'payload': 
self._message_serializer.serialize_message(request_message),
-            'auth': basic_auth
+            'auth': self._auth
         }
 
         self._transport.write(message)
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/request.py 
b/gremlin-python/src/main/python/gremlin_python/driver/request.py
index 9993802f06..5e04f0bb85 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/request.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/request.py
@@ -20,8 +20,8 @@ import collections
 
 __author__ = 'David M. Brown ([email protected])'
 
-RequestMessageV4 = collections.namedtuple(
-    'RequestMessageV4', ['fields', 'gremlin'])
+RequestMessage = collections.namedtuple(
+    'RequestMessage', ['fields', 'gremlin'])
 
-TokensV4 = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
-            'evaluationTimeout', 'materializeProperties', 'timeoutMs', 
'userAgent']
+Tokens = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
+          'evaluationTimeout', 'materializeProperties', 'timeoutMs', 
'userAgent']
diff --git a/gremlin-python/src/main/python/tests/conftest.py 
b/gremlin-python/src/main/python/tests/conftest.py
index f362a1f515..c2a82aa066 100644
--- a/gremlin-python/src/main/python/tests/conftest.py
+++ b/gremlin-python/src/main/python/tests/conftest.py
@@ -24,18 +24,17 @@ import pytest
 import logging
 import queue
 
-
 from gremlin_python.driver.client import Client
 from gremlin_python.driver.connection import Connection
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
 from gremlin_python.driver.protocol import GremlinServerHTTPProtocol
 from gremlin_python.driver.serializer import GraphBinarySerializersV4
 from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport
-
+from gremlin_python.driver.auth import Auth
 
 """HTTP server testing variables"""
-gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 
'http://localhost:{}/')
-gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 
'https://localhost:{}/')
+gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 
'http://localhost:{}/gremlin')
+gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 
'https://localhost:{}/gremlin')
 anonymous_url = gremlin_server_url.format(45940)
 basic_url = gremlin_basic_auth_url.format(45941)
 
@@ -44,7 +43,6 @@ verbose_logging = False
 logging.basicConfig(format='%(asctime)s [%(levelname)8s] 
[%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s',
                     level=logging.DEBUG if verbose_logging else logging.INFO)
 
-
 """
 Tests below are for the HTTP server with GraphBinaryV4
 """
@@ -52,7 +50,7 @@ Tests below are for the HTTP server with GraphBinaryV4
 def connection(request):
     protocol = GremlinServerHTTPProtocol(
         GraphBinarySerializersV4(),
-        username='stephen', password='password')
+        auth=Auth.basic('stephen', 'password'))
     executor = concurrent.futures.ThreadPoolExecutor(5)
     pool = queue.Queue()
     try:
@@ -91,7 +89,8 @@ def authenticated_client(request):
             # turn off certificate verification for testing purposes only
             ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
             ssl_opts.verify_mode = ssl.CERT_NONE
-            client = Client(basic_url, 'gmodern', username='stephen', 
password='password',
+            client = Client(basic_url, 'gmodern',
+                            auth=Auth.basic('stephen', 'password'),
                             transport_factory=lambda: 
AiohttpHTTPTransport(ssl_options=ssl_opts))
         else:
             raise ValueError("Invalid authentication option - " + 
request.param)
@@ -144,7 +143,7 @@ def remote_connection_crew(request):
         return remote_conn
 
 
-# TODO: revisit once auth is updated
+# TODO: revisit once testing for multiple types of auth is enabled
 @pytest.fixture(params=['basic'])
 def remote_connection_authenticated(request):
     try:
@@ -153,7 +152,7 @@ def remote_connection_authenticated(request):
             ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
             ssl_opts.verify_mode = ssl.CERT_NONE
             remote_conn = DriverRemoteConnection(basic_url, 'gmodern',
-                                                 username='stephen', 
password='password',
+                                                 auth=Auth.basic('stephen', 
'password'),
                                                  transport_factory=lambda: 
AiohttpHTTPTransport(ssl_options=ssl_opts))
         else:
             raise ValueError("Invalid authentication option - " + 
request.param)
diff --git a/gremlin-python/src/main/python/tests/driver/test_auth.py 
b/gremlin-python/src/main/python/tests/driver/test_auth.py
new file mode 100644
index 0000000000..fbd516366f
--- /dev/null
+++ b/gremlin-python/src/main/python/tests/driver/test_auth.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from aiohttp import BasicAuth as aiohttpBasicAuth
+
+from src.main.python.gremlin_python.driver.auth import Auth
+
+
+def create_mock_request():
+    return {'headers':
+            {'content-type': 'application/vnd.graphbinary-v4.0',
+             'accept': 'application/vnd.graphbinary-v4.0'},
+            'payload': b'',
+            'url': 'https://test_url:8182/gremlin'}
+
+
+class TestAuth(object):
+
+    def test_basic_auth_request(self):
+        mock_request = create_mock_request()
+        assert 'authorization' not in mock_request['headers']
+        Auth.basic('username', 'password').apply(mock_request)
+        assert 'authorization' in mock_request['headers']
+        assert aiohttpBasicAuth('username', 'password').encode() == 
mock_request['headers']['authorization']
+
+    def test_sigv4_auth_request(self):
+        mock_request = create_mock_request()
+        assert 'Authorization' not in mock_request['headers']
+        assert 'X-Amz-Date' not in mock_request['headers']
+        Auth.sigv4('us-west-2', 'MOCK_ID', 'MOCK_KEY').apply(mock_request)
+        print(mock_request)
+        assert mock_request['headers']['X-Amz-Date'] is not None
+        assert 
mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 
Credential=MOCK_ID')
+        assert 'us-west-2/neptune-db/aws4_request' in 
mock_request['headers']['Authorization']
+        assert 'Signature=' in mock_request['headers']['Authorization']
+
+
diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py 
b/gremlin-python/src/main/python/tests/driver/test_client.py
index ce180d7bf3..41b8a87a7a 100644
--- a/gremlin-python/src/main/python/tests/driver/test_client.py
+++ b/gremlin-python/src/main/python/tests/driver/test_client.py
@@ -25,7 +25,7 @@ import pytest
 from gremlin_python.driver.client import Client
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
 from gremlin_python.driver.protocol import GremlinServerError
-from gremlin_python.driver.request import RequestMessageV4
+from gremlin_python.driver.request import RequestMessage
 from gremlin_python.process.graph_traversal import __, GraphTraversalSource
 from gremlin_python.process.traversal import TraversalStrategies, Parameter
 from gremlin_python.process.strategies import OptionsStrategy
@@ -41,8 +41,7 @@ test_no_auth_url = gremlin_server_url.format(45940)
 
 
 def create_basic_request_message(traversal, source='gmodern'):
-    msg = RequestMessageV4(fields={'g': source}, 
gremlin=traversal.gremlin_lang.get_gremlin())
-    return msg
+    return RequestMessage(fields={'g': source}, 
gremlin=traversal.gremlin_lang.get_gremlin())
 
 
 def test_connection(connection):

Reply via email to