Author: Armin Rigo <[email protected]>
Branch: stmgc-c7
Changeset: r76005:9d4605f70a31
Date: 2015-02-19 23:36 +0100
http://bitbucket.org/pypy/pypy/changeset/9d4605f70a31/

Log:    Change again a bit the API in transaction.py. Now it feels more
        natural.

diff --git a/lib_pypy/pypy_test/test_transaction.py 
b/lib_pypy/pypy_test/test_transaction.py
--- a/lib_pypy/pypy_test/test_transaction.py
+++ b/lib_pypy/pypy_test/test_transaction.py
@@ -8,9 +8,10 @@
 def test_simple_random_order():
     for x in range(N):
         lst = []
-        with transaction.TransactionQueue():
-            for i in range(10):
-                transaction.add(lst.append, i)
+        tq = transaction.TransactionQueue()
+        for i in range(10):
+            tq.add(lst.append, i)
+        tq.run()
         if VERBOSE:
             print lst
         assert sorted(lst) == range(10), lst
@@ -22,9 +23,10 @@
             lst.append(i)
             i += 1
             if i < 10:
-                transaction.add(do_stuff, i)
-        with transaction.TransactionQueue():
-            transaction.add(do_stuff, 0)
+                tq.add(do_stuff, i)
+        tq = transaction.TransactionQueue()
+        tq.add(do_stuff, 0)
+        tq.run()
         if VERBOSE:
             print lst
         assert lst == range(10), lst
@@ -36,10 +38,11 @@
             lsts[i].append(j)
             j += 1
             if j < 10:
-                transaction.add(do_stuff, i, j)
-        with transaction.TransactionQueue():
-            for i in range(5):
-                transaction.add(do_stuff, i, 0)
+                tq.add(do_stuff, i, j)
+        tq = transaction.TransactionQueue()
+        for i in range(5):
+            tq.add(do_stuff, i, 0)
+        tq.run()
         if VERBOSE:
             print lsts
         assert lsts == (range(10),) * 5, lsts
@@ -53,14 +56,15 @@
             lsts[i].append(j)
             j += 1
             if j < 5:
-                transaction.add(do_stuff, i, j)
+                tq.add(do_stuff, i, j)
             else:
                 lsts[i].append('foo')
                 raise FooError
+        tq = transaction.TransactionQueue()
+        for i in range(10):
+            tq.add(do_stuff, i, 0)
         try:
-            with transaction.TransactionQueue():
-                for i in range(10):
-                    transaction.add(do_stuff, i, 0)
+            tq.run()
         except FooError:
             pass
         else:
@@ -78,19 +82,66 @@
 
 
 def test_number_of_transactions_reported():
-    py.test.skip("not reimplemented")
-    with transaction.TransactionQueue():
-        transaction.add(lambda: None)
-    assert transaction.number_of_transactions_in_last_run() == 1
+    tq = transaction.TransactionQueue()
+    tq.add(lambda: None)
+    tq.add(lambda: None)
+    tq.run()
+    assert tq.number_of_transactions_executed() == 2
+
+    tq.run()
+    assert tq.number_of_transactions_executed() == 2
+
+    tq.add(lambda: None)
+    tq.run()
+    assert tq.number_of_transactions_executed() == 3
+
+    tq.add(lambda: some_name_that_is_not_defined)
+    try:
+        tq.run()
+    except NameError:
+        pass
+    else:
+        raise AssertionError("should have raised NameError")
+    assert tq.number_of_transactions_executed() == 4
 
     def add_transactions(l):
         if l:
             for x in range(l[0]):
-                transaction.add(add_transactions, l[1:])
+                tq.add(add_transactions, l[1:])
 
-    with transaction.TransactionQueue():
-        transaction.add(add_transactions, [10, 10, 10])
-    assert transaction.number_of_transactions_in_last_run() == 1111
+    tq = transaction.TransactionQueue()
+    tq.add(add_transactions, [10, 10, 10])
+    tq.run()
+    assert tq.number_of_transactions_executed() == 1111
+
+def test_unexecuted_transactions_after_exception():
+    class FooError(Exception):
+        pass
+    class BarError(Exception):
+        pass
+    def raiseme(exc):
+        raise exc
+    seen = []
+    tq = transaction.TransactionQueue()
+    tq.add(raiseme, FooError)
+    tq.add(raiseme, BarError)
+    tq.add(seen.append, 42)
+    tq.add(seen.append, 42)
+    try:
+        tq.run()
+    except (FooError, BarError), e:
+        seen_exc = e.__class__
+    else:
+        raise AssertionError("should have raised FooError or BarError")
+    try:
+        tq.run()
+    except (FooError, BarError), e:
+        assert e.__class__ != seen_exc
+    else:
+        raise AssertionError("unexecuted transactions have disappeared")
+    for i in range(2):
+        tq.run()
+        assert seen == [42, 42]
 
 
 def test_stmidset():
