Author: Armin Rigo <[email protected]>
Branch: queue
Changeset: r1859:e1de0881bd13
Date: 2015-06-18 14:34 +0200
http://bitbucket.org/pypy/stmgc/changeset/e1de0881bd13/

Log:    task_done() and join() interface of Queue.Queue

diff --git a/c8/stm/queue.c b/c8/stm/queue.c
--- a/c8/stm/queue.c
+++ b/c8/stm/queue.c
@@ -25,6 +25,10 @@
            notion is per segment.)  this flag says that the queue is
            already in the tree STM_PSEGMENT->active_queues. */
         bool active;
+
+        /* counts the number of put's done in this transaction, minus
+           the number of task_done's */
+        int64_t unfinished_tasks_in_this_transaction;
     };
     char pad[64];
 } stm_queue_segment_t;
@@ -37,6 +41,10 @@
 
     /* a chained list of old entries in the queue */
     queue_entry_t *volatile old_entries;
+
+    /* total of 'unfinished_tasks_in_this_transaction' for all
+       committed transactions */
+    volatile int64_t unfinished_tasks;
 };
 
 
@@ -126,6 +134,7 @@
     queue_lock_acquire();
 
     bool added_any_old_entries = false;
+    bool finished_more_tasks = false;
     wlog_t *item;
     TREE_LOOP_FORWARD(STM_PSEGMENT->active_queues, item) {
         stm_queue_t *queue = (stm_queue_t *)item->addr;
@@ -133,6 +142,11 @@
         queue_entry_t *head, *freehead;
 
         if (at_commit) {
+            int64_t d = seg->unfinished_tasks_in_this_transaction;
+            if (d != 0) {
+                finished_more_tasks |= (d < 0);
+                __sync_add_and_fetch(&queue->unfinished_tasks, d);
+            }
             head = seg->added_in_this_transaction;
             freehead = seg->old_objects_popped;
         }
@@ -145,6 +159,7 @@
         seg->added_in_this_transaction = NULL;
         seg->added_young_limit = NULL;
         seg->old_objects_popped = NULL;
+        seg->unfinished_tasks_in_this_transaction = 0;
 
         /* free the list of entries that must disappear */
         queue_free_entries(freehead);
@@ -176,10 +191,11 @@
 
     queue_lock_release();
 
-    if (added_any_old_entries) {
-        assert(_has_mutex());
+    assert(_has_mutex());
+    if (added_any_old_entries)
         cond_broadcast(C_QUEUE_OLD_ENTRIES);
-    }
+    if (finished_more_tasks)
+        cond_broadcast(C_QUEUE_FINISHED_MORE_TASKS);
 }
 
 void stm_queue_put(object_t *qobj, stm_queue_t *queue, object_t *newitem)
@@ -195,6 +211,7 @@
     seg->added_in_this_transaction = entry;
 
     queue_activate(queue);
+    seg->unfinished_tasks_in_this_transaction++;
 
     /* add qobj to 'objects_pointing_to_nursery' if it has the
        WRITE_BARRIER flag */
@@ -285,6 +302,41 @@
     }
 }
 
+void stm_queue_task_done(stm_queue_t *queue)
+{
+    queue_activate(queue);
+    stm_queue_segment_t *seg = &queue->segs[STM_SEGMENT->segment_num - 1];
+    seg->unfinished_tasks_in_this_transaction--;
+}
+
+int stm_queue_join(object_t *qobj, stm_queue_t *queue, stm_thread_local_t *tl)
+{
+    int64_t result;
+
+#if STM_TESTS
+    result = queue->unfinished_tasks;   /* can't wait in tests */
+    result += (queue->segs[STM_SEGMENT->segment_num - 1]
+               .unfinished_tasks_in_this_transaction);
+    if (result > 0)
+        return 42;
+#else
+    STM_PUSH_ROOT(*tl, qobj);
+    _stm_commit_transaction();
+
+    s_mutex_lock();
+    while ((result = queue->unfinished_tasks) > 0) {
+        cond_wait(C_QUEUE_FINISHED_MORE_TASKS);
+    }
+    s_mutex_unlock();
+
+    _stm_start_transaction(tl);
+    STM_POP_ROOT(*tl, qobj);   /* 'queue' should stay alive until here */
+#endif
+
+    /* returns 1 for 'ok', or 0 for error: negative 'unfinished_tasks' */
+    return (result == 0);
+}
+
 static void queue_trace_list(queue_entry_t *entry, void trace(object_t **),
                              queue_entry_t *stop_at)
 {
diff --git a/c8/stm/sync.h b/c8/stm/sync.h
--- a/c8/stm/sync.h
+++ b/c8/stm/sync.h
@@ -7,6 +7,7 @@
     C_SEGMENT_FREE,
     C_SEGMENT_FREE_OR_SAFE_POINT,
     C_QUEUE_OLD_ENTRIES,
+    C_QUEUE_FINISHED_MORE_TASKS,
     _C_TOTAL
 };
 
