This is an automated email from the ASF dual-hosted git repository.

kenhuuu pushed a commit to branch v4-py-interceptor
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit dbaa799b4808cc74e92230db8519b0bd901643c7
Author: Ken Hu <106191785+kenh...@users.noreply.github.com>
AuthorDate: Fri Oct 18 09:14:42 2024 -0700

    first working version on interceptor
---
 .../src/main/python/examples/connections.py        |   4 +-
 .../main/python/gremlin_python/driver/client.py    |  19 +--
 .../driver/driver_remote_connection.py             |  17 +--
 .../main/python/gremlin_python/driver/protocol.py  |  33 ++++--
 gremlin-python/src/main/python/radish/terrain.py   |   4 +-
 gremlin-python/src/main/python/tests/conftest.py   |  39 ++++++
 .../src/main/python/tests/driver/test_client.py    |  11 ++
 .../tests/driver/test_driver_remote_connection.py  |   8 +-
 .../src/main/python/tests/driver/test_protocol.py  | 131 +++++++++++++++++++++
 .../main/python/tests/process/test_traversal.py    |   3 +-
 .../tests/structure/io/test_functionalityio.py     |   8 +-
 11 files changed, 243 insertions(+), 34 deletions(-)

diff --git a/gremlin-python/src/main/python/examples/connections.py 
b/gremlin-python/src/main/python/examples/connections.py
index f268e6c27d..39f42519e6 100644
--- a/gremlin-python/src/main/python/examples/connections.py
+++ b/gremlin-python/src/main/python/examples/connections.py
@@ -22,7 +22,7 @@ sys.path.append("..")
 from gremlin_python.process.anonymous_traversal import traversal
 from gremlin_python.process.strategies import *
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
-from gremlin_python.driver.serializer import GraphBinarySerializersV1
+from gremlin_python.driver.serializer import GraphBinarySerializersV4
 
 
 def main():
@@ -84,7 +84,7 @@ def with_configs():
     rc = DriverRemoteConnection(
         'ws://localhost:8182/gremlin', 'g',
         username="", password="", kerberized_service='',
-        message_serializer=GraphBinarySerializersV1(), graphson_reader=None,
+        response_serializer=GraphBinarySerializersV4(), graphson_reader=None,
         graphson_writer=None, headers=None, session=None,
         enable_user_agent_on_connect=True
     )
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py 
b/gremlin-python/src/main/python/gremlin_python/driver/client.py
index 73fd96588f..8f948b18a0 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/client.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py
@@ -41,8 +41,10 @@ class Client:
 
     def __init__(self, url, traversal_source, protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 message_serializer=None, auth=None, headers=None,
-                 enable_user_agent_on_connect=True, 
enable_bulked_result=False, **transport_kwargs):
+                 request_serializer=serializer.GraphBinarySerializersV4(),
+                 response_serializer=None, interceptors=None, auth=None,
+                 headers=None, enable_user_agent_on_connect=True,
+                 enable_bulked_result=False, **transport_kwargs):
         log.info("Creating Client with url '%s'", url)
 
         self._closed = False
@@ -53,11 +55,11 @@ class Client:
         self._traversal_source = traversal_source
         if "max_content_length" not in transport_kwargs:
             transport_kwargs["max_content_length"] = 10 * 1024 * 1024
-        if message_serializer is None:
-            message_serializer = serializer.GraphBinarySerializersV4()
+        if response_serializer is None:
+            response_serializer = serializer.GraphBinarySerializersV4()
 
-        self._message_serializer = message_serializer
         self._auth = auth
+        self._response_serializer = response_serializer
 
         if transport_factory is None:
             try:
@@ -75,8 +77,8 @@ class Client:
         if protocol_factory is None:
             def protocol_factory():
                 return protocol.GremlinServerHTTPProtocol(
-                    self._message_serializer,
-                    auth=self._auth)
+                    request_serializer, response_serializer, auth=self._auth,
+                    interceptors=interceptors)
         self._protocol_factory = protocol_factory
 
         if pool_size is None:
@@ -95,6 +97,9 @@ class Client:
     @property
     def available_pool_size(self):
         return self._pool.qsize()
