#!/usr/bin/env python

# Import(s)
import os
import vtk
import numpy as np
from shutil import copy


def get_dimensions(mha_file):
    """get_dimensions(mha_file) -> Get dimensions of the image from the MHA
    file"""
    with open(mha_file,'r') as mha_file:
        for line in mha_file:
            field, value = line.split('=')				#str.split - Return a list of the words in the string, using sep as the delimiter string.

            if 'DimSize' in field:
                value = value.strip()					#removes spaces before and after the text in the string (not in the middle)
                dimensions = [int(v) for v in value.split(' ')]
                return dimensions					#returns: [224, 208, 208]
        else:
            raise ValueError, 'Unable to find DimSize'

def get_spacing(mha_file):
    """get_spacing(mha_file) -> Get spacing of the image from the MHA file"""
    with open(mha_file,'r') as mha_file:
        for line in mha_file:
            field, value = line.split('=')

	    if 'ElementSpacing' in field:
                value = value.strip()	
                physical_spacing = [float(v) for v in value.split(' ')]
                return physical_spacing					#returns: [0.082400, 0.083700, 0.075700] 
        else:
            raise ValueError, 'Unable to find ElementSpacing'

def get_rest(mha_file,segmentation):
    """get_rest(mha_file) -> Get element spacing, type and origin of the image from the MHA file"""

    with open(mha_file,'r') as mha_file:
        for line in mha_file:
            field, value = line.split('=')

	    if segmentation == False:
	    	if 'Offset' in field:
			value = value.strip()			
                	origin = [int(v) for v in value.split(' ')]	# returns: [0, 0, 0]
                	print ''
			print '-->> WARNING: Offset of the volume was found and might be [0,0,0]. In any case CHECK IT, MIGHT BE DIFFERENT!'
			print ''
	
	    if 'ElementType' in field:
		eType = ['MET_UCHAR']
                value = value.strip()		
                value = [str(v) for v in value.split(' ')]
                if value != eType:
			raise ValueError, 'ElementType must be in an unsigned character type to be read by the program!'
		return True 
	else:
	    raise ValueError, 'Unable to find ElementType'

def split_raw(input_,output,dimensions):
    """split_raw(input_,output,dim) -> Split 3D raw file into dim[2] slices
    with name output.i, where i is the slice number (indexed from 0)"""

    slice_size = dimensions[0]*dimensions[1]
    with open(input_,'rb') as input_:					# read in binary mode
        for i in range(dimensions[2]):
            slice_ = input_.read(slice_size)
            with open(output + '.%d' % i,'wb') as out:			# write in binary mode
                out.write(slice_)


