from sqlalchemy import *
from sqlalchemy import event
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
import collections
from sqlalchemy.orm.collections import collection, collection_adapter,\
    InstrumentedList, prepare_instrumentation, _instrument_class
from sqlalchemy.ext.associationproxy import association_proxy, _AssociationCollection
from sqlalchemy.event import dispatcher
 

class GroupedByKeyList(InstrumentedList):
    '''List containing objects which share a common and constant key function return value (= the grouping)
    
    A list which mutates objects inserted or removed from it by using key_setter. The list can only contain
    objects for which key_getter(obj)==key_value
    
    When an object is removed from the list key_setter(obj,key_null_value) is set.
    
    If key_attr is set :class:`GroupedByKeyList` listens for mutating events on 
    the contained objects via the :class:`sqlalchemy.orm.events` interface. The object is removed from
    the list if  key_getter(object) == key_value does not hold any more after the mutation has happened 

    
    The list can contain any objects which can be supplied as arguments to the key_getter and key_setter
    callables.
    '''
    def _default_key_getter(self,obj):
        return getattr(obj, self.key_attribute)
    def _default_key_setter(self,obj,value):
        return setattr(obj, self.key_attribute,value)


    def __init__(self, key_value,  key_attribute=None, key_getter=None,key_setter=None, key_null_value=None ):
        '''initializes the list. 
        
        The list must be supplied with a key getter and a key setter,
        or a key_attribute wich is read and set according to the :class:`GroupedByKeyList` setting
        :param key_value: The return value for all objects contained in this List
        :param key_attribute: attribute name string which is used to generate default getters and setterrs
        :param key_getter: getter callable which takes 1 argument, the contained object 
        and returns a value comparable to key_value. The key_getter argument overrides the default getter
        getattr(object, key_attribute)
        :param key_setter: setter callable which takes 2 arguments, the object and a value. The value provided
        will be the key_value parameter
        :param key_null_value: Call key_setter(obj,key_null_value) when the object is removed from the list 
        '''
        
        if not(key_attribute or (key_getter and key_setter)):
            raise ValueError("please supply either key_attribute or (key_getter and key_setter)") 
        if key_attribute:
            self.key_getter= self._default_key_getter
            self.key_setter= self._default_key_setter
            self.key_attribute=key_attribute
        if key_setter:
            self.key_setter=key_setter
        if key_getter:
            self.key_getter=key_getter
        self.key_null_value=key_null_value
        
    def append(self, obj):
        if not self.key_val==self.key_getter(obj):
            self.key_setter(obj, self.key_val)
        elif not obj in  self: 
            if getattr(self, 'key_attribute',None):
                def set_listener(target, value, oldvalue, initiator ):
                    if self.key_getter() != self.key_val:
                        adapter = collection_adapter(self)
                        adapter.remove(target, mutate=False)
                event.listen(getattr(obj, self.key_attribute),"set", set_listener )
        super(GroupedByKeyList, self).append(obj)
    def remove(self,obj,mutate=True):
        if mutate:
            self.key_setter(obj, self.key_null_value)
        super(GroupedByKeyList, self).remove(obj)

__instance_for_instrumentation=GroupedByKeyList("Just for the sake of prepare_instrumentation", key_attribute="blah")
''' This is just usaed in order to make it possible call
prepare_instrumentation which other instantiates the object using
the no-arg constructor, which is not available for :class:`GroupedByKeyList`
TODO: add a prepare_instrumentation which accepts a class argument,
or check whether supplied type is already a class (as opposed to factory callable)
'''
prepare_instrumentation( lambda : __instance_for_instrumentation)    
del __instance_for_instrumentation
_instrument_class(GroupedByKeyList)

class GroupByKeyCollection(collections.defaultdict):
    def _default_key_getter(self,obj):
        return getattr(obj, self.key_attribute)
    def _default_key_setter(self,obj,value):
        return setattr(obj, self.key_attribute,value)

    def __init__(self, key_attribute=None, key_getter=None,key_setter=None, key_null_value=None):
        super(GroupByKeyCollection, self).__init__()
        if not(key_attribute ):
            raise ValueError("please supply the key_attribute for the contained objects") 
        self.key_attribute=key_attribute
        self.key_null_value=key_null_value
        if key_setter:
            self.key_setter=key_setter
        else:
            self.key_getter= self._default_key_getter
        if key_getter:
            self.key_getter=key_getter
        else:
            self.key_setter= self._default_key_setter

    
    def __missing__(self, key):
        l= GroupedByKeyList(key, key_attribute=self.key_attribute,
                                         key_getter = self.key_getter, 
                                         key_setter=self.key_setter, 
                                         key_null_value=self.key_null_value)
        def remove_listener(target, value,initiator):
            key=self.key_getter(value)
            if key != self.key_null_value:
                self[key].append(value)
        l.parent = self
        import pdb
        pdb.set_trace()
        event.listen(l, 'remove', remove_listener)
        super(GroupByKeyCollection,self)[key]=l
        return l
    
    @collection.appender
    def add(self, value, _sa_initiator=None):
        key = self.key_getter(value)
        self[key].append(value)

    @collection.remover
    def remove(self, value, _sa_initiator=None):
        key = self.key_getter(value)
        self[key].remove(value)

    @collection.internally_instrumented
    def __setitem__(self, key, value):
        adapter = collection_adapter(self)
        # the collection API usually provides these events transparently, but due to
        # the unusual structure, we pretty much have to fire them ourselves
        # for each item.
        for item in value:
            item = adapter.fire_append_event(item, None)
        collections.defaultdict.__setitem__(self, key, value)

    @collection.internally_instrumented
    def __delitem__(self, key, value):
        adapter = collection_adapter(self)
        for item in value:
            item = adapter.fire_remove_event(item, None)
        collections.defaultdict.__delitem__(self, key, value)

    @collection.iterator
    def iterate(self):
        for collection in self.values():
            for item in collection:
                yield item

    @collection.converter
    def _convert(self, target):
        for collection in target.values():
            for item in collection:
                yield item

    def update(self, k):
        raise NotImplementedError()


class AssociationGBK(_AssociationCollection):
    def __init__(self, lazy_collection, creator, value_attr, parent):
        getter, setter = parent._default_getset(parent.collection_class)
        super(AssociationGBK, self).__init__(
                lazy_collection, creator, getter, setter, parent)
        
        
    def _create(self, key, value):
        return self.creator(key, value)

    def _get(self, object):
        return self.getter(object)

    def _set(self, object, key, value):
        return self.setter(object, key, value)

    def __getitem__(self, key):
        return [self._get(item) for item in self.col[key]]

    def __setitem__(self, key, value):
        self.col[key] = [self._create(key, item) for item in value]

    def add(self, key, item):
        self.col.add(self._create(key, item))

    def items(self):
        return ((key, [self._get(item) for item in self.col[key]])
                for key in self.col)

    def update(self, kw):
        for key, value in kw.items():
            self[key] = value

    def clear(self):
        self.col.clear()

    def copy(self):
        return dict(self.items())

    def __repr__(self):
        return repr(dict(self.items()))


