Hi friends of VIFF,

I've finally completed the patch for a two-threaded VIFF where most of the VIFF code runs in a separate thread. The patch is against the tip of my repository: http://hg.viff.dk/mkeller

It turned out be not so straight-forward as I thought. I had to use a recursion, as in the hack, but I refined it to ensure that the recursion limit isn't exceeded.

Benchmarks:
Unfortunately, this solution is slower than the hack, e.g. one AES block encryption takes 4 seconds compared to 3 seconds with the hack. On the other hand, the preprocessing time in the actively secure multiplication is linear and not quadratic, whereas the online time is significantly larger:

        two-threaded            hack                    original
(n,t)   online  preprocessing   online  preprocessing   online  preproc.
(4,1)   6       22              4       17              4       20
(7,2)   10      37              6       29              6       42
(10,3)  13      53              8       42              8       82
(13,4)  17      68              10      56              10      136
(16,5)  20      84              12      68              12      208
(19,6)  23      106             13      83              14      287
(22,7)  26      120             15      98              17      377

I did some profiling and didn't find an obvious reason why the two-thread is slower. Therefore, I guess that the reason is the multi-threading implementation of Python (which could be better, as mentioned in the discussion about the hack). The guess is also supported by the fact that having an own thread for every callback, which I also tried, turned out to be really slow.

Unit tests:
All unit test get passed, even the previosly skipped test_multiple_callbacks. This because I added the @increment_pc decorator to schedule_callback(). This of course changes the program counters heavily but I didn't experience any problems.

Best regards,
Marcel
diff -r e2759515f57f apps/aes.py
--- a/apps/aes.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/apps/aes.py	Tue Apr 21 17:25:45 2009 +0200
@@ -82,7 +82,7 @@
         rt.shutdown()
 
     g = gather_shares(opened_ciphertext)
-    g.addCallback(fin)
+    rt.schedule_callback(g, fin)
 
 def share_key(rt):
     key =  []
diff -r e2759515f57f apps/benchmark.py
--- a/apps/benchmark.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/apps/benchmark.py	Tue Apr 21 17:25:45 2009 +0200
@@ -91,6 +91,7 @@
     print "Total time used: %.3f sec" % (stop-start)
     print "Time per %s operation: %.0f ms" % (what, 1000*(stop-start) / count)
     print "*" * 6
+    sys.stdout.flush()
 
 
 operations = ["mul", "compToft05", "compToft07", "eq"]
@@ -174,7 +175,7 @@
             # run with no preprocessing. So they are quite brittle.
             if self.operation == operator.mul:
                 key = ("generate_triples", (Zp,))
-                desc = [(i, 1, 0) for i in range(2, 2 + count)]
+                desc = [(2, 2, 2 * count + 1, 2, i, 1, 0) for i in range(1, count + 1)]
                 program_desc.setdefault(key, []).extend(desc)
             elif isinstance(self.rt, ComparisonToft05Mixin):
                 key = ("generate_triples", (GF256,))
@@ -228,7 +229,8 @@
         if seconds > 0:
             print "Starting test in %d" % seconds
             sys.stdout.flush()
-            reactor.callLater(1, self.countdown, None, seconds - 1)
+            time.sleep(1)
+            self.countdown(None, seconds - 1)
         else:
             print "Starting test now"
             sys.stdout.flush()
@@ -255,6 +257,7 @@
     def run_test(self, _):
         c_shares = []
         record_start("parallel test")
+        sys.stdout.flush()
         while self.a_shares and self.b_shares:
             a = self.a_shares.pop()
             b = self.b_shares.pop()
diff -r e2759515f57f apps/millionaires.py
--- a/apps/millionaires.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/apps/millionaires.py	Tue Apr 21 17:25:45 2009 +0200
@@ -97,10 +97,10 @@
         # the argument (which is None since self.results_ready does
         # not return anything), so we throw it away using a lambda
         # expressions which ignores its first argument.
-        results.addCallback(lambda _: runtime.synchronize())
+        runtime.schedule_callback(results, lambda _: runtime.synchronize())
         # The next callback shuts the runtime down, killing the
         # connections between the players.
