#$Id: sa_gentestbase.py 652 2007-01-23 18:10:06Z sdobrev $
from sqlalchemy import *
import unittest

class Base( object):
    idname = 'id'
    def __str__( self):
        r = self.__class__.__name__ + '('
        for k in self.props:
            v = getattr( self, k, '<notset>')
            if isinstance( v, Base):
                v = '>'+str(v.data)
            r+= ' '+k+'='+str(v)
        return r+' )'

class config:
    echo = False
    dump = False
    debug = False
    session_clear = True

class Test_AB0( unittest.TestCase):
    def setUp(me):
        if config.debug or config.echo:
            print '=====', me.id()
        me.db = create_engine( 'sqlite:///:memory:')
        me.meta = BoundMetaData( me.db)
        me.meta.engine.echo = config.echo
    def tearDown(me):
        me.meta = None
        #destroy ALL caches
        from sqlalchemy.orm import mapperlib
        mapperlib.global_extensions[:] = []
        mapperlib.mapper_registry.clear()
        me.db.dispose()

    def query( me, session, A,B, table_A,table_B, ida,idb, sa,sb, samulti,sbmulti ):
        if config.debug:
            print sa
            print sb
        if config.session_clear: session.clear()
        if config.dump:
            print list( table_A.select().execute() )
            print list( table_B.select().execute() )

        #single
        q = session.query( A).get_by( **{A.idname:ida})
        me.assertEqual( sa, str(q))

        q = session.query( B).get_by( **{B.idname:idb})
        me.assertEqual( sb, str(q))

        #multiple
        q = session.query( A).select()
        x = [ str(z) for z in q ]
        x.sort()
        samulti.sort()
        me.assertEqual( samulti, x)

        q = session.query( B).select()
        x = [ str(z) for z in q ]
        x.sort()
        sbmulti.sort()
        me.assertEqual( sbmulti, x)

def setup():
    import sys
    sys.setrecursionlimit( 60)
    for k in [ 'echo', 'debug', 'dump', 'no_session_clear']:
        v = k in sys.argv
        if v: sys.argv.remove(k)
        if k.startswith('no_'):
            k = k[3:]
            v = not v
        setattr( config, k, v)
    print 'config:', ', '.join( '%s=%s' % (k,v) for k,v in config.__dict__.iteritems() if not k.startswith('__') )
# vim:ts=4:sw=4:expandtab
