Luke Opperman skrev:
Ok, want to let you all in on something magical that made my weekend. This is
an extension to my earlier thread on using SQLBuilder more consistently in
SQLObject.

I wrote something similar a couple of weeks ago, I didn't submit it to the mailing list because it was not in my opinion clean enough to be submitted upstream.

The idea behind it is to avoid SQL views and instead define your own composite objects which uses columns from different tables.
SQLBuilder works so you can access ViewableSubclass.q.column which will
be translated to the real query defined in the columns dictionary.
GROUP BY is done magically behind you if a function is used. I plan
to add a white-list of aggregators so you can use non-aggregate functions at some point (where GROUP BY makes less sense)

It does not implement any kind of caching yet, so a new object will always be returned. I plan to solve this once it is a real problem to me.

I'm attaching the implementation (viewable.py) and an example (view.py)
I'm still using an old version of SQLObject, based on a fairly old 0.7 SVN revision, so it might not work with current trunk.

Johan
import copy

from sqlobject.dbconnection import Iteration
from sqlobject.declarative import DeclarativeMeta, setup_attributes
from sqlobject.sqlbuilder import (SQLCall, SQLObjectField, SQLObjectTable,
                                  NoDefault, AND)
from sqlobject.col import SOIntCol
from sqlobject.sresults import SelectResults
from sqlobject.styles import underToMixed
from sqlobject.classregistry import registry


class ViewableMeta(object):
    table = None
    defaultOrder = None
    columnList = []
    columnNames = []
    idName = 'id'
    columns = {}
    parentClass = None


class DynamicViewColumn(object):

    def __init__(self, cls, name):
        self.origName = self.name = self.dbName = name
        self.soClass = cls


class SQLObjectView(object):

    def __init__(self, cls):
        self.cls = cls

    def __getattr__(self, attr):
        if attr == 'id':
            return SQLObjectField(self.cls.sqlmeta.table,
                                  self.cls.sqlmeta.idName, attr)
        return self.cls.sqlmeta.columns[attr].value


class Viewable(object):
    __metaclass__ = DeclarativeMeta

    sqlmeta = ViewableMeta
    columns = {}
    clause = None

    def __classinit__(cls, new_attrs):
        setup_attributes(cls, new_attrs)

        columns = new_attrs['columns']
        if not columns:
            return

        cols = columns.copy()
        if not 'id' in cols:
            raise TypeError("You need a id column in %r" % Viewable)

        idquery = cols.pop('id')
        cls.sqlmeta.table = idquery.tableName

        for colName in sorted(cols):
            cls.addColumn(colName, cols[colName])

        cls.q = SQLObjectView(cls)

    @classmethod
    def addColumn(cls, name, query):
        col = None
        if isinstance(query, SQLObjectField):
            table = table_from_name(query.tableName)
            fieldName = query.fieldName
            if fieldName != 'id':
                # O(N)
                for col in table.sqlmeta.columnList:
                    if col.dbName == fieldName:
                        break
                else:
                    raise AssertionError(table.sqlmeta.table + '.' + name)

                # Let's modify origName so it can be used in introspection, but first
                # make a copy of the column.
                col = copy.copy(col)
                col.origName = name

        if not col:
            col = DynamicViewColumn(cls, name)

        col.value = query
        cls.sqlmeta.columns[name] = col
        cls.sqlmeta.columnList.append(col)
        cls.sqlmeta.columnNames.append(name)

    @classmethod
    def delColumn(cls, name):
        col = cls.sqlmeta.columns.pop(name)
        cls.sqlmeta.columnList.remove(col)
        cls.sqlmeta.columnNames.remove(name)

    @classmethod
    def get(cls, idValue, selectResults=None, connection=None):
        if not selectResults:
            selectResults = []

        instance = cls()
        instance.id = idValue
        instance.__dict__.update(zip(cls.sqlmeta.columnNames,
                                      selectResults))
        instance._connection = connection

        return instance

    @classmethod
    def select(cls, clause=None, clauseTables=None,
               orderBy=NoDefault, limit=None,
               lazyColumns=False, reversed=False,
               distinct=False, connection=None,
               join=None, columns=None):
        if cls.clause:
            if clause:
                clause = AND(clause, cls.clause)
            else:
                clause = cls.clause
        if columns:
            cls.columns.update(columns)

        return ViewableSelectResults(cls, clause,
                                     clauseTables=clauseTables,
                                     orderBy=orderBy,
                                     limit=limit,
                                     lazyColumns=lazyColumns,
                                     reversed=reversed,
                                     distinct=distinct,
                                     connection=connection,
                                     join=cls.joins,
                                     ns=cls.columns)

    def sync(self):
        obj = self.select(
            self.q.id == self.id,
            connection=self.get_connection()).getOne()

        for attr in self.sqlmeta.columnNames:
            setattr(self, attr, getattr(obj, attr, None))

    def get_connection(self):
        return self._connection

def queryForSelect(conn, select):
    ops = select.ops
    join = ops.get('join')
    cls = select.sourceClass
    if join:
        tables = conn._fixTablesForJoins(select)
    else:
        tables = select.tables
    if ops.get('distinct', False):
        q = 'SELECT DISTINCT '
    else:
        q = 'SELECT '

    if ops.get('lazyColumns', 0):
        q += "%s.%s FROM %s" % (
            cls.sqlmeta.table, cls.sqlmeta.idName,
            ", ".join(tables))
    else:
        ns = select.ops['ns'].copy()
        q += '%s AS id, %s FROM %s' % (
            ns.pop('id'),
            ', '.join(['%s AS %s' % (ns[item], item)
                       for item in sorted(ns.keys())]),
            ", ".join(tables))

    if join:
        q += conn._addJoins(select, tables)

        if not tables:
            q = q[:-1]

    q += " WHERE"
    q = conn._addWhereClause(select, q, limit=0)

    groupBy = False
    for item in ns.values():
        if isinstance(item, SQLCall):
            groupBy = True
            break

    if groupBy:
        items = []
        for item in ns.values():
            if not isinstance(item, SQLCall):
                items.append(str(item))
        items.append(str(select.ops['ns']['id']))
        q += " GROUP BY %s" % ', '.join(items)

    start = ops.get('start', 0)
    end = ops.get('end', None)
    if start or end:
        q = conn._queryAddLimitOffset(q, start, end)

    return q


