Author: Armin Rigo <[email protected]>
Branch: stm-gc
Changeset: r52799:077ddb94d35b
Date: 2012-02-23 13:57 +0100
http://bitbucket.org/pypy/pypy/changeset/077ddb94d35b/

Log:    Fix pointer comparison between two non-NULL objects.

diff --git a/pypy/rpython/lltypesystem/lloperation.py 
b/pypy/rpython/lltypesystem/lloperation.py
--- a/pypy/rpython/lltypesystem/lloperation.py
+++ b/pypy/rpython/lltypesystem/lloperation.py
@@ -403,6 +403,7 @@
     'stm_descriptor_init':    LLOp(canrun=True),
     'stm_descriptor_done':    LLOp(canrun=True),
     'stm_writebarrier':       LLOp(sideeffects=False),
+    'stm_normalize_global':   LLOp(),
     'stm_start_transaction':  LLOp(canrun=True),
     'stm_commit_transaction': LLOp(canrun=True),
 
diff --git a/pypy/rpython/memory/gc/stmgc.py b/pypy/rpython/memory/gc/stmgc.py
--- a/pypy/rpython/memory/gc/stmgc.py
+++ b/pypy/rpython/memory/gc/stmgc.py
@@ -320,6 +320,11 @@
         #
         @always_inline
         def stm_writebarrier(obj):
+            """The write barrier must be called on any object that may be
+            a global.  It looks for, and possibly makes, a local copy of
+            this object.  The result of this call is the local copy ---
+            or 'obj' itself if it is already local.
+            """
             if self.header(obj).tid & GCFLAG_GLOBAL != 0:
                 obj = _stm_write_barrier_global(obj)
             return obj
@@ -380,6 +385,23 @@
             stm_operations.tldict_add(obj, localobj)
             #
             return localobj
+        #
+        def stm_normalize_global(obj):
+            """Normalize a pointer for the purpose of equality
+            comparison with another pointer.  If 'obj' is the local
+            version of an existing global object, then returns the
+            global object.  Don't use for e.g. hashing, because if 'obj'
+            is a purely local object, it just returns 'obj' --- which
+            will change at the next commit.
+            """
+            if not obj:
+                return obj
+            tid = self.header(obj).tid
+            if tid & (GCFLAG_GLOBAL|GCFLAG_WAS_COPIED) != GCFLAG_WAS_COPIED:
+                return obj
+            # the only relevant case: it's the local copy of a global object
+            return self.header(obj).version
+        self.stm_normalize_global = stm_normalize_global
 
     # ----------
 
diff --git a/pypy/rpython/memory/gc/test/test_stmgc.py 
b/pypy/rpython/memory/gc/test/test_stmgc.py
--- a/pypy/rpython/memory/gc/test/test_stmgc.py
+++ b/pypy/rpython/memory/gc/test/test_stmgc.py
@@ -564,3 +564,27 @@
         s2 = llmemory.cast_adr_to_ptr(wr2.wadr, lltype.Ptr(S))
         assert s2.a == 4242
         assert s2 == tr1.s1   # tr1 is a root, so not copied yet
+
+    def test_normalize_global_null(self):
+        a = self.gc.stm_normalize_global(llmemory.NULL)
+        assert a == llmemory.NULL
+
+    def test_normalize_global_already_global(self):
+        sr1, sr1_adr = self.malloc(SR)
+        a = self.gc.stm_normalize_global(sr1_adr)
+        assert a == sr1_adr
+
+    def test_normalize_global_purely_local(self):
+        self.select_thread(1)
+        sr1, sr1_adr = self.malloc(SR)
+        a = self.gc.stm_normalize_global(sr1_adr)
+        assert a == sr1_adr
+
+    def test_normalize_global_local_copy(self):
+        sr1, sr1_adr = self.malloc(SR)
+        self.select_thread(1)
+        tr1_adr = self.gc.stm_writebarrier(sr1_adr)
+        a = self.gc.stm_normalize_global(sr1_adr)
+        assert a == sr1_adr
+        a = self.gc.stm_normalize_global(tr1_adr)
+        assert a == sr1_adr
diff --git a/pypy/rpython/memory/gctransform/stmframework.py 
b/pypy/rpython/memory/gctransform/stmframework.py
--- a/pypy/rpython/memory/gctransform/stmframework.py
+++ b/pypy/rpython/memory/gctransform/stmframework.py
@@ -18,6 +18,9 @@
         self.stm_writebarrier_ptr = getfn(
             self.gcdata.gc.stm_writebarrier,
             [annmodel.SomeAddress()], annmodel.SomeAddress())
