/*
 * $Header: $
 * $Revision: $
 * $Date: $
 *
 */
package org.apache.commons.httpclient;

import java.security.*;
import java.util.*;
import java.io.*;
import javax.crypto.*;
import javax.crypto.spec.*;

/**
 * <p>Provides an implementation of the NTLM authentication
 * protocol.</p>
 * <p>This class provides methods for generating authentication
 * challenge responses for the NTLM authentication protocol.  The NTLM
 * protocol is a proprietary Microsoft protocol and as such no RFC
 * exists for it.  This class is based upon the reverse engineering
 * efforts of a wide range of people.</p>
 *
 * @author <a href="mailto:adrian@ephox.com">Adrian Sutton</a>
 * @version $Revision: $ $Date: $
 *
 */
public final class NTLM {

	private static byte[] currResponse;
	private static int pos = 0;

	/**
	 * <p>Returns the response for the given message.</p>
	 *
	 * @param message the message that was received from the server.
	 * @param username the username to authenticate with.
	 * @param password the password to authenticate with.
	 * @param domain the NT domain to authenticate in.
	 *
	 * @throws UnsupportedEncodingException if ASCII encoding is not
	 * supported by the JVM.
	 */
	public static final String getResponseFor(String message,
			String username, String password, String host, String domain)
			throws UnsupportedEncodingException, HttpException {
		String response = null;
		if (message == null || message.trim().equals("")) {
			response = getType1Message(host, domain);
		} else {
			response = getType3Message(username, password, host, domain,
					parseType2Message(message));
		}
		return response;
	}

	/** Prepares the object to create a response of the given length.
	 * @param length the length of the response to prepare.
	 */
	private static void prepareResponse(int length) {
		currResponse = new byte[length];
		pos = 0;
	}

	/** Adds the given byte to the response.
	 * @param b the byte to add.
	 */
	private static void addByte(byte b) {
		currResponse[pos] = b;
		pos++;
	}

	/** Adds the given bytes to the response.
	 * @param bytes the bytes to add.
	 */
	private static void addBytes(byte[] bytes) {
		for (int i = 0; i < bytes.length; i++) {
			currResponse[pos] = bytes[i];
			pos++;
		}
	}

	/** Returns the response that has been generated after shrinking the
	 * array if required and base64 encodes the response.
	 */
	private static String getResponse() throws UnsupportedEncodingException {
		byte[] resp;
		if (currResponse.length > pos) {
			byte[] tmp = new byte[pos];
			for (int i = 0; i < pos; i++) {
				tmp[i] = currResponse[i];
			}
			resp = tmp;
		} else {
			resp = currResponse;
		}
		return new String(Base64.encode(resp), "ASCII");
	}
	
	private static String getType1Message(String host, String domain)
			throws UnsupportedEncodingException {
		host = host.toUpperCase();
		domain = domain.toUpperCase();
		byte[] hostBytes = host.getBytes("ASCII");
		byte[] domainBytes = domain.getBytes("ASCII");

		int finalLength = 32 + hostBytes.length + domainBytes.length;
		prepareResponse(finalLength);
		byte[] msg = new byte[finalLength];
		
		// The initial id string.
		byte[] protocol = "NTLMSSP".getBytes("ASCII");
		addBytes(protocol);
		addByte((byte)0);

		// Type
		addByte((byte)1);
		addByte((byte)0);
		addByte((byte)0);
		addByte((byte)0);

		// Flags
		addByte( (byte)6);
		addByte( (byte)82);
		addByte( (byte)0);
		addByte( (byte)0);

		// Domain length (first time).
		int iDomLen = domainBytes.length;
		byte[] domLen = convertShort(iDomLen);
		addByte( domLen[0]);
		addByte( domLen[1]);

		// Domain length (second time).
		addByte( domLen[0]);
		addByte( domLen[1]);

		// Domain offset.
		byte[] domOff = convertShort(hostBytes.length + 32);
		addByte( domOff[0]);
		addByte( domOff[1]);
		addByte( (byte)0);
		addByte( (byte)0);

		// Host length (first time).
		byte[] hostLen = convertShort(hostBytes.length);
		addByte( hostLen[0]);
		addByte( hostLen[1]);

		// Host length (second time).
		addByte( hostLen[0]);
		addByte( hostLen[1]);

		// Host offset (always 32).
		byte[] hostOff = convertShort(32);
		addByte( hostOff[0]);
		addByte( hostOff[1]);
		addByte( (byte)0);
		addByte( (byte)0);

		// Host String.
		addBytes(hostBytes);

		// Domain String.
		addBytes(domainBytes);

		return getResponse();
	}