+    
+    def response_serializer(self):
+        return self._response_serializer
 
     @property
     def executor(self):
diff --git 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
index dc8b95eb7a..b32d8f7ad0 100644
--- 
a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
+++ 
b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py
@@ -33,9 +33,10 @@ class DriverRemoteConnection(RemoteConnection):
 
     def __init__(self, url, traversal_source="g", protocol_factory=None,
                  transport_factory=None, pool_size=None, max_workers=None,
-                 auth=None,
-                 message_serializer=None, headers=None,
-                 enable_user_agent_on_connect=True, 
enable_bulked_result=False, **transport_kwargs):
+                 request_serializer=serializer.GraphBinarySerializersV4(),
+                 response_serializer=None, interceptors=None, auth=None,
+                 headers=None, enable_user_agent_on_connect=True,
+                 enable_bulked_result=False, **transport_kwargs):
         log.info("Creating DriverRemoteConnection with url '%s'", str(url))
         self.__url = url
         self.__traversal_source = traversal_source
@@ -44,21 +45,21 @@ class DriverRemoteConnection(RemoteConnection):
         self.__pool_size = pool_size
         self.__max_workers = max_workers
         self.__auth = auth
-        self.__message_serializer = message_serializer
         self.__headers = headers
         self.__enable_user_agent_on_connect = enable_user_agent_on_connect
         self.__enable_bulked_result = enable_bulked_result
         self.__transport_kwargs = transport_kwargs
 
-        if message_serializer is None:
-            message_serializer = serializer.GraphBinarySerializersV4()
+        if response_serializer is None:
+            response_serializer = serializer.GraphBinarySerializersV4()
         self._client = client.Client(url, traversal_source,
                                      protocol_factory=protocol_factory,
                                      transport_factory=transport_factory,
                                      pool_size=pool_size,
                                      max_workers=max_workers,
-                                     message_serializer=message_serializer,
-                                     auth=auth,
+                                     request_serializer=request_serializer,
+                                     response_serializer=response_serializer,
+                                     interceptors=interceptors, auth=auth,
                                      headers=headers,
                                      
enable_user_agent_on_connect=enable_user_agent_on_connect,
                                      enable_bulked_result=enable_bulked_result,
diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py 
b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
index b286c9dec9..930d2de328 100644
--- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
+++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py
@@ -53,9 +53,19 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta):
 
 class GremlinServerHTTPProtocol(AbstractBaseProtocol):
 
-    def __init__(self, message_serializer, auth=None):
+    def __init__(self, request_serializer, response_serializer,
+                 interceptors=None, auth=None):
+        if callable(interceptors):
+            interceptors = [interceptors]
+        elif not (isinstance(interceptors, tuple)
+                  or isinstance(interceptors, list)
+                  or interceptors is None):
+            raise TypeError("interceptors must be a callable, tuple, list or 
None")
+
         self._auth = auth
-        self._message_serializer = message_serializer
+        self._interceptors = interceptors
+        self._request_serializer = request_serializer
+        self._response_serializer = response_serializer
         self._response_msg = {'status': {'code': 0,
                                          'message': '',
                                          'exception': ''},
@@ -67,15 +77,22 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol):
         super(GremlinServerHTTPProtocol, self).connection_made(transport)
 
     def write(self, request_message):
-        content_type = str(self._message_serializer.version, encoding='utf-8')
-
+        accept = str(self._response_serializer.version, encoding='utf-8')
         message = {
-            'headers': {'content-type': content_type,
-                        'accept': content_type},
-            'payload': 
self._message_serializer.serialize_message(request_message),
+            'headers': {'accept': accept},
+            'payload': 
self._request_serializer.serialize_message(request_message)
+                if self._request_serializer is not None else request_message,
             'auth': self._auth
         }
 
+        # The user may not want the payload to be serialized if they are using 
an interceptor.
+        if self._request_serializer is not None:
+            content_type = str(self._request_serializer.version, 
encoding='utf-8')
+            message['headers']['content-type'] = content_type
+
+        for interceptor in self._interceptors or []:
+            message = interceptor(message)
+
         self._transport.write(message)
 
     # data is received in chunks
@@ -110,7 +127,7 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol):
             self._is_first_chunk = False
 
     def _decode_chunk(self, message, data_buffer, is_first_chunk):