+        self.stm_normalize_global_ptr = getfn(
+            self.gcdata.gc.stm_normalize_global,
+            [annmodel.SomeAddress()], annmodel.SomeAddress())
         self.stm_start_ptr = getfn(
             self.gcdata.gc.start_transaction.im_func,
             [s_gc], annmodel.s_None)
@@ -50,6 +53,15 @@
                                resulttype=llmemory.Address)
         hop.genop('cast_adr_to_ptr', [v_localadr], resultvar=op.result)
 
+    def gct_stm_normalize_global(self, hop):
+        op = hop.spaceop
+        v_adr = hop.genop('cast_ptr_to_adr',
+                          [op.args[0]], resulttype=llmemory.Address)
+        v_globaladr = hop.genop("direct_call",
+                                [self.stm_normalize_global_ptr, v_adr],
+                                resulttype=llmemory.Address)
+        hop.genop('cast_adr_to_ptr', [v_globaladr], resultvar=op.result)
+
     def gct_stm_start_transaction(self, hop):
         hop.genop("direct_call", [self.stm_start_ptr, self.c_const_gc])
 
diff --git a/pypy/translator/stm/localtracker.py 
b/pypy/translator/stm/localtracker.py
--- a/pypy/translator/stm/localtracker.py
+++ b/pypy/translator/stm/localtracker.py
@@ -20,7 +20,11 @@
         self.gsrc = GcSource(translator)
 
     def is_local(self, variable):
-        assert isinstance(variable, Variable)
+        if isinstance(variable, Constant):
+            if not variable.value:  # the constant NULL can be considered local
+                return True
+            self.reason = 'constant'
+            return False
         try:
             srcs = self.gsrc[variable]
         except KeyError:
diff --git a/pypy/translator/stm/transform.py b/pypy/translator/stm/transform.py
--- a/pypy/translator/stm/transform.py
+++ b/pypy/translator/stm/transform.py
@@ -102,11 +102,10 @@
             self.count_get_immutable += 1
             newoperations.append(op)
             return
-        if isinstance(op.args[0], Variable):
-            if self.localtracker.is_local(op.args[0]):
-                self.count_get_local += 1
-                newoperations.append(op)
-                return
+        if self.localtracker.is_local(op.args[0]):
+            self.count_get_local += 1
+            newoperations.append(op)
+            return
         self.count_get_nonlocal += 1
         op1 = SpaceOperation(stmopname, op.args, op.result)
         newoperations.append(op1)
@@ -153,8 +152,7 @@
         self.transform_set(newoperations, op)
 
     def stt_stm_writebarrier(self, newoperations, op):
-        if (isinstance(op.args[0], Variable) and
-            self.localtracker.is_local(op.args[0])):
+        if self.localtracker.is_local(op.args[0]):
             op = SpaceOperation('same_as', op.args, op.result)
         else:
             self.count_write_barrier += 1
@@ -183,6 +181,24 @@
             return
         newoperations.append(op)
 
+    def pointer_comparison(self, newoperations, op):
+        if (self.localtracker.is_local(op.args[0]) and
+            self.localtracker.is_local(op.args[1])):
+            newoperations.append(op)
+            return
+        nargs = []
+        for v1 in op.args:
+            if isinstance(v1, Variable):
+                v0 = v1
+                v1 = copyvar(self.translator.annotator, v0)
+                newoperations.append(
+                    SpaceOperation('stm_normalize_global', [v0], v1))
+            nargs.append(v1)
+        newoperations.append(SpaceOperation(op.opname, nargs, op.result))
+
+    stt_ptr_eq = pointer_comparison
+    stt_ptr_ne = pointer_comparison
+
 
 def transform_graph(graph):
     # for tests: only transforms one graph
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to