Author: Armin Rigo <[email protected]>
Branch: stmgc-c4
Changeset: r65130:bf3b552315d6
Date: 2013-06-30 21:32 +0200
http://bitbucket.org/pypy/pypy/changeset/bf3b552315d6/

Log:    threadlocals, and stm_ptr_eq()

diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py
--- a/rpython/rlib/rstm.py
+++ b/rpython/rlib/rstm.py
@@ -163,6 +163,7 @@
 
     def __init__(self, Cls):
         "NOT_RPYTHON: must be prebuilt"
+        import thread
         self.Cls = Cls
         self.local = thread._local()      # <- NOT_RPYTHON
         self.unique_id = ThreadLocalReference._COUNT
@@ -174,6 +175,8 @@
     @specialize.arg(0)
     def get(self):
         if we_are_translated():
+            from rpython.rtyper.lltypesystem import rclass
+            from rpython.rtyper.annlowlevel import cast_base_ptr_to_instance
             ptr = llop.stm_threadlocalref_get(rclass.OBJECTPTR, self.unique_id)
             return cast_base_ptr_to_instance(self.Cls, ptr)
         else:
@@ -183,6 +186,7 @@
     def set(self, value):
         assert isinstance(value, self.Cls) or value is None
         if we_are_translated():
+            from rpython.rtyper.annlowlevel import cast_instance_to_base_ptr
             ptr = cast_instance_to_base_ptr(value)
             llop.stm_threadlocalref_set(lltype.Void, self.unique_id, ptr)
         else:
diff --git a/rpython/rtyper/lltypesystem/lloperation.py 
b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -429,6 +429,7 @@
     'stm_minor_collect':      LLOp(),
     'stm_major_collect':      LLOp(),
     'stm_get_tid':            LLOp(canfold=True),
+    'stm_ptr_eq':             LLOp(canfold=True),
     'stm_id':                 LLOp(sideeffects=False),
     'stm_hash':               LLOp(sideeffects=False),
     'stm_push_root':          LLOp(),
@@ -440,6 +441,11 @@
     'stm_change_atomic':      LLOp(),
     'stm_get_atomic':         LLOp(sideeffects=False),
 
+    'stm_threadlocalref_get': LLOp(sideeffects=False),
+    'stm_threadlocalref_set': LLOp(),
+    'stm_threadlocal_get':    LLOp(sideeffects=False),
+    'stm_threadlocal_set':    LLOp(),
+
     # __________ address operations __________
 
     'boehm_malloc':         LLOp(),
diff --git a/rpython/translator/c/funcgen.py b/rpython/translator/c/funcgen.py
--- a/rpython/translator/c/funcgen.py
+++ b/rpython/translator/c/funcgen.py
@@ -601,6 +601,8 @@
     OP_STM_SET_TRANSACTION_LENGTH = _OP_STM
     OP_STM_CHANGE_ATOMIC = _OP_STM
     OP_STM_GET_ATOMIC = _OP_STM
+    OP_STM_THREADLOCAL_GET = _OP_STM
+    OP_STM_THREADLOCAL_SET = _OP_STM
 
 
     def OP_PTR_NONZERO(self, op):
diff --git a/rpython/translator/stm/funcgen.py 
b/rpython/translator/stm/funcgen.py
--- a/rpython/translator/stm/funcgen.py
+++ b/rpython/translator/stm/funcgen.py
@@ -71,7 +71,8 @@
     arg0 = funcgen.expr(op.args[0])
     arg1 = funcgen.expr(op.args[1])
     result = funcgen.expr(op.result)
-    return '%s = stm_pointer_equal(%s, %s);' % (result, arg0, arg1)
+    return '%s = stm_pointer_equal((gcptr)%s, (gcptr)%s);' % (
+        result, arg0, arg1)
 
 def stm_become_inevitable(funcgen, op):
     try:
@@ -135,6 +136,15 @@
     result = funcgen.expr(op.result)
     return '%s = stm_atomic(0);' % (result,)
 