	/** Extracts the server nonce out of the given message type 2.
	 * @param msg the String containing the base64 encoded message.
	 * @return an array of 8 bytes that the server sent to be used when
	 * hashing the password.
	 */
	private static byte[] parseType2Message(String sMsg)
		        throws UnsupportedEncodingException {
		// Decode the message first.
		byte[] msg = Base64.decode(sMsg.getBytes("ASCII"));
		byte[] nonce = new byte[8];
		// The nonce is the 8 bytes starting from the byte in position 24.
		for (int i = 0; i < 8; i++) {
			nonce[i] = msg[i + 24];
		}
		return nonce;
	}

	/** Creates the type 3 message using the given server nonce.
	 * @param nonce the 8 byte array the server sent.
	 */
	private static String getType3Message(String user, String password,
			String host, String domain, byte[] nonce)
			throws UnsupportedEncodingException, HttpException {

		int nt_resp_len = 0;
		int lm_resp_len = 24;
		domain = domain.toUpperCase();
		host = host.toUpperCase();
		user = user.toUpperCase();
		byte[] domainBytes = domain.getBytes("ASCII");
		byte[] hostBytes = host.getBytes("ASCII");
		byte[] userBytes = user.getBytes("ASCII");
		int domainLen = domainBytes.length;
		int hostLen = hostBytes.length;
		int userLen = userBytes.length;
		int finalLength = 64 + nt_resp_len + lm_resp_len + domainLen +
			userLen + hostLen;
		prepareResponse(finalLength);
		byte[] ntlmssp = "NTLMSSP".getBytes("ASCII");
		addBytes(ntlmssp);
		addByte((byte)0);
		addByte((byte)3);
		addByte((byte)0);
		addByte((byte)0);
		addByte((byte)0);

		// LM Resp Length (twice)
		addBytes(convertShort(24));
		addBytes(convertShort(24));

		// LM Resp Offset
		addBytes(convertShort(finalLength - 24));
		addByte((byte)0);
		addByte((byte)0);

		// NT Resp Length (twice)
		addBytes(convertShort(0));
		addBytes(convertShort(0));

		// NT Resp Offset
		addBytes(convertShort(finalLength));
		addByte((byte)0);
		addByte((byte)0);

		// Domain length (twice)
		addBytes(convertShort(domainLen));
		addBytes(convertShort(domainLen));
		
		// Domain offset.
		addBytes(convertShort(64));
		addByte((byte)0);
		addByte((byte)0);

		// User Length (twice)
		addBytes(convertShort(userLen));
		addBytes(convertShort(userLen));

		// User offset
		addBytes(convertShort(64 + domainLen));
		addByte((byte)0);
		addByte((byte)0);

		// Host length (twice)
		addBytes(convertShort(hostLen));
		addBytes(convertShort(hostLen));

		// Host offset
		addBytes(convertShort(64 + domainLen + userLen));

		for (int i = 0; i < 6; i++) {
			addByte((byte)0);
		}

		// Message length
		addBytes(convertShort(finalLength));
		addByte((byte)0);
		addByte((byte)0);

		// Flags
		addByte((byte)6);
		addByte((byte)82);
		addByte((byte)0);
		addByte((byte)0);

		addBytes(domainBytes);
		addBytes(userBytes);
		addBytes(hostBytes);
		addBytes(hashPassword(password, nonce));
		return getResponse();
	}

