#!/usr/bin/python

# Copyright 2007, 2008 VIFF Development Team.
#
# This file is part of VIFF, the Virtual Ideal Functionality Framework.
#
# VIFF is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License (LGPL) as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# VIFF is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.

# This code can be used to generate shared RSA keys of any desired
# length. The implementation is based on the algorithm described
# in "Efficient Generation of Shared RSA keys" written by
# Dan Boneh and Matthew Franklin in 1997.
#
# Some adjustments have been made, the first one found in the step
# "Trial division", which is specially implemented for 3 players,
# although it can be be extended to arbitrary number of players.
# The second change is that the trial division for N is for a larger
# span than used in the article and also that each player checks
# different spans instead of all players check the same ones.
#
# In this example the protocol is run between three millionaires and
# uses a protocol for secure integer comparison by Toft from 2005.
#
# Give a player configuration file as a command line argument or run
# the example with '--help' for help with the command line options.

# import the necessary modules
import random
import math
import gmpy
import time

from optparse import OptionParser
from twisted.internet import reactor

from viff.field import GF
from viff.runtime import Runtime, create_runtime, gather_shares, make_runtime_class, Share
from viff.comparison import ComparisonToft07Mixin, Toft05Runtime
from viff.config import load_config
from viff.util import rand, find_prime
from viff.equality import ProbabilisticEqualityMixin

# We start by defining the protocol, it will be started at the bottom
# of the file.