class ViewableIteration(Iteration):

    def __init__(self, dbconn, rawconn, select, keepConnection=False):
        self.dbconn = dbconn
        self.rawconn = rawconn
        self.select = select
        self.keepConnection = keepConnection
        self.cursor = rawconn.cursor()
        self.query = queryForSelect(dbconn, select)
        if dbconn.debug:
            dbconn.printDebug(rawconn, self.query, 'Select')
        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)


class ViewableSelectResults(SelectResults):

    def __init__(self, sourceClass, clause, clauseTables=None,
                 **ops):
        SelectResults.__init__(self, sourceClass, clause, clauseTables, **ops)

        # The table we're joining from must be the last one in the FROM-clause
        table = sourceClass.sqlmeta.table
        if self.tables[-1] != table:
            self.tables.remove(table)
            self.tables.append(table)

    def __str__(self):
        return queryForSelect(self._getConnection(), self)

    def lazyIter(self):
        conn = self._getConnection()
        return iter(list(ViewableIteration(
            conn, conn.getConnection(), self, keepConnection=True)))


_cache = {}
def table_from_name(name):
    # O(1), but initially expensive
    global _cache
    if not _cache:
        for table in registry(None).allClasses():
            _cache[table.sqlmeta.table] = table
    return _cache[name]

# -*- Mode: Python; coding: iso-8859-1 -*-
# vi:si:et:sw=4:sts=4:ts=4

##
## Copyright (C) 2007 Async Open Source <http://www.async.com.br>
## All rights reserved
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 2 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; if not, write to the Free Software
## Foundation, Inc., or visit: http://www.gnu.org/.
##
## Author(s):   Johan Dahlin <[EMAIL PROTECTED]>
##

from sqlobject.viewable import Viewable
from stoqlib.domain.person import Person, PersonAdaptToSupplier
from stoqlib.domain.product import (Product, ProductAdaptToSellable,
                                    ProductAdaptToStorable,
                                    ProductStockItem,
                                    ProductSupplierInfo)
from stoqlib.domain.sellable import ASellable, SellableUnit, BaseSellableInfo
from stoqlib.domain.stock import AbstractStockItem
from sqlobject.sqlbuilder import func, AND, INNERJOINOn, LEFTJOINOn

class WarehouseView(Viewable):

    columns = dict(
        id=ASellable.q.id,
        code=ASellable.q.code,
        barcode=ASellable.q.barcode,
        status=ASellable.q.status,
        cost=ASellable.q.cost,
        price=BaseSellableInfo.q.price,
        is_valid_model=BaseSellableInfo.q._is_valid_model,
        description=BaseSellableInfo.q.description,
        unit=SellableUnit.q.description,
        product_id=Product.q.id,
        supplier_name=Person.q.name,
        stock=func.SUM(AbstractStockItem.q.quantity +
                       AbstractStockItem.q.logic_quantity),
        )

    joins = [
        # Sellable unit
        LEFTJOINOn(None, SellableUnit,
                    SellableUnit.q.id == ASellable.q.unitID),
        # Product
        INNERJOINOn(None, ProductAdaptToSellable,
                    ProductAdaptToSellable.q.id == ASellable.q.id),
        INNERJOINOn(None, Product,
                    Product.q.id == ProductAdaptToSellable.q._originalID),
        # Product Stock Item
        INNERJOINOn(None, ProductAdaptToStorable,
                    ProductAdaptToStorable.q._originalID == Product.q.id),
        INNERJOINOn(None, ProductStockItem,
                    ProductStockItem.q.storableID == ProductAdaptToStorable.q.id),
        INNERJOINOn(None, AbstractStockItem,
                    AbstractStockItem.q.id == ProductStockItem.q.id),
        # Product Supplier
        INNERJOINOn(None, ProductSupplierInfo,
                    AND(ProductSupplierInfo.q.productID == Product.q.id,
                        ProductSupplierInfo.q.is_main_supplier == True)),
        INNERJOINOn(None, PersonAdaptToSupplier,
                    PersonAdaptToSupplier.q.id == ProductSupplierInfo.q.supplierID),
        INNERJOINOn(None, Person,
                    Person.q.id == PersonAdaptToSupplier.q._originalID),
        ]

    clause = AND(
        BaseSellableInfo.q.id == ASellable.q.base_sellable_infoID,
        BaseSellableInfo.q._is_valid_model == True,
        )

    @classmethod
    def select_by_branch(cls, query, branch, connection=None):
        if branch:
            branch_query = AbstractStockItem.q.branchID == branch.id
            if query:
                query = AND(query, branch_query)
            else:
                query = branch_query

        return cls.select(query, connection=connection)

    @property
    def product(self):
        return Product.get(self.product_id, connection=self.get_connection())
-------------------------------------------------------------------------
Take Surveys. Earn Cash. Influence the Future of IT
Join SourceForge.net's Techsay panel and you'll get the chance to share your
opinions on IT & business topics through brief surveys-and earn cash
http://www.techsay.com/default.php?page=join.php&p=sourceforge&CID=DEVDEV
_______________________________________________
sqlobject-discuss mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/sqlobject-discuss

Reply via email to