Hi
I few weeks ago I decided to spend a little time on an small hack which
would make it possible to use python list comprehensions and avoid
the penalty imposed when executing logic on the client side.
Eg, Why shouldn't it be possible to write this kind of code:
cost = sum(e.salary * e.comission for e in company.employees)
instead of;
result = company.employees.find(Employee.company_id == Company.id)
cost = result.sum(Employee.salary * Employee.comission)
The former version is far more readable but unfortunately a lot slower
when the size of the database increases.
Attached is a small proof of concept I wrote which with help of a
decorator and a meta class makes it possible to mark methods which
then will be optimized.
I'm not suggesting that this should be included in Storm itself, it's
pretty hackish as it is right now. I'm just proving that it can be done
and making sure that this can be found by anyone having similar ideas.
Johan
import compiler
from compiler import pycodegen, misc
from compiler.ast import CallFunc, Node, Name, Assign, Return, GenExpr
import inspect
import new
from storm.references import ReferenceSet
class SQLReplacer(object):
def __init__(self, class_name, function, reference_sets):
self.class_name = class_name
self.function = function
self.reference_sets = reference_sets
source = inspect.getsource(function)
tree = compiler.parse('class %s:\n%s' % (class_name, source))
m = __import__(function.__module__, '', '', ' ')
misc.set_filename('<%s>'% m.__file__, tree)
klass = tree.node.nodes[0]
decorator = klass.code.nodes[0]
self.parse(decorator.code)
self.tree = tree
def get_func(self):
gen = pycodegen.ModuleCodeGenerator(self.tree)
code = gen.getCode()
locs = dict(sqlmagic=lambda x: x)
exec code in locs
func = getattr(locs[self.class_name], self.function.__name__)
return new.function(func.func_code,
self.function.func_globals)
def parse(self, node):
for child in node.getChildren():
if isinstance(child, Node):
child.parentNode = node
self.parse(child)
node_name = node.__class__.__name__
func = getattr(self, node_name, None)
if func:
func(node)
def _get_genexr(self, node):
while True:
if not hasattr(node, 'parentNode'):
return
if isinstance(node.parentNode, GenExpr):
return node.parentNode
node = node.parentNode
def _generate_body(self, klass, reference, column, aggregator):
result_string = (
"__%(Class)s = self.__class__.%(refname)s._remote_key1[0].cls\n"
"__reference = self.%(refname)s\n"
"__key = getattr(self, __reference._relation.local_key[0].name)\n"
"__result = __reference.find(__%(Class)s.%(column)s == __key)\n"
"__value = __result.%(aggregator)s()") % dict(Class=klass,
refname=reference,
column=column,
aggregator=aggregator)
module = compiler.parse(result_string)
stmt = module.node
return stmt.nodes
def Getattr(self, node):
if (isinstance(node.expr, Name) and
node.expr.name != 'self' or
node.attrname not in self.reference_sets):
return
genexpr = self._get_genexr(node)
if not genexpr:
return
aggregator_node = genexpr.parentNode
if not isinstance(aggregator_node, CallFunc):
return
aggregator = aggregator_node.node.name
if aggregator not in ['sum', 'max', 'min']:
print 'Unsupported aggregator: %s' % aggregator
return
aggregator_args = genexpr.code.expr
construct = genexpr.parentNode.parentNode
if isinstance(construct, Return):
construct.value = Name('__value')
elif isinstance(construct, Assign):
construct.expr = Name('__value')
else:
print 'Unsupported construct: %r' % (construct,)
return
reference = self.reference_sets[node.attrname]
refcls = reference._remote_key1.cls.__name__
refname = reference._remote_key1.name
self.visit(genexpr.code.expr, refcls)
generated = self._generate_body(refcls, node.attrname, refname,
aggregator)
generated_call = generated[-1].expr
generated_call.args.append(aggregator_args)
parent = construct.parentNode
pos = parent.nodes.index(construct)
parent.nodes = parent.nodes[:pos] + generated + parent.nodes[pos:]
def visit(self, node, refname):
if isinstance(node, Name):
node.name = '__' + refname
if isinstance(node, Node):
for child in node.getChildren():
self.visit(child, refname)
class MagicMeta(type):
def __new__(self, name, bases, ns):
reference_sets = {}
for attr, value in ns.items():
if isinstance(value, ReferenceSet):
reference_sets[attr] = value
for attr, value in ns.items():
if hasattr(value, '__magic__'):
compiler = SQLReplacer(name, value, reference_sets)
ns[attr] = compiler.get_func()
return type.__new__(self, name, bases, ns)
class Magic(object):
__metaclass__ = MagicMeta
def sqlmagic(func):
func.__magic__ = True
return func
def test():
from storm import database
from storm.database import create_database
from storm.properties import Int, Unicode, Float
from storm.store import Store
from storm.references import Reference, ReferenceSet
db = create_database("sqlite:")
store = Store(db)
class Employee(object):
__storm_table__ = "employee"
salary = Float()
comission = Float()
company_id = Int()
id = Int(primary=True)
name = Unicode()
def __init__(self, name, salary, comission):
self.name = name
self.salary = salary
self.comission = comission
class Company(Magic):
__storm_table__ = "company"
id = Int(primary=True)
name = Unicode()
employees = ReferenceSet(id, Employee.company_id)
def __init__(self, name):
self.name = name
@sqlmagic
def calculate_to_pay_magic(self):
# This is complied into:
# __Employee = self.__class__.employees._remote_key1[0].cls
# __reference = self.employees
# __key = getattr(self, __reference._relation.local_key[0].name)
# __result = __reference.find(__Employee.company_id == __key)
# __value = __result.sum(__Employee.salary * __Employee.comission)
# x = __value
x = sum(e.salary * e.comission for e in self.employees)
return x
def calculate_to_pay(self):
return sum(e.salary * e.comission for e in self.employees)
Employee.company = Reference(Employee.company_id, Company.id)
print 'Creating example data'
store.execute("CREATE TABLE company "
"(id INTEGER PRIMARY KEY, name VARCHAR)", noresult=True)
store.execute("CREATE TABLE employee "
"(id INTEGER PRIMARY KEY, name VARCHAR, company_id INTEGER, "
"salary FLOAT, comission FLOAT)")
sweets = store.add(Company(u"Sweets Inc."))
store.flush()
for i in range(5000):
mike = store.add(Employee(u"Mike Mayer %d" % i, 200+i, 10+i))
sweets.employees.add(mike)
store.flush()
#database.DEBUG = True
import time
print
print 'Normal query ----'
t = time.time()
print 'Result:', sweets.calculate_to_pay()
print '%2.4f sec' % (time.time()-t)
print
print 'SQL magic query ----'
t = time.time()
print 'Result:', sweets.calculate_to_pay_magic()
print '%2.4f sec' % (time.time()-t)
if __name__ == '__main__':
test()
--
storm mailing list
[email protected]
Modify settings or unsubscribe at:
https://lists.ubuntu.com/mailman/listinfo/storm