import matplotlib
matplotlib.use("Agg")
from yt.mods import *
import pylab
import sys
import numpy as na
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from mpl_toolkits.axes_grid1.parasite_axes import SubplotHost


frame_template = "aaSN/%s/StarMassExp_%s_%s.eps"


if int(sys.argv[1]) == 1:
    name = ["%s" %(sys.argv[5])]
   
start = int(sys.argv[2])
final = int(sys.argv[3])
freq = 1

folder = sys.argv[4]

rangee = final +1 - start

masses = []
massesMsun = []
number = []
time = []
iii = 0

for j in range(int(sys.argv[1])):
    i = start
    pf = load("SciNet/%s/%s/DD%04i/data%04i" % (folder,name[j], i,i))
    sp = pf.h.all_data()
    gasmass =  sp["CellMassMsun"].sum()
    mj = []
    mn = []
    nj = []
    tj = []
    for i in range(start,final+1, freq):
        print name[j],"dump:", i
        npart = 0
        pf = load("SciNet/%s/%s/DD%04i/data%04i" % (folder,name[j], i,i))
        if pf is None: continue
        if na.any(pf.h.grid_particle_count):
            iii += 1
            sp = pf.h.all_data()
            for k in range(sp["ParticleMassMsun"].size):
                if sp["ParticleMassMsun"][k] > 5.e-2:
                    npart += 1 
            meanmass = sp["ParticleMassMsun"].sum()/npart
            maxmass = sp["ParticleMassMsun"].max()
            print "number of particles = ",npart, "mean mass =",meanmass, "max mass = ",maxmass
            print "INITIAL TIME =", pf["InitialTime"]
            print "total mass in stars, Msun = ", sp["ParticleMassMsun"].sum()
            print "total mass in cells, Msun = ", sp["CellMassMsun"].sum()
            mj.append(sp["ParticleMassMsun"].sum()/gasmass)
            mn.append(sp["ParticleMassMsun"].sum())
            nj.append(npart)
            tj.append(pf["InitialTime"]*pf["years"]/1.e6)
        else:
            print "No stars", i, iii
            mj.append(0.0)
            mn.append(0.0)
            nj.append(0)
            tj.append(pf["InitialTime"])
    print mj
    masses.append(mj)
    massesMsun.append(mn) 
    number.append(nj)
    time.append(tj)
    print number[j]
    

linewidthnum = [2.0,2.05,2.1,2.15,2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8,2.85,2.9,2.95,3.0]
print time
print massesMsun

fractions = []
for i in range (100):
    fractions.append(i*0.01)


fig = plt.figure()
ax_mass = SubplotHost(fig, 1,1,1, aspect=1.)
frac_to_mass = 1.0/gasmass
aux_trans = mtransforms.Affine2D().scale(1.,frac_to_mass )
ax_frac = ax_mass.twin(aux_trans)
ax_frac.set_viewlim_mode("transform")

fig.add_subplot(ax_mass)

ax_mass.plot(time,massesMsun)

ax_mass.axis["bottom"].set_label("Time (Myrs)" )
ax_frac.axis["right"].set_label("Fraction of Original Gas Mass in Stars" )
ax_mass.axis["left"].set_label("Mass in Stars (Msun)" )
ax_mass.set_xlim(5, 7)
ax_mass.set_ylim(0,8000)
plt.savefig(frame_template %(folder,name[0],name[int(sys.argv[1])-1]), format="eps")

#python2.6 StellarMass.py 1 50 70 GMC GMCirbrp128


## import matplotlib.transforms as mtransforms
## import matplotlib.pyplot as plt
## from mpl_toolkits.axes_grid1.parasite_axes import SubplotHost

## obs = [["01_S1", 3.88, 0.14, 1970, 63],
##        ["01_S4", 5.6, 0.82, 1622, 150],
##        ["02_S1", 2.4, 0.54, 1570, 40],
##        ["03_S1", 4.1, 0.62, 2380, 170]]


## fig = plt.figure()

## ax_kms = SubplotHost(fig, 1,1,1, aspect=1.)

## # angular proper motion("/yr) to linear velocity(km/s) at distance=2.3kpc
## pm_to_kms = 1./206265.*2300*3.085e18/3.15e7/1.e5

## aux_trans = mtransforms.Affine2D().scale(pm_to_kms, 1.)
## ax_pm = ax_kms.twin(aux_trans)
## ax_pm.set_viewlim_mode("transform")

## fig.add_subplot(ax_kms)

## for n, ds, dse, w, we in obs:
##     time = ((2007+(10. + 4/30.)/12)-1988.5)
##     v = ds / time * pm_to_kms
##     ve = dse / time * pm_to_kms
##     ax_kms.errorbar([v], [w], xerr=[ve], yerr=[we], color="k")


## ax_kms.axis["bottom"].set_label("Linear velocity at 2.3 kpc [km/s]")
## ax_kms.axis["left"].set_label("FWHM [km/s]")
## ax_pm.axis["top"].set_label("Proper Motion [$^{''}$/yr]")
## ax_pm.axis["top"].label.set_visible(True)
## ax_pm.axis["right"].major_ticklabels.set_visible(False)

## ax_kms.set_xlim(950, 3700)
## ax_kms.set_ylim(950, 3100)
## # xlim and ylim of ax_pms will be automatically adjusted.

## plt.draw()
## plt.show()
