import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicReference;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;

public class JDKSSLEngineTest {

    private static final ByteBuffer[] UNUSED_BUFFER_ARRAY = {ByteBuffer.allocate(2 ^ 16)};

    public void testForJDK(SSLEngine clientEngine, SSLEngine serverEngine) throws Exception {
        clientEngine.setUseClientMode(true);
        serverEngine.setUseClientMode(false);

        SSLSession clientSession = clientEngine.getSession();
        SSLSession serverSession = serverEngine.getSession();
        AtomicReference<ByteBuffer> clientReadBuffer = new AtomicReference<>(ByteBuffer.allocate(clientSession.getPacketBufferSize()));
        AtomicReference<ByteBuffer> clientWriteBuffer = new AtomicReference<>(ByteBuffer.allocate(clientSession.getPacketBufferSize()));
        clientWriteBuffer.get().position(clientWriteBuffer.get().limit());
        AtomicReference<ByteBuffer> serverReadBuffer = new AtomicReference<>(ByteBuffer.allocate(serverSession.getPacketBufferSize()));
        AtomicReference<ByteBuffer> serverWriteBuffer = new AtomicReference<>(ByteBuffer.allocate(serverSession.getPacketBufferSize()));
        serverWriteBuffer.get().position(serverWriteBuffer.get().limit());

        clientEngine.beginHandshake();
        serverEngine.beginHandshake();

        assert clientEngine.getHandshakeStatus() == NEED_WRAP;
        assert serverEngine.getHandshakeStatus() == NEED_UNWRAP;

        write(clientEngine, clientWriteBuffer);

        sendData(clientWriteBuffer.get(), serverReadBuffer.get());

        read(serverEngine, serverReadBuffer, serverWriteBuffer);
        sendData(serverWriteBuffer.get(), clientReadBuffer.get());

        while (serverEngine.getHandshakeStatus() == NEED_WRAP) {
            write(serverEngine, serverWriteBuffer);
            sendData(serverWriteBuffer.get(), clientReadBuffer.get());
        }

        assert serverEngine.getHandshakeStatus() == NEED_UNWRAP;

        read(clientEngine, clientReadBuffer, clientWriteBuffer);
        sendData(clientWriteBuffer.get(), serverReadBuffer.get());

        while (clientEngine.getHandshakeStatus() == NEED_WRAP) {
            write(clientEngine, clientWriteBuffer);
            sendData(clientWriteBuffer.get(), serverReadBuffer.get());
        }

        assert clientEngine.getHandshakeStatus() == NEED_UNWRAP;

        serverEngine.closeOutbound();

        // Need to send alert and close_notify
        assert serverEngine.getHandshakeStatus() == NEED_WRAP;
        assert serverEngine.isInboundDone() == false : "Inbound should not be done because we have not received close_notify";
        assert serverEngine.isOutboundDone() == false : "Outbound should not be done because we have not produced alert/close_notify";

        while (serverEngine.getHandshakeStatus() == NEED_WRAP) {
            write(serverEngine, serverWriteBuffer);
            sendData(serverWriteBuffer.get(), clientReadBuffer.get());
        }

        assert serverEngine.isInboundDone() == false : "Inbound should not be done because we have not received close_notify";
        assert serverEngine.isOutboundDone() : "Outbound should be done because we have produced alert/close_notify";

        try {
            read(clientEngine, clientReadBuffer, clientWriteBuffer);
        } catch (SSLException e) {
            // EXCEPTION THAT MUST BE HANDLED BEFORE RESPONDING WITH CLOSE_NOTIFY (OR ALERT)
            assert e.getMessage().startsWith("Received close_notify during handshake");
        }

        assert clientEngine.isInboundDone() : "Inbound should be done because we have received alert/close_notify";
        assert clientEngine.isOutboundDone() == false : "Outbound should not be done because we have not produced alert/close_notify";

        // THESE ARE DELAY HANDSHAKE MESSAGES FROM THE CLIENT (ClientKeyExchange/ChangeCipherSpec)

        read(serverEngine, serverReadBuffer, serverWriteBuffer);

        assert serverEngine.isInboundDone() == false : "Inbound should not be done because we have not received close_notify";
        if (false) {
            // THIS ASSERTION CURRENTLY FAILS BECAUSE THE FINAL HANDSHAKE MESSAGES PUT THE SERVER OUT OF THE CLOSING PROCESS
            assert serverEngine.isOutboundDone() : "Outbound should be done because we have produced alert/close_notify";
        }

        // SEND CLOSE_NOTIFY/ALERT FROM CLIENT
        while (clientEngine.getHandshakeStatus() == NEED_WRAP) {
            write(clientEngine, clientWriteBuffer);
            sendData(clientWriteBuffer.get(), serverReadBuffer.get());
        }

        assert clientEngine.isInboundDone() : "Inbound should be done because we have received alert/close_notify";
        assert clientEngine.isOutboundDone() : "Outbound should be done because we have produced alert/close_notify";

        try {
            read(serverEngine, serverReadBuffer, serverWriteBuffer);
        } catch (SSLException e) {
            // THIS IS THE ALERT FROM THE CLIENT THAT WAS PRODUCED WHEN IT RECEIVED A CLOSE_NOTIFY DURING HANDSHAKE
            // THE CLIENT SHOULD MAYBE NOT SEND THIS ALERT AND SHOULD JUST RESPOND WITH CLOSE_NOTIFY
            assert e.getMessage().startsWith("Received fatal alert: unexpected_message");
        }

        assert serverEngine.isInboundDone() : "Inbound should be done because we have received alert/close_notify";
        if (false) {
            // THIS ASSERTION CURRENTLY FAILS BECAUSE THE FINAL HANDSHAKE MESSAGES PUT THE SERVER OUT OF THE CLOSING PROCESS
            assert serverEngine.isOutboundDone() : "Outbound should be done because we have produced alert/close_notify";
        }
    }

