Author: Carl Friedrich Bolz <cfb...@gmx.de>
Branch: better-storesink
Changeset: r87168:8c2f9afa87bb
Date: 2016-09-13 14:07 +0200
http://bitbucket.org/pypy/pypy/changeset/8c2f9afa87bb/

Log:    make cast_pointer introduce a union to be able to track things
        better

        this requires the keys of the heapcache to also store the
        concretetype of the variable where the actual setfield/getfield is
        performed

diff --git a/rpython/translator/backendopt/cse.py 
b/rpython/translator/backendopt/cse.py
--- a/rpython/translator/backendopt/cse.py
+++ b/rpython/translator/backendopt/cse.py
@@ -2,6 +2,7 @@
 
 from rpython.translator.backendopt import support
 from rpython.rtyper.lltypesystem.lloperation import llop
+from rpython.rtyper.lltypesystem import lltype
 from rpython.flowspace.model import mkentrymap, Variable, Constant
 from rpython.translator.backendopt import removenoops
 from rpython.translator import simplify
@@ -48,7 +49,11 @@
                 self.heapcache.copy())
 
     def _var_rep(self, var):
-        var = self.new_unions.get(var, var)
+        while True:
+            newvar = self.new_unions.get(var, None)
+            if newvar is None:
+                break
+            var = newvar # can take several dereferences
         return self.variable_families.find_rep(var)
 
     def _key_with_replacement(self, key, index, var):
@@ -105,7 +110,7 @@
             firstlinkarg = self._var_rep(firstlink.args[argindex])
             for key, res in self.purecache.iteritems():
                 (opname, concretetype, args) = key
-                if args[0] != firstlinkarg: # XXX other args
+                if self._var_rep(args[0]) != firstlinkarg: # XXX other args
                     continue
                 results = [res]
                 for linkindex, (link, cache) in enumerate(tuples):
@@ -145,21 +150,21 @@
             # bit slow, but probably ok
             firstlinkarg = self._var_rep(firstlink.args[argindex])
             for key, res in self.heapcache.iteritems():
-                (arg, fieldname) = key
-                if arg != firstlinkarg:
+                (arg, concretetype, fieldname) = key
+                if self._var_rep(arg) != firstlinkarg:
                     continue
                 results = [res]
                 for linkindex, (link, cache) in enumerate(tuples):
                     if linkindex == 0:
                         continue
                     otherarg = cache._var_rep(link.args[argindex])
-                    newkey = (otherarg, fieldname)
+                    newkey = (otherarg, concretetype, fieldname)
                     otherres = cache.heapcache.get(newkey, None)
                     if otherres is None:
                         break
                     results.append(otherres)
                 else:
-                    newkey = (self._var_rep(inputarg), fieldname)
+                    newkey = (self._var_rep(inputarg), concretetype, fieldname)
                     newres = self._merge_results(tuples, results, backedges)
                     heapcache[newkey] = newres
 
@@ -179,7 +184,7 @@
 
     def _clear_heapcache_for(self, concretetype, fieldname):
         for k in self.heapcache.keys():
-            if k[0].concretetype == concretetype and k[1] == fieldname:
+            if k[1] == concretetype and k[2] == fieldname:
                 del self.heapcache[k]
 
     def _clear_heapcache_for_effects_of_op(self, op):
@@ -194,7 +199,7 @@
         else:
             for k in self.heapcache.keys():
                 # XXX slow
-                key = ('struct', k[0].concretetype, k[1])
+                key = ('struct', k[1], k[2])
                 if key in effects:
                     del self.heapcache[k]
 
@@ -217,17 +222,19 @@
             # heap operations
             if op.opname == 'getfield':
                 fieldname = op.args[1].value
+                concretetype = op.args[0].concretetype
                 arg0 = representative_arg(op.args[0])
                 res = None
                 if isinstance(arg0, Constant):
-                    PTRTYPE = arg0.concretetype.TO
+                    PTRTYPE = concretetype.TO
                     if PTRTYPE._immutable_field(fieldname):
                         # can constant-fold:
                         FIELDTYPE = getattr(PTRTYPE, fieldname)
-                        value = getattr(arg0.value, fieldname)
+                        const = lltype.cast_pointer(concretetype, arg0.value)
+                        value = getattr(const, fieldname)
                         res = Constant(value, FIELDTYPE)
                 if res is None:
-                    tup = (arg0, fieldname)
+                    tup = (arg0, op.args[0].concretetype, fieldname)
                     res = self.heapcache.get(tup, None)
                 if res is not None:
                     op.opname = 'same_as'
@@ -238,10 +245,11 @@
                     self.heapcache[tup] = op.result
                 continue
             if op.opname == 'setfield':
+                concretetype = op.args[0].concretetype
                 target = representative_arg(op.args[0])
                 field = op.args[1].value
-                self._clear_heapcache_for(target.concretetype, field)
-                self.heapcache[target, field] = op.args[2]
+                self._clear_heapcache_for(concretetype, field)
+                self.heapcache[target, concretetype, field] = op.args[2]
                 continue
             if has_side_effects(op):
                 self._clear_heapcache_for_effects_of_op(op)
@@ -260,6 +268,11 @@
                 self.new_unions[op.result] = res
             else:
                 self.purecache[key] = op.result
+            if op.opname == "cast_pointer":
+                # cast_pointer is a pretty strange operation! it introduces
+                # more aliases, that confuse the CSE pass. Therefore we unify
+                # the two variables in new_unions, to improve the folding.
+                self.new_unions[op.result] = op.args[0]
         return added_same_as
 
 def _merge(tuples, variable_families, analyzer, loop_blocks, backedges):
diff --git a/rpython/translator/backendopt/test/test_cse.py 
b/rpython/translator/backendopt/test/test_cse.py
--- a/rpython/translator/backendopt/test/test_cse.py
+++ b/rpython/translator/backendopt/test/test_cse.py
@@ -419,9 +419,45 @@
                     a.x = 2
                 # here a is a subclass of B
                 res += a.x
+            else:
+                a = A()
+            res += a.__class__ is A
             return res
         self.check(f, [int], getfield=0)
 
+    def test_cast_pointer_leading_to_constant(self):
+        class Cls(object):
+            pass
+        class Sub(Cls):
+            pass
+        cls1 = Cls()
+        cls2 = Sub()
+        cls2.user_overridden_class = True
+        cls3 = Sub()
+        cls3.user_overridden_class = False
+        class A(object):
+            pass
+        def f(i):
+            res = 0
+            if i > 20:
+                a = A()
+                a.cls = cls1
+                return 1
+            elif i > 30:
+                a = A()
+                a.cls = cls2
+                cls = a.cls
+                assert type(cls) is Sub
+                return cls.user_overridden_class
+            else:
+                a = A()
+                a.cls = cls3
+                cls = a.cls
+                assert type(cls) is Sub
+                return cls.user_overridden_class
+        self.check(f, [int], getfield=2)
+
+
 
 def fakevar(name='v'):
     var = Variable(name)
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to