#$Id: sa_gentestbase.py 764 2007-01-29 19:15:16Z sdobrev $
from sqlalchemy import *
import unittest

class Base( object):
    def __str__( self):
        r = self.__class__.__name__ + '('
        for k in self.props:
            v = getattr( self, k, '<notset>')
            if isinstance( v, Base):
                v = '>'+str(v.data)
            r+= ' '+k+'='+str(v)
        return r+' )'

class config:
    echo = False
    dump = False
    debug = False
    session_clear = True
    reuse_db = False
    db = 'sqlite:///:memory:'
    repeat = 1

class Test_AB0( unittest.TestCase):
    _db = None
    def setUp(me):
        if config.debug or config.echo:
            print '=====', me.id()

        if config.reuse_db and me._db:
            db = me._db
        else:
            db = create_engine( config.db)
            if config.reuse_db:
                me._db = db

        me.db = db
        me.meta = BoundMetaData( db)
        me.meta.engine.echo = config.echo

    def tearDown(me):
        me.meta.drop_all()
        me.meta = None
        #destroy ALL caches
        clear_mappers()

        if not config.reuse_db:
            me.db.dispose()
        me.db = None

    def query( me, session, expects, idname ='id'):
        if config.debug:
            for item in expects: print item['exp_single']
        if config.session_clear: session.clear()
        if config.dump:
            for item in expects: print '\n'.join( item['table'].select().execute() )
        for item in expects:
            me.query1( session, idname=idname, **item)

    def query1( me, session, idname, klas, table, oid, exp_single, exp_multi):
        #single
        q = session.query( klas).get_by( **{idname: oid})
        me.assertEqual( exp_single, str(q),
                klas.__name__+'.getby_'+idname+'(): exp=%(exp_single)s, res=%(q)s' % locals()
            )

        #multiple
        q = session.query( klas).select()
        x = [ str(z) for z in q ]
        x.sort()
        exp_multi.sort()
        me.assertEqual( exp_multi, x,
                klas.__name__+'.select(): exp=%(exp_multi)s, res=%(x)s' % locals()
            )

    def run( self, *a, **k):
        for i in range( config.repeat):
            unittest.TestCase.run( self, *a,**k)


def setup():
    import sys
#    sys.setrecursionlimit( 600)
    for k in [ 'echo', 'debug', 'dump', 'no_session_clear', 'reuse_db' ]:
        v = k in sys.argv
        if v: sys.argv.remove(k)
        if k.startswith('no_'):
            k = k[3:]
            v = not v
        setattr( config, k, v)

    for a in sys.argv[1:]:
        kv = a.split('=')
        if len(kv)==2:
            k,v = kv
            if k=='db':
                config.db = v
                sys.argv.remove(a)
            elif k=='repeat':
                config.repeat = int(v)
                sys.argv.remove(a)

    print 'config:', ', '.join( '%s=%s' % (k,v) for k,v in config.__dict__.iteritems() if not k.startswith('__') )
# vim:ts=4:sw=4:expandtab
