Author: Armin Rigo <ar...@tunes.org>
Branch: c7-refactor
Changeset: r789:fa6db780b4a1
Date: 2014-02-19 17:50 +0100
http://bitbucket.org/pypy/stmgc/changeset/fa6db780b4a1/

Log:    merge heads

diff --git a/c7/test/support.py b/c7/test/support.py
--- a/c7/test/support.py
+++ b/c7/test/support.py
@@ -305,6 +305,12 @@
     lib._set_type_id(o, tid)
     return o
 
+def stm_allocate_old_refs(n):
+    o = lib._stm_allocate_old(HDR + n * WORD)
+    tid = 421420 + n
+    lib._set_type_id(o, tid)
+    return o
+
 def stm_allocate(size):
     o = lib.stm_allocate(size)
     tid = 42 + size
diff --git a/c7/test/test_random.py b/c7/test/test_random.py
--- a/c7/test/test_random.py
+++ b/c7/test/test_random.py
@@ -12,35 +12,55 @@
         print >> sys.stderr, cmd
         exec cmd in globals(), self.content
 
+        
+
 _root_numbering = 0
-def get_new_root_name():
+is_ref_type_map = {}
+def get_new_root_name(is_ref_type):
     global _root_numbering
     _root_numbering += 1
-    return "lp%d" % _root_numbering
+    r = "lp%d" % _root_numbering
+    is_ref_type_map[r] = is_ref_type
+    return r
+
 
 _global_time = 0
-def contention_management(our_trs, other_trs, wait=False):
+def contention_management(our_trs, other_trs, wait=False, 
objs_in_conflict=None):
+    """exact copy of logic in contention.c"""
+    
     if other_trs.start_time < our_trs.start_time:
         pass
     else:
-        other_trs.must_abort = True
+        other_trs.set_must_abort(objs_in_conflict)
         
-    if not other_trs.must_abort:
-        our_trs.must_abort = True
+    if not other_trs.check_must_abort():
+        our_trs.set_must_abort(objs_in_conflict)
     elif wait:
         # abort anyway:
-        our_trs.must_abort = True
+        our_trs.set_must_abort(objs_in_conflict)
         
 
 class TransactionState(object):
-    """maintains read/write sets"""
+    """State of a transaction running in a thread,
+    e.g. maintains read/write sets. The state will be
+    discarded on abort or pushed to other threads"""
+    
     def __init__(self, start_time):
         self.read_set = set()
         self.write_set = set()
         self.values = {}
-        self.must_abort = False
+        self._must_abort = False
         self.start_time = start_time
+        self.objs_in_conflict = set()
 
+    def set_must_abort(self, objs_in_conflict=None):
+        if objs_in_conflict is not None:
+            self.objs_in_conflict |= objs_in_conflict
+        self._must_abort = True
+
+    def check_must_abort(self):
+        return self._must_abort
+        
     def has_conflict_with(self, committed):
         return bool(self.read_set & committed.write_set)
     
@@ -53,8 +73,9 @@
             self.values.update(committed.values)
 
         if self.has_conflict_with(committed):
-            contention_management(self, committed)
-        return self.must_abort
+            contention_management(self, committed,
+                                  objs_in_conflict=self.read_set & 
committed.write_set)
+        return self.check_must_abort()
 
     def read_root(self, r):
         self.read_set.add(r)
@@ -69,7 +90,9 @@
         
 
 class ThreadState(object):
-    """maintains state for one thread """
+    """Maintains state for one thread. Mostly manages things
+    to be kept between transactions (e.g. saved roots) and
+    handles discarding/reseting states on transaction abort"""
     
     def __init__(self, num, global_state):
         self.num = num
@@ -114,7 +137,7 @@
 
     def update_roots(self, ex):
         assert self.roots_on_stack == self.roots_on_transaction_start
-        for r in self.saved_roots[::-1]:
+        for r in reversed(self.saved_roots):
             ex.do('%s = self.pop_root()' % r)
             self.roots_on_stack -= 1
         assert self.roots_on_stack == 0
@@ -137,7 +160,7 @@
         trs = self.transaction_state
         gtrs = self.global_state.committed_transaction_state
         self.global_state.check_for_write_read_conflicts(trs)
-        conflicts = trs.must_abort
+        conflicts = trs.check_must_abort()
         if not conflicts:
             # update global committed state w/o conflict
             assert not gtrs.update_from_committed(trs)
@@ -146,13 +169,17 @@
         return conflicts
 
     def abort_transaction(self):
-        assert self.transaction_state.must_abort
+        assert self.transaction_state.check_must_abort()
         self.roots_on_stack = self.roots_on_transaction_start
         del self.saved_roots[self.roots_on_stack:]
         self.transaction_state = None
 
         
 class GlobalState(object):
+    """Maintains the global view (in a TransactionState) on
+    objects and threads. It also handles checking for conflicts
+    between threads and pushing state to other threads"""
+    
     def __init__(self, ex, rnd):
         self.ex = ex
         self.rnd = rnd
