Author: Armin Rigo <[email protected]>
Branch: stm-thread
Changeset: r55055:260225d12347
Date: 2012-05-12 11:31 +0200
http://bitbucket.org/pypy/pypy/changeset/260225d12347/

Log:    Adapt test_ztranslation.

diff --git a/pypy/translator/stm/test/targetdemo.py 
b/pypy/translator/stm/test/targetdemo.py
deleted file mode 100644
--- a/pypy/translator/stm/test/targetdemo.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from pypy.rpython.lltypesystem import lltype, rffi
-from pypy.rpython.lltypesystem.lloperation import llop
-from pypy.rlib import rstm
-from pypy.rlib.debug import debug_print, ll_assert
-
-
-class Node:
-    def __init__(self, value):
-        self.value = value
-        self.next = None
-
-class Global:
-    NUM_THREADS = 4
-    LENGTH      = 5000
-    USE_MEMORY  = False
-    anchor      = Node(-1)
-glob = Global()
-
-class Arg:
-    pass
-
-
-def add_at_end_of_chained_list(node, value):
-    x = Node(value)
-    while node.next:
-        node = node.next
-        if glob.USE_MEMORY:
-            x = Node(value)
-    newnode = x
-    node.next = newnode
-
-def check_chained_list(node):
-    seen = [0] * (glob.LENGTH+1)
-    seen[-1] = glob.NUM_THREADS
-    errors = glob.LENGTH
-    while node is not None:
-        value = node.value
-        #print value
-        if not (0 <= value < glob.LENGTH):
-            print "node.value out of bounds:", value
-            raise AssertionError
-        seen[value] += 1
-        if seen[value] > seen[value-1]:
-            errors = min(errors, value)
-        node = node.next
-    if errors < glob.LENGTH:
-        value = errors
-        print "seen[%d] = %d, seen[%d] = %d" % (value-1, seen[value-1],
-                                                value, seen[value])
-        raise AssertionError
-
-    if seen[glob.LENGTH-1] != glob.NUM_THREADS:
-        print "seen[LENGTH-1] != NUM_THREADS"
-        raise AssertionError
-    print "check ok!"
-
-
-def _check_pointer(arg1):
-    arg1.foobar = 40    # now 'arg1' is local
-    return arg1
-
-class CheckPointerEquality(rstm.Transaction):
-    def __init__(self, arg):
-        self.arg = arg
-    def run(self):
-        res = _check_pointer(self.arg)    # 'self.arg' reads a GLOBAL object
-        ll_assert(res is self.arg, "ERROR: bogus pointer equality")
-        raw1 = rffi.cast(rffi.CCHARP, self.retry_counter)
-        raw2 = rffi.cast(rffi.CCHARP, -1)
-        ll_assert(raw1 != raw2, "ERROR: retry_counter == -1")
-
-class MakeChain(rstm.Transaction):
-    def __init__(self, anchor, value):
-        self.anchor = anchor
-        self.value = value
-    def run(self):
-        add_at_end_of_chained_list(self.anchor, self.value)
-        self.value += 1
-        if self.value < glob.LENGTH:
-            return [self]       # re-schedule the same Transaction object
-
-class InitialTransaction(rstm.Transaction):
-    def run(self):
-        debug_print("InitialTransaction.run", self.retry_counter)
-        ll_assert(self.retry_counter == 0, "no reason to abort-and-retry here")
-        ll_assert(rstm.thread_id() != 0, "thread_id == 0")
-        scheduled = []
-        for i in range(glob.NUM_THREADS):
-            arg = Arg()
-            arg.foobar = 41
-            scheduled.append(CheckPointerEquality(arg))
-            scheduled.append(MakeChain(glob.anchor, 0))
-        return scheduled
-
-# __________  Entry point  __________
-
-def entry_point(argv):
-    print "hello world"
-    assert rstm.stm_is_enabled()
-    if len(argv) > 1:
-        glob.NUM_THREADS = int(argv[1])
-        if len(argv) > 2:
-            glob.LENGTH = int(argv[2])
-            if len(argv) > 3:
-                glob.USE_MEMORY = bool(int(argv[3]))
-    ll_assert(rstm.thread_id() == 0, "thread_id != 0")
-    #
-    rstm.run_all_transactions(InitialTransaction(),
-                              num_threads=glob.NUM_THREADS)
-    check_chained_list(glob.anchor.next)
-    #
-    glob.anchor.next = None
-    rstm.run_all_transactions(InitialTransaction(),
-                              num_threads=glob.NUM_THREADS)
-    check_chained_list(glob.anchor.next)
-    #
-    return 0
-
-# _____ Define and setup target ___
-
-def target(*args):
-    return entry_point, None
-
-if __name__ == '__main__':
-    import sys
-    entry_point(sys.argv)
diff --git a/pypy/translator/stm/test/targetdemo2.py 
b/pypy/translator/stm/test/targetdemo2.py
--- a/pypy/translator/stm/test/targetdemo2.py
+++ b/pypy/translator/stm/test/targetdemo2.py
@@ -2,6 +2,8 @@
 from pypy.module.thread import ll_thread
 from pypy.rlib import rstm
 from pypy.rlib.objectmodel import invoke_around_extcall, we_are_translated
