#!/usr/bin/python
# relational algebra in Python 2.3

# Allows you to construct simple relational algebra queries in Python
# and translates them on the fly to SQL.  Then you can iterate over
# the query results, which are Python dictionaries.  I've tried this
# in Python 2.3, but maybe it might work in other versions of Python 2.

# This is just the barest of bare bones here, but I implemented it in
# one evening.  You can see the kind of stuff it supports by looking
# in the test() method down below.

# Credit goes to Avi Bryant's vaporware "Roe" for Smalltalk, to
# SchemeQL, and to E. F. Codd for inspiration.

# A short list of the most egregiously missing pieces:
# - union, intersect, difference
# - more joins than just the simple inner join
# - a more convenient simple inner join
# - aggregate operations (and thus 'group by')
# - other computations in the list of output columns
# - aliases for output columns
# - relops other than '=' in the where list
# - more convenient syntax for specifying foo.column('baz'),
#   e.g. foo['baz'] or foo.baz
# - support for query.column('foo.bar')
# - insertion and update of data

import types

class WideningProjection(Exception): pass
class AliasingAmbiguous(Exception): pass
class AmbiguousAttributeName(Exception): pass
class CantProjectColumnsFromAnotherTable(Exception): pass

class query:
        def __init__(self, tables, dbconn=None, where=[], columns=[]):
                (self.tables, self.where, self.columns, self.dbconn) = (
                        tables, where, columns, dbconn)
        def clonebut(self, **overrides):
                defaults = {
                        'tables': self.tables, 
                        'where': self.where,
                        'columns': self.columns,
                        'dbconn': self.dbconn,
                }
                defaults.update(overrides)
                return self.__class__(**defaults)
        def sql(self):
                return 'select %(columns)s from %(tables)s%(where)s' % {
                        'columns': ', '.join(self.columnnames()),
                        'tables': self.tablespecs(),
                        'where': self.whereclause()
                }
        def default_table(self):
                if len(self.tables) != 1: return None
                else: return self.tables[0][1]
        def columnnames(self):
                columns = self.columns or [
                        (alias, '*') for table, alias in self.tables]
                return [self.namefor(column) for column in columns]
        def tablespecs(self):
                rv = []
                for tablename, alias in self.tables:
                        if alias != tablename: 
                                rv.append('%s as %s' % (tablename, alias))
                        else: rv.append(tablename)
                return ', '.join(rv)
        def whereclause(self):
                if not self.where: return ''
                rv = []
                for variable, relop, value in self.where:
                        rv.append('%s %s %s' % (self.namefor(variable), relop, 
self.namefor(value)))
                return ' where ' + ' and '.join(rv)
        def project(self, *columns):
                ncolumns = []
                available_tables = [alias for tablename, alias in self.tables]
                for item in columns:
                        if type(item) in types.StringTypes:
                                item = self.column(item)
                        if item.tablename not in available_tables:
                                raise CantProjectColumnsFromAnotherTable(
                                        item.tablename, 
                                        item.attrname, 
                                        available_tables
                                )
                        ncolumns.append((item.tablename, item.attrname))
                if self.columns:
                        for col in ncolumns:
                                if col not in self.columns: 
                                        raise WideningProjection(col, self.columns)
                return self.clonebut(columns=ncolumns)
        def as(self, alias):
                # hmm, we need to handle renaming here somehow
                mytable = self.default_table()
                if not mytable: raise AliasingAmbiguous(alias, self.tables)
                return self.clonebut(tables=[(mytable, alias)])
        def __mul__(self, other):   # cartesian product
                # XXX shouldn't we be worrying about renaming
                # conflicts here?
                return self.clonebut(
                        tables=(self.tables + other.tables), 
                        where=(self.where + other.where),
                        columns=(self.columns + other.columns)
                )
        def select(self, *conditions, **sugarclauses):
                eqclauses = [(self.column(var), '=', value) 
                        for (var, value) in sugarclauses.items()]
                return self.clonebut(where=self.where + eqclauses + list(conditions))
        def column(self, attrname):
                if not self.columns: return self.starattr(attrname)
                candidates = [(table, col) for (table, col) in self.columns
                        if col == attrname]
                if len(candidates) == 1:
                        return tablecolumn(self, candidates[0])
                else:
                        raise AmbiguousAttributeName(attrname, candidates)
        def starattr(self, attrname):
                # for when we're a "select *" query
                default = self.default_table()
                if not default:
                        raise AmbiguousAttributeName(attrname, 
                                [(tbl, attrname) for tbl in self.tables])
                return tablecolumn(self, (default, attrname))
        def namefor(self, thing):
                if type(thing) in types.StringTypes:
                        # quoting that works for Postgres and MySQL
                        # but breaks standard databases
                        return "'%s'" % thing.replace('\\', '\\\\').replace("'", "''")
                elif type(thing) is types.IntType:
                        return str(thing)
                elif type(thing) is types.TupleType:
                        tablename, attrname = thing
                elif hasattr(thing, 'tablename'):
                        tablename, attrname = thing.tablename, thing.attrname
                else:
                        raise TypeError(thing)
                if tablename == self.default_table(): return attrname
                else: return '%s.%s' % (tablename, attrname)
        def __iter__(self):
                return query_results(self.sql(), self.dbconn)