-        results.addCallback(lambda _: runtime.shutdown())
+        runtime.schedule_callback(results, lambda _: runtime.shutdown())
 
     def results_ready(self, results):
         # Since this method is called as a callback above, the results
diff -r e2759515f57f viff/active.py
--- a/viff/active.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/active.py	Tue Apr 21 17:25:45 2009 +0200
@@ -501,6 +501,7 @@
         result = Share(self, share_x.field)
         # This is the Deferred we will do processing on.
         triple = self.get_triple(share_x.field)
+        triple.addCallback(gather_shares)
         self.schedule_callback(triple, finish_mul)
         # We add the result to the chains in triple.
         triple.chainDeferred(result)
diff -r e2759515f57f viff/aes.py
--- a/viff/aes.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/aes.py	Tue Apr 21 17:25:45 2009 +0200
@@ -374,9 +374,9 @@
                 trigger.addCallback(progress, i, time.time())
 
                 if (i < self.rounds - 1):
-                    self.runtime.schedule_callback(trigger, round, state, i + 1)
+                    self.runtime.schedule_complex_callback(trigger, round, state, i + 1)
                 else:
-                    self.runtime.schedule_callback(trigger, final_round, state)
+                    self.runtime.schedule_complex_callback(trigger, final_round, state)
 
             prep_progress(i, start_round)
 
diff -r e2759515f57f viff/equality.py
--- a/viff/equality.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/equality.py	Tue Apr 21 17:25:45 2009 +0200
@@ -74,7 +74,7 @@
                 xj = (-1) * (1/Zp(2)) * (bj - 1)
             return xj
 
-        x = [cj.addCallback(finish, bj) for cj, bj in zip(c, b)]
+        x = [self.schedule_callback(cj, finish, bj) for cj, bj in zip(c, b)]
 
         # Take the product (this is here the same as the "and") of all
         # the x'es
diff -r e2759515f57f viff/passive.py
--- a/viff/passive.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/passive.py	Tue Apr 21 17:25:45 2009 +0200
@@ -98,7 +98,7 @@
                         d = Share(self, share.field, (share.field(peer_id), share))
                     else:
                         d = self._expect_share(peer_id, share.field)
-                        self.schedule_callback(d, lambda s, peer_id: (s.field(peer_id), s), peer_id)
+                        d.addCallback(lambda s, peer_id: (s.field(peer_id), s), peer_id)
                     deferreds.append(d)
                 return recombine(deferreds)
 
diff -r e2759515f57f viff/runtime.py
--- a/viff/runtime.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/runtime.py	Tue Apr 21 17:25:45 2009 +0200
@@ -39,9 +39,10 @@
 from collections import deque
 
 from viff.field import GF256, FieldElement
-from viff.util import wrapper, rand, deep_wait, track_memory_usage
+from viff.util import wrapper, rand, deep_wait, track_memory_usage, \
+     clone_deferred
 
-from twisted.internet import reactor
+from twisted.internet import reactor, selectreactor
 from twisted.internet.task import LoopingCall
 from twisted.internet.error import ConnectionDone, CannotListenError
 from twisted.internet.defer import Deferred, DeferredList, gatherResults, succeed
@@ -49,6 +50,10 @@
 from twisted.internet.protocol import ReconnectingClientFactory, ServerFactory
 from twisted.protocols.basic import Int16StringReceiver
 
