"""tests of joined-eager loaded attributes"""

from test.lib.testing import eq_, is_, is_not_
import sqlalchemy as sa
from test.lib import testing
from sqlalchemy.orm import joinedload, deferred, undefer, \
    joinedload_all, backref, eagerload, Session, immediateload
from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, \
    func
from test.lib.schema import Table, Column
from sqlalchemy.orm import mapper, relationship, create_session, \
    lazyload, joinedload, aliased, column_property, query
from sqlalchemy.sql import operators
from test.lib.testing import eq_, assert_raises, \
    assert_raises_message
from test.lib.assertsql import CompiledSQL
from test.lib import fixtures
from test.orm import _fixtures
from sqlalchemy.util import OrderedDict as odict
import datetime

class CacheableQuery(query.Query):
    def __init__(self, *p, **kw):
        self._cached_context = None
        self._cached_context_labels = None
        super(CacheableQuery, self).__init__(*p, **kw)
    
    @query._generative()
    def bake(self, labels=True):
        context = super(CacheableQuery, self)._compile_context(labels)
        self._cached_context_labels = labels
        self._cached_context = context
    
    def _compile_context(self, labels=True):
        context = super(CacheableQuery, self)._compile_context(labels)
        if self._cached_context and self._cached_context_labels == labels:
            cached = self._cached_context
            context.statement = cached.statement
        return context

class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
    run_inserts = 'once'
    run_deletes = None

    def test_baked(self):
        users, Address, addresses, User = (self.tables.users,
                                self.classes.Address,
                                self.tables.addresses,
                                self.classes.User)

        mapper(User, users, properties={
            'address':relationship(mapper(Address, addresses), order_by=Address.id, uselist=False)
        })
        sess = create_session(query_cls=CacheableQuery)
        cache = {}
        q = sess.query(User).filter(User.id == sa.bindparam("id")).options(joinedload(User.address)).execution_options(compiled_cache = cache).bake()
        sess.close()

        sess = create_session(query_cls=CacheableQuery)
        u = q.with_session(sess).params(id=7).first()
        sess.close()
        eq_(User(id=7, address=Address(id=1, email_address='jack@bean.com')), u)
        eq_(1, len(cache))

