package com.syntegra.peopleservices.servlet.filter;

import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.OutputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import java.util.Enumeration;

import org.apache.log4j.*;

/**
 * based on http://www.orionserver.com/tutorials/filters/5.html
 */
public class GZIPFilter2 extends GenericFilter {

    private static final Logger logger = Logger.getLogger(GZIPFilter.class.getName());

    public GZIPFilter2() {}

    public void doFilter(final ServletRequest request,
                         final ServletResponse response,
                         FilterChain chain)
                                throws IOException, ServletException {

        HttpServletResponse httpResponse = (HttpServletResponse) response;
        HttpServletRequest httpRequest = (HttpServletRequest)request;
        GenericResponseWrapper wrapper = new GenericResponseWrapper(httpResponse);
        //ByteArrayResponseWrapper wrapper = new ByteArrayResponseWrapper(httpResponse);
        chain.doFilter(request, wrapper);
        //for using "Vary", see http://nagoya.apache.org/bugzilla/show_bug.cgi?id=2820
        httpResponse.addHeader("Vary", "Accept-Encoding");
        OutputStream out = response.getOutputStream();

        byte[] origBytes = wrapper.getData();
        if (!isCached(wrapper) && !isIncluded(request) && accepts(httpRequest, "gzip")) {
            httpResponse.addHeader("Content-Encoding", "gzip");
            ByteArrayOutputStream compressed = new ByteArrayOutputStream();
            GZIPOutputStream gzout = new GZIPOutputStream(compressed);
            gzout.write(origBytes);
            gzout.finish();
            gzout.close();

            if (logger.isDebugEnabled()) {
                logger.debug("compressed data...");
                logger.debug(compressed);
                ByteArrayInputStream bais = new ByteArrayInputStream(compressed.toByteArray());
                GZIPInputStream gzin = new GZIPInputStream(bais);
                byte[] buffer = new byte[1024];
                int n, i = 0, m = buffer.length;
                while ((n = gzin.read (buffer, i, m - i)) >= 0) {
                    i += n;
                    if (i >= m) {
                        byte[] newBuffer = new byte[m *= 2];
                        System.arraycopy (buffer, 0, newBuffer, 0, i);
                        buffer = newBuffer;
                    }
                }
                byte[] result = new byte[i];
                System.arraycopy (buffer, 0, result, 0, i);
                ByteArrayOutputStream decompressed = new ByteArrayOutputStream();
                DataOutputStream daos = new DataOutputStream(decompressed);
                daos.write(result);
                daos.flush();
                daos.close();
                logger.debug("decompressed data...");
                logger.debug(decompressed);
            }

            response.setContentLength(compressed.size());
            out.write(compressed.toByteArray());

            if (logger.isDebugEnabled()) {
                logger.debug("GZIP filtered data totals...");
                logger.debug("Orig data size: "+origBytes.length+" bytes");
                logger.debug("GZIP data size: "+compressed.size()+" bytes");
            }
        }
         else {
            out.write(origBytes);

            if (logger.isDebugEnabled()) {
                logger.debug("Bypassed GZIP filtering...");
                logger.debug("Orig data size: "+origBytes.length+" bytes");
            }
        }
        out.flush();
        response.flushBuffer();
        out.close();
    }

    protected boolean isIncluded(ServletRequest request) {
        String uri = (String) request.getAttribute("javax.servlet.include.request_uri");
        if (uri == null) {
            if (logger.isDebugEnabled()) logger.debug("request is unique (invoked by client-side request)");
            return false;
        }
        else {
            if (logger.isDebugEnabled()) logger.debug("request is included in another request (invoked by server-side request)");
            return true;
        }
    }

    protected boolean isCached(GenericResponseWrapper wrapper) {
        if (wrapper.getData().length > 0) {
            if (logger.isDebugEnabled()) logger.debug("non-cached response");
            return false;
        }
        else {
            if (logger.isDebugEnabled()) logger.debug("cached response (empty)");
            return true;
        }
    }

    protected boolean accepts(HttpServletRequest request, String name) {
        boolean accepts = headerContains(request, "Accept-Encoding", name);
        if (logger.isDebugEnabled()) logger.debug("client supports 'Accept-Encoding' '"+name+"': "+accepts);
        return accepts;
    }

    protected boolean headerContains(HttpServletRequest request, String header, String value) {
        /*
        //a stupid bug in Tomcat makes this return a single enumeration element on headers which
        //send a comma delimited list of values rather than repeated separate entries of the same
        //header name with a single value.  So, the output of the debug statement below would look something like this:
        //DEBUG - current value for header 'Accept-Encoding': 'gzip,deflate,compress;q=0.9'
        //The work-around is simply to use getHeader() and see if the given value exists in the
        //string returned.  How annoying!  See here for more details:
        //http://nagoya.apache.org/bugzilla/show_bug.cgi?id=9526
        Enumeration accepted = request.getHeaders(header);
        while (accepted.hasMoreElements()) {
            String headerValue = (String)accepted.nextElement();
            if (logger.isDebugEnabled()) logger.debug("current value for header '"+header+"': '"+headerValue+"'");
            if (headerValue.equalsIgnoreCase(value)) {
                return true;
            }
        }
        */
        /*
        String headerValues = request.getHeader(header);
        if ((headerValues != null) && (headerValues.indexOf(value) != -1)) return true;
        */

        Enumeration accepted = request.getHeaders(header);
        while (accepted.hasMoreElements()) {
            String headerValue = (String)accepted.nextElement();
            if (logger.isDebugEnabled()) logger.debug("current value for header '"+header+"': '"+headerValue+"'");
            if (headerValue.indexOf(value) != -1) {
                return true;
            }
        }
        return false;
    }
}
