/*-------------------------------------------------------------------------
 *
 * pg_stat_statements.c
 *
 *-------------------------------------------------------------------------
 */
#include "postgres.h"

#include "access/hash.h"
#include "catalog/pg_type.h"
#include "executor/executor.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "optimizer/planner.h"
#include "pgstat.h"
#include "portability/instr_time.h"
#include "storage/spin.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "utils/hsearch.h"
#include "utils/memutils.h"

PG_MODULE_MAGIC;

void	_PG_init(void);
void	_PG_fini(void);
Datum	pg_stat_statements_reset(PG_FUNCTION_ARGS);
Datum	pg_stat_statements(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(pg_stat_statements_reset);
PG_FUNCTION_INFO_V1(pg_stat_statements);

static PlannedStmt *my_planner(Query *parse, int cursorOptions,
							   ParamListInfo boundParams);
static TupleTableSlot *my_executor(QueryDesc *queryDesc,
								   ScanDirection direction, long count);

static planner_hook_type	prev_planner_hook = NULL;
static executor_hook_type	prev_executor_hook = NULL;

typedef struct StatStmtData
{
	LWLockId	lock;
	Size		buf_size;
	Size		buf_used;
	char		buffer[1];
} StatStmtData;

typedef struct StatStmtEntry
{
	uint32			tag;
	Oid				userid;
	Oid				datid;
	char		   *query;
	PgStat_Counter	planned;
	PgStat_Counter	calls;
	instr_time		total_time;
	slock_t			mutex;
} StatStmtEntry;

static long				max_statements = 0;
static Size				statement_buffer = 0;

static StatStmtData	   *pgStatStmt;
static HTAB			   *pgStatStmtHash;

#define DEFAULT_MAX_STATEMENTS		512
#define DEFAULT_STATEMENT_BUFFER	1024	/* in kB */
#define MAX_STATEMENT_LEGNTH		2000	/* in byte */

static Size	stat_stmt_memsize(void);

static void *
stat_stmt_alloc(Size size)
{
	void *ptr;

	if (pgStatStmt->buf_used + size > pgStatStmt->buf_size)
		return NULL;	/* out of memory */

	ptr = pgStatStmt->buffer + pgStatStmt->buf_used;
	pgStatStmt->buf_used += size;
	return ptr;
}

void
_PG_init(void)
{
	RequestAddinShmemSpace(stat_stmt_memsize());
	RequestAddinLWLocks(1);

	prev_planner_hook = planner_hook;
	planner_hook = my_planner;

	prev_executor_hook = executor_hook;
	executor_hook = my_executor;
}

void
_PG_fini(void)
{
	planner_hook = prev_planner_hook;
	executor_hook = prev_executor_hook;
}

static uint32
make_stmt_tag(Oid userid, Oid datid, const char *query)
{
	return oid_hash(&userid, sizeof(Oid)) ^
	       oid_hash(&datid, sizeof(Oid)) ^
           DatumGetUInt32(hash_any((const unsigned char *) query,
									strlen(query)));
}

static Size
stat_stmt_memsize(void)
{
	if (max_statements == 0)
	{
		const char *option;

		option = GetConfigOption("statspack.max_statements");
		if (option)
			max_statements = pg_atoi((char *)option, sizeof(int32), 1);
		else
			max_statements = DEFAULT_MAX_STATEMENTS;

		option = GetConfigOption("statspack.statement_buffer");
		if (option)
			statement_buffer = pg_atoi((char *)option, sizeof(int32), 1) * 1024L;
		else
			statement_buffer = DEFAULT_STATEMENT_BUFFER * 1024L;
	}

	return sizeof(StatStmtData) + statement_buffer +
		hash_estimate_size(max_statements, sizeof(StatStmtEntry));
}

static void
stat_stmt_init(void)
{
	StatStmtData   *ptr;
    bool			found;
	Size			size;
	HASHCTL			info = { 0 };

	elog(DEBUG1, "pg_stat_statements: init");

	size = stat_stmt_memsize();

	LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);

	ptr = ShmemInitStruct("Statement stat", size, &found);
	if (!ptr)
		elog(ERROR, "out of shared memory");
	if (!found)
    {
		ptr->lock = LWLockAssign();
		ptr->buf_size = statement_buffer;
		ptr->buf_used = 0;
    }

	info.keysize = sizeof(uint32);
	info.entrysize = sizeof(StatStmtEntry);
	info.hash = oid_hash;
	pgStatStmtHash = ShmemInitHash("Statement stat entries",
								   max_statements, max_statements,
								   &info,
								   HASH_ELEM | HASH_FUNCTION);
	if (!pgStatStmtHash)
		elog(ERROR, "out of shared memory");

	LWLockRelease(AddinShmemInitLock);

	pgStatStmt = ptr;
}

