#!/usr/bin/python
# relational algebra in Python 2.3
# Allows you to construct simple relational algebra queries in Python
# and translates them on the fly to SQL. Then you can iterate over
# the query results, which are Python dictionaries. I've tried this
# in Python 2.3, but maybe it might work in other versions of Python 2.
# This is just the barest of bare bones here, but I implemented it in
# one evening. You can see the kind of stuff it supports by looking
# in the test() method down below.
# Credit goes to Avi Bryant's vaporware "Roe" for Smalltalk, to
# SchemeQL, and to E. F. Codd for inspiration.
# A short list of the most egregiously missing pieces:
# - union, intersect, difference
# - more joins than just the simple inner join
# - a more convenient simple inner join
# - aggregate operations (and thus 'group by')
# - other computations in the list of output columns
# - aliases for output columns
# - relops other than '=' in the where list
# - more convenient syntax for specifying foo.column('baz'),
# e.g. foo['baz'] or foo.baz
# - support for query.column('foo.bar')
# - insertion and update of data
import types
class WideningProjection(Exception): pass
class AliasingAmbiguous(Exception): pass
class AmbiguousAttributeName(Exception): pass
class CantProjectColumnsFromAnotherTable(Exception): pass
class query:
def __init__(self, tables, dbconn=None, where=[], columns=[]):
(self.tables, self.where, self.columns, self.dbconn) = (
tables, where, columns, dbconn)
def clonebut(self, **overrides):
defaults = {
'tables': self.tables,
'where': self.where,
'columns': self.columns,
'dbconn': self.dbconn,
}
defaults.update(overrides)
return self.__class__(**defaults)
def sql(self):
return 'select %(columns)s from %(tables)s%(where)s' % {
'columns': ', '.join(self.columnnames()),
'tables': self.tablespecs(),
'where': self.whereclause()
}
def default_table(self):
if len(self.tables) != 1: return None
else: return self.tables[0][1]
def columnnames(self):
columns = self.columns or [
(alias, '*') for table, alias in self.tables]
return [self.namefor(column) for column in columns]
def tablespecs(self):
rv = []
for tablename, alias in self.tables:
if alias != tablename:
rv.append('%s as %s' % (tablename, alias))
else: rv.append(tablename)
return ', '.join(rv)
def whereclause(self):
if not self.where: return ''
rv = []
for variable, relop, value in self.where:
rv.append('%s %s %s' % (self.namefor(variable), relop,
self.namefor(value)))
return ' where ' + ' and '.join(rv)
def project(self, *columns):
ncolumns = []
available_tables = [alias for tablename, alias in self.tables]
for item in columns:
if type(item) in types.StringTypes:
item = self.column(item)
if item.tablename not in available_tables:
raise CantProjectColumnsFromAnotherTable(
item.tablename,
item.attrname,
available_tables
)
ncolumns.append((item.tablename, item.attrname))
if self.columns:
for col in ncolumns:
if col not in self.columns:
raise WideningProjection(col, self.columns)
return self.clonebut(columns=ncolumns)
def as(self, alias):
# hmm, we need to handle renaming here somehow
mytable = self.default_table()
if not mytable: raise AliasingAmbiguous(alias, self.tables)
return self.clonebut(tables=[(mytable, alias)])
def __mul__(self, other): # cartesian product
# XXX shouldn't we be worrying about renaming
# conflicts here?
return self.clonebut(
tables=(self.tables + other.tables),
where=(self.where + other.where),
columns=(self.columns + other.columns)
)
def select(self, *conditions, **sugarclauses):
eqclauses = [(self.column(var), '=', value)
for (var, value) in sugarclauses.items()]
return self.clonebut(where=self.where + eqclauses + list(conditions))
def column(self, attrname):
if not self.columns: return self.starattr(attrname)
candidates = [(table, col) for (table, col) in self.columns
if col == attrname]
if len(candidates) == 1:
return tablecolumn(self, candidates[0])
else:
raise AmbiguousAttributeName(attrname, candidates)
def starattr(self, attrname):
# for when we're a "select *" query
default = self.default_table()
if not default:
raise AmbiguousAttributeName(attrname,
[(tbl, attrname) for tbl in self.tables])
return tablecolumn(self, (default, attrname))
def namefor(self, thing):
if type(thing) in types.StringTypes:
# quoting that works for Postgres and MySQL
# but breaks standard databases
return "'%s'" % thing.replace('\\', '\\\\').replace("'", "''")
elif type(thing) is types.IntType:
return str(thing)
elif type(thing) is types.TupleType:
tablename, attrname = thing
elif hasattr(thing, 'tablename'):
tablename, attrname = thing.tablename, thing.attrname
else:
raise TypeError(thing)
if tablename == self.default_table(): return attrname
else: return '%s.%s' % (tablename, attrname)
def __iter__(self):
return query_results(self.sql(), self.dbconn)
class tablecolumn:
def __init__(self, query, (tablename, attrname)):
(self.query, self.tablename, self.attrname) = (
query, tablename, attrname)
def __eq__(self, other):
return self, '=', other
def name(self):
return self.query.namefor(self)
class query_results:
def __init__(self, sql, dbconn):
dbconn.query(sql)
self.results = dbconn.use_result()
def __iter__(self): return self
def next(self):
rows = self.results.fetch_row(how=1) # how=1 returns dicts
if not rows: raise StopIteration
return rows[0]
def table(tablename, dbconn=None):
return query(tables=[(tablename, tablename)], dbconn=dbconn)
def ok(a, b): assert a == b, (a, b)
def test():
foo = table("foo")
ok(foo.sql(), 'select * from foo')
ok(foo.project('a', 'b').sql(), 'select a, b from foo')
ok(foo.as('bar').sql(), 'select * from foo as bar')
ok((foo * foo.as('baz')).sql(),
'select foo.*, baz.* from foo, foo as baz')
ok(foo.select(a=3).sql(), 'select * from foo where a = 3')
ok(foo.select(a=5, b='asdf').sql(), "select * from foo where a = 5 and b =
'asdf'")
ok(foo.select(b="Can't").sql(), "select * from foo where b = 'Can''t'")
bar = table('bar')
joinq = (foo * bar).select(bar.column('id') ==
foo.column('barid')).project(foo.column('a'), foo.column('b'), bar.column('d'))
ok(joinq.sql(), "select foo.a, foo.b, bar.d from foo, bar where bar.id =
foo.barid")
child = foo.as('child')
ok((foo * child).select(foo.column('id') == child.column('parentid'))
.sql(), "select foo.*, child.* from foo, foo as child "
"where foo.id = child.parentid")
test()
# I tested the MySQL connectivity as follows:
# import _mysql, relalg
# db = _mysql.connect(db='kragen')
# q = relalg.table('foo', db)
# list(q)
# list(q.select(a=3))
# list(q.select(a=4))
# Before this, I'd had to install the python-mysql Debian package to
# get the _mysql module, and I'd had to create the database, grant
# myself access, create a table, and put stuff in the table. For
# future reference, MySQL doesn't have a 'varchar' type, just 'text',
# and a default MySQL installation allows "mysql -u root" to create
# databases and "grant all on newdatabasename.* to ''@'localhost'".
# Surprisingly, the '' around the @ really *are* important.