# HG changeset patch
# User Martin Geisler <[EMAIL PROTECTED]>
# Date 1225723352 -3600
# Node ID f44d1f9c2c0051c4e4ac70164fdf483c4885af2a
# Parent  77e1cc1fb9b9a776affafaa0b2727589e378b3bd
Experimental switching of protocols at runtime.
I got tired of seeing stringReceived in ShareExchanger always test
self.peer_id, so I decided to split it into two protocols: one for
establishing the ID, and another for exchanging shares.

This patch seems to work, except that it breaks the test suite quite
badly. I haven't been able to fix it, so now I'm just sending it out
here for you to look at.

diff --git a/viff/runtime.py b/viff/runtime.py
--- a/viff/runtime.py
+++ b/viff/runtime.py
@@ -51,6 +51,31 @@
 from twisted.protocols.basic import Int16StringReceiver
 
 
+def switch_protocol(old, new):
+    """Switch from one :class:`IntNStringReceiver` protocol to
+    another.
+
+    Any data buffered by the old protocol is transferred to the new
+    protocol, the transport is updated to call the new protocol, and
+    finally the :meth:`connectionMade` method is called on the new
+    protocol.
+    """
+    # Copy important attributes from old protocol
+    new.factory = old.factory
+    new.transport = old.transport
+    new.recvd = old.recvd
+    # Put old protocol on pause. This is necessary to stop it from
+    # calling stringReceived on the old protocol in case it had
+    # buffered enough data for two strings.
+    old.paused = True
+    # Switch the protocol's transport's protocol.
+    old.transport.protocol = new
+    # Trigger new protocol.
+    new.connectionMade()
+    # Trigger delivery of any data buffered by the old protocol.
+    new.dataReceived("")
+
+
 class Share(Deferred):
     """A shared number.
 
@@ -246,6 +271,28 @@
     share_list.addCallback(filter_results)
     return share_list
 
+class EstablishID(Int16StringReceiver):
+
+    def connectionMade(self):
+        self.sendString(str(self.factory.runtime.id))
+
+    def stringReceived(self, string):
+        """Called when an ID is received."""
+        # TODO: Handle ValueError if the string cannot be decoded.
+        peer_id = int(string)
+        try:
+            cert = self.transport.getPeerCertificate()
+        except AttributeError:
+            cert = None
+        if cert:
+            # The player ID are stored in the serial number of the
+            # certificate -- this makes it easy to check that the
+            # player is who he claims to be.
+            if cert.get_serial_number() != peer_id:
+                print "Peer %s claims to be %d, aborting!" \
+                    % (cert.get_subject(), peer_id)
+                self.transport.loseConnection()
+        self.factory.identify_peer(self, peer_id)
 
 class ShareExchanger(Int16StringReceiver):
     """Send and receive shares.
@@ -259,14 +306,10 @@
     """
 
     def __init__(self):
-        self.peer_id = None
         self.lost_connection = Deferred()
         #: Data expected to be received in the future.
         self.incoming_data = {}
 
-    def connectionMade(self):
-        self.sendString(str(self.factory.runtime.id))
-
     def connectionLost(self, reason):
         reason.trap(ConnectionDone)
         self.lost_connection.callback(self)
@@ -278,37 +321,20 @@
         and a data part. The data is passed the appropriate Deferred
         in :class:`self.incoming_data`.
         """
-        if self.peer_id is None:
-            # TODO: Handle ValueError if the string cannot be decoded.
-            self.peer_id = int(string)
-            try:
-                cert = self.transport.getPeerCertificate()
-            except AttributeError:
-                cert = None
-            if cert:
-                # The player ID are stored in the serial number of the
-                # certificate -- this makes it easy to check that the
-                # player is who he claims to be.
-                if cert.get_serial_number() != self.peer_id:
-                    print "Peer %s claims to be %d, aborting!" \
-                        % (cert.get_subject(), self.peer_id)
-                    self.transport.loseConnection()
-            self.factory.identify_peer(self)
+        program_counter, data_type, data = marshal.loads(string)
+        key = (program_counter, data_type)
+
+        deq = self.incoming_data.setdefault(key, deque())
+        if deq and isinstance(deq[0], Deferred):
+            deferred = deq.popleft()
+            if not deq:
+                del self.incoming_data[key]
+            deferred.callback(data)
         else:
-            program_counter, data_type, data = marshal.loads(string)
-            key = (program_counter, data_type)
+            deq.append(data)
 
-            deq = self.incoming_data.setdefault(key, deque())
-            if deq and isinstance(deq[0], Deferred):
-                deferred = deq.popleft()
-                if not deq:
-                    del self.incoming_data[key]
-                deferred.callback(data)
-            else:
-                deq.append(data)
-
-            # TODO: marshal.loads can raise EOFError, ValueError, and
-            # TypeError. They should be handled somehow.
+        # TODO: marshal.loads can raise EOFError, ValueError, and
+        # TypeError. They should be handled somehow.
 
     def sendData(self, program_counter, data_type, data):
         send_data = (program_counter, data_type, data)
@@ -330,7 +356,7 @@
 class ShareExchangerFactory(ReconnectingClientFactory, ServerFactory):
     """Factory for creating ShareExchanger protocols."""
 
-    protocol = ShareExchanger
+    protocol = EstablishID
 
     def __init__(self, runtime, players, protocols_ready):
         """Initialize the factory."""
@@ -339,8 +365,11 @@
         self.needed_protocols = len(players) - 1
         self.protocols_ready = protocols_ready
 
-    def identify_peer(self, protocol):
-        self.runtime.add_player(self.players[protocol.peer_id], protocol)
+    def identify_peer(self, protocol, peer_id):
+        share_exchanger = ShareExchanger()
+        # Redirect new traffic to the ShareExchanger protocol.
+        switch_protocol(protocol, share_exchanger)
+        self.runtime.add_player(self.players[peer_id], share_exchanger)
         self.needed_protocols -= 1
         if self.needed_protocols == 0:
             self.protocols_ready.callback(self.runtime)
@@ -496,6 +525,7 @@
         self.add_player(player, None)
 
     def add_player(self, player, protocol):
+        print "Player %d noticed Player %d" % (self.id, player.id)
         self.players[player.id] = player
         self.num_players = len(self.players)
         # There is no protocol for ourselves, so we wont add that:
diff --git a/viff/test/util.py b/viff/test/util.py
--- a/viff/test/util.py
+++ b/viff/test/util.py
@@ -20,7 +20,7 @@
 from twisted.internet.defer import Deferred, gatherResults, maybeDeferred
 from twisted.trial.unittest import TestCase
 
-from viff.runtime import PassiveRuntime, Share, ShareExchanger, 
ShareExchangerFactory
+from viff.runtime import PassiveRuntime, Share, ShareExchanger
 from viff.field import GF
 from viff.config import generate_configs, load_config
 from viff.util import rand
@@ -155,19 +155,14 @@
         result = Deferred()
 
         # Create a runtime that knows about no other players than itself.
-        # It will eventually be returned in result when the factory has
-        # determined that all needed protocols are ready.
         runtime = self.runtime_class(players[id], self.threshold)
-        factory = ShareExchangerFactory(runtime, players, result)
-        # We add the Deferred passed to ShareExchangerFactory and not
-        # the Runtime, since we want everybody to wait until all
-        # runtimes are ready.
         self.runtimes.append(result)
 
-        for peer_id in players:
+        for peer_id, peer in players.iteritems():
             if peer_id != id:
                 protocol = ShareExchanger()
-                protocol.factory = factory
+                print "Adding %s for %s to %s" % (protocol, peer, runtime)
+                runtime.add_player(peer, protocol)
 
                 # Keys for when we are the client and when we are the server.
                 client_key = (id, peer_id)
_______________________________________________
viff-patches mailing list
[email protected]
http://lists.viff.dk/listinfo.cgi/viff-patches-viff.dk

Reply via email to