-        chunk_msg = self._message_serializer.deserialize_message(data_buffer, 
is_first_chunk)
+        chunk_msg = self._response_serializer.deserialize_message(data_buffer, 
is_first_chunk)
 
         if 'result' in chunk_msg:
             msg_data = message['result']['data']
diff --git a/gremlin-python/src/main/python/radish/terrain.py 
b/gremlin-python/src/main/python/radish/terrain.py
index cc000ef881..c54334ab0c 100644
--- a/gremlin-python/src/main/python/radish/terrain.py
+++ b/gremlin-python/src/main/python/radish/terrain.py
@@ -101,4 +101,6 @@ def __create_remote(server_graph_name):
 
     bulked = world.config.user_data["bulked"] == "true" if "bulked" in 
world.config.user_data else False
 
-    return DriverRemoteConnection(test_no_auth_url, server_graph_name, 
message_serializer=s, enable_bulked_result=bulked)
+    return DriverRemoteConnection(test_no_auth_url, server_graph_name,
+                                  request_serializer=s, response_serializer=s,
+                                  enable_bulked_result=bulked)
diff --git a/gremlin-python/src/main/python/tests/conftest.py 
b/gremlin-python/src/main/python/tests/conftest.py
index 059600f32c..46bccc0dfd 100644
--- a/gremlin-python/src/main/python/tests/conftest.py
+++ b/gremlin-python/src/main/python/tests/conftest.py
@@ -18,6 +18,7 @@
 #
 
 import concurrent.futures
+from json import dumps
 import os
 import ssl
 import pytest
@@ -181,3 +182,41 @@ def invalid_alias_remote_connection(request):
 
         request.addfinalizer(fin)
         return remote_conn
+
+
+@pytest.fixture()
+def remote_connection_with_interceptor(request):
+    try:
+        remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern',
+                                             request_serializer=None,
+                                             interceptors=json_interceptor)
+    except OSError:
+        pytest.skip('Gremlin Server is not running')
+    else:
+        def fin():
+            remote_conn.close()
+
+        request.addfinalizer(fin)
+        return remote_conn
+
+
+@pytest.fixture()
+def client_with_interceptor(request):
+    try:
+        client = Client(anonymous_url, 'gmodern', request_serializer=None,
+                        response_serializer=GraphBinarySerializersV4(),
+                        interceptors=json_interceptor)
+    except OSError:
+        pytest.skip('Gremlin Server is not running')
+    else:
+        def fin():
+            client.close()
+
+        request.addfinalizer(fin)
+        return client
+
+
+def json_interceptor(request):
+        request['headers']['content-type'] = "application/json"
+        request['payload'] = dumps({"gremlin": "g.inject(2)", "g": "g"})
+        return request
diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py 
b/gremlin-python/src/main/python/tests/driver/test_client.py
index 01fdd95916..5f5431651b 100644
--- a/gremlin-python/src/main/python/tests/driver/test_client.py
+++ b/gremlin-python/src/main/python/tests/driver/test_client.py
@@ -26,6 +26,7 @@ from gremlin_python.driver.client import Client
 from gremlin_python.driver.driver_remote_connection import 
DriverRemoteConnection
 from gremlin_python.driver.protocol import GremlinServerError
 from gremlin_python.driver.request import RequestMessage
+from gremlin_python.driver.serializer import GraphBinarySerializersV4
 from gremlin_python.process.graph_traversal import __, GraphTraversalSource
 from gremlin_python.process.traversal import TraversalStrategies, Parameter
 from gremlin_python.process.strategies import OptionsStrategy
@@ -551,3 +552,13 @@ def 
test_client_custom_invalid_request_id_graphbinary_bytecode(client):
 def test_client_custom_valid_request_id_bytecode(client):
     query = GraphTraversalSource(Graph(), TraversalStrategies()).V().bytecode
     assert len(client.submit(query).all().result()) == 6
