https://github.com/python/cpython/commit/f653caa5a88d3b5027a8f286ff3a3ccd9e6fe4ed
commit: f653caa5a88d3b5027a8f286ff3a3ccd9e6fe4ed
branch: main
author: Peter Lazorchak <lazorch...@gmail.com>
committer: Fidget-Spinner <kenjin4...@gmail.com>
date: 2024-01-11T13:33:05+08:00
summary:

gh-89811: Check for valid tp_version_tag in specializer (GH-113558)

files:
A Misc/NEWS.d/next/Core and 
Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst
M Lib/test/test_type_cache.py
M Modules/_testcapimodule.c
M Python/specialize.c

diff --git a/Lib/test/test_type_cache.py b/Lib/test/test_type_cache.py
index 72587ecc11b6f3..95b55009c7187d 100644
--- a/Lib/test/test_type_cache.py
+++ b/Lib/test/test_type_cache.py
@@ -1,5 +1,6 @@
 """ Tests for the internal type cache in CPython. """
 import unittest
+import dis
 from test import support
 from test.support import import_helper
 try:
@@ -8,8 +9,11 @@
     _clear_type_cache = None
 
 # Skip this test if the _testcapi module isn't available.
-type_get_version = import_helper.import_module('_testcapi').type_get_version
-type_assign_version = 
import_helper.import_module('_testcapi').type_assign_version
+_testcapi = import_helper.import_module("_testcapi")
+type_get_version = _testcapi.type_get_version
+type_assign_specific_version_unsafe = 
_testcapi.type_assign_specific_version_unsafe
+type_assign_version = _testcapi.type_assign_version
+type_modified = _testcapi.type_modified
 
 
 @support.cpython_only
@@ -56,6 +60,183 @@ class C:
         self.assertNotEqual(type_get_version(C), 0)
         self.assertNotEqual(type_get_version(C), c_ver)
 
