This is an automated email from the ASF dual-hosted git repository. kenhuuu pushed a commit to branch v4-py-interceptor in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit dbaa799b4808cc74e92230db8519b0bd901643c7 Author: Ken Hu <106191785+kenh...@users.noreply.github.com> AuthorDate: Fri Oct 18 09:14:42 2024 -0700 first working version on interceptor --- .../src/main/python/examples/connections.py | 4 +- .../main/python/gremlin_python/driver/client.py | 19 +-- .../driver/driver_remote_connection.py | 17 +-- .../main/python/gremlin_python/driver/protocol.py | 33 ++++-- gremlin-python/src/main/python/radish/terrain.py | 4 +- gremlin-python/src/main/python/tests/conftest.py | 39 ++++++ .../src/main/python/tests/driver/test_client.py | 11 ++ .../tests/driver/test_driver_remote_connection.py | 8 +- .../src/main/python/tests/driver/test_protocol.py | 131 +++++++++++++++++++++ .../main/python/tests/process/test_traversal.py | 3 +- .../tests/structure/io/test_functionalityio.py | 8 +- 11 files changed, 243 insertions(+), 34 deletions(-) diff --git a/gremlin-python/src/main/python/examples/connections.py b/gremlin-python/src/main/python/examples/connections.py index f268e6c27d..39f42519e6 100644 --- a/gremlin-python/src/main/python/examples/connections.py +++ b/gremlin-python/src/main/python/examples/connections.py @@ -22,7 +22,7 @@ sys.path.append("..") from gremlin_python.process.anonymous_traversal import traversal from gremlin_python.process.strategies import * from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection -from gremlin_python.driver.serializer import GraphBinarySerializersV1 +from gremlin_python.driver.serializer import GraphBinarySerializersV4 def main(): @@ -84,7 +84,7 @@ def with_configs(): rc = DriverRemoteConnection( 'ws://localhost:8182/gremlin', 'g', username="", password="", kerberized_service='', - message_serializer=GraphBinarySerializersV1(), graphson_reader=None, + response_serializer=GraphBinarySerializersV4(), graphson_reader=None, graphson_writer=None, headers=None, session=None, enable_user_agent_on_connect=True ) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index 73fd96588f..8f948b18a0 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -41,8 +41,10 @@ class Client: def __init__(self, url, traversal_source, protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - message_serializer=None, auth=None, headers=None, - enable_user_agent_on_connect=True, enable_bulked_result=False, **transport_kwargs): + request_serializer=serializer.GraphBinarySerializersV4(), + response_serializer=None, interceptors=None, auth=None, + headers=None, enable_user_agent_on_connect=True, + enable_bulked_result=False, **transport_kwargs): log.info("Creating Client with url '%s'", url) self._closed = False @@ -53,11 +55,11 @@ class Client: self._traversal_source = traversal_source if "max_content_length" not in transport_kwargs: transport_kwargs["max_content_length"] = 10 * 1024 * 1024 - if message_serializer is None: - message_serializer = serializer.GraphBinarySerializersV4() + if response_serializer is None: + response_serializer = serializer.GraphBinarySerializersV4() - self._message_serializer = message_serializer self._auth = auth + self._response_serializer = response_serializer if transport_factory is None: try: @@ -75,8 +77,8 @@ class Client: if protocol_factory is None: def protocol_factory(): return protocol.GremlinServerHTTPProtocol( - self._message_serializer, - auth=self._auth) + request_serializer, response_serializer, auth=self._auth, + interceptors=interceptors) self._protocol_factory = protocol_factory if pool_size is None: @@ -95,6 +97,9 @@ class Client: @property def available_pool_size(self): return self._pool.qsize() + + def response_serializer(self): + return self._response_serializer @property def executor(self): diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index dc8b95eb7a..b32d8f7ad0 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -33,9 +33,10 @@ class DriverRemoteConnection(RemoteConnection): def __init__(self, url, traversal_source="g", protocol_factory=None, transport_factory=None, pool_size=None, max_workers=None, - auth=None, - message_serializer=None, headers=None, - enable_user_agent_on_connect=True, enable_bulked_result=False, **transport_kwargs): + request_serializer=serializer.GraphBinarySerializersV4(), + response_serializer=None, interceptors=None, auth=None, + headers=None, enable_user_agent_on_connect=True, + enable_bulked_result=False, **transport_kwargs): log.info("Creating DriverRemoteConnection with url '%s'", str(url)) self.__url = url self.__traversal_source = traversal_source @@ -44,21 +45,21 @@ class DriverRemoteConnection(RemoteConnection): self.__pool_size = pool_size self.__max_workers = max_workers self.__auth = auth - self.__message_serializer = message_serializer self.__headers = headers self.__enable_user_agent_on_connect = enable_user_agent_on_connect self.__enable_bulked_result = enable_bulked_result self.__transport_kwargs = transport_kwargs - if message_serializer is None: - message_serializer = serializer.GraphBinarySerializersV4() + if response_serializer is None: + response_serializer = serializer.GraphBinarySerializersV4() self._client = client.Client(url, traversal_source, protocol_factory=protocol_factory, transport_factory=transport_factory, pool_size=pool_size, max_workers=max_workers, - message_serializer=message_serializer, - auth=auth, + request_serializer=request_serializer, + response_serializer=response_serializer, + interceptors=interceptors, auth=auth, headers=headers, enable_user_agent_on_connect=enable_user_agent_on_connect, enable_bulked_result=enable_bulked_result, diff --git a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py index b286c9dec9..930d2de328 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/protocol.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/protocol.py @@ -53,9 +53,19 @@ class AbstractBaseProtocol(metaclass=abc.ABCMeta): class GremlinServerHTTPProtocol(AbstractBaseProtocol): - def __init__(self, message_serializer, auth=None): + def __init__(self, request_serializer, response_serializer, + interceptors=None, auth=None): + if callable(interceptors): + interceptors = [interceptors] + elif not (isinstance(interceptors, tuple) + or isinstance(interceptors, list) + or interceptors is None): + raise TypeError("interceptors must be a callable, tuple, list or None") + self._auth = auth - self._message_serializer = message_serializer + self._interceptors = interceptors + self._request_serializer = request_serializer + self._response_serializer = response_serializer self._response_msg = {'status': {'code': 0, 'message': '', 'exception': ''}, @@ -67,15 +77,22 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol): super(GremlinServerHTTPProtocol, self).connection_made(transport) def write(self, request_message): - content_type = str(self._message_serializer.version, encoding='utf-8') - + accept = str(self._response_serializer.version, encoding='utf-8') message = { - 'headers': {'content-type': content_type, - 'accept': content_type}, - 'payload': self._message_serializer.serialize_message(request_message), + 'headers': {'accept': accept}, + 'payload': self._request_serializer.serialize_message(request_message) + if self._request_serializer is not None else request_message, 'auth': self._auth } + # The user may not want the payload to be serialized if they are using an interceptor. + if self._request_serializer is not None: + content_type = str(self._request_serializer.version, encoding='utf-8') + message['headers']['content-type'] = content_type + + for interceptor in self._interceptors or []: + message = interceptor(message) + self._transport.write(message) # data is received in chunks @@ -110,7 +127,7 @@ class GremlinServerHTTPProtocol(AbstractBaseProtocol): self._is_first_chunk = False def _decode_chunk(self, message, data_buffer, is_first_chunk): - chunk_msg = self._message_serializer.deserialize_message(data_buffer, is_first_chunk) + chunk_msg = self._response_serializer.deserialize_message(data_buffer, is_first_chunk) if 'result' in chunk_msg: msg_data = message['result']['data'] diff --git a/gremlin-python/src/main/python/radish/terrain.py b/gremlin-python/src/main/python/radish/terrain.py index cc000ef881..c54334ab0c 100644 --- a/gremlin-python/src/main/python/radish/terrain.py +++ b/gremlin-python/src/main/python/radish/terrain.py @@ -101,4 +101,6 @@ def __create_remote(server_graph_name): bulked = world.config.user_data["bulked"] == "true" if "bulked" in world.config.user_data else False - return DriverRemoteConnection(test_no_auth_url, server_graph_name, message_serializer=s, enable_bulked_result=bulked) + return DriverRemoteConnection(test_no_auth_url, server_graph_name, + request_serializer=s, response_serializer=s, + enable_bulked_result=bulked) diff --git a/gremlin-python/src/main/python/tests/conftest.py b/gremlin-python/src/main/python/tests/conftest.py index 059600f32c..46bccc0dfd 100644 --- a/gremlin-python/src/main/python/tests/conftest.py +++ b/gremlin-python/src/main/python/tests/conftest.py @@ -18,6 +18,7 @@ # import concurrent.futures +from json import dumps import os import ssl import pytest @@ -181,3 +182,41 @@ def invalid_alias_remote_connection(request): request.addfinalizer(fin) return remote_conn + + +@pytest.fixture() +def remote_connection_with_interceptor(request): + try: + remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', + request_serializer=None, + interceptors=json_interceptor) + except OSError: + pytest.skip('Gremlin Server is not running') + else: + def fin(): + remote_conn.close() + + request.addfinalizer(fin) + return remote_conn + + +@pytest.fixture() +def client_with_interceptor(request): + try: + client = Client(anonymous_url, 'gmodern', request_serializer=None, + response_serializer=GraphBinarySerializersV4(), + interceptors=json_interceptor) + except OSError: + pytest.skip('Gremlin Server is not running') + else: + def fin(): + client.close() + + request.addfinalizer(fin) + return client + + +def json_interceptor(request): + request['headers']['content-type'] = "application/json" + request['payload'] = dumps({"gremlin": "g.inject(2)", "g": "g"}) + return request diff --git a/gremlin-python/src/main/python/tests/driver/test_client.py b/gremlin-python/src/main/python/tests/driver/test_client.py index 01fdd95916..5f5431651b 100644 --- a/gremlin-python/src/main/python/tests/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/driver/test_client.py @@ -26,6 +26,7 @@ from gremlin_python.driver.client import Client from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.protocol import GremlinServerError from gremlin_python.driver.request import RequestMessage +from gremlin_python.driver.serializer import GraphBinarySerializersV4 from gremlin_python.process.graph_traversal import __, GraphTraversalSource from gremlin_python.process.traversal import TraversalStrategies, Parameter from gremlin_python.process.strategies import OptionsStrategy @@ -551,3 +552,13 @@ def test_client_custom_invalid_request_id_graphbinary_bytecode(client): def test_client_custom_valid_request_id_bytecode(client): query = GraphTraversalSource(Graph(), TraversalStrategies()).V().bytecode assert len(client.submit(query).all().result()) == 6 + +def test_response_serializer_never_None(): + client = Client('url', 'g', response_serializer=None) + resp_ser = client.response_serializer() + assert resp_ser is not None + + +def test_serializer_and_interceptor_forwarded(client_with_interceptor): + result = client_with_interceptor.submit("g.inject(1)").next() + assert [2] == result # interceptor changes request to g.inject(2) diff --git a/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py b/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py index 848ce8aa20..905e7bea29 100644 --- a/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py +++ b/gremlin-python/src/main/python/tests/driver/test_driver_remote_connection.py @@ -40,8 +40,7 @@ class TestDriverRemoteConnection(object): # this is a temporary test for basic graphSONV4 connectivity, once all types are implemented, enable graphSON testing # in conftest.py and remove this def test_graphSONV4_temp(self): - remote_conn = DriverRemoteConnection(test_no_auth_url, 'gmodern', - message_serializer=serializer.GraphSONSerializerV4()) + remote_conn = DriverRemoteConnection(test_no_auth_url, 'gmodern') g = traversal().with_(remote_conn) assert long(6) == g.V().count().to_list()[0] # # @@ -249,3 +248,8 @@ class TestDriverRemoteConnection(object): g = traversal().with_(remote_connection_authenticated) assert long(6) == g.V().count().to_list()[0] + + def test_forwards_interceptor_serializers(self, remote_connection_with_interceptor): + g = traversal().with_(remote_connection_with_interceptor) + result = g.inject(1).next() + assert 2 == result # interceptor changes request to g.inject(2) diff --git a/gremlin-python/src/main/python/tests/driver/test_protocol.py b/gremlin-python/src/main/python/tests/driver/test_protocol.py new file mode 100644 index 0000000000..3efd12975f --- /dev/null +++ b/gremlin-python/src/main/python/tests/driver/test_protocol.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from gremlin_python.driver.protocol import GremlinServerHTTPProtocol +from gremlin_python.driver.serializer import GraphBinarySerializersV4 +from gremlin_python.driver.transport import AbstractBaseTransport +from gremlin_python.driver.request import RequestMessage + +class MockHTTPTransport(AbstractBaseTransport): + def connect(self, url, headers=None): + pass + + def write(self, message): + self._message = message + + def get_write(self): + return self._message + + def read(self): + pass + + def close(self): + pass + + def closed(self): + pass + +def test_none_request_serializer_valid(): + protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(), interceptors=None) + mock_transport = MockHTTPTransport() + protocol.connection_made(mock_transport) + + message = RequestMessage(fields={}, gremlin="g.V()") + protocol.write(message) + written = mock_transport.get_write() + + assert written["payload"] == message + assert 'content-type' not in written["headers"] + +def test_graphbinary_request_serializer_serializes_payload(): + gb_ser = GraphBinarySerializersV4() + protocol = GremlinServerHTTPProtocol(gb_ser, gb_ser) + mock_transport = MockHTTPTransport() + protocol.connection_made(mock_transport) + + message = RequestMessage(fields={}, gremlin="g.V()") + protocol.write(message) + written = mock_transport.get_write() + + assert written["payload"] == gb_ser.serialize_message(message) + assert written["headers"]['content-type'] == str(gb_ser.version, encoding='utf-8') + +def test_interceptor_allows_tuple_and_list(): + try: + tuple = GremlinServerHTTPProtocol(None, None, interceptors=(lambda req: req)) + list = GremlinServerHTTPProtocol(None, None, interceptors=[lambda req: req]) + assert True + except: + assert False + +def test_interceptor_doesnt_allow_any_type(): + try: + protocol = GremlinServerHTTPProtocol(None, None, interceptors=1) + assert False + except TypeError: + assert True + +def test_single_interceptor_runs(): + changed_req = RequestMessage(fields={}, gremlin="changed") + def interceptor(request): + request['payload'] = changed_req + return request + + protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(), + interceptors=interceptor) + mock_transport = MockHTTPTransport() + protocol.connection_made(mock_transport) + + message = RequestMessage(fields={}, gremlin="g.V()") + protocol.write(message) + written = mock_transport.get_write() + + assert written['payload'] == changed_req + +def test_interceptor_works_with_request_serializer(): + gb_ser = GraphBinarySerializersV4() + message = RequestMessage(fields={}, gremlin="g.E()") + + def assert_inteceptor(request): + assert request['payload'] == gb_ser.serialize_message(message) + request['payload'] = "changed" + return request + + protocol = GremlinServerHTTPProtocol(gb_ser, gb_ser, interceptors=assert_inteceptor) + mock_transport = MockHTTPTransport() + protocol.connection_made(mock_transport) + + protocol.write(message) + written = mock_transport.get_write() + + assert written["payload"] == "changed" + +def test_interceptors_run_sequentially(): + def three(request): request['payload'].gremlin.append(3); return request + def two(request): request['payload'].gremlin.append(2); return request + def one(request): request['payload'].gremlin.append(1); return request + protocol = GremlinServerHTTPProtocol(None, GraphBinarySerializersV4(), + interceptors=[one, two, three]) + mock_transport = MockHTTPTransport() + protocol.connection_made(mock_transport) + + message = RequestMessage(fields={}, gremlin=[]) + protocol.write(message) + written = mock_transport.get_write() + + assert written["payload"].gremlin == [1, 2, 3] diff --git a/gremlin-python/src/main/python/tests/process/test_traversal.py b/gremlin-python/src/main/python/tests/process/test_traversal.py index c7f64a5990..a69b8100fe 100644 --- a/gremlin-python/src/main/python/tests/process/test_traversal.py +++ b/gremlin-python/src/main/python/tests/process/test_traversal.py @@ -359,8 +359,7 @@ class TestTraversal(object): def create_connection_to_gtx(): - return DriverRemoteConnection(anonymous_url, 'gtx', - message_serializer=serializer.GraphBinarySerializersV4()) + return DriverRemoteConnection(anonymous_url, 'gtx') def add_node_validate_transaction_state(g, g_add_to, g_start_count, g_add_to_start_count, tx_verify_list): diff --git a/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py b/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py index 8dbd0a5af5..be74a28710 100644 --- a/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py +++ b/gremlin-python/src/main/python/tests/structure/io/test_functionalityio.py @@ -110,7 +110,7 @@ def test_uuid(remote_connection): def test_short(remote_connection): - if not isinstance(remote_connection._client._message_serializer, GraphBinarySerializersV4): + if not isinstance(remote_connection._client.response_serializer(), GraphBinarySerializersV4): return g = traversal().with_(remote_connection) @@ -126,7 +126,7 @@ def test_short(remote_connection): def test_bigint_positive(remote_connection): - if not isinstance(remote_connection._client._message_serializer, GraphBinarySerializersV4): + if not isinstance(remote_connection._client.response_serializer(), GraphBinarySerializersV4): return g = traversal().with_(remote_connection) @@ -142,7 +142,7 @@ def test_bigint_positive(remote_connection): def test_bigint_negative(remote_connection): - if not isinstance(remote_connection._client._message_serializer, GraphBinarySerializersV4): + if not isinstance(remote_connection._client.response_serializer(), GraphBinarySerializersV4): return g = traversal().with_(remote_connection) @@ -159,7 +159,7 @@ def test_bigint_negative(remote_connection): @pytest.mark.skip(reason="BigDecimal implementation needs revisiting") def test_bigdecimal(remote_connection): - if not isinstance(remote_connection._client._message_serializer, GraphBinarySerializersV4): + if not isinstance(remote_connection._client.response_serializer(), GraphBinarySerializersV4): return g = traversal().with_(remote_connection)