# -*- coding: utf-8 -*-
"""
Created on Sun Jul 05 17:45:23 2015

@author: schae
"""
import numpy as np
import matplotlib.pyplot as plt
import time

import pycuda.driver as drv
from pycuda.compiler import SourceModule

# -- initialize the device
import pycuda.autoinit

blocks = 256
threads=16
nx=200
ny=200
n_iter=25
eta = np.zeros((ny,nx))
u = np.zeros((ny,nx))
v = np.zeros((ny,nx))
h = np.ones((ny,nx))*25
h[:,100::]=5
eta[25:150,25:50]=3

eta = eta.astype(np.float32)
u = u.astype(np.float32)
v = v.astype(np.float32)
h = h.astype(np.float32)

dx=1
dy=1
dt=1
# create two timers so we can speed-test each approach
start=drv.Event()
end = drv.Event()


kernel_code= """
  __global__ void eta_update(float *eta, float *eta_prev, float *u, float *v, float *h, float dt, int nx, int ny)
  {
    unsigned int ix = threadIdx.x + blockIdx.x * blockDim.x;
    unsigned int iy = blockIdx.y;
    int I = iy*nx+ix;
    
    if (I>=nx*ny){return;}
    
    if ( (I>nx) && (I < nx*ny-1-nx) && (I%nx !=0) && (I%nx != ny-1) )
    {    eta[I]=eta_prev[I]-dt*((u[I+1]*h[I+1]-u[I-1]*h[I-1])/2+(v[I+nx]*h[I+nx]-v[I-nx]*h[I-nx])/2);
    }
  }
  
  __global__ void u_update(float *u, float *u_prev, float *v, float *eta, float dt, int nx, int ny)
  {
    unsigned int ix = threadIdx.x + blockIdx.x * blockDim.x;
    unsigned int iy = blockIdx.y;
    int I = iy*nx+ix;
    
    if (I>=nx*ny){return;}
    
    if ( (I>nx) && (I < nx*ny-1-nx) && (I%nx !=0) && (I%nx != ny-1) )
    {    u[I]=u_prev[I]-dt*9.81*((eta[I+1]-eta[I-1])/(2)+
        ((u_prev[I]-abs(u_prev[I]))*(u_prev[I+1]-u_prev[I])+(u_prev[I]+abs(u_prev[I]))*(u_prev[I]-u_prev[I-1]))/2+
        ((v[I]-abs(v[I]))*(u_prev[I+nx]-u_prev[I])+(v[I]+abs(v[I]))*(u_prev[I]-u_prev[I-nx]))/2);
    }
  }
  
  __global__ void v_update(float *v, float *v_prev, float *u, float *eta, float dt, int nx, int ny)
  {
    unsigned int ix = threadIdx.x + blockIdx.x * blockDim.x;
    unsigned int iy = blockIdx.y;
    int I = iy*nx+ix;
    
    if (I>=nx*ny){return;}
    
    if ( (I>nx) && (I < nx*ny-1-nx) && (I%nx !=0) && (I%nx != ny-1) )
    {    v[I]=v_prev[I]-dt*9.81*((eta[I+nx]-eta[I-nx])/2+
        ((u[I]-abs(u[I]))*(v_prev[I+1]-v_prev[I])+(u[I]+abs(u[I]))*(v_prev[I]-v_prev[I-1]))/2+
        ((v_prev[I]-abs(v_prev[I]))*(v_prev[I+nx]-v_prev[I])+(v_prev[I]+abs(v_prev[I]))*(v_prev[I]-v_prev[I-nx]))/2);
    }
  }
  __global__ void Shapiro(float *eta, float *eta_prev, int nx, int ny)
  {
    unsigned int ix = threadIdx.x + blockIdx.x * blockDim.x;
    unsigned int iy = blockIdx.y;
    int I = iy*nx+ix;
    
    if (I>=nx*ny){return;}
    
    if ( (I>nx) && (I < nx*ny-1-nx) && (I%nx !=0) && (I%nx != ny-1) )
    {    eta[I]=0.5*eta_prev[I]+0.5*0.25*(eta_prev[I+1]+eta_prev[I-1]+eta_prev[I+nx]+eta_prev[I-nx]);
    }
  }
  """

mod = SourceModule(kernel_code)  


        
u_update = mod.get_function("u_update") #update x-velocities

v_update = mod.get_function("v_update") #update y-velocities

eta_update = mod.get_function("eta_update") #update water surface elevation

shapiro = mod.get_function("Shapiro") #filter surface elevation
#plt.pcolor(eta)
dt=0.05 #timestep
print('Commence wave propagation')
start.record() # start timing
t=time.time()
u_gpu=drv.mem_alloc(u.nbytes)
v_gpu=drv.mem_alloc(v.nbytes)
eta_gpu=drv.mem_alloc(eta.nbytes)
h_gpu=drv.mem_alloc(h.nbytes)
u_old_gpu=drv.mem_alloc(u.nbytes)
v_old_gpu=drv.mem_alloc(v.nbytes)
eta_old_gpu=drv.mem_alloc(eta.nbytes)
drv.memcpy_htod(u_gpu, u)
drv.memcpy_htod(v_gpu, v)
drv.memcpy_htod(eta_gpu, eta)
drv.memcpy_htod(h_gpu, h)

n_iter=1 #here it works
n_iter=10 #anything larger than 1 leads to zero'd arrays, is there anything missing?


for i in range(n_iter):
    u_old_gpu= u_gpu
    
    u_update(u_gpu, u_old_gpu, v_gpu, eta_gpu, np.float32(dt),
             np.int32(nx), np.int32(ny), grid=(nx,ny), block=(threads,1,1))
    v_old_gpu=v_gpu
    v_update(v_gpu, v_old_gpu, u_gpu, eta_gpu, np.float32(dt),
             np.int32(nx), np.int32(ny), grid=(nx,ny), block=(threads,1,1))
    eta_old_gpu=eta_gpu
    eta_update(eta_gpu, eta_old_gpu, u_gpu, v_gpu, h_gpu, np.float32(dt),
               np.int32(nx), np.int32(ny), grid=(nx,ny), block=(threads,1,1))
    eta_old_gpu=eta_gpu
    shapiro(eta_gpu, eta_old_gpu, 
            np.int32(nx), np.int32(ny), grid=(nx,ny), block=(threads,1,1))



t=time.time()-t
print(t)
end.record() # end timing
end.synchronize()

drv.memcpy_dtoh(eta, eta_gpu)
drv.memcpy_dtoh(u, u_gpu)
drv.memcpy_dtoh(v, v_gpu)
plt.pcolor(eta)

secs = start.time_till(end)*1e-3
print "SourceModule time"
print(secs)
