import pymel.core as pm
import pymel.api as api

def isAnimated(node, checkParent):
    if not isinstance(node, pm.api.MObject):
        if not isinstance(node, pm.PyNode):
            node = pm.PyNode(node)
        node = node.__apimobject__()
    # get all plugs with connections on this node...
    plugs = api.MPlugArray()
    _getIncomingPlugs(node, plugs)
    checkedPlugs = {}
    for i in xrange(plugs.length()):
        if _isAnimatedPlug(plugs[i], checkParent, checkedPlugs):
            return True
    return False

def _isAnimatedPlug(plug, checkParent, checkedPlugs):
    if plug in checkedPlugs:
        return checkedPlugs[plug]
    # set value to this plug to true - this means that if there's a cycle, 
    # then we'll just assume that the plug is animated... not totally
    # accurate, but easier, and assuming that plugs ARE animated is "safer"
    # (and besides, you shouldn't have any DG cycles anyway!)
    checkedPlugs[plug] = True
    newVal = _isAnimatedPlug2(plug, checkParent, checkedPlugs)
    checkedPlugs[plug] = newVal
    print "_isAnimatedPlug(%s) - %s" % (plug.partialName(True), newVal)
    return newVal

# split this out, just so that we can return a True/False value at any point,
# without having to "unwind" and then set checkedPlugs...
def _isAnimatedPlug2(plug, checkParent, checkedPlugs):
    
    # do the "easy" check that autodesk gives us first...
    if pm.api.MAnimUtil.isAnimated(plug, checkParent):
#        print "animated!"
        return True
    elif plug.node().hasFn(api.MFn.kTime) and plug.partialName() in ('o', 'unw'):
        return True

    # now do a check for world space attributes...
    attr = pm.api.MFnAttribute(plug.attribute())
    if attr.isWorldSpace():
        worldPlugs = pm.api.MPlugArray()
        _getWorldSpacePlugs(plug.node(), worldPlugs)
        for i in xrange(worldPlugs.length()):
            worldPlug = worldPlugs[i]
            if _isAnimatedPlug(worldPlug, checkParent, checkedPlugs):
                return True

    # now, iterate upstream through the DG...
    iter = pm.api.MItDependencyGraph(plug, pm.api.MFn.kInvalid,
        pm.api.MItDependencyGraph.kUpstream,
        pm.api.MItDependencyGraph.kDepthFirst,
        pm.api.MItDependencyGraph.kPlugLevel)
    iter.setTraversalOverWorldSpaceDependents(True)
    # the first "plug" is a weird "attribute-less" plug - skip it...
    if not iter.isDone():
        iter.next()
    
    while not iter.isDone():
        currentPlug = iter.thisPlug()
        if _isAnimatedPlug(currentPlug, checkParent, checkedPlugs):
            return True
        iter.next()

    return False

def _getIncomingPlugs(object, plugs):
    iter = pm.api.MItDependencyGraph(object, pm.api.MFn.kInvalid,
        pm.api.MItDependencyGraph.kUpstream,
        pm.api.MItDependencyGraph.kBreadthFirst,
        pm.api.MItDependencyGraph.kPlugLevel)
    # the first "plug" is a weird "attribute-less" plug - skip it...
    if not iter.isDone():
        iter.next()
    while not iter.isDone():
        currentPlug = iter.thisPlug()
        if currentPlug.node() != object:
            break
        plugs.append(currentPlug)
        iter.next()
    return plugs
isAnimated_new = isAnimated

def _getWorldSpacePlugs(object, plugs):
    mfnNode = api.MFnDependencyNode()
    mfnAttr = api.MFnAttribute()
    nodes = api.MSelectionList()
    _getParentHier(object, nodes)
    currentObj = api.MObject()
    for i in xrange(nodes.length()):
        nodes.getDependNode(i, currentObj)
        mfnNode.setObject(currentObj)
        for j in xrange(mfnNode.attributeCount()):
            attrObj = mfnNode.attribute(j)
            mfnAttr.setObject(attrObj)
            if mfnAttr.isAffectsWorldSpace():
                plug = mfnNode.findPlug(attrObj, True)
                plugs.append(plug)
    return plugs