+
+def test_response_serializer_never_None():
+    client = Client('url', 'g', response_serializer=None)
+    resp_ser = client.response_serializer()
+    assert resp_ser is not None
+
+
+def test_serializer_and_interceptor_forwarded(client_with_interceptor):
+    result = client_with_interceptor.submit("g.inject(1)").next()
+    assert [2] == result # interceptor changes request to g.inject(2)
diff --git 
a/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py 
b/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py
index 848ce8aa20..905e7bea29 100644
--- 
a/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py
+++ 
b/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py
@@ -40,8 +40,7 @@ class TestDriverRemoteConnection(object):
     # this is a temporary test for basic graphSONV4 connectivity, once all 
types are implemented, enable graphSON testing
     # in conftest.py and remove this
     def test_graphSONV4_temp(self):
-        remote_conn = DriverRemoteConnection(test_no_auth_url, 'gmodern',
-                                             
message_serializer=serializer.GraphSONSerializerV4())
+        remote_conn = DriverRemoteConnection(test_no_auth_url, 'gmodern')
         g = traversal().with_(remote_conn)
         assert long(6) == g.V().count().to_list()[0]
         # #
@@ -249,3 +248,8 @@ class TestDriverRemoteConnection(object):
         g = traversal().with_(remote_connection_authenticated)
 
         assert long(6) == g.V().count().to_list()[0]
+
+    def test_forwards_interceptor_serializers(self, 
remote_connection_with_interceptor):
+        g = traversal().with_(remote_connection_with_interceptor)
+        result = g.inject(1).next()
+        assert 2 == result # interceptor changes request to g.inject(2)
diff --git a/gremlin-python/src/main/python/tests/driver/test_protocol.py 
b/gremlin-python/src/main/python/tests/driver/test_protocol.py
new file mode 100644
index 0000000000..3efd12975f
--- /dev/null
+++ b/gremlin-python/src/main/python/tests/driver/test_protocol.py
@@ -0,0 +1,131 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from gremlin_python.driver.protocol import GremlinServerHTTPProtocol
+from gremlin_python.driver.serializer import GraphBinarySerializersV4
+from gremlin_python.driver.transport import AbstractBaseTransport
+from gremlin_python.driver.request import RequestMessage
+
+class MockHTTPTransport(AbstractBaseTransport):
+    def connect(self, url, headers=None):
+        pass
+
+    def write(self, message):
+        self._message = message
+
+    def get_write(self):
+        return self._message
+
+    def read(self):
+        pass
+
+    def close(self):
+        pass
+
+    def closed(self):
+        pass
+
+def test_none_request_serializer_valid():
+    protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(), 
interceptors=None)
+    mock_transport = MockHTTPTransport()
+    protocol.connection_made(mock_transport)
+    
+    message = RequestMessage(fields={}, gremlin="g.V()")
+    protocol.write(message)
+    written = mock_transport.get_write()
+
+    assert written["payload"] == message
+    assert 'content-type' not in written["headers"]
+
+def test_graphbinary_request_serializer_serializes_payload():
+    gb_ser = GraphBinarySerializersV4()
+    protocol = GremlinServerHTTPProtocol(gb_ser, gb_ser)
+    mock_transport = MockHTTPTransport()
+    protocol.connection_made(mock_transport)
+    
+    message = RequestMessage(fields={}, gremlin="g.V()")
+    protocol.write(message)
+    written = mock_transport.get_write()
+
+    assert written["payload"] == gb_ser.serialize_message(message)
+    assert written["headers"]['content-type'] == str(gb_ser.version, 
encoding='utf-8')
+
+def test_interceptor_allows_tuple_and_list():
+    try:
+        tuple = GremlinServerHTTPProtocol(None, None, interceptors=(lambda 
req: req))
+        list = GremlinServerHTTPProtocol(None, None, interceptors=[lambda req: 
req])
+        assert True
+    except:
+        assert False
+
+def test_interceptor_doesnt_allow_any_type():
+    try:
+        protocol = GremlinServerHTTPProtocol(None, None, interceptors=1)
+        assert False
+    except TypeError:
+        assert True
+
+def test_single_interceptor_runs():
+    changed_req = RequestMessage(fields={}, gremlin="changed")
+    def interceptor(request):
+        request['payload'] = changed_req
+        return request
+
+    protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(),
+                                         interceptors=interceptor)
+    mock_transport = MockHTTPTransport()
+    protocol.connection_made(mock_transport)
+    
+    message = RequestMessage(fields={}, gremlin="g.V()")
+    protocol.write(message)
+    written = mock_transport.get_write()
+
+    assert written['payload'] == changed_req
+
+def test_interceptor_works_with_request_serializer():
+    gb_ser = GraphBinarySerializersV4()
+    message = RequestMessage(fields={}, gremlin="g.E()")
+
+    def assert_inteceptor(request):
+        assert request['payload'] == gb_ser.serialize_message(message)
+        request['payload'] = "changed"
+        return request
+    
+    protocol = GremlinServerHTTPProtocol(gb_ser, gb_ser, 
interceptors=assert_inteceptor)
+    mock_transport = MockHTTPTransport()
+    protocol.connection_made(mock_transport)
+    
+    protocol.write(message)
+    written = mock_transport.get_write()
+
+    assert written["payload"] == "changed"
+
+def test_interceptors_run_sequentially():
+    def three(request): request['payload'].gremlin.append(3); return request
+    def two(request): request['payload'].gremlin.append(2); return request
+    def one(request): request['payload'].gremlin.append(1); return request
+    protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(),
+                                         interceptors=[one, two, three])
+    mock_transport = MockHTTPTransport()
+    protocol.connection_made(mock_transport)
+    
+    message = RequestMessage(fields={}, gremlin=[])
+    protocol.write(message)
+    written = mock_transport.get_write()
+
+    assert written["payload"].gremlin == [1, 2, 3]
diff --git a/gremlin-python/src/main/python/tests/process/test_traversal.py 
b/gremlin-python/src/main/python/tests/process/test_traversal.py
index c7f64a5990..a69b8100fe 100644
--- a/gremlin-python/src/main/python/tests/process/test_traversal.py
+++ b/gremlin-python/src/main/python/tests/process/test_traversal.py
@@ -359,8 +359,7 @@ class TestTraversal(object):
 
 
 def create_connection_to_gtx():
