# HG changeset patch
# User Janus Dam Nielsen <janus.niel...@alexandra.dk>
# Date 1245395036 -7200
# Node ID a07740da4582869d11ead0f56ae055965aa2b4b0
# Parent  07a8329e75322d482dae15186422dd75e9ddb653
Implementation of the basic multiplication command.

diff --git a/viff/orlandi.py b/viff/orlandi.py
--- a/viff/orlandi.py
+++ b/viff/orlandi.py
@@ -15,6 +15,8 @@
 # You should have received a copy of the GNU Lesser General Public
 # License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
 
+from twisted.internet.defer import DeferredList, gatherResults
+
 from viff.runtime import Runtime, increment_pc, Share, ShareList, gather_shares
 from viff.util import rand, dprint
 
@@ -442,6 +444,21 @@
              return results[0]
         return results       
 
+    @increment_pc
+    def mul(self, share_x, share_y):
+        """Multiplication of shares.
+
+        Communication cost: ???.
+        """
+        # TODO: Communication cost?
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+
+        a, b, c = self._get_triple(field)
+        return self._basic_multiplication(share_x, share_y, a, b, c)
+
     def _additive_constant(self, zero, field_element):
         """Greate an additive constant.
 
@@ -488,6 +505,97 @@
         Cz = Cx * Cy
         return (zi, (rhozi1, rhozi2), Cz)
 
+    def _cmul(self, share_x, share_y, field):
+        """Multiplication of a share with a constant.
+
+        Either share_x or share_y must be an OrlandiShare but not both.
+        Returns None if both share_x and share_y are OrlandiShares.
+
+        """
+        def constant_multiply(x, c):
+            zi, rhoz, Cx = self._const_mul(c, x)
+            return OrlandiShare(self, field, zi, rhoz, Cx)
+        if not isinstance(share_x, Share):
+            # Then share_y must be a Share => local multiplication. We
+            # clone first to avoid changing share_y.
+            assert isinstance(share_y, Share), \
+                "At least one of the arguments must be a share."
+            result = share_y.clone()
+            result.addCallback(constant_multiply, share_x)
+            return result
+        if not isinstance(share_y, Share):
+            # Likewise when share_y is a constant.
+            assert isinstance(share_x, Share), \
+                "At least one of the arguments must be a share."
+            result = share_x.clone()
+            result.addCallback(constant_multiply, share_y)
+            return result
+        return None
+
+    def _const_mul(self, c, x):
+        """Multiplication of a share-tuple with a constant c."""
+        xi, (rhoi1, rhoi2), Cx = x
+        zi = xi * c
+        rhoz = (rhoi1 * c, rhoi2 * c)
+        Cz = Cx # TODO: exponentiation
+        return (zi, rhoz, Cx)
+
+
+    def _get_triple(self, field):
+        n = field(0)
+        a = OrlandiShare(self, field, field(2), (n, n), n)
+        b = OrlandiShare(self, field, field(4), (n, n), n)
+        c = OrlandiShare(self, field, field(24), (n, n), n)
+        return (a, b, c)
+
+    @increment_pc
+    def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, 
triple_c):
+        """Multiplication of shares give a triple.
+
+        Communication cost: ???.
+        
+        d = Open([x] - [a])
+        e = Open([y] - [b])
+        [z] = e[x] + d[y] - [de] + [c]
+        """
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+        n = field(0)
+
+        cmul_result = self._cmul(share_x, share_y, field)
+        if cmul_result is  not None:
+            return cmul_result
+
+        def multiply((x, y, d, e, c)):
+            # [de]
+            de = self._additive_constant(field(0), d * e)
+            # e[x]
+            t1 = self._const_mul(e, x)
+            # d[y]
+            t2 = self._const_mul(d, y)
+            # d[y] - [de]
+            t3 = self._minus(t2, de)
+            # d[y] - [de] + [c]
+            t4 = self._plus(t3, c)
+            # [z] = e[x] + d[y] - [de] + [c]
+            zi, rhoz, Cz = self._plus(t1, t4)
+            return OrlandiShare(self, field, zi, rhoz, Cz)
+
+        # d = Open([x] - [a])
+        d = self.open(share_x - triple_a)
+        # e = Open([y] - [b])
+        e = self.open(share_y - triple_b)
+        result = gather_shares([share_x, share_y, d, e, triple_c])
+        result.addCallbacks(multiply, self.error_handler)
+
+        # do actual communication
+        self.activate_reactor()
+
+        return result
+
     def error_handler(self, ex):
         print "Error: ", ex
         return ex
+
diff --git a/viff/test/test_orlandi_runtime.py 
b/viff/test/test_orlandi_runtime.py
--- a/viff/test/test_orlandi_runtime.py
+++ b/viff/test/test_orlandi_runtime.py
@@ -252,3 +252,126 @@
         d2.addCallback(check)
         return DeferredList([d1, d2])
 
+    @protocol
+    def test_basic_multiply(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+ 
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+ 
+        x2 = runtime.shift([2], self.Zp, x1)
+        y2 = runtime.shift([3], self.Zp, y1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._basic_multiplication(x2, y2, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_mul_mul(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+ 
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+ 
+        x2 = runtime.shift([2], self.Zp, x1)
+        y2 = runtime.shift([3], self.Zp, y1)
+
+        z2 = x2 * y2
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_basic_multiply_constant_right(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._basic_multiplication(x2, y1, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_basic_multiply_constant_left(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._basic_multiplication(y1, x2, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_left(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._cmul(y1, x2, self.Zp)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_right(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._cmul(x2, y1, self.Zp)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_None(self, runtime):
+        """Test multiplication of two numbers."""
+
+        x1 = 42
+        y1 = 7
+
+        x2 = runtime.shift([1], self.Zp, x1)
+        y2 = runtime.shift([1], self.Zp, y1)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._cmul(y2, x2, self.Zp)
+        self.assertEquals(z2, None)
+        return z2
_______________________________________________
viff-devel mailing list (http://viff.dk/)
viff-devel@viff.dk
http://lists.viff.dk/listinfo.cgi/viff-devel-viff.dk

Reply via email to