Repository: thrift Updated Branches: refs/heads/master 051ed3c80 -> 85650612e
THRIFT-3525 py:dynamic fails to handle binary list/set/map element This closes #775 Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/299255af Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/299255af Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/299255af Branch: refs/heads/master Commit: 299255afbb1f0ba302d3e29a76e20c0f5984f31e Parents: 1b4ebc3 Author: Nobuaki Sukegawa <[email protected]> Authored: Wed Jan 6 14:52:50 2016 +0900 Committer: Nobuaki Sukegawa <[email protected]> Committed: Mon Jan 11 11:34:20 2016 +0900 ---------------------------------------------------------------------- lib/py/src/protocol/TProtocol.py | 162 ++++++++++++---------------------- lib/py/src/protocol/fastbinary.c | 2 +- 2 files changed, 55 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/299255af/lib/py/src/protocol/TProtocol.py ---------------------------------------------------------------------- diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index be2fcea..9679ba0 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -18,10 +18,13 @@ # from thrift.Thrift import TException, TType, TFrozenDict -import six - from ..compat import binary_to_str, str_to_binary +import six +import sys +from itertools import islice +from six.moves import zip + class TProtocolException(TException): """Custom Protocol Exception class""" @@ -239,61 +242,38 @@ class TProtocolBase(object): raise TProtocolException(type=TProtocolException.INVALID_DATA, message='Invalid binary field type %d' % ttype) return ('readBinary', 'writeBinary', False) - return self._TTYPE_HANDLERS[ttype] + return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) - def readFieldByTType(self, ttype, spec): - try: - (r_handler, w_handler, is_container) = self._ttype_handlers(ttype, spec) - except IndexError: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - if r_handler is None: + def _read_by_ttype(self, ttype, spec, espec): + reader_name, _, is_container = self._ttype_handlers(ttype, spec) + if reader_name is None: raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - reader = getattr(self, r_handler) - if not is_container: - return reader() - return reader(spec) + message='Invalid type %d' % (ttype)) + reader_func = getattr(self, reader_name) + read = (lambda: reader_func(espec)) if is_container else reader_func + while True: + yield read() + + def readFieldByTType(self, ttype, spec): + return self._read_by_ttype(ttype, spec, spec).next() def readContainerList(self, spec): - results = [] - ttype, tspec = spec[0], spec[1] - is_immutable = spec[2] - r_handler = self._ttype_handlers(ttype, spec)[0] - reader = getattr(self, r_handler) + ttype, tspec, is_immutable = spec (list_type, list_len) = self.readListBegin() - if tspec is None: - # list values are simple types - for idx in range(list_len): - results.append(reader()) - else: - # this is like an inlined readFieldByTType - container_reader = self._ttype_handlers(list_type, tspec)[0] - val_reader = getattr(self, container_reader) - for idx in range(list_len): - val = val_reader(tspec) - results.append(val) + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) + results = (tuple if is_immutable else list)(elems) self.readListEnd() - return tuple(results) if is_immutable else results + return results def readContainerSet(self, spec): - results = set() - ttype, tspec = spec[0], spec[1] - is_immutable = spec[2] - r_handler = self._ttype_handlers(ttype, spec)[0] - reader = getattr(self, r_handler) + ttype, tspec, is_immutable = spec (set_type, set_len) = self.readSetBegin() - if tspec is None: - # set members are simple types - for idx in range(set_len): - results.add(reader()) - else: - container_reader = self._ttype_handlers(set_type, tspec)[0] - val_reader = getattr(self, container_reader) - for idx in range(set_len): - results.add(val_reader(tspec)) + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) + results = (frozenset if is_immutable else set)(elems) self.readSetEnd() - return frozenset(results) if is_immutable else results + return results def readContainerStruct(self, spec): (obj_class, obj_spec) = spec @@ -302,30 +282,16 @@ class TProtocolBase(object): return obj def readContainerMap(self, spec): - results = dict() - key_ttype, key_spec = spec[0], spec[1] - val_ttype, val_spec = spec[2], spec[3] - is_immutable = spec[4] + ktype, kspec, vtype, vspec, is_immutable = spec (map_ktype, map_vtype, map_len) = self.readMapBegin() # TODO: compare types we just decoded with thrift_spec and # abort/skip if types disagree - key_reader = getattr(self, self._ttype_handlers(key_ttype, key_spec)[0]) - val_reader = getattr(self, self._ttype_handlers(val_ttype, val_spec)[0]) - # list values are simple types - for idx in range(map_len): - if key_spec is None: - k_val = key_reader() - else: - k_val = self.readFieldByTType(key_ttype, key_spec) - if val_spec is None: - v_val = val_reader() - else: - v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types - # i.e. this fails: d=dict(); d[[0,1]] = 2 - results[k_val] = v_val + keys = self._read_by_ttype(ktype, spec, kspec) + vals = self._read_by_ttype(vtype, spec, vspec) + keyvals = islice(zip(keys, vals), map_len) + results = (TFrozenDict if is_immutable else dict)(keyvals) self.readMapEnd() - return TFrozenDict(results) if is_immutable else results + return results def readStruct(self, obj, thrift_spec, is_immutable=False): if is_immutable: @@ -359,46 +325,25 @@ class TProtocolBase(object): val.write(self) def writeContainerList(self, val, spec): - self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._ttype_handlers(spec[0], spec) - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) + ttype, tspec, _ = spec + self.writeListBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass self.writeListEnd() def writeContainerSet(self, val, spec): - self.writeSetBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._ttype_handlers(spec[0], spec) - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) + ttype, tspec, _ = spec + self.writeSetBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass self.writeSetEnd() def writeContainerMap(self, val, spec): - k_type = spec[0] - v_type = spec[2] - ignore, ktype_name, k_is_container = self._ttype_handlers(k_type, spec) - ignore, vtype_name, v_is_container = self._ttype_handlers(v_type, spec) - k_writer = getattr(self, ktype_name) - v_writer = getattr(self, vtype_name) - self.writeMapBegin(k_type, v_type, len(val)) - for m_key, m_val in six.iteritems(val): - if not k_is_container: - k_writer(m_key) - else: - k_writer(m_key, spec[1]) - if not v_is_container: - v_writer(m_val) - else: - v_writer(m_val, spec[3]) + ktype, kspec, vtype, vspec, _ = spec + self.writeMapBegin(ktype, vtype, len(val)) + for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), + self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): + pass self.writeMapEnd() def writeStruct(self, obj, thrift_spec): @@ -414,20 +359,21 @@ class TProtocolBase(object): fid = field[0] ftype = field[1] fspec = field[3] - # get the writer method for this value self.writeFieldBegin(fname, ftype, fid) self.writeFieldByTType(ftype, val, fspec) self.writeFieldEnd() self.writeFieldStop() self.writeStructEnd() + def _write_by_ttype(self, ttype, vals, spec, espec): + _, writer_name, is_container = self._ttype_handlers(ttype, spec) + writer_func = getattr(self, writer_name) + write = (lambda v: writer_func(v, espec)) if is_container else writer_func + for v in vals: + yield write(v) + def writeFieldByTType(self, ttype, val, spec): - r_handler, w_handler, is_container = self._ttype_handlers(ttype, spec) - writer = getattr(self, w_handler) - if is_container: - writer(val, spec) - else: - writer(val) + self._write_by_ttype(ttype, [val], spec, spec).next() def checkIntegerLimits(i, bits): http://git-wip-us.apache.org/repos/asf/thrift/blob/299255af/lib/py/src/protocol/fastbinary.c ---------------------------------------------------------------------- diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c index 714fb13..eaecb8c 100644 --- a/lib/py/src/protocol/fastbinary.c +++ b/lib/py/src/protocol/fastbinary.c @@ -947,7 +947,7 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* } if (parsedspec.type != type) { if (!skip(input, type)) { - PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); + PyErr_Format(PyExc_TypeError, "struct field had wrong type: expected %d but got %d", parsedspec.type, type); goto error; } else { continue;