-    return DriverRemoteConnection(anonymous_url, 'gtx',
-                                  
message_serializer=serializer.GraphBinarySerializersV4())
+    return DriverRemoteConnection(anonymous_url, 'gtx')
 
 
 def add_node_validate_transaction_state(g, g_add_to, g_start_count, 
g_add_to_start_count, tx_verify_list):
diff --git 
a/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py 
b/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py
index 8dbd0a5af5..be74a28710 100644
--- a/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py
+++ b/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py
@@ -110,7 +110,7 @@ def test_uuid(remote_connection):
 
 
 def test_short(remote_connection):
-    if not isinstance(remote_connection._client._message_serializer, 
GraphBinarySerializersV4):
+    if not isinstance(remote_connection._client.response_serializer(), 
GraphBinarySerializersV4):
         return
 
     g = traversal().with_(remote_connection)
@@ -126,7 +126,7 @@ def test_short(remote_connection):
 
 
 def test_bigint_positive(remote_connection):
-    if not isinstance(remote_connection._client._message_serializer, 
GraphBinarySerializersV4):
+    if not isinstance(remote_connection._client.response_serializer(), 
GraphBinarySerializersV4):
         return
 
     g = traversal().with_(remote_connection)
@@ -142,7 +142,7 @@ def test_bigint_positive(remote_connection):
 
 
 def test_bigint_negative(remote_connection):
-    if not isinstance(remote_connection._client._message_serializer, 
GraphBinarySerializersV4):
+    if not isinstance(remote_connection._client.response_serializer(), 
GraphBinarySerializersV4):
         return
 
     g = traversal().with_(remote_connection)
@@ -159,7 +159,7 @@ def test_bigint_negative(remote_connection):
 
 @pytest.mark.skip(reason="BigDecimal implementation needs revisiting")
 def test_bigdecimal(remote_connection):
-    if not isinstance(remote_connection._client._message_serializer, 
GraphBinarySerializersV4):
+    if not isinstance(remote_connection._client.response_serializer(), 
GraphBinarySerializersV4):
         return
 
     g = traversal().with_(remote_connection)

Reply via email to