https://github.com/python/cpython/commit/ff96b81d78c4a52fb1eb8384300af3dd0dd2db0d
commit: ff96b81d78c4a52fb1eb8384300af3dd0dd2db0d
branch: main
author: Ken Jin <[email protected]>
committer: Fidget-Spinner <[email protected]>
date: 2024-03-02T03:40:04+08:00
summary:

gh-115480: Type propagate _BINARY_OP_ADD_UNICODE (GH-115710)

files:
M Lib/test/test_capi/test_opt.py
M Python/optimizer_bytecodes.c
M Python/optimizer_cases.c.h

diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index f4fcdea05e96bf..a0a19225b79433 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -795,11 +795,14 @@ def test_float_add_constant_propagation(self):
         def testfunc(n):
             a = 1.0
             for _ in range(n):
-                a = a + 0.1
+                a = a + 0.25
+                a = a + 0.25
+                a = a + 0.25
+                a = a + 0.25
             return a
 
         res, ex = self._run_with_optimizer(testfunc, 32)
-        self.assertAlmostEqual(res, 4.2)
+        self.assertAlmostEqual(res, 33.0)
         self.assertIsNotNone(ex)
         uops = get_opnames(ex)
         guard_both_float_count = [opname for opname in iter_opnames(ex) if 
opname == "_GUARD_BOTH_FLOAT"]
@@ -812,11 +815,14 @@ def test_float_subtract_constant_propagation(self):
         def testfunc(n):
             a = 1.0
             for _ in range(n):
-                a = a - 0.1
+                a = a - 0.25
+                a = a - 0.25
+                a = a - 0.25
+                a = a - 0.25
             return a
 
         res, ex = self._run_with_optimizer(testfunc, 32)
-        self.assertAlmostEqual(res, -2.2)
+        self.assertAlmostEqual(res, -31.0)
         self.assertIsNotNone(ex)
         uops = get_opnames(ex)
         guard_both_float_count = [opname for opname in iter_opnames(ex) if 
opname == "_GUARD_BOTH_FLOAT"]
@@ -829,11 +835,14 @@ def test_float_multiply_constant_propagation(self):
         def testfunc(n):
             a = 1.0
             for _ in range(n):
-                a = a * 2.0
+                a = a * 1.0
+                a = a * 1.0
+                a = a * 1.0
+                a = a * 1.0
             return a
 
         res, ex = self._run_with_optimizer(testfunc, 32)
-        self.assertAlmostEqual(res, 2 ** 32)
+        self.assertAlmostEqual(res, 1.0)
         self.assertIsNotNone(ex)
         uops = get_opnames(ex)
         guard_both_float_count = [opname for opname in iter_opnames(ex) if 
opname == "_GUARD_BOTH_FLOAT"]
@@ -842,6 +851,24 @@ def testfunc(n):
         # We'll also need to verify that propagation actually occurs.
         self.assertIn("_BINARY_OP_MULTIPLY_FLOAT", uops)
 
+    def test_add_unicode_propagation(self):
+        def testfunc(n):
+            a = ""
+            for _ in range(n):
+                a + a
+                a + a
+                a + a
+                a + a
+            return a
+
+        res, ex = self._run_with_optimizer(testfunc, 32)
+        self.assertEqual(res, "")
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+        guard_both_unicode_count = [opname for opname in iter_opnames(ex) if 
opname == "_GUARD_BOTH_UNICODE"]
+        self.assertLessEqual(len(guard_both_unicode_count), 1)
+        self.assertIn("_BINARY_OP_ADD_UNICODE", uops)
+
     def test_compare_op_type_propagation_float(self):
         def testfunc(n):
             a = 1.0
diff --git a/Python/optimizer_bytecodes.c b/Python/optimizer_bytecodes.c
index 2b47381ec76db4..786d884fc5a1a8 100644
--- a/Python/optimizer_bytecodes.c
+++ b/Python/optimizer_bytecodes.c
@@ -254,6 +254,22 @@ dummy_func(void) {
         }
     }
 
+    op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
+        if (sym_is_const(left) && sym_is_const(right) &&
+            sym_matches_type(left, &PyUnicode_Type) && sym_matches_type(right, 
&PyUnicode_Type)) {
+            PyObject *temp = PyUnicode_Concat(sym_get_const(left), 
sym_get_const(right));
+            if (temp == NULL) {
+                goto error;
+            }
+            res = sym_new_const(ctx, temp);
+            Py_DECREF(temp);
+            OUT_OF_SPACE_IF_NULL(res);
+        }
+        else {
+            OUT_OF_SPACE_IF_NULL(res = sym_new_type(ctx, &PyUnicode_Type));
+        }
+    }
+
     op(_TO_BOOL, (value -- res)) {
         (void)value;
         res = sym_new_type(ctx, &PyBool_Type);
diff --git a/Python/optimizer_cases.c.h b/Python/optimizer_cases.c.h
index 9d7ebb80f62857..6d3488f2118589 100644
--- a/Python/optimizer_cases.c.h
+++ b/Python/optimizer_cases.c.h
@@ -446,9 +446,24 @@
         }
 
         case _BINARY_OP_ADD_UNICODE: {
+            _Py_UopsSymbol *right;
+            _Py_UopsSymbol *left;
             _Py_UopsSymbol *res;
-            res = sym_new_unknown(ctx);
-            if (res == NULL) goto out_of_space;
+            right = stack_pointer[-1];
+            left = stack_pointer[-2];
+            if (sym_is_const(left) && sym_is_const(right) &&
+                sym_matches_type(left, &PyUnicode_Type) && 
sym_matches_type(right, &PyUnicode_Type)) {
+                PyObject *temp = PyUnicode_Concat(sym_get_const(left), 
sym_get_const(right));
+                if (temp == NULL) {
+                    goto error;
+                }
+                res = sym_new_const(ctx, temp);
+                Py_DECREF(temp);
+                OUT_OF_SPACE_IF_NULL(res);
+            }
+            else {
+                OUT_OF_SPACE_IF_NULL(res = sym_new_type(ctx, &PyUnicode_Type));
+            }
             stack_pointer[-2] = res;
             stack_pointer += -1;
             break;

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to