#slprofile, the profile module adapted for stackless and blue

import profile
import stackless
import sys
import blue
import time
import traceback
import thread
import functools
import log

stackless.globaltrace = True


#This class provides the start() and stop() functions
class Profile(profile.Profile):
    base = profile.Profile
    def __init__(self, timer = None, bias=None):
        self.current_tasklet = stackless.getcurrent()
        self.thread_id = thread.get_ident()
        if timer is None: timer = Timer()
        self.base.__init__(self, timer, bias)
        self.sleeping = {}


    def __call__(self, *args):
        "make callable, allowing an instance to be the profiler"
        r =  self.dispatcher(*args)

    def _setup(self):
        self.cur, self.timings, self.current_tasklet = None, {}, stackless.getcurrent()
        self.thread_id = thread.get_ident()
        self.simulate_call("profiler")

    def start(self, name = "start"):
        if getattr(self, "running", False):
            return
        self._setup()
        self.simulate_call("start")
        self.running = True
        sys.setprofile(self.dispatcher)

    def stop(self):
        sys.setprofile(None)
        self.running = False
        self.TallyTimings()

    #special cases for the original run commands, makin sure to
    #clear the timer context.
    def runctx(self, cmd, globals, locals):
        self._setup()
        try:
            profile.Profile.runctx(self, cmd, globals, locals)
        finally:
            self.TallyTimings()

    def runcall(self, func, *args, **kw):
        self._setup()
        try:
            profile.Profile.runcall(self, func, *args, **kw)
        finally:
            self.TallyTimings()


    def trace_dispatch_return_extend_back(self, frame, t):
        """A hack function to override error checking in parent class.  It allows invalid returns
        (where frames weren't preveiously entered into the profiler) which can happen for
        all the tasklets that suddenly start to get monitored.
        This means that the time will eventually be attributed to a call
        high in the chain, when there is a tasklet switch
        """
        if isinstance(self.cur[-2], Profile.fake_frame):
            return False
            self.trace_dispatch_call(frame, 0)
        return self.trace_dispatch_return(frame, t);

    def trace_dispatch_c_return_extend_back(self, frame, t):
        #same for c return
        if isinstance(self.cur[-2], Profile.fake_frame):
            return False #ignore bogus returns
            self.trace_dispatch_c_call(frame, 0)
        return self.trace_dispatch_return(frame,t)


    #Add "return safety" to the dispatchers
    dispatch = dict(profile.Profile.dispatch)
    dispatch.update({
        "return": trace_dispatch_return_extend_back,
        "c_return": trace_dispatch_c_return_extend_back,
        })

    def SwitchTasklet(self, t0, t1, t):
        #tally the time spent in the old tasklet
        pt, it, et, fn, frame, rcur = self.cur
        cur = (pt, it+t, et, fn, frame, rcur)

        #we are switching to a new tasklet, store the old
        self.sleeping[t0] = cur, self.timings
        self.current_tasklet = t1

        #find the new one
        try:
            self.cur, self.timings = self.sleeping.pop(t1)
        except KeyError:
            self.cur, self.timings = None, {}
            self.simulate_call("profiler")
            self.simulate_call("new_tasklet")


    def ContextWrap(f):
        @functools.wraps(f)
        def ContextWrapper(self, arg, t):
            current = stackless.getcurrent()
            if current != self.current_tasklet:
                self.SwitchTasklet(self.current_tasklet, current, t)
                t = 0.0 #the time was billed to the previous tasklet
            return f(self, arg, t)
        return ContextWrapper

    #Add automatic tasklet detection to the callbacks.
    dispatch = dict([(key, ContextWrap(val)) for key,val in dispatch.iteritems()])


    def TallyTimings(self):
        oldtimings = self.sleeping
        self.sleeping = {}

        #first, unwind the main "cur"
        self.cur = self.Unwind(self.cur, self.timings)

        #we must keep the timings dicts separate for each tasklet, since it contains
        #the 'ns' item, recursion count of each function in that tasklet.  This is
        #used in the Unwind dude.
        for tasklet, (cur,timings) in oldtimings.iteritems():
            self.Unwind(cur, timings)

            for k,v in timings.iteritems():
                if k not in self.timings:
                    self.timings[k] = v
                else:
                    #accumulate all to the self.timings
                    cc, ns, tt, ct, callers = self.timings[k]
                    #ns should be 0 after unwinding
                    cc+=v[0]
                    tt+=v[2]
                    ct+=v[3]
                    for k1,v1 in v[4].iteritems():
                        callers[k1] = callers.get(k1, 0)+v1
                    self.timings[k] = cc, ns, tt, ct, callers


    def Unwind(self, cur, timings):
        "A function to unwind a 'cur' frame and tally the results"
        "see profile.trace_dispatch_return() for details"
        #also see simulate_cmd_complete()
        while(cur[-1]):
            rpt, rit, ret, rfn, frame, rcur = cur
            frame_total = rit+ret

            if rfn in timings:
                cc, ns, tt, ct, callers = timings[rfn]
            else:
                cc, ns, tt, ct, callers = 0, 0, 0, 0, {}

            if not ns:
                ct = ct + frame_total
                cc = cc + 1

            if rcur:
                ppt, pit, pet, pfn, pframe, pcur = rcur
            else:
                pfn = None

            if pfn in callers:
                callers[pfn] = callers[pfn] + 1  # hack: gather more
            elif pfn:
                callers[pfn] = 1

            timings[rfn] = cc, ns - 1, tt + rit, ct, callers

            ppt, pit, pet, pfn, pframe, pcur = rcur
            rcur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
            cur = rcur
        return cur


class Timer(object):
    #This class creates a seconds timer initialized to zero on creation, using the hypersensitive bluetimer if possible.
    def __init__(self):
        self.startTime = blue.os.GetTime(1)

    def __call__(self):
        return float(blue.os.GetTime(1) - self.startTime) *1e-7
