package org.apache.log4j.contrib;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.StringTokenizer;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import sun.misc.BASE64Decoder;

/** <p>The AuthenticatingServlet is a servlet which requires
 * proper HTTP authentication.</p>
 *
 * @author <a href="mailto:joe@ispsoft.de">Jochen Wiedmann</a>
 */
public abstract class AuthenticatingServlet extends HttpServlet {
  protected boolean isAuthenticated(String pUsername, String pPassword) {
    return "AuthUser".equals(pUsername)  &&  "AuthPassword".equals(pPassword);
  }

  protected String getRealm() {
    return "Admin Area";
  }

  /** <p>Creates a new instance of AuthenticatingServlet.</p>
   */
  public AuthenticatingServlet() {
  }

  protected boolean checkAuthentication(HttpServletRequest pRequest) {
    // Verify whether there is a proper "Authorization" header
    String authorization = pRequest.getHeader("Authorization");
    if (authorization == null) {
      return false;
    }
    // Verify whether the authorization header has the format "Basic details";
    StringTokenizer st = new StringTokenizer(authorization);
    if (!st.hasMoreTokens()) {
      return false;
    }
    if (!"Basic".equalsIgnoreCase(st.nextToken())) {
      return false;
    }
    if (!st.hasMoreTokens()) {
      return false;
    }
    String details = st.nextToken();
    // Verify whether the details are base64 encoded
    BASE64Decoder decoder = new BASE64Decoder();
    byte[] buffer;
    try {
      buffer = decoder.decodeBuffer(details);
    } catch (IOException e) {
      return false;
    }
    String decodedDetails;
    try {
      decodedDetails = new String(buffer, "UTF8");
    } catch (UnsupportedEncodingException e) {
      return false;
    }
    // Verify whether the details have the format username:password
    st = new StringTokenizer(decodedDetails, ":");
    if (!st.hasMoreTokens()) {
      return false;
    }
    String username = st.nextToken();
    if (!st.hasMoreTokens()) {
      return false;
    }
    String password = st.nextToken();
    if (st.hasMoreTokens()) {
      return false;
    }
    return isAuthenticated(username, password);
  }

  public void rejectUnauthenticated(HttpServletResponse pResponse) {
    pResponse.setHeader("WWW-Authenticate", "Basic realm=\"" + getRealm() + "\"");
    pResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
  }

  public void doGet(HttpServletRequest pRequest, HttpServletResponse pResponse) throws ServletException, IOException {
    if (!checkAuthentication(pRequest)) {
      rejectUnauthenticated(pResponse);
    } else {
      run(pRequest, pResponse, false);
    }
  }

  public void doPost(HttpServletRequest pRequest, HttpServletResponse pResponse) throws ServletException, IOException {
    if (!checkAuthentication(pRequest)) {
      rejectUnauthenticated(pResponse);
    } else {
      run(pRequest, pResponse, true);
    }
  }

  protected abstract void run(HttpServletRequest pRequest, HttpServletResponse pResponse,
                                boolean usePost) throws ServletException, IOException;
}

