Hi,

As most of you people already know the SQLObject.select method does not
return a list of the selected instances directly, instead a
SelectResults instance is returned. This instance can then be used to
further manipulate the search query.
`SelectResults.filter` can be used to add another "where" clause.
`SelectResults.order_by` can be used to specify the sort order.
Python slicing of the instance will also be translated into the
corresponding SQL equivalent (result[10:20] translates to OFFSET 10
LIMIT 10)

I've attached a simple work-alike version of this for sqlalchemy. It's
currently not integrated with sqlalchemy at all since I'm not sure
what's the best way to do that, or if it should be done at all.

it currently works like this:

res = SelectResults(mapper, table.c.column == "something")
res = res.order_by([table.c.column]) #add an order clause

for x in res[:10]:  # Fetch and print the top ten instances
  print x.column2

x = list(res) # execute the query

# Count how many instances that have column2 > 42
# and column == "something"
print res.filter(table.c.column2 > 42).count()

You can find more examples in the unit tests or the sqlobject
documentation :)

Cheers,
Jonas
from sqlalchemy.sql import and_, select, func


class SelectResults(object):
    def __init__(self, mapper, clause=None, ops={}):
        self._mapper = mapper
        self._clause = clause
        self._ops = {}
        self._ops.update(ops)

    def count(self):
        return self._mapper.count(self._clause)

    def min(self, col):
        return select([func.min(col)], self._clause, **self._ops).scalar()

    def max(self, col):
        return select([func.max(col)], self._clause, **self._ops).scalar()

    def sum(self, col):
        return select([func.sum(col)], self._clause, **self._ops).scalar()

    def avg(self, col):
        return select([func.avg(col)], self._clause, **self._ops).scalar()

    def clone(self):
        return SelectResults(self._mapper, self._clause, self._ops.copy())
        
    def filter(self, clause):
        new = self.clone()
        new._clause = and_(self._clause, clause)
        return new

    def order_by(self, order_by):
        new = self.clone()
        new._ops['order_by'] = order_by
        return new

    def limit(self, limit):
        return self[:limit]

    def offset(self, offset):
        return self[offset:]

    def __getitem__(self, item):
        if isinstance(item, slice):
            start = item.start
            stop = item.stop
            if (isinstance(start, int) and start < 0) or \
               (isinstance(stop, int) and stop < 0):
                return list(self)[item]
            else:
                res = self.clone()
                if start != None and stop != None:
                    res._ops.update(dict(offset=start, limit=stop-start))
                elif start == None and stop != None:
                    res._ops.update(dict(limit=stop))
                elif start != None and stop == None:
                    res._ops.update(dict(offset=start))
                if item.step != None:
                    return list(res)[None:None:item.step]
                else:
                    return res
        else:
            return list(self[item:item+1])[0]
    
    def __iter__(self):
        return iter(self._mapper.select(self._clause, **self._ops))

from testbase import PersistTest
import testbase

from sqlalchemy import *
from sqlalchemy.mapping.sresults import SelectResults


class Foo(object):
    pass

    
class SelectResultsTest(PersistTest):
    def setUpAll(self):
        global foo
        foo = Table('foo', testbase.db,
                    Column('id', Integer, primary_key=True),
                    Column('bar', Integer))
        
        assign_mapper(Foo, foo)
        foo.create()
        for i in range(100):
            Foo(bar=i)
        objectstore.commit()
    
    def setUp(self):
        self.orig = Foo.select()
        self.res = SelectResults(Foo.mapper)
        
    def tearDownAll(self):
        global foo
        foo.drop()
        
    def test_slice(self):
        assert self.res[1] == self.orig[1]
        assert list(self.res[10:20]) == self.orig[10:20]
        assert list(self.res[10:]) == self.orig[10:]
        assert list(self.res[:10]) == self.orig[:10]
        assert list(self.res[:10]) == self.orig[:10]
        assert list(self.res[10:40:3]) == self.orig[10:40:3]
        assert list(self.res[-5:]) == self.orig[-5:]

    def test_aggregate(self):
        assert self.res.count() == 100
        assert self.res.filter(foo.c.bar<30).min(foo.c.bar) == 0
        assert self.res.filter(foo.c.bar<30).max(foo.c.bar) == 29
        assert self.res.filter(foo.c.bar<30).sum(foo.c.bar) == 435
        assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5

    def test_filter(self):
        assert self.res.count() == 100
        assert self.res.filter(Foo.c.bar < 30).count() == 30
        res2 = self.res.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
        assert res2.count() == 19
        
    def test_order_by(self):
        assert self.res.order_by([Foo.c.bar])[0].bar == 0
        assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99

    def test_offset(self):
        assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10
        
    def test_offset(self):
        assert len(list(self.res.limit(10))) == 10


if __name__ == "__main__":
    testbase.main()        

Reply via email to