def _getParentHier(object, nodes):
    # use an MSelectionList to avoid duplicates
    if not object.hasFn(api.MFn.kDagNode):
        return nodes
    dagPaths = api.MDagPathArray()
    dagNode = api.MFnDagNode(object)
    dagNode.getAllPaths(dagPaths)
    for i in xrange(dagPaths.length()):
        dagPath = dagPaths[i]
        # pop off the last item - that will always be the start object
        dagPath.pop()
        while dagPath.length():
            nodes.add(dagPath.node())
            dagPath.pop()
    return nodes

#print isAnimated('pPlane1')
#print isAnimated('polySurfaceShape1')

#pm.api.MAnimUtil.isAnimated(pm.PyNode('pPlane1').__apimobject__(), True)
#pm.api.MAnimUtil.isAnimated(pm.PyNode('pPlane1').__apimobject__(), False)

def isAnimated_old(node, checkParent):
    if not isinstance(node, pm.api.MObject):
        if not isinstance(node, pm.PyNode):
            node = pm.PyNode(node)
        node = node.__apimobject__()
    iter = pm.api.MItDependencyGraph(node, pm.api.MFn.kInvalid,
        pm.api.MItDependencyGraph.kUpstream,
        pm.api.MItDependencyGraph.kDepthFirst,
        pm.api.MItDependencyGraph.kPlugLevel)

    # the first "plug" is a weird "attribute-less" plug - skip it...
#    if not iter.isDone():
#        iter.next()
    while not iter.isDone():
        node = iter.thisNode()
        plug = iter.thisPlug()
        print "checking: %s - %s" % (pm.PyNode(node), pm.PyNode(plug))

        if (node.hasFn(api.MFn.kPluginDependNode) or
                node.hasFn(api.MFn.kConstraint ) or
                node.hasFn(api.MFn.kPointConstraint) or
                node.hasFn(api.MFn.kAimConstraint) or
                node.hasFn(api.MFn.kOrientConstraint) or
                node.hasFn(api.MFn.kScaleConstraint) or
                node.hasFn(api.MFn.kGeometryConstraint) or
                node.hasFn(api.MFn.kNormalConstraint) or
                node.hasFn(api.MFn.kTangentConstraint) or
                node.hasFn(api.MFn.kParentConstraint) or
                node.hasFn(api.MFn.kPoleVectorConstraint) or
                node.hasFn(api.MFn.kParentConstraint) or
                node.hasFn(api.MFn.kTime) or
                node.hasFn(api.MFn.kJoint) or
                node.hasFn(api.MFn.kGeometryFilt) or
                node.hasFn(api.MFn.kTweak) or
                node.hasFn(api.MFn.kPolyTweak) or
                node.hasFn(api.MFn.kSubdTweak) or
                node.hasFn(api.MFn.kCluster) or
                node.hasFn(api.MFn.kFluid) or
                node.hasFn(api.MFn.kPolyBoolOp)):
            return True

        if node.hasFn(api.MFn.kExpression):
            fn = api.MFnExpression(node)
            if fn.isAnimated():
                return True

        if api.MAnimUtil.isAnimated(node, checkParent):
            return True
        iter.next()


    return False

test_objs = ['ggGA:outerFeather17', 'ggGA:R_feather_geo|ggGA:mid_feather_geo', 'ggGA:R_feather_geo', 'ggGA:feathers_geo']
def test(test_objs):
    for checkParent in (False, True):
        for obj in test_objs:
            new = isAnimated(obj, checkParent)
            old = isAnimated_old(obj, checkParent)
            if old and not new:
                print "obj: %s - checkParent: %s - new: %s - old: %s" % (obj, checkParent, new, old)
