""" plot_reject_example.py
    
    Example of using rejection sampling for Brownian motion and plotting results using ggplot2 via rpy2
    
    Comment out the 
"""

import numpy
numpy.random.seed(42)
import math
import rpy2.robjects as ro
import rpy2.robjects.numpy2ri
import rpy2.rlike.container as rlc
import rpy2.robjects.lib.ggplot2 as ggplot2

class Probz:
    """Gaussian distribution for particle position for this example"""
    def __init__(self,z0,tc,DL):
        self.z0=z0
        self.tc=tc
        self.DL=DL
    
    def __call__(self,z):
        z0,tc,DL=self.z0,self.tc,self.DL
        t=DL*tc
        return (4*math.pi*t)**(-0.5)*numpy.exp(-(z-z0)**2/4./t)

def probz_reject(probz,sigma,nstdev=3,n=1000):
    """ Rejection method for probability distribution given by an instance of Probz
        Returns a list of samples from lammdist
    """
    z0=probz.z0
    xmn=z0-nstdev*sigma
    xmx=z0+nstdev*sigma
    ymx=0.1
    
    xsamps=numpy.random.uniform(xmn,xmx,size=n)
    U=numpy.random.uniform(size=n)
    xsamps=xsamps[U<=probz(xsamps)/ymx]
    return xsamps

def main():
    DL=2.7e-11*(1e6)**2 # Diffusivity, um^2
    tc=.5 # timestep, s
    nstdev=3 # number of standard deviations to consider
    sigma=math.sqrt(2*DL*tc) # um
    
    z0=50
    z=numpy.linspace(z0-nstdev*sigma,z0+nstdev*sigma,100)
    probz=Probz(z0,tc,DL)
    pz=probz(z)
    pz_r=probz_reject(probz,sigma,nstdev=nstdev,n=10000)
    
    Rz=ro.Vector(z)
    Rpz=ro.Vector(pz)
    Rpz_r=ro.Vector(pz_r)
    
    tl=rlc.TaggedList([Rz,Rpz],tags=('z','dens'))
    datafline=ro.DataFrame(tl)
    
    tl=rlc.TaggedList([Rpz_r],tags=('z'))
    datafhist=ro.DataFrame(tl)
    
    bw=1 # histogram binwidth
    pp=ggplot2.ggplot(datafhist)
    pp=pp+ggplot2.geom_histogram(ggplot2.aes_string(x='z',y='..density..'),binwidth=bw)
    
    ## Use one of the following lines to draw the line graph. Either one currently results in failure
    #pp=pp+ggplot2.geom_line(datafline,ggplot2.aes_string(x='z',y='dens'))
    pp=pp+ggplot2.geom_line(data=datafline,ggplot2.aes_string(x='z',y='dens'))
    
    return pp
    
if __name__ == '__main__':
    pp=main()
    pp.plot()