# HG changeset patch # User Janus Dam Nielsen <janus.niel...@alexandra.dk> # Date 1245395036 -7200 # Node ID a07740da4582869d11ead0f56ae055965aa2b4b0 # Parent 07a8329e75322d482dae15186422dd75e9ddb653 Implementation of the basic multiplication command.
diff --git a/viff/orlandi.py b/viff/orlandi.py --- a/viff/orlandi.py +++ b/viff/orlandi.py @@ -15,6 +15,8 @@ # You should have received a copy of the GNU Lesser General Public # License along with VIFF. If not, see <http://www.gnu.org/licenses/>. +from twisted.internet.defer import DeferredList, gatherResults + from viff.runtime import Runtime, increment_pc, Share, ShareList, gather_shares from viff.util import rand, dprint @@ -442,6 +444,21 @@ return results[0] return results + @increment_pc + def mul(self, share_x, share_y): + """Multiplication of shares. + + Communication cost: ???. + """ + # TODO: Communication cost? + assert isinstance(share_x, Share) or isinstance(share_y, Share), \ + "At least one of share_x and share_y must be a Share." + + field = getattr(share_x, "field", getattr(share_y, "field", None)) + + a, b, c = self._get_triple(field) + return self._basic_multiplication(share_x, share_y, a, b, c) + def _additive_constant(self, zero, field_element): """Greate an additive constant. @@ -488,6 +505,97 @@ Cz = Cx * Cy return (zi, (rhozi1, rhozi2), Cz) + def _cmul(self, share_x, share_y, field): + """Multiplication of a share with a constant. + + Either share_x or share_y must be an OrlandiShare but not both. + Returns None if both share_x and share_y are OrlandiShares. + + """ + def constant_multiply(x, c): + zi, rhoz, Cx = self._const_mul(c, x) + return OrlandiShare(self, field, zi, rhoz, Cx) + if not isinstance(share_x, Share): + # Then share_y must be a Share => local multiplication. We + # clone first to avoid changing share_y. + assert isinstance(share_y, Share), \ + "At least one of the arguments must be a share." + result = share_y.clone() + result.addCallback(constant_multiply, share_x) + return result + if not isinstance(share_y, Share): + # Likewise when share_y is a constant. + assert isinstance(share_x, Share), \ + "At least one of the arguments must be a share." + result = share_x.clone() + result.addCallback(constant_multiply, share_y) + return result + return None + + def _const_mul(self, c, x): + """Multiplication of a share-tuple with a constant c.""" + xi, (rhoi1, rhoi2), Cx = x + zi = xi * c + rhoz = (rhoi1 * c, rhoi2 * c) + Cz = Cx # TODO: exponentiation + return (zi, rhoz, Cx) + + + def _get_triple(self, field): + n = field(0) + a = OrlandiShare(self, field, field(2), (n, n), n) + b = OrlandiShare(self, field, field(4), (n, n), n) + c = OrlandiShare(self, field, field(24), (n, n), n) + return (a, b, c) + + @increment_pc + def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c): + """Multiplication of shares give a triple. + + Communication cost: ???. + + d = Open([x] - [a]) + e = Open([y] - [b]) + [z] = e[x] + d[y] - [de] + [c] + """ + assert isinstance(share_x, Share) or isinstance(share_y, Share), \ + "At least one of share_x and share_y must be a Share." + + field = getattr(share_x, "field", getattr(share_y, "field", None)) + n = field(0) + + cmul_result = self._cmul(share_x, share_y, field) + if cmul_result is not None: + return cmul_result + + def multiply((x, y, d, e, c)): + # [de] + de = self._additive_constant(field(0), d * e) + # e[x] + t1 = self._const_mul(e, x) + # d[y] + t2 = self._const_mul(d, y) + # d[y] - [de] + t3 = self._minus(t2, de) + # d[y] - [de] + [c] + t4 = self._plus(t3, c) + # [z] = e[x] + d[y] - [de] + [c] + zi, rhoz, Cz = self._plus(t1, t4) + return OrlandiShare(self, field, zi, rhoz, Cz) + + # d = Open([x] - [a]) + d = self.open(share_x - triple_a) + # e = Open([y] - [b]) + e = self.open(share_y - triple_b) + result = gather_shares([share_x, share_y, d, e, triple_c]) + result.addCallbacks(multiply, self.error_handler) + + # do actual communication + self.activate_reactor() + + return result + def error_handler(self, ex): print "Error: ", ex return ex + diff --git a/viff/test/test_orlandi_runtime.py b/viff/test/test_orlandi_runtime.py --- a/viff/test/test_orlandi_runtime.py +++ b/viff/test/test_orlandi_runtime.py @@ -252,3 +252,126 @@ d2.addCallback(check) return DeferredList([d1, d2]) + @protocol + def test_basic_multiply(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([2], self.Zp, x1) + y2 = runtime.shift([3], self.Zp, y1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._basic_multiplication(x2, y2, a, b, c) + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_mul_mul(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([2], self.Zp, x1) + y2 = runtime.shift([3], self.Zp, y1) + + z2 = x2 * y2 + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_basic_multiply_constant_right(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([1], self.Zp, x1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._basic_multiplication(x2, y1, a, b, c) + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_basic_multiply_constant_left(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([1], self.Zp, x1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._basic_multiplication(y1, x2, a, b, c) + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_constant_multiplication_constant_left(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([1], self.Zp, x1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._cmul(y1, x2, self.Zp) + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_constant_multiplication_constant_right(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + def check(v): + self.assertEquals(v, x1 * y1) + + x2 = runtime.shift([1], self.Zp, x1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._cmul(x2, y1, self.Zp) + d = runtime.open(z2) + d.addCallback(check) + return d + + @protocol + def test_constant_multiplication_constant_None(self, runtime): + """Test multiplication of two numbers.""" + + x1 = 42 + y1 = 7 + + x2 = runtime.shift([1], self.Zp, x1) + y2 = runtime.shift([1], self.Zp, y1) + + a, b, c = runtime._get_triple(self.Zp) + z2 = runtime._cmul(y2, x2, self.Zp) + self.assertEquals(z2, None) + return z2 _______________________________________________ viff-devel mailing list (http://viff.dk/) viff-devel@viff.dk http://lists.viff.dk/listinfo.cgi/viff-devel-viff.dk