#!/usr/bin/env python
'''
Polynomial code with fast decoding
'''

from mpi4py import MPI
import numpy as np
import random
import time

##################### Parameters ########################
# Use one master and N workers
N = 5

# Matrix division
m = 2
n = 2
p = 2

#Parameter ell needs to be an integer greater than N
ell = 8

#Input matrix size - A: s by r, B: s by t.
s = 10
r = 10
t = 10

#Recovery threshold
T = m*n

#Parameters of the code and some needed for decoding
s_cod = 2
s_cod_2ell = pow(s_cod, 2*ell)
s_cod_ell = pow(s_cod, ell)

# Values of x_i used by workers. They are the N-th roots of unity with small values rounded to zero.
var = np.array([np.exp(1j*2*np.pi*i/ell) for i in range(N)])

comm = MPI.COMM_WORLD

#Check for wrong number of MPI processes
if comm.size != N+1:
	print("The number of MPI processes mismatches the number of workers.")
	comm.Abort(1)
	
if comm.rank == 0:

	# Master

	print "Running with %d processes:" % comm.Get_size()

	#Create random matrices. Now it doesn't make sense to use np.int64.
	A=np.matrix(np.random.random_integers(0,1,(s,r)))
	B=np.matrix(np.random.random_integers(0,1,(s,t)))
	
	#64-bit precision is inherited here
	Ah=np.split(A,p)
	Bh=np.split(B,p)

	Ahv = []
	Bhv = []
	for i in range(p):
		Ahv.append(np.split(Ah[i], m, axis=1))
		Bhv.append(np.split(Bh[i], m, axis=1))
		
	# Encode the matrices
	Aenc = []
	Benc = []
	for i in range(N):
		Aenc.append(Ahv[0][0]*(pow(s_cod*var[i], ell)) + Ahv[1][0] + Ahv[0][1]*(pow(s_cod*var[i], ell+1)) + Ahv[1][1]*s_cod*var[i])
		Benc.append(Bhv[0][0] + Bhv[1][0]*(pow(s_cod*var[i], ell)) + Bhv[0][1]*(pow(s_cod*var[i], 2)) + Bhv[1][1]*(pow(s_cod*var[i], ell+2)))                        
	
	# Initialize return dictionary
	Rdict = []
	for i in range(N):
		Rdict.append(np.zeros((r/m, t/n), dtype=np.complex128))

	# Start requests to send and receive
	reqA = [None] * N
	reqB = [None] * N
	reqC = [None] * N
  
	bp_start = time.time()
  
	for i in range(N):
		reqA[i] = comm.Isend([Aenc[i],MPI.C_DOUBLE_COMPLEX], dest=i+1, tag=15)
		reqB[i] = comm.Isend([Benc[i],MPI.C_DOUBLE_COMPLEX], dest=i+1, tag=29)
		reqC[i] = comm.Irecv([Rdict[i],MPI.C_DOUBLE_COMPLEX], source=i+1, tag=42)

	MPI.Request.Waitall(reqA)
	MPI.Request.Waitall(reqB)

	bp_sent = time.time()
	print "Time spent sending all messages is: %f" % (bp_sent - bp_start)

	Crtn = [None] * N
	lst = []
	#Wait for the mn fastest workers
	for i in range(T):
		j = MPI.Request.Waitany(reqC)
		lst.append(j)
		Crtn[j] = Rdict[j]
	bp_received = time.time()
	print "Time spent waiting for %d workers %s is: %f" % (T, ",".join(map(str, [x+1 for x in lst])), (bp_received - bp_sent))
	
	#Compute the inverse of Vandermonde manually based on NASA paper. Do not change the ones_like()
	L_inv = np.ones_like(np.matrix([[0]*(T) for i in range(T)])).astype(np.complex128)
	for i in range(T):
		for j in range(T):
			if i < j:
				L_inv[i,j] = 0
			elif i > 0:
				ran = range(i+1)
				del ran[j]
				prod_terms = np.array([1.0/(s_cod*var[lst[j]]-s_cod*var[lst[k]]) for k in ran]).astype(np.complex128)
				prod_tmp = 1+0j
				for k in range(i):
					prod_tmp = prod_tmp*prod_terms[k]
				L_inv[i,j] = prod_tmp
				
	U_inv = np.empty_like(np.matrix([[0]*(T) for i in range(T)])).astype(np.complex128)
	for i in range(T):
		for j in range(T):
			if i == j:
				U_inv[i,j] = 1
			elif j == 0:
				U_inv[i,j] = 0
			else:
				if i == 0:
					U_inv[i,j] = -U_inv[i,j-1]*s_cod*var[lst[j-1]]
				else:
					U_inv[i,j] = U_inv[i-1,j-1]-U_inv[i,j-1]*s_cod*var[lst[j-1]]
								  
	V_inv = np.dot(U_inv, L_inv)

	print("Starting decoding...")
	
	#Decode only the 4 useful terms (blocks) of C
	C_block = []
		
	ctr=0
	for i in range(T):
	 
		#We need all 4 terms

		# C_block.append(np.empty_like(np.matrix([[0]*(t/n) for q in range(r/m)])).astype(np.complex128))
		C_block.append(np.empty_like(np.matrix([[0]*(t/n) for q in range(r/m)])).astype(np.int64))
		
		#Do element-wise polynomial interpolation by Vandermonde inversion
		for j in range(r/m):
			for k in range(t/n):
				
				tmp_sum = 0+0j
				for l in range(T):
					tmp_sum = tmp_sum + np.dot(V_inv[i,l], Crtn[lst[l]][j,k])
				
				#Decoded values should be real
				C_block[ctr][j,k] = np.round(np.real(tmp_sum))
		
		ctr += 1
	
	#Decode and concatenate column by column.
	C = np.empty((r,0), int)
	for i in range(n):
		cur_col = np.empty((0, t/n), int)
		
		#Construct column
		for j in range(m):			
			cur_col = np.append(cur_col, np.mod(C_block[n*i+j], s_cod_2ell)//s_cod_ell, axis=0)
		
		#Concatenate column 
		C = np.append(C, cur_col, axis=1)
	
	#Test
	print("The product is: ")
	print(C)	
	
	bp_done = time.time()
	print "Time spent decoding is: %f" % (bp_done - bp_received)

else:
	# Worker

	# Receive split input matrices from the master
	Ai = np.empty_like(np.matrix([[0]*(r/m) for i in range(s/p)])).astype(np.complex128)
	Bi = np.empty_like(np.matrix([[0]*(t/n) for i in range(s/p)])).astype(np.complex128)
		
	rA = comm.Irecv(Ai, source=0, tag=15)
	rB = comm.Irecv(Bi, source=0, tag=29)

	rA.wait()
	rB.wait()

	wbp_received = time.time()

	#We compute A^T*B
	
	Ai_mat = np.matrix(Ai,dtype=object)
	Bi_mat = np.matrix(Bi,dtype=object)
	Ci_mat = np.dot(Ai_mat.getT(), Bi_mat)
	Ci_mat_np = np.concatenate(Ci_mat).astype(np.complex128)
		
	wbp_done = time.time()
	print "Worker %d computing takes: %f\n" % (comm.Get_rank(), wbp_done - wbp_received)

	sC = comm.Isend(Ci_mat_np, dest=0, tag=42)
	sC.Wait()
