https://github.com/python/cpython/commit/0990d55725cb649e74739c983b67cf08c58e8439
commit: 0990d55725cb649e74739c983b67cf08c58e8439
branch: main
author: Dino Viehland <[email protected]>
committer: DinoV <[email protected]>
date: 2024-01-30T09:33:36-08:00
summary:

gh-112075: refactor dictionary lookup functions for better re-usability 
(#114629)

Refactor dict lookup functions to use force inline helpers

files:
M Objects/dictobject.c

diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index c5477ab15f8dc9..23d7e9b5e38a35 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -874,11 +874,11 @@ lookdict_index(PyDictKeysObject *k, Py_hash_t hash, 
Py_ssize_t index)
     Py_UNREACHABLE();
 }
 
-// Search non-Unicode key from Unicode table
-static Py_ssize_t
-unicodekeys_lookup_generic(PyDictObject *mp, PyDictKeysObject* dk, PyObject 
*key, Py_hash_t hash)
+static inline Py_ALWAYS_INLINE Py_ssize_t
+do_lookup(PyDictObject *mp, PyDictKeysObject *dk, PyObject *key, Py_hash_t 
hash,
+          Py_ssize_t (*check_lookup)(PyDictObject *, PyDictKeysObject *, void 
*, Py_ssize_t ix, PyObject *key, Py_hash_t))
 {
-    PyDictUnicodeEntry *ep0 = DK_UNICODE_ENTRIES(dk);
+    void *ep0 = _DK_ENTRIES(dk);
     size_t mask = DK_MASK(dk);
     size_t perturb = hash;
     size_t i = (size_t)hash & mask;
@@ -886,73 +886,26 @@ unicodekeys_lookup_generic(PyDictObject *mp, 
PyDictKeysObject* dk, PyObject *key
     for (;;) {
         ix = dictkeys_get_index(dk, i);
         if (ix >= 0) {
-            PyDictUnicodeEntry *ep = &ep0[ix];
-            assert(ep->me_key != NULL);
-            assert(PyUnicode_CheckExact(ep->me_key));
-            if (ep->me_key == key) {
+            Py_ssize_t cmp = check_lookup(mp, dk, ep0, ix, key, hash);
+            if (cmp < 0) {
+                return cmp;
+            } else if (cmp) {
                 return ix;
             }
-            if (unicode_get_hash(ep->me_key) == hash) {
-                PyObject *startkey = ep->me_key;
-                Py_INCREF(startkey);
-                int cmp = PyObject_RichCompareBool(startkey, key, Py_EQ);
-                Py_DECREF(startkey);
-                if (cmp < 0) {
-                    return DKIX_ERROR;
-                }
-                if (dk == mp->ma_keys && ep->me_key == startkey) {
-                    if (cmp > 0) {
-                        return ix;
-                    }
-                }
-                else {
-                    /* The dict was mutated, restart */
-                    return DKIX_KEY_CHANGED;
-                }
-            }
         }
         else if (ix == DKIX_EMPTY) {
             return DKIX_EMPTY;
         }
         perturb >>= PERTURB_SHIFT;
         i = mask & (i*5 + perturb + 1);
-    }
-    Py_UNREACHABLE();
-}
 
-// Search Unicode key from Unicode table.
-static Py_ssize_t _Py_HOT_FUNCTION
-unicodekeys_lookup_unicode(PyDictKeysObject* dk, PyObject *key, Py_hash_t hash)
-{
-    PyDictUnicodeEntry *ep0 = DK_UNICODE_ENTRIES(dk);
-    size_t mask = DK_MASK(dk);
-    size_t perturb = hash;
-    size_t i = (size_t)hash & mask;
-    Py_ssize_t ix;
-    for (;;) {
-        ix = dictkeys_get_index(dk, i);
-        if (ix >= 0) {
-            PyDictUnicodeEntry *ep = &ep0[ix];
-            assert(ep->me_key != NULL);
-            assert(PyUnicode_CheckExact(ep->me_key));
-            if (ep->me_key == key ||
-                    (unicode_get_hash(ep->me_key) == hash && 
unicode_eq(ep->me_key, key))) {
-                return ix;
-            }
-        }
-        else if (ix == DKIX_EMPTY) {
-            return DKIX_EMPTY;
-        }
-        perturb >>= PERTURB_SHIFT;
-        i = mask & (i*5 + perturb + 1);
         // Manual loop unrolling
         ix = dictkeys_get_index(dk, i);
         if (ix >= 0) {
-            PyDictUnicodeEntry *ep = &ep0[ix];
-            assert(ep->me_key != NULL);
-            assert(PyUnicode_CheckExact(ep->me_key));
-            if (ep->me_key == key ||
-                    (unicode_get_hash(ep->me_key) == hash && 
unicode_eq(ep->me_key, key))) {
+            Py_ssize_t cmp = check_lookup(mp, dk, ep0, ix, key, hash);
+            if (cmp < 0) {
+                return cmp;
+            } else if (cmp) {
                 return ix;
             }
         }
@@ -965,49 +918,94 @@ unicodekeys_lookup_unicode(PyDictKeysObject* dk, PyObject 
*key, Py_hash_t hash)
     Py_UNREACHABLE();
 }
 
-// Search key from Generic table.
+static inline Py_ALWAYS_INLINE Py_ssize_t
+compare_unicode_generic(PyDictObject *mp, PyDictKeysObject *dk,
+                        void *ep0, Py_ssize_t ix, PyObject *key, Py_hash_t 
hash)
+{
+    PyDictUnicodeEntry *ep = &((PyDictUnicodeEntry *)ep0)[ix];
+    assert(ep->me_key != NULL);
+    assert(PyUnicode_CheckExact(ep->me_key));
+    assert(!PyUnicode_CheckExact(key));
+    // TODO: Thread safety
+
+    if (unicode_get_hash(ep->me_key) == hash) {
+        PyObject *startkey = ep->me_key;
+        Py_INCREF(startkey);
+        int cmp = PyObject_RichCompareBool(startkey, key, Py_EQ);
+        Py_DECREF(startkey);
+        if (cmp < 0) {
+            return DKIX_ERROR;
+        }
+        if (dk == mp->ma_keys && ep->me_key == startkey) {
+            return cmp;
+        }
+        else {
+            /* The dict was mutated, restart */
+            return DKIX_KEY_CHANGED;
+        }
+    }
+    return 0;
+}
+
+// Search non-Unicode key from Unicode table
 static Py_ssize_t
-dictkeys_generic_lookup(PyDictObject *mp, PyDictKeysObject* dk, PyObject *key, 
Py_hash_t hash)
+unicodekeys_lookup_generic(PyDictObject *mp, PyDictKeysObject* dk, PyObject 
*key, Py_hash_t hash)
 {
-    PyDictKeyEntry *ep0 = DK_ENTRIES(dk);
-    size_t mask = DK_MASK(dk);
-    size_t perturb = hash;
-    size_t i = (size_t)hash & mask;
-    Py_ssize_t ix;
-    for (;;) {
-        ix = dictkeys_get_index(dk, i);
-        if (ix >= 0) {
-            PyDictKeyEntry *ep = &ep0[ix];
-            assert(ep->me_key != NULL);
-            if (ep->me_key == key) {
-                return ix;
-            }
-            if (ep->me_hash == hash) {
-                PyObject *startkey = ep->me_key;
-                Py_INCREF(startkey);
-                int cmp = PyObject_RichCompareBool(startkey, key, Py_EQ);
-                Py_DECREF(startkey);
-                if (cmp < 0) {
-                    return DKIX_ERROR;
-                }
-                if (dk == mp->ma_keys && ep->me_key == startkey) {
-                    if (cmp > 0) {
-                        return ix;
-                    }
-                }
-                else {
-                    /* The dict was mutated, restart */
-                    return DKIX_KEY_CHANGED;
-                }
-            }
+    return do_lookup(mp, dk, key, hash, compare_unicode_generic);
+}
+
+static inline Py_ALWAYS_INLINE Py_ssize_t
+compare_unicode_unicode(PyDictObject *mp, PyDictKeysObject *dk,
+                        void *ep0, Py_ssize_t ix, PyObject *key, Py_hash_t 
hash)
+{
+    PyDictUnicodeEntry *ep = &((PyDictUnicodeEntry *)ep0)[ix];
+    assert(ep->me_key != NULL);
+    assert(PyUnicode_CheckExact(ep->me_key));
+    if (ep->me_key == key ||
+            (unicode_get_hash(ep->me_key) == hash && unicode_eq(ep->me_key, 
key))) {
+        return 1;
+    }
+    return 0;
+}
+
+static Py_ssize_t _Py_HOT_FUNCTION
+unicodekeys_lookup_unicode(PyDictKeysObject* dk, PyObject *key, Py_hash_t hash)
+{
+    return do_lookup(NULL, dk, key, hash, compare_unicode_unicode);
+}
+
+static inline Py_ALWAYS_INLINE Py_ssize_t
+compare_generic(PyDictObject *mp, PyDictKeysObject *dk,
+                void *ep0, Py_ssize_t ix, PyObject *key, Py_hash_t hash)
+{
+    PyDictKeyEntry *ep = &((PyDictKeyEntry *)ep0)[ix];
+    assert(ep->me_key != NULL);
+    if (ep->me_key == key) {
+        return 1;
+    }
+    if (ep->me_hash == hash) {
+        PyObject *startkey = ep->me_key;
+        Py_INCREF(startkey);
+        int cmp = PyObject_RichCompareBool(startkey, key, Py_EQ);
+        Py_DECREF(startkey);
+        if (cmp < 0) {
+            return DKIX_ERROR;
         }
-        else if (ix == DKIX_EMPTY) {
-            return DKIX_EMPTY;
+        if (dk == mp->ma_keys && ep->me_key == startkey) {
+            return cmp;
+        }
+        else {
+            /* The dict was mutated, restart */
+            return DKIX_KEY_CHANGED;
         }
-        perturb >>= PERTURB_SHIFT;
-        i = mask & (i*5 + perturb + 1);
     }
-    Py_UNREACHABLE();
+    return 0;
+}
+
+static Py_ssize_t
+dictkeys_generic_lookup(PyDictObject *mp, PyDictKeysObject* dk, PyObject *key, 
Py_hash_t hash)
+{
+    return do_lookup(mp, dk, key, hash, compare_generic);
 }
 
 /* Lookup a string in a (all unicode) dict keys.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to