static void
stat_stmt_planned(uint32 id, Oid userid, Oid datid, const char *query)
{
	volatile StatStmtEntry *stat;

	/* Initialize shared data */
	if (!pgStatStmt)
		stat_stmt_init();

	/* Get the stats entry for this statement, create if necessary */
	LWLockAcquire(pgStatStmt->lock, LW_SHARED);

	stat = hash_search(pgStatStmtHash, &id, HASH_FIND, NULL);
	if (!stat)
	{
		bool		found;

		/* New entry. Re-acquire exclusive lock */
		LWLockRelease(pgStatStmt->lock);
		LWLockAcquire(pgStatStmt->lock, LW_EXCLUSIVE);

		stat = hash_search(pgStatStmtHash, &id, HASH_ENTER_NULL, &found);
		if (stat == NULL)
		{
			/* out of memory */
			LWLockRelease(pgStatStmt->lock);
			return;
		}

		if (!found)
		{
			Size	len;

			stat->userid = userid;
			stat->datid = datid;

			len = strlen(query);
			if (len > MAX_STATEMENT_LEGNTH)
				len = MAX_STATEMENT_LEGNTH;
			stat->query = stat_stmt_alloc(len + 1);
			if (stat->query)
			{
				memcpy(stat->query, query, len);
				stat->query[len] = '\0';
			}
			else
				stat->query = "(out of memory)";

			stat->planned = 0;
			stat->calls = 0;
			INSTR_TIME_SET_ZERO(stat->total_time);
			SpinLockInit(&stat->mutex);
		}
	}

	/*
	 * We don't need the spin lock on stat->mutex here because
	 * we hold an exclusive lock on pgStatStmt->lock.
	 */
	stat->planned += 1;

	LWLockRelease(pgStatStmt->lock);
}

static bool
stat_stmt_executed(uint32 id, instr_time duration)
{
	volatile StatStmtEntry *stat;

	if (!pgStatStmt)
		return false;

	LWLockAcquire(pgStatStmt->lock, LW_SHARED);

	stat = hash_search(pgStatStmtHash, &id, HASH_FIND, NULL);
	if (stat == NULL)
	{
		/* not found */
		LWLockRelease(pgStatStmt->lock);
		return false;
	}

	SpinLockAcquire(&stat->mutex);
	INSTR_TIME_ADD(stat->total_time, duration);
	stat->calls += 1;
	SpinLockRelease(&stat->mutex);

	LWLockRelease(pgStatStmt->lock);

	return true;
}

static PlannedStmt *
my_planner(Query *parse, int cursorOptions, ParamListInfo boundParams)
{
	PlannedStmt	   *result;
	const char	   *query = debug_query_string;

	result = standard_planner(parse, cursorOptions, boundParams);

	if (query)
	{
		Oid	userid = GetUserId();
		result->tag = make_stmt_tag(userid, MyDatabaseId, query);
		stat_stmt_planned(result->tag, userid, MyDatabaseId, query);
	}

	return result;
}

