/*
 * Wrapper functions around the PG core aggregates that allow definition of
 * custom aggregates.
 *
 * We need a custom aggregate on shard node that performs no
 * "finalization". Since the core functions have "internal" type in the
 * catalog, they cannot be used in CREATE AGGREGATE command. Therefore we need
 * to define a custom transient type as well as functions that work with it.
 */
#include "postgres.h"
#include "fmgr.h"
#include "nodes/nodes.h"
#include "nodes/execnodes.h"

PG_MODULE_MAGIC;

typedef struct PartialAggState
{
	/*
	 * The actual, in-core aggregation state.
	 *
	 * TODO Use union as soon as additional types are needed.
	 */
	//struct NumericAggState *state;
	Datum	state;

	/*
	 * It makes sense to cache FmgrInfo of the frequently called functions.
	 *
	 * Core function to retrieve the state.
	 */
	FmgrInfo	fn_get_state;

	/* Function to serialize the core state. */
	FmgrInfo	fn_serialize;

	/* Function to convert the output of fn_serialize to cstring. */
	FmgrInfo	fn_out;
} PartialAggState;

extern Datum partial_state_in(PG_FUNCTION_ARGS);
extern Datum partial_state_out(PG_FUNCTION_ARGS);
extern Datum partial_state_to_bytea(PG_FUNCTION_ARGS);
extern Datum numeric_accum_partial(PG_FUNCTION_ARGS);
//extern Datum int4_accum_partial(PG_FUNCTION_ARGS);

static PartialAggState * accum_partial_common_init(PG_FUNCTION_ARGS,
												   Oid fn_get_state);
static Datum partial_state_out_common(PartialAggState *state,
	FunctionCallInfo fcinfo);

PG_FUNCTION_INFO_V1(partial_state_in);
Datum
partial_state_in(PG_FUNCTION_ARGS)
{
	elog(ERROR, "partial_state_in() not implemented");
}

PG_FUNCTION_INFO_V1(partial_state_out);
Datum
partial_state_out(PG_FUNCTION_ARGS)
{
	PartialAggState *state;
	Datum	result_bytea, result;

	if (PG_ARGISNULL(0))
		PG_RETURN_NULL();

	state = (PartialAggState *) PG_GETARG_POINTER(0);
	result_bytea = partial_state_out_common(state, fcinfo);

	if (!OidIsValid(state->fn_out.fn_oid))
	{
		/* byteaout */
		fmgr_info(31, &state->fn_out);
	}
	result = FunctionCall1(&state->fn_out, result_bytea);

	PG_RETURN_POINTER(result);
}

/*
 * A cast to ensure that the output of partial aggregation is considered bytea
 * rather than "state". Without this partial_state_out() would be called by
 * the node that performs the final aggregation. TODO Consider making
 * partial_state_out() generic enough so that it can be called there too.
 */
PG_FUNCTION_INFO_V1(partial_state_to_bytea);
Datum
partial_state_to_bytea(PG_FUNCTION_ARGS)
{
	PartialAggState *state;
	Datum	result_bytea;

	if (PG_ARGISNULL(0))
		PG_RETURN_NULL();

	state = (PartialAggState *) PG_GETARG_POINTER(0);
	result_bytea = partial_state_out_common(state, fcinfo);
	/* Simply copy input to output. */
	PG_RETURN_POINTER(DatumGetByteaPCopy(result_bytea));
}

PG_FUNCTION_INFO_V1(numeric_accum_partial);
Datum
numeric_accum_partial(PG_FUNCTION_ARGS)
{
	PartialAggState *state;

	if (PG_ARGISNULL(0))
	{
		/* 1833 ~ numeric_accum */
		state = accum_partial_common_init(fcinfo, 1833);

		/* numeric_avg_serialize */
		fmgr_info(2740, &state->fn_serialize);
	}
	else
	{
		state = (PartialAggState *) PG_GETARG_POINTER(0);

		Assert(OidIsValid(state->fn_get_state.fn_oid));
		Assert(state->state != (Datum) 0);

		if (!PG_ARGISNULL(1))
			state->state = FunctionCall2(&state->fn_get_state,
										 state->state, PG_GETARG_DATUM(1));
	}

	PG_RETURN_POINTER(state);
}