+from Queue import Queue, Empty
+from threading import Lock, Event
+import sys
+
 
 class Share(Deferred):
     """A shared number.
@@ -78,6 +83,7 @@
         self.field = field
         if value is not None:
             self.callback(value)
+        self.priority = 0
 
     def __add__(self, other):
         """Addition."""
@@ -212,14 +218,19 @@
 
         for index, share in enumerate(shares):
             share.addCallbacks(self._callback_fired, self._callback_fired,
-                               callbackArgs=(index, True),
+                               callbackArgs=(index, True, share.priority),
                                errbackArgs=(index, False))
 
-    def _callback_fired(self, result, index, success):
+    def _callback_fired(self, result, index, success, priority=0):
         self.results[index] = (success, result)
         self.missing_shares -= 1
         if not self.called and self.missing_shares == 0:
-            self.callback(self.results)
+            if self.priority >= priority:
+                self.callback(self.results)
+            else:
+                self.pause()
+                self.callback(self.results)
+                self.runtime.deferred_queue.put((self, None))
         return result
 
 
@@ -266,6 +277,9 @@
         self.lost_connection = Deferred()
         #: Data expected to be received in the future.
         self.incoming_data = {}
+        #: Lock to protect :attr:`incoming_data`.
+        self.incoming_lock = Lock()
+        self.activation_counter = 0
 
     def connectionMade(self):
         self.sendString(str(self.factory.runtime.id))
@@ -301,21 +315,32 @@
             program_counter, data_type, data = marshal.loads(string)
             key = (program_counter, data_type)
 
+            self.incoming_lock.acquire()
             deq = self.incoming_data.setdefault(key, deque())
             if deq and isinstance(deq[0], Deferred):
                 deferred = deq.popleft()
                 if not deq:
                     del self.incoming_data[key]
+                self.incoming_lock.release()
+                # queue deferred
+                deferred.pause()
                 deferred.callback(data)
+                self.factory.runtime.deferred_queue.put((deferred, program_counter))
             else:
                 deq.append(data)
+                self.incoming_lock.release()
 
             # TODO: marshal.loads can raise EOFError, ValueError, and
             # TypeError. They should be handled somehow.
 
     def sendData(self, program_counter, data_type, data):
         send_data = (program_counter, data_type, data)
-        self.sendString(marshal.dumps(send_data))
+        reactor.threadCallQueue.append((self.sendString, [marshal.dumps(send_data)], {}))
+        self.activation_counter +=1
+
+        if (self.activation_counter >= 1):
+            reactor.wakeUp()
+            self.activation_counter = 0
 
     def sendShare(self, program_counter, share):
         """Send a share.
@@ -501,6 +526,15 @@
         # communicating with ourselves.
         self.add_player(player, None)
 
+        #: Blocking queue for Deferreds.
+        self.deferred_queue = Queue(0)
+        #: Current recursion depth.
+        self.recursion_depth = 0
+        #: Activation counter.
+        self.activation_counter = 0
+        #: Event is set while the reactor is waiting (select syscall).
+        self.select_event = Event()
+
     def add_player(self, player, protocol):
         self.players[player.id] = player
         self.num_players = len(self.players)
@@ -522,18 +556,18 @@
             results = [maybeDeferred(self.port.stopListening)]
             for protocol in self.protocols.itervalues():
                 results.append(protocol.lost_connection)
-                protocol.loseConnection()
+                reactor.callFromThread(protocol.loseConnection)
             return DeferredList(results)
 
-        def stop_reactor(_):
+        def stop_reactor(_, self):
             print "done."
             print "Stopping reactor... ",
-            reactor.stop()
+            reactor.callFromThread(reactor.stop)
             print "done."
 
         sync = self.synchronize()
         sync.addCallback(close_connections)
-        sync.addCallback(stop_reactor)
+        sync.addCallback(stop_reactor, self)
         return sync
 
     def wait_for(self, *vars):
@@ -542,8 +576,9 @@
         The runtime is shut down when all variables are calculated.
         """
         dl = DeferredList(vars)
-        dl.addCallback(lambda _: self.shutdown())
+        self.schedule_callback(dl, lambda _: self.shutdown())
 
+    @increment_pc
     def schedule_callback(self, deferred, func, *args, **kwargs):
         """Schedule a callback on a deferred with the correct program
         counter.
@@ -574,7 +609,64 @@
             finally:
                 self.program_counter[:] = current_pc
 