class Protocol:

    # returns the list of primes larger than min and less or equal to max
    def get_primes(self, min, max):
        primes = []
        while True:
            prime = int(gmpy.next_prime(min))
            if prime <= max:
                primes += [prime]
                min = prime
            else:
                return primes


    # the function for generating a private part of p for each player
    def generate_p(self):
        self.function_count[0] += 1
        # player 1 needs to obtain its share of p as congruent to 3 mod 4
        if self.runtime.id == 1:
            self.p = 4*random.randint(1, self.numeric_length - 1) + 3
        # every other player needs to obtain its share of p as congruent to 0 mod 4
        else:
            self.p = 4*random.randint(1, self.numeric_length - 1)

        #print "my p = " + str(self.p)
        self.trial_division_p()


    # the function for generating a private part of q for each player, equal to the corresponding function for p
    def generate_q(self):
        self.function_count[1] += 1
        if self.runtime.id == 1:
            self.q = 4*random.randint(1, self.numeric_length - 1) + 3
        else:
            self.q = 4*random.randint(1, self.numeric_length - 1)

        #print "my q = " + str(self.q)
        self.trial_division_q()


    # function for doing shared trial division for small primes on the choosen p
    # alternative step to the step described in the article, with this solution nothing is revealed
    # check if p is composite for small primes (done secret shared)
    # each player choose a random number from Zp and this number along with its private p (mod the current prime number to be tested)
    def trial_division_p(self):
        self.function_count[2] += 1
        # the function is done iterative, therefore the next prime to be checked needs to be choosen
        prime_num = self.prime_list_b1[self.prime_pointer]
        # calculate the remainder of self.p modulus the current prime number in the list
        p_trial = self.p % prime_num
        #print "my p_trial = " + str(p_trial) + " for prime_num = " + str(prime_num)
        r_trial = random.randint(1, self.Zp.modulus - 1)
        #print "my random r_trial = " + str(r_trial)

        # share the values
        p_trial1, p_trial2, p_trial3 = self.runtime.shamir_share([1, 2, 3], self.Zp, p_trial)
        p_r_trial1, p_r_trial2, p_r_trial3 = self.runtime.shamir_share([1, 2, 3], self.Zp, r_trial)

        # calculate the needed values
        p_trial_tot = (p_trial1 + p_trial2 + p_trial3)
        r_trial_tot = (p_r_trial1 + p_r_trial2 + p_r_trial3)
        # the value to reveal, p_trial_tot is the sum of each players' private p, r_trial_tot is the sum of a random number from each player and prime_num is the current prime number to check
        trial_reveal = p_trial_tot * (p_trial_tot - prime_num) * (p_trial_tot - 2 * prime_num) * r_trial_tot

        # open the value of the open_trial_reveal share            
        open_trial_reveal = self.runtime.open(trial_reveal)
        results = gather_shares([open_trial_reveal])
        # addCallback lets the program wait for the results to be ready, then call the function given as the argument
        results.addCallback(self.check_trial_division_p)


    # reveal-function that are called from trial_division_p() when the results are ready
    # from the equation in trial_division_p() trial_reveal = p(p - prime)(p - 2*prime)*r, if prime divides p, then surely this expression will be zero for 3 players
    # if prime does NOT divide p, then the result re_trial will be nothing but a random number, and reveals no information about the players' private p
    def check_trial_division_p(self, results):
        self.function_count[3] += 1
        rev_trial = results[0].value
        #print "rev_trial = " + str(rev_trial)

        # if prime divides p, generate a new p and start over
        if rev_trial == 0:
            self.prime_pointer = 0
            #print "generating p again"
            self.generate_p()
        # if not, check if more primes are to be tested, if yes, go back to trial_division_p(), if no, generate q
        else:
            self.prime_pointer += 1
            # if all the primes in the prime_list_b1 is tested, generate q
            if self.prime_pointer >= len(self.prime_list_b1):
                self.prime_pointer = 0
                self.generate_q()
            # else, check for next prime in the list
            else:
                self.trial_division_p()
            

    # this function is equal to the corresponding function for p
    def trial_division_q(self):
        self.function_count[4] += 1
        prime_num = self.prime_list_b1[self.prime_pointer]
        q_trial = self.q % prime_num
        #print "my q_trial = " + str(q_trial) + " for prime_num = " + str(prime_num)
        r_trial = random.randint(1, self.Zp.modulus - 1)
        #print "my random r_trial = " + str(r_trial)

        q_trial1, q_trial2, q_trial3 = self.runtime.shamir_share([1, 2, 3], self.Zp, q_trial)
        q_r_trial1, q_r_trial2, q_r_trial3 = self.runtime.shamir_share([1, 2, 3], self.Zp, r_trial)

        q_trial_tot = (q_trial1 + q_trial2 + q_trial3)
        r_trial_tot = (q_r_trial1 + q_r_trial2 + q_r_trial3)
        trial_reveal = q_trial_tot * (q_trial_tot - prime_num) * (q_trial_tot - 2 * prime_num) * r_trial_tot
        
        open_trial_reveal = self.runtime.open(trial_reveal)
        results = gather_shares([open_trial_reveal])
        results.addCallback(self.check_trial_division_q)


    # this function is equal to the corresponding function for p until a q is accepted so far
    def check_trial_division_q(self, results):
        self.function_count[5] += 1
        rev_trial = results[0].value
        #print "rev_trial = " + str(rev_trial)

        if rev_trial == 0:
            self.prime_pointer = 0
            #print "generating q again"
            self.generate_q()
        else:
            self.prime_pointer += 1
            # if all the primes in the prime_list_b1 is tested, reveal n
            if self.prime_pointer >= len(self.prime_list_b1):
                self.prime_pointer = 0

                p1, p2, p3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.p)
                # calculate the total p as a share
                self.ptot = (p1 + p2 + p3)

                q1, q2, q3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.q)
                # calculate the total q as a share
                self.qtot = (q1+ q2 + q3)

                # calculate and open the RSA-modulus N                
                n = self.ptot * self.qtot
                open_n = self.runtime.open(n)

                # FOR DEBUGGING ONLY
                #open_ptot = self.runtime.open(self.ptot)
                #open_qtot = self.runtime.open(self.qtot)
                # END DEBUGGING ONLY

                results = gather_shares([open_n]) #, open_ptot, open_qtot]) # LAST TWO FOR DEBUGGING ONLY
                results.addCallback(self.check_n)
            # else, check for next prime in the list
            else:
                self.trial_division_q()
 


    # function to save the revealed N and the shared value of phi, plus do useful debugging printouts
    def check_n(self, results):
        self.function_count[6] += 1
        #print "n = " + str(results[0])
        
        self.n_revealed = results[0].value
        self.phi = (self.ptot - 1) * (self.qtot - 1)
        #print "completed rounds: " + str(self.completed_rounds) + " / " + str(self.rounds)
        #print "\nn_revealed = " + str(self.n_revealed)

        # FOR DEBUGGING ONLY        
        #print "p_revealed = " + str(results[1].value)
        #print "q_revealed = " + str(results[2].value)
        # END DEBUGGING ONLY
        
        #print "#bits in N = " + str(math.ceil(math.log(self.n_revealed, 2)))
        
        self.primality_test_N()

    # function for more primality testing on p and q
    # the primality testing for N can be done very quickly locally for each player since N is a revealed value
    # each player checks N for different intervals (in prime_list_b2) for program speed up
    def primality_test_N(self):
        self.function_count[7] += 1
        # assume that the primality test will not fail
        test_failed = 0
        for i in self.prime_list_b2:
            #print "N mod " + str(i) + " = " + str(self.n_revealed % i)
            # if the current prime in the list divides N, this means that N has a factor equal to this prime, since this factor is small (in comparison to the value of p, q and N), this means that N is not the product of two large primes p and q
            if self.n_revealed % i == 0:
                #print "failed... " + str(i) + " divides " + str(self.n_revealed)
                test_failed = 1
                break

        # share the values
        failed1, failed2, failed3 = self.runtime.shamir_share([1, 2, 3], self.Zp, test_failed)

        # calculate and open the sum of failed values
        failed_tot = failed1 + failed2 + failed3
        open_failed_tot = self.runtime.open(failed_tot)
    
        results = gather_shares([open_failed_tot])
        results.addCallback(self.check_primality_test_N)


    # function for checking the primality test for N
    def check_primality_test_N(self, results):
        self.function_count[8] += 1
        # if each player has checked through its whole list of primes, but none divides N, p and q are so far accepted
        if results[0].value == 0:
            #print "primality test for N is OK, generate g"
            self.generate_g()
        # if the results are not 0, then or or more of the players have discovered a factor for N that is not p or q, start the whole process from start with generating p
        else:
            #print "primality test for N failed, start generating p"
            self.generate_p()
        

    # function for agreeing on a random chosen g
    def generate_g(self):
        self.function_count[9] += 1
        # player 1 chooses a random number in the interval [1, N-1] and shares it with the other players
        if self.runtime.id == 1:
            self.g = random.randint(1, self.n_revealed - 1)
            #print "g = " + str(self.g)
            self.g = self.runtime.shamir_share([1], self.Zp, self.g)
        else:
            # no input to the shamir share means that this player has no value to share, but gets a value of what is shared (by player 1)
            self.g = self.runtime.shamir_share([1], self.Zp)

        self.open_g = self.runtime.open(self.g)
        results = gather_shares([self.open_g])
        results.addCallback(self.check_g)


    # function for distributed biprimality test, check that the jacobi symbol of g is equal to 1, if yes, calculate v
    def check_g(self, results):
        self.function_count[10] += 1
        #print "g = " + str(results[0].value)
        self.g = results[0].value
        # calculate the jacobi symbol of (g/N)
        jacobi = gmpy.jacobi(self.g, self.n_revealed) % self.n_revealed # remove the modulus?????????????????????????????????
        #print "jacobi = " + str(jacobi)
        # if the jacobi value is equal to 1, then calculate v
        if jacobi == 1:
            # calculate the v's
            if self.runtime.id == 1:
                # calculate player 1's private part of phi (N - p1 - q1 + 1)
                self.phi_i = self.n_revealed - self.p - self.q + 1
                #self.v = self.g**((self.n_revealed - self.p - self.q + 1) / 4) % self.n_revealed
                base = gmpy.mpz(self.g)
                power = gmpy.mpz(self.phi_i / 4)
                modulus = gmpy.mpz(self.n_revealed)
                self.v = int(pow(base, power, modulus))
                #self.v = self.powermod(self.g, (self.n_revealed - self.p - self.q + 1) / 4, self.n_revealed)
            else:
                # calculate every other players' private part of phi -(pi + qi) for player i
                self.phi_i = -(self.p + self.q)
                # the function gmpy.divm(1, a, b) calculates the inverse of a mod b
                self.inverse_v = int(gmpy.divm(1, self.g, self.n_revealed))

                base = gmpy.mpz(self.inverse_v)
                power = gmpy.mpz(-self.phi_i / 4)
                modulus = gmpy.mpz(self.n_revealed)
                self.v = int(pow(base, power, modulus))

            #print "self.phi_i = " + str(self.phi_i)
        # if the jacobi value is not 1, then choose generate a new g
        else:
            self.generate_g()
            return

        #print "self.v = " + str(self.v)

        # share the v's (already mod N)
        v1, v2, v3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.v)

        # calculate the total v        
        v_tot = v1 * v2 * v3
        self.open_v = self.runtime.open(v_tot)
        results = gather_shares([self.open_v])
        #print "GIKK GREIT MED GATHER SHARES"
        results.addCallback(self.check_v)
        

    # function for checking for a valid v    
    def check_v(self, results):
        self.function_count[11] += 1
        # the resulting v is also calculated mod N
        v = results[0].value % self.n_revealed
        #print "v = " + str(v)
        
        # if v is equal to 1/-1 mod N, go to the next step, generating z
        if v == 1 or v == self.n_revealed - 1:
            self.generate_z()
        # else, the distributed biprimality test failed, start all over with generating p
        else:
            self.prime_pointer = 0
            self.generate_p()


    # function for the 4th step in the distributed biprimality test --> the alternative step described
    def generate_z(self):
        self.function_count[12] += 1
        # each player generate a random number
        self.r_z = random.randint(1, self.n_revealed - 1)
        # the random numbers are shared
        r1, r2, r3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.r_z)
        z = (r1 + r2 + r3) * (-1 + (self.ptot + self.qtot))

        self.open_z = self.runtime.open(z)
        results = gather_shares([self.open_z])
        results.addCallback(self.check_z)


    # function for checking that gcd(z, N) is equal to 1
    def check_z(self, results):
        self.function_count[13] += 1
        z = results[0].value % self.n_revealed
        #print "z = " + str(z)

        # calculate the gcd of z and N        
        z_n = gmpy.gcd(z, self.n_revealed)
        # if the gcd is equal to 1, then the distributed biprimality test is passed
        if z_n == 1:
            #print "gcd(z, N) = 1, start generating e,d"
            # choosing the RSA public exponent e, a prime close to a power of two is often chosen, 2^16 + 1 = 65537 is very often used
            self.e = 2**16 + 1 
            #self.e = 17
            #print "e = " + str(self.e)
            self.generate_l()
            #self.generate_psi()

        # else the distributed biprimality test has failed, and the whole protocol is started again by generating new p and q's            
        else:
            #print "gcd(z, N) != 1, restart with generating p"
            self.prime_pointer = 0
            self.generate_p()


    # function for generating l, used to finding the private exponent d
    # by arriving at this function p and q are found to be primes, and only a shared d is needed
    def generate_l(self):
        self.function_count[14] += 1
        # every player calculates his/her private phi_i mod e (public exponent)
        self.l = self.phi_i % self.e
        print "\n\nPRIVATE VARIABLES"
        print "self.l = " + str(self.l)
        # share the l's and calculate the total l
        l1, l2, l3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.l)
        l_tot = l1 + l2 + l3
        
        open_l_tot = self.runtime.open(l_tot)
        results = gather_shares([open_l_tot])
        results.addCallback(self.generate_d)


    # function for generating the private exponent d, each player end up with a private part of the total d
    def generate_d(self, results):
        self.function_count[15] += 1
        # calculate the total l mod e
        l_tot = results[0].value % self.e
        #print "l_tot = " + str(l_tot)

        # check that total l is invertable mod e        
        try:
            zeta = gmpy.divm(1, l_tot, self.e) # CHECK IF INVERTABLE
        except:
            # if not invertable, the protocol needs to be started all over
            # not invertable often means badly chosen 'e'
            print "not invertable mod e"
            self.generate_p()
            
        #print "zeta (inv) = " + str(zeta)

        # calculate this player's private d, rounded down, this means it's not entirely correct, but corrected later
        self.d = int( - (zeta*self.phi_i)/self.e)
        print "self.p = " + str(self.p)
        print "self.q = " + str(self.q)
        print "self.d = " + str(self.d)
        print "N (public) = " + str(self.n_revealed)

        # calculate this player's c, which is used to correct the d with a trial decryption
        base = gmpy.mpz(self.m)
        power = gmpy.mpz(self.e)
        modulus = gmpy.mpz(self.n_revealed)
        self.c = int(pow(base, power, modulus))
        
        # the wanted value to calculate is this player's c^di mod N, but player 1's 'd' is negative, therefore find the inverse of player 1's c mod N, and use that instead
        if self.runtime.id == 1:
            self.c = gmpy.divm(1, self.c, self.n_revealed)
        base = gmpy.mpz(self.c)
        if self.runtime.id == 1:
            power = gmpy.mpz(-self.d)
        else:
            power = gmpy.mpz(self.d)
        modulus = gmpy.mpz(self.n_revealed)
        # decrypt = c^di mod N
        self.decrypt = int(pow(base, power, modulus))
        #print "self.decrypt (c^di mod N) = " + str(self.decrypt)

        # each player share its c = self.decrypt
        c1, c2, c3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.decrypt)

        open_c1 = self.runtime.open(c1)
        open_c2 = self.runtime.open(c2)
        open_c3 = self.runtime.open(c3)
        
        results = gather_shares([open_c1, open_c2, open_c3])
        results.addCallback(self.check_decrypt)

    def check_decrypt(self, results):
        self.function_count[16] += 1
        # player 3 is responsible for the trial decryption, mostly because player 1 has a negative d and that means more calculations if player 1 is suppose to do the task
        if self.runtime.id == 3:
            c1 = results[0].value
            c2 = results[1].value
            c3 = results[2].value

            # the adjustment is at most n-1, for three players this means max 2            
            for i in range(0,3):
                # calculate the temp_decrypt
                tmp_decrypt = c1 * c2 * c3 % self.n_revealed #self.c**self.r * c1 * c2 * c3 % self.n_revealed
                print "Decryption = " + str(tmp_decrypt)
                # check if this value is the correct value
                if (tmp_decrypt == self.m):
                    print "d found, with +r = " + str(i)
                    # if it is, correct_decryptions is increased
                    self.correct_decryptions += 1
                    print "Correct decryptions: " + str(self.correct_decryptions) + " / " + str(self.rounds)
                    break
                else:
                    # if not, player 3's d is increased by 1 and c3 is recalculated before the next iteration of the for-loop is done
                    self.d += 1
                    base = gmpy.mpz(self.c)
                    power = gmpy.mpz(self.d)
                    modulus = gmpy.mpz(self.n_revealed)
                    c3 = int(pow(base, power, modulus))

        # time2 is set to calculate the total time for the generation of this valid key            
        self.time2 = time.clock()
        # completed_rounds is increased in case of more rounds
        self.completed_rounds += 1
        print "Completed rounds: " + str(self.completed_rounds) + " / " + str(self.rounds)
        # the time for finding the current key is saved in the times variable
        self.times += [self.time2 - self.time1]
        # check if all the key generation rounds are finished
        if self.completed_rounds == self.rounds:
            # if so, print the datas from the generations
            print "\n\nBENCHMARKS FOR VALID KEY GENERATION"
            print "times = " + str(self.times)
            print "Average: " + str(sum(self.times) / (self.rounds))
            print "Correct decryptions: " + str(self.correct_decryptions) + " / " + str(self.rounds)
            print "\n"
            for i in range(len(self.function_count)):
                print str(self.function_count_names[i]) + ": " + str(self.function_count[i]) + ", avg: " + str(int(self.function_count[i] / self.rounds))
            # test if the program is suppose to do decryption_benchmark as well
            if self.decrypt_benchmark_active == True:
                self.decrypt_benchmark()
                return
            else:
                # the protocol is finished, synchronize the shutdown
                self.runtime.shutdown()
        else:
            # more key generation shall be done, reset the parameters for a new round and start the protocol again from generate_p()
            self.prime_pointer = 0
            self.decrypt_tries = 0
            self.time1 = time.clock()
            self.generate_p()


    # function for benchmarking the decryption time for a valid key
    # the method is to choose a message 'm', calculate the cipher c = m^e mod N, then find each player's part of the message mi = c^di mod N
    def decrypt_benchmark(self):
        # start the clock for time benchmark
        self.decrypt_time1 = time.clock()

        # calculate this player's cipher c        
        base = gmpy.mpz(self.m)
        power = gmpy.mpz(self.e)
        modulus = gmpy.mpz(self.n_revealed)
        self.c = int(pow(base, power, modulus))
        
        # since player 1's d is negative, find the inverse
        if self.runtime.id == 1:
            self.c = gmpy.divm(1, self.c, self.n_revealed)
        base = gmpy.mpz(self.c)
        if self.runtime.id == 1:
            power = gmpy.mpz(-self.d)
        else:
            power = gmpy.mpz(self.d)

        modulus = gmpy.mpz(self.n_revealed)
        # calculate this player's mi = c^di mod N
        self.decrypt = int(pow(base, power, modulus))

        # share the values
        c1, c2, c3 = self.runtime.shamir_share([1, 2, 3], self.Zp, self.decrypt)

        # calculate the total c and open
        c_tot = c1 * c2 * c3
        open_c_tot = self.runtime.open(c_tot)

        results = gather_shares([open_c_tot])
        results.addCallback(self.check_decrypt_benchmark)
        

    # function for checking the results from the decryption benchmark
    def check_decrypt_benchmark(self, results):
        # the offset of the total d is off by at most n-1, iterate through all possible values
        for i in range(0,3):
            # calculate a tmp_decrypt
            tmp_decrypt = results[0].value % self.n_revealed
            # check if this is equal to the original message
            if tmp_decrypt == self.m:
                # if so, stop the clock
                self.decrypt_time2 = time.clock()
                # update the number of decrypt tries and save the time used for the current decryption
                self.decrypt_tries += 1
                self.decrypt_times += [self.decrypt_time2 - self.decrypt_time1]
                #print "correct decryption for m = " + str(self.m)

                # check if more decryption benchmarks is suppose to be done
                if self.decrypt_tries < self.decrypt_rounds:
                    # if yes, update the original message to not repeat decryption for the same message 'm' every time
                    self.m += 1
                    # go back to the decrypt_benchmark() for a new round
                    self.decrypt_benchmark()
                    return
                else:
                    # print some useful output from the benchmark
                    print "\n\nBENCHMARK FOR DECRYPTION"
                    print "times = " + str(self.decrypt_times)
                    print "average decrypt time = " + str(sum(self.decrypt_times) / self.decrypt_rounds)
                    # the protocol is finished, synchronize the shutdown
                    self.runtime.shutdown()
                    return
                                                          

    # function that starts the shared RSA protocol
    def __init__(self, runtime):

        # CHANGEABLE VARIABLES
        #**********************
        
        # rounds are the total number of rounds to be run for benchmark
        self.rounds = 1
        # set True to do decryption benchmark, False to drop this benchmark
        self.decrypt_benchmark_active = True
        # The number of decryption rounds to be performed if active
        self.decrypt_rounds = 10
        # the number of bits in N (meaning p and q are bits_N / 2 each)
        self.bits_N = 128

        # m is the message used to check for correct decryption        
        self.m = 2

        # the lower limit for primality testing, testing done secret shared
        self.bound1 = 12
        # the limits for primality testing of N, done locally with different boundaries for each player
        # more efficient to let player 1 check larger span, statistically player 1 will fail most often
        self.bound2_p1 = 15000 # 12-15000 = 1749 primes
        self.bound2_p2 = 17500 # 15000-17500 = 260 primes
        self.bound2_p3 = 20000 # 17500-20000 = 253 primes
        

        # VARIABLES NOT TO BE CHANGED
        #*****************************

        # time1 and time2 is used to measure the total time of generating a key
        self.time1 = time.clock()
        self.time2 = 0
        # completed_rounds are used when running keygeneration several times for benchmarking
        self.completed_rounds = 0
        # times are the times from each round in key generation
        self.times = []
        # correct_decryptions are used to sum up the total number of correct decryptions when benchmarking key generation
        # if printout show that correct_decryptions is not equal to the total number of rounds, the protocol is flawed
        self.correct_decryptions = 0

        # decrypt_time1/2 is used to measure the time for decryption benchmark
        self.decrypt_time1 = 0
        self.decrypt_time2 = 0
        # decrypt_times are the times from each round in the decrypt benchmark
        self.decrypt_times = []

        #self.completed_decrypt = 0

        # completed_decrypt is used to count the number of decryptions done until now in decryption benchmark
        self.decrypt_tries = 0

        # Save the Runtime for later use
        self.runtime = runtime

        # bit_length is the number of bits in p and q (correct for 3 players)
        self.bit_length = int(self.bits_N / 2) - 1
        # numeric_length is the used to generate a numeric value based on a certain number of bits and is divided by 4 because of the way p and q are choosen later
        self.numeric_length = int((2**self.bit_length) / 4)

        # prime_list_b1 is the list of primes that are checked secret shared
        self.prime_list_b1 = self.get_primes(2, self.bound1)

        # prime_list_b2 is the list of primes that are checked locally by each player, and is therefore different for each player        
        if self.runtime.id == 1:
            self.prime_list_b2 = self.get_primes(self.bound1, self.bound2_p1)
        elif self.runtime.id == 2:
            self.prime_list_b2 = self.get_primes(self.bound2_p1, self.bound2_p2)
        else:
            self.prime_list_b2 = self.get_primes(self.bound2_p2, self.bound2_p3)
        
        #print self.prime_list_b1
        print "length of list b2 = " + str(len(self.prime_list_b2))

        # prime_pointer is used to point to the right prime number in the list at all times        
        self.prime_pointer = 0

        # list used for debugging how many times each function is run
        #
        self.function_count = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
        self.function_count_names = ["generate_p", "generate_q", "trial_division_p", "check_trial_division_p", "trial_division_q", "check_trial_division_q", "check_n", "primality_test_N", "check_primality_test_N", "generate_g", "check_g", "check_v", "generate_z", "check_z", "generate_l", "generate_d", "check_decrypt"]

        # l needs to be large enough to cope with all possible numbers that appear in the program during execution
        # if this value is too small, the values could wrap around the value of Zp.modulus and give bogus outputs
        l = int(self.bits_N * 3.5)
        k = runtime.options.security_parameter

        # For the comparison protocol to work, we need a field modulus
        # bigger than 2**(l+1) + 2**(l+k+1), where the bit length of
        # the input numbers is l and k is the security parameter.
        # Further more, the prime must be a Blum prime (a prime p such
        # that p % 4 == 3 holds). The find_prime function lets us find
        # a suitable prime.
        self.Zp = GF(find_prime(2**(l + 1) + 2**(l + k + 1), blum=True))

        #print self.Zp.modulus

        # start the protocol by each player generating its own private value for p        
        self.generate_p()


# Parse command line arguments.
parser = OptionParser()
Runtime.add_options(parser)
options, args = parser.parse_args()

if len(args) == 0:
    parser.error("you must specify a config file")
else:
    id, players = load_config(args[0])

# Create a deferred Runtime and ask it to run our protocol when ready.
#pre_runtime = create_runtime(id, players, 1, options, Toft05Runtime)
runtime_class = make_runtime_class(mixins=[ComparisonToft07Mixin])
pre_runtime = create_runtime(id, players, 1, options, runtime_class)
pre_runtime.addCallback(Protocol)

# Start the Twisted event loop.
reactor.run()
