This isn't at all packaged or documented or anything, but I thought it
might be apropos to the discussion.

This is what I've been using for an object wrapper.  It's grown
organically, and there are certainly things that are missing.  But it
might be of interest.  Objects all inherit from a single class, and have
attributes to declare the table and columns -- I suppose columns could
be automatically picked up, but declaring them explicitly isn't a big
deal for me.  In the near future, instead of declaring columns with a
list of names, I'd be more apt to use some class to define the columns
(whether there's a trigger for changes, what the default value should
be, etc.)

Anyway, Document.py is an example of using the classes.  It should be
apparent it requires lots of other context, but it might give a bit of
an idea.  SQLObject.py has the class and the building function. 
ParamFactory.py makes sure when you request an object there aren't
duplicates.  db.py is a little abstraction layer.

  Ian

from Local.SQLObject import *
import Project
import os
import User
from Request import Request
from Locations import getLocation
import SitePage
import DocumentInstance
from FileInspector import getInfo
from mx import DateTime

True, False = 1==1, 0==1

class _Document(SQLObject):

    _database = 'comment'
    _table = 'document'
    _columns = ['projectID', 'title',
                'authorUserID', 'activeDocumentInstanceID']
    _extraNew = ['projectID']
    _triggers = ['projectID']
    _defaults = {'authorUserID': None,
                 }
    
    def project(self):
        return Project.Project(self.projectID())

    def author(self):
        if self.authorUserID() is None:
            return None
        return User.User(self.authorUserID())

    def titleOrName(self):
        if self.title():
            return self.title()
        if self.activeDocumentInstanceID():
            docInst = self.activeDocumentInstance()
            if docInst.sourceFilename():
                return os.path.basename(docInst.sourceFilename())
        return 'untitled document %i' % self._id

    def newProjectID(self, projectID):
        project = Project.Project(projectID)
        project._childrenDocuments = None
    changeProjectID = newProjectID

    def activeDocumentInstance(self):
        return DocumentInstance.DocumentInstance(self.activeDocumentInstanceID())

    def documentInstances(self):
        ids = queryFirst(self._database, """
        SELECT id
        FROM document_instance
        WHERE document_id = %s
        ORDER BY date_retrieved
        """, self._id)
        return map(DocumentInstance.DocumentInstance, ids)

    def changeNew(self):
        self.project()._childrenDocuments = None

def delDocument(document):
    documentID = SitePage.getID(document)
    document = Document(documentID)
    project = document.project()
    project._childrenDocuments = None
    project._documentCount = None
    for documentInstance in document.documentInstances():
        DocumentInstance.DocumentInstance.delete(documentInstance.id())
    query(_Document._database, """
    DELETE FROM document
    WHERE id = %s
    """, documentID)

Document = SQLBuild(_Document, delete=delDocument)
from threading import Lock

class ParamFactory:

    def __init__(self, klass, **extraMethods):
        self._lock = Lock()
        self._cache = {}
        self._klass = klass
        for name, func in extraMethods.items():
            setattr(self, name, func)

    def __call__(self, *args):
        self._lock.acquire()
        if not self._cache.has_key(args):
            value = self._klass(*args)
            self._cache[args] = value
            self._lock.release()
            return value
        else:
            self._lock.release()
            return self._cache[args]

    def allInstances(self):
        return self._cache.values()
    
from db import *
import ParamFactory
import threading
import re

True, False = 1==1, 0==1

class SQLObject:
    """
    SPECIAL ATTRIBUTES:

    _database = database name
    _table = table name
    _columns = dictionary of attribute->SQL column names
    _defaults = dictionary of attribute->default values
    _extraNew = list of non-SQL values that can be set with set()
    _alternateID = dictionary of attributes->creation names
        that allow non-ID id's (name aliases)
    _trigger = list of attributes for which changeAttribute must
        be called when the attribute is changed

    SPECIAL METHODS:

    _processNewInput:
        Class method (called like def _processNewInput(klass, kw))
        Takes the class and dictionary of input to .new(),
        and fiddles as it chooses, returning the new dictionary.
        (class method, as this is called before instantiation)
    _processSetInput:
        Like _processNewInput, except called on .set()
        (with (self, kw))
    changeNew:
        Called when a new object is created (not like __init__,
        but when an entirely new object is created)
    """

    def __init__(self, id):
        self._id = id
        self._autoInitDone = False
        self._writeLock = threading.Lock()

    def id(self):
        return self._id
        
    def set(self, **kw):
        self._autoInit()
        self._writeLock.acquire()
        if self._processSetInput:
            kw = self._processSetInput(kw)
        sets = []
        triggers = []
        for name, value in kw.items():
            if self._columns.has_key('%sID' % name) \
               and not self._columns.has_key(name):
                kw['%sID' % name] = getID(value)
                del kw[name]
        for name, value in kw.items():
            if not self._columns.has_key(name):
                getattr(self, prefix('set', name))(value)
                continue
            if value != self._autoValues[name]:
                triggers.append((name, self._autoValues[name]))
            sets.append('%s = %s' % (self._columns[name], sqlQuote(value)))
        if sets:
            query(self._database,
                  """
                  UPDATE %s
                  SET %s
                  WHERE id = %s
                  """ % (self._table, ', \n'.join(sets), self._id))
        if self._autoInitDone:
            for name, value in kw.items():
                self._autoValues[name] = value
        if self._trigger:
            for trigger, oldValue in triggers:
                if trigger in self._trigger:
                    getattr(self, prefix('change', trigger))(oldValue)
        self._writeLock.release()

    def _autoInit(self):
        if self._autoInitDone: return
        items = self._columns.items()
        self._autoValues = {}
        values = queryOne(self._database, """
        SELECT %s
        FROM %s
        WHERE id = %s
        """ % (', '.join(map(lambda x: x[1], items)),
               self._table,
               self._id))
        assert values, 'The object %s by the ID %s does not exist' \
               % (self.__class__.__name__, self._id)
        for (attrName, columnName), value in zip(items, values):
            self._autoValues[attrName] = value
        self._autoInitDone = True
        return True

    def invalidateCache(self):
        self._autoInitDone = False

    def changeNew(self):
        pass

