# HG changeset patch
# User Janus Dam Nielsen <[email protected]>
# Date 1254816324 -7200
# Node ID dd15e514cbb0a59f8cf06caa55064cccde3f9b41
# Parent  2dabe8c91e557603cbfeb84d39f892f6bf4e773f
importeret rettelse autopreprocessing.patch

diff --git a/apps/benchmark.py b/apps/benchmark.py
--- a/apps/benchmark.py
+++ b/apps/benchmark.py
@@ -63,6 +63,7 @@
 import viff.reactor
 viff.reactor.install()
 from twisted.internet import reactor
+from twisted.internet.defer import Deferred
 
 from viff.field import GF, GF256, FakeGF
 from viff.runtime import Runtime, create_runtime, gather_shares, \
@@ -87,12 +88,13 @@
     print "Started", what
 
 
-def record_stop(_, what):
+def record_stop(x, what):
     stop = time.time()
     print
     print "Total time used: %.3f sec" % (stop-start)
     print "Time per %s operation: %.0f ms" % (what, 1000*(stop-start) / count)
     print "*" * 6
+    return x
 
 operations = {"mul": (operator.mul,[]),
               "compToft05": (operator.ge, [ComparisonToft05Mixin]),
@@ -116,7 +118,7 @@
                   help="the name of the basic runtime to test")
 parser.add_option("-n", "--num_players", action="store_true", 
dest="num_players",
                   help="number of players")