+    def test_type_assign_specific_version(self):
+        """meta-test for type_assign_specific_version_unsafe"""
+        class C:
+            pass
+
+        type_assign_version(C)
+        orig_version = type_get_version(C)
+        self.assertNotEqual(orig_version, 0)
+
+        type_modified(C)
+        type_assign_specific_version_unsafe(C, orig_version + 5)
+        type_assign_version(C)  # this should do nothing
+
+        new_version = type_get_version(C)
+        self.assertEqual(new_version, orig_version + 5)
+
+        _clear_type_cache()
+
+
+@support.cpython_only
+class TypeCacheWithSpecializationTests(unittest.TestCase):
+    def tearDown(self):
+        _clear_type_cache()
+
+    def _assign_and_check_valid_version(self, user_type):
+        type_modified(user_type)
+        type_assign_version(user_type)
+        self.assertNotEqual(type_get_version(user_type), 0)
+
+    def _assign_and_check_version_0(self, user_type):
+        type_modified(user_type)
+        type_assign_specific_version_unsafe(user_type, 0)
+        self.assertEqual(type_get_version(user_type), 0)
+
+    def _all_opnames(self, func):
+        return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))
+
+    def _check_specialization(self, func, arg, opname, *, should_specialize):
+        self.assertIn(opname, self._all_opnames(func))
+
+        for _ in range(100):
+            func(arg)
+
+        if should_specialize:
+            self.assertNotIn(opname, self._all_opnames(func))
+        else:
+            self.assertIn(opname, self._all_opnames(func))
+
+    def test_class_load_attr_specialization_user_type(self):
+        class A:
+            def foo(self):
+                pass
+
+        self._assign_and_check_valid_version(A)
+
+        def load_foo_1(type_):
+            type_.foo
+
+        self._check_specialization(load_foo_1, A, "LOAD_ATTR", 
should_specialize=True)
+        del load_foo_1
+
+        self._assign_and_check_version_0(A)
+
+        def load_foo_2(type_):
+            return type_.foo
+
+        self._check_specialization(load_foo_2, A, "LOAD_ATTR", 
should_specialize=False)
+
+    def test_class_load_attr_specialization_static_type(self):
+        self._assign_and_check_valid_version(str)
+        self._assign_and_check_valid_version(bytes)
+
+        def get_capitalize_1(type_):
+            return type_.capitalize
+
+        self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", 
should_specialize=True)
+        self.assertEqual(get_capitalize_1(str)('hello'), 'Hello')
+        self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello')
+        del get_capitalize_1
+
+        # Permanently overflow the static type version counter, and force str 
and bytes
+        # to have tp_version_tag == 0
+        for _ in range(2**16):
+            type_modified(str)
+            type_assign_version(str)
+            type_modified(bytes)
+            type_assign_version(bytes)
+
+        self.assertEqual(type_get_version(str), 0)
+        self.assertEqual(type_get_version(bytes), 0)
+
+        def get_capitalize_2(type_):
+            return type_.capitalize
+
+        self._check_specialization(get_capitalize_2, str, "LOAD_ATTR", 
should_specialize=False)
+        self.assertEqual(get_capitalize_2(str)('hello'), 'Hello')
+        self.assertEqual(get_capitalize_2(bytes)(b'hello'), b'Hello')
+
+    def test_property_load_attr_specialization_user_type(self):
+        class G:
+            @property
+            def x(self):
+                return 9
+
+        self._assign_and_check_valid_version(G)
+
+        def load_x_1(instance):
+            instance.x
+
+        self._check_specialization(load_x_1, G(), "LOAD_ATTR", 
should_specialize=True)
+        del load_x_1
+
+        self._assign_and_check_version_0(G)
+
+        def load_x_2(instance):
+            instance.x
+
+        self._check_specialization(load_x_2, G(), "LOAD_ATTR", 
should_specialize=False)
+
+    def test_store_attr_specialization_user_type(self):
+        class B:
+            __slots__ = ("bar",)
+
+        self._assign_and_check_valid_version(B)
+
+        def store_bar_1(type_):
+            type_.bar = 10
+
+        self._check_specialization(store_bar_1, B(), "STORE_ATTR", 
should_specialize=True)
+        del store_bar_1
+
+        self._assign_and_check_version_0(B)
+
+        def store_bar_2(type_):
+            type_.bar = 10
+
+        self._check_specialization(store_bar_2, B(), "STORE_ATTR", 
should_specialize=False)
+
+    def test_class_call_specialization_user_type(self):
+        class F:
+            def __init__(self):
+                pass
+
+        self._assign_and_check_valid_version(F)
+
+        def call_class_1(type_):
+            type_()
+
+        self._check_specialization(call_class_1, F, "CALL", 
should_specialize=True)
+        del call_class_1
+
+        self._assign_and_check_version_0(F)
+
+        def call_class_2(type_):
+            type_()
+
+        self._check_specialization(call_class_2, F, "CALL", 
should_specialize=False)
+
+    def test_to_bool_specialization_user_type(self):
+        class H:
+            pass
+
+        self._assign_and_check_valid_version(H)
+
+        def to_bool_1(instance):
+            not instance
+
+        self._check_specialization(to_bool_1, H(), "TO_BOOL", 
should_specialize=True)
+        del to_bool_1
+
+        self._assign_and_check_version_0(H)
+
+        def to_bool_2(instance):
+            not instance
+
+        self._check_specialization(to_bool_2, H(), "TO_BOOL", 
should_specialize=False)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Core and 
Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst b/Misc/NEWS.d/next/Core 
and Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst
new file mode 100644
index 00000000000000..90bd9814faffd5
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and 
Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst  
@@ -0,0 +1,2 @@
+Check for a valid ``tp_version_tag`` before performing bytecode 
specializations that
+rely on this value being usable.
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index 398570ff8e05c6..a0b21b7efbd971 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -2409,6 +2409,32 @@ type_get_version(PyObject *self, PyObject *type)
     return res;
 }
 
+static PyObject *
+type_modified(PyObject *self, PyObject *type)
+{
+    if (!PyType_Check(type)) {
+        PyErr_SetString(PyExc_TypeError, "argument must be a type");
+        return NULL;
+    }
+    PyType_Modified((PyTypeObject *)type);
+    Py_RETURN_NONE;
+}
+
+// Circumvents standard version assignment machinery - use with caution and 
only on
+// short-lived heap types
+static PyObject *
+type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
+{
+    PyTypeObject *type;
+    unsigned int version;
+    if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", 
&type, &version)) {
+        return NULL;
+    }
+    assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
+    type->tp_version_tag = version;
+    type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
+    Py_RETURN_NONE;
+}
 
 static PyObject *
 type_assign_version(PyObject *self, PyObject *type)