def sqlBuildAttr(klass):
    for attrName, columnName in klass._columns.items():
        l = eval('lambda self: (self._autoInitDone or self._autoInit()) and self._autoValues[%s]' % repr(attrName))
        setattr(klass, attrName, l)

def SQLBuild(klass, **extraMethods):
    if not hasattr(klass, '_defaults'):
        klass._defaults = {}
    if not hasattr(klass, 'extraNew'):
        klass._extraNew = []
    if not hasattr(klass, '_alternateID'):
        klass._alternateID = {}
    if not hasattr(klass, '_trigger'):
        klass._trigger = []
    if not hasattr(klass, '_processNewInput'):
        klass._processNewInput = None
    if not hasattr(klass, '_processSetInput'):
        klass._processSetInput = None
    if type(klass._columns) is type([]) \
       or type(klass._columns) is type(()):
        dict = {}
        for name in klass._columns:
            dict[name] = _translateName(name)
        klass._columns = dict

    sqlBuildAttr(klass)
    newClass = NewSQLObject(klass)
    extraMethods['new'] = newClass
    setFactory = []
    if klass._alternateID:
        for attribute, methodName in klass._alternateID.items():
            extraMethods[methodName] = ByAttributeBuilder(klass, attribute)
            setFactory.append(extraMethods[methodName])
    setFactory.append(newClass)
    factory = ParamFactory.ParamFactory(klass, **extraMethods)
    for object in setFactory:
        object._factory = factory
    return factory

class NewSQLObject:

    def __init__(self, klass):
        self._klass = klass

    def __call__(self, **kw):
        klass = self._klass
        if klass._processNewInput:
            kw = klass._processNewInput.im_func(klass, kw)
        for columnName in klass._columns.keys():
            if columnName.endswith('ID') \
               and kw.has_key(columnName[:-2]):
                kw[columnName] = getID(kw[columnName[:-2]])
                del kw[columnName[:-2]]
        for name, default in klass._defaults.items():
            if not kw.has_key(name):
                if callable(default):
                    default = default()
                kw[name] = default
        for name in klass._columns.keys():
            assert kw.has_key(name), \
                   'You must provide a value for the column %s' % name
        items = kw.items()
        names = map(lambda x, c=klass._columns: c[x[0]], items)
        values = map(lambda x: x[1], items)
        id = queryInsert(
            klass._database,
            """
            INSERT INTO %s (%s)
            VALUES (%s)
            """ % (klass._table,
                   ', '.join(names),
                   ', '.join(map(lambda x: sqlQuote(x), values))))
        object = self._factory(id)

        for column in klass._extraNew:
            if not kw.has_key(column): continue
            getattr(object, prefix('new', column))(kw[column])

        object.changeNew()
        return object

class ByAttributeBuilder:

    def __init__(self, klass, attribute):
        self._klass = klass
        self._attribute = attribute
        self._cached = {}

    def __call__(self, id):
        if self._cached.has_key(id):
            return self._factory(self._cached[id])
        else:
            (realID,) = queryOne("""
            SELECT id
            FROM %s
            WHERE %s = %s
            """ % (self._klass._table,
                   self._attribute,
                   sql_quote(id)))
            self._cached[id] = realID
            return self._factory(realID)
        
def getID(obj):
    if type(obj) is type(1):
        return obj
    elif type(obj) is type(1L):
        return int(obj)
    elif type(obj) is type(""):
        return int(obj)
    else:
        return obj.id()

