module DynamicCall;

import std.traits;
import std.stdio;

enum ParamType
{
	// No return value
	Void,

	// ST0
	Float,
	Double,
	Real,

	// EAX
	Byte,
	Word,
	DWord,
	Pointer,
	// AArray, 		// isn't tested (merge with Pointer?)

	//EDX,EAX
	QWord,
	DArray,		 // isn't tested

	// Hidden pointer
	Hidden_Pointer,	// for returning large (>8 bytes) structs, not tested yet
}

struct Param
{
	ParamType type;	// type of value
	void* ptr;		// pointer to value
}

// This struct describes everything needed to make a call
struct Call
{
	Param[] input;	// set of input arguments
	Param output;	// result info
	void* funcptr;	// function pointer
}

// makes a call, stores result in call.ouput
void makeDynamicCall(Call* call)
{
	switch (call.output.type) {
		case ParamType.Void:
			_makeCall!(void)(call);
			break;

		case ParamType.Pointer:
			makeCall!(void*)(call);
			break;

		case ParamType.Byte:
			makeCall!(byte)(call);
			break;
			
		case ParamType.Word:
			makeCall!(short)(call);
			break;
			
		case ParamType.DWord:
			makeCall!(int)(call);
			break;
			
		case ParamType.QWord:
			makeCall!(long)(call);
			break;

		case ParamType.Float:
			makeCall!(float)(call);
			break;

		case ParamType.Double:
			makeCall!(double)(call);
			break;
			
		case ParamType.Real:
			makeCall!(real)(call);
			break;
	}
}

// helper function to save some typing
void makeCall(T)(Call* call)
{
	*cast(T*)call.output.ptr = _makeCall!(T)(call);
}

T _makeCall(T)(Call* call)
{
	void* funcptr = call.funcptr;
	void* argptr;

	int numArgs = call.input.length;
	
	if (numArgs != 0) {	// this check is needed because last parameter is passed in EAX (if possible)
		Param* param = call.input.ptr;
		
		// iterate over first numArgs-1 arguments
		for ( ; --numArgs; ++param) {
			/*
			// the following doesn't work for some reason (compiles but lead to wrong result in run-time)
			// would be so much more elegant!
			push!(arg)(param);
			/*/
			argptr = param.ptr;
			switch (param.type) {
				case ParamType.Byte:	// push byte
					arg(*cast(byte*)argptr);
					break;
					
				case ParamType.Word:	// push word
					arg(*cast(short*)argptr);
					break;

				case ParamType.Pointer:
				case ParamType.DWord:	// push dword
					arg(*cast(int*)argptr);
					break;

				case ParamType.QWord:	// push qword
					arg(*cast(long*)argptr);
					break;

				case ParamType.Float:	// push float
					arg(*cast(float*)argptr);
					break;

				case ParamType.Double:	// push double
					arg(*cast(double*)argptr);
					break;
					
				case ParamType.Real:	// push real
					arg(*cast(real*)argptr);
					break;
			}
			//*/
		}

		// same as above but passes in EAX if possible

		/*
		push!(lastArg)(param);
		/*/
		argptr = param.ptr;
		switch (param.type) {
			case ParamType.Byte:
				lastArg(*cast(byte*)argptr);
				break;
				
			case ParamType.Word:
				lastArg(*cast(short*)argptr);
				break;
				
			case ParamType.Pointer:
			case ParamType.DWord:
				lastArg(*cast(int*)argptr);
				break;

			case ParamType.QWord:
				lastArg(*cast(long*)argptr);
				break;

			case ParamType.Float:
				lastArg(*cast(float*)argptr);
				break;

			case ParamType.Double:
				lastArg(*cast(double*)argptr);
				
			case ParamType.Real:
				lastArg(*cast(real*)argptr);
		}
		//*/
	}

	asm {
		// call it!
		call funcptr;
	}
}

// A helper function that pushes an argument to stack in a type-safe manner
// extern (System) is used so that argument isn't passed via EAX
// does it work the same way in Linux? Or Linux uses __cdecl?
// There must be other way to pass all the arguments on stack, but this one works well so far
// Beautiful, isn't it?
extern (System) void arg(T)(T arg)
{
	asm {
		naked;
		ret;
	}
}

