#include <unistd.h>
#include <fcntl.h>
#include <getopt.h>

#ifdef LWIP
#include "lwip/opt.h"
#include "lwip/init.h"
#include "lwip/sys.h"
#include "lwip/tcp.h"
#include "lwip/inet_chksum.h"
#include "lwip/tcpip.h"

#include "netif/ethernetif.h"

#include "lwip/api.h"
#include "arch/perf.h"

#include "lwip/sockets.h"
/* nonstatic debug cmd option, exported in lwipopts.h */
unsigned char debug_flags;
#else
#include <sys/socket.h>       /*  socket definitions        */
#include <sys/types.h>        /*  socket types              */
#include <arpa/inet.h>        /*  inet (3) funtions         */
#include <netinet/in.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#endif

#define LISTENQ        (1024)   /*  Backlog for listen()   */


/** @todo add options for selecting netif, starting DHCP client etc */
static struct option longopts[] = {
  /* turn on debugging output (if build with LWIP_DEBUG) */
  {"debug", no_argument,        NULL, 'd'},
  /* server */
  {"is_server", no_argument,        NULL, 's'},
  /* help */
  {"help", no_argument, NULL, 'h'},
  /* gateway address */
  {"gateway", required_argument, NULL, 'g'},
  /* port */
  {"port", required_argument, NULL, 'p'},
  /* ip address */
  {"ipaddr", required_argument, NULL, 'i'},
  /* netmask */
  {"netmask", required_argument, NULL, 'm'},
  /* server ip address */
  {"serverip", required_argument, NULL, 'H'},
  /* message_size */
  {"message_size", required_argument, NULL, 'M'},
  {NULL,   0,                 NULL,  0}
};
#define NUM_OPTS ((sizeof(longopts) / sizeof(struct option)) - 1)
void usage(void)
{
  unsigned char i;

  printf("options:\n");
  for (i = 0; i < NUM_OPTS; i++) {
    printf("-%c --%s\n",longopts[i].val, longopts[i].name);
  }
}

#ifdef LWIP
static void
tcpip_init_done(void *arg)
{
  sys_sem_t *sem;
  sem = arg;
  sys_sem_signal(*sem);
}
#endif

static void client(int bind_port, int send_size, int server_ip)
{
  int send_socket =socket(AF_INET, SOCK_STREAM, 0);
  struct sockaddr_in servaddr;
  int sec, usec,total_bytes_sent=0;
  float elapsed_time;
  struct timeval time1, time2;
  int bytes_remaining = 1000000000;
  char *buf;
  int len, i;

#ifdef LWIP
  buf = mem_malloc(send_size*sizeof(char));
#else
  buf = malloc(send_size*sizeof(char));
#endif
  if (buf == NULL)
  {
    printf("malloc of %d bytes failed. Please reduce message size or increase MEM_SIZE\n",send_size);
    exit(1);
  }

  for (i=0; i<send_size; i++)
  {
    buf[i] = 0x7a;
  }
  buf[send_size-1] = 0;
  memset(&servaddr, 0, sizeof(servaddr));
  servaddr.sin_family      = AF_INET;
  servaddr.sin_addr.s_addr = server_ip;
  servaddr.sin_port        = htons(bind_port);
  if (connect(send_socket,(struct sockaddr *)&servaddr, sizeof(servaddr)) < 0)
  {
    printf("rishi_perf: client: data socket connect failed\n");
    exit(1);
  }
  /* start timer */
  gettimeofday(&time1,NULL);
  while (bytes_remaining > 0)
  {
    if (bytes_remaining < send_size) len = bytes_remaining;
    else len = send_size;
    if ((len=send(send_socket,buf,send_size,0)) != send_size)
    {
      if (len >= 0) break;
      printf ("rishi_perf: client: error in sending data\n");
      exit(1);
    }
    total_bytes_sent += len;
    bytes_remaining -= len;
    printf("sent %d, left=%d\n",len, bytes_remaining);
  }
  gettimeofday(&time2,NULL);
  buf[0] = 'c';
  send(send_socket,buf,1,0);

  if (time2.tv_usec < time1.tv_usec) {
    time2.tv_usec += 1000000;
    time2.tv_sec  -= 1;
  }
  sec  = time2.tv_sec - time1.tv_sec;
  usec = time2.tv_usec - time1.tv_usec;
  elapsed_time  = (float)sec + ((float)usec/(float)1000000.0);
#ifdef LWIP
  LWIP_PLATFORM_DIAG(("Bytes Received=%d Time = %f\nBytes/sec = %f\n",total_bytes_sent,elapsed_time,total_bytes_sent/elapsed_time));
#else
  printf("Bytes Received=%d Time = %f\nBytes/sec = %f\n",total_bytes_sent,elapsed_time,total_bytes_sent/elapsed_time);
#endif

  /* Close connection and discard connection identifier. */
  if ( close(send_socket) < 0 ) {
    fprintf(stderr, "rishi_perf: client: Error calling close()\n");
    exit(EXIT_FAILURE);
  }

}

