Author: Maciej Fijalkowski <[email protected]>
Branch: rdict-experiments-3
Changeset: r67244:1e3bab885783
Date: 2013-10-09 15:46 +0200
http://bitbucket.org/pypy/pypy/changeset/1e3bab885783/

Log:    (fijal, arigo) whack whack whack until we make the first test pass

diff --git a/rpython/rtyper/lltypesystem/rdict.py 
b/rpython/rtyper/lltypesystem/rdict.py
--- a/rpython/rtyper/lltypesystem/rdict.py
+++ b/rpython/rtyper/lltypesystem/rdict.py
@@ -1,43 +1,153 @@
+import sys
 from rpython.tool.pairtype import pairtype
 from rpython.flowspace.model import Constant
 from rpython.rtyper.rdict import AbstractDictRepr, AbstractDictIteratorRepr
-from rpython.rtyper.lltypesystem import lltype
-from rpython.rlib import objectmodel, jit
+from rpython.rtyper.lltypesystem import lltype, llmemory, rffi
+from rpython.rlib import objectmodel, jit, rgc
 from rpython.rlib.debug import ll_assert
-from rpython.rlib.rarithmetic import r_uint, intmask, LONG_BIT
+from rpython.rlib.rarithmetic import r_uint, intmask
 from rpython.rtyper import rmodel
 from rpython.rtyper.error import TyperError
+from rpython.rtyper.annlowlevel import llhelper
 
 
-HIGHEST_BIT = r_uint(intmask(1 << (LONG_BIT - 1)))
-MASK = r_uint(intmask(HIGHEST_BIT - 1))
-
 # ____________________________________________________________
 #
 #  generic implementation of RPython dictionary, with parametric DICTKEY and
-#  DICTVALUE types.
+#  DICTVALUE types. The basic implementation is a sparse array of indexes
+#  plus a dense array of structs that contain keys and values. struct looks
+#  like that:
 #
-#  XXX for immutable dicts, the array should be inlined and
-#      resize_counter and everused are not needed.
 #
 #    struct dictentry {
 #        DICTKEY key;
+#        DICTVALUE value;
+#        long f_hash;        # (optional) key hash, if hard to recompute
 #        bool f_valid;      # (optional) the entry is filled
-#        bool f_everused;   # (optional) the entry is or has ever been filled
-#        DICTVALUE value;
-#        int f_hash;        # (optional) key hash, if hard to recompute
 #    }
 #
 #    struct dicttable {
 #        int num_items;
+#        int num_used_items;
 #        int resize_counter;
-#        Array *entries;
+#        {byte, short, int, long} *indexes;
+#        dictentry *entries;
+#        lookup_function; # one of the four possible functions for different
+#                         # size dicts
 #        (Function DICTKEY, DICTKEY -> bool) *fnkeyeq;
 #        (Function DICTKEY -> int) *fnkeyhash;
 #    }
 #
 #
 
