This is an automated email from the ASF dual-hosted git repository.

xiazcy pushed a commit to branch master-http
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git


The following commit(s) were added to refs/heads/master-http by this push:
     new 44dce94bed implement pluggable auth in python with basic and sigv4 as 
reference (#2731)
44dce94bed is described below

commit 44dce94bed89f0f9cf6a0eea2ab28b3982f6de51
Author: Yang Xia <[email protected]>
AuthorDate: Wed Aug 28 14:01:59 2024 -0700

    implement pluggable auth in python with basic and sigv4 as reference (#2731)
---
 CHANGELOG.asciidoc                                 |  1 +
 .../gremlin_python/driver/aiohttp/transport.py     | 19 ++++---
 .../src/main/python/gremlin_python/driver/auth.py  | 55 ++++++++++++++++++
 .../main/python/gremlin_python/driver/client.py    | 18 ++----
 .../driver/driver_remote_connection.py             | 15 ++---
 .../main/python/gremlin_python/driver/protocol.py  | 21 ++-----
 .../main/python/gremlin_python/driver/request.py   |  8 +--
 gremlin-python/src/main/python/tests/conftest.py   | 17 +++---
 .../src/main/python/tests/driver/test_auth.py      | 66 ++++++++++++++++++++++
 .../src/main/python/tests/driver/test_client.py    |  5 +-
 10 files changed, 163 insertions(+), 62 deletions(-)

diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc
index 935424cd80..c5907b4c4b 100644
--- a/CHANGELOG.asciidoc
+++ b/CHANGELOG.asciidoc
@@ -56,6 +56,7 @@ 
image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
 * `EmbeddedRemoteConnection` will use `Gremlinlang`, not `JavaTranslator`.
 * Java `Client` will no longer support submitting traversals. 
`DriverRemoteConnection` should be used instead.
 * Removed usage of `Bytecode` from `gremlin-python`.
+* Added `auth` module in `gremlin-python` for pluggable authentication.
 * Fixed `GremlinLangScriptEngine` handling for some strategies.
 
 == TinkerPop 3.7.0 (Gremfir Master of the Pan Flute)
diff --git 
a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py 
b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
index 2f4164f1ca..3a5d1667e2 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py
@@ -46,6 +46,7 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
         self._client_session = None
         self._http_req_resp = None
         self._enable_ssl = False
+        self._url = None
 
         # Set all inner variables to parameters passed in.
         self._aiohttp_kwargs = kwargs
@@ -65,17 +66,17 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
         self.close()
 
     def connect(self, url, headers=None):
+        self._url = url
         # Inner function to perform async connect.
         async def async_connect():
-            # Start client session and use it to send all HTTP requests. Base 
url is the endpoint, headers are set here
-            # Base url can only parse basic url with no path, see 
https://github.com/aio-libs/aiohttp/issues/6647
+            # Start client session and use it to send all HTTP requests. 
Headers can be set here.
             if self._enable_ssl:
                 # ssl context is established through tcp connector
                 tcp_conn = aiohttp.TCPConnector(ssl_context=self._ssl_context)
                 self._client_session = 
aiohttp.ClientSession(connector=tcp_conn,
-                                                             base_url=url, 
headers=headers, loop=self._loop)
+                                                             headers=headers, 
loop=self._loop)
             else:
-                self._client_session = aiohttp.ClientSession(base_url=url, 
headers=headers, loop=self._loop)
+                self._client_session = aiohttp.ClientSession(headers=headers, 
loop=self._loop)
 
         # Execute the async connect synchronously.
         self._loop.run_until_complete(async_connect())
@@ -83,13 +84,13 @@ class AiohttpHTTPTransport(AbstractBaseTransport):
     def write(self, message):
         # Inner function to perform async write.
         async def async_write():
-            basic_auth = None
-            # basic password authentication for https connections
+            # To pass url into message for request authentication processing
+            message.update({'url': self._url})
             if message['auth']:
-                basic_auth = aiohttp.BasicAuth(message['auth']['username'], 
message['auth']['password'])
+                message['auth'](message)
+
             async with async_timeout.timeout(self._write_timeout):
-                self._http_req_resp = await 
self._client_session.post(url="/gremlin",
-                                                                      
auth=basic_auth,
+                self._http_req_resp = await 
self._client_session.post(url=self._url,
                                                                       
data=message['payload'],
                                                                       
headers=message['headers'],
                                                                       
**self._aiohttp_kwargs)
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py 
b/gremlin-python/src/main/python/gremlin_python/driver/auth.py
new file mode 100644
index 0000000000..dd7a1610ef
--- /dev/null
+++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py
@@ -0,0 +1,55 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+def basic(username, password):
+    from aiohttp import BasicAuth as aiohttpBasicAuth
+
+    def apply(request):
+        return request['headers'].update({'authorization': 
aiohttpBasicAuth(username, password).encode()})
+
+    return apply
+
+
+def sigv4(region, service):
+    import os
+    from boto3 import Session
+    from botocore.auth import SigV4Auth
+    from botocore.awsrequest import AWSRequest
+
+    def apply(request):
+        access_key = os.environ.get('AWS_ACCESS_KEY_ID', '')
+        secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', '')
+        session_token = os.environ.get('AWS_SESSION_TOKEN', '')
+
+        session = Session(
+            aws_access_key_id=access_key,
+            aws_secret_access_key=secret_key,
+            aws_session_token=session_token,
+            region_name=region
+        )
+
+        sigv4_request = AWSRequest(method="POST", url=request['url'], 
data=request['payload'])
+        SigV4Auth(session.get_credentials(), service, 
region).add_auth(sigv4_request)
+        request['headers'].update(sigv4_request.headers)
+        request['payload'] = sigv4_request.data
+        return request
+
+    return apply
+
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py 
b/gremlin-python/src/main/python/gremlin_python/driver/client.py
index 0e217acda4..16bd6f8270 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/client.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py
@@ -19,12 +19,9 @@
 import logging
 import warnings
 import queue
-import re
 from concurrent.futures import ThreadPoolExecutor
 
 from gremlin_python.driver import connection, protocol, request, serializer
-from gremlin_python.process import traversal
-from gremlin_python.driver.request import TokensV4
 
 log = logging.getLogger("gremlinpython")
 
@@ -44,13 +41,10 @@ class Client:
 
     def __init__(self, url, traversal_source, protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 message_serializer=None, username="", password="", 
headers=None,
+                 message_serializer=None, auth=None, headers=None,
                  enable_user_agent_on_connect=True, **transport_kwargs):
         log.info("Creating Client with url '%s'", url)
 
-        # check via url that we are using http protocol
-        self._use_http = re.search('^http', url)
-
         self._closed = False
         self._url = url
         self._headers = headers
@@ -62,8 +56,7 @@ class Client:
             message_serializer = serializer.GraphBinarySerializersV4()
 
         self._message_serializer = message_serializer
-        self._username = username
-        self._password = password
+        self._auth = auth
 
         if transport_factory is None:
             try:
@@ -82,8 +75,7 @@ class Client:
             def protocol_factory():
                 return protocol.GremlinServerHTTPProtocol(
                     self._message_serializer,
-                    username=self._username,
-                    password=self._password)
+                    auth=self._auth)
         self._protocol_factory = protocol_factory
 
         if pool_size is None:
@@ -163,11 +155,11 @@ class Client:
 
         if isinstance(message, str):
             log.debug("fields='%s', gremlin='%s'", str(fields), str(message))
-            message = request.RequestMessageV4(fields=fields, gremlin=message)
+            message = request.RequestMessage(fields=fields, gremlin=message)
 
         conn = self._pool.get(True)
         if request_options:
-            message.fields.update({token: request_options[token] for token in 
TokensV4
+            message.fields.update({token: request_options[token] for token in 
request.Tokens
                                    if token in request_options and token != 
'bindings'})
             if 'bindings' in request_options:
                 if 'bindings' in message.fields:
diff --git 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
index 24ae16becb..a48dc0087b 100644
--- 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
+++ 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
@@ -21,9 +21,8 @@ from concurrent.futures import Future
 import warnings
 
 from gremlin_python.driver import client, serializer
-from gremlin_python.driver.remote_connection import (
-    RemoteConnection, RemoteTraversal)
-from gremlin_python.driver.request import TokensV4
+from gremlin_python.driver.remote_connection import RemoteConnection, 
RemoteTraversal
+from gremlin_python.driver.request import Tokens
 
 log = logging.getLogger("gremlinpython")
 
@@ -34,7 +33,7 @@ class DriverRemoteConnection(RemoteConnection):
 
     def __init__(self, url, traversal_source="g", protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 username="", password="",
+                 auth=None,
                  message_serializer=None, headers=None,
                  enable_user_agent_on_connect=True, **transport_kwargs):
         log.info("Creating DriverRemoteConnection with url '%s'", str(url))
@@ -44,8 +43,7 @@ class DriverRemoteConnection(RemoteConnection):
         self.__transport_factory = transport_factory
         self.__pool_size = pool_size
         self.__max_workers = max_workers
-        self.__username = username
-        self.__password = password
+        self.__auth = auth
         self.__message_serializer = message_serializer
         self.__headers = headers
         self.__enable_user_agent_on_connect = enable_user_agent_on_connect
@@ -59,8 +57,7 @@ class DriverRemoteConnection(RemoteConnection):
                                      pool_size=pool_size,
                                      max_workers=max_workers,
                                      message_serializer=message_serializer,
-                                     username=username,
-                                     password=password,
+                                     auth=auth,
                                      headers=headers,
                                      
enable_user_agent_on_connect=enable_user_agent_on_connect,
                                      **transport_kwargs)
@@ -121,7 +118,7 @@ class DriverRemoteConnection(RemoteConnection):
     def extract_request_options(gremlin_lang):
         request_options = {}
         for os in gremlin_lang.options_strategies:
-            request_options.update({token: os.configuration[token] for token 
in TokensV4
+            request_options.update({token: os.configuration[token] for token 
in Tokens
                                     if token in os.configuration})
         if gremlin_lang.parameters is not None and 
len(gremlin_lang.parameters) > 0:
             request_options["params"] = gremlin_lang.parameters
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py 
b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
index a35c004a2a..b286c9dec9 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
@@ -16,7 +16,6 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import json
 import logging
 import abc
 
@@ -54,12 +53,9 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta):
 
 class GremlinServerHTTPProtocol(AbstractBaseProtocol):
 
-    def __init__(self,
-                 message_serializer,
-                 username='', password=''):
+    def __init__(self, message_serializer, auth=None):
+        self._auth = auth
         self._message_serializer = message_serializer
-        self._username = username
-        self._password = password
         self._response_msg = {'status': {'code': 0,
                                          'message': '',
                                          'exception': ''},
@@ -71,18 +67,13 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol):
         super(GremlinServerHTTPProtocol, self).connection_made(transport)
 
     def write(self, request_message):
-
-        basic_auth = {}
-        if self._username and self._password:
-            basic_auth['username'] = self._username
-            basic_auth['password'] = self._password
-
         content_type = str(self._message_serializer.version, encoding='utf-8')
+
         message = {
-            'headers': {'CONTENT-TYPE': content_type,
-                        'ACCEPT': content_type},
+            'headers': {'content-type': content_type,
+                        'accept': content_type},
             'payload': 
self._message_serializer.serialize_message(request_message),
-            'auth': basic_auth
+            'auth': self._auth
         }
 
         self._transport.write(message)
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/request.py 
b/gremlin-python/src/main/python/gremlin_python/driver/request.py
index 9993802f06..5e04f0bb85 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/request.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/request.py
@@ -20,8 +20,8 @@ import collections
 
 __author__ = 'David M. Brown ([email protected])'
 
-RequestMessageV4 = collections.namedtuple(
-    'RequestMessageV4', ['fields', 'gremlin'])
+RequestMessage = collections.namedtuple(
+    'RequestMessage', ['fields', 'gremlin'])
 
-TokensV4 = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
-            'evaluationTimeout', 'materializeProperties', 'timeoutMs', 
'userAgent']
+Tokens = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
+          'evaluationTimeout', 'materializeProperties', 'timeoutMs', 
'userAgent']
diff --git a/gremlin-python/src/main/python/tests/conftest.py 
b/gremlin-python/src/main/python/tests/conftest.py
index f362a1f515..9e353c3d17 100644
--- a/gremlin-python/src/main/python/tests/conftest.py
+++ b/gremlin-python/src/main/python/tests/conftest.py
@@ -24,18 +24,17 @@ import pytest
 import logging
 import queue
 
-
 from gremlin_python.driver.client import Client
 from gremlin_python.driver.connection import Connection
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
 from gremlin_python.driver.protocol import GremlinServerHTTPProtocol
 from gremlin_python.driver.serializer import GraphBinarySerializersV4
 from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport
-
+from gremlin_python.driver.auth import basic, sigv4
 
 """HTTP server testing variables"""
-gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 
'http://localhost:{}/')
-gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 
'https://localhost:{}/')
+gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 
'http://localhost:{}/gremlin')
+gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 
'https://localhost:{}/gremlin')
 anonymous_url = gremlin_server_url.format(45940)
 basic_url = gremlin_basic_auth_url.format(45941)
 