/* PG_FUNCTION_INFO_V1(int4_accum_partial); */
/* Datum */
/* int4_accum_partial(PG_FUNCTION_ARGS) */
/* { */
/* 	PartialAggState *state; */

/* 	if (PG_ARGISNULL(0)) */
/* 	{ */
/* 		/\* 1835 ~ int4_accum *\/ */
/* 		state = accum_partial_common_init(fcinfo, 1835); */

/* 		/\* array_out *\/ */
/* 		fmgr_info(2297, &state->fn_serialize); */
/* 	} */
/* 	else */
/* 	{ */
/* 		state = (PartialAggState *) PG_GETARG_POINTER(0); */

/* 		Assert(OidIsValid(state->fn_get_state.fn_oid)); */
/* 		Assert(state->state != (Datum) 0); */

/* 		if (!PG_ARGISNULL(1)) */
/* 			state->state = FunctionCall2(&state->fn_get_state, */
/* 										 state->state, PG_GETARG_DATUM(1)); */
/* 	} */

/* 	PG_RETURN_POINTER(state); */
/* } */

/*
 * Initialize transient state by calling fn_get_state function and wrap it
 * into PartialAggState.
 */
static PartialAggState *
accum_partial_common_init(PG_FUNCTION_ARGS, Oid fn_get_state)
{
	FunctionCallInfoData fcinfo_core;
	MemoryContext agg_context, old_context;
	PartialAggState	*state;

	if (!AggCheckCallContext(fcinfo, &agg_context))
		elog(ERROR, "aggregate function called in non-aggregate context");
	old_context = MemoryContextSwitchTo(agg_context);
	state = (PartialAggState *) palloc0(sizeof(PartialAggState));
	MemoryContextSwitchTo(old_context);

	fmgr_info(fn_get_state, &state->fn_get_state);

	InitFunctionCallInfoData(fcinfo_core, &state->fn_get_state, 2,
							 InvalidOid, fcinfo->context, NULL);
	fcinfo_core.arg[0] = (Datum) 0;
	fcinfo_core.arg[1] = PG_GETARG_DATUM(1);
	fcinfo_core.argnull[0] = true;
	fcinfo_core.argnull[1] = false;
	state->state = FunctionCallInvoke(&fcinfo_core);

	if (fcinfo_core.isnull)
		elog(ERROR, "function %u returned NULL",
			 fcinfo_core.flinfo->fn_oid);

	return state;
}

static Datum
partial_state_out_common(PartialAggState *state, FunctionCallInfo fcinfo)
{
	FunctionCallInfoData fcinfo_core;
	Node *context;
	Datum	result;

	if (!OidIsValid(state->fn_serialize.fn_oid))
		elog(ERROR, "fn_serialize not initialized");

	/*
	 * Do manually what FunctionCall1Col1() would do, as we need to pass
	 * fcinfo->context to the core function (the context contains AggState).
	 */
	if (fcinfo->context == NULL)
		/*
		 * Fake state so numeric_avg_serialize() does not raise ERROR (the
		 * function does not use the state anyway.q
		 */
		context = (Node *) makeNode(AggState);
	else
		context = fcinfo->context;

	InitFunctionCallInfoData(fcinfo_core, &state->fn_serialize, 1,
							 InvalidOid, context, NULL);
	fcinfo_core.arg[0] = state->state;
	fcinfo_core.argnull[0] = false;
	result = FunctionCallInvoke(&fcinfo_core);
	if (fcinfo_core.isnull)
		elog(ERROR, "function %u returned NULL",
			 fcinfo_core.flinfo->fn_oid);

	return result;
}
