import com.datastax.driver.core.*;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;

import java.util.*;
import java.util.concurrent.Semaphore;

public class CassandraTest2 {

    public static void main(String[] args) throws Exception {

        final int maxExecutions = Integer.parseInt(args[0]);
        final int parallelism = Integer.parseInt(args[1]);

        CassandraTest2 test = new CassandraTest2();

        // Connect
        System.out.println("Connecting...");
        String[] contactPoints = new String[]{ IP1, IP2, IP3};
        String keyspace = "stresscql";
        String cql = "UPDATE counttest SET " +
                "count1_column = count1_column + ?, " +
                "count2_column = count2_column + ?, " +
                "count3_column = count3_column + ?, " +
                "count4_column = count4_column + ?, " +
                "count5_column = count5_column + ? " +
                "WHERE key_column = ? AND cluster_column = ?";
        test.prepare(contactPoints, keyspace, cql, maxExecutions, parallelism);

        // Prepare statements
        System.out.println("Preparing...");
        List<BoundStatement> statements = test.generateStatements();

        // Start timer
        Timer timer = new Timer();
        timer.schedule(new TimerTask() {
            int lastCount = 0;

            @Override
            public void run() {
                int delta = test.executionsSoFar - lastCount;
                lastCount = delta + lastCount;
                System.out.println("Rate = " + delta + " inserts/second - ExecutionsSoFar = " + test.executionsSoFar + " (" + 100 * test.executionsSoFar / maxExecutions + "%)");
            }
        }, 1000, 1000);

        // Execute
        System.out.println("Executing...");
        test.execute(statements);

        // Cleanup
        timer.cancel();
        test.session.close();
        test.cluster.close();

        System.exit(0);
    }

    private final Random random = new Random();

    private Cluster cluster;
    private Session session;
    private Semaphore throttle;
    private PreparedStatement prepared;
    private int maxExecutions;

    private volatile int executionsSoFar;

    private void prepare(String[] contactPoints, String keyspace, String cql, int maxExecutions, int parallelism) {
        this.cluster = Cluster.builder().addContactPoints(contactPoints).build();
        this.session = cluster.connect("\"" + keyspace + "\"");
        this.prepared = session.prepare(cql);
        this.maxExecutions = maxExecutions;
        this.throttle = new Semaphore(parallelism, false);
    }

    private void execute(List<BoundStatement> statements) {
        for (BoundStatement statement : statements) {
            throttle.acquireUninterruptibly();
            Futures.addCallback(session.executeAsync(statement), new FutureCallback<ResultSet>() {
                @Override
                public void onSuccess(ResultSet o) {
                    throttle.release();
                    executionsSoFar++;
                }

                @Override
                public void onFailure(Throwable throwable) {
                    throttle.release();
                    executionsSoFar++;
                    System.out.println("Save exception = " + throwable);
                }
            });
        }
    }

    private List<BoundStatement> generateStatements() {
        List<BoundStatement> statements = new ArrayList<>(maxExecutions);
        for (int i = 0; i < maxExecutions; i++) {
            long key_column = getRandom(0, 5000000);
            int cluster_column = getRandom(0, 4096);
            long count1_column = getRandom(0, 10);
            long count2_column = getRandom(0, 10);
            long count3_column = getRandom(0, 10);
            long count4_column = getRandom(0, 10);
            long count5_column = getRandom(0, 10);
            statements.add(prepared.bind(
                    count1_column,
                    count2_column,
                    count3_column,
                    count4_column,
                    count5_column,
                    key_column,
                    cluster_column
            ));
        }
        return statements;
    }

    private int getRandom(int from, int to) {
        return (int) (from + random.nextDouble() * (to - from));
    }
}
