Repository: thrift Updated Branches: refs/heads/master b9641e094 -> e841b3dac
THRIFT-162 Thrift structures are unhashable, preventing them from being used as set elements Client: Python Patch: David Reiss, Nobuaki Sukegawa This closes #714 Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/e841b3da Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/e841b3da Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/e841b3da Branch: refs/heads/master Commit: e841b3dac619a5e5d3523d059d48db1a12e41360 Parents: b9641e0 Author: Nobuaki Sukegawa <[email protected]> Authored: Tue Nov 17 11:01:17 2015 +0900 Committer: Roger Meier <[email protected]> Committed: Sat Nov 28 00:08:07 2015 +0100 ---------------------------------------------------------------------- compiler/cpp/src/generate/t_py_generator.cc | 261 ++++++++++++++++------- lib/py/src/Thrift.py | 20 ++ lib/py/src/protocol/TBase.py | 27 ++- lib/py/src/protocol/TProtocol.py | 28 +-- lib/py/src/protocol/fastbinary.c | 131 +++++++----- test/DebugProtoTest.thrift | 14 +- test/ThriftTest.thrift | 5 +- test/py/RunClientServer.py | 1 + test/py/TestFrozen.py | 116 ++++++++++ 9 files changed, 451 insertions(+), 152 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/compiler/cpp/src/generate/t_py_generator.cc ---------------------------------------------------------------------- diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc index 44816ab..49c0b57 100644 --- a/compiler/cpp/src/generate/t_py_generator.cc +++ b/compiler/cpp/src/generate/t_py_generator.cc @@ -65,8 +65,9 @@ public: if (gen_dynamic_) { gen_newstyle_ = 0; // dynamic is newstyle gen_dynbaseclass_ = "TBase"; + gen_dynbaseclass_frozen_ = "TFrozenBase"; gen_dynbaseclass_exc_ = "TExceptionBase"; - import_dynbase_ = "from thrift.protocol.TBase import TBase, TExceptionBase, TTransport\n"; + import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TTransport\n"; } iter = parsed_options.find("dynbase"); @@ -75,6 +76,11 @@ public: gen_dynbaseclass_ = (iter->second); } + iter = parsed_options.find("dynfrozen"); + if (iter != parsed_options.end()) { + gen_dynbaseclass_frozen_ = (iter->second); + } + iter = parsed_options.find("dynexc"); if (iter != parsed_options.end()) { gen_dynbaseclass_exc_ = (iter->second); @@ -142,8 +148,7 @@ public: void generate_py_struct(t_struct* tstruct, bool is_exception); void generate_py_struct_definition(std::ofstream& out, t_struct* tstruct, - bool is_xception = false, - bool is_result = false); + bool is_xception = false); void generate_py_struct_reader(std::ofstream& out, t_struct* tstruct); void generate_py_struct_writer(std::ofstream& out, t_struct* tstruct); void generate_py_struct_required_validator(std::ofstream& out, t_struct* tstruct); @@ -166,8 +171,7 @@ public: void generate_deserialize_field(std::ofstream& out, t_field* tfield, - std::string prefix = "", - bool inclass = false); + std::string prefix = ""); void generate_deserialize_struct(std::ofstream& out, t_struct* tstruct, std::string prefix = ""); @@ -244,7 +248,12 @@ public: return real_module; } + static bool is_immutable(t_type* ttype) { + return ttype->annotations_.find("python.immutable") != ttype->annotations_.end(); + } + private: + /** * True if we should generate new-style classes. */ @@ -257,6 +266,7 @@ private: bool gen_dynbase_; std::string gen_dynbaseclass_; + std::string gen_dynbaseclass_frozen_; std::string gen_dynbaseclass_exc_; std::string import_dynbase_; @@ -353,14 +363,12 @@ void t_py_generator::init_generator() { py_autogen_comment() << endl << py_imports() << endl << render_includes() << endl << - render_fastbinary_includes() << - endl << endl; + render_fastbinary_includes(); f_consts_ << py_autogen_comment() << endl << py_imports() << endl << - "from .ttypes import *" << endl << - endl; + "from .ttypes import *" << endl; } /** @@ -372,9 +380,6 @@ string t_py_generator::render_includes() { for (size_t i = 0; i < includes.size(); ++i) { result += "import " + get_real_py_module(includes[i], gen_twisted_) + ".ttypes\n"; } - if (includes.size() > 0) { - result += "\n"; - } return result; } @@ -413,7 +418,7 @@ string t_py_generator::py_autogen_comment() { * Prints standard thrift imports */ string t_py_generator::py_imports() { - return string("from thrift.Thrift import TType, TMessageType, TException, TApplicationException"); + return string("from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException"); } /** @@ -443,7 +448,7 @@ void t_py_generator::generate_typedef(t_typedef* ttypedef) { void t_py_generator::generate_enum(t_enum* tenum) { std::ostringstream to_string_mapping, from_string_mapping; - f_types_ << "class " << tenum->get_name() << (gen_newstyle_ ? "(object)" : "") + f_types_ << endl << endl << "class " << tenum->get_name() << (gen_newstyle_ ? "(object)" : "") << (gen_dynamic_ ? "(" + gen_dynbaseclass_ + ")" : "") << ":" << endl; indent_up(); generate_python_docstring(f_types_, tenum); @@ -468,7 +473,7 @@ void t_py_generator::generate_enum(t_enum* tenum) { indent_down(); f_types_ << endl; - f_types_ << to_string_mapping.str() << endl << from_string_mapping.str() << endl; + f_types_ << to_string_mapping.str() << endl << from_string_mapping.str(); } /** @@ -547,6 +552,9 @@ string t_py_generator::render_const_value(t_type* type, t_const_value* value) { } else if (type->is_map()) { t_type* ktype = ((t_map*)type)->get_key_type(); t_type* vtype = ((t_map*)type)->get_val_type(); + if (is_immutable(type)) { + out << "TFrozenDict("; + } out << "{" << endl; indent_up(); const map<t_const_value*, t_const_value*>& val = value->get_map(); @@ -560,6 +568,9 @@ string t_py_generator::render_const_value(t_type* type, t_const_value* value) { } indent_down(); indent(out) << "}"; + if (is_immutable(type)) { + out << ")"; + } } else if (type->is_list() || type->is_set()) { t_type* etype; if (type->is_list()) { @@ -568,9 +579,16 @@ string t_py_generator::render_const_value(t_type* type, t_const_value* value) { etype = ((t_set*)type)->get_elem_type(); } if (type->is_set()) { + if (is_immutable(type)) { + out << "frozen"; + } out << "set("; } - out << "[" << endl; + if (is_immutable(type) || type->is_set()) { + out << "(" << endl; + } else { + out << "[" << endl; + } indent_up(); const vector<t_const_value*>& val = value->get_list(); vector<t_const_value*>::const_iterator v_iter; @@ -580,7 +598,11 @@ string t_py_generator::render_const_value(t_type* type, t_const_value* value) { out << "," << endl; } indent_down(); - indent(out) << "]"; + if (is_immutable(type) || type->is_set()) { + indent(out) << ")"; + } else { + indent(out) << "]"; + } if (type->is_set()) { out << ")"; } @@ -622,26 +644,26 @@ void t_py_generator::generate_py_struct(t_struct* tstruct, bool is_exception) { */ void t_py_generator::generate_py_struct_definition(ofstream& out, t_struct* tstruct, - bool is_exception, - bool is_result) { - (void)is_result; + bool is_exception) { const vector<t_field*>& members = tstruct->get_members(); const vector<t_field*>& sorted_members = tstruct->get_sorted_members(); vector<t_field*>::const_iterator m_iter; - out << std::endl << "class " << tstruct->get_name(); + out << endl << endl << "class " << tstruct->get_name(); if (is_exception) { if (gen_dynamic_) { out << "(" << gen_dynbaseclass_exc_ << ")"; } else { out << "(TException)"; } - } else { - if (gen_newstyle_) { - out << "(object)"; - } else if (gen_dynamic_) { + } else if (gen_dynamic_) { + if (is_immutable(tstruct)) { + out << "(" << gen_dynbaseclass_frozen_ << ")"; + } else { out << "(" << gen_dynbaseclass_ << ")"; } + } else if (gen_newstyle_) { + out << "(object)"; } out << ":" << endl; indent_up(); @@ -670,13 +692,13 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, */ if (gen_slots_) { - indent(out) << "__slots__ = [ " << endl; + indent(out) << "__slots__ = (" << endl; indent_up(); for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) { indent(out) << "'" << (*m_iter)->get_name() << "'," << endl; } indent_down(); - indent(out) << " ]" << endl << endl; + indent(out) << ")" << endl << endl; } // TODO(dreiss): Look into generating an empty tuple instead of None @@ -691,7 +713,7 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) { for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) { - indent(out) << "None, # " << sorted_keys_pos << endl; + indent(out) << "None, # " << sorted_keys_pos << endl; } indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type()) @@ -700,16 +722,17 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, << ", " << type_to_spec_args((*m_iter)->get_type()) << ", " << render_field_default_value(*m_iter) << ", " << ")," - << " # " << sorted_keys_pos << endl; + << " # " << sorted_keys_pos << endl; sorted_keys_pos++; } indent_down(); - indent(out) << ")" << endl << endl; + indent(out) << ")" << endl; } else { indent(out) << "thrift_spec = None" << endl; } + out << endl; if (members.size() > 0) { out << indent() << "def __init__(self,"; @@ -729,10 +752,23 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, if (!type->is_base_type() && !type->is_enum() && (*m_iter)->get_value() != NULL) { indent(out) << "if " << (*m_iter)->get_name() << " is " << "self.thrift_spec[" << (*m_iter)->get_key() << "][4]:" << endl; - indent(out) << " " << (*m_iter)->get_name() << " = " << render_field_default_value(*m_iter) + indent_up(); + indent(out) << (*m_iter)->get_name() << " = " << render_field_default_value(*m_iter) << endl; + indent_down(); + } + + if (is_immutable(tstruct)) { + if (gen_newstyle_ || gen_dynamic_) { + indent(out) << "super(" << tstruct->get_name() << ", self).__setattr__('" + << (*m_iter)->get_name() << "', " << (*m_iter)->get_name() << ")" << endl; + } else { + indent(out) << "self.__dict__['" << (*m_iter)->get_name() + << "'] = " << (*m_iter)->get_name() << endl; + } + } else { + indent(out) << "self." << (*m_iter)->get_name() << " = " << (*m_iter)->get_name() << endl; } - indent(out) << "self." << (*m_iter)->get_name() << " = " << (*m_iter)->get_name() << endl; } indent_down(); @@ -740,6 +776,26 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, out << endl; } + if (is_immutable(tstruct)) { + out << indent() << "def __setattr__(self, *args):" << endl + << indent() << " raise TypeError(\"can't modify immutable instance\")" << endl + << endl; + out << indent() << "def __delattr__(self, *args):" << endl + << indent() << " raise TypeError(\"can't modify immutable instance\")" << endl + << endl; + + // Hash all of the members in order, and also hash in the class + // to avoid collisions for stuff like single-field structures. + out << indent() << "def __hash__(self):" << endl + << indent() << " return hash(self.__class__) ^ hash(("; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << "self." << (*m_iter)->get_name() << ", "; + } + + out << "))" << endl << endl; + } + if (!gen_dynamic_) { generate_py_struct_reader(out, tstruct); generate_py_struct_writer(out, tstruct); @@ -759,7 +815,7 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, out << indent() << "def __repr__(self):" << endl << indent() << " L = ['%s=%r' % (key, value)" << endl << - indent() << " for key, value in self.__dict__.items()]" << endl << + indent() << " for key, value in self.__dict__.items()]" << endl << indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl << endl; @@ -794,7 +850,7 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, << indent() << " return True" << endl << endl; out << indent() << "def __ne__(self, other):" << endl << indent() - << " return not (self == other)" << endl << endl; + << " return not (self == other)" << endl; } indent_down(); } @@ -806,18 +862,30 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, t_struct* tstruct) const vector<t_field*>& fields = tstruct->get_members(); vector<t_field*>::const_iterator f_iter; - indent(out) << "def read(self, iprot):" << endl; + if (is_immutable(tstruct)) { + out << indent() << "@classmethod" << endl << indent() << "def read(cls, iprot):" << endl; + } else { + indent(out) << "def read(self, iprot):" << endl; + } indent_up(); + const char* id = is_immutable(tstruct) ? "cls" : "self"; + indent(out) << "if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated " "and isinstance(iprot.trans, TTransport.CReadableTransport) " - "and self.thrift_spec is not None " + "and " << id << ".thrift_spec is not None " "and fastbinary is not None:" << endl; indent_up(); - indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" - << endl; - indent(out) << "return" << endl; + if (is_immutable(tstruct)) { + indent(out) + << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec))" + << endl; + } else { + indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" + << endl; + indent(out) << "return" << endl; + } indent_down(); indent(out) << "iprot.readStructBegin()" << endl; @@ -850,7 +918,11 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, t_struct* tstruct) indent_up(); indent(out) << "if ftype == " << type_to_enum((*f_iter)->get_type()) << ":" << endl; indent_up(); - generate_deserialize_field(out, *f_iter, "self."); + if (is_immutable(tstruct)) { + generate_deserialize_field(out, *f_iter); + } else { + generate_deserialize_field(out, *f_iter, "self."); + } indent_down(); out << indent() << "else:" << endl << indent() << " iprot.skip(ftype)" << endl; indent_down(); @@ -866,6 +938,16 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, t_struct* tstruct) indent(out) << "iprot.readStructEnd()" << endl; + if (is_immutable(tstruct)) { + indent(out) << "return cls(" << endl; + indent_up(); + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << (*f_iter)->get_name() << "=" << (*f_iter)->get_name() << "," << endl; + } + indent_down(); + indent(out) << ")" << endl; + } + indent_down(); out << endl; } @@ -916,7 +998,6 @@ void t_py_generator::generate_py_struct_writer(ofstream& out, t_struct* tstruct) indent_down(); generate_py_struct_required_validator(out, tstruct); - out << endl; } void t_py_generator::generate_py_struct_required_validator(ofstream& out, t_struct* tstruct) { @@ -962,7 +1043,7 @@ void t_py_generator::generate_service(t_service* tservice) { f_service_ << "import logging" << endl << "from .ttypes import *" << endl << "from thrift.Thrift import TProcessor" << endl - << render_fastbinary_includes() << endl; + << render_fastbinary_includes(); if (gen_twisted_) { f_service_ << "from zope.interface import Interface, implements" << endl @@ -974,8 +1055,6 @@ void t_py_generator::generate_service(t_service* tservice) { f_service_ << "from thrift.transport import TTransport" << endl; } - f_service_ << endl; - // Generate the three main parts of the service generate_service_interface(tservice); generate_service_client(tservice); @@ -996,7 +1075,7 @@ void t_py_generator::generate_service_helpers(t_service* tservice) { vector<t_function*> functions = tservice->get_functions(); vector<t_function*>::iterator f_iter; - f_service_ << "# HELPER FUNCTIONS AND STRUCTURES" << endl; + f_service_ << endl << "# HELPER FUNCTIONS AND STRUCTURES" << endl; for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { t_struct* ts = (*f_iter)->get_arglist(); @@ -1024,7 +1103,7 @@ void t_py_generator::generate_py_function_helpers(t_function* tfunction) { for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { result.append(*f_iter); } - generate_py_struct_definition(f_service_, &result, false, true); + generate_py_struct_definition(f_service_, &result, false); } } @@ -1047,7 +1126,7 @@ void t_py_generator::generate_service_interface(t_service* tservice) { } } - f_service_ << "class Iface" << extends_if << ":" << endl; + f_service_ << endl << endl << "class Iface" << extends_if << ":" << endl; indent_up(); generate_python_docstring(f_service_, tservice); vector<t_function*> functions = tservice->get_functions(); @@ -1055,17 +1134,22 @@ void t_py_generator::generate_service_interface(t_service* tservice) { f_service_ << indent() << "pass" << endl; } else { vector<t_function*>::iterator f_iter; + bool first = true; for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << endl; + } f_service_ << indent() << "def " << function_signature(*f_iter, true) << ":" << endl; indent_up(); generate_python_docstring(f_service_, (*f_iter)); - f_service_ << indent() << "pass" << endl << endl; + f_service_ << indent() << "pass" << endl; indent_down(); } } indent_down(); - f_service_ << endl; } /** @@ -1089,6 +1173,8 @@ void t_py_generator::generate_service_client(t_service* tservice) { } } + f_service_ << endl << endl; + if (gen_twisted_) { f_service_ << "class Client" << extends_client << ":" << endl << " implements(Iface)" << endl << endl; @@ -1111,7 +1197,7 @@ void t_py_generator::generate_service_client(t_service* tservice) { if (gen_twisted_) { f_service_ << indent() << " self._transport = transport" << endl << indent() << " self._oprot_factory = oprot_factory" << endl << indent() - << " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl << endl; + << " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl; } else if (gen_tornado_) { f_service_ << indent() << " self._transport = transport" << endl << indent() << " self._iprot_factory = iprot_factory" << endl << indent() @@ -1119,28 +1205,26 @@ void t_py_generator::generate_service_client(t_service* tservice) { << indent() << " else iprot_factory)" << endl << indent() << " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl << indent() << " self._transport.io_loop.spawn_callback(self._start_receiving)" - << endl << endl; + << endl; } else { f_service_ << indent() << " self._iprot = self._oprot = iprot" << endl << indent() << " if oprot is not None:" << endl << indent() << " self._oprot = oprot" - << endl << indent() << " self._seqid = 0" << endl << endl; + << endl << indent() << " self._seqid = 0" << endl; } } else { if (gen_twisted_) { f_service_ << indent() << " " << extends - << ".Client.__init__(self, transport, oprot_factory)" << endl << endl; + << ".Client.__init__(self, transport, oprot_factory)" << endl; } else if (gen_tornado_) { f_service_ << indent() << " " << extends - << ".Client.__init__(self, transport, iprot_factory, oprot_factory)" << endl - << endl; + << ".Client.__init__(self, transport, iprot_factory, oprot_factory)" << endl; } else { - f_service_ << indent() << " " << extends << ".Client.__init__(self, iprot, oprot)" << endl - << endl; + f_service_ << indent() << " " << extends << ".Client.__init__(self, iprot, oprot)" << endl; } } if (gen_tornado_ && extends.empty()) { - f_service_ << + f_service_ << endl << indent() << "@gen.engine" << endl << indent() << "def _start_receiving(self):" << endl << indent() << " while True:" << endl << @@ -1164,8 +1248,7 @@ void t_py_generator::generate_service_client(t_service* tservice) { indent() << " except Exception as e:" << endl << indent() << " future.set_exception(e)" << endl << indent() << " else:" << endl << - indent() << " future.set_result(result)" << endl << - endl; + indent() << " future.set_result(result)" << endl; } // Generate client method implementations @@ -1177,6 +1260,7 @@ void t_py_generator::generate_service_client(t_service* tservice) { vector<t_field*>::const_iterator fld_iter; string funname = (*f_iter)->get_name(); + f_service_ << endl; // Open function indent(f_service_) << "def " << function_signature(*f_iter, false) << ":" << endl; indent_up(); @@ -1386,12 +1470,10 @@ void t_py_generator::generate_service_client(t_service* tservice) { // Close function indent_down(); - f_service_ << endl; } } indent_down(); - f_service_ << endl; } /** @@ -1560,6 +1642,8 @@ void t_py_generator::generate_service_server(t_service* tservice) { extends_processor = extends + ".Processor, "; } + f_service_ << endl << endl; + // Generate the header portion if (gen_twisted_) { f_service_ << "class Processor(" << extends_processor << "TProcessor):" << endl @@ -1630,15 +1714,14 @@ void t_py_generator::generate_service_server(t_service* tservice) { } indent_down(); - f_service_ << endl; // Generate the process subfunctions for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << endl; generate_process_function(tservice, *f_iter); } indent_down(); - f_service_ << endl; } /** @@ -1860,7 +1943,10 @@ void t_py_generator::generate_process_function(t_service* tservice, t_function* } f_service_ << "args." << (*f_iter)->get_name(); } - f_service_ << ")" << endl << indent() << "msg_type = TMessageType.REPLY" << endl; + f_service_ << ")" << endl; + if (!tfunction->is_oneway()) { + f_service_ << indent() << "msg_type = TMessageType.REPLY" << endl; + } indent_down(); f_service_ << indent() @@ -1900,7 +1986,6 @@ void t_py_generator::generate_process_function(t_service* tservice, t_function* // Close function indent_down(); - f_service_ << endl; } } @@ -1909,9 +1994,7 @@ void t_py_generator::generate_process_function(t_service* tservice, t_function* */ void t_py_generator::generate_deserialize_field(ofstream& out, t_field* tfield, - string prefix, - bool inclass) { - (void)inclass; + string prefix) { t_type* type = get_true_type(tfield->get_type()); if (type->is_void()) { @@ -1932,7 +2015,6 @@ void t_py_generator::generate_deserialize_field(ofstream& out, switch (tbase) { case t_base_type::TYPE_VOID: throw "compiler error: cannot serialize void field in a struct: " + name; - break; case t_base_type::TYPE_STRING: if (((t_base_type*)type)->is_binary()) { out << "readBinary()"; @@ -1979,8 +2061,12 @@ void t_py_generator::generate_deserialize_field(ofstream& out, * Generates an unserializer for a struct, calling read() */ void t_py_generator::generate_deserialize_struct(ofstream& out, t_struct* tstruct, string prefix) { - out << indent() << prefix << " = " << type_name(tstruct) << "()" << endl << indent() << prefix - << ".read(iprot)" << endl; + if (is_immutable(tstruct)) { + out << indent() << prefix << " = " << type_name(tstruct) << ".read(iprot)" << endl; + } else { + out << indent() << prefix << " = " << type_name(tstruct) << "()" << endl + << indent() << prefix << ".read(iprot)" << endl; + } } /** @@ -2030,9 +2116,18 @@ void t_py_generator::generate_deserialize_container(ofstream& out, t_type* ttype // Read container end if (ttype->is_map()) { indent(out) << "iprot.readMapEnd()" << endl; + if (is_immutable(ttype)) { + indent(out) << prefix << " = TFrozenDict(" << prefix << ")" << endl; + } } else if (ttype->is_set()) { indent(out) << "iprot.readSetEnd()" << endl; + if (is_immutable(ttype)) { + indent(out) << prefix << " = frozenset(" << prefix << ")" << endl; + } } else if (ttype->is_list()) { + if (is_immutable(ttype)) { + indent(out) << prefix << " = tuple(" << prefix << ")" << endl; + } indent(out) << "iprot.readListEnd()" << endl; } } @@ -2178,7 +2273,7 @@ void t_py_generator::generate_serialize_container(ofstream& out, t_type* ttype, if (ttype->is_map()) { string kiter = tmp("kiter"); string viter = tmp("viter"); - indent(out) << "for " << kiter << "," << viter << " in " << prefix << ".items():" << endl; + indent(out) << "for " << kiter << ", " << viter << " in " << prefix << ".items():" << endl; indent_up(); generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); indent_down(); @@ -2456,18 +2551,21 @@ string t_py_generator::type_to_spec_args(t_type* ttype) { } else if (ttype->is_struct() || ttype->is_xception()) { return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)"; } else if (ttype->is_map()) { - return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + "," - + type_to_spec_args(((t_map*)ttype)->get_key_type()) + "," - + type_to_enum(((t_map*)ttype)->get_val_type()) + "," - + type_to_spec_args(((t_map*)ttype)->get_val_type()) + ")"; + return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + ", " + + type_to_spec_args(((t_map*)ttype)->get_key_type()) + ", " + + type_to_enum(((t_map*)ttype)->get_val_type()) + ", " + + type_to_spec_args(((t_map*)ttype)->get_val_type()) + ", " + + (is_immutable(ttype) ? "True" : "False") + ")"; } else if (ttype->is_set()) { - return "(" + type_to_enum(((t_set*)ttype)->get_elem_type()) + "," - + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + ")"; + return "(" + type_to_enum(((t_set*)ttype)->get_elem_type()) + ", " + + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + ", " + + (is_immutable(ttype) ? "True" : "False") + ")"; } else if (ttype->is_list()) { - return "(" + type_to_enum(((t_list*)ttype)->get_elem_type()) + "," - + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + ")"; + return "(" + type_to_enum(((t_list*)ttype)->get_elem_type()) + ", " + + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + ", " + + (is_immutable(ttype) ? "True" : "False") + ")"; } throw "INVALID TYPE IN type_to_spec_args: " + ttype->get_name(); @@ -2484,6 +2582,7 @@ THRIFT_REGISTER_GENERATOR( " slots: Generate code using slots for instance members.\n" " dynamic: Generate dynamic code, less code generated but slower.\n" " dynbase=CLS Derive generated classes from class CLS instead of TBase.\n" + " dynfrozen=CLS Derive generated immutable classes from class CLS instead of TFrozenBase.\n" " dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n" " dynimport='from foo.bar import CLS'\n" " Add an import line to generated code to find the dynbase class.\n") http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/lib/py/src/Thrift.py ---------------------------------------------------------------------- diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py index 9890af7..cbb9184 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/Thrift.py @@ -168,3 +168,23 @@ class TApplicationException(TException): oprot.writeFieldEnd() oprot.writeFieldStop() oprot.writeStructEnd() + + +class TFrozenDict(dict): + """A dictionary that is "frozen" like a frozenset""" + + def __init__(self, *args, **kwargs): + super(TFrozenDict, self).__init__(*args, **kwargs) + # Sort the items so they will be in a consistent order. + # XOR in the hash of the class so we don't collide with + # the hash of a list of tuples. + self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) + + def __setitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __hash__(self): + return self.__hashval http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/lib/py/src/protocol/TBase.py ---------------------------------------------------------------------- diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py index 118a679..4f71e11 100644 --- a/lib/py/src/protocol/TBase.py +++ b/lib/py/src/protocol/TBase.py @@ -27,7 +27,7 @@ except: class TBase(object): - __slots__ = [] + __slots__ = () def __repr__(self): L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] @@ -68,4 +68,27 @@ class TBase(object): class TExceptionBase(TBase, Exception): - __slots__ = [] + pass + + +class TFrozenBase(TBase): + def __setitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __hash__(self, *args): + return hash(self.__class__) ^ hash(self.__slots__) + + @classmethod + def read(cls, iprot): + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + cls.thrift_spec is not None and + fastbinary is not None): + self = cls() + return fastbinary.decode_binary(None, + iprot.trans, + (self.__class__, self.thrift_spec)) + return iprot.readStruct(cls, cls.thrift_spec, True) http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/lib/py/src/protocol/TProtocol.py ---------------------------------------------------------------------- diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index ca22c48..1d703e3 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -17,7 +17,7 @@ # under the License. # -from thrift.Thrift import TException, TType +from thrift.Thrift import TException, TType, TFrozenDict import six from ..compat import binary_to_str, str_to_binary @@ -108,9 +108,6 @@ class TProtocolBase: def writeBinary(self, str_val): pass - def writeBinary(self, str_val): - return self.writeString(str_val) - def readMessageBegin(self): pass @@ -171,9 +168,6 @@ class TProtocolBase: def readBinary(self): pass - def readBinary(self): - return self.readString() - def skip(self, ttype): if ttype == TType.STOP: return @@ -264,6 +258,7 @@ class TProtocolBase: def readContainerList(self, spec): results = [] ttype, tspec = spec[0], spec[1] + is_immutable = spec[2] r_handler = self._ttype_handlers(ttype, spec)[0] reader = getattr(self, r_handler) (list_type, list_len) = self.readListBegin() @@ -279,11 +274,12 @@ class TProtocolBase: val = val_reader(tspec) results.append(val) self.readListEnd() - return results + return tuple(results) if is_immutable else results def readContainerSet(self, spec): results = set() ttype, tspec = spec[0], spec[1] + is_immutable = spec[2] r_handler = self._ttype_handlers(ttype, spec)[0] reader = getattr(self, r_handler) (set_type, set_len) = self.readSetBegin() @@ -297,7 +293,7 @@ class TProtocolBase: for idx in range(set_len): results.add(val_reader(tspec)) self.readSetEnd() - return results + return frozenset(results) if is_immutable else results def readContainerStruct(self, spec): (obj_class, obj_spec) = spec @@ -309,6 +305,7 @@ class TProtocolBase: results = dict() key_ttype, key_spec = spec[0], spec[1] val_ttype, val_spec = spec[2], spec[3] + is_immutable = spec[4] (map_ktype, map_vtype, map_len) = self.readMapBegin() # TODO: compare types we just decoded with thrift_spec and # abort/skip if types disagree @@ -328,9 +325,11 @@ class TProtocolBase: # i.e. this fails: d=dict(); d[[0,1]] = 2 results[k_val] = v_val self.readMapEnd() - return results + return TFrozenDict(results) if is_immutable else results - def readStruct(self, obj, thrift_spec): + def readStruct(self, obj, thrift_spec, is_immutable=False): + if is_immutable: + fields = {} self.readStructBegin() while True: (fname, ftype, fid) = self.readFieldBegin() @@ -345,11 +344,16 @@ class TProtocolBase: fname = field[2] fspec = field[3] val = self.readFieldByTType(ftype, fspec) - setattr(obj, fname, val) + if is_immutable: + fields[fname] = val + else: + setattr(obj, fname, val) else: self.skip(ftype) self.readFieldEnd() self.readStructEnd() + if is_immutable: + return obj(**fields) def writeContainerStruct(self, val, spec): val.write(self) http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/lib/py/src/protocol/fastbinary.c ---------------------------------------------------------------------- diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c index 93c4911..a17019b 100644 --- a/lib/py/src/protocol/fastbinary.c +++ b/lib/py/src/protocol/fastbinary.c @@ -124,11 +124,6 @@ typedef enum TType { #define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) #define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) -// Py_ssize_t was not defined before Python 2.5 -#if (PY_VERSION_HEX < 0x02050000) -typedef int Py_ssize_t; -#endif - /** * A cache of the spec_args for a set or list, * so we don't have to keep calling PyTuple_GET_ITEM. @@ -136,6 +131,7 @@ typedef int Py_ssize_t; typedef struct { TType element_type; PyObject* typeargs; + bool immutable; } SetListTypeArgs; /** @@ -147,6 +143,7 @@ typedef struct { TType vtag; PyObject* ktypeargs; PyObject* vtypeargs; + bool immutable; } MapTypeArgs; /** @@ -156,6 +153,7 @@ typedef struct { typedef struct { PyObject* klass; PyObject* spec; + bool immutable; } StructTypeArgs; /** @@ -233,8 +231,8 @@ parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { static bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 2) { - PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + if (PyTuple_Size(typeargs) != 3) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args"); return false; } @@ -245,13 +243,15 @@ parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2); + return true; } static bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { - if (PyTuple_Size(typeargs) != 4) { - PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + if (PyTuple_Size(typeargs) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map"); return false; } @@ -267,6 +267,7 @@ parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4); return true; } @@ -289,7 +290,7 @@ parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { // i'd like to use ParseArgs here, but it seems to be a bottleneck. if (PyTuple_Size(spec_tuple) != 5) { - PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d", PyTuple_Size(spec_tuple)); return false; } @@ -885,11 +886,21 @@ skip(DecodeBuffer* input, TType type) { static PyObject* decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); -static bool -decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { +static PyObject* +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) { int spec_seq_len = PyTuple_Size(spec_seq); + bool immutable = output == Py_None; + PyObject* kwargs = NULL; if (spec_seq_len == -1) { - return false; + return NULL; + } + + if (immutable) { + kwargs = PyDict_New(); + if (!kwargs) { + PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage"); + return NULL; + } } while (true) { @@ -901,14 +912,14 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { type = readByte(input); if (type == -1) { - return false; + goto error; } if (type == T_STOP) { break; } tag = readI16(input); if (INT_CONV_ERROR_OCCURRED(tag)) { - return false; + goto error; } if (tag >= 0 && tag < spec_seq_len) { item_spec = PyTuple_GET_ITEM(spec_seq, tag); @@ -918,19 +929,19 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { if (item_spec == Py_None) { if (!skip(input, type)) { - return false; + goto error; } else { continue; } } if (!parse_struct_item_spec(&parsedspec, item_spec)) { - return false; + goto error; } if (parsedspec.type != type) { if (!skip(input, type)) { PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); - return false; + goto error; } else { continue; } @@ -938,16 +949,34 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); if (fieldval == NULL) { - return false; + goto error; } - if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + if ((immutable && PyDict_SetItem(kwargs, parsedspec.attrname, fieldval) == -1) + || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1)) { Py_DECREF(fieldval); - return false; + goto error; } Py_DECREF(fieldval); } - return true; + if (immutable) { + PyObject* args = PyTuple_New(0); + PyObject* ret = NULL; + if (!args) { + PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage"); + goto error; + } + ret = PyObject_Call(klass, args, kwargs); + Py_DECREF(kwargs); + Py_DECREF(args); + return ret; + } + Py_INCREF(output); + return output; + + error: + Py_XDECREF(kwargs); + return NULL; } @@ -1033,6 +1062,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { int32_t len; PyObject* ret = NULL; int i; + bool use_tuple = false; if (!parse_set_list_args(&parsedargs, typeargs)) { return NULL; @@ -1047,7 +1077,8 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { return NULL; } - ret = PyList_New(len); + use_tuple = type == T_LIST && parsedargs.immutable; + ret = use_tuple ? PyTuple_New(len) : PyList_New(len); if (!ret) { return NULL; } @@ -1058,20 +1089,18 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { Py_DECREF(ret); return NULL; } - PyList_SET_ITEM(ret, i, item); + if (use_tuple) { + PyTuple_SET_ITEM(ret, i, item); + } else { + PyList_SET_ITEM(ret, i, item); + } } // TODO(dreiss): Consider biting the bullet and making two separate cases // for list and set, avoiding this post facto conversion. if (type == T_SET) { PyObject* setret; -#if (PY_VERSION_HEX < 0x02050000) - // hack needed for older versions - setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); -#else - // official version - setret = PySet_New(ret); -#endif + setret = parsedargs.immutable ? PyFrozenSet_New(ret) : PySet_New(ret); Py_DECREF(ret); return setret; } @@ -1131,6 +1160,22 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { goto error; } + if (parsedargs.immutable) { + PyObject* thrift = PyImport_ImportModule("thrift.Thrift"); + PyObject* cls = NULL; + PyObject* arg = NULL; + if (!thrift) { + goto error; + } + cls = PyObject_GetAttrString(thrift, "TFrozenDict"); + if (!cls) { + goto error; + } + arg = PyTuple_New(1); + PyTuple_SET_ITEM(arg, 0, ret); + return PyObject_CallObject(cls, arg); + } + return ret; error: @@ -1140,22 +1185,12 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { case T_STRUCT: { StructTypeArgs parsedargs; - PyObject* ret; + PyObject* ret; if (!parse_struct_args(&parsedargs, typeargs)) { return NULL; } - ret = PyObject_CallObject(parsedargs.klass, NULL); - if (!ret) { - return NULL; - } - - if (!decode_struct(input, ret, parsedargs.spec)) { - Py_DECREF(ret); - return NULL; - } - - return ret; + return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec); } case T_STOP: @@ -1179,6 +1214,7 @@ decode_binary(PyObject *self, PyObject *args) { PyObject* typeargs = NULL; StructTypeArgs parsedargs; DecodeBuffer input = {0, 0}; + PyObject* ret = NULL; if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { return NULL; @@ -1192,14 +1228,9 @@ decode_binary(PyObject *self, PyObject *args) { return NULL; } - if (!decode_struct(&input, output_obj, parsedargs.spec)) { - free_decodebuf(&input); - return NULL; - } - + ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec); free_decodebuf(&input); - - Py_RETURN_NONE; + return ret; } /* ====== END READING FUNCTIONS ====== */ http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/test/DebugProtoTest.thrift ---------------------------------------------------------------------- diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift index 50ae4c1..e7119c4 100644 --- a/test/DebugProtoTest.thrift +++ b/test/DebugProtoTest.thrift @@ -72,11 +72,15 @@ struct Backwards { } struct Empty { -} +} ( + python.immutable = "", +) struct Wrapper { 1: Empty foo -} +} ( + python.immutable = "", +) struct RandomStuff { 1: i32 a, @@ -153,9 +157,9 @@ struct CompactProtoTestStruct { 42: map<byte, binary> byte_binary_map; 43: map<byte, bool> byte_boolean_map; // collections as keys - 44: map<list<byte>, byte> list_byte_map; - 45: map<set<byte>, byte> set_byte_map; - 46: map<map<byte,byte>, byte> map_byte_map; + 44: map<list<byte> (python.immutable = ""), byte> list_byte_map; + 45: map<set<byte> (python.immutable = ""), byte> set_byte_map; + 46: map<map<byte,byte> (python.immutable = ""), byte> map_byte_map; // collections as values 47: map<byte, map<byte,byte>> byte_map_map; 48: map<byte, set<byte>> byte_set_map; http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/test/ThriftTest.thrift ---------------------------------------------------------------------- diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift index 414f9a5..a58ed97 100644 --- a/test/ThriftTest.thrift +++ b/test/ThriftTest.thrift @@ -105,12 +105,13 @@ struct Insanity { 1: map<Numberz, UserId> userMap, 2: list<Xtruct> xtructs -} +} (python.immutable= "") struct CrazyNesting { 1: string string_field, 2: optional set<Insanity> set_field, - 3: required list< map<set<i32>,map<i32,set<list<map<Insanity,string>>>>>> list_field, + 3: required list<map<set<i32> (python.immutable = ""), + map<i32,set<list<map<Insanity,string>(python.immutable = "")> (python.immutable = "")>>>> list_field, 4: binary binary_field } http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/test/py/RunClientServer.py ---------------------------------------------------------------------- diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py index fa2a264..f084a41 100755 --- a/test/py/RunClientServer.py +++ b/test/py/RunClientServer.py @@ -37,6 +37,7 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*') DEFAULT_LIBDIR_PY3 = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib') SCRIPTS = [ + 'TestFrozen.py', 'TSimpleJSONProtocolTest.py', 'SerializationTest.py', 'TestEof.py', http://git-wip-us.apache.org/repos/asf/thrift/blob/e841b3da/test/py/TestFrozen.py ---------------------------------------------------------------------- diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py new file mode 100755 index 0000000..76750ad --- /dev/null +++ b/test/py/TestFrozen.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty, Wrapper +from thrift.Thrift import TFrozenDict +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol +import collections +import unittest + + +class TestFrozenBase(unittest.TestCase): + def _roundtrip(self, src, dst): + otrans = TTransport.TMemoryBuffer() + optoro = self.protocol(otrans) + src.write(optoro) + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = self.protocol(itrans) + return dst.read(iproto) or dst + + def test_dict_is_hashable_only_after_frozen(self): + d0 = {} + self.assertFalse(isinstance(d0, collections.Hashable)) + d1 = TFrozenDict(d0) + self.assertTrue(isinstance(d1, collections.Hashable)) + + def test_struct_with_collection_fields(self): + pass + + def test_set(self): + """Test that annotated set field can be serialized and deserialized""" + x = CompactProtoTestStruct(set_byte_map={ + frozenset([42, 100, -100]): 99, + frozenset([0]): 100, + frozenset([]): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99) + self.assertEqual(x2.set_byte_map[frozenset([0])], 100) + self.assertEqual(x2.set_byte_map[frozenset([])], 0) + + def test_map(self): + """Test that annotated map field can be serialized and deserialized""" + x = CompactProtoTestStruct(map_byte_map={ + TFrozenDict({42: 42, 100: -100}): 99, + TFrozenDict({0: 0}): 100, + TFrozenDict({}): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99) + self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100) + self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0) + + def test_list(self): + """Test that annotated list field can be serialized and deserialized""" + x = CompactProtoTestStruct(list_byte_map={ + (42, 100, -100): 99, + (0,): 100, + (): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99) + self.assertEqual(x2.list_byte_map[(0,)], 100) + self.assertEqual(x2.list_byte_map[()], 0) + + def test_empty_struct(self): + """Test that annotated empty struct can be serialized and deserialized""" + x = CompactProtoTestStruct(empty_struct_field=Empty()) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.empty_struct_field, Empty()) + + def test_struct(self): + """Test that annotated struct can be serialized and deserialized""" + x = Wrapper(foo=Empty()) + self.assertEqual(x.foo, Empty()) + x2 = self._roundtrip(x, Wrapper) + self.assertEqual(x2.foo, Empty()) + + +class TestFrozen(TestFrozenBase): + def protocol(self, trans): + return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans) + + +class TestFrozenAccelerated(TestFrozenBase): + def protocol(self, trans): + return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans) + + +def suite(): + suite = unittest.TestSuite() + loader = unittest.TestLoader() + suite.addTest(loader.loadTestsFromTestCase(TestFrozen)) + suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated)) + return suite + +if __name__ == "__main__": + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
