Hi

I few weeks ago I decided to spend a little time on an small hack which
would make it possible to use python list comprehensions and avoid
the penalty imposed when executing logic on the client side.

Eg, Why shouldn't it be possible to write this kind of code:

  cost = sum(e.salary * e.comission for e in company.employees)

instead of;

  result = company.employees.find(Employee.company_id == Company.id)
  cost = result.sum(Employee.salary * Employee.comission)

The former version is far more readable but unfortunately a lot slower
when the size of the database increases.

Attached is a small proof of concept I wrote which with help of a
decorator and a meta class makes it possible to mark methods which
then will be optimized.

I'm not suggesting that this should be included in Storm itself, it's
pretty hackish as it is right now. I'm just proving that it can be done
and making sure that this can be found by anyone having similar ideas.

Johan

import compiler
from compiler import pycodegen, misc
from compiler.ast import CallFunc, Node, Name, Assign, Return, GenExpr
import inspect
import new

from storm.references import ReferenceSet

class SQLReplacer(object):
    def __init__(self, class_name, function, reference_sets):
        self.class_name = class_name
        self.function = function
        self.reference_sets = reference_sets
        source = inspect.getsource(function)
        tree = compiler.parse('class %s:\n%s' % (class_name, source))

        m = __import__(function.__module__, '', '', ' ')
        misc.set_filename('<%s>'% m.__file__, tree)

        klass = tree.node.nodes[0]
        decorator = klass.code.nodes[0]

        self.parse(decorator.code)
        self.tree = tree

    def get_func(self):
        gen = pycodegen.ModuleCodeGenerator(self.tree)
        code = gen.getCode()
        locs = dict(sqlmagic=lambda x: x)
        exec code in locs
        func = getattr(locs[self.class_name], self.function.__name__)
        return new.function(func.func_code,
                            self.function.func_globals)

    def parse(self, node):
        for child in node.getChildren():
            if isinstance(child, Node):
                child.parentNode = node
                self.parse(child)

        node_name = node.__class__.__name__
        func = getattr(self, node_name, None)
        if func:
            func(node)

    def _get_genexr(self, node):
        while True:
            if not hasattr(node, 'parentNode'):
                return
            if isinstance(node.parentNode, GenExpr):
                return node.parentNode
            node = node.parentNode

    def _generate_body(self, klass, reference, column, aggregator):
        result_string = (
            "__%(Class)s = self.__class__.%(refname)s._remote_key1[0].cls\n"
            "__reference = self.%(refname)s\n"
            "__key = getattr(self, __reference._relation.local_key[0].name)\n"
            "__result = __reference.find(__%(Class)s.%(column)s == __key)\n"
            "__value = __result.%(aggregator)s()") % dict(Class=klass,
                                                          refname=reference,
                                                          column=column,
                                                          aggregator=aggregator)
        module = compiler.parse(result_string)
        stmt = module.node
        return stmt.nodes

    def Getattr(self, node):
        if (isinstance(node.expr, Name) and
            node.expr.name != 'self' or
            node.attrname not in self.reference_sets):
            return

        genexpr = self._get_genexr(node)
        if not genexpr:
            return
        aggregator_node = genexpr.parentNode
        if not isinstance(aggregator_node, CallFunc):
            return

        aggregator = aggregator_node.node.name
        if aggregator not in ['sum', 'max', 'min']:
            print 'Unsupported aggregator: %s' % aggregator
            return
        aggregator_args = genexpr.code.expr

        construct = genexpr.parentNode.parentNode
        if isinstance(construct, Return):
            construct.value = Name('__value')
        elif isinstance(construct, Assign):
            construct.expr = Name('__value')
        else:
            print 'Unsupported construct: %r' % (construct,)
            return

        reference = self.reference_sets[node.attrname]
        refcls = reference._remote_key1.cls.__name__
        refname = reference._remote_key1.name
        self.visit(genexpr.code.expr, refcls)

        generated = self._generate_body(refcls, node.attrname, refname,
                                        aggregator)
        generated_call = generated[-1].expr
        generated_call.args.append(aggregator_args)

        parent = construct.parentNode
        pos = parent.nodes.index(construct)
        parent.nodes = parent.nodes[:pos] + generated + parent.nodes[pos:]

    def visit(self, node, refname):
        if isinstance(node, Name):
            node.name = '__' + refname

        if isinstance(node, Node):
            for child in node.getChildren():
                self.visit(child, refname)

class MagicMeta(type):
    def __new__(self, name, bases, ns):

        reference_sets = {}
        for attr, value in ns.items():
            if isinstance(value, ReferenceSet):
                reference_sets[attr] = value

        for attr, value in ns.items():
            if hasattr(value, '__magic__'):
                compiler = SQLReplacer(name, value, reference_sets)
                ns[attr] = compiler.get_func()

        return type.__new__(self, name, bases, ns)


class Magic(object):
    __metaclass__ = MagicMeta

def sqlmagic(func):
    func.__magic__ = True
    return func


def test():
    from storm import database
    from storm.database import create_database
    from storm.properties import Int, Unicode, Float
    from storm.store import Store
    from storm.references import Reference, ReferenceSet

    db = create_database("sqlite:")
    store = Store(db)

    class Employee(object):
        __storm_table__ = "employee"
        salary = Float()
        comission = Float()
        company_id = Int()
        id = Int(primary=True)
        name = Unicode()

        def __init__(self, name, salary, comission):
            self.name = name
            self.salary = salary
            self.comission = comission

    class Company(Magic):
        __storm_table__ = "company"
        id = Int(primary=True)
        name = Unicode()
        employees = ReferenceSet(id, Employee.company_id)

        def __init__(self, name):
            self.name = name

        @sqlmagic
        def calculate_to_pay_magic(self):
            # This is complied into:
            #  __Employee = self.__class__.employees._remote_key1[0].cls
            #  __reference = self.employees
            #  __key = getattr(self, __reference._relation.local_key[0].name)
            #  __result = __reference.find(__Employee.company_id == __key)
            #  __value = __result.sum(__Employee.salary * __Employee.comission)
            #  x = __value
            x = sum(e.salary * e.comission for e in self.employees)

            return x

        def calculate_to_pay(self):
            return sum(e.salary * e.comission for e in self.employees)

    Employee.company = Reference(Employee.company_id, Company.id)

    print 'Creating example data'
    store.execute("CREATE TABLE company "
                  "(id INTEGER PRIMARY KEY, name VARCHAR)", noresult=True)

    store.execute("CREATE TABLE employee "
                  "(id INTEGER PRIMARY KEY, name VARCHAR, company_id INTEGER, "
                  "salary FLOAT, comission FLOAT)")

    sweets = store.add(Company(u"Sweets Inc."))
    store.flush()

    for i in range(5000):
        mike = store.add(Employee(u"Mike Mayer %d" % i, 200+i, 10+i))
        sweets.employees.add(mike)

    store.flush()

    #database.DEBUG = True

    import time

    print
    print 'Normal query ----'
    t = time.time()
    print 'Result:', sweets.calculate_to_pay()
    print '%2.4f sec' % (time.time()-t)

    print
    print 'SQL magic query ----'
    t = time.time()
    print 'Result:', sweets.calculate_to_pay_magic()
    print '%2.4f sec' % (time.time()-t)

if __name__ == '__main__':
    test()
-- 
storm mailing list
[email protected]
Modify settings or unsubscribe at: 
https://lists.ubuntu.com/mailman/listinfo/storm

Reply via email to