	/** Creates the LANManager and NT response for the given password using the
	 * given nonce.
	 * @param passw the password to create a hash for.
	 * @param nonce the nonce sent by the server.
	 */
	private static byte[] hashPassword(String password, byte[] nonce)
			throws UnsupportedEncodingException, HttpException {
		byte[] passw = password.toUpperCase().getBytes("ASCII");
		byte[] lm_pw1 = new byte[7];
		byte[] lm_pw2 = new byte[7];

		int len = passw.length;
		if (len > 7) {
			len = 7;
		}

		int idx;
		for (idx = 0; idx < len; idx++) {
			lm_pw1[idx] = passw[idx];
		}
		for (; idx < 7; idx++) {
			lm_pw1[idx] = (byte)0;
		}

		len = passw.length;
		if (len > 14) {
			len = 14;
		}
		for (idx = 7; idx < len; idx++) {
			lm_pw2[idx - 7] = passw[idx];
		}
		for (; idx < 14; idx++) {
			lm_pw2[idx - 7] = (byte)0;
		}

		// Create LanManager hashed Password
		byte[] magic = {(byte)0x4B, (byte)0x47, (byte)0x53, (byte)0x21, (byte)0x40, (byte)0x23, (byte)0x24, (byte)0x25};

		byte[] lm_hpw1;
		DES des = new DES(lm_pw1);
		lm_hpw1 = des.encrypt(magic);

		des = new DES(lm_pw2);
		byte[] lm_hpw2 = des.encrypt(magic);

		byte[] lm_hpw = new byte[21];
		for (int i = 0; i < lm_hpw1.length; i++) {
			lm_hpw[i] = lm_hpw1[i];
		}
		for (int i = 0; i < lm_hpw2.length; i++) {
			lm_hpw[i + 8] = lm_hpw2[i];
		}
		for (int i = 0; i < 5; i++) {
			lm_hpw[i + 16] = (byte)0;
		}

		// Create the responses.
		byte[] lm_resp = new byte[24];
		calc_resp(lm_hpw, nonce, lm_resp);

		return lm_resp;
	}

	/** Takes a 21 byte array and treats it as 3 56-bit DES keys.  The 8
	 * byte plaintext is encrypted with each key and the resulting 24
	 * bytes are stored in teh results array.
	 */
	private static void calc_resp(byte[] keys, byte[] plaintext, byte[] results)
			throws HttpException {
		byte[] keys1 = new byte[7];
		byte[] keys2 = new byte[7];
		byte[] keys3 = new byte[7];
		for (int i = 0; i < 7; i++) {
			keys1[i] = keys[i];
		}

		for (int i = 0; i < 7; i++) {
			keys2[i] = keys[i + 7];
		}

		for (int i = 0; i < 7; i++) {
			keys3[i] = keys[i + 14];
		}
		DES des = new DES(keys1);
		byte[] results1 = des.encrypt(plaintext);

		des = new DES(keys2);
		byte[] results2 = des.encrypt(plaintext);

		des = new DES(keys3);
		byte[] results3 = des.encrypt(plaintext);

		for (int i = 0; i < 8; i++) {
			results[i] = results1[i];
		}
		for (int i = 0; i < 8; i++) {
			results[i + 8] = results2[i];
		}
		for (int i = 0; i < 8; i++) {
			results[i + 16] = results3[i];
		}
	}

	/** Converts a given number to a two byte array in little endian
	 * order.
	 * @param num the number to convert.
	 */
	private static byte[] convertShort(int num) {
		byte[] val = new byte[2];
		String hex = Integer.toString(num, 16);
		while (hex.length() < 4) {
			hex = "0" + hex;
		}
		String low = hex.substring(2, 4);
		String high = hex.substring(0, 2);

		val[0] = (byte)Integer.parseInt(low, 16);
		val[1] = (byte)Integer.parseInt(high, 16);
		return val;
	}
}
