#include <syslog.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <netdb.h>

#ifndef MAXHOSTNAMELEN
#define MAXHOSTNAMELEN HOST_NAME_MAX
#endif

#ifdef HEIMDAL
#include <gssapi.h>
#else
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_generic.h>
#endif

#include <krb5.h>

static const unsigned char ntlmProtocol [] = {'N', 'T', 'L', 'M', 'S', 'S', 'P', 0};

static unsigned char os_toascii[256];
static const char basis_64[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

static const unsigned char pr2six[256] =
{
    /* ASCII table */
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63,
    52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64,
    64,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
    15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64,
    64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
    41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64
};

int base64encode_len(int len)
{
    return ((len + 2) / 3 * 4) + 1;
}

int base64encode(char *encoded, const char *string, int len)
{
    int i;
    char *p;

    p = encoded;
    for (i = 0; i < len - 2; i += 3) {
	*p++ = basis_64[(os_toascii[string[i]] >> 2) & 0x3F];
	*p++ = basis_64[((os_toascii[string[i]] & 0x3) << 4) |
			((int) (os_toascii[string[i + 1]] & 0xF0) >> 4)];
	*p++ = basis_64[((os_toascii[string[i + 1]] & 0xF) << 2) |
			((int) (os_toascii[string[i + 2]] & 0xC0) >> 6)];
	*p++ = basis_64[os_toascii[string[i + 2]] & 0x3F];
    }
    if (i < len) {
	*p++ = basis_64[(os_toascii[string[i]] >> 2) & 0x3F];
	if (i == (len - 1)) {
	    *p++ = basis_64[((os_toascii[string[i]] & 0x3) << 4)];
	    *p++ = '=';
	}
	else {
	    *p++ = basis_64[((os_toascii[string[i]] & 0x3) << 4) |
			    ((int) (os_toascii[string[i + 1]] & 0xF0) >> 4)];
	    *p++ = basis_64[((os_toascii[string[i + 1]] & 0xF) << 2)];
	}
	*p++ = '=';
    }

    *p++ = '\0';
    return p - encoded;
}

int base64decode_len(const char *bufcoded)
{
    int nbytesdecoded;
    register const unsigned char *bufin;
    register int nprbytes;

    bufin = (const unsigned char *) bufcoded;
    while (pr2six[*(bufin++)] <= 63);

    nprbytes = (bufin - (const unsigned char *) bufcoded) - 1;
    nbytesdecoded = ((nprbytes + 3) / 4) * 3;

    return nbytesdecoded + 1;
}

int base64decode_binary(unsigned char *bufplain,
			const char *bufcoded)
{
    int nbytesdecoded;
    register const unsigned char *bufin;
    register unsigned char *bufout;
    register int nprbytes;

    bufin = (const unsigned char *) bufcoded;
    while (pr2six[*(bufin++)] <= 63);
    nprbytes = (bufin - (const unsigned char *) bufcoded) - 1;
    nbytesdecoded = ((nprbytes + 3) / 4) * 3;

    bufout = (unsigned char *) bufplain;
    bufin = (const unsigned char *) bufcoded;

    while (nprbytes > 4) {
	*(bufout++) =
	    (unsigned char) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4);
	*(bufout++) =
	    (unsigned char) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2);
	*(bufout++) =
	    (unsigned char) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]);
	bufin += 4;
	nprbytes -= 4;
    }

    /* Note: (nprbytes == 1) would be an error, so just ingore that case */
    if (nprbytes > 1) {
	*(bufout++) =
	    (unsigned char) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4);
    }
    if (nprbytes > 2) {
	*(bufout++) =
	    (unsigned char) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2);
    }
    if (nprbytes > 3) {
	*(bufout++) =
	    (unsigned char) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]);
    }
    nbytesdecoded -= (4 - nprbytes) & 3;
    return nbytesdecoded;
}

int base64decode(char *bufplain, const char *bufcoded)
{
    int len;

    len = base64decode_binary((unsigned char *) bufplain, bufcoded);
    bufplain[len] = '\0';
    return len;
}

static const char *
get_gss_error(OM_uint32 error_status, char *prefix)
{
    OM_uint32 maj_stat, min_stat;
    OM_uint32 msg_ctx = 0;
    gss_buffer_desc status_string;
    char buf[1024];
    size_t len;

    snprintf(buf, sizeof(buf), "%s: ", prefix);
    len = strlen(buf);
    do {
	maj_stat = gss_display_status (&min_stat,
				       error_status,
				       GSS_C_MECH_CODE,
				       GSS_C_NO_OID,
				       &msg_ctx,
				       &status_string);
	if (sizeof(buf) > len + status_string.length + 1) {
	    /*
	      sprintf(buf, "%s:", (char*) status_string.value);
	    */
	    sprintf(buf+len, "%s:", (char*) status_string.value);
	    len += status_string.length;
	}
	gss_release_buffer(&min_stat, &status_string);
    } while (!GSS_ERROR(maj_stat) && msg_ctx != 0);

    return (strdup(buf));
}