@@ -161,41 +188,48 @@
         self.committed_transaction_state = TransactionState(0)
 
     def push_state_to_other_threads(self, tr_state):
-        assert not tr_state.must_abort
+        assert not tr_state.check_must_abort()
         for ts in self.thread_states:
             other_trs = ts.transaction_state
             if other_trs is None or other_trs is tr_state:
                 continue
             other_trs.update_from_committed(tr_state, only_new=True)
 
-        if tr_state.must_abort:
-            self.ex.do('# conflict while pushing to other threads')
+        if tr_state.check_must_abort():
+            self.ex.do('# conflict while pushing to other threads: %s' %
+                       tr_state.objs_in_conflict)
 
     def check_for_write_write_conflicts(self, tr_state):
-        assert not tr_state.must_abort
+        assert not tr_state.check_must_abort()
+        for ts in self.thread_states:
+            other_trs = ts.transaction_state
+            if other_trs is None or other_trs is tr_state:
+                continue
+
+            confl_set = other_trs.write_set & tr_state.write_set
+            if confl_set:
+                contention_management(tr_state, other_trs, True,
+                                      objs_in_conflict=confl_set)
+                
+        if tr_state.check_must_abort():
+            self.ex.do('# write-write conflict: %s' %
+                       tr_state.objs_in_conflict)
+
+    def check_for_write_read_conflicts(self, tr_state):
+        assert not tr_state.check_must_abort()
         for ts in self.thread_states:
             other_trs = ts.transaction_state
             if other_trs is None or other_trs is tr_state:
                 continue
             
-            if other_trs.write_set & tr_state.write_set:
-                contention_management(tr_state, other_trs, True)
+            confl_set = other_trs.read_set & tr_state.write_set
+            if confl_set:
+                contention_management(tr_state, other_trs,
+                                      objs_in_conflict=confl_set)
                 
-        if tr_state.must_abort:
-            self.ex.do('# write-write conflict')
-
-    def check_for_write_read_conflicts(self, tr_state):
-        assert not tr_state.must_abort
-        for ts in self.thread_states:
-            other_trs = ts.transaction_state
-            if other_trs is None or other_trs is tr_state:
-                continue
-            
-            if other_trs.read_set & tr_state.write_set:
-                contention_management(tr_state, other_trs)
-                
-        if tr_state.must_abort:
-            self.ex.do('# write-read conflict')
+        if tr_state.check_must_abort():
+            self.ex.do('# write-read conflict: %s' %
+                       tr_state.objs_in_conflict)
 
 
 # ========== STM OPERATIONS ==========
@@ -224,48 +258,89 @@
             ex.do('py.test.raises(Conflict, self.commit_transaction)')
         else:
             ex.do('self.commit_transaction()')
+
+class OpAbortTransaction(Operation):
+    def do(self, ex, global_state, thread_state):
+        thread_state.transaction_state.set_must_abort()
+        thread_state.abort_transaction()
+        ex.do('self.abort_transaction()')
+
+            
             
 class OpAllocate(Operation):
     def do(self, ex, global_state, thread_state):
-        r = get_new_root_name()
+        r = get_new_root_name(False)
         thread_state.push_roots(ex)
-        ex.do('%s = stm_allocate(16)' % r)
+        size = global_state.rnd.choice([
+            16,
+            "SOME_MEDIUM_SIZE+16",
+            "SOME_LARGE_SIZE+16",
+        ])
+        ex.do('%s = stm_allocate(%s)' % (r, size))
         assert thread_state.transaction_state.write_root(r, 0) is None
         
         thread_state.pop_roots(ex)
         thread_state.register_root(r)
 
+class OpAllocateRef(Operation):
+    def do(self, ex, global_state, thread_state):
+        r = get_new_root_name(True)
+        thread_state.push_roots(ex)
+        num = global_state.rnd.randrange(1, 10)
+        ex.do('%s = stm_allocate_refs(%s)' % (r, num))
+        assert thread_state.transaction_state.write_root(r, "ffi.NULL") is None
+        
+        thread_state.pop_roots(ex)
+        thread_state.register_root(r)
+
+
 class OpForgetRoot(Operation):
     def do(self, ex, global_state, thread_state):
         r = thread_state.forget_random_root()
         ex.do('# forget %s' % r)
 
-class OpSetChar(Operation):
+class OpWrite(Operation):
     def do(self, ex, global_state, thread_state):
         r = thread_state.get_random_root()
-        v = ord(global_state.rnd.choice("abcdefghijklmnop"))
+        if is_ref_type_map[r]:
+            v = thread_state.get_random_root()
+        else:
+            v = ord(global_state.rnd.choice("abcdefghijklmnop"))
         trs = thread_state.transaction_state
         trs.write_root(r, v)
 
         global_state.check_for_write_write_conflicts(trs)
