#cython: boundscheck=False


import numpy.random as rn
from copy import copy

######################################################################
# First the 

cdef class Point:
    def __init__(self, float x=0, float y=0, float z=0):
        self.x = x
        self.y = y
        self.z = z

    def __repr__(self):
        return "Point(%f, %f, %f)" % (self.x, self.y, self.y)

    cdef void set(self, float x, float y, float z):
        self.x = x
        self.y = y
        self.z = z

    def __copy__(self):
        return Point(self.x , self.y, self.z)

cdef Point _zero_point = Point(0,0,0)


######################################################################
# Spatial objects for use with the gravitation stuff

cdef class SpatialObject:
    """
    The base class for all geometric objects.
    """

    def __cinit__(self):
        self.gt = null

    cdef _setType(self, GeomType gt):
        self.gt = gt

    cpdef GeomType gtype(self):
        return self.gt

    def setDensity(self, float d):
        self._setDensity(d)

    cdef void _setDensity(self, float d):
        self._density = d

    cpdef float density(self):
        return self._density

    cpdef float mass(self):
        return self.volume()*self.density()
        

    ########################################
    # Methods meant to be overridden by subclasses
    
    cpdef ar sample(self, size_t requested_size):
        """
        Returns a random point inside the geomertric shape.  
        """
        return self.sample_nonrandom(1)

    cpdef ar sample_nonrandom(self, size_t depth):
        return np.zeros( (1, 3), dtype=np.float32)

    cpdef bint isInside(self, Point p):
        return False

    cpdef float volume(self):
        return 0

    def __copy__(self):
        return SpatialObject()

    

######################################################################
# A single point mass

cdef class PointMass(SpatialObject):
    """
    Defines a point mass.  Essentially implemented as an object that
    has volume 1 but behaves as a point.
    """
    
    # Methods
    def __cinit__(self):
        self._setType(point_mass)

    def __init__(self, Point pt not None, float mass):
        self.p = copy(pt)
        self.setDensity(mass)
        
    def __copy__(self):
        return PointMass(self.p, self.density())

    cpdef ar sample(self, size_t requested_size):
        return self.sample_nonrandom(1)

    cpdef ar sample_nonrandom(self, size_t depth):
        cdef ar[float, ndim=2] r = np.empty( (1, 3), dtype=np.float32)

        r[0,0] = self.p.x
        r[0,1] = self.p.y
        r[0,2] = self.p.z

        return r

    cpdef bint isInside(self, Point p):
        return False

    cpdef float volume(self):
        return 1

    cpdef Point point(self):
        return self.p



######################################################################
# Box 

cdef class Box(SpatialObject):
    """
    Defines a box object.
    """

    def __init__(self, Point p1 = None, Point p2=None, float density = 0):
        if p1 is None or p2 is None:
            self.setCoordsDirect(_zero_point, _zero_point)
        else:
            self.setCoords(p1, p2)

        self.setDensity(density)

    def __copy__(self):
        return Box(Point(self.xmin, self.ymin, self.zmin),
                   Point(self.xmax, self.ymax, self.zmax), 
                   self.density())

    cpdef bint isInside(self, Point p):
        return (p.x >= self.xmin and p.x <= self.xmax 
                and p.y >= self.ymin and p.y <= self.ymax 
                and p.z >= self.zmin and p.z <= self.zmax)

    cpdef ar sample(self, size_t target_count):
        """
        Returns a random point inside the geomertric shape.  
        """
        cdef ar[float, ndim=2] X = np.empty( (target_count, 3), dtype=np.float32)

        X[:, 0] = rn.uniform(self.xmin, self.xmin+self.wx, size=target_count)
        X[:, 1] = rn.uniform(self.ymin, self.ymin+self.wy, size=target_count)
        X[:, 2] = rn.uniform(self.zmin, self.zmin+self.wz, size=target_count)

        return X

    cpdef ar sample_nonrandom(self, size_t target_count):

        # As we will have n*n*n samples
        cdef size_t n = self.get_gridsize(target_count)

        # We'll reshape this before returning it
        cdef ar[float, ndim=2] X = np.empty( (n*n*n, 3), dtype=np.float32)
        self.setNonrandomSamples(X, n)

        return X

    cdef tuple get_nonrandom_array_size(self, size_t target_count):
        cdef size_t n = self.get_gridsize(target_count)
        return (n*n*n, 3)

    cdef size_t get_gridsize(self, size_t target_count):
        return <size_t>ceil(target_count ** (1.0 / 3) )

    cdef void setNonrandomSamples(self, ar X_o, size_t grid_size):

        cdef size_t n = grid_size
        cdef size_t nx = 1, ny = n, nz = n*n

        # Faster to do the indexing manually; assume contiguousness
        # This also allows arrays to be of different sizes
        cdef ar[float, mode="c"] X = X_o
        cdef size_t xi, yi, zi
        
        for xi from 0 <= xi < n:
            for yi from 0 <= yi < n:
                for zi from 0 <= zi < n:
                    idx = 3*(xi*nx + yi*ny +zi*nz)
                    X[idx + 0] = self.xmin + (self.wx/n)*(xi + 0.5)
                    X[idx + 1] = self.ymin + (self.wy/n)*(yi + 0.5) 
                    X[idx + 2] = self.zmin + (self.wz/n)*(zi + 0.5)

    cpdef float volume(self):
        return self.wx * self.wy * self.wz

    ############################################################
    # Now some methods particular to the Box stuff 

    def setCoords(self, Point p1 not None, Point p2 not None):
        """
        Sets the coordinates of the box to the specified points.
        These points may be the coordinates of any two opposite
        corners of the box.
        """
        
        self.xmin = min(p1.x, p2.x)
        self.ymin = min(p1.y, p2.y)
        self.zmin = min(p1.z, p2.z)

        self.xmax = max(p1.x, p2.x)
        self.ymax = max(p1.y, p2.y)
        self.zmax = max(p1.z, p2.z)

        self._setWidths()


    cdef void setCoordsDirect(self, Point p1, Point p2):
        """
        Sets the coordinates of the box to the specified points
        without doing any checking for +/- differences, etc.  Faster
        but less safe.
        """

        self.xmin, self.ymin, self.zmin = p1.x, p1.y, p1.z
        self.xmax, self.ymax, self.zmax = p2.x, p2.y, p2.z

        self._setWidths()
        
    cdef void _setWidths(self):
        self.wx = self.xmax - self.xmin
        self.wy = self.ymax - self.ymin
        self.wz = self.zmax - self.zmin



######################################################################
# Functions for checking whether the geometric objects overlap

cdef inline bint _overlaps_BoxBox(Box b1, Box b2):
    return ( (b1.xmin > b2.xmax or b1.xmax < b2.xmin)
             and (b1.ymin > b2.ymax or b1.ymax < b2.ymin)
             and (b1.zmin > b2.zmax or b1.zmax < b2.zmin) )

cdef inline bint _overlaps_BoxPointmass(Box b, PointMass p):
    return b.isInside(p)

cpdef bint overlaps(o1, o2):
    """
    Tests to see whether objects o1 and o2 overlap.
    """

    if isinstance(o1, Box) and isinstance(o2, Box):
        return _overlaps_BoxBox(o1, o2)
    elif isinstance(o1, PointMass) and isinstance(o2, Box):
        return _overlaps_BoxPointmass(o1, o2)
    elif isinstance(o2, PointMass) and isinstance(o1, Box):
        return _overlaps_BoxPointmass(o2, o1)
    else:
        return False