TupleTableSlot *
my_executor(QueryDesc *queryDesc, ScanDirection direction, long count)
{
	TupleTableSlot	   *result;
	instr_time			starttime;
	instr_time			duration;

	INSTR_TIME_SET_CURRENT(starttime);

	result = ExecutePlan(queryDesc->estate,
					   queryDesc->planstate,
					   queryDesc->operation,
					   count,
					   direction,
					   queryDesc->dest);

	INSTR_TIME_SET_CURRENT(duration);
	INSTR_TIME_SUBTRACT(duration, starttime);

	stat_stmt_executed(queryDesc->plannedstmt->tag, duration);

	return result;
}

Datum
pg_stat_statements_reset(PG_FUNCTION_ARGS)
{
	if (pgStatStmt)
	{
		HASH_SEQ_STATUS		hash_seq;
		StatStmtEntry	   *stat;

		LWLockAcquire(pgStatStmt->lock, LW_EXCLUSIVE);

		hash_seq_init(&hash_seq, pgStatStmtHash);
		while ((stat = hash_seq_search(&hash_seq)) != NULL)
			hash_search(pgStatStmtHash, &stat->tag, HASH_REMOVE, NULL);

		pgStatStmt->buf_used = 0;

		LWLockRelease(pgStatStmt->lock);
	}

	PG_RETURN_BOOL(true);
}

Datum
pg_stat_statements(PG_FUNCTION_ARGS)
{
	ReturnSetInfo	   *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
	TupleDesc			tupdesc;
	Tuplestorestate    *tupstore;
	MemoryContext		per_query_ctx;
	MemoryContext		oldcontext;

	/* check to see if caller supports us returning a tuplestore */
	if (rsinfo == NULL || !IsA(rsinfo, ReturnSetInfo))
		ereport(ERROR,
				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
				 errmsg("set-valued function called in context that cannot accept a set")));
	if (!(rsinfo->allowedModes & SFRM_Materialize))
		ereport(ERROR,
				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
				 errmsg("materialize mode required, but it is not " \
						"allowed in this context")));

	/* Initialize shared data */
	if (!pgStatStmt)
		stat_stmt_init();

	per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
	oldcontext = MemoryContextSwitchTo(per_query_ctx);

	tupdesc = CreateTupleDescCopy(rsinfo->expectedDesc);
	tupstore = tuplestore_begin_heap(true, false, work_mem);
	rsinfo->returnMode = SFRM_Materialize;
	rsinfo->setResult = tupstore;
	rsinfo->setDesc = tupdesc;

	if (pgStatStmt)
	{
		HASH_SEQ_STATUS		hash_seq;
		StatStmtEntry	   *stat;

		LWLockAcquire(pgStatStmt->lock, LW_SHARED);

		hash_seq_init(&hash_seq, pgStatStmtHash);
		while ((stat = hash_seq_search(&hash_seq)) != NULL)
		{
			Datum		values[6];
			bool		nulls[6] = { 0 };
			int64		usec;
			int			i = 0;

			/* generate junk in short-term context */
			MemoryContextSwitchTo(oldcontext);

			values[i++] = ObjectIdGetDatum(stat->userid);
			values[i++] = ObjectIdGetDatum(stat->datid);

			if (stat->query == NULL)
				nulls[i++] = true;
			else
				values[i++] = CStringGetTextDatum(stat->query);

			/* XXX: Is spinlock needed here? */
			values[i++] = Int64GetDatumFast(stat->planned);
			values[i++] = Int64GetDatumFast(stat->calls);
			usec = INSTR_TIME_GET_MICROSEC(stat->total_time);
			values[i++] = Int64GetDatumFast(usec);

			/* switch to appropriate context while storing the tuple */
			MemoryContextSwitchTo(per_query_ctx);
			tuplestore_putvalues(tupstore, tupdesc, values, nulls);
		}

		LWLockRelease(pgStatStmt->lock);
	}

	/* clean up and return the tuplestore */
	tuplestore_donestoring(tupstore);

	MemoryContextSwitchTo(oldcontext);

	return (Datum) 0;
}