-parser.add_option("--mixins", type="choice", choices=mixins.keys(),
+parser.add_option("--mixins", type="string",
                   help="operation to benchmark")
 parser.add_option("--prss", action="store_true",
                   help="use PRSS for preprocessing")
@@ -138,7 +140,7 @@
                   help="additional arguments to the runtime, the format is a 
comma separated list of id=value pairs e.g. --args s=1,d=0,lambda=1")
 
 parser.set_defaults(modulus=2**65, threshold=1, count=10,
-                    runtime=runtimes.keys()[0], mixins=mixins.keys(), 
num_players=2, prss=True,
+                    runtime=runtimes.keys()[0], mixins="", num_players=2, 
prss=True,
                     operation=operations.keys()[0], parallel=True, fake=False)
 
 # Add standard VIFF options.
@@ -174,53 +176,43 @@
     def __init__(self, rt, operation):
         self.rt = rt
         self.operation = operation
-        self.sync_preprocess()
-
-    def sync_preprocess(self):
-        print "Synchronizing preprocessing"
+        self.pc = None
         sys.stdout.flush()
         sync = self.rt.synchronize()
+        self.doTest(sync, lambda x: x)
         self.rt.schedule_callback(sync, self.preprocess)
+        self.doTest(sync, lambda x: self.rt.shutdown())
+        
+#     def sync_preprocess(self):
+#         print "Synchronizing preprocessing"
+#         sys.stdout.flush()
+#         sync = self.rt.synchronize()
+#         self.rt.schedule_callback(sync, self.preprocess)
 
-    def preprocess(self, _):
-        program_desc = {}
-
-        if isinstance(self.rt, BasicActiveRuntime):
-            # TODO: Make this optional and maybe automatic. The
-            # program descriptions below were found by carefully
-            # studying the output reported when the benchmarks were
-            # run with no preprocessing. So they are quite brittle.
-            if self.operation == operator.mul:
-                key = ("generate_triples", (Zp,))
-                desc = [(i, 1, 0) for i in range(3, 3 + count)]
-                program_desc.setdefault(key, []).extend(desc)
-            elif isinstance(self.rt, ComparisonToft05Mixin):
-                key = ("generate_triples", (GF256,))
-                desc = sum([[(c, 64, i, 1, 1, 0) for i in range(2, 33)] +
-                            [(c, 64, i, 3, 1, 0) for i in range(17, 33)]
-                            for c in range(3 + 2*count, 3 + 3*count)],
-                           [])
-                program_desc.setdefault(key, []).extend(desc)
-            elif isinstance(self.rt, ComparisonToft07Mixin):
-                key = ("generate_triples", (Zp,))
-                desc = sum([[(c, 2, 4, i, 2, 1, 0) for i in range(1, 33)] +
-                            [(c, 2, 4, 99, 2, 1, 0)] +
-                            [(c, 2, 4, i, 1, 0) for i in range(65, 98)]
-                            for c in range(3 + 2*count, 3 + 3*count)],
-                           [])
-                program_desc.setdefault(key, []).extend(desc)
-
-        if program_desc:
+    def preprocess(self, needed_data):
+        print "Preprocess", needed_data
+        if needed_data:
             print "Starting preprocessing"
             record_start("preprocessing")
-            preproc = self.rt.preprocess(program_desc)
+            preproc = self.rt.preprocess(needed_data)
             preproc.addCallback(record_stop, "preprocessing")
-            self.rt.schedule_callback(preproc, self.begin)
+            return preproc
         else:
             print "Need no preprocessing"
-            self.begin(None)
+            return None
+
+    def doTest(self, d, termination_function):
+        print "doTest", self.rt.program_counter
+        self.rt.schedule_callback(d, self.begin)
+        self.rt.schedule_callback(d, self.sync_test)
+#         self.rt.schedule_callback(d, self.countdown, 3)
+        self.rt.schedule_callback(d, self.run_test)
+        self.rt.schedule_callback(d, self.sync_test)
+        self.rt.schedule_callback(d, self.finished, termination_function)
+        return d
 
     def begin(self, _):
+        print "begin", self.rt.program_counter
         print "Runtime ready, generating shares"
         self.a_shares = []
         self.b_shares = []
@@ -234,43 +226,49 @@
             self.a_shares.append(self.rt.input([inputter], Zp, a))
             self.b_shares.append(self.rt.input([inputter], Zp, b))
         shares_ready = gather_shares(self.a_shares + self.b_shares)
-        self.rt.schedule_callback(shares_ready, self.sync_test)
+        return shares_ready
 
-    def sync_test(self, _):
+    def sync_test(self, x):
         print "Synchronizing test start."
         sys.stdout.flush()
         sync = self.rt.synchronize()
-        self.rt.schedule_callback(sync, self.countdown, 3)
+        self.rt.schedule_callback(sync, lambda y: x)
+        return sync
 
-    def countdown(self, _, seconds):
-        if seconds > 0:
-            print "Starting test in %d" % seconds
-            sys.stdout.flush()
-            reactor.callLater(1, self.countdown, None, seconds - 1)
-        else:
-            print "Starting test now"
-            sys.stdout.flush()
-            self.run_test(None)
+#     def countdown(self, _, seconds):
+#         if seconds > 0:
+#             print "Starting test in %d" % seconds
+#             sys.stdout.flush()
+#             reactor.callLater(1, self.countdown, None, seconds - 1)
+#         else:
+#             print "Starting test now"
+#             sys.stdout.flush()
+#             self.run_test(None)
 
     def run_test(self, _):
         raise NotImplemented("Override this abstract method in a sub class.")
 
-    def finished(self, _):
+    def finished(self, needed_data, termination_function):
         sys.stdout.flush()
 
         if self.rt._needed_data:
             print "Missing pre-processed data:"
-            for (func, args), pcs in self.rt._needed_data.iteritems():
+            for (func, args), pcs in needed_data.iteritems():
                 print "* %s%s:" % (func, args)
                 print "  " + pformat(pcs).replace("\n", "\n  ")
 
-        self.rt.shutdown()
+        return termination_function(needed_data)
 
 # This class implements a benchmark where run_test executes all
 # operations in parallel.
 class ParallelBenchmark(Benchmark):
 
-    def run_test(self, _):
+    def run_test(self, shares):
+        print "rt", self.rt.program_counter, self.pc
+        if self.pc != None:
+            self.rt.program_counter = self.pc
+        else:
+            self.pc = list(self.rt.program_counter)
         c_shares = []
         record_start("parallel test")
         while self.a_shares and self.b_shares:
@@ -280,24 +278,30 @@
 
         done = gather_shares(c_shares)
         done.addCallback(record_stop, "parallel test")
-        self.rt.schedule_callback(done, self.finished)
+        def f(x):
+            needed_data = self.rt._needed_data
+            self.rt._needed_data = {}
+            return needed_data
+        done.addCallback(f)
+        return done
+
 
 # A benchmark where the operations are executed one after each other.
 class SequentialBenchmark(Benchmark):
 
-    def run_test(self, _):
+    def run_test(self, _, termination_function, d):
         record_start("sequential test")
-        self.single_operation(None)
+        self.single_operation(None, termination_function)
 
-    def single_operation(self, _):
+    def single_operation(self, _, termination_function):
         if self.a_shares and self.b_shares:
             a = self.a_shares.pop()
             b = self.b_shares.pop()
             c = self.operation(a, b)
-            self.rt.schedule_callback(c, self.single_operation)
+            self.rt.schedule_callback(c, self.single_operation, 
termination_function)
         else:
             record_stop(None, "sequential test")
-            self.finished(None)
+            self.finished(None, termination_function)
 
 # Identify the base runtime class.
 base_runtime_class = runtimes[options.runtime]
diff --git a/viff/active.py b/viff/active.py
--- a/viff/active.py
+++ b/viff/active.py
@@ -378,11 +378,11 @@
     def get_triple(self, field):
         # This is a waste, but this function is only called if there
         # are no pre-processed triples left.
-        count, result = self.generate_triples(field)
+        count, result = self.generate_triples(field, None)
         result.addCallback(lambda triples: triples[0])
         return result
 
-    def generate_triples(self, field):
+    def generate_triples(self, field, number_of_requested_triples):
         """Generate multiplication triples.
 
         These are random numbers *a*, *b*, and *c* such that ``c =
@@ -423,11 +423,11 @@
 
     @preprocess("generate_triples")
     def get_triple(self, field):
-        count, result = self.generate_triples(field)
+        count, result = self.generate_triples(field, None)
         result.addCallback(lambda triples: triples[0])
         return result
 
-    def generate_triples(self, field):
+    def generate_triples(self, field, number_of_requested_triples):
         """Generate a multiplication triple using PRSS.
 
         These are random numbers *a*, *b*, and *c* such that ``c =
diff --git a/viff/runtime.py b/viff/runtime.py
--- a/viff/runtime.py
+++ b/viff/runtime.py
@@ -475,17 +475,21 @@
     """
 
     def preprocess_decorator(method):
-
         @wrapper(method)
         def preprocess_wrapper(self, *args, **kwargs):
             pc = tuple(self.program_counter)
             try:
+                self.program_counter[-1] += 1
                 return self._pool[pc]
             except KeyError:
-                key = (generator, args)
-                pcs = self._needed_data.setdefault(key, [])
-                pcs.append(pc)
-                return method(self, *args, **kwargs)
+                try:
+                    key = (generator, args)
+                    pcs = self._needed_data.setdefault(key, [])
+                    pcs.append(pc)
+                    self.program_counter.append(0)
+                    return method(self, *args, **kwargs)
+                finally:
+                    self.program_counter.pop()
 
         return preprocess_wrapper
     return preprocess_decorator
@@ -808,6 +812,9 @@
             func = getattr(self, generator)
             results = []
             items = 0
+            args = list(args)
+            args.append(len(program_counters))
+            args = tuple(args)
             while items < len(program_counters):
                 item_count, result = func(*args)
                 items += item_count
_______________________________________________
viff-patches mailing list
[email protected]
http://lists.viff.dk/listinfo.cgi/viff-patches-viff.dk

Reply via email to