char *gethost_name() {
    char      hostname[MAXHOSTNAMELEN];
    struct addrinfo *hres, *hres_list;
    int rc,count;

    rc = gethostname(hostname,MAXHOSTNAMELEN);
    if (rc)
    {
	syslog(LOG_ERR, "error while resolving hostname");
	return NULL;
    }
    rc = getaddrinfo(hostname,NULL,NULL,&hres);
    if (rc != 0) {
        syslog(LOG_ERR, "error while resolving hostname with getaddrinfo: %s",gai_strerror(rc));
        return NULL;
    }
    hres_list=hres;
    count=0;
    while (hres_list) {
        count++;
        hres_list=hres_list->ai_next;
    }
    rc = getnameinfo (hres->ai_addr, hres->ai_addrlen,hostname, sizeof (hostname), NULL, 0, 0);
    if (rc != 0) {
        syslog(LOG_ERR, "error while resolving ip address with getnameinfo: %s",gai_strerror(rc));
        freeaddrinfo(hres);
        return NULL ;
    }

    hostname[MAXHOSTNAMELEN]='\0';
    return strdup(hostname);
}

int main(int argc, const char **argv)
{
    char buf[6400+1];
    int length;
    char *c;
    static int err=0;
    gss_buffer_desc input_token = GSS_C_EMPTY_BUFFER;
    gss_buffer_desc output_token = GSS_C_EMPTY_BUFFER;
    unsigned char        *kerberosToken       = NULL;
    size_t                kerberosTokenLength = 0;
    unsigned char        *spnegoToken         = NULL ;
    size_t                spnegoTokenLength   = 0;
    int rc;
    char *service_name,*host_name;
    OM_uint32 major_status, minor_status;
    gss_name_t my_gss_name = GSS_C_NO_NAME;
    gss_cred_id_t my_gss_creds = GSS_C_NO_CREDENTIAL;
    gss_ctx_id_t gss_context = GSS_C_NO_CONTEXT;
    gss_cred_id_t delegated_cred = GSS_C_NO_CREDENTIAL;
    gss_name_t client_name = GSS_C_NO_NAME;
    int                   ret_flags=0,spnego_flag=0;
    gss_buffer_desc service = GSS_C_EMPTY_BUFFER;

    while (1) {
	if (fgets(buf, sizeof(buf)-1, stdin) == NULL) {
	    if (ferror(stdin)) {
		syslog(LOG_DEBUG, "fgets() failed! dying..... errno=%d (%s)\n", ferror(stdin),
		       strerror(ferror(stdin)));

		exit(1);    /* BIIG buffer */
	    }
	    exit(0);
	}

	c=memchr(buf,'\n',sizeof(buf)-1);
	if (c) {
	    *c = '\0';
	    length = c-buf;
	} else {
	    err = 1;
	}
	if (err) {
	    syslog(LOG_DEBUG, "Oversized message\n");
	    fprintf(stderr, "NA Oversized message\n");
	    err = 0;
	    continue;
	}

	syslog(LOG_DEBUG, "Got '%s' from squid (length: %d).\n",buf,length);

	if (buf[0] == '\0') {
	    syslog(LOG_DEBUG, "Invalid Request\n");
	    fprintf(stderr, "NA Invalid Request\n");
	    continue;
	}

	if (strlen(buf) < 2) {
	    syslog(LOG_DEBUG, "SPNEGO query [%s] invalid", buf);
	    (stdout, "NA SPNEGO query invalid\n");
	    continue;
	}

	if (strncmp(buf, "YR", 2) == 0) {

	} else if (strncmp(buf, "KK", 2) == 0) {

	} else {
	    syslog(LOG_DEBUG, "SPNEGO query [%s] invalid", buf);
	    fprintf(stdout, "NA SPNEGO query invalid\n");
	    continue;
	}

	if ( (strlen(buf) == 2)) {

	    /* no client data, get the negTokenInit offering
	       mechanisms */

	    syslog(LOG_DEBUG, "SPNEGO query [%s] invalid", buf);
	    fprintf(stdout, "NA SPNEGO query invalid\n");
	    continue;
	}

	if (strlen(buf) <= 3) {
	    syslog(LOG_DEBUG, "GSS-SPNEGO query [%s] invalid\n", buf);
	    fprintf(stdout, "NA GSS-SPNEGO query invalid\n");
	    continue;
	}
        
	input_token.length = base64decode_len(buf+3);
	input_token.value = malloc(input_token.length);

	input_token.length = base64decode(input_token.value, buf+3);
 
	if (( rc=parseNegTokenInit (input_token.value,
				    input_token.length,
				    &kerberosToken,
				    &kerberosTokenLength))!=0 ){
	    syslog(LOG_DEBUG, "parseNegTokenInit failed with rc=%d",rc);
        
       
	    if ( rc < 100 || rc > 199 ) {
		syslog(LOG_DEBUG, "GSS-SPNEGO query [%s] invalid\n", buf);
		fprintf(stdout, "NA GSS-SPNEGO query invalid\n");
		continue;
	    } 
	    if ((input_token.length >= sizeof ntlmProtocol + 1) &&
		(!memcmp (input_token.value, ntlmProtocol, sizeof ntlmProtocol))) {
		syslog(LOG_DEBUG, "received type %d NTLM token", (int) *((unsigned char *)input_token.value + sizeof ntlmProtocol));
		fprintf(stdout, "NA received type %d NTLM token\n",(int) *((unsigned char *)input_token.value + sizeof ntlmProtocol));
		continue;
	    } 
	    spnego_flag=0;
	} else { 
	    input_token.length=kerberosTokenLength;
	    input_token.value = malloc(input_token.length);
	    if (input_token.value == NULL) {
		syslog(LOG_DEBUG, "Not enough memory");
		fprintf(stdout, "NA Not enough memory\n");
		continue;
	    }
	    memcpy(input_token.value,kerberosToken,input_token.length);
	    spnego_flag=1;
	}
     
	host_name=gethost_name();
	if ( !host_name ) {
	    fprintf(stdout,"NA Hostname not found");
	    continue;
	}
	service_name=strdup("HTTP");
	service.value = malloc(strlen(service_name)+strlen(host_name)+2);
	snprintf(service.value,strlen(service_name)+strlen(host_name)+2,"%s@%s",service_name,host_name);
	service.length = strlen((unsigned char *)service.value);
#if HEIMDAL
	major_status = gss_import_name(&minor_status, &service,
				       GSS_C_NT_HOSTBASED_SERVICE, &my_gss_name);
#else
	major_status = gss_import_name(&minor_status, &service,
				       gss_nt_service_name, &my_gss_name);
#endif

	if (GSS_ERROR(major_status)) {
	    syslog(LOG_DEBUG, "%s Used service principal: %s", get_gss_error(minor_status,"gss_import_name() failed for service principal"),(unsigned char *)service.value);
	    fprintf(stdout, "NA %s\n",get_gss_error( minor_status,"gss_import_name() failed for service principal"));
	    continue;
	}
	major_status = gss_acquire_cred(&minor_status, my_gss_name, GSS_C_INDEFINITE,
					GSS_C_NO_OID_SET, GSS_C_ACCEPT, &my_gss_creds,
					NULL, NULL);
	if (GSS_ERROR(major_status)) {
	    syslog(LOG_DEBUG, "%s Used service principal: %s", get_gss_error(minor_status,"gss_acquire_cred() failed"),(unsigned char *)service.value);
	    fprintf(stdout, "NA %s\n",get_gss_error(minor_status,"gss_acquire_cred() failed"));
	    continue;
	}

	major_status = gss_accept_sec_context(&minor_status,
					      &gss_context,
					      my_gss_creds,
					      &input_token,
					      GSS_C_NO_CHANNEL_BINDINGS,
					      &client_name,
					      NULL,
					      &output_token,
					      &ret_flags,
					      NULL,
					      &delegated_cred);


	if (output_token.length) {
	    char *token = NULL;
	    if (spnego_flag) {
		if ((rc=makeNegTokenTarg (output_token.value,
					  output_token.length,
					  &spnegoToken,
					  &spnegoTokenLength))!=0 ) {
		    syslog(LOG_DEBUG, "makeNegTokenTarg failed with rc=%d",rc);
		    fprintf(stdout, "NA makeNegTokenTarg failed with rc=%d\n",rc);
		    continue;
		}
	    } else {
		spnegoToken = output_token.value;
		spnegoTokenLength = output_token.length;
	    }
	    token = malloc(base64encode_len(spnegoTokenLength));
	    if (token == NULL) {
		syslog(LOG_DEBUG, "Not enough memory");
		gss_release_buffer(&minor_status, &output_token);
		fprintf(stdout, "NA Not enough memory\n");
		continue;
	    }

	    base64encode(token, spnegoToken, spnegoTokenLength);
	    gss_release_buffer(&minor_status, &output_token);

	    if (GSS_ERROR(major_status)) {
		syslog(LOG_DEBUG, "%s Used service principal: %s", get_gss_error( minor_status,"gss_accept_sec_context() failed"),(unsigned char *)service.value);
		fprintf(stdout, "NA %s\n",get_gss_error( minor_status,"gss_accept_sec_context() failed"));
		continue;
	    }
	    if (major_status & GSS_S_CONTINUE_NEEDED) {
		syslog(LOG_DEBUG, "continuation needed");
		fprintf(stdout, "TT %s\n",token);
		continue;
	    }
	    major_status = gss_display_name(&minor_status, client_name, &output_token,
					    NULL);

	    gss_release_name(&minor_status, &client_name);
	    if (GSS_ERROR(major_status)) {
		syslog(LOG_DEBUG, "%s", get_gss_error(minor_status,"gss_display_name() failed"));
		fprintf(stdout, "NA %s\n",get_gss_error(minor_status,"gss_display_name() failed"));
		continue;
	    }
	    fprintf(stdout, "AF %s %s\n",token,output_token.value);
	    syslog(LOG_DEBUG, "AF %s %s\n",token,output_token.value); 
	    continue;
	} else {
	    fprintf(stdout, "NA Invalid token\n");
	    continue;
	}
    }
}