+def stm_threadlocal_get(funcgen, op):
+    result = funcgen.expr(op.result)
+    return '%s = (%s)stm_thread_local_obj;' % (
+        result, cdecl(funcgen.lltypename(op.result), ''))
+
+def stm_threadlocal_set(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    return 'stm_thread_local_obj = (gcptr)%s;' % (arg0,)
+
 
 def op_stm(funcgen, op):
     func = globals()[op.opname]
diff --git a/rpython/translator/stm/test/test_ztranslated.py 
b/rpython/translator/stm/test/test_ztranslated.py
--- a/rpython/translator/stm/test/test_ztranslated.py
+++ b/rpython/translator/stm/test/test_ztranslated.py
@@ -215,12 +215,6 @@
             assert t.get() is None
             t.set(x)
             assert t.get() is x
-            assert llop.stm_threadlocalref_llcount(lltype.Signed) == 1
-            p = llop.stm_threadlocalref_lladdr(llmemory.Address, 0)
-            adr = p.address[0]
-            adr2 = cast_instance_to_base_ptr(x)
-            adr2 = llmemory.cast_ptr_to_adr(adr2)
-            assert adr == adr2
             print "ok"
             return 0
         t, cbuilder = self.compile(main)
diff --git a/rpython/translator/stm/threadlocalref.py 
b/rpython/translator/stm/threadlocalref.py
--- a/rpython/translator/stm/threadlocalref.py
+++ b/rpython/translator/stm/threadlocalref.py
@@ -1,10 +1,13 @@
-from rpython.rtyper.lltypesystem import lltype, llmemory
-from rpython.translator.unsimplify import varoftype
+from rpython.annotator import model as annmodel
+from rpython.rtyper import annlowlevel
+from rpython.rtyper.lltypesystem import lltype, rclass
+from rpython.rtyper.lltypesystem.lloperation import llop
 from rpython.flowspace.model import SpaceOperation, Constant
 
 
-def transform_tlref(graphs):
+def transform_tlref(t):
     ids = set()
+    graphs = t.graphs
     #
     for graph in graphs:
         for block in graph.iterblocks():
@@ -13,23 +16,35 @@
                 if (op.opname == 'stm_threadlocalref_set' or
                     op.opname == 'stm_threadlocalref_get'):
                     ids.add(op.args[0].value)
+    if not ids:
+        return
     #
     ids = sorted(ids)
-    ARRAY = lltype.FixedSizeArray(llmemory.Address, len(ids))
-    S = lltype.Struct('THREADLOCALREF', ('ptr', ARRAY),
-                      hints={'stm_thread_local': True})
-    ll_threadlocalref = lltype.malloc(S, immortal=True)
-    c_threadlocalref = Constant(ll_threadlocalref, lltype.Ptr(S))
-    c_fieldname = Constant('ptr', lltype.Void)
-    c_null = Constant(llmemory.NULL, llmemory.Address)
+    total = len(ids)
+    ARRAY = lltype.GcArray(rclass.OBJECTPTR)
     #
-    def getaddr(v_num, v_result):
-        v_array = varoftype(lltype.Ptr(ARRAY))
-        ops = [
-            SpaceOperation('getfield', [c_threadlocalref, c_fieldname],
-                           v_array),
-            SpaceOperation('direct_ptradd', [v_array, v_num], v_result)]
-        return ops
+    def ll_threadlocalref_get(index):
+        array = llop.stm_threadlocal_get(lltype.Ptr(ARRAY))
+        if not array:
+            return lltype.nullptr(rclass.OBJECTPTR.TO)
+        else:
+            return array[index]
+    #
+    def ll_threadlocalref_set(index, newvalue):
+        array = llop.stm_threadlocal_get(lltype.Ptr(ARRAY))
+        if not array:
+            array = lltype.malloc(ARRAY, total)
+            llop.stm_threadlocal_set(lltype.Void, array)
+        array[index] = newvalue
+    #
+    annhelper = annlowlevel.MixLevelHelperAnnotator(t.rtyper)
+    s_Int = annmodel.SomeInteger()
+    s_Ptr = annmodel.SomePtr(rclass.OBJECTPTR)
+    c_getter_ptr = annhelper.constfunc(ll_threadlocalref_get,
+                                       [s_Int], s_Ptr)
+    c_setter_ptr = annhelper.constfunc(ll_threadlocalref_set,
+                                       [s_Int, s_Ptr], annmodel.s_None)
+    annhelper.finish()
     #
     for graph in graphs:
         for block in graph.iterblocks():
@@ -38,26 +53,17 @@
                 if op.opname == 'stm_threadlocalref_set':
                     id = op.args[0].value
                     c_num = Constant(ids.index(id), lltype.Signed)
-                    v_addr = varoftype(lltype.Ptr(ARRAY))
-                    ops = getaddr(c_num, v_addr)
-                    ops.append(SpaceOperation('stm_threadlocalref_llset',
-                                              [v_addr, op.args[1]],
-                                              op.result))
+                    ops = [
+                        SpaceOperation('direct_call', [c_setter_ptr, c_num,
+                                                       op.args[1]],
+                                       op.result)
+                        ]
                     block.operations[i:i+1] = ops
                 elif op.opname == 'stm_threadlocalref_get':
                     id = op.args[0].value
                     c_num = Constant(ids.index(id), lltype.Signed)
-                    v_array = varoftype(lltype.Ptr(ARRAY))
                     ops = [
-                        SpaceOperation('getfield', [c_threadlocalref,
-                                                    c_fieldname],
-                                       v_array),
-                        SpaceOperation('getarrayitem', [v_array, c_num],
-                                       op.result)]
+                        SpaceOperation('direct_call', [c_getter_ptr, c_num],
+                                       op.result)
+                        ]
                     block.operations[i:i+1] = ops
-                elif op.opname == 'stm_threadlocalref_lladdr':
-                    block.operations[i:i+1] = getaddr(op.args[0], op.result)
-                elif op.opname == 'stm_threadlocalref_llcount':
-                    c_count = Constant(len(ids), lltype.Signed)
-                    op = SpaceOperation('same_as', [c_count], op.result)
-                    block.operations[i] = op
diff --git a/rpython/translator/stm/transform2.py 
b/rpython/translator/stm/transform2.py
--- a/rpython/translator/stm/transform2.py
+++ b/rpython/translator/stm/transform2.py
@@ -40,7 +40,7 @@
 
     def transform_threadlocalref(self):
         from rpython.translator.stm.threadlocalref import transform_tlref
-        transform_tlref(self.translator.graphs)
+        transform_tlref(self.translator)
 
     def start_log(self):
         from rpython.translator.c.support import log
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to