// A helper function that pushes an argument to stack in a type-safe manner
// Allowed to pass argumet via EAX (that's why it's extern (D))
void lastArg(T)(T arg)
{
	asm {
		naked;
		ret;
	}
}

// Compare it to my older implementation:
/+
T _makeCall(T)(Call* call)
{
	void* funcptr = call.funcptr;
	void* argptr;
	int i = call.input.length;
	
	int eax = -1;

	foreach (ref param; call.input) {
		--i;
		argptr = param.ptr;

		switch (param.type) {
			case ParamType.Byte:
				// passing word
				asm {
					mov EDX, argptr;
					mov AL, byte ptr[EDX];
				}
				if (i != 0) {
					asm {
						push EAX;
					}
				} else {
					asm {
						mov eax, EAX;
					}
				}
				break;
				
			case ParamType.Word:
				// passing word
				asm {
					mov EDX, argptr;
					mov AX, word ptr[EDX];
				}
				if (i != 0) {
					asm {
						push EAX;
					}
				} else {
					asm {
						mov eax, EAX;
					}
				}
				break;
			
			case ParamType.Pointer:
			case ParamType.DWord:
				// passing word
				asm {
					mov EDX, argptr;
					mov EAX, dword ptr[EDX];
				}
				if (i != 0) {
					asm {
						push EAX;
					}
				} else {
					asm {
						mov eax, EAX;
					}
				}
				break;
				
			case ParamType.QWord:
				// pushing word
				asm {
					mov EDX, argptr;
					mov EAX, dword ptr[EDX+4];
					push EAX;
					mov EAX, dword ptr[EDX];
					push EAX;
				}
				break;

			case ParamType.Float:
				// pushing float
				asm {
					sub ESP, 4;
					mov EAX, dword ptr[argptr];
					fld dword ptr[EAX];
					fstp dword ptr[ESP];
				}
				break;

			case ParamType.Double:
				// pushing double
				asm {
					sub ESP, 8;
					mov EAX, qword ptr[argptr];
					fld qword ptr[EAX];
					fstp qword ptr[ESP];
				}
				break;
		}
	}

	asm {
		mov EAX, eax;
		call funcptr;
	}
}
+/

// I was trying to move out common code to a separate function, but failed. It doesn't work for reasons unknown to me
/+
void push(alias fun)(Param* param)
{
	switch (param.type) {
		case ParamType.Byte:
			fun(*cast(byte*)param.ptr);
			break;
			
		case ParamType.Word:
			fun(*cast(short*)param.ptr);
			break;

		case ParamType.Pointer:
		case ParamType.DWord:
			fun(*cast(int*)param.ptr);
			break;

		case ParamType.QWord:
			fun(*cast(long*)param.ptr);
			break;

		case ParamType.Float:
			fun(*cast(float*)param.ptr);
			break;

		case ParamType.Double:
			fun(*cast(double*)param.ptr);
			break;
			
		case ParamType.Real:
			fun(*cast(real*)param.ptr);
			break;
	}
}
+/

// Convenient templates to map from type T to corresponding ParamType enum element

template isStructSize(T, int size)
{
	enum isStructSize = is (T == struct) && T.sizeof == size;
}

template ParamTypeFromT(T) if (is (T == byte) || is (T == ubyte) || is (T == char) || isStructSize!(T, 1))
{
	alias ParamType.Byte ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == short) || is (T == ushort) || is (T == wchar) || isStructSize!(T, 2))
{
	alias ParamType.Word ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == int) || is (T == uint) || is (T == dchar) || isStructSize!(T, 4))
{
	alias ParamType.DWord ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == long) || is (T == ulong) || isStructSize!(T, 8) || is (T == delegate))
{
	alias ParamType.QWord ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == float))
{
	alias ParamType.Float ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == double))
{
	alias ParamType.Double ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == real))
{
	alias ParamType.Real ParamTypeFromT;
}

template ParamTypeFromT(T) if (is (T == void))
{
	alias ParamType.Void ParamTypeFromT;
}

template ParamTypeFromT(T) if (isPointer!(T))
{
	alias ParamType.Pointer ParamTypeFromT;
}