-        if trs.must_abort:
+        if trs.check_must_abort():
             thread_state.abort_transaction()
-            ex.do("py.test.raises(Conflict, stm_set_char, %s, %s)" % (r, 
repr(chr(v))))
+            if is_ref_type_map[r]:
+                ex.do("py.test.raises(Conflict, stm_set_ref, %s, 0, %s)" % (r, 
v))
+            else:
+                ex.do("py.test.raises(Conflict, stm_set_char, %s, %s)" % (r, 
repr(chr(v))))
         else:
-            ex.do("stm_set_char(%s, %s)" % (r, repr(chr(v))))
+            if is_ref_type_map[r]:
+                ex.do("stm_set_ref(%s, 0, %s)" % (r, v))
+            else:
+                ex.do("stm_set_char(%s, %s)" % (r, repr(chr(v))))
 
-class OpGetChar(Operation):
+class OpRead(Operation):
     def do(self, ex, global_state, thread_state):
         r = thread_state.get_random_root()
         trs = thread_state.transaction_state
         v = trs.read_root(r)
         #
-        ex.do("assert stm_get_char(%s) == %s" % (r, repr(chr(v))))
+        if is_ref_type_map[r]:
+            if v in thread_state.saved_roots or v in global_state.shared_roots:
+                ex.do("assert stm_get_ref(%s, 0) == %s" % (r, v))
+            else:
+                # we still need to read it (as it is in the read-set):
+                ex.do("stm_get_ref(%s, 0)" % r)
+        else:
+            ex.do("assert stm_get_char(%s) == %s" % (r, repr(chr(v))))
 
 class OpSwitchThread(Operation):
     def do(self, ex, global_state, thread_state):
         trs = thread_state.transaction_state
-        conflicts = trs is not None and trs.must_abort
+        conflicts = trs is not None and trs.check_must_abort()
         #
         if conflicts:
             thread_state.abort_transaction()
@@ -281,7 +356,7 @@
     def test_fixed_16_bytes_objects(self, seed=1010):
         rnd = random.Random(seed)
 
-        N_OBJECTS = 5
+        N_OBJECTS = 3
         N_THREADS = 2
         ex = Exec(self)
         ex.do("""
@@ -299,18 +374,41 @@
         curr_thread = global_state.thread_states[0]
 
         for i in range(N_OBJECTS):
-            r = get_new_root_name()
+            r = get_new_root_name(False)
             ex.do('%s = stm_allocate_old(16)' % r)
             global_state.committed_transaction_state.write_root(r, 0)
             global_state.shared_roots.append(r)
+
+            r = get_new_root_name(True)
+            ex.do('%s = stm_allocate_old_refs(1)' % r)
+            global_state.committed_transaction_state.write_root(r, "ffi.NULL")
+            global_state.shared_roots.append(r)
         global_state.committed_transaction_state.write_set = set()
         global_state.committed_transaction_state.read_set = set()
 
         # random steps:
+        possible_actions = [
+                OpAllocate,
+                OpAllocateRef,
+                OpWrite,
+                OpWrite,
+                OpWrite,
+                OpWrite,
+                OpRead,
+                OpRead,
+                OpRead,
+                OpRead,
+                OpRead,
+                OpRead,
+                OpCommitTransaction,
+                OpAbortTransaction,
+                OpForgetRoot,
+            ]
         remaining_steps = 200
         while remaining_steps > 0:
             remaining_steps -= 1
 
+            # make sure we are in a transaction:
             n_thread = rnd.randrange(0, N_THREADS)
             if n_thread != curr_thread.num:
                 ex.do('#')
@@ -319,23 +417,28 @@
             if curr_thread.transaction_state is None:
                 OpStartTransaction().do(ex, global_state, curr_thread)
 
-            action = rnd.choice([
-                OpAllocate,
-                OpSetChar,
-                OpSetChar,
-                OpGetChar,
-                OpGetChar,
-                OpCommitTransaction,
-                OpForgetRoot,
-            ])
+            # do something random
+            action = rnd.choice(possible_actions)
             action().do(ex, global_state, curr_thread)
+
+        # to make sure we don't have aborts in the test's teardown method,
+        # we will simply stop all running transactions
+        for ts in global_state.thread_states:
+            if ts.transaction_state is not None:
+                if curr_thread != ts:
+                    ex.do('#')
+                    curr_thread = ts
+                    OpSwitchThread().do(ex, global_state, curr_thread)
+                if curr_thread.transaction_state:
+                    # could have aborted in the switch() above
+                    OpCommitTransaction().do(ex, global_state, curr_thread)
             
 
 
     def _make_fun(seed):
         def test_fun(self):
             self.test_fixed_16_bytes_objects(seed)
-        test_fun.__name__ = 'test_fixed_16_bytes_objects_%d' % seed
+        test_fun.__name__ = 'test_random_%d' % seed
         return test_fun
 
     for _seed in range(5000, 5100):
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to