Author: cutting
Date: Tue Jan 5 18:48:04 2010
New Revision: 896176
URL: http://svn.apache.org/viewvc?rev=896176&view=rev
Log:
Rework Python RPC. Contributed by Jeff Hammerbacher.
Added:
hadoop/avro/trunk/src/py/avro/ipc.py
hadoop/avro/trunk/src/py/avro/protocol.py
hadoop/avro/trunk/src/test/py/sample_ipc_client.py
hadoop/avro/trunk/src/test/py/sample_ipc_server.py
hadoop/avro/trunk/src/test/py/test_protocol.py
Modified:
hadoop/avro/trunk/CHANGES.txt
hadoop/avro/trunk/src/py/avro/__init__.py
Modified: hadoop/avro/trunk/CHANGES.txt
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/CHANGES.txt?rev=896176&r1=896175&r2=896176&view=diff
==============================================================================
--- hadoop/avro/trunk/CHANGES.txt (original)
+++ hadoop/avro/trunk/CHANGES.txt Tue Jan 5 18:48:04 2010
@@ -166,6 +166,8 @@
AVRO-219. Rework Python API. (Jeff Hammerbacher via cutting)
+ AVRO-264. Rework Python RPC. (Jeff Hammerbacher via cutting)
+
OPTIMIZATIONS
AVRO-172. More efficient schema processing (massie)
Modified: hadoop/avro/trunk/src/py/avro/__init__.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/__init__.py?rev=896176&r1=896175&r2=896176&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/__init__.py (original)
+++ hadoop/avro/trunk/src/py/avro/__init__.py Tue Jan 5 18:48:04 2010
@@ -14,5 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['schema', 'io', 'datafile']
+__all__ = ['schema', 'io', 'datafile', 'protocol', 'ipc']
Added: hadoop/avro/trunk/src/py/avro/ipc.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/ipc.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/py/avro/ipc.py (added)
+++ hadoop/avro/trunk/src/py/avro/ipc.py Tue Jan 5 18:48:04 2010
@@ -0,0 +1,461 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Support for inter-process calls.
+"""
+import cStringIO
+import struct
+import socket
+from avro import io
+from avro import protocol
+from avro import schema
+
+#
+# Constants
+#
+
+HANDSHAKE_REQUEST_SCHEMA = schema.parse("""\
+{
+ "type": "record",
+ "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc",
+ "fields": [
+ {"name": "clientHash",
+ "type": {"type": "fixed", "name": "MD5", "size": 16}},
+ {"name": "clientProtocol", "type": ["null", "string"]},
+ {"name": "serverHash", "type": ["null", "MD5"]},
+ {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]}
+ ]
+}""")
+
+HANDSHAKE_RESPONSE_SCHEMA = schema.parse("""\
+{
+ "type": "record",
+ "name": "HandshakeResponse", "namespace": "org.apache.avro.ipc",
+ "fields": [
+ {"name": "match",
+ "type": {"type": "enum", "name": "HandshakeMatch",
+ "symbols": ["BOTH", "CLIENT", "NONE"]}},
+ {"name": "serverProtocol", "type": ["null", "string"]},
+ {"name": "serverHash",
+ "type": ["null", {"type": "fixed", "name": "MD5", "size": 16}]},
+ {"name": "meta",
+ "type": ["null", {"type": "map", "values": "bytes"}]}
+ ]
+}
+""")
+
+HANDSHAKE_REQUESTOR_WRITER = io.DatumWriter(HANDSHAKE_REQUEST_SCHEMA)
+HANDSHAKE_REQUESTOR_READER = io.DatumReader(HANDSHAKE_RESPONSE_SCHEMA)
+HANDSHAKE_RESPONDER_WRITER = io.DatumWriter(HANDSHAKE_RESPONSE_SCHEMA)
+HANDSHAKE_RESPONDER_READER = io.DatumReader(HANDSHAKE_REQUEST_SCHEMA)
+
+META_SCHEMA = schema.parse('{"type": "map", "values": "bytes"}')
+META_WRITER = io.DatumWriter(META_SCHEMA)
+META_READER = io.DatumReader(META_SCHEMA)
+
+SYSTEM_ERROR_SCHEMA = schema.parse('["string"]')
+
+# protocol cache
+REMOTE_HASHES = {}
+REMOTE_PROTOCOLS = {}
+
+BIG_ENDIAN_INT_STRUCT = struct.Struct('!I')
+BUFFER_HEADER_LENGTH = 4
+BUFFER_SIZE = 8192
+
+#
+# Exceptions
+#
+
+class AvroRemoteException(schema.AvroException):
+ """
+ Raised when an error message is sent by an Avro requestor or responder.
+ """
+ def __init__(self, fail_msg=None):
+ schema.AvroException.__init__(self, fail_msg)
+
+class ConnectionClosedException(schema.AvroException):
+ pass
+
+#
+# Base IPC Classes (Requestor/Responder)
+#
+
+class Requestor(object):
+ """Base class for the client side of a protocol interaction."""
+ def __init__(self, local_protocol, transport):
+ self._local_protocol = local_protocol
+ self._transport = transport
+ self._remote_protocol = None
+ self._remote_hash = None
+ self._send_protocol = None
+
+ # read-only properties
+ local_protocol = property(lambda self: self._local_protocol)
+ transport = property(lambda self: self._transport)
+
+ # read/write properties
+ def set_remote_protocol(self, new_remote_protocol):
+ self._remote_protocol = new_remote_protocol
+ REMOTE_PROTOCOLS[self.transport.remote_name] = self.remote_protocol
+ remote_protocol = property(lambda self: self._remote_protocol,
+ set_remote_protocol)
+ def set_remote_hash(self, new_remote_hash):
+ self._remote_hash = new_remote_hash
+ REMOTE_HASHES[self.transport.remote_name] = self.remote_hash
+ remote_hash = property(lambda self: self._remote_hash, set_remote_hash)
+ def set_send_protocol(self, new_send_protocol):
+ self._send_protocol = new_send_protocol
+ send_protocol = property(lambda self: self._send_protocol, set_send_protocol)
+
+ def request(self, message_name, request_datum):
+ """
+ Writes a request message and reads a response or error message.
+ """
+ # build handshake and call request
+ buffer_writer = cStringIO.StringIO()
+ buffer_encoder = io.BinaryEncoder(buffer_writer)
+ self.write_handshake_request(buffer_encoder)
+ self.write_call_request(message_name, request_datum, buffer_encoder)
+
+ # send the handshake and call request; block until call response
+ call_request = buffer_writer.getvalue()
+ call_response = self.transport.transceive(call_request)
+
+ # process the handshake and call response
+ buffer_decoder = io.BinaryDecoder(cStringIO.StringIO(call_response))
+ call_response_exists = self.read_handshake_response(buffer_decoder)
+ if call_response_exists:
+ return self.read_call_response(message_name, buffer_decoder)
+ else:
+ self.request(message_name, request_datum)
+
+ def write_handshake_request(self, encoder):
+ local_hash = self.local_protocol.md5
+ remote_name = self.transport.remote_name
+ remote_hash = REMOTE_HASHES.get(remote_name)
+ if remote_hash is None:
+ remote_hash = local_hash
+ self.remote_protocol = self.local_protocol
+ request_datum = {}
+ request_datum['clientHash'] = local_hash
+ request_datum['serverHash'] = remote_hash
+ if self.send_protocol:
+ request_datum['clientProtocol'] = str(self.local_protocol)
+ HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
+
+ def write_call_request(self, message_name, request_datum, encoder):
+ """
+ The format of a call request is:
+ * request metadata, a map with values of type bytes
+ * the message name, an Avro string, followed by
+ * the message parameters. Parameters are serialized according to
+ the message's request declaration.
+ """
+ # request metadata (not yet implemented)
+ request_metadata = {}
+ META_WRITER.write(request_metadata, encoder)
+
+ # message name
+ message = self.local_protocol.messages.get(message_name)
+ if message is None:
+ raise schema.AvroException('Unknown message: %s' % message_name)
+ encoder.write_utf8(message.name)
+
+ # message parameters
+ self.write_request(message.request, request_datum, encoder)
+
+ def write_request(self, request_fields, request_datum, encoder):
+ """
+ Looks an awful lot like new_io.write_record, eh?
+ """
+ for field in request_fields:
+ datum_writer = io.DatumWriter(field.type)
+ datum_writer.write(request_datum.get(field.name), encoder)
+
+ def read_handshake_response(self, decoder):
+ handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
+ match = handshake_response.get('match')
+ if match == 'BOTH':
+ self.send_protocol = False
+ return True
+ elif match == 'CLIENT':
+ if self.send_protocol:
+ raise schema.AvroException('Handshake failure.')
+ self.remote_protocol = handshake_response.get('serverProtocol')
+ self.remote_hash = handshake_response.get('serverHash')
+ self.send_protocol = False
+ return False
+ elif match == 'NONE':
+ if self.send_protocol:
+ raise schema.AvroException('Handshake failure.')
+ self.remote_protocol = handshake_response.get('serverProtocol')
+ self.remote_hash = handshake_response.get('serverHash')
+ self.send_protocol = True
+ return False
+ else:
+ raise schema.AvroException('Unexpected match: %s' % match)
+
+ def read_call_response(self, message_name, decoder):
+ """
+ The format of a call response is:
+ * response metadata, a map with values of type bytes
+ * a one-byte error flag boolean, followed by either:
+ o if the error flag is false,
+ the message response, serialized per the message's response schema.
+ o if the error flag is true,
+ the error, serialized per the message's error union schema.
+ """
+ # response metadata
+ response_metadata = META_READER.read(decoder)
+
+ # remote response schema
+ remote_message_schema = self.remote_protocol.messages.get(message_name)
+ if remote_message_schema is None:
+ raise schema.AvroException('Unknown remote message: %s' % message_name)
+
+ # local response schema
+ local_message_schema = self.local_protocol.messages.get(message_name)
+ if local_message_schema is None:
+ raise schema.AvroException('Unknown local message: %s' % message_name)
+
+ # error flag
+ if not decoder.read_boolean():
+ writers_schema = remote_message_schema.response
+ readers_schema = local_message_schema.response
+ return self.read_response(writers_schema, readers_schema, decoder)
+ else:
+ writers_schema = remote_message_schema.errors or SYSTEM_ERROR_SCHEMA
+ readers_schema = local_message_schema.errors or SYSTEM_ERROR_SCHEMA
+ raise self.read_error(writers_schema, readers_schema, decoder)
+
+ def read_response(self, writers_schema, readers_schema, decoder):
+ datum_reader = io.DatumReader(writers_schema, readers_schema)
+ return datum_reader.read(decoder)
+
+ def read_error(self, writers_schema, readers_schema, decoder):
+ datum_reader = io.DatumReader(writers_schema, readers_schema)
+ return AvroRemoteException(datum_reader.read(decoder))
+
+class Responder(object):
+ """Base class for the server side of a protocol interaction."""
+ def __init__(self, local_protocol):
+ self._local_protocol = local_protocol
+ self._local_hash = self.local_protocol.md5
+ self._protocol_cache = {}
+ self.set_protocol_cache(self.local_hash, self.local_protocol)
+
+ # read-only properties
+ local_protocol = property(lambda self: self._local_protocol)
+ local_hash = property(lambda self: self._local_hash)
+ protocol_cache = property(lambda self: self._protocol_cache)
+
+ # utility functions to manipulate protocol cache
+ def get_protocol_cache(self, hash):
+ return self.protocol_cache.get(hash)
+ def set_protocol_cache(self, hash, protocol):
+ self.protocol_cache[hash] = protocol
+
+ def respond(self, transport):
+ """
+ Called by a server to deserialize a request, compute and serialize
+ a response or error. Compare to 'handle()' in Thrift.
+ """
+ call_request = transport.read_framed_message()
+ buffer_decoder = io.BinaryDecoder(cStringIO.StringIO(call_request))
+ buffer_writer = cStringIO.StringIO()
+ buffer_encoder = io.BinaryEncoder(buffer_writer)
+ error = None
+ response_metadata = {}
+
+ try:
+ remote_protocol = self.process_handshake(transport, buffer_decoder,
+ buffer_encoder)
+ # handshake failure
+ if remote_protocol is None:
+ return buffer_writer.getvalue()
+
+ # read request using remote protocol
+ request_metadata = META_READER.read(buffer_decoder)
+ remote_message_name = buffer_decoder.read_utf8()
+
+ # get remote and local request schemas so we can do
+ # schema resolution (one fine day)
+ remote_message = remote_protocol.messages.get(remote_message_name)
+ if remote_message is None:
+ fail_msg = 'Unknown remote message: %s' % remote_message_name
+ raise schema.AvroException(fail_msg)
+ local_message = self.local_protocol.messages.get(remote_message_name)
+ if local_message is None:
+ fail_msg = 'Unknown local message: %s' % remote_message_name
+ raise schema.AvroException(fail_msg)
+ writers_fields = remote_message.request
+ # TODO(hammer) pass reader schema
+ request = self.read_request(writers_fields, buffer_decoder)
+ # perform server logic
+ try:
+ response = self.invoke(local_message, request)
+ except AvroRemoteException, e:
+ error = e
+ except Exception, e:
+ error = AvroRemoteException(str(e))
+
+ # write response using local protocol
+ META_WRITER.write(response_metadata, buffer_encoder)
+ buffer_encoder.write_boolean(error is not None)
+ if error is None:
+ writers_schema = local_message.response
+ self.write_response(writers_schema, response, buffer_encoder)
+ else:
+ writers_schema = local_message.errors or SYSTEM_ERROR_SCHEMA
+ self.write_error(writers_schema, error, buffer_encoder)
+ except schema.AvroException, e:
+ error = AvroRemoteException(str(e))
+ buffer_encoder = io.BinaryEncoder(cStringIO.StringIO())
+ META_WRITER.write(response_metadata, buffer_encoder)
+ buffer_encoder.write_boolean(True)
+ self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
+ return buffer_writer.getvalue()
+
+ def process_handshake(self, transport, decoder, encoder):
+ handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
+ handshake_response = {}
+
+ # determine the remote protocol
+ client_hash = handshake_request.get('clientHash')
+ client_protocol = handshake_request.get('clientProtocol')
+ remote_protocol = self.get_protocol_cache(client_hash)
+ if remote_protocol is None and client_protocol is not None:
+ remote_protocol = protocol.parse(client_protocol)
+ self.set_protocol_cache(client_hash, remote_protocol)
+
+ # evaluate remote's guess of the local protocol
+ server_hash = handshake_request.get('serverHash')
+ if self.local_hash == server_hash:
+ if remote_protocol is None:
+ handshake_response['match'] = 'NONE'
+ else:
+ handshake_response['match'] = 'BOTH'
+ else:
+ if remote_protocol is None:
+ handshake_response['match'] = 'NONE'
+ else:
+ handshake_response['match'] = 'CLIENT'
+
+ if handshake_response['match'] != 'BOTH':
+ handshake_response['serverProtocol'] = str(self.local_protocol)
+ handshake_response['serverHash'] = self.local_hash
+
+ HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
+ return remote_protocol
+
+ def invoke(self, local_message, request):
+ """
+ Aactual work done by server: cf. handler in thrift.
+ """
+ pass
+
+ def read_request(self, writers_fields, decoder):
+ """
+ Need to handle schema resolution here. Half-assing it now.
+ """
+ request_data = []
+ for field in writers_fields:
+ datum_reader = io.DatumReader(field.type)
+ request_data.append(datum_reader.read(decoder))
+ return request_data
+
+ def write_response(self, writers_schema, response_datum, encoder):
+ datum_writer = io.DatumWriter(writers_schema)
+ datum_writer.write(response_datum, encoder)
+
+ def write_error(self, writers_schema, error_exception, encoder):
+ datum_writer = io.DatumWriter(writers_schema)
+ datum_writer.write(str(error_exception), encoder)
+
+#
+# Transport Implementations
+#
+
+class SocketTransport(object):
+ """A simple socket-based Transport implementation."""
+ def __init__(self, sock):
+ self._sock = sock
+
+ # read-only properties
+ sock = property(lambda self: self._sock)
+ remote_name = property(lambda self: self.sock.getsockname())
+
+ def transceive(self, request):
+ self.write_framed_message(request)
+ return self.read_framed_message()
+
+ def read_framed_message(self):
+ message = []
+ while True:
+ buffer = cStringIO.StringIO()
+ buffer_length = self.read_buffer_length()
+ if buffer_length == 0:
+ return ''.join(message)
+ while buffer.tell() < buffer_length:
+ chunk = self.sock.recv(buffer_length - buffer.tell())
+ if chunk == '':
+ raise ConnectionClosedException("Socket read 0 bytes.")
+ buffer.write(chunk)
+ message.append(buffer.getvalue())
+
+ def write_framed_message(self, message):
+ message_length = len(message)
+ total_bytes_sent = 0
+ while message_length - total_bytes_sent > 0:
+ if message_length - total_bytes_sent > BUFFER_SIZE:
+ buffer_length = BUFFER_SIZE
+ else:
+ buffer_length = message_length - total_bytes_sent
+ self.write_buffer(message[total_bytes_sent:
+ (total_bytes_sent + buffer_length)])
+ total_bytes_sent += buffer_length
+ # A message is always terminated by a zero-length buffer.
+ self.write_buffer_length(0)
+
+ def write_buffer(self, chunk):
+ buffer_length = len(chunk)
+ self.write_buffer_length(buffer_length)
+ total_bytes_sent = 0
+ while total_bytes_sent < buffer_length:
+ bytes_sent = self.sock.send(chunk[total_bytes_sent:])
+ if bytes_sent == 0:
+ raise ConnectionClosedException("Socket sent 0 bytes.")
+ total_bytes_sent += bytes_sent
+
+ def write_buffer_length(self, n):
+ bytes_sent = self.sock.sendall(BIG_ENDIAN_INT_STRUCT.pack(n))
+ if bytes_sent == 0:
+ raise ConnectionClosedException("socket sent 0 bytes")
+
+ def read_buffer_length(self):
+ read = self.sock.recv(BUFFER_HEADER_LENGTH)
+ if read == '':
+ raise ConnectionClosedException("Socket read 0 bytes.")
+ return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]
+
+ def close(self):
+ self.sock.close()
+
+#
+# Server Implementations (none yet)
+#
+
Added: hadoop/avro/trunk/src/py/avro/protocol.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/protocol.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/py/avro/protocol.py (added)
+++ hadoop/avro/trunk/src/py/avro/protocol.py Tue Jan 5 18:48:04 2010
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Protocol implementation.
+"""
+import cStringIO
+import md5
+try:
+ import simplejson as json
+except ImportError:
+ import json
+from avro import schema
+
+#
+# Constants
+#
+
+# TODO(hammer): confirmed 'fixed' with Doug
+VALID_TYPE_SCHEMA_TYPES = ('enum', 'record', 'error', 'fixed')
+
+#
+# Exceptions
+#
+
+class ProtocolParseException(schema.AvroException):
+ pass
+
+#
+# Base Classes
+#
+
+class Protocol(object):
+ """An application protocol."""
+ def _parse_types(self, types, type_names):
+ type_objects = []
+ for type in types:
+ type_object = schema.make_avsc_object(type, type_names)
+ if type_object.type not in VALID_TYPE_SCHEMA_TYPES:
+ fail_msg = 'Type %s not an enum, record, or error.' % type
+ raise ProtocolParseException(fail_msg)
+ type_objects.append(type_object)
+ return type_objects
+
+ def _parse_messages(self, messages, names):
+ message_objects = {}
+ for name, body in messages.iteritems():
+ if message_objects.has_key(name):
+ fail_msg = 'Message name "%s" repeated.' % name
+ raise ProtocolParseException(fail_msg)
+ elif not(hasattr(body, 'get') and callable(body.get)):
+ fail_msg = 'Message name "%s" has non-object body %s.' % (name, body)
+ raise ProtocolParseException(fail_msg)
+
+ request = body.get('request')
+ response = body.get('response')
+ errors = body.get('errors')
+ message_objects[name] = Message(name, request, response, errors, names)
+ return message_objects
+
+ def __init__(self, name, namespace=None, types=None, messages=None):
+ # Ensure valid ctor args
+ if not name:
+ fail_msg = 'Protocols must have a non-empty name.'
+ raise ProtocolParseException(fail_msg)
+ elif not isinstance(name, basestring):
+ fail_msg = 'The name property must be a string.'
+ raise ProtocolParseException(fail_msg)
+ elif namespace is not None and not isinstance(namespace, basestring):
+ fail_msg = 'The namespace property must be a string.'
+ raise ProtocolParseException(fail_msg)
+ elif types is not None and not isinstance(types, list):
+ fail_msg = 'The types property must be a list.'
+ raise ProtocolParseException(fail_msg)
+ elif (messages is not None and
+ not(hasattr(messages, 'get') and callable(messages.get))):
+ fail_msg = 'The messages property must be a JSON object.'
+ raise ProtocolParseException(fail_msg)
+
+ self._props = {}
+ self.set_prop('name', name)
+ if namespace is not None: self.set_prop('namespace', namespace)
+ type_names = {}
+ if types is not None:
+ self.set_prop('types', self._parse_types(types, type_names))
+ if messages is not None:
+ self.set_prop('messages', self._parse_messages(messages, type_names))
+ self._md5 = md5.new(str(self)).digest()
+
+ # read-only properties
+ name = property(lambda self: self.get_prop('name'))
+ namespace = property(lambda self: self.get_prop('namespace'))
+ fullname = property(lambda self:
+ schema.Name.make_fullname(self.name, self.namespace))
+ types = property(lambda self: self.get_prop('types'))
+ types_dict = property(lambda self: dict([(type.name, type)
+ for type in self.types]))
+ messages = property(lambda self: self.get_prop('messages'))
+ md5 = property(lambda self: self._md5)
+ props = property(lambda self: self._props)
+
+ # utility functions to manipulate properties dict
+ def get_prop(self, key):
+ return self.props.get(key)
+ def set_prop(self, key, value):
+ self.props[key] = value
+
+ def __str__(self):
+ # until we implement a JSON encoder for Schema and Message objects,
+ # we'll have to go through and call str() by hand.
+ to_dump = {}
+ to_dump['protocol'] = self.name
+ if self.namespace: to_dump['namespace'] = self.namespace
+ if self.types:
+ to_dump['types'] = [json.loads(str(t)) for t in self.types]
+ if self.messages:
+ messages_dict = {}
+ for name, body in self.messages.iteritems():
+ messages_dict[name] = json.loads(str(body))
+ to_dump['messages'] = messages_dict
+ return json.dumps(to_dump)
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+class Message(object):
+ """A Protocol message."""
+ def _parse_request(self, request, names):
+ if not isinstance(request, list):
+ fail_msg = 'Request property not a list: %s' % request
+ raise ProtocolParseException(fail_msg)
+ return schema.RecordSchema.make_field_objects(request, names)
+
+ def _parse_response(self, response, names):
+ if isinstance(response, basestring) and names.has_key(response):
+ self._response_from_names = True
+ return names.get(response)
+ else:
+ return schema.make_avsc_object(response, names)
+
+ def _parse_errors(self, errors, names):
+ if not isinstance(errors, list):
+ fail_msg = 'Errors property not a list: %s' % errors
+ raise ProtocolParseException(fail_msg)
+ return schema.make_avsc_object(errors, names)
+
+ def __init__(self, name, request, response, errors=None, names=None):
+ self._name = name
+ self._response_from_names = False
+
+ self._props = {}
+ self.set_prop('request', self._parse_request(request, names))
+ self.set_prop('response', self._parse_response(response, names))
+ if errors is not None:
+ self.set_prop('errors', self._parse_errors(errors, names))
+
+ # read-only properties
+ name = property(lambda self: self._name)
+ response_from_names = property(lambda self: self._response_from_names)
+ request = property(lambda self: self.get_prop('request'))
+ response = property(lambda self: self.get_prop('response'))
+ errors = property(lambda self: self.get_prop('errors'))
+ props = property(lambda self: self._props)
+
+ # utility functions to manipulate properties dict
+ def get_prop(self, key):
+ return self.props.get(key)
+ def set_prop(self, key, value):
+ self.props[key] = value
+
+ # TODO(hammer): allow schemas and fields to be JSON Encoded!
+ def __str__(self):
+ to_dump = {}
+ to_dump['request'] = [json.loads(str(r)) for r in self.request]
+ if self.response_from_names:
+ to_dump['response'] = self.response.fullname
+ else:
+ to_dump['response'] = json.loads(str(self.response))
+ if self.errors:
+ to_dump['errors'] = json.loads(str(self.errors))
+ return json.dumps(to_dump)
+
+ def __eq__(self, that):
+ return self.name == that.name and self.props == that.props
+
+def make_avpr_object(json_data):
+ """Build Avro Protocol from data parsed out of JSON string."""
+ if hasattr(json_data, 'get') and callable(json_data.get):
+ name = json_data.get('protocol')
+ namespace = json_data.get('namespace')
+ types = json_data.get('types')
+ messages = json_data.get('messages')
+ return Protocol(name, namespace, types, messages)
+ else:
+ raise ProtocolParseException('Not a JSON object: %s' % json_data)
+
+def parse(json_string):
+ """Constructs the Protocol from the JSON text."""
+ try:
+ json_data = json.loads(json_string)
+ except:
+ raise ProtocolParseException('Error parsing JSON: %s' % json_string)
+
+ # construct the Avro Protocol object
+ return make_avpr_object(json_data)
+
Added: hadoop/avro/trunk/src/test/py/sample_ipc_client.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/sample_ipc_client.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/sample_ipc_client.py (added)
+++ hadoop/avro/trunk/src/test/py/sample_ipc_client.py Tue Jan 5 18:48:04 2010
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import socket
+import sys
+
+from avro import ipc
+from avro import protocol
+from avro import schema
+
+MAIL_PROTOCOL_JSON = """\
+{"namespace": "example.proto",
+ "protocol": "Mail",
+
+ "types": [
+ {"name": "Message", "type": "record",
+ "fields": [
+ {"name": "to", "type": "string"},
+ {"name": "from", "type": "string"},
+ {"name": "body", "type": "string"}
+ ]
+ }
+ ],
+
+ "messages": {
+ "send": {
+ "request": [{"name": "message", "type": "Message"}],
+ "response": "string"
+ },
+ "replay": {
+ "request": [],
+ "response": "string"
+ }
+ }
+}
+"""
+MAIL_PROTOCOL = protocol.parse(MAIL_PROTOCOL_JSON)
+SERVER_ADDRESS = ('localhost', 9090)
+
+class UsageError(Exception):
+ def __init__(self, value):
+ self.value = value
+ def __str__(self):
+ return repr(self.value)
+
+def make_requestor(server_address, protocol):
+ sock = socket.socket()
+ sock.connect(server_address)
+ client = ipc.SocketTransport(sock)
+ return ipc.Requestor(protocol, client)
+
+if __name__ == '__main__':
+ if len(sys.argv) not in [4, 5]:
+ raise UsageError("Usage: <to> <from> <body> [<count>]")
+
+ # client code - attach to the server and send a message
+ # fill in the Message record
+ message = dict()
+ message['to'] = sys.argv[1]
+ message['from'] = sys.argv[2]
+ message['body'] = sys.argv[3]
+
+ try:
+ num_messages = int(sys.argv[4])
+ except:
+ num_messages = 1
+
+ # build the parameters for the request
+ params = {}
+ params['message'] = message
+
+ # send the requests and print the result
+ for msg_count in range(num_messages):
+ requestor = make_requestor(SERVER_ADDRESS, MAIL_PROTOCOL)
+ result = requestor.request('send', params)
+ print("Result: " + result)
+
+ # try out a replay message
+ requestor = make_requestor(SERVER_ADDRESS, MAIL_PROTOCOL)
+ result = requestor.request('replay', dict())
+ print("Replay Result: " + result)
Added: hadoop/avro/trunk/src/test/py/sample_ipc_server.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/sample_ipc_server.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/sample_ipc_server.py (added)
+++ hadoop/avro/trunk/src/test/py/sample_ipc_server.py Tue Jan 5 18:48:04 2010
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from SocketServer import BaseRequestHandler, TCPServer
+from avro import ipc
+from avro import protocol
+from avro import schema
+
+MAIL_PROTOCOL_JSON = """\
+{"namespace": "example.proto",
+ "protocol": "Mail",
+
+ "types": [
+ {"name": "Message", "type": "record",
+ "fields": [
+ {"name": "to", "type": "string"},
+ {"name": "from", "type": "string"},
+ {"name": "body", "type": "string"}
+ ]
+ }
+ ],
+
+ "messages": {
+ "send": {
+ "request": [{"name": "message", "type": "Message"}],
+ "response": "string"
+ },
+ "replay": {
+ "request": [],
+ "response": "string"
+ }
+ }
+}
+"""
+MAIL_PROTOCOL = protocol.parse(MAIL_PROTOCOL_JSON)
+SERVER_ADDRESS = ('localhost', 9090)
+
+class MailResponder(ipc.Responder):
+ def __init__(self):
+ ipc.Responder.__init__(self, MAIL_PROTOCOL)
+
+ def invoke(self, message, request):
+ if message.name == 'send':
+ request_content = request[0]
+ response = "Sent message to %(to)s from %(from)s with body %(body)s" % \
+ request_content
+ return response
+ elif message.name == 'replay':
+ return 'replay'
+
+class MailHandler(BaseRequestHandler):
+ def handle(self):
+ self.responder = MailResponder()
+ self.transport = ipc.SocketTransport(self.request)
+ self.transport.write_framed_message(self.responder.respond(self.transport))
+
+if __name__ == '__main__':
+ mail_server = TCPServer(SERVER_ADDRESS, MailHandler)
+ mail_server.serve_forever()
Added: hadoop/avro/trunk/src/test/py/test_protocol.py
URL:
http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/test_protocol.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/test_protocol.py (added)
+++ hadoop/avro/trunk/src/test/py/test_protocol.py Tue Jan 5 18:48:04 2010
@@ -0,0 +1,257 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Test the protocol parsing logic.
+"""
+import unittest
+from avro import protocol
+
+class ExampleProtocol(object):
+ def __init__(self, protocol_string, valid, name='', comment=''):
+ self._protocol_string = protocol_string
+ self._valid = valid
+ self._name = name or protocol_string # default to schema_string for name
+ self._comment = comment
+
+ # read-only properties
+ protocol_string = property(lambda self: self._protocol_string)
+ valid = property(lambda self: self._valid)
+ name = property(lambda self: self._name)
+
+ # read/write properties
+ def set_comment(self, new_comment): self._comment = new_comment
+ comment = property(lambda self: self._comment, set_comment)
+
+#
+# Example Protocols
+#
+
+EXAMPLES = [
+ ExampleProtocol("""\
+{
+ "namespace": "com.acme",
+ "protocol": "HelloWorld",
+
+ "types": [
+ {"name": "Greeting", "type": "record", "fields": [
+ {"name": "message", "type": "string"}]},
+ {"name": "Curse", "type": "error", "fields": [
+ {"name": "message", "type": "string"}]}
+ ],
+
+ "messages": {
+ "hello": {
+ "request": [{"name": "greeting", "type": "Greeting" }],
+ "response": "Greeting",
+ "errors": ["Curse"]
+ }
+ }
+}
+ """, True),
+ ExampleProtocol("""\
+{"namespace": "org.apache.avro.test",
+ "protocol": "Simple",
+
+ "types": [
+ {"name": "Kind", "type": "enum", "symbols": ["FOO","BAR","BAZ"]},
+
+ {"name": "MD5", "type": "fixed", "size": 16},
+
+ {"name": "TestRecord", "type": "record",
+ "fields": [
+ {"name": "name", "type": "string", "order": "ignore"},
+ {"name": "kind", "type": "Kind", "order": "descending"},
+ {"name": "hash", "type": "MD5"}
+ ]
+ },
+
+ {"name": "TestError", "type": "error", "fields": [
+ {"name": "message", "type": "string"}
+ ]
+ }
+
+ ],
+
+ "messages": {
+
+ "hello": {
+ "request": [{"name": "greeting", "type": "string"}],
+ "response": "string"
+ },
+
+ "echo": {
+ "request": [{"name": "record", "type": "TestRecord"}],
+ "response": "TestRecord"
+ },
+
+ "add": {
+ "request": [{"name": "arg1", "type": "int"}, {"name": "arg2", "type":
"int"}],
+ "response": "int"
+ },
+
+ "echoBytes": {
+ "request": [{"name": "data", "type": "bytes"}],
+ "response": "bytes"
+ },
+
+ "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["TestError"]
+ }
+ }
+
+}
+ """, True),
+ ExampleProtocol("""\
+{"namespace": "org.apache.avro.test.namespace",
+ "protocol": "TestNamespace",
+
+ "types": [
+ {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+ {"name": "TestRecord", "type": "record",
+ "fields": [ {"name": "hash", "type": "org.apache.avro.test.util.MD5"} ]
+ },
+ {"name": "TestError", "namespace": "org.apache.avro.test.errors",
+ "type": "error", "fields": [ {"name": "message", "type": "string"} ]
+ }
+ ],
+
+ "messages": {
+ "echo": {
+ "request": [{"name": "record", "type": "TestRecord"}],
+ "response": "TestRecord"
+ },
+
+ "error": {
+ "request": [],
+ "response": "null",
+ "errors": ["org.apache.avro.test.errors.TestError"]
+ }
+
+ }
+
+}
+ """, True),
+ ExampleProtocol("""\
+{"namespace": "org.apache.avro.test",
+ "protocol": "BulkData",
+
+ "types": [],
+
+ "messages": {
+
+ "read": {
+ "request": [],
+ "response": "bytes"
+ },
+
+ "write": {
+ "request": [ {"name": "data", "type": "bytes"} ],
+ "response": "null"
+ }
+
+ }
+
+}
+ """, True),
+]
+
+VALID_EXAMPLES = [e for e in EXAMPLES if e.valid]
+
+class TestProtocol(unittest.TestCase):
+ def test_parse(self):
+ print ''
+ print 'TEST PARSE'
+ print '=========='
+ print ''
+
+ num_correct = 0
+ for example in EXAMPLES:
+ try:
+ protocol.parse(example.protocol_string)
+ if example.valid: num_correct += 1
+ debug_msg = "%s: PARSE SUCCESS" % example.name
+ except:
+ if not example.valid: num_correct += 1
+ debug_msg = "%s: PARSE FAILURE" % example.name
+ finally:
+ print debug_msg
+
+ fail_msg = "Parse behavior correct on %d out of %d protocols." % \
+ (num_correct, len(EXAMPLES))
+ self.assertEqual(num_correct, len(EXAMPLES), fail_msg)
+
+ def test_valid_cast_to_string_after_parse(self):
+ """
+ Test that the string generated by an Avro Protocol object
+ is, in fact, a valid Avro protocol.
+ """
+ print ''
+ print 'TEST CAST TO STRING'
+ print '==================='
+ print ''
+
+ num_correct = 0
+ for example in VALID_EXAMPLES:
+ protocol_data = protocol.parse(example.protocol_string)
+ try:
+ protocol.parse(str(protocol_data))
+ debug_msg = "%s: STRING CAST SUCCESS" % example.name
+ num_correct += 1
+ except:
+ debug_msg = "%s: STRING CAST FAILURE" % example.name
+ finally:
+ print debug_msg
+
+ fail_msg = "Cast to string success on %d out of %d protocols" % \
+ (num_correct, len(VALID_EXAMPLES))
+ self.assertEqual(num_correct, len(VALID_EXAMPLES), fail_msg)
+
+ def test_equivalence_after_round_trip(self):
+ """
+ 1. Given a string, parse it to get Avro protocol "original".
+ 2. Serialize "original" to a string and parse that string
+ to generate Avro protocol "round trip".
+ 3. Ensure "original" and "round trip" protocols are equivalent.
+ """
+ print ''
+ print 'TEST ROUND TRIP'
+ print '==============='
+ print ''
+
+ num_correct = 0
+ for example in VALID_EXAMPLES:
+ try:
+ original_protocol = protocol.parse(example.protocol_string)
+ round_trip_protocol = protocol.parse(str(original_protocol))
+
+ if original_protocol == round_trip_protocol:
+ num_correct += 1
+ debug_msg = "%s: ROUND TRIP SUCCESS" % example.name
+ else:
+ debug_msg = "%s: ROUND TRIP FAILURE" % example.name
+ except:
+ debug_msg = "%s: ROUND TRIP FAILURE" % example.name
+ finally:
+ print debug_msg
+
+ fail_msg = "Round trip success on %d out of %d protocols" % \
+ (num_correct, len(VALID_EXAMPLES))
+ self.assertEqual(num_correct, len(VALID_EXAMPLES), fail_msg)
+
+if __name__ == '__main__':
+ unittest.main()