class tablecolumn:
        def __init__(self, query, (tablename, attrname)):
                (self.query, self.tablename, self.attrname) = (
                        query, tablename, attrname)
        def __eq__(self, other):
                return self, '=', other
        def name(self):
                return self.query.namefor(self)

class query_results:
        def __init__(self, sql, dbconn):
                dbconn.query(sql)
                self.results = dbconn.use_result()
        def __iter__(self): return self
        def next(self):
                rows = self.results.fetch_row(how=1)  # how=1 returns dicts
                if not rows: raise StopIteration
                return rows[0]

def table(tablename, dbconn=None):
        return query(tables=[(tablename, tablename)], dbconn=dbconn)

def ok(a, b): assert a == b, (a, b)

def test():
        foo = table("foo")
        ok(foo.sql(), 'select * from foo')
        ok(foo.project('a', 'b').sql(), 'select a, b from foo')
        ok(foo.as('bar').sql(), 'select * from foo as bar')
        ok((foo * foo.as('baz')).sql(), 
                'select foo.*, baz.* from foo, foo as baz')
        ok(foo.select(a=3).sql(), 'select * from foo where a = 3')
        ok(foo.select(a=5, b='asdf').sql(), "select * from foo where a = 5 and b = 
'asdf'")
        ok(foo.select(b="Can't").sql(), "select * from foo where b = 'Can''t'")
        bar = table('bar')
        joinq = (foo * bar).select(bar.column('id') == 
foo.column('barid')).project(foo.column('a'), foo.column('b'), bar.column('d'))
        ok(joinq.sql(), "select foo.a, foo.b, bar.d from foo, bar where bar.id = 
foo.barid")
        child = foo.as('child')
        ok((foo * child).select(foo.column('id') == child.column('parentid'))
                .sql(), "select foo.*, child.* from foo, foo as child "
                "where foo.id = child.parentid")

test()

# I tested the MySQL connectivity as follows:
# import _mysql, relalg
# db = _mysql.connect(db='kragen')
# q = relalg.table('foo', db)
# list(q)
# list(q.select(a=3))
# list(q.select(a=4))

# Before this, I'd had to install the python-mysql Debian package to
# get the _mysql module, and I'd had to create the database, grant
# myself access, create a table, and put stuff in the table.  For
# future reference, MySQL doesn't have a 'varchar' type, just 'text',
# and a default MySQL installation allows "mysql -u root" to create
# databases and "grant all on newdatabasename.* to ''@'localhost'".
# Surprisingly, the '' around the @ really *are* important.

Reply via email to