diff --git a/lib_pypy/transaction.py b/lib_pypy/transaction.py
--- a/lib_pypy/transaction.py
+++ b/lib_pypy/transaction.py
@@ -108,106 +108,101 @@
     pass
 
 
-def add(f, *args, **kwds):
-    """Register a new transaction that will be done by 'f(*args, **kwds)'.
-    Must be called within the transaction in the "with TransactionQueue()"
-    block, or within a transaction started by this one, directly or
-    indirectly.
-    """
-    _thread_local.pending.append((f, args, kwds))
+class TransactionQueue(object):
+    """A queue of pending transactions.
 
-
-class TransactionQueue(object):
-    """Use in 'with TransactionQueue():'.  Creates a queue of
-    transactions.  The first transaction in the queue is the content of
-    the 'with:' block, which is immediately started.
-
-    Any transaction can register new transactions that will be run
-    after the current one is finished, using the global function add().
+    Use the add() method to register new transactions into the queue.
+    Afterwards, call run() once.  While transactions run, it is possible
+    to add() more transactions, which will run after the current one is
+    finished.  The run() call only returns when the queue is completely
+    empty.
     """
 
-    def __init__(self, nb_segments=0):
+    def __init__(self):
+        self._deque = collections.deque()
+        self._pending = self._deque
+        self._number_transactions_exec = 0
+
+    def add(self, f, *args, **kwds):
+        """Register a new transaction to be done by 'f(*args, **kwds)'.
+        """
+        # note: 'self._pending.append' can be two things here:
+        # * if we are outside run(), it is the regular deque.append method;
+        # * if we are inside run(), self._pending is a thread._local()
+        #   and then its append attribute is the append method of a
+        #   thread-local list.
+        self._pending.append((f, args, kwds))
+
+    def run(self, nb_segments=0):
+        """Run all transactions, and all transactions started by these
+        ones, recursively, until the queue is empty.  If one transaction
+        raises, run() re-raises the exception and the unexecuted transaction
+        are left in the queue.
+        """
+        if is_atomic():
+            raise TransactionError(
+                "TransactionQueue.run() cannot be called in an atomic context")
+        if not self._pending:
+            return
         if nb_segments <= 0:
             nb_segments = getsegmentlimit()
-        _thread_pool.ensure_threads(nb_segments)
 
-    def __enter__(self):
-        if hasattr(_thread_local, "pending"):
-            raise TransactionError(
-                "recursive invocation of TransactionQueue()")
-        if is_atomic():
-            raise TransactionError(
-                "invocation of TransactionQueue() from an atomic context")
-        _thread_local.pending = []
-        atomic.__enter__()
+        assert self._pending is self._deque, "broken state"
+        try:
+            self._pending = thread._local()
+            lock_done_running = thread.allocate_lock()
+            lock_done_running.acquire()
+            lock_deque = thread.allocate_lock()
+            locks = []
+            exception = []
+            args = (locks, lock_done_running, lock_deque,
+                    exception, nb_segments)
+            #
+            for i in range(nb_segments):
+                thread.start_new_thread(self._thread_runner, args)
+            #
+            # The threads run here, and they will release this lock when
+            # they are all finished.
+            lock_done_running.acquire()
+            #
+            assert len(locks) == nb_segments
+            for lock in locks:
+                lock.release()
+            #
+        finally:
+            self._pending = self._deque
+        #
+        if exception:
+            exc_type, exc_value, exc_traceback = exception
+            raise exc_type, exc_value, exc_traceback
 
-    def __exit__(self, exc_type, exc_value, traceback):
-        atomic.__exit__(exc_type, exc_value, traceback)
-        pending = _thread_local.pending
-        del _thread_local.pending
-        if exc_type is None and len(pending) > 0:
-            _thread_pool.run(pending)
+    def number_of_transactions_executed(self):
+        return self._number_transactions_exec
 
-
-# ____________________________________________________________
-
-
-class _ThreadPool(object):
-
-    def __init__(self):
-        self.lock_running = thread.allocate_lock()
-        self.lock_done_running = thread.allocate_lock()
-        self.lock_done_running.acquire()
-        self.nb_threads = 0
-        self.deque = collections.deque()
-        self.locks = []
-        self.lock_deque = thread.allocate_lock()
-        self.exception = []
-
-    def ensure_threads(self, n):
-        if n > self.nb_threads:
-            with self.lock_running:
-                for i in range(self.nb_threads, n):
-                    assert len(self.locks) == self.nb_threads
-                    self.nb_threads += 1
-                    thread.start_new_thread(self.thread_runner, ())
-                    # The newly started thread should run immediately into
-                    # the case 'if len(self.locks) == self.nb_threads:'
-                    # and release this lock.  Wait until it does.
-                    self.lock_done_running.acquire()
-
-    def run(self, pending):
-        # For now, can't run multiple threads with each an independent
-        # TransactionQueue(): they are serialized.
-        with self.lock_running:
-            assert self.exception == []
-            assert len(self.deque) == 0
-            deque = self.deque
-            with self.lock_deque:
-                deque.extend(pending)
-                try:
-                    for i in range(len(pending)):
-                        self.locks.pop().release()
-                except IndexError:     # pop from empty list
-                    pass
-            #
-            self.lock_done_running.acquire()
-            #
-            if self.exception:
-                exc_type, exc_value, exc_traceback = self.exception
-                del self.exception[:]
-                raise exc_type, exc_value, exc_traceback
-
-    def thread_runner(self):
-        deque = self.deque
+    def _thread_runner(self, locks, lock_done_running, lock_deque,
+                       exception, nb_segments):
+        pending = []
+        self._pending.append = pending.append
+        deque = self._deque
         lock = thread.allocate_lock()
         lock.acquire()
-        pending = []
-        _thread_local.pending = pending
-        lock_deque = self.lock_deque
-        exception = self.exception
+        next_transaction = None
+        count = [0]
         #
-        while True:
+        def _pause_thread():
+            self._number_transactions_exec += count[0]
+            count[0] = 0
+            locks.append(lock)
+            if len(locks) == nb_segments:
+                lock_done_running.release()
+            lock_deque.release()
+            #
+            # Now wait until our lock is released.
+            lock.acquire()
+            return len(locks) == nb_segments
+        #
+        while not exception:
+            assert next_transaction is None
             #
             # Look at the deque and try to fetch the next item on the left.
             # If empty, we add our lock to the 'locks' list.
@@ -216,13 +211,8 @@
                 next_transaction = deque.popleft()
                 lock_deque.release()
             else:
-                self.locks.append(lock)
-                if len(self.locks) == self.nb_threads:
-                    self.lock_done_running.release()
-                lock_deque.release()
-                #
-                # Now wait until our lock is released.
-                lock.acquire()
+                if _pause_thread():
+                    return
                 continue
             #
             # Now we have a next_transaction.  Run it.
@@ -230,13 +220,16 @@
             while True:
                 f, args, kwds = next_transaction
                 with atomic:
-                    if len(exception) == 0:
-                        try:
-                            with signals_enabled:
-                                f(*args, **kwds)
-                        except:
-                            exception.extend(sys.exc_info())
-                del next_transaction
+                    if exception:
+                        break
+                    next_transaction = None
+                    try:
+                        with signals_enabled:
+                            count[0] += 1
+                            f(*args, **kwds)
+                    except:
+                        exception.extend(sys.exc_info())
+                        break
                 #
                 # If no new 'pending' transactions have been added, exit
                 # this loop and go back to fetch more from the deque.
@@ -252,19 +245,27 @@
                 # single item to 'next_transaction', because it looks
                 # like a good idea to preserve some first-in-first-out
                 # approximation.)
-                with self.lock_deque:
+                with lock_deque:
                     deque.extend(pending)
                     next_transaction = deque.popleft()
                     try:
                         for i in range(1, len(pending)):
-                            self.locks.pop().release()
+                            locks.pop().release()
                     except IndexError:     # pop from empty list
                         pass
                 del pending[:]
+        #
+        # We exit here with an exception.  Re-add 'next_transaction'
+        # if it is not None.
+        lock_deque.acquire()
+        if next_transaction is not None:
+            deque.appendleft(next_transaction)
+            next_transaction = None
+        while not _pause_thread():
+            lock_deque.acquire()
 
 
-_thread_pool = _ThreadPool()
-_thread_local = thread._local()
+# ____________________________________________________________
 
 
 def XXXreport_abort_info(info):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to