    private void sendData(ByteBuffer from, ByteBuffer to) {
        to.put(from);
        assert from.hasRemaining() == false : "Should have consumed data in 'from' buffer";
    }

    private void write(SSLEngine clientEngine, AtomicReference<ByteBuffer> writeBufferRef) throws SSLException {
        handshake(clientEngine, writeBufferRef);
    }

    private boolean read(SSLEngine sslEngine, AtomicReference<ByteBuffer> readBufferRef, AtomicReference<ByteBuffer> writeBufferRef)
        throws SSLException {
        boolean continueUnwrap = true;
        while (continueUnwrap && readBufferRef.get().position() > 0) {
            ByteBuffer readBuffer = readBufferRef.get();
            readBuffer.flip();
            SSLEngineResult result = sslEngine.unwrap(readBuffer, UNUSED_BUFFER_ARRAY);
            switch (result.getStatus()) {
                case OK:
                    readBuffer.compact();
                    break;
                case BUFFER_UNDERFLOW:
                    // There is not enough space in the network buffer for an entire SSL packet. Compact the
                    // current data and expand the buffer if necessary.
                    int currentCapacity = readBuffer.capacity();
                    ensureBufferSize(sslEngine, readBufferRef);
                    readBuffer = readBufferRef.get();
                    if (currentCapacity == readBuffer.capacity()) {
                        readBuffer.compact();
                    }
                    continue;
                case BUFFER_OVERFLOW:
                    throw new AssertionError("Should not happen during handshake!");
                case CLOSED:
                    assert sslEngine.isInboundDone() : "We received close_notify so read should be done";
                    return true;
                default:
                    throw new IllegalStateException("Unexpected UNWRAP result: " + result.getStatus());
            }

            handshake(sslEngine, writeBufferRef);
            continueUnwrap = result.bytesConsumed() > 0;
        }
        return false;
    }

    private SSLEngineResult wrap(SSLEngine sslEngine, AtomicReference<ByteBuffer> writeBufferRef) throws SSLException {
        ByteBuffer writeBuffer = writeBufferRef.get();
        assert writeBuffer.hasRemaining() == false : "Should never called with pending writes";

        writeBuffer.clear();
        while (true) {
            SSLEngineResult result;
            try {
                result = sslEngine.wrap(UNUSED_BUFFER_ARRAY, writeBuffer);
            } catch (SSLException e) {
                writeBuffer.position(writeBuffer.limit());
                throw e;
            }

            switch (result.getStatus()) {
                case OK:
                    writeBuffer.flip();
                    return result;
                case BUFFER_UNDERFLOW:
                    throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP");
                case BUFFER_OVERFLOW:
                    // There is not enough space in the network buffer for an entire SSL packet. Expand the
                    // buffer if it's smaller than the current session packet size. Otherwise return and wait
                    // for existing data to be flushed.
                    int currentCapacity = writeBuffer.capacity();
                    ensureBufferSize(sslEngine, writeBufferRef);
                    writeBuffer = writeBufferRef.get();
                    if (currentCapacity == writeBuffer.capacity()) {
                        return result;
                    }
                    break;
                case CLOSED:
                    if (result.bytesProduced() > 0) {
                        writeBuffer.flip();
                    } else {
                        assert false : "WRAP during close processing should produce close message.";
                    }
                    return result;
                default:
                    throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus());
            }
        }
    }

    private void ensureBufferSize(SSLEngine sslEngine, AtomicReference<ByteBuffer> bufferRef) {
        ByteBuffer currentBuffer = bufferRef.get();
        int networkPacketSize = sslEngine.getSession().getPacketBufferSize();
        if (currentBuffer.capacity() < networkPacketSize) {
            ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize);
            currentBuffer.flip();
            newBuffer.put(currentBuffer);
            bufferRef.set(newBuffer);
        }
    }

    private void handshake(SSLEngine sslEngine, AtomicReference<ByteBuffer> writeBufferRef) throws SSLException {
        SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();

        boolean continueHandshaking = true;
        while (continueHandshaking) {
            switch (handshakeStatus) {
                case NEED_UNWRAP:
                    // We UNWRAP as much as possible immediately after a read. Do not need to do it here.
                    continueHandshaking = false;
                    break;
                case NEED_WRAP:
                    if (writeBufferRef.get().hasRemaining() == false) {
                        handshakeStatus = wrap(sslEngine, writeBufferRef).getHandshakeStatus();
                    }
                    // If we need NEED_TASK we should run the tasks immediately
                    if (handshakeStatus != NEED_TASK) {
                        continueHandshaking = false;
                    }
                    break;
                case NEED_TASK:
                    runTasks(sslEngine);
                    handshakeStatus = sslEngine.getHandshakeStatus();
                    break;
                case NOT_HANDSHAKING:
                    continueHandshaking = false;
                    break;
                case FINISHED:
                    continueHandshaking = false;
                    break;
            }
        }
    }

    private void runTasks(SSLEngine sslEngine) {
        Runnable delegatedTask;
        while ((delegatedTask = sslEngine.getDelegatedTask()) != null) {
            delegatedTask.run();
        }
    }

}
