/*
 * Copyright (c) 2010 Philip Frey, Systems Group ETH Zurich.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

/*
 * Basic RDMA operations.
 */

#include <stdlib.h>
#include <string.h>

#include "iwarp.h"
#include "iwarp_debug.h"

static uint64_t rx_wr_id = 0;
static uint64_t tx_wr_id = 0;


////////////////////////////////////////////////////////////////////////////////
// PRIVATE HELPER METHODS
////////////////////////////////////////////////////////////////////////////////

/*
 * Make sure that the connection context is not NULL and has a valid qp.
 *
 * Return 0 if valid; -1 otherwise.
 */
static inline int ctx_invalid(
		IN	const struct iw_ctx_conn *ctx_conn)
{
	/* assert that the iwarp context is not NULL */
	if (ctx_null(ctx_conn)) {
		return -1;
	}

	/* assert that the qp is not null */
	if (!ctx_conn->qp) {
		dprint(DBG_ON, LOG_ERROR, "qp is NULL");
		return -1;
	}

	return 0;

}


/*
 * Make sure that the context is not NULL and in IWARP_CONNECTED state.
 *
 * Return 0 if connected and valid; -1 otherwise.
 */
static inline int ctx_unconnected(
		IN	const struct iw_ctx_conn *ctx_conn)
{
	/* assert that the iwarp context is not NULL */
	if (ctx_invalid(ctx_conn)) {
		return -1;
	}

	/* assert that the iwarp context in IWARP_CONNECTED state */
	if (ctx_conn->state != IWARP_CONNECTED) {
		dprint(DBG_ON, LOG_ERROR, "connection not established (state=%s)",
				iw_state_str(ctx_conn->state));
		return -1;
	}

	return 0;

}


/*
 * Make sure that the context is in a state where receive wrs can be posted to
 * the rq.
 *
 * Return 1 if context is ready; 0 otherwise.
 */
static inline int ctx_rtr(
		IN	const struct iw_ctx_conn *ctx_conn)
{
	//TODO: check in which states the rq is available
	return 1;
}


/*
 * Make sure that the context is in a state where send wrs can be posted to the
 * sq.
 *
 * Return 1 if context is ready; 0 otherwise.
 */
static inline int ctx_rts(
		IN	const struct iw_ctx_conn *ctx_conn)
{
	//TODO: check in which states the sq is available
	return 1;
}


////////////////////////////////////////////////////////////////////////////////
// INTERFACE METHODS
////////////////////////////////////////////////////////////////////////////////