+def get_ll_dict(DICTKEY, DICTVALUE, get_custom_eq_hash=None, DICT=None,
+                ll_fasthash_function=None, ll_hash_function=None,
+                ll_eq_function=None, method_cache={},
+                dummykeyobj=None, dummyvalueobj=None):
+    # get the actual DICT type. if DICT is None, it's created, otherwise
+    # forward reference is becoming DICT
+    if DICT is None:
+        DICT = lltype.GcForwardReference()
+    # compute the shape of the DICTENTRY structure
+    entryfields = []
+    entrymeths = {
+        'allocate': lltype.typeMethod(_ll_malloc_entries),
+        'delete': _ll_free_entries,
+        'must_clear_key':   (isinstance(DICTKEY, lltype.Ptr)
+                             and DICTKEY._needsgc()),
+        'must_clear_value': (isinstance(DICTVALUE, lltype.Ptr)
+                             and DICTVALUE._needsgc()),
+        }
+
+    # * the key
+    entryfields.append(("key", DICTKEY))
+
+    # * the state of the entry - trying to encode it as dummy objects
+    if dummykeyobj:
+        # all the state can be encoded in the key
+        entrymeths['dummy_obj'] = dummykeyobj
+        entrymeths['valid'] = ll_valid_from_key
+        entrymeths['mark_deleted'] = ll_mark_deleted_in_key
+        # the key is overwritten by 'dummy' when the entry is deleted
+        entrymeths['must_clear_key'] = False
+
+    elif dummyvalueobj:
+        # all the state can be encoded in the value
+        entrymeths['dummy_obj'] = dummyvalueobj
+        entrymeths['valid'] = ll_valid_from_value
+        entrymeths['mark_deleted'] = ll_mark_deleted_in_value
+        # value is overwritten by 'dummy' when entry is deleted
+        entrymeths['must_clear_value'] = False
+
+    else:
+        # we need a flag to know if the entry was ever used
+        entryfields.append(("f_valid", lltype.Bool))
+        entrymeths['valid'] = ll_valid_from_flag
+        entrymeths['mark_deleted'] = ll_mark_deleted_in_flag
+
+    # * the value
+    entryfields.append(("value", DICTVALUE))
+
+    if ll_fasthash_function is None:
+        entryfields.append(("f_hash", lltype.Signed))
+        entrymeths['hash'] = ll_hash_from_cache
+    else:
+        entrymeths['hash'] = ll_hash_recomputed
+        entrymeths['fasthashfn'] = ll_fasthash_function
+
+    # Build the lltype data structures
+    DICTENTRY = lltype.Struct("dictentry", *entryfields)
+    DICTENTRYARRAY = lltype.GcArray(DICTENTRY,
+                                    adtmeths=entrymeths)
+    LOOKUP_FUNC = lltype.Ptr(lltype.FuncType([lltype.Ptr(DICT), DICTKEY, 
lltype.Signed, lltype.Signed], lltype.Signed))
+
+
+    fields =          [ ("num_items", lltype.Signed),
+                        ("num_used_items", lltype.Signed),
+                        ("resize_counter", lltype.Signed),
+                        ("indexes", llmemory.GCREF),
+                        ("lookup_function", LOOKUP_FUNC),
+                        ("entries", lltype.Ptr(DICTENTRYARRAY)) ]
+    if get_custom_eq_hash is not None:
+        r_rdict_eqfn, r_rdict_hashfn = get_custom_eq_hash()
+        fields.extend([ ("fnkeyeq", r_rdict_eqfn.lowleveltype),
+                        ("fnkeyhash", r_rdict_hashfn.lowleveltype) ])
+        adtmeths = {
+            'keyhash':        ll_keyhash_custom,
+            'keyeq':          ll_keyeq_custom,
+            'r_rdict_eqfn':   r_rdict_eqfn,
+            'r_rdict_hashfn': r_rdict_hashfn,
+            'paranoia':       True,
+            }
+    else:
+        # figure out which functions must be used to hash and compare
+        ll_keyhash = ll_hash_function
+        ll_keyeq = ll_eq_function
+        ll_keyhash = lltype.staticAdtMethod(ll_keyhash)
+        if ll_keyeq is not None:
+            ll_keyeq = lltype.staticAdtMethod(ll_keyeq)
+        adtmeths = {
+            'keyhash':  ll_keyhash,
+            'keyeq':    ll_keyeq,
+            'paranoia': False,
+            }
+    adtmeths['KEY']   = DICTKEY
+    adtmeths['VALUE'] = DICTVALUE
+    adtmeths['allocate'] = lltype.typeMethod(_ll_malloc_dict)
+    adtmeths['empty_array'] = DICTENTRYARRAY.allocate(0)
+    adtmeths['byte_lookup_function'] = new_lookup_function(LOOKUP_FUNC,
+                                                           T=rffi.UCHAR)
+    adtmeths['short_lookup_function'] = new_lookup_function(LOOKUP_FUNC,
+                                                            T=rffi.USHORT)
+    if IS_64BIT:
+        adtmeths['int_lookup_function'] = new_lookup_function(LOOKUP_FUNC,
+                                                              T=rffi.UINT)
+    adtmeths['long_lookup_function'] = new_lookup_function(LOOKUP_FUNC,
+                                                           T=lltype.Unsigned)
+    DICT.become(lltype.GcStruct("dicttable", adtmeths=adtmeths,
+                                *fields))
+    return DICT
+
 class DictRepr(AbstractDictRepr):
 
     def __init__(self, rtyper, key_repr, value_repr, dictkey, dictvalue,
@@ -73,140 +183,29 @@
         if 'value_repr' not in self.__dict__:
             self.external_value_repr, self.value_repr = 
self.pickrepr(self._value_repr_computer())
         if isinstance(self.DICT, lltype.GcForwardReference):
-            self.DICTKEY = self.key_repr.lowleveltype
-            self.DICTVALUE = self.value_repr.lowleveltype
-
-            # compute the shape of the DICTENTRY structure
-            entryfields = []
-            entrymeths = {
-                'allocate': lltype.typeMethod(_ll_malloc_entries),
-                'delete': _ll_free_entries,
-                'must_clear_key':   (isinstance(self.DICTKEY, lltype.Ptr)
-                                     and self.DICTKEY._needsgc()),
-                'must_clear_value': (isinstance(self.DICTVALUE, lltype.Ptr)
-                                     and self.DICTVALUE._needsgc()),
-                }
-
-            # * the key
-            entryfields.append(("key", self.DICTKEY))
-
-            # * if NULL is not a valid ll value for the key or the value
-            #   field of the entry, it can be used as a marker for
-            #   never-used entries.  Otherwise, we need an explicit flag.
+            DICTKEY = self.key_repr.lowleveltype
+            DICTVALUE = self.value_repr.lowleveltype
+            # * we need an explicit flag if the key and the value is not
+            #   able to store dummy values
             s_key   = self.dictkey.s_value
             s_value = self.dictvalue.s_value
-            nullkeymarker = not self.key_repr.can_ll_be_null(s_key)
-            nullvaluemarker = not self.value_repr.can_ll_be_null(s_value)
-            if self.force_non_null:
-                if not nullkeymarker:
-                    rmodel.warning("%s can be null, but forcing non-null in 
dict key" % s_key)
-                    nullkeymarker = True
-                if not nullvaluemarker:
-                    rmodel.warning("%s can be null, but forcing non-null in 
dict value" % s_value)
-                    nullvaluemarker = True
-            dummykeyobj = self.key_repr.get_ll_dummyval_obj(self.rtyper,
-                                                            s_key)
-            dummyvalueobj = self.value_repr.get_ll_dummyval_obj(self.rtyper,
-                                                                s_value)
+            assert not self.force_non_null # XXX kill the flag
+            kwd = {}
+            if self.custom_eq_hash:
+                kwd['get_custom_eq_hash'] = self.custom_eq_hash
+            else:
+                kwd['ll_hash_function'] = self.key_repr.get_ll_hash_function()
+                kwd['ll_eq_function'] = self.key_repr.get_ll_eq_function()
+                kwd['ll_fasthash_function'] = 
self.key_repr.get_ll_fasthash_function()
+            kwd['dummykeyobj'] = self.key_repr.get_ll_dummyval_obj(self.rtyper,
+                                                                   s_key)
+            kwd['dummyvalueobj'] = self.value_repr.get_ll_dummyval_obj(
+                self.rtyper, s_value)
 
-            # * the state of the entry - trying to encode it as dummy objects
-            if nullkeymarker and dummykeyobj:
-                # all the state can be encoded in the key
-                entrymeths['everused'] = ll_everused_from_key
-                entrymeths['dummy_obj'] = dummykeyobj
-                entrymeths['valid'] = ll_valid_from_key
-                entrymeths['mark_deleted'] = ll_mark_deleted_in_key
-                # the key is overwritten by 'dummy' when the entry is deleted
-                entrymeths['must_clear_key'] = False
-
-            elif nullvaluemarker and dummyvalueobj:
-                # all the state can be encoded in the value
-                entrymeths['everused'] = ll_everused_from_value
-                entrymeths['dummy_obj'] = dummyvalueobj
-                entrymeths['valid'] = ll_valid_from_value
-                entrymeths['mark_deleted'] = ll_mark_deleted_in_value
-                # value is overwritten by 'dummy' when entry is deleted
-                entrymeths['must_clear_value'] = False
-
-            else:
-                # we need a flag to know if the entry was ever used
-                # (we cannot use a NULL as a marker for this, because
-                # the key and value will be reset to NULL to clear their
-                # reference)
-                entryfields.append(("f_everused", lltype.Bool))
-                entrymeths['everused'] = ll_everused_from_flag
-
-                # can we still rely on a dummy obj to mark deleted entries?
-                if dummykeyobj:
-                    entrymeths['dummy_obj'] = dummykeyobj
-                    entrymeths['valid'] = ll_valid_from_key
-                    entrymeths['mark_deleted'] = ll_mark_deleted_in_key
-                    # key is overwritten by 'dummy' when entry is deleted
-                    entrymeths['must_clear_key'] = False
-                elif dummyvalueobj:
-                    entrymeths['dummy_obj'] = dummyvalueobj
-                    entrymeths['valid'] = ll_valid_from_value
-                    entrymeths['mark_deleted'] = ll_mark_deleted_in_value
-                    # value is overwritten by 'dummy' when entry is deleted
-                    entrymeths['must_clear_value'] = False
-                else:
-                    entryfields.append(("f_valid", lltype.Bool))
-                    entrymeths['valid'] = ll_valid_from_flag
-                    entrymeths['mark_deleted'] = ll_mark_deleted_in_flag
-
-            # * the value
-            entryfields.append(("value", self.DICTVALUE))
-
-            # * the hash, if needed
-            if self.custom_eq_hash:
-                fasthashfn = None
-            else:
-                fasthashfn = self.key_repr.get_ll_fasthash_function()
-            if fasthashfn is None:
-                entryfields.append(("f_hash", lltype.Signed))
-                entrymeths['hash'] = ll_hash_from_cache
-            else:
-                entrymeths['hash'] = ll_hash_recomputed
-                entrymeths['fasthashfn'] = fasthashfn
-
-            # Build the lltype data structures
-            self.DICTENTRY = lltype.Struct("dictentry", *entryfields)
-            self.DICTENTRYARRAY = lltype.GcArray(self.DICTENTRY,
-                                                 adtmeths=entrymeths)
-            fields =          [ ("num_items", lltype.Signed),
-                                ("resize_counter", lltype.Signed),
-                                ("entries", lltype.Ptr(self.DICTENTRYARRAY)) ]
-            if self.custom_eq_hash:
-                self.r_rdict_eqfn, self.r_rdict_hashfn = 
self._custom_eq_hash_repr()
-                fields.extend([ ("fnkeyeq", self.r_rdict_eqfn.lowleveltype),
-                                ("fnkeyhash", 
self.r_rdict_hashfn.lowleveltype) ])
-                adtmeths = {
-                    'keyhash':        ll_keyhash_custom,
-                    'keyeq':          ll_keyeq_custom,
-                    'r_rdict_eqfn':   self.r_rdict_eqfn,
-                    'r_rdict_hashfn': self.r_rdict_hashfn,
-                    'paranoia':       True,
-                    }
-            else:
-                # figure out which functions must be used to hash and compare
-                ll_keyhash = self.key_repr.get_ll_hash_function()
-                ll_keyeq = self.key_repr.get_ll_eq_function()  # can be None
-                ll_keyhash = lltype.staticAdtMethod(ll_keyhash)
-                if ll_keyeq is not None:
-                    ll_keyeq = lltype.staticAdtMethod(ll_keyeq)
-                adtmeths = {
-                    'keyhash':  ll_keyhash,
-                    'keyeq':    ll_keyeq,
-                    'paranoia': False,
-                    }
-            adtmeths['KEY']   = self.DICTKEY
-            adtmeths['VALUE'] = self.DICTVALUE
-            adtmeths['allocate'] = lltype.typeMethod(_ll_malloc_dict)
-            self.DICT.become(lltype.GcStruct("dicttable", adtmeths=adtmeths,
-                                             *fields))
-
+            get_ll_dict(DICTKEY, DICTVALUE, DICT=self.DICT, **kwd)
 
     def convert_const(self, dictobj):
+        XXX
         from rpython.rtyper.lltypesystem import llmemory
         # get object from bound dict methods
         #dictobj = getattr(dictobj, '__self__', dictobj)
@@ -384,36 +383,57 @@
 #  be direct_call'ed from rtyped flow graphs, which means that they will
 #  get flowed and annotated, mostly with SomePtr.
 
-def ll_everused_from_flag(entries, i):
-    return entries[i].f_everused
+DICTINDEX_LONG = lltype.Ptr(lltype.GcArray(lltype.Unsigned))
+DICTINDEX_INT = lltype.Ptr(lltype.GcArray(rffi.UINT))
+DICTINDEX_SHORT = lltype.Ptr(lltype.GcArray(rffi.USHORT))
+DICTINDEX_BYTE = lltype.Ptr(lltype.GcArray(rffi.UCHAR))
 
-def ll_everused_from_key(entries, i):
-    return bool(entries[i].key)
+IS_64BIT = sys.maxint != 2 ** 31 - 1
 
-def ll_everused_from_value(entries, i):
-    return bool(entries[i].value)
+def ll_malloc_indexes_and_choose_lookup(d, n):
+    DICT = lltype.typeOf(d).TO
+    if n <= 256:
+        d.indexes = lltype.cast_opaque_ptr(llmemory.GCREF,
+                                           lltype.malloc(DICTINDEX_BYTE.TO, n,
+                                                         zero=True))
+        d.lookup_function = DICT.byte_lookup_function
+    elif n <= 65536:
+        d.indexes = lltype.cast_opaque_ptr(llmemory.GCREF,
+                                           lltype.malloc(DICTINDEX_SHORT.TO, n,
+                                                         zero=True))
+        d.lookup_function = DICT.short_lookup_function
+    elif IS_64BIT and n <= 2 ** 32:
+        d.indexes = lltype.cast_opaque_ptr(llmemory.GCREF,
+                                           lltype.malloc(DICTINDEX_INT.TO, n,
+                                                         zero=True))
+        d.lookup_function = DICT.int_lookup_function
+    else:
+        d.indexes = lltype.cast_opaque_ptr(llmemory.GCREF,
+                                           lltype.malloc(DICTINDEX_LONG.TO, n,
+                                                         zero=True))
+        d.lookup_function = DICT.long_lookup_function
 
 def ll_valid_from_flag(entries, i):
     return entries[i].f_valid
 
-def ll_mark_deleted_in_flag(entries, i):
-    entries[i].f_valid = False
-
 def ll_valid_from_key(entries, i):
     ENTRIES = lltype.typeOf(entries).TO
     dummy = ENTRIES.dummy_obj.ll_dummy_value
     return entries.everused(i) and entries[i].key != dummy
 
+def ll_valid_from_value(entries, i):
+    ENTRIES = lltype.typeOf(entries).TO
+    dummy = ENTRIES.dummy_obj.ll_dummy_value
+    return entries.everused(i) and entries[i].value != dummy
+
+def ll_mark_deleted_in_flag(entries, i):
+    entries[i].f_valid = False
+
 def ll_mark_deleted_in_key(entries, i):
     ENTRIES = lltype.typeOf(entries).TO
     dummy = ENTRIES.dummy_obj.ll_dummy_value
     entries[i].key = dummy
 
-def ll_valid_from_value(entries, i):
-    ENTRIES = lltype.typeOf(entries).TO
-    dummy = ENTRIES.dummy_obj.ll_dummy_value
-    return entries.everused(i) and entries[i].value != dummy
-
 def ll_mark_deleted_in_value(entries, i):
     ENTRIES = lltype.typeOf(entries).TO
     dummy = ENTRIES.dummy_obj.ll_dummy_value
@@ -426,9 +446,6 @@
     ENTRIES = lltype.typeOf(entries).TO
     return ENTRIES.fasthashfn(entries[i].key)
 
-def ll_get_value(d, i):
-    return d.entries[i].value
-
 def ll_keyhash_custom(d, key):
     DICT = lltype.typeOf(d).TO
     return objectmodel.hlinvoke(DICT.r_rdict_hashfn, d.fnkeyhash, key)
@@ -445,47 +462,64 @@
     return bool(d) and d.num_items != 0
 
 def ll_dict_getitem(d, key):
-    i = ll_dict_lookup(d, key, d.keyhash(key))
-    if not i & HIGHEST_BIT:
-        return ll_get_value(d, i)
+    index = d.lookup_function(d, key, d.keyhash(key), FLAG_LOOKUP)
+    if index != -1:
+        return d.entries[index].value
     else:
         raise KeyError
 
 def ll_dict_setitem(d, key, value):
     hash = d.keyhash(key)
-    i = ll_dict_lookup(d, key, hash)
-    return _ll_dict_setitem_lookup_done(d, key, value, hash, i)
+    index = d.lookup_function(d, key, hash, FLAG_STORE)
+    return _ll_dict_setitem_lookup_done(d, key, value, hash, index)
 
 # It may be safe to look inside always, it has a few branches though, and their
 # frequencies needs to be investigated.
 @jit.look_inside_iff(lambda d, key, value, hash, i: jit.isvirtual(d) and 
jit.isconstant(key))
 def _ll_dict_setitem_lookup_done(d, key, value, hash, i):
-    valid = (i & HIGHEST_BIT) == 0
-    i = i & MASK
     ENTRY = lltype.typeOf(d.entries).TO.OF
-    entry = d.entries[i]
-    if not d.entries.everused(i):
-        # a new entry that was never used before
-        ll_assert(not valid, "valid but not everused")
+    if i >= 0:
+        entry = d.entries[i]
+        entry.value = value
+    else:
+        if len(d.entries) == d.num_used_items:
+            ll_dict_grow(d)
+        entry = d.entries[d.num_used_items]
+        entry.key = key
+        entry.value = value
+        if hasattr(ENTRY, 'f_hash'):
+            entry.f_hash = hash
+        if hasattr(ENTRY, 'f_valid'):
+            entry.f_valid = True
+        d.num_used_items += 1
+        d.num_items += 1
         rc = d.resize_counter - 3
-        if rc <= 0:       # if needed, resize the dict -- before the insertion
+        if rc <= 0:
+            XXX
             ll_dict_resize(d)
             i = ll_dict_lookup_clean(d, hash)  # then redo the lookup for 'key'
             entry = d.entries[i]
             rc = d.resize_counter - 3
             ll_assert(rc > 0, "ll_dict_resize failed?")
         d.resize_counter = rc
-        if hasattr(ENTRY, 'f_everused'): entry.f_everused = True
-        entry.value = value
+
+def ll_dict_grow(d):
+    # This over-allocates proportional to the list size, making room
+    # for additional growth.  The over-allocation is mild, but is
+    # enough to give linear-time amortized behavior over a long
+    # sequence of appends() in the presence of a poorly-performing
+    # system malloc().
+    # The growth pattern is:  0, 4, 8, 16, 25, 35, 46, 58, 72, 88, ...
+    newsize = len(d.entries) + 1
+    if newsize < 9:
+        some = 3
     else:
-        # override an existing or deleted entry
-        entry.value = value
-        if valid:
-            return
-    entry.key = key
-    if hasattr(ENTRY, 'f_hash'):  entry.f_hash = hash
-    if hasattr(ENTRY, 'f_valid'): entry.f_valid = True
-    d.num_items += 1
+        some = 6
+    some += newsize >> 3
+    new_allocated = newsize + some
+    newitems = lltype.malloc(lltype.typeOf(d).TO.entries.TO, new_allocated)
+    rgc.ll_arraycopy(d.entries, newitems, 0, 0, len(d.entries))
+    d.entries = newitems
 
 def ll_dict_insertclean(d, key, value, hash):
     # Internal routine used by ll_dict_resize() to insert an item which is
@@ -565,67 +599,103 @@
 # ------- a port of CPython's dictobject.c's lookdict implementation -------
 PERTURB_SHIFT = 5
 
[email protected]_inside_iff(lambda d, key, hash: jit.isvirtual(d) and 
jit.isconstant(key))
-def ll_dict_lookup(d, key, hash):
-    entries = d.entries
-    ENTRIES = lltype.typeOf(entries).TO
-    direct_compare = not hasattr(ENTRIES, 'no_direct_compare')
-    mask = len(entries) - 1
-    i = r_uint(hash & mask)
-    # do the first try before any looping
-    if entries.valid(i):
-        checkingkey = entries[i].key
-        if direct_compare and checkingkey == key:
-            return i   # found the entry
-        if d.keyeq is not None and entries.hash(i) == hash:
-            # correct hash, maybe the key is e.g. a different pointer to
-            # an equal object
-            found = d.keyeq(checkingkey, key)
-            if d.paranoia:
-                if (entries != d.entries or
-                    not entries.valid(i) or entries[i].key != checkingkey):
-                    # the compare did major nasty stuff to the dict: start over
-                    return ll_dict_lookup(d, key, hash)
-            if found:
-                return i   # found the entry
-        freeslot = -1
-    elif entries.everused(i):
-        freeslot = intmask(i)
-    else:
-        return i | HIGHEST_BIT # pristine entry -- lookup failed
+FREE = 0
+DELETED = 1
+VALID_OFFSET = 2
 
-    # In the loop, a deleted entry (everused and not valid) is by far
-    # (factor of 100s) the least likely outcome, so test for that last.
-    perturb = r_uint(hash)
-    while 1:
-        # compute the next index using unsigned arithmetic
-        i = (i << 2) + i + perturb + 1
-        i = i & mask
-        # keep 'i' as a signed number here, to consistently pass signed
-        # arguments to the small helper methods.
-        if not entries.everused(i):
-            if freeslot == -1:
-                freeslot = intmask(i)
-            return r_uint(freeslot) | HIGHEST_BIT
-        elif entries.valid(i):
-            checkingkey = entries[i].key
+FLAG_LOOKUP = 0
+FLAG_STORE = 1
+FLAG_DELETE = 2
+
+def new_lookup_function(LOOKUP_FUNC, T):
+    INDEXES = lltype.Ptr(lltype.GcArray(T))
+
+    @jit.look_inside_iff(lambda d, key, hash, store_flag:
+                         jit.isvirtual(d) and jit.isconstant(key))
+    def ll_dict_lookup(d, key, hash, store_flag):
+        entries = d.entries
+        indexes = lltype.cast_opaque_ptr(INDEXES, d.indexes)
+        mask = len(indexes) - 1
+        i = hash & mask
+        # do the first try before any looping
+        ENTRIES = lltype.typeOf(entries).TO
+        direct_compare = not hasattr(ENTRIES, 'no_direct_compare')
+        index = rffi.cast(lltype.Signed, indexes[i])
+        if index >= VALID_OFFSET:
+            checkingkey = entries[index - VALID_OFFSET].key
             if direct_compare and checkingkey == key:
-                return i
-            if d.keyeq is not None and entries.hash(i) == hash:
+                XXX
+                return index   # found the entry
+            if d.keyeq is not None and entries.hash(index - VALID_OFFSET) == 
hash:
                 # correct hash, maybe the key is e.g. a different pointer to
                 # an equal object
                 found = d.keyeq(checkingkey, key)
+                #llop.debug_print(lltype.Void, "comparing keys", 
ll_debugrepr(checkingkey), ll_debugrepr(key), found)
                 if d.paranoia:
-                    if (entries != d.entries or
-                        not entries.valid(i) or entries[i].key != checkingkey):
-                        # the compare did major nasty stuff to the dict:
-                        # start over
-                        return ll_dict_lookup(d, key, hash)
+                    XXX
+                    if (entries != d.entries or indexes != d.indexes or
+                        not entries.valid(ll_index_getitem(d.size, indexes, i))
+                        or entries.getitem_clean(index).key != checkingkey):
+                        # the compare did major nasty stuff to the dict: start 
over
+                        if d_signed_indexes(d):
+                            return ll_dict_lookup(d, key, hash,
+                                                  ll_index_getitem_signed)
+                        else:
+                            return ll_dict_lookup(d, key, hash,
+                                                  ll_index_getitem_int)
                 if found:
-                    return i   # found the entry
-        elif freeslot == -1:
-            freeslot = intmask(i)
-        perturb >>= PERTURB_SHIFT
+                    return index - VALID_OFFSET
+            freeslot = -1
+        elif index == DELETED:
+            freeslot = i
+        else:
+            # pristine entry -- lookup failed
+            if store_flag == FLAG_STORE:
+                indexes[i] = rffi.cast(T, d.num_used_items + VALID_OFFSET)
+            return -1
+
+        # In the loop, a deleted entry (everused and not valid) is by far
+        # (factor of 100s) the least likely outcome, so test for that last.
+        XXX
+        perturb = r_uint(hash)
+        while 1:
+            # compute the next index using unsigned arithmetic
+            i = r_uint(i)
+            i = (i << 2) + i + perturb + 1
+            i = intmask(i) & mask
+            index = ll_index_getitem(d.size, indexes, i)
+            # keep 'i' as a signed number here, to consistently pass signed
+            # arguments to the small helper methods.
+            if index == FREE:
+                if freeslot == -1:
+                    freeslot = i
+                return freeslot | HIGHEST_BIT
+            elif entries.valid(index):
+                checkingkey = entries.getitem_clean(index).key
+                if direct_compare and checkingkey == key:
+                    return i
+                if d.keyeq is not None and entries.hash(index) == hash:
+                    # correct hash, maybe the key is e.g. a different pointer 
to
+                    # an equal object
+                    found = d.keyeq(checkingkey, key)
+                    if d.paranoia:
+                        if (entries != d.entries or indexes != d.indexes or
+                            not entries.valid(ll_index_getitem(d.size, 
indexes, i)) or
+                            entries.getitem_clean(index).key != checkingkey):
+                            # the compare did major nasty stuff to the dict:
+                            # start over
+                            if d_signed_indexes(d):
+                                return ll_dict_lookup(d, key, hash,
+                                                      ll_index_getitem_signed)
+                            else:
+                                return ll_dict_lookup(d, key, hash,
+                                                      ll_index_getitem_int)
+                    if found:
+                        return i   # found the entry
+            elif freeslot == -1:
+                freeslot = i
+            perturb >>= PERTURB_SHIFT
+    return llhelper(LOOKUP_FUNC, ll_dict_lookup)
 
 def ll_dict_lookup_clean(d, hash):
     # a simplified version of ll_dict_lookup() which assumes that the
@@ -649,12 +719,15 @@
 
 def ll_newdict(DICT):
     d = DICT.allocate()
-    d.entries = DICT.entries.TO.allocate(DICT_INITSIZE)
+    d.entries = DICT.empty_array
+    ll_malloc_indexes_and_choose_lookup(d, DICT_INITSIZE)
     d.num_items = 0
+    d.num_used_items = 0
     d.resize_counter = DICT_INITSIZE * 2
     return d
 
 def ll_newdict_size(DICT, length_estimate):
+    xxx
     length_estimate = (length_estimate // 2) * 3
     n = DICT_INITSIZE
     while n < length_estimate:
diff --git a/rpython/rtyper/test/test_rdict.py 
b/rpython/rtyper/test/test_rdict.py
--- a/rpython/rtyper/test/test_rdict.py
+++ b/rpython/rtyper/test/test_rdict.py
@@ -4,7 +4,9 @@
 from rpython.rtyper.lltypesystem import rdict, rstr
 from rpython.rtyper.test.tool import BaseRtypingTest
 from rpython.rlib.objectmodel import r_dict
-from rpython.rlib.rarithmetic import r_int, r_uint, r_longlong, r_ulonglong
+from rpython.rlib.rarithmetic import r_int, r_uint, r_longlong, r_ulonglong,\
+     intmask
+from rpython.rtyper.annlowlevel import llstr, hlstr
 
 import py
 py.log.setconsumer("rtyper", py.log.STDOUT)
@@ -21,6 +23,108 @@
         assert 0 <= x < 4
         yield x
 
+def foreach_index(ll_d):
+    indexes = ll_d.indexes._obj.container._as_ptr()
+    for i in range(len(indexes)):
+        yield rffi.cast(lltype.Signed, indexes[i])
+
+def count_items(ll_d, ITEM):
+    c = 0
+    for item in foreach_index(ll_d):
+        if item == ITEM:
+            c += 1
+    return c
+
+class TestRDictDirect(object):
+    def _get_str_dict(self):
+        # STR -> lltype.Signed
+        DICT = rdict.get_ll_dict(lltype.Ptr(rstr.STR), lltype.Signed,
+                                 
ll_fasthash_function=rstr.LLHelpers.ll_strhash,
+                                 ll_hash_function=rstr.LLHelpers.ll_strhash,
+                                 ll_eq_function=rstr.LLHelpers.ll_streq)
+        return DICT
+
+    def test_dict_creation(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        rdict.ll_dict_setitem(ll_d, llstr("abc"), 13)
+        assert count_items(ll_d, rdict.FREE) == rdict.DICT_INITSIZE - 1
+        assert rdict.ll_dict_getitem(ll_d, llstr("abc")) == 13
+
+    def test_dict_del_lastitem(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        py.test.raises(KeyError, rdict.ll_dict_delitem, ll_d, llstr("abc"))
+        rdict.ll_dict_setitem(ll_d, llstr("abc"), 13)
+        py.test.raises(KeyError, rdict.ll_dict_delitem, ll_d, llstr("def"))
+        rdict.ll_dict_delitem(ll_d, llstr("abc"))
+        assert count_items(ll_d, rdict.FREE) == rdict.DICT_INITSIZE - 1
+        assert count_items(ll_d, rdict.DELETED) == 1
+        py.test.raises(KeyError, rdict.ll_dict_getitem, ll_d, llstr("abc"))
+
+    def test_dict_del_not_lastitem(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        rdict.ll_dict_setitem(ll_d, llstr("abc"), 13)
+        rdict.ll_dict_setitem(ll_d, llstr("def"), 15)
+        rdict.ll_dict_delitem(ll_d, llstr("abc"))
+        assert count_items(ll_d, rdict.FREE) == rdict.DICT_INITSIZE - 2
+        assert count_items(ll_d, rdict.DELETED) == 1
+
+    def test_dict_resize(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        rdict.ll_dict_setitem(ll_d, llstr("a"), 1)
+        rdict.ll_dict_setitem(ll_d, llstr("b"), 2)
+        rdict.ll_dict_setitem(ll_d, llstr("c"), 3)
+        rdict.ll_dict_setitem(ll_d, llstr("d"), 4)
+        assert ll_d.size == 8
+        rdict.ll_dict_setitem(ll_d, llstr("e"), 5)
+        rdict.ll_dict_setitem(ll_d, llstr("f"), 6)
+        assert ll_d.size == 32
+        for item in ['a', 'b', 'c', 'd', 'e', 'f']:
+            assert rdict.ll_dict_getitem(ll_d, llstr(item)) == ord(item) - 
ord('a') + 1
+
+    def test_dict_iteration(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        rdict.ll_dict_setitem(ll_d, llstr("k"), 1)
+        rdict.ll_dict_setitem(ll_d, llstr("j"), 2)
+        ITER = rdict.get_ll_dictiter(lltype.Ptr(DICT))
+        ll_iter = rdict.ll_dictiter(ITER, ll_d)
+        ll_iterkeys = rdict.ll_dictnext_group['keys']
+        next = ll_iterkeys(lltype.Signed, ll_iter)
+        assert hlstr(next) == "k"
+        next = ll_iterkeys(lltype.Signed, ll_iter)
+        assert hlstr(next) == "j"
+        py.test.raises(StopIteration, ll_iterkeys, lltype.Signed, ll_iter)
+
+    def test_popitem(self):
+        DICT = self._get_str_dict()
+        ll_d = rdict.ll_newdict(DICT)
+        rdict.ll_dict_setitem(ll_d, llstr("k"), 1)
+        rdict.ll_dict_setitem(ll_d, llstr("j"), 2)
+        ll_elem = rdict.ll_popitem(lltype.Ptr(
+            lltype.GcStruct('x', ('item0', lltype.Ptr(rstr.STR)),
+                            ('item1', lltype.Signed))), ll_d)
+        assert hlstr(ll_elem.item0) == "j"
+        assert ll_elem.item1 == 2
+
+    def test_direct_enter_and_del(self):
+        def eq(a, b):
+            return a == b
+
+        DICT = rdict.get_ll_dict(lltype.Signed, lltype.Signed,
+                                 ll_fasthash_function=intmask,
+                                 ll_hash_function=intmask,
+                                 ll_eq_function=eq)
+        ll_d = rdict.ll_newdict(DICT)
+        numbers = [i * rdict.DICT_INITSIZE + 1 for i in range(8)]
+        for num in numbers:
+            rdict.ll_dict_setitem(ll_d, num, 1)
+            rdict.ll_dict_delitem(ll_d, num)
+            for k in foreach_index(ll_d):
+                assert k < 0
 
 class TestRdict(BaseRtypingTest):
 
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to