# Class definitions
class UltraSound3D:
    def __init__(self, patient_number, frame_number, segmentation, init):
        """US3DData(patient_number, frame_number) -> Create US3DData operating on the given patient data."""
	
	main_folder = os.getcwd()
	
	if segmentation:
		if init:
			data_root = './Examples/Full_Segmentation_ED_frames'
		else:
			data_root = './3D_Ultrasound/Full_Segmentation_ED_frames'
	else:
		if init:
			data_root = './Examples'
		else:
			data_root = './3D_Ultrasound'
	os.chdir(data_root)
	data_folder = os.getcwd()
	
	# Get list of raw
	all_files = os.listdir('.')
	raw_files = []
	for filename in all_files:
    		root, ext = os.path.splitext(filename)		# split the pathname path into a pair (root, ext) such that root + ext(ension) == path
    		if '.raw' in ext:
        		raw_files.append(root)			# appends '20fr10'

	# Get frame and patient numbers, make patient folders, move .raws to that and create xxfryy folders with the 3D raw slices (splitting)
	frame_numbers = []
	patient_numbers = []
	for name in raw_files:
    		patient,frame = name.split('fr')		# patient = '20'   and frame = '10'

		# Make directories for patients
		try:
        		os.mkdir(patient)
    	        except OSError:
        		print 'Patient %s folder already exists' % patient

		# Make directories for (vtk) frames
		try:
			copy(name+'.raw', './'+patient)
			os.remove(name+'.raw')
			if segmentation:
				copy(name+'.mhd', './'+patient)
				os.remove(name+'.mhd')
			else:
				copy(name+'.mha', './'+patient)
				os.remove(name+'.mha')

			os.chdir('./'+patient)
			patient_folder = os.getcwd()		#/local/wolf3519/Documents/DataRes/LV_Defor_Model_KalmanF/3D_Ultrasound/16 or 13,10,...

       			os.mkdir(name)

			# For each file, read MHA and split into directory       #zip(): This function returns a list of tuples, where the i-th tuple contains the i-th element from each of the argument sequences or iterables. The returned list is truncated in length to the length of the shortest argument sequence
			#for filename,frame in zip(raw_files,frame_numbers):
			if segmentation:
				mha = name + '.mhd'
			else:
				mha = name + '.mha'

			print '  --> Slicing frame %s! Please wait...' % frame
			print ''

			get_rest(mha,segmentation)
			dim = get_dimensions(mha)

			output = './' + name + '/' + name
			input_ = './' + name + '.raw'

			split_raw(input_,output,dim)
  		
			os.chdir(data_folder)
        	except OSError:
       			print 'Patient %s (and respective frame) already exists' % patient
       			continue
	
	os.chdir(main_folder)

	# Set patient and frame number
       	self.patient_number = str(patient_number)
       	self.patient_root = os.path.join(data_root,self.patient_number)
	self.frame_number = str(frame_number)

	# Get dimensions of current patient frame
	self.stem = self.patient_number + 'fr' + self.frame_number

	if segmentation:
		mha = self.patient_root + '/' + self.stem + '.mhd'
	else:
		mha = self.patient_root + '/' + self.stem + '.mha'

	self.dimensions = get_dimensions(mha)
	self.origin = [0,0,0]
	self.physical_spacing = get_spacing(mha)

        # Calculate centre
        x_max = np.array(self.dimensions,dtype=float)-1
        origin = np.array(self.origin)
        spacing = np.array(self.physical_spacing)
        self.centre = origin+spacing*0.5*x_max

    def update(self):
        """update() -> Update memory mapped array."""
	# numpy.memmap: Create a memory-map to an array stored in a file on disk
        path = self.raw_frame_path()
        self.im = np.memmap(path, np.uint8, 'r', 0, tuple(self.dimensions[::-1]))	# if dim = [1,2,3], tuple(dim[::-1])returns [3,2,1]		# WHY DO YOU INVERT THE DESIRED SHAPE OF THE ARRAY (access small segment of the .RAW file) HERE?
	       
    def split_frame_path(self):
        """split_frame_path() -> Return the filename stem for the current frame number. This can be used by vtkImageReader2 to read the 3D data corresponding to the given frame."""
        filename = self.stem + '/' + self.stem
        return os.path.join(self.patient_root, filename)

    def raw_frame_path(self):
        """raw_frame_path() -> Return full filename of .RAW corresponding to the  given frame number. Dimensions of frame are those specified by ``DIMENSIONS``."""

        filename = self.patient_number + 'fr' + self.frame_number + '.raw'
        return os.path.join(self.patient_root, filename)

    def to_vtkImageReader2(self):
        """to_vtkImageReader2() -> Prepare the frame as a
        vtkImageReader2 object."""
        
	#print self.patient_number
	#print self.frame_number

        ir = vtk.vtkImageReader2()
        ir.SetFilePrefix(self.split_frame_path())
        ir.SetDataExtent(0,self.dimensions[0]-1,0,self.dimensions[1]-1,0,self.dimensions[2]-1)
        ir.SetDataSpacing(*self.physical_spacing)	# giving the element spacing (in .mha) = size of the voxel dimensions, everything represented now in the volume is at scale, meaning it should be in cm
        ir.SetDataScalarTypeToUnsignedChar()
        ir.SetDataOrigin(*self.origin)
        ir.Update()

        return ir

    def __getitem__(self,key):
        """__getitem__(key) -> Expose the memory-mapped array to the main
        image data."""

        # Map array if required
        if not hasattr(self,'im'):  
            self.update()

        # NOTE: VTK (x,y,z) is given by (z,DIMENSIONS[2]-1-y,x)			WHAT?????????
        return self.im.__getitem__(key)