int await_completions(
		IN		enum iw_cq_type				 cq,
		IN		int							 num_wcs,
		IN OUT	struct ibv_wc				*wc,
		IN OUT	struct iw_ctx_conn			*ctx_conn)
{
	int						 ret, missing;
	struct ibv_comp_channel	*target_comp_channel;
	struct ibv_cq			*target_cq, *event_cq;
	void					*event_cq_ctx;
	uint32_t				*target_acks;

	/* sanitize input */
	if (ctx_null(ctx_conn)) {
		return -1;
	}
	if (ctx_unconnected(ctx_conn)) {
		//TODO: can a completion arrive on a cq of an unconnected qp?
		return -1;
	}
	if (num_wcs <= 0) {
		dprint(DBG_ON, LOG_ERROR, "num_wcs must be at least 1");
		return -1;
	}
	if (!wc) {
		dprint(DBG_ON, LOG_ERROR, "array for work completions is NULL");
		return -1;
	}

	/* determine target cq and channel */
	if (cq == IW_SCQ) {
		target_comp_channel = ctx_conn->send_comp_channel;
		target_cq = ctx_conn->send_cq;
		target_acks = &ctx_conn->scq_acks;
	} else if (cq == IW_RCQ) {
		target_comp_channel = ctx_conn->recv_comp_channel;
		target_cq = ctx_conn->receive_cq;
		target_acks = &ctx_conn->rcq_acks;
	} else {
		dprint(DBG_ON, LOG_ERROR, "invalid cq type (%d)", cq);
		return -1;
	}

	/* poll to test if there are completions on the cq */
	ret = ibv_poll_cq(target_cq, num_wcs, wc);
	if (ret < 0) {
		dprint(DBG_ON, LOG_ERROR, "error polling the cq");
		return -1;
	} else if (ret == num_wcs) {
		dprint(DBG_OPS, LOG_INFO, "all completions received");
		return 0;
	} else {
		missing = num_wcs - ret;
		wc += ret;
		dprint(DBG_OPS, LOG_INFO, "%d of %d completions received", ret,
				num_wcs);
	}

	do {

		/* some (or all) wcs are still missing:
		 * wait for the next event on the completion channel */
		ret = (*lib->ops.s_ibv_get_cq_event)(target_comp_channel,
				&event_cq, &event_cq_ctx);
		if (ret) {
			dprint(DBG_ON, LOG_ERROR, "failed to get cq event from completion"
					" channel (%m)");
			return -1;
		}
		(*target_acks)++;
		if (event_cq != target_cq) {
			dprint(DBG_ON, LOG_ERROR, "cq mismatch");
			return -1;
		}

		/* acknowledge the events only when a threshold is reached */
		if (*target_acks >= MAX_PENDING_ACKS) {
			dprint(DBG_OPS, LOG_DEBUG, "ACKing CQ events");
			(*lib->ops.s_ibv_ack_cq_events)(target_cq, *target_acks);
			*target_acks = 0;
		}

		/* request notification for the next event */
		ret = ibv_req_notify_cq(target_cq, 0);
		if (ret) {
			dprint(DBG_ON, LOG_ERROR, "failed to request notification for next"
					" event (%m)");
			return -1;
		}

		/* get next wcs from target cq */
		ret = ibv_poll_cq(target_cq, missing, wc);
		if (ret < 0) {
			dprint(DBG_ON, LOG_ERROR, "error polling the cq");
			return -1;
		} else if (ret == missing) {
			dprint(DBG_OPS, LOG_INFO, "all completions received");
			break;
		} else {
			missing -= ret;
			wc += ret; //TODO: same as above - is this increment correct?
			dprint(DBG_OPS, LOG_INFO, "%d of %d completions received",
					num_wcs - missing, num_wcs);
		}

	} while(1);


	dprint(DBG_OPS, LOG_INFO, "all completions received");

	return 0;

}


int post_recv_sgl(
		IN	const struct iw_sgl			*dst,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_recv_wr	rx_wr, *bad_wr;

	/* sanitize input */
	if (ctx_invalid(ctx_conn) || !ctx_rtr(ctx_conn)) {
		return -1;
	}

	/* create receive work request */
	rx_wr.wr_id = rx_wr_id++;
	rx_wr.next = NULL;
	if (!dst) {
		rx_wr.sg_list = NULL;
		rx_wr.num_sge = 0;
		dprint(DBG_OPS, LOG_WARNING, "dst is NULL; using zero-length sgl");
	} else {
		rx_wr.sg_list = dst->sg_list;
		rx_wr.num_sge = dst->num_sge;
	}

	/* post receive work request */
	if (ctx_conn->iw_srq) {
		ret = ibv_post_srq_recv(ctx_conn->iw_srq->srq, &rx_wr, &bad_wr);
	} else {
		ret = ibv_post_recv(ctx_conn->qp, &rx_wr, &bad_wr);
	}
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma receive wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma receive posted successfully to rq");

	return rx_wr.wr_id;

}


int post_recv_lmr(
		IN	const struct iw_lmr			*dst,
		IN	uint32_t					 dst_off,
		IN	uint32_t					 len,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_recv_wr	rx_wr, *bad_wr;
	struct ibv_sge		sge;

	/* sanitize input */
	if (ctx_invalid(ctx_conn) || !ctx_rtr(ctx_conn)) {
		return -1;
	}

	/* make sure that the mem_dst (local) has sufficient privileges */
	if (dst && !(dst->access & IBV_ACCESS_LOCAL_WRITE)) {
		dprint(DBG_ON, LOG_ERROR, "rdma receive destination memory region must"
				" have at least IBV_ACCESS_LOCAL_WRITE permission");
		return -1;
	}

	/* create receive work request */
	rx_wr.wr_id = rx_wr_id++;
	rx_wr.next = NULL;
	if (!dst) {
		rx_wr.sg_list = NULL;
		rx_wr.num_sge = 0;
		dprint(DBG_OPS, LOG_WARNING, "dst is NULL; using zero-length sgl");
	} else {
		sge.addr = (uint64_t)dst->buf + dst_off;
		sge.length = len;
		sge.lkey = dst->mr->lkey;
		rx_wr.sg_list = &sge;
		rx_wr.num_sge = 1;
	}

	/* post receive work request */
	if (ctx_conn->iw_srq) {
		ret = ibv_post_srq_recv(ctx_conn->iw_srq->srq, &rx_wr, &bad_wr);
	} else {
		ret = ibv_post_recv(ctx_conn->qp, &rx_wr, &bad_wr);
	}
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma receive wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma receive posted successfully to rq");

	return rx_wr.wr_id;

}