@@ -44,7 +43,6 @@ verbose_logging = False
 logging.basicConfig(format='%(asctime)s [%(levelname)8s] 
[%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s',
                     level=logging.DEBUG if verbose_logging else logging.INFO)
 
-
 """
 Tests below are for the HTTP server with GraphBinaryV4
 """
@@ -52,7 +50,7 @@ Tests below are for the HTTP server with GraphBinaryV4
 def connection(request):
     protocol = GremlinServerHTTPProtocol(
         GraphBinarySerializersV4(),
-        username='stephen', password='password')
+        auth=basic('stephen', 'password'))
     executor = concurrent.futures.ThreadPoolExecutor(5)
     pool = queue.Queue()
     try:
@@ -91,7 +89,8 @@ def authenticated_client(request):
             # turn off certificate verification for testing purposes only
             ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
             ssl_opts.verify_mode = ssl.CERT_NONE
-            client = Client(basic_url, 'gmodern', username='stephen', 
password='password',
+            client = Client(basic_url, 'gmodern',
+                            auth=basic('stephen', 'password'),
                             transport_factory=lambda: 
AiohttpHTTPTransport(ssl_options=ssl_opts))
         else:
             raise ValueError("Invalid authentication option - " + 
request.param)
@@ -144,7 +143,7 @@ def remote_connection_crew(request):
         return remote_conn
 
 
-# TODO: revisit once auth is updated
+# TODO: revisit once testing for multiple types of auth is enabled
 @pytest.fixture(params=['basic'])
 def remote_connection_authenticated(request):
     try:
@@ -153,7 +152,7 @@ def remote_connection_authenticated(request):
             ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
             ssl_opts.verify_mode = ssl.CERT_NONE
             remote_conn = DriverRemoteConnection(basic_url, 'gmodern',
-                                                 username='stephen', 
password='password',
+                                                 auth=basic('stephen', 
'password'),
                                                  transport_factory=lambda: 
AiohttpHTTPTransport(ssl_options=ssl_opts))
         else:
             raise ValueError("Invalid authentication option - " + 
request.param)
diff --git a/gremlin-python/src/main/python/tests/driver/test_auth.py 
b/gremlin-python/src/main/python/tests/driver/test_auth.py
new file mode 100644
index 0000000000..584cf1ec68
--- /dev/null
+++ b/gremlin-python/src/main/python/tests/driver/test_auth.py
@@ -0,0 +1,66 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+from aiohttp import BasicAuth as aiohttpBasicAuth
+from gremlin_python.driver.auth import basic, sigv4
+
+
+def create_mock_request():
+    return {'headers':
+            {'content-type': 'application/vnd.graphbinary-v4.0',
+             'accept': 'application/vnd.graphbinary-v4.0'},
+            'payload': b'',
+            'url': 'https://test_url:8182/gremlin'}
+
+
+class TestAuth(object):
+
+    def test_basic_auth_request(self):
+        mock_request = create_mock_request()
+        assert 'authorization' not in mock_request['headers']
+        basic('username', 'password')(mock_request)
+        assert 'authorization' in mock_request['headers']
+        assert aiohttpBasicAuth('username', 'password').encode() == 
mock_request['headers']['authorization']
+
+    def test_sigv4_auth_request(self):
+        mock_request = create_mock_request()
+        assert 'Authorization' not in mock_request['headers']
+        assert 'X-Amz-Date' not in mock_request['headers']
+        os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID'
+        os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY'
+        sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request)
+        assert mock_request['headers']['X-Amz-Date'] is not None
+        assert 
mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 
Credential=MOCK_ID')
+        assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in 
mock_request['headers']['Authorization']
+        assert 'Signature=' in mock_request['headers']['Authorization']
+
+    def test_sigv4_auth_request_session_token(self):
+        mock_request = create_mock_request()
+        assert 'Authorization' not in mock_request['headers']
+        assert 'X-Amz-Date' not in mock_request['headers']
+        assert 'X-Amz-Security-Token' not in mock_request['headers']
+        os.environ['AWS_SESSION_TOKEN'] = 'MOCK_TOKEN'
+        sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request)
+        assert mock_request['headers']['X-Amz-Date'] is not None
+        assert 
mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 
Credential=')
+        assert mock_request['headers']['X-Amz-Security-Token'] == 'MOCK_TOKEN'
+        assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in 
mock_request['headers']['Authorization']
+        assert 'Signature=' in mock_request['headers']['Authorization']
+
+
diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py 
b/gremlin-python/src/main/python/tests/driver/test_client.py
index ce180d7bf3..41b8a87a7a 100644
--- a/gremlin-python/src/main/python/tests/driver/test_client.py
+++ b/gremlin-python/src/main/python/tests/driver/test_client.py
@@ -25,7 +25,7 @@ import pytest
 from gremlin_python.driver.client import Client
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
 from gremlin_python.driver.protocol import GremlinServerError
-from gremlin_python.driver.request import RequestMessageV4
+from gremlin_python.driver.request import RequestMessage
 from gremlin_python.process.graph_traversal import __, GraphTraversalSource
 from gremlin_python.process.traversal import TraversalStrategies, Parameter
 from gremlin_python.process.strategies import OptionsStrategy
@@ -41,8 +41,7 @@ test_no_auth_url = gremlin_server_url.format(45940)
 
 
 def create_basic_request_message(traversal, source='gmodern'):
-    msg = RequestMessageV4(fields={'g': source}, 
gremlin=traversal.gremlin_lang.get_gremlin())
-    return msg
+    return RequestMessage(fields={'g': source}, 
gremlin=traversal.gremlin_lang.get_gremlin())
 
 
 def test_connection(connection):

Reply via email to