/**
* (C) 2004 Jay Dolan. jasonthomasdolan@yahoo.com.
*
*  bpm.c: Calculate the beats per minute of an MP3 file.  Or try.
*
*  This software is provided free of charge, and with no warranty.
*  For more information, see the GNU General Public License.
*/

#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>

typedef FILE Wav;

int g_sampleRate 				= 44100;  //expected input sample rate
int g_channels 					= 2;  //expected input sample channles (stereo)
int g_downsampleFactor 	= 10;  //original sample will be downsampled to 1/n
int g_startTime 					= 0;  //start analyzing at n seconds
int g_endTime 					= 10;  //stop analyzing at n seconds
int g_minBpm					= 120;  //minimum bpm to check for
int g_maxBpm					= 145;	//maximum bpm to check for

/**
*  Returns a pointer to a wav file in memory, created by mpg123.  The format of the wav file 
*  is raw (headerless) linear PCM, 16 bit, stereo, host byte order.
*/
Wav *getWav(char *p_fileName){
	
	char command[1024];
	Wav *wav;
	
	if((wav = fopen(p_fileName, "r")) == NULL){  //ensure file exists
		fprintf(stderr, "Unable to open file \"%s\"\n", p_fileName);
		exit(0);
	}
	
	fclose(wav);

	snprintf(command, 1024, "mpg123 -s \"%s\" 2> /dev/null", p_fileName);  //construct command
	
	if((wav = popen(command, "r")) == NULL){  //open pipe to command
		fprintf(stderr, "Unable to execute command \"%s\"\n", command);
		exit(0);
	}
	
	return wav;
}

/**
*  Returns the average difference in amplitude per sample between p_audio, and p_audio shifted
*  at p_shift samples.  In theory, a small return value indicates that p_shift is nearing the samples
*  per beat of p_audio.
*/
int phaseShift(short *p_audio, int p_samples, int p_shift){
	
	int i, diff = 0;
	
	for(i = 0; i < p_samples - p_shift; i++)		
		diff += abs(p_audio[i] - p_audio[i + p_shift]);
	
	return diff / (p_samples - p_shift);	
}

/**
*  Returns the beats per minute of p_wav as a float using phase shifting.
*  A downsampled "copy" of p_wav is created for performance reasons.
*  A phase shift is then applied over the audio data between g_startTime 
*  and g_endTime incrementally from g_minBpm to g_maxBpm.  At each 
*  increment, the amplitude difference between the original and shifted
*  version is compared to find the best-fit phase shift.  This shift then
*  yields the BPM.
*/
float getBpm(Wav *p_wav){
	
	short left, right, *audio; 
	int i, shift, diff, previousDiff = 0; float k, bpm = 0;
	int sampleRate = g_sampleRate / g_channels / g_downsampleFactor;  //sample rate of buffer
	int samples = (g_endTime - g_startTime) * sampleRate;  //number of samples to be analyzed
	
	fseek(p_wav, sizeof(short) * g_startTime * g_sampleRate, 0);  //seek to first sample
	
	audio = malloc(sizeof(short) * samples);  //allocate buffer for amplitude mean
	
	for(i = 0; i < samples; i++){  //read p_wav and construct audio
		
		if(fread(&left, sizeof(short), 1, p_wav) == 0)  //read left channel
			break;
			
		if(g_channels == 2){  //stereo stream, most common
			if(fread(&right, sizeof(short), 1, p_wav) == 0)  //read right channel
				break;
		
			audio[i] = (left + right) / 2;  //calc mean of channels
		}
		else audio[i] = left; //mono stream
		
		fseek(p_wav, sizeof(short) * g_channels * g_downsampleFactor, 0);  //seek to next desired sample
	}
	
	if(i < samples){  //allow a smaller set of samples than desired, but warn
		fprintf(stderr, "Warning: EOF reached before requested samples\n");
		samples = i;
	}
	
	for(k = g_minBpm; k <= g_maxBpm; k += .1){  //apply phase shift for every .1 bpm
	
		shift = sampleRate / (k / 60.0) * 4;  //calculate 4 beat shift for current k
		
		diff = phaseShift(audio, samples, shift);  //apply phase shift
		
		bpm = diff < previousDiff ? k : bpm;  //record smallest diff
		
		previousDiff = diff;		
	}
	
	free(audio);
	
	return bpm;
}

/**
*  Returns 0 after printing the file's BPM to stdout.
*/
int main(int argc, char *args[]){
	
	if(getopt(argc, args, "v") != -1){  //print version and exit
		printf("bpm 0.1, jasonthomasdolan@yahoo.com\n");
		exit(0);
	}
	
	if(getopt(argc, args, "s:") != -1){  //start time specified, resolve end time
		g_startTime = atoi(optarg);
	
		if(getopt(argc, args, "e:") != -1)
			g_endTime = atoi(optarg);
		else g_endTime = g_startTime + 10;
	}
	
	if(g_startTime >= g_endTime){  //invalid options, show usage
		fprintf(stderr, "Usage: %s [-s seconds [-e seconds]] file\n", args[0]);
		exit(1);
	}
	
	Wav *wav = getWav(args[optind]);  //convert the file to .wav
	
	printf("%f\n", getBpm(wav));  //analyze .wav
	
	pclose(wav);  //close pipe

	return 0;
}



