Author: Armin Rigo <ar...@tunes.org> Branch: stmgc-c8-dictiter Changeset: r80553:47629bb038b7 Date: 2015-11-05 19:27 +0100 http://bitbucket.org/pypy/pypy/changeset/47629bb038b7/
Log: Iterators on hashtables diff --git a/pypy/module/pypystm/hashtable.py b/pypy/module/pypystm/hashtable.py --- a/pypy/module/pypystm/hashtable.py +++ b/pypy/module/pypystm/hashtable.py @@ -2,6 +2,7 @@ The class pypystm.hashtable, mapping integers to objects. """ +from pypy.interpreter.error import OperationError from pypy.interpreter.baseobjspace import W_Root from pypy.interpreter.typedef import TypeDef from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault @@ -78,6 +79,57 @@ for i in range(count)] return space.newlist(lst_w) + def iterkeys_w(self, space): + return W_HashtableIterKeys(self.h) + + def itervalues_w(self, space): + return W_HashtableIterValues(self.h) + + def iteritems_w(self, space): + return W_HashtableIterItems(self.h) + + +class W_BaseHashtableIter(W_Root): + _immutable_fields_ = ["hiter"] + + def __init__(self, hobj): + self.hiter = hobj.iterentries() + + def descr_iter(self, space): + return self + + def descr_length_hint(self, space): + # xxx overestimate: doesn't remove the items already yielded, + # and uses the faster len_estimate() + return space.wrap(self.hiter.hashtable.len_estimate()) + + def next_entry(self, space): + try: + return self.hiter.next() + except StopIteration: + raise OperationError(space.w_StopIteration, space.w_None) + + def _cleanup_(self): + raise Exception("seeing a prebuilt %r object" % ( + self.__class__,)) + +class W_HashtableIterKeys(W_BaseHashtableIter): + def descr_next(self, space): + entry = self.next_entry(space) + return space.wrap(intmask(entry.index)) + +class W_HashtableIterValues(W_BaseHashtableIter): + def descr_next(self, space): + entry = self.next_entry(space) + return cast_gcref_to_instance(W_Root, entry.object) + +class W_HashtableIterItems(W_BaseHashtableIter): + def descr_next(self, space): + entry = self.next_entry(space) + return space.newtuple([ + space.wrap(intmask(entry.index)), + cast_gcref_to_instance(W_Root, entry.object)]) + def W_Hashtable___new__(space, w_subtype): r = space.allocate_instance(W_Hashtable, w_subtype) @@ -98,4 +150,30 @@ keys = interp2app(W_Hashtable.keys_w), values = interp2app(W_Hashtable.values_w), items = interp2app(W_Hashtable.items_w), + + __iter__ = interp2app(W_Hashtable.iterkeys_w), + iterkeys = interp2app(W_Hashtable.iterkeys_w), + itervalues = interp2app(W_Hashtable.itervalues_w), + iteritems = interp2app(W_Hashtable.iteritems_w), ) + +W_HashtableIterKeys.typedef = TypeDef( + "hashtable_iterkeys", + __iter__ = interp2app(W_HashtableIterKeys.descr_iter), + next = interp2app(W_HashtableIterKeys.descr_next), + __length_hint__ = interp2app(W_HashtableIterKeys.descr_length_hint), + ) + +W_HashtableIterValues.typedef = TypeDef( + "hashtable_itervalues", + __iter__ = interp2app(W_HashtableIterValues.descr_iter), + next = interp2app(W_HashtableIterValues.descr_next), + __length_hint__ = interp2app(W_HashtableIterValues.descr_length_hint), + ) + +W_HashtableIterItems.typedef = TypeDef( + "hashtable_iteritems", + __iter__ = interp2app(W_HashtableIterItems.descr_iter), + next = interp2app(W_HashtableIterItems.descr_next), + __length_hint__ = interp2app(W_HashtableIterItems.descr_length_hint), + ) diff --git a/pypy/module/pypystm/test/test_hashtable.py b/pypy/module/pypystm/test/test_hashtable.py --- a/pypy/module/pypystm/test/test_hashtable.py +++ b/pypy/module/pypystm/test/test_hashtable.py @@ -55,3 +55,13 @@ assert sorted(h.keys()) == [42, 43] assert sorted(h.values()) == ["bar", "foo"] assert sorted(h.items()) == [(42, "foo"), (43, "bar")] + + def test_iterator(self): + import pypystm + h = pypystm.hashtable() + h[42] = "foo" + h[43] = "bar" + assert sorted(h) == [42, 43] + assert sorted(h.iterkeys()) == [42, 43] + assert sorted(h.itervalues()) == ["bar", "foo"] + assert sorted(h.iteritems()) == [(42, "foo"), (43, "bar")] diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py --- a/rpython/rlib/rstm.py +++ b/rpython/rlib/rstm.py @@ -223,11 +223,13 @@ # ____________________________________________________________ _STM_HASHTABLE_P = rffi.COpaquePtr('stm_hashtable_t') +_STM_HASHTABLE_TABLE_P = rffi.COpaquePtr('stm_hashtable_table_t') _STM_HASHTABLE_ENTRY = lltype.GcStruct('HASHTABLE_ENTRY', ('index', lltype.Unsigned), ('object', llmemory.GCREF)) _STM_HASHTABLE_ENTRY_P = lltype.Ptr(_STM_HASHTABLE_ENTRY) +_STM_HASHTABLE_ENTRY_PP = rffi.CArrayPtr(_STM_HASHTABLE_ENTRY_P) _STM_HASHTABLE_ENTRY_ARRAY = lltype.GcArray(_STM_HASHTABLE_ENTRY_P) @dont_look_inside @@ -245,6 +247,11 @@ lltype.nullptr(_STM_HASHTABLE_ENTRY_ARRAY)) @dont_look_inside +def _ll_hashtable_len_estimate(h): + return llop.stm_hashtable_length_upper_bound(lltype.Signed, + h.ll_raw_hashtable) + +@dont_look_inside def _ll_hashtable_list(h): upper_bound = llop.stm_hashtable_length_upper_bound(lltype.Signed, h.ll_raw_hashtable) @@ -264,6 +271,27 @@ def _ll_hashtable_writeobj(h, entry, value): llop.stm_hashtable_write_entry(lltype.Void, h, entry, value) +@dont_look_inside +def _ll_hashtable_iterentries(h): + rgc.register_custom_trace_hook(_HASHTABLE_ITER_OBJ, + lambda_hashtable_iter_trace) + table = llop.stm_hashtable_iter(_STM_HASHTABLE_TABLE_P, h.ll_raw_hashtable) + hiter = lltype.malloc(_HASHTABLE_ITER_OBJ) + hiter.hashtable = h # for keepalive + hiter.table = table + hiter.prev = lltype.nullptr(_STM_HASHTABLE_ENTRY_PP.TO) + return hiter + +def _ll_hashiter_next(hiter): + entrypp = llop.stm_hashtable_iter_next(_STM_HASHTABLE_ENTRY_PP, + hiter.hashtable, + hiter.table, + hiter.prev) + if not entrypp: + raise StopIteration + hiter.prev = entrypp + return entrypp[0] + _HASHTABLE_OBJ = lltype.GcStruct('HASHTABLE_OBJ', ('ll_raw_hashtable', _STM_HASHTABLE_P), hints={'immutable': True}, @@ -271,11 +299,19 @@ adtmeths={'get': _ll_hashtable_get, 'set': _ll_hashtable_set, 'len': _ll_hashtable_len, + 'len_estimate': _ll_hashtable_len_estimate, 'list': _ll_hashtable_list, 'lookup': _ll_hashtable_lookup, - 'writeobj': _ll_hashtable_writeobj}) + 'writeobj': _ll_hashtable_writeobj, + 'iterentries': _ll_hashtable_iterentries}) NULL_HASHTABLE = lltype.nullptr(_HASHTABLE_OBJ) +_HASHTABLE_ITER_OBJ = lltype.GcStruct('HASHTABLE_ITER_OBJ', + ('hashtable', lltype.Ptr(_HASHTABLE_OBJ)), + ('table', _STM_HASHTABLE_TABLE_P), + ('prev', _STM_HASHTABLE_ENTRY_PP), + adtmeths={'next': _ll_hashiter_next}) + def _ll_hashtable_trace(gc, obj, callback, arg): from rpython.memory.gctransform.stmframework import get_visit_function visit_fn = get_visit_function(callback, arg) @@ -288,6 +324,15 @@ llop.stm_hashtable_free(lltype.Void, h.ll_raw_hashtable) lambda_hashtable_finlz = lambda: _ll_hashtable_finalizer +def _ll_hashtable_iter_trace(gc, obj, callback, arg): + from rpython.memory.gctransform.stmframework import get_visit_function + addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'hashtable') + gc._trace_callback(callback, arg, addr) + visit_fn = get_visit_function(callback, arg) + addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'table') + llop.stm_hashtable_iter_tracefn(lltype.Void, addr.address[0], visit_fn) +lambda_hashtable_iter_trace = lambda: _ll_hashtable_iter_trace + _false = CDefinedIntSymbolic('0', default=0) # remains in the C code @dont_look_inside @@ -344,6 +389,9 @@ items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF] return len(items) + def len_estimate(self): + return len(self._content) + def list(self): items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF] count = len(items) @@ -359,6 +407,9 @@ assert isinstance(entry, EntryObjectForTest) self.set(entry.key, nvalue) + def iterentries(self): + return IterEntriesForTest(self, self._content.itervalues()) + class EntryObjectForTest(object): def __init__(self, hashtable, key): self.hashtable = hashtable @@ -374,6 +425,14 @@ object = property(_getobj, _setobj) +class IterEntriesForTest(object): + def __init__(self, hashtable, iterator): + self.hashtable = hashtable + self.iterator = iterator + + def next(self): + return next(self.iterator) + # ____________________________________________________________ _STM_QUEUE_P = rffi.COpaquePtr('stm_queue_t') diff --git a/rpython/translator/stm/funcgen.py b/rpython/translator/stm/funcgen.py --- a/rpython/translator/stm/funcgen.py +++ b/rpython/translator/stm/funcgen.py @@ -398,9 +398,28 @@ arg0 = funcgen.expr(op.args[0]) arg1 = funcgen.expr(op.args[1]) arg2 = funcgen.expr(op.args[2]) - return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s, ' + return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s,' ' (void(*)(object_t**))%s);' % (arg0, arg1, arg2)) +def stm_hashtable_iter(funcgen, op): + arg0 = funcgen.expr(op.args[0]) + result = funcgen.expr(op.result) + return '%s = stm_hashtable_iter(%s);' % (result, arg0) + +def stm_hashtable_iter_next(funcgen, op): + arg0 = funcgen.expr(op.args[0]) + arg1 = funcgen.expr(op.args[1]) + arg2 = funcgen.expr(op.args[2]) + result = funcgen.expr(op.result) + return ('%s = stm_hashtable_iter_next(%s, %s, %s);' % + (arg0, arg1, arg2, result)) + +def stm_hashtable_iter_tracefn(funcgen, op): + arg0 = funcgen.expr(op.args[0]) + arg1 = funcgen.expr(op.args[1]) + return ('stm_hashtable_tracefn((stm_hashtable_table_t *)%s,' + ' (void(*)(object_t**))%s);' % (arg0, arg1)) + def stm_queue_create(funcgen, op): result = funcgen.expr(op.result) return '%s = stm_queue_create();' % (result,) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit