
import freenet.crypt.BlockCipher;
import freenet.crypt.ciphers.Rijndael;

public final class TestEncrypt {
    
    public static void main(String[] args) {
        test_symmetric(128,128);
        test_symmetric(128,192);
        test_symmetric(128,256);
        
        test_symmetric(192,128);
        test_symmetric(192,192);
        test_symmetric(192,256);
        
        test_symmetric(256,128);
        test_symmetric(256,192);
        test_symmetric(256,256);
    }
    
    public static void test_symmetric(int keySize, int blockSize) {
        java.util.Random random=new java.util.Random();
        
        System.out.println("TEST SYMMETRIC: keySize="+keySize+" / blockSize="+blockSize);
        
        byte[] key=new byte[keySize/8];
        
        random.nextBytes(key);
        print("Using key= 0x", key);
        
        BlockCipher cipher=null;
        try {
            cipher=new Rijndael(keySize, blockSize);
            cipher.initialize(key);
        } catch (Exception e) {
            e.printStackTrace();
            System.err.println("Failed to init cipher");
            System.exit(2);
        }
        
        //Now use bytes rather than bits;
        blockSize/=8;
        
        byte[] clear_text=new byte[blockSize];
        byte[] crypted=new byte[blockSize];
        byte[] decrypted=new byte[blockSize];
        
        random.nextBytes(clear_text);
        print("Input    = 0x", clear_text);
        
        cipher.encipher(clear_text, crypted);
        print("Ciphered = 0x", crypted);
        
        cipher.decipher(crypted, decrypted);
        print("Restored = 0x", decrypted);
        
        if (equal(clear_text, decrypted)) {
            System.err.println("SUCCESS");
        } else {
            System.err.println("FAILURE");
        }
    }
    
    private static final char[] HEX_DIGITS = {
        '0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F'
    };
    
    private static void print(String prefix, byte[] ba) {
        java.io.PrintStream out=System.out;
        int length = ba.length;
        char[] buf = new char[length * 2];
        for (int i = 0, j = 0, k; i < length; ) {
            k = ba[i++];
            buf[j++] = HEX_DIGITS[(k >>> 4) & 0x0F];
            buf[j++] = HEX_DIGITS[ k        & 0x0F];
        }
        out.print(prefix);
        out.println(buf);
    }
    
    private static boolean equal(byte[] a, byte[] b) {
        if (a.length != b.length)
            return false;
        int length = a.length;
        for (int i = 0; i < length; i++) {
            if (a[i]!=b[i])
                return false;
        }
        return true;
    }
    
}