diff --git a/c8/stmgc.h b/c8/stmgc.h
--- a/c8/stmgc.h
+++ b/c8/stmgc.h
@@ -747,6 +747,11 @@
    transaction (this is needed to ensure correctness). */
 object_t *stm_queue_get(object_t *qobj, stm_queue_t *queue, double timeout,
                         stm_thread_local_t *tl);
+/* task_done() and join(): see https://docs.python.org/2/library/queue.html */
+void stm_queue_task_done(stm_queue_t *queue);
+/* join() commits and waits outside a transaction (so push roots).
+   Unsuitable if the current transaction is atomic! */
+int stm_queue_join(object_t *qobj, stm_queue_t *queue, stm_thread_local_t *tl);
 void stm_queue_tracefn(stm_queue_t *queue, void trace(object_t **));
 
 
diff --git a/c8/test/support.py b/c8/test/support.py
--- a/c8/test/support.py
+++ b/c8/test/support.py
@@ -225,6 +225,8 @@
 void stm_queue_put(object_t *qobj, stm_queue_t *queue, object_t *newitem);
 object_t *stm_queue_get(object_t *qobj, stm_queue_t *queue, double timeout,
                         stm_thread_local_t *tl);
+void stm_queue_task_done(stm_queue_t *queue);
+int stm_queue_join(object_t *qobj, stm_queue_t *queue, stm_thread_local_t *tl);
 void stm_queue_tracefn(stm_queue_t *queue, void trace(object_t **));
 
 void _set_queue(object_t *obj, stm_queue_t *q);
@@ -658,7 +660,9 @@
 
 def get_hashtable(o):
     assert lib._get_type_id(o) == 421419
-    return lib._get_hashtable(o)
+    h = lib._get_hashtable(o)
+    assert h
+    return h
 
 def stm_allocate_queue():
     o = lib.stm_allocate(16)
@@ -670,7 +674,9 @@
 
 def get_queue(o):
     assert lib._get_type_id(o) == 421417
-    return lib._get_queue(o)
+    q = lib._get_queue(o)
+    assert q
+    return q
 
 def stm_get_weakref(o):
     return lib._get_weakref(o)
diff --git a/c8/test/test_queue.py b/c8/test/test_queue.py
--- a/c8/test/test_queue.py
+++ b/c8/test/test_queue.py
@@ -18,7 +18,7 @@
             try:
                 assert lib._get_type_id(obj) == 421417
                 self.seen_queues -= 1
-                q = lib._get_queue(obj)
+                q = get_queue(obj)
                 lib.stm_queue_free(q)
             except:
                 self.errors.append(sys.exc_info()[2])
@@ -42,16 +42,30 @@
         return q
 
     def get(self, obj):
-        q = lib._get_queue(obj)
+        q = get_queue(obj)
         res = lib.stm_queue_get(obj, q, 0.0, self.tls[self.current_thread])
         if res == ffi.NULL:
             raise Empty
         return res
 
     def put(self, obj, newitem):
-        q = lib._get_queue(obj)
+        q = get_queue(obj)
         lib.stm_queue_put(obj, q, newitem)
 
+    def task_done(self, obj):
+        q = get_queue(obj)
+        lib.stm_queue_task_done(q)
+
+    def join(self, obj):
+        q = get_queue(obj)
+        res = lib.stm_queue_join(obj, q, self.tls[self.current_thread]);
+        if res == 1:
+            return
+        elif res == 42:
+            raise Conflict("join() cannot wait in tests")
+        else:
+            raise AssertionError("stm_queue_join error")
+
 
 class TestQueue(BaseTestQueue):
 
@@ -299,3 +313,51 @@
             self.push_root(qobj)
             stm_minor_collect()
             qobj = self.pop_root()
+
+    def test_task_done_1(self):
+        self.start_transaction()
+        qobj = self.allocate_queue()
+        self.push_root(qobj)
+        stm_minor_collect()
+        qobj = self.pop_root()
+        self.join(qobj)
+        obj1 = stm_allocate(32)
+        self.put(qobj, obj1)
+        py.test.raises(Conflict, self.join, qobj)
+        self.get(qobj)
+        py.test.raises(Conflict, self.join, qobj)
+        self.task_done(qobj)
+        self.join(qobj)
+
+    def test_task_done_2(self):
+        self.start_transaction()
+        qobj = self.allocate_queue()
+        self.push_root(qobj)
+        self.put(qobj, stm_allocate(32))
+        self.put(qobj, stm_allocate(32))
+        self.get(qobj)
+        self.get(qobj)
+        self.commit_transaction()
+        qobj = self.pop_root()
+        #
+        self.start_transaction()
+        py.test.raises(Conflict, self.join, qobj)
+        #
+        self.switch(1)
+        self.start_transaction()
+        py.test.raises(Conflict, self.join, qobj)
+        self.task_done(qobj)
+        py.test.raises(Conflict, self.join, qobj)
+        self.task_done(qobj)
+        self.join(qobj)
+        #
+        self.switch(0)
+        py.test.raises(Conflict, self.join, qobj)
+        #
+        self.switch(1)
+        self.commit_transaction()
+        #
+        self.switch(0)
+        self.join(qobj)
+        #
+        stm_major_collect()       # to get rid of the queue object
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to