
from sqlalchemy import *
import unittest

class Base( object):
    def __str__( self):
        r = self.__class__.__name__ + '('
        for k in self.props:
            v = getattr( self, k, '<notset>')
            if isinstance( v, Base):
                v = '>'+str(v.name)
            r+= ' '+k+'='+str(v)
        return r+' )'

class config:
    echo = False
    debug = False
    log_sa = True
    session_clear = True
    reuse_db = False
    leak = False
    memory = False
    db = 'sqlite:///:memory:'
    repeat = 1



class AB( unittest.TestCase):
    def test_ABC_poly_1__inh_t__Alazy_B__inhB_A__Blazy___BAlazy___inhC_B__Clazy___CAlazy___CBlazy_( me):
        meta=me.meta
        table_A = Table( 'A', meta,
            Column( 'linkA_id', Integer, ForeignKey( 'B.db_id',     name= 'linkA_id_fk',     use_alter= True, ), ),
            Column( 'name',   type= String, ),
            Column( 'db_id',   primary_key= True,   type= Integer, ),
            Column( 'atype',   type= String, ),
        )
        table_B = Table( 'B', meta,
            Column( 'dataB',   type= String, ),
            Column( 'db_id', Integer, ForeignKey( 'A.db_id', ),   primary_key= True, ),
        )
        table_C = Table( 'C', meta,
            Column( 'dataC',   type= String, ),
            Column( 'db_id', Integer, ForeignKey( 'B.db_id', ),   primary_key= True, ),
        )

        meta.create_all()

        class A( Base):
            props = ['db_id', 'linkA', 'name']
        class B( A):
            props = ['db_id', 'linkA', 'name', 'dataB']
        class C( B):
            props = ['db_id', 'linkA', 'name', 'dataC', 'dataB']

        select_tableA= table_A.select( table_A.c.atype == 'A', ).alias( 'bz4A' )
        if 1:
            pu_a = select_tableA
        else:
            pu_a = polymorphic_union( {
                        'A': select_tableA,
                        'B': select( [table_A, table_B.c.dataB], table_A.c.atype == 'B', from_obj= [join( table_A, table_B, table_B.c.db_id == table_A.c.db_id, )], ).alias( 'bz4B' ),
                        'C': select( [table_A, table_B.c.dataB, table_C.c.dataC], table_A.c.atype == 'C', from_obj= [join( join( table_A, table_B, table_B.c.db_id == table_A.c.db_id, ), table_C, table_C.c.db_id == table_B.c.db_id, )], ).alias( 'bz4C' ),
                        }, None, 'pu_a', ) #tableinh

        mapper_A = mapper( A, table_A,
                    polymorphic_identity= 'A',
                    polymorphic_on= pu_a.c.atype,
                    select_table= pu_a,
                    )
        mapper_A.add_property( 'linkA', relation( B,
                    foreign_keys= table_A.c.linkA_id,
                    lazy= True,
                    post_update= True,
                    primaryjoin= table_A.c.linkA_id == table_B.c.db_id,
                    remote_side= table_B.c.db_id,
                    uselist= False,
                    ) )

        pu_b = polymorphic_union( {
                        'B': select( [table_A, table_B.c.dataB], table_A.c.atype == 'B', from_obj= [join( table_A, table_B, table_B.c.db_id == table_A.c.db_id, )], ).alias( 'bz4B' ),
                        'C': select( [table_A, table_B.c.dataB, table_C.c.dataC], table_A.c.atype == 'C', from_obj= [join( join( table_A, table_B, table_B.c.db_id == table_A.c.db_id, ), table_C, table_C.c.db_id == table_B.c.db_id, )], ).alias( 'bz4C' ),
                        }, None, 'pu_b', ) #tableinh
        mapper_B = mapper( B, table_B,
                    inherit_condition= table_B.c.db_id == table_A.c.db_id,
                    inherits= mapper_A,
                    polymorphic_identity= 'B',
                    polymorphic_on= pu_b.c.atype,
                    select_table= pu_b,
                    )
        mapper_C = mapper( C, table_C,
                    inherit_condition= table_C.c.db_id == table_B.c.db_id,
                    inherits= mapper_B,
                    polymorphic_identity= 'C',
                    )



        #populate
        a = A()
        b = B()
        c = C()
        a.linkA = b
        a.name = 'aaaa'
        b.name = 'bbbb'
        b.dataB= 'my_b'
        c.name = 'cccc'
        c.dataC= 'my_c'

        session = create_session()
        session.save(a)
        session.save(b)
        session.save(c)
        session.flush()

        session.clear()
        for t in table_A,table_B,table_C:
            print 'table', t, list( t.select().execute())

        oa = session.query(A).get_by_db_id( a.db_id)
        print 'expect', a,  a.linkA
        print 'result', oa, oa.linkA

        me.assertEquals( oa.linkA.db_id, a.linkA.db_id)
        me.assertEquals( oa.linkA.name, a.linkA.name )
        me.assertEquals( a_str, str(oa) )

        """
2361 and before:
Traceback (most recent call last):
  File "_test-ABCD-t.py", line 118, in test_ABC_poly_1__inh_t__Alazy_B__inhB_A__Blazy___BAlazy___inhC_B__Clazy___CAlazy___CBlazy_
    me.assertEquals( oa.linkA.name, a.linkA.name )
AssertionError: u'ccccc' != 'bbbbb'

2362:
AssertionError: Could not find corresponding column for B.db_id in selectable SELECT "bz4C".db_id, "bz4C".name, CAST(NULL AS TEXT) AS "dataD", "bz4C"."linkA_id", "bz4C".atype, "bz4C"."dataC", "bz4C"."dataB"
FROM (SELECT "A"."linkA_id" AS "linkA_id", "A".name AS name, "A".db_id AS db_id, "A".atype AS atype, "B"."dataB" AS "dataB", "C"."dataC" AS "dataC"
FROM "A" JOIN "B" ON "B".db_id = "A".db_id JOIN "C" ON "C".db_id = "B".db_id
WHERE "A".atype = ?) AS "bz4C" UNION ALL SELECT "bz4B".db_id, "bz4B".name, CAST(NULL AS TEXT) AS "dataD", "bz4B"."linkA_id", "bz4B".atype, CAST(NULL AS TEXT) AS "dataC", "bz4B"."dataB"
FROM (SELECT "A"."linkA_id" AS "linkA_id", "A".name AS name, "A".db_id AS db_id, "A".atype AS atype, "B"."dataB" AS "dataB"
FROM "A" JOIN "B" ON "B".db_id = "A".db_id
WHERE "A".atype = ?) AS "bz4B" UNION ALL SELECT "D".db_id, "A".name, "D"."dataD", "A"."linkA_id", "A".atype, "C"."dataC", "B"."dataB"
FROM "A" JOIN "B" ON "B".db_id = "A".db_id JOIN "C" ON "C".db_id = "B".db_id JOIN "D" ON "D".db_id = "C".db_id


"""

    _db = None
    def setUp(me):
        if config.debug or config.echo:
            print '=====', me.id()

        if config.reuse_db and me._db:
            db = me._db
        else:
            db = create_engine( config.db)
            if config.reuse_db:
                me._db = db

        format ='* SA: %(levelname)s %(message)s'
        #plz no timestamps!
        if config.log_sa:
            import logging
            logging.basicConfig( level=logging.DEBUG, format=format, stream =logging.sys.stdout)
            logging.getLogger('sqlalchemy').setLevel( logging.DEBUG) #debug EVERYTHING!

        me.db = db
        me.meta = BoundMetaData( db)
        me.meta.engine.echo = config.echo

    def tearDown(me):
        me.meta.drop_all()
        me.meta = None
        #destroy ALL caches
        clear_mappers()

        if not config.reuse_db:
            me.db.dispose()
        me.db = None
        if config.leak:
            import gc, sqlalchemy
            gc.set_debug( gc.DEBUG_UNCOLLECTABLE | gc.DEBUG_SAVEALL | gc.DEBUG_INSTANCES | gc.DEBUG_STATS ) #OBJECTS
            gc.collect()
            #print "MAPPER REG:", dict(sqlalchemy.orm.mapperlib.mapper_registry)
            #print "SESION REG:", dict(sqlalchemy.orm.session._sessions)
            #print "CLASSKEYS:", dict(sqlalchemy.util.ArgSingleton.instances)
            i = 0
            for x in gc.get_objects():
                if isinstance(x, sqlalchemy.orm.Mapper) or isinstance(x, sqlalchemy.BoundMetaData):
                    i+=1
                    #print x
            print 'gc/sqlalc', i

    def run( self, *a, **k):
        for i in range( config.repeat):
            unittest.TestCase.run( self, *a,**k)
            if config.memory: memusage()

help = 'echo debug log_sa no_session_clear reuse_db leak memory'
def setup():
    import sys
    for h in ['help', '-h', '--help']:
        if h in sys.argv:
            print 'options:', help

    for k in help.split():
        v = k in sys.argv
        if v: sys.argv.remove(k)
        if k.startswith('no_'):
            k = k[3:]
            v = not v
        setattr( config, k, v)

    for a in sys.argv[1:]:
        kv = a.split('=')
        if len(kv)==2:
            k,v = kv
            if k=='db':
                config.db = v
            elif k=='repeat':
                config.repeat = int(v)
            else: continue
            sys.argv.remove(a)

    print 'config:', ', '.join( '%s=%s' % (k,v) for k,v in config.__dict__.iteritems() if not k.startswith('__') )

#_mem = ''
def memusage():
    import os
    pid = os.getpid()
    m = ''
    for l in file( '/proc/%(pid)s/status' % locals() ):
        l = l.strip()
        for k in 'VmPeak VmRSS VmData'.split():
            if l.startswith(k):
                m += '; '+l
    if m: print m
#            global _mem
#            if _mem != m:
#                _mem = m
#                print m

if __name__ == '__main__':
    setup()
    unittest.main()