int post_send_sgl(
		IN	const struct iw_sgl			*src,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	send_wr, *bad_wr;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}

	/* create send work request */
	send_wr.wr_id = tx_wr_id++;
	send_wr.next = NULL;
	if (!src) {
		send_wr.sg_list = NULL;
		send_wr.num_sge = 0;
		dprint(DBG_OPS, LOG_WARNING, "src is NULL; doing zero-length send");
	} else {
		send_wr.sg_list = src->sg_list;
		send_wr.num_sge = src->num_sge;
	}
	send_wr.opcode = IBV_WR_SEND;
	send_wr.send_flags = flags;

	/* post send work request */
	ret = ibv_post_send(ctx_conn->qp, &send_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma send wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma send posted successfully to sq");

	return send_wr.wr_id;

}

int post_send_lmr(
		IN	const struct iw_lmr			*src,
		IN	uint32_t					 src_off,
		IN	uint32_t					 len,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	send_wr, *bad_wr;
	struct ibv_sge		sge;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}

	/* create send work request */
	send_wr.wr_id = tx_wr_id++;
	send_wr.next = NULL;
	if (!src) {
		send_wr.sg_list = NULL;
		send_wr.num_sge = 0;
		dprint(DBG_OPS, LOG_WARNING, "src is NULL; doing zero-length send");
	} else {
		sge.addr = (uint64_t)src->buf + src_off;
		sge.length = len;
		sge.lkey = src->mr->lkey;
		send_wr.sg_list = &sge;
		send_wr.num_sge = 1;
	}
	send_wr.opcode = IBV_WR_SEND;
	send_wr.send_flags = flags;

	/* post send work request */
	ret = ibv_post_send(ctx_conn->qp, &send_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma send wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma send posted successfully to sq");

	return send_wr.wr_id;

}


int post_write_sgl(
		IN	const struct iw_sgl			*src,
		IN	const struct iw_rmr			*dst,
		IN	uint32_t					 dst_off,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	write_wr, *bad_wr;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}

	/* create write work request */
	write_wr.wr_id = tx_wr_id++;
	write_wr.next = NULL;
	if (!src || !dst) {
		write_wr.sg_list = NULL;
		write_wr.num_sge = 0;
		write_wr.wr.rdma.remote_addr = 0;
		write_wr.wr.rdma.rkey = 0;
		dprint(DBG_OPS, LOG_WARNING, "src or dst are NULL; doing zero-length"
				" write");
	} else {
		write_wr.sg_list = src->sg_list;
		write_wr.num_sge = src->num_sge;
		write_wr.wr.rdma.remote_addr = dst->addr + dst_off;
		write_wr.wr.rdma.rkey = dst->rkey;
	}
	write_wr.opcode = IBV_WR_RDMA_WRITE;
	write_wr.send_flags = flags;

	/* post write work request */
	ret = ibv_post_send(ctx_conn->qp, &write_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma write wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma write successfully posted to sq");

	return write_wr.wr_id;

}


int post_write_lmr(
		IN	const struct iw_lmr			*src,
		IN	uint32_t					 src_off,
		IN	const struct iw_rmr			*dst,
		IN	uint32_t					 dst_off,
		IN	uint32_t					 len,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	write_wr, *bad_wr;
	struct ibv_sge		sge;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}

	/* create write work request */
	write_wr.wr_id = tx_wr_id++;
	write_wr.next = NULL;
	if (!src || !dst) {
		write_wr.sg_list = NULL;
		write_wr.num_sge = 0;
		write_wr.wr.rdma.remote_addr = 0;
		write_wr.wr.rdma.rkey = 0;
		dprint(DBG_OPS, LOG_WARNING, "src or dst are NULL; doing zero-length"
				" write");
	} else {
		sge.addr = (uint64_t)src->buf + src_off;
		sge.length = len;
		sge.lkey = src->mr->lkey;
		write_wr.sg_list = &sge;
		write_wr.num_sge = 1;
		write_wr.wr.rdma.remote_addr = dst->addr + dst_off;
		write_wr.wr.rdma.rkey = dst->rkey;
	}
	write_wr.opcode = IBV_WR_RDMA_WRITE;
	write_wr.send_flags = flags;

	/* post write work request */
	ret = ibv_post_send(ctx_conn->qp, &write_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma write wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma write successfully posted to sq");

	return write_wr.wr_id;

}