def getObject(obj, klass):
    if type(obj) is type(1):
        return klass(obj)
    elif type(obj) is type(1L):
        return klass(int(obj))
    elif type(obj) is type(""):
        return klass(int(obj))
    else:
        return obj

_translateRE = re.compile(r'_.')
_idRE = re.compile('id$', re.I)
def _translateName(name):
    name = _translateRE.sub(lambda x: x.group(0)[1].upper(), name)
    return _idRE.sub('ID', name)

_translateRE = re.compile(r'[A-Z]')
def _translateName(name):
    if name.endswith('ID'):
        tail = '_id'
        name = name[:-2]
    else:
        tail = ''
    name = _translateRE.sub(lambda x: '_%s' % x.group(0).lower(), name)
    return name + tail

def switchBoolean(kw, switchKey, answerKey, trueValue, falseValue):
    if kw.has_key(switchKey):
        if kw[switchKey]:
            kw[answerKey] = trueValue
        else:
            kw[answerKey] = falseValue
        del kw[switchKey]
    return kw

def prefix(p, s):
    return p + s[0].upper + s[1:]
import MySQLdb, re, string
try:
    import mx.DateTime.ISO
    isoStr = mx.DateTime.ISO.strGMT
    from mx.DateTime import DateTimeType
except ImportError:
    import DateTime.ISO
    isoStr = DateTime.ISO.strGMT
    from DateTime import DateTimeType

oldIsoStr = isoStr
def isoStr(val):
    val = oldIsoStr(val)
    if val.find('+') == -1:
        return val
    else:
        return val[:val.find('+')]

databases = {
    'murphyarts_emily': ['murphyarts_emily', 'password'],
    'comment': ['comment', 'password'],
    }


True, False = (1==1), (0==1)
_debug = False
#_debug = True

_dbs = {}
for (dbName, (username, password)) in databases.items():
    _dbs[dbName] = MySQLdb.connect(host='localhost', db=dbName,
                                   user=username, passwd=password)

def getCursor(dbName, named=False):
    if named:
        return _dbs[dbName].cursor(MySQLdb.DictCursor)
    else:
        return _dbs[dbName].cursor()

def queryAll(dbName, s, *vars, **kw):
    """Get a list of rows from result"""
    c = getCursor(dbName, **kw)
    c.execute(sqlSub(s, vars))
    return c.fetchall()

def query(dbName, s, *vars, **kw):
    """Get the cursor from the result"""
    c = getCursor(dbName, **kw)
    c.execute(sqlSub(s, vars))
    return c

def queryInsert(dbName, s, *vars, **kw):
    """Execute a query, and get the ID of the supposed insert"""
    c = query(dbName, s, *vars, **kw)
    return c.insert_id()

def queryFirst(dbName, *vars, **kw):
    """
    Get a list of items, where each item is the first
    item from the return row, as with:
    SELECT id FROM whatever
    Otherwise you get [(1,), (2,), (3,)], instead of [1, 2, 3]
    """
    values = queryAll(dbName, *vars, **kw)
    return map(lambda x: x[0], values)

def queryOne(dbName, s, *vars, **kw):
    """Get a single row from a result"""
    c = getCursor(dbName, **kw)
    c.execute(sqlSub(s, vars))
    return c.fetchone()

def sqlSub(s, vars):
    """substitute s with the quoted version of vars"""
    if not vars:
        return s
    elif len(vars) == 1:
        return s % sqlQuote(vars[0])
    else:
        return s % tuple(map(sqlQuote, vars))

def sqlSubViewer(s, vars):
    print "query: %s (V: %s)" % (s, map(sqlQuote, vars))
    q = sqlSubDo(s, vars)
    return q

if _debug:
    sqlSubDo = sqlSub
    sqlSub = sqlSubViewer

class SQLExpression:
    def __init__(self, expr):
        self._expr = expr
    def sqlRepr(self):
        return self._expr
    def __repr__(self):
        return '<SQL Expression: %s>' % repr(self._expr)
    __str__ = __repr__

_sqlQuoteRE = re.compile("'")
def sqlQuote(value):
    """
    Quote a value for insertion as a SQL value.  Puts ' around
    strings, nothing around numbers, turns None into NULL,
    and turns arrays into (<quoted values>, ...)
    """
    if type(value) is type(""):
        return "'%s'" % _sqlQuoteRE.sub("\\'", str(value))
    elif value is None:
        return "NULL"
    elif type(value) is type(()) or \
         type(value) is type([]):
        return "(" + string.join(map(sqlQuote, value), ", ") + ")"
    elif type(value) is type(1) or \
         type(value) is type(1L):
        return "%i" % value
    elif type(value) is type(1.0):
        return str(value)
    elif isinstance(value, SQLExpression):
        return value.sqlRepr()
    elif type(value) is DateTimeType:
        return "'%s'" % isoStr(value)
    else:
        raise ValueError, "Unknown type to quote: %s for %s" % \
              (type(value), repr(value))

Reply via email to