static void server(int bind_port, int recv_size)
{
  int listen_sock =socket(AF_INET, SOCK_STREAM, 0), connect_sock;
  struct sockaddr_in servaddr;
  char *buf;
  int len,first_packet;
  int sec, usec,total_bytes_received;
  float elapsed_time;
  struct timeval time1, time2;
  struct sockaddr_in peeraddr_in;
  socklen_t addrlen;
  char return_msg[1024];
#ifdef LWIP
  buf = mem_malloc(recv_size*sizeof(char));
#else
  buf = malloc(recv_size*sizeof(char));
#endif
  if (buf == NULL)
  {
    printf("malloc of %d bytes failed. Please reduce message size or increase MEM_SIZE\n",recv_size);
    exit(1);
  }
  if ( (listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0 )
  {
	  printf("rishi_perf: server: Error creating listening socket.\n");
	  exit(1);
  }
  memset(&servaddr, 0, sizeof(servaddr));
  servaddr.sin_family      = AF_INET;
  servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
  servaddr.sin_port        = htons(bind_port);

  if ( bind(listen_sock, (struct sockaddr *) &servaddr, sizeof(servaddr)) < 0 )
  {
	  printf("rishi_perf: server: Error calling bind()\n");
	  exit(1);
  }

  if ( listen(listen_sock, LISTENQ) < 0 ) {
    printf("rishi_perf: server: Error calling listen()\n");
    exit(1);
  }

  addrlen = sizeof(peeraddr_in);
  while (1)
  {
  	/*  Wait for a connection, then accept() it  */
	  if ( (connect_sock = accept(listen_sock, (struct sockaddr *)&peeraddr_in, &addrlen) ) < 0 )
	  {
	    printf("rishi_perf: server: Error calling accept()\n");
	    break;
	  }

    first_packet = 1;
    while ((len = recv(connect_sock,buf, recv_size, 0)) > 0)
    {
      if (first_packet)
      {
        /* start timer */
        gettimeofday(&time1,NULL);
        total_bytes_received = 0;
        first_packet = 0;
      }
      total_bytes_received += len;
      if (buf[0] == 'c') break;
      //printf("got packet len=%d\n",len);
    }
    gettimeofday(&time2,NULL);

    if (time2.tv_usec < time1.tv_usec) {
      time2.tv_usec += 1000000;
      time2.tv_sec  -= 1;
    }

    sec  = time2.tv_sec - time1.tv_sec;
    usec = time2.tv_usec - time1.tv_usec;
    elapsed_time  = (float)sec + ((float)usec/(float)1000000.0);
    sprintf(return_msg,"Bytes Received=%d Time = %f\nBytes/sec = %f\n",total_bytes_received,elapsed_time,total_bytes_received/elapsed_time);
    send(connect_sock,return_msg,strlen(return_msg),0);
#ifdef LWIP
    LWIP_PLATFORM_DIAG(("%s\n",return_msg));
#else
    printf("%s\n",return_msg);
#endif

    /* Close connection and discard connection identifier. */
    if ( close(connect_sock) < 0 ) {
      fprintf(stderr, "rishi_perf: server: Error calling close()\n");
      exit(EXIT_FAILURE);
    }
  }
}


int
main(int argc, char **argv)
{
#ifdef LWIP
  struct netif netif;
  sys_sem_t sem;
  static struct ip_addr ipaddr, netmask, gw;
#endif
  struct in_addr inaddr;
  int ch;
  char ip_str[16] = {0}, nm_str[16] = {0}, gw_str[16] = {0};
  int is_server = 0;
  int server_ip=0;
  int bind_port;
  int message_size = 1024;

#ifdef LWIP
  /* startup defaults (may be overridden by one or more opts) */
  IP4_ADDR(&gw, 172,16,184,2);
  IP4_ADDR(&netmask, 255,255,255,0);
  IP4_ADDR(&ipaddr, 172,16,184,130);
  /* use debug flags defined by debug.h */
  debug_flags = LWIP_DBG_OFF;
#endif
  bind_port = 12865;
  while ((ch = getopt_long(argc, argv, "M:dshg:i:m:p:H:", longopts, NULL)) != -1) {
    switch (ch) {
#ifdef LWIP
      case 'd':
        debug_flags |= (LWIP_DBG_ON|LWIP_DBG_TRACE|LWIP_DBG_STATE|LWIP_DBG_FRESH|LWIP_DBG_HALT);
        break;
      case 'g':
        inet_aton(optarg, &inaddr);
        gw.addr = inaddr.s_addr;
        break;
      case 'i':
        inet_aton(optarg, &inaddr);
        ipaddr.addr = inaddr.s_addr;
        break;
      case 'm':
        inet_aton(optarg, &inaddr);
        netmask.addr = inaddr.s_addr;
        break;
#endif
      case 's':
        is_server = 1;
        break;
      case 'h':
        usage();
        exit(0);
        break;
      case 'p':
        bind_port = atoi(optarg);
        break;
      case 'H':
        inet_aton(optarg, &inaddr);
        server_ip = inaddr.s_addr;
        break;
      case 'M':
        message_size = atoi(optarg);
        break;
      default:
        usage();
        exit(0);
        break;
    }
  }
  argc -= optind;
  argv += optind;

#ifdef LWIP
  inaddr.s_addr = ipaddr.addr;
  strncpy(ip_str,inet_ntoa(inaddr),sizeof(ip_str));
  inaddr.s_addr = netmask.addr;
  strncpy(nm_str,inet_ntoa(inaddr),sizeof(nm_str));
  inaddr.s_addr = gw.addr;
  strncpy(gw_str,inet_ntoa(inaddr),sizeof(gw_str));
  printf("Host at %s mask %s gateway %s\n", ip_str, nm_str, gw_str);
#ifdef PERF
  perf_init("/tmp/simhost.perf");
#endif /* PERF */

  lwip_init();
  netif_init();
  sem = sys_sem_new(0);
  tcpip_init(tcpip_init_done, &sem);
  sys_sem_wait(sem);
  sys_sem_free(sem);
  printf("TCP/IP initialized.\n");

  netif_set_default(netif_add(&netif,&ipaddr, &netmask, &gw, NULL, ethernetif_init,
                  tcpip_input));
  netif_set_up(&netif);
#endif
  if (is_server) server(bind_port, message_size);
  else client(bind_port,message_size,server_ip);
  return (0);
}