@@ -3342,6 +3368,9 @@ static PyMethodDef TestMethods[] = {
     {"test_py_is_macros", test_py_is_macros, METH_NOARGS},
     {"test_py_is_funcs", test_py_is_funcs, METH_NOARGS},
     {"type_get_version", type_get_version, METH_O, 
PyDoc_STR("type->tp_version_tag")},
+    {"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
+    {"type_assign_specific_version_unsafe", 
type_assign_specific_version_unsafe, METH_VARARGS,
+     PyDoc_STR("forcefully assign type->tp_version_tag")},
     {"type_assign_version", type_assign_version, METH_O, 
PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
     {"type_get_tp_bases", type_get_tp_bases, METH_O},
     {"type_get_tp_mro", type_get_tp_mro, METH_O},
diff --git a/Python/specialize.c b/Python/specialize.c
index 7b63393803b430..13e0440dd9dd0d 100644
--- a/Python/specialize.c
+++ b/Python/specialize.c
@@ -586,6 +586,7 @@ _PyCode_Quicken(PyCodeObject *code)
 static int function_kind(PyCodeObject *code);
 static bool function_check_args(PyObject *o, int expected_argcount, int 
opcode);
 static uint32_t function_get_version(PyObject *o, int opcode);
+static uint32_t type_get_version(PyTypeObject *t, int opcode);
 
 static int
 specialize_module_load_attr(
@@ -874,6 +875,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT 
*instr, PyObject *name)
     PyObject *descr = NULL;
     DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
     assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
+    if (type_get_version(type, LOAD_ATTR) == 0) {
+        goto fail;
+    }
     switch(kind) {
         case OVERRIDING:
             SPECIALIZATION_FAIL(LOAD_ATTR, 
SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1057,6 +1061,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT 
*instr, PyObject *name)
     }
     PyObject *descr;
     DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
+    if (type_get_version(type, STORE_ATTR) == 0) {
+        goto fail;
+    }
     switch(kind) {
         case OVERRIDING:
             SPECIALIZATION_FAIL(STORE_ATTR, 
SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1183,6 +1190,9 @@ specialize_class_load_attr(PyObject *owner, _Py_CODEUNIT 
*instr,
     PyObject *descr = NULL;
     DescriptorClassification kind = 0;
     kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
+    if (type_get_version((PyTypeObject *)owner, LOAD_ATTR) == 0) {
+        return -1;
+    }
     switch (kind) {
         case METHOD:
         case NON_DESCRIPTOR:
@@ -1455,6 +1465,18 @@ function_get_version(PyObject *o, int opcode)
     return version;
 }
 
+/* Returning 0 indicates a failure. */
+static uint32_t
+type_get_version(PyTypeObject *t, int opcode)
+{
+    uint32_t version = t->tp_version_tag;
+    if (version == 0) {
+        SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
+        return 0;
+    }
+    return version;
+}
+
 void
 _Py_Specialize_BinarySubscr(
      PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
@@ -1726,6 +1748,9 @@ specialize_class_call(PyObject *callable, _Py_CODEUNIT 
*instr, int nargs)
     }
     if (tp->tp_new == PyBaseObject_Type.tp_new) {
         PyFunctionObject *init = get_init_for_simple_managed_python_class(tp);
+        if (type_get_version(tp, CALL) == 0) {
+            return -1;
+        }
         if (init != NULL) {
             if (((PyCodeObject *)init->func_code)->co_argcount != nargs+1) {
                 SPECIALIZATION_FAIL(CALL, SPEC_FAIL_WRONG_NUMBER_ARGUMENTS);
@@ -2466,7 +2491,10 @@ _Py_Specialize_ToBool(PyObject *value, _Py_CODEUNIT 
*instr)
             SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS);
             goto failure;
         }
-        uint32_t version = Py_TYPE(value)->tp_version_tag;
+        uint32_t version = type_get_version(Py_TYPE(value), TO_BOOL);
+        if (version == 0) {
+            goto failure;
+        }
         instr->op.code = TO_BOOL_ALWAYS_TRUE;
         write_u32(cache->version, version);
         assert(version);

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to