-        deferred.addCallback(callback_wrapper, *args, **kwargs)
+        return deferred.addCallback(callback_wrapper, *args, **kwargs)
+
+    def schedule_complex_callback(self, share, func, *args, **kwargs):
+        """Schedule a complex callback, i.e. a callback which blocks a
+        long time."""
+
+        assert isinstance(share, Share), "Only shares can have complex callbacks."
+
+        share.priority = -1
+        self.schedule_callback(share, func, *args, **kwargs)
+
+    def profile_deferred_queue_loop(self):
+        import cProfile
+        prof = cProfile.Profile()
+        prof.runcall(self.deferred_queue_loop)
+        prof.print_stats(1)
+
+    def deferred_queue_loop(self):
+        while True:
+            deferred, pc = self.deferred_queue.get()
+
+            if deferred is not None:
+                deferred.unpause()
+                from twisted.python import failure
+                if isinstance(deferred.result,failure.Failure):
+                    deferred.result.printTraceback()
+                sys.stdout.flush()
+            else:
+                return
+
+    def process_deferred_queue(self):
+        self.select_event.wait()
+
+        max_depth = 1
+
+        if (self.recursion_depth >= max_depth):
+            return
+
+        self.recursion_depth += 1
+
+        while True:
+            try:
+                deferred, pc = self.deferred_queue.get(block=False)
+            except Empty:
+                break
+            else:
+                if deferred is not None:
+                    if isinstance(deferred, Share) and \
+                           self.recursion_depth / float(max_depth) - deferred.priority > 1:
+                        self.deferred_queue.put((deferred, None))
+                        break
+                    else:
+                        deferred.unpause()
+                else:
+                    self.deferred_queue.put((None, None))
+                    break
+
+        self.recursion_depth -= 1
 
     @increment_pc
     def synchronize(self):
@@ -598,16 +690,25 @@
         pc = tuple(self.program_counter)
         key = (pc, data_type)
 
+        lock = self.protocols[peer_id].incoming_lock
+        lock.acquire()
         deq = self.protocols[peer_id].incoming_data.setdefault(key, deque())
         if deq and not isinstance(deq[0], Deferred):
             # We have already received some data from the other side.
             data = deq.popleft()
             if not deq:
                 del self.protocols[peer_id].incoming_data[key]
+            lock.release()
             deferred.callback(data)
         else:
             # We have not yet received anything from the other side.
             deq.append(deferred)
+            lock.release()
+            self.activation_counter += 1
+
+            if (self.activation_counter >= 1):
+                self.process_deferred_queue()
+                self.activation_counter = 0
 
     def _exchange_shares(self, peer_id, field_element):
         """Exchange shares with another player.
@@ -789,24 +890,13 @@
         # profiler here and stop it upon shutdown, but this triggers
         # http://bugs.python.org/issue1375 since the start and stop
         # calls are in different stack frames.
-        import hotshot
-        prof = hotshot.Profile("player-%d.prof" % id)
+        import cProfile
+        prof = cProfile.Profile()
         old_run = reactor.run
         def new_run(*args, **kwargs):
             print "Starting reactor with profiling"
             prof.runcall(old_run, *args, **kwargs)
-
-            import sys
-            import hotshot.stats
-            print "Loading profiling statistics...",
-            sys.stdout.flush()
-            stats = hotshot.stats.load("player-%d.prof" % id)
-            print "done."
-            print
-            stats.strip_dirs()
-            stats.sort_stats("time", "calls")
-            stats.print_stats(40)
-            stats.dump_stats("player-%d.pstats" % id)
+            prof.print_stats(1)
         reactor.run = new_run
 
     # This will yield a Runtime when all protocols are connected.
@@ -874,7 +964,38 @@
             print "Will connect to %s" % player
             connect(player.host, player.port)
 
-    return result
+    # Start the main callback in the separate thread.
+    deferred_runtime = clone_deferred(result)
+    deferred_runtime.pause()
+
+    def start_deferred_queue_loop(runtime, deferred_runtime, options):
+        print "Start VIFF thread"
+
+        runtime.deferred_queue.put((deferred_runtime, None))
+
+        if options and options.profile:
+            reactor.callInThread(runtime.profile_deferred_queue_loop)
+        else:
+            reactor.callInThread(runtime.deferred_queue_loop)
+
+        reactor.addSystemEventTrigger("before", "shutdown",
+                                      runtime.deferred_queue.put, (None, None))
+        return runtime
+
+    result.addCallback(start_deferred_queue_loop, deferred_runtime, options)
+
+    # Monkey patch selectreactor to known when it is waiting.
+    plain_select = selectreactor._select
+
+    def patched_select(*args, **kwargs):
+        runtime.select_event.set()
+        result = plain_select(*args, **kwargs)
+        runtime.select_event.clear()
+        return result
+
+    selectreactor._select = patched_select
+
+    return deferred_runtime
 
 if __name__ == "__main__":
     import doctest    #pragma NO COVER
diff -r e2759515f57f viff/test/test_basic_runtime.py
--- a/viff/test/test_basic_runtime.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/test/test_basic_runtime.py	Tue Apr 21 17:25:45 2009 +0200
@@ -62,13 +62,13 @@
         """
 
         def verify_program_counter(_):