+from pypy.rlib.debug import ll_assert
+from pypy.rpython.lltypesystem import rffi
 
 
 class Node:
@@ -64,6 +66,9 @@
     def run(self):
         try:
             self.value = 0
+            self.arg = Arg()
+            rstm.perform_transaction(ThreadRunner.check_ptr_equality,
+                                     ThreadRunner, self)
             rstm.perform_transaction(ThreadRunner.run_really,
                                      ThreadRunner, self)
         finally:
@@ -85,6 +90,21 @@
         self.value += 1
         return int(self.value < glob.LENGTH)
 
+    def check_ptr_equality(self, retry_counter):
+        res = _check_pointer(self.arg)    # 'self.arg' reads a GLOBAL object
+        ll_assert(res is self.arg, "ERROR: bogus pointer equality")
+        raw1 = rffi.cast(rffi.CCHARP, retry_counter)
+        raw2 = rffi.cast(rffi.CCHARP, -1)
+        ll_assert(raw1 != raw2, "ERROR: retry_counter == -1")
+        return 0
+
+class Arg:
+    foobar = 42
+
+def _check_pointer(arg1):
+    arg1.foobar = 40    # now 'arg1' is local
+    return arg1
+
 # ____________________________________________________________
 # bah, we are really missing an RPython interface to threads
 
diff --git a/pypy/translator/stm/test/test_ztranslated.py 
b/pypy/translator/stm/test/test_ztranslated.py
--- a/pypy/translator/stm/test/test_ztranslated.py
+++ b/pypy/translator/stm/test/test_ztranslated.py
@@ -1,28 +1,30 @@
 import py
 from pypy.rlib import rstm, rgc
 from pypy.translator.stm.test.support import CompiledSTMTests
-from pypy.translator.stm.test import targetdemo
+from pypy.translator.stm.test import targetdemo2
 
 
 class TestSTMTranslated(CompiledSTMTests):
 
     def test_targetdemo(self):
-        t, cbuilder = self.compile(targetdemo.entry_point)
+        t, cbuilder = self.compile(targetdemo2.entry_point)
         data, dataerr = cbuilder.cmdexec('4 5000', err=True)
         assert 'check ok!' in data
 
     def test_bug1(self):
         #
-        class InitialTransaction(rstm.Transaction):
-            def run(self):
-                rgc.collect(0)
+        class Foobar:
+            pass
+        def check(foobar, retry_counter):
+            rgc.collect(0)
+            return 0
         #
         class X:
             def __init__(self, count):
                 self.count = count
         def g():
             x = X(1000)
-            rstm.run_all_transactions(InitialTransaction())
+            rstm.perform_transaction(check, Foobar, Foobar())
             return x
         def entry_point(argv):
             x = X(len(argv))
@@ -36,9 +38,10 @@
 
     def test_bug2(self):
         #
-        class DoNothing(rstm.Transaction):
-            def run(self):
-                pass
+        class Foobar:
+            pass
+        def check(foobar, retry_counter):
+            return 0    # do nothing
         #
         class X2:
             pass
@@ -48,7 +51,7 @@
             x = prebuilt2[count]
             x.foobar = 2                    # 'x' becomes a local
             #
-            rstm.run_all_transactions(DoNothing())
+            rstm.perform_transaction(check, Foobar, Foobar())
                                             # 'x' becomes the global again
             #
             y = prebuilt2[count]            # same prebuilt obj
@@ -65,9 +68,10 @@
         assert '12\n12\n' in data, "got: %r" % (data,)
 
     def test_prebuilt_nongc(self):
-        class DoNothing(rstm.Transaction):
-            def run(self):
-                pass
+        class Foobar:
+            pass
+        def check(foobar, retry_counter):
+            return 0    # do nothing
         from pypy.rpython.lltypesystem import lltype
         R = lltype.GcStruct('R', ('x', lltype.Signed))
         S1 = lltype.Struct('S1', ('r', lltype.Ptr(R)))
@@ -76,7 +80,7 @@
         #                   hints={'stm_thread_local': True})
         #s2 = lltype.malloc(S2, immortal=True, flavor='raw')
         def do_stuff():
-            rstm.run_all_transactions(DoNothing())
+            rstm.perform_transaction(check, Foobar, Foobar())
             print s1.r.x
             #print s2.r.x
         do_stuff._dont_inline_ = True
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to