//Written in the D programming language

/++
    Module contains helper functions for unit tests.

    Copyright: Copyright Jonathan M Davis 2010
    License:   <a href="http://www.boost.org/LICENSE_1_0.txt">Boost License 1.0</a>.
    Authors:   Jonathan M Davis

             Copyright Jonathan M Davis 2010.
    Distributed under the Boost Software License, Version 1.0.
       (See accompanying file LICENSE_1_0.txt or copy at
             http://www.boost.org/LICENSE_1_0.txt)
+/
module unittests;

version(unittest)
{

import core.exception;
import std.array;
import std.conv;
import std.string;

/++
    Allows file and line info to be passed to unittesting functions which
    have variadic arguments and thus can't use default arguments for file and
    line info. That way, the line with the test is the line where the
    error is reported rather than a line inside the unittesting function.
 +/
struct LineInfo
{
    string file;
    size_t line;

    this(string file, size_t line)
    {
        this.file = file;
        this.line = line;
    }

    static LineInfo opCall(string file = __FILE__, size_t line = __LINE__)
    {
        return LineInfo(file, line);
    }
}


/++
    Asserts that the given function and arguments throws the given exception type.
    That exception is caught and does not escape assertExcThrown(). However, any other
    exceptions will escape. assertExcThrown() also works with Errors - including AssertError.

    Params:
        E    = The exception to test for.
        func = The function to be tested.
        li   = LineInfo() is passed as the first argument so that the line where assertExcThrown() is
               called will be the line reported for test failures rather than a line internal to assertExcThrown().
        args = Arguments (if any) to be passed to the function being tested.

    Examples:
    --------------------
    assertExcThrown!(Exception, myfunc)(LineInfo(), param1, param2);
    --------------------
+/
void assertExcThrown(E : Throwable, alias func, T...)(in LineInfo li, T args)
    if(__traits(compiles, func(args)))
{
    bool thrown = false;

    try
        func(args);
    catch(E e)
        thrown = true;

    if(!thrown)
        throw new AssertError(format("assertExcThrown() failed: No %s was thrown from %s()", E.stringof, __traits(identifier, func)), li.file, li.line);
}


/++
    Asserts that the given function and arguments throws the given exception type.
    That exception is caught and does not escape assertExcThrown(). However, any other
    exceptions will escape. assertExcThrown() also works with Errors - including AssertError.

    Params:
        msg  = Message to be printed when test fails.
        E    = The exception to test for.
        func = The function to be tested.
        li   = LineInfo() is passed as the first argument so that the line where assertExcThrown() is
               called will be the line reported for test failures rather than a line internal to assertExcThrown().
        args = Arguments (if any) to be passed to the function being tested.


    Examples:
    --------------------
    assertExcThrown!("My test failed!", Exception, myfunc)(LineInfo(), param1, param2);
    --------------------
+/
void assertExcThrown(string msg, E : Throwable, alias func, T...)(in LineInfo li, T args)
    if(__traits(compiles, func(args)))
{
    bool thrown = false;

    try
        func(args);
    catch(E e)
        thrown = true;

    if(!thrown)
        throw new AssertError(format("assertExcThrown() failed: No %s was thrown from %s(): %s", E.stringof, __traits(identifier, func), msg), li.file, li.line);
}


unittest
{
    void throwExc(Throwable t)
    {
        throw t;
    }

    void nothrowExc()
    {
    }

    try
        assertExcThrown!(Exception, throwExc)(LineInfo(), new Exception("It's an Exception"));
    catch(AssertError)
        assert(0);

    try
        assertExcThrown!("It's a message", Exception, throwExc)(LineInfo(), new Exception("It's an Exception"));
    catch(AssertError)
        assert(0);

    try
        assertExcThrown!(AssertError, throwExc)(LineInfo(), new AssertError("It's an AssertError", __FILE__, __LINE__));
    catch(AssertError)
        assert(0);

    try
        assertExcThrown!("It's a message", AssertError, throwExc)(LineInfo(), new AssertError("It's an AssertError", __FILE__, __LINE__));
    catch(AssertError)
        assert(0);


    {
        bool thrown = false;
        try
            assertExcThrown!(Exception, nothrowExc)(LineInfo());
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcThrown!("It's a message", Exception, nothrowExc)(LineInfo());
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcThrown!(AssertError, nothrowExc)(LineInfo());
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcThrown!("It's a message", AssertError, nothrowExc)(LineInfo());
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }
}

/++
    Asserts that the given function and arguments does not throw the given exception type.
    If that exception is thrown, it is caught and does not escape assertExcNotThrown(). Instead,
    the AssertError indicating test failure is thrown. Any other exceptions will escape
    assertExcNotThrown(). assertExcNotThrown() also works with Errors - including AssertError.

    Params:
        E    = The exception to test for.
        func = The function to be tested.
        li   = LineInfo() is passed as the first argument so that the line where assertExcNotThrown() is
               called will be the line reported for test failures rather than a line internal to assertExcNotThrown().
        args = Arguments (if any) to be passed to the function being tested.

    Examples:
    --------------------
    assertExcNotThrown!(Exception, myfunc)(LineInfo(), param1, param2);
    --------------------
+/
void assertExcNotThrown(E : Throwable, alias func, T...)(in LineInfo li, T args)
    if(__traits(compiles, func(args)))
{
    try
        func(args);
    catch(E e)
        throw new AssertError(format("assertExcNotThrown() failed: %s was thrown from %s()", E.stringof, __traits(identifier, func)), li.file, li.line);
}

/++
    Asserts that the given function and arguments does not throw the given exception type.
    If that exception is thrown, it is caught and does not escape assertExcNotThrown(). Instead,
    the AssertError indicating test failure is thrown. Any other exceptions will escape
    assertExcNotThrown(). assertExcNotThrown() also works with Errors - including AssertError.

    Params:
        msg  = Message to be printed when test fails.
        E    = The exception to test for.
        func = The function to be tested.
        li   = LineInfo() is passed as the first argument so that the line where assertExcNotThrown() is
               called will be the line reported for test failures rather than a line internal to assertExcNotThrown().
        args = Arguments (if any) to be passed to the function being tested.

    Examples:
    --------------------
    assertExcNotThrown!("My test failed!", Exception, myfunc)(LineInfo(), param1, param2);
    --------------------
+/
void assertExcNotThrown(string msg, E : Throwable, alias func, T...)(in LineInfo li, T args)
    if(__traits(compiles, func(args)))
{
    try
        func(args);
    catch(E e)
        throw new AssertError(format("assertExcNotThrown() failed: %s was thrown from %s(): %s", E.stringof, __traits(identifier, func), msg), li.file, li.line);
}

unittest
{
    void throwExc(Throwable t)
    {
        throw t;
    }

    void nothrowExc()
    {
    }

    try
        assertExcNotThrown!(Exception, nothrowExc)(LineInfo());
    catch(AssertError)
        assert(0);

    try
        assertExcNotThrown!("It's a message", Exception, nothrowExc)(LineInfo());
    catch(AssertError)
        assert(0);

    try
        assertExcNotThrown!(AssertError, nothrowExc)(LineInfo());
    catch(AssertError)
        assert(0);

    try
        assertExcNotThrown!("It's a message", AssertError, nothrowExc)(LineInfo());
    catch(AssertError)
        assert(0);


    {
        bool thrown = false;
        try
            assertExcNotThrown!(Exception, throwExc)(LineInfo(), new Exception("It's an Exception"));
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcNotThrown!("It's a message", Exception, throwExc)(LineInfo(), new Exception("It's an Exception"));
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcNotThrown!(AssertError, throwExc)(LineInfo(), new AssertError("It's an AssertError", __FILE__, __LINE__));
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }

    {
        bool thrown = false;
        try
            assertExcNotThrown!("It's a message", AssertError, throwExc)(LineInfo(), new AssertError("It's an AssertError", __FILE__, __LINE__));
        catch(AssertError)
            thrown = true;

        assert(thrown);
    }
}


/++
    Asserts that two values are equal according to ==. This function is useful over simply asserting
    that they're equal because it gives better error messages on test failure.

    The file and line number parameters can be set, but they're really there so that the line number in
    the error message on test failure includes the file and line number where assertEqual() was called
    rather than the file and line number inside of assertEqual().

    Params:
        actual   = The value to test.
        expected = The value that actual is supposed to be equal to.
        msg      = Optional message to output on test failure.
        file     = The file to list on test failure.
        line     = The line number to list on test failure.

    Examples:
    --------------------
    assertEqual(myfunc(), 7);
    assertEqual(myfunc(), 7, "My test failed!");
    --------------------
+/
void assertEqual(T, U)(in T actual, in U expected, string msg = null, string file = __FILE__, size_t line = __LINE__)
    if(__traits(compiles, actual != expected) &&
       __traits(compiles, to!string(actual)) &&
       __traits(compiles, to!string(expected)))
{
    if(actual != expected)
    {
        if(msg.empty)
            throw new AssertError(format("assertEquals() failed: actual [%s], expected [%s]", actual, expected), file, line);
        else
            throw new AssertError(format("assertEquals() failed: actual [%s], expected [%s] : %s", actual, expected, msg), file, line);
    }
}


unittest
{
    struct IntWrapper
    {
        int value;

        this(int value)
        {
            this.value = value;
        }

        string toString()
        {
            return to!string(value);
        }

        string toString() const
        {
            return to!string(value);
        }
    }

    assertExcThrown!(AssertError, assertEqual)(LineInfo(), 6, 7);
    assertExcThrown!(AssertError, assertEqual)(LineInfo(), 6, 6.1);
    assertExcThrown!(AssertError, assertEqual)(LineInfo(), IntWrapper(6), IntWrapper(7));
    assertExcThrown!(AssertError, assertEqual)(LineInfo(), IntWrapper(7), IntWrapper(6));

    assertExcNotThrown!(AssertError, assertEqual)(LineInfo(), 6, 6);
    assertExcNotThrown!(AssertError, assertEqual)(LineInfo(), 6, 6.0);
    assertExcNotThrown!(AssertError, assertEqual)(LineInfo(), IntWrapper(6), IntWrapper(6));
}


/++
    Asserts that two values are not equal according to ==. This function is useful over simply asserting
    that they're unequal because it gives better error messages on test failure.

    The file and line number parameters can be set, but they're really there so that the line number in
    the error message on test failure includes the file and line number where assertEqual() was called
    rather than the file and line number inside of assertEqual().

    Params:
        actual   = The value to test.
        expected = The value that actual is not supposed to be equal to.
        msg      = Optional message to output on test failure.
        file     = The file to list on test failure.
        line     = The line number to list on test failure.

    Examples:
    --------------------
    assertNotEqual(myfunc(), 7);
    assertNotEqual(myfunc(), 7, "My test failed!");
    --------------------
+/
void assertNotEqual(T, U)(in T actual, in U expected, string msg = null, string file = __FILE__, size_t line = __LINE__)
    if(__traits(compiles, actual == expected) && __traits(compiles, to!string(actual)))
{
    if(actual == expected)
    {
        if(msg.empty)
            throw new AssertError(format("assertNotEquals() failed: value [%s]", actual), file, line);
        else
            throw new AssertError(format("assertNotEquals() failed: value [%s] : %s", actual, msg), file, line);
    }
}

unittest
{
    struct IntWrapper
    {
        int value;

        this(int value)
        {
            this.value = value;
        }

        string toString()
        {
            return to!string(value);
        }

        string toString() const
        {
            return to!string(value);
        }
    }

    assertExcNotThrown!(AssertError, assertNotEqual)(LineInfo(), 6, 7);
    assertExcNotThrown!(AssertError, assertNotEqual)(LineInfo(), 6, 6.1);
    assertExcNotThrown!(AssertError, assertNotEqual)(LineInfo(), IntWrapper(6), IntWrapper(7));
    assertExcNotThrown!(AssertError, assertNotEqual)(LineInfo(), IntWrapper(7), IntWrapper(6));

    assertExcThrown!(AssertError, assertNotEqual)(LineInfo(), 6, 6);
    assertExcThrown!(AssertError, assertNotEqual)(LineInfo(), 6, 6.0);
    assertExcThrown!(AssertError, assertNotEqual)(LineInfo(), IntWrapper(6), IntWrapper(6));
}

/++
    Asserts that a value is equal, less than, or greater than another according to opCmp().
    This function is useful over simply asserting on the results of opCmp() because it gives better
    error messages on test failure.

    The file and line number parameters can be set, but they're really there so that the line number in
    the error message on test failure includes the file and line number where assertOpCmp() was called
    rather than the file and line number inside of assertOpCmp().

    Params:
        op    = The operation from opCmp() to test. Must be "==", "<", or ">".
        rhs   = The value which will have opCmp() called on it.
        lhs   = The value passed to opCmp().
        msg   = Optional message to output on test failure.
        file  = The file to list on test failure.
        line  = The line number to list on test failure.

    Examples:
    --------------------
    assertEqual!("<")(myfunc(), 7);
    assertEqual!("==")(myfunc(), 7);
    assertEqual!(">")(myfunc(), 7);
    assertEqual!("<")(myfunc(), 7, "My test failed!");
    assertEqual!(">")(myfunc(), 7, "My test failed!");
    assertEqual!("==")(myfunc(), 7, "My test failed!");
    --------------------
+/
void assertOpCmp(string op, T, U)(in T lhs, in U rhs, string msg = null, string file = __FILE__, size_t line = __LINE__)
    if((op == "==" || op == "<" || op == ">") &&
        __traits(compiles, lhs.opCmp(rhs)) &&
        __traits(compiles, to!string(lhs)),
        __traits(compiles, to!string(rhs)))
{
    immutable result = lhs.opCmp(rhs);

    static if(op == "==")
    {
        if(result == 0)
            return;
        if(result < 0)
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\"==\")() failed: [%s] < [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\"==\")() failed: [%s] < [%s] : %s", lhs, rhs, msg), file, line);
        }
        else
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\"==\")() failed: [%s] > [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\"==\")() failed: [%s] > [%s] : %s", lhs, rhs, msg), file, line);
        }
    }
    else static if(op == "<")
    {
        if(result < 0)
            return;
        if(result == 0)
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\"<\")() failed: [%s] == [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\"<\")() failed: [%s] == [%s] : %s", lhs, rhs, msg), file, line);
        }
        else
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\"<\")() failed: [%s] > [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\"<\")() failed: [%s] > [%s] : %s", lhs, rhs, msg), file, line);
        }
    }
    else static if(op == ">")
    {
        if(result > 0)
            return;
        if(result < 0)
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\">\")() failed: [%s] < [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\">\")() failed: [%s] < [%s] : %s", lhs, rhs, msg), file, line);
        }
        else
        {
            if(msg.empty)
                throw new AssertError(format("assertOpCmp!(\">\")() failed: [%s] == [%s]", lhs, rhs), file, line);
            else
                throw new AssertError(format("assertOpCmp!(\">\")() failed: [%s] == [%s] : %s", lhs, rhs, msg), file, line);
        }
    }
    else
        static assert(0);
}

unittest
{
    struct IntWrapper
    {
        int value;

        this(int value)
        {
            this.value = value;
        }

        int opCmp(const ref IntWrapper rhs) const
        {
            if(value < rhs.value)
                return -1;
            else if(value > rhs.value)
                return 1;

            return 0;
        }

        string toString()
        {
            return to!string(value);
        }

        string toString() const
        {
            return to!string(value);
        }
    }

    assertExcNotThrown!(AssertError, assertOpCmp!("==", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(6));
    assertExcNotThrown!(AssertError, assertOpCmp!("==", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(0), IntWrapper(0));
    assertExcNotThrown!(AssertError, assertOpCmp!("<", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(0), IntWrapper(6));
    assertExcNotThrown!(AssertError, assertOpCmp!("<", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(7));
    assertExcNotThrown!(AssertError, assertOpCmp!(">", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(0));
    assertExcNotThrown!(AssertError, assertOpCmp!(">", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(7), IntWrapper(6));

    assertExcThrown!(AssertError, assertOpCmp!("==", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(7));
    assertExcThrown!(AssertError, assertOpCmp!("==", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(7), IntWrapper(6));
    assertExcThrown!(AssertError, assertOpCmp!("<", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(6));
    assertExcThrown!(AssertError, assertOpCmp!("<", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(7), IntWrapper(6));
    assertExcThrown!(AssertError, assertOpCmp!(">", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(6));
    assertExcThrown!(AssertError, assertOpCmp!(">", IntWrapper, IntWrapper))(LineInfo(), IntWrapper(6), IntWrapper(7));
}

}