-            self.assertEquals(runtime.program_counter, [0])
+            self.assertEquals(runtime.program_counter, [1, 0])
 
         d = Deferred()
         runtime.schedule_callback(d, verify_program_counter)
 
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [1])
+        self.assertEquals(runtime.program_counter, [2])
 
         # Now trigger verify_program_counter.
         d.callback(None)
@@ -129,8 +129,6 @@
         d2.callback(None)
 
         return gatherResults([d1, d2])
-    test_multiple_callbacks.skip = ("TODO: Scheduling callbacks fails to "
-                                    "increment program counter!")
 
     @protocol
     def test_multi_send(self, runtime):
diff -r e2759515f57f viff/test/util.py
--- a/viff/test/util.py	Thu Mar 05 21:02:57 2009 +0100
+++ b/viff/test/util.py	Tue Apr 21 17:25:45 2009 +0200
@@ -19,16 +19,19 @@
 
 from twisted.internet.defer import Deferred, gatherResults, maybeDeferred
 from twisted.trial.unittest import TestCase
+from twisted.internet import reactor
 
 from viff.passive import PassiveRuntime
 from viff.runtime import Share, ShareExchanger, ShareExchangerFactory
 from viff.field import GF
 from viff.config import generate_configs, load_config
-from viff.util import rand
+from viff.util import rand, clone_deferred
 from viff.test.loopback import loopbackAsync
 
 from random import Random
 
+from threading import Thread
+
 
 def protocol(method):
     """Decorator for protocol tests.
@@ -44,9 +47,12 @@
     def wrapper(self):
 
         def shutdown_protocols(result, runtime):
+            # Stop VIFF thread.
+            runtime.deferred_queue.put((None, None))
             # TODO: this should use runtime.shutdown instead.
             for protocol in runtime.protocols.itervalues():
-                protocol.loseConnection()
+                reactor.callFromThread(protocol.loseConnection)
+            reactor.wakeUp()
             # If we were called as an errback, then returning the
             # result signals the original error to Trial.
             return result
@@ -127,6 +133,7 @@
         self.close_sentinels = []
 
         self.runtimes = []
+        self.real_runtimes = []
         for id in reversed(range(1, self.num_players+1)):
             _, players = load_config(configs[id])
             self.create_loopback_runtime(id, players)
@@ -143,6 +150,9 @@
         for protocol in self.protocols.itervalues():
             protocol.transport.close()
 
+        for runtime in self.real_runtimes:
+            runtime.deferred_queue.put((None, None))
+
     def create_loopback_runtime(self, id, players):
         """Create a L{Runtime} connected with a loopback.
 
@@ -163,7 +173,10 @@
         # We add the Deferred passed to ShareExchangerFactory and not
         # the Runtime, since we want everybody to wait until all
         # runtimes are ready.
-        self.runtimes.append(result)
+        deferred_runtime = clone_deferred(result)
+        deferred_runtime.pause()
+        self.runtimes.append(deferred_runtime)
+        self.real_runtimes.append(runtime)
 
         for peer_id in players:
             if peer_id != id:
@@ -189,6 +202,14 @@
                     sentinel = loopbackAsync(server, client)
                     self.close_sentinels.append(sentinel)
 
+        def start_deferred_queue_loop(runtime, deferred_runtime):
+            runtime.deferred_queue.put((deferred_runtime, None))
+            runtime.select_event.set()
+            reactor.callInThread(runtime.deferred_queue_loop)
+            return runtime
+
+        result.addCallback(start_deferred_queue_loop, deferred_runtime)
+
 
 class BinaryOperatorTestCase:
     """Test a binary operator.
_______________________________________________
viff-devel mailing list (http://viff.dk/)
viff-devel@viff.dk
http://lists.viff.dk/listinfo.cgi/viff-devel-viff.dk

Reply via email to