int post_read_sgl(
		IN	const struct iw_sgl			*dst,
		IN	const struct iw_rmr			*src,
		IN	uint32_t					 src_off,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	read_wr, *bad_wr;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}
	if (dst && dst->num_sge > 1) {
		dprint(DBG_ON, LOG_ERROR, "dst cannot have more than 1 sge");
		return -1;
	}

	/* create read work request */
	read_wr.wr_id = tx_wr_id++;
	read_wr.next = NULL;
	if (!dst || !src) {
		read_wr.sg_list = NULL;
		read_wr.num_sge = 0;
		read_wr.wr.rdma.remote_addr = 0;
		read_wr.wr.rdma.rkey = 0;
		dprint(DBG_OPS, LOG_WARNING, "src or dst are NULL; doing zero-length"
				" read");
	} else {
		read_wr.sg_list = dst->sg_list;
		read_wr.num_sge = dst->num_sge;
		read_wr.wr.rdma.remote_addr = (uint64_t)src->addr + src_off;
		read_wr.wr.rdma.rkey = src->rkey;
	}
	read_wr.opcode = IBV_WR_RDMA_READ;
	if (!flags) {
		dprint(DBG_ON, LOG_WARNING, "unsignaled read unsupported; doing"
				" signaled read instead");
	}
	read_wr.send_flags = flags | IBV_SEND_SIGNALED;

	/* post read work request */
	ret = ibv_post_send(ctx_conn->qp, &read_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma read wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma read successfuly posted to sq");

	return read_wr.wr_id;

}


int post_read_lmr(
		IN	const struct iw_lmr			*dst,
		IN	uint32_t					 dst_off,
		IN	const struct iw_rmr			*src,
		IN	uint32_t					 src_off,
		IN	uint32_t					 len,
		IN	enum ibv_send_flags			 flags,
		IN	const struct iw_ctx_conn	*ctx_conn)
{
	int					ret;
	struct ibv_send_wr	read_wr, *bad_wr;
	struct ibv_sge		sge;

	/* sanitize input */
	if (ctx_unconnected(ctx_conn) || !ctx_rts(ctx_conn)) {
		return -1;
	}
	if (!dst || !src) {
		dprint(DBG_ON, LOG_ERROR, "read operation MUST have a dst and a src"
				" buffer");
		return -1;
	}
	if (dst && (!(dst->access & IBV_ACCESS_LOCAL_WRITE) ||
			!(dst->access & IBV_ACCESS_REMOTE_WRITE))) {
		dprint(DBG_ON, LOG_ERROR, "rdma read destination memory region must"
				" have at least IBV_ACCESS_LOCAL_WRITE and"
				" IBV_ACCESS_REMOTE_WRITE permissions");
		return -1;
	}

	/* create read work request */
	read_wr.wr_id = tx_wr_id++;
	read_wr.next = NULL;
	sge.addr = (uint64_t)dst->buf + dst_off;
	sge.length = len;
	sge.lkey = dst->mr->lkey;
	read_wr.sg_list = &sge;
	read_wr.num_sge = 1;
	read_wr.wr.rdma.remote_addr = (uint64_t)src->addr + src_off;
	read_wr.wr.rdma.rkey = src->rkey;
	read_wr.opcode = IBV_WR_RDMA_READ;
	if (!flags) {
		dprint(DBG_ON, LOG_WARNING, "unsignaled read unsupported");
	}
	read_wr.send_flags = flags | IBV_SEND_SIGNALED;

	/* post read work request */
	ret = ibv_post_send(ctx_conn->qp, &read_wr, &bad_wr);
	if (ret) {
		dprint(DBG_ON, LOG_ERROR, "failed to post rdma read wr");
		return -1;
	}

	dprint(DBG_OPS, LOG_INFO, "rdma read successfuly posted to sq");

	return read_wr.wr_id;

}
