Hi everybody,

We have talked on and off about making a front-end compiler for VIFF and
today I figured that I would try making such a guy...

So far it can only parse a simple language and spit out something which
looks like equivalent Python code in which both branches of
if-statements are run. So a program like this:

  if a:
    x = y
    if b:
      x = z
    else:
      x = w
    fi
  fi

is transformed into this

  x = (a * y + (1 - a) * x)
  x = (a * b * z + (1 - a * b) * x)
  x = ((1 - a * b) * w + (1 - (1 - a * b)) * x)

which one could plug into a VIFF skeleton and run (not yet done).

The idea is that the conditions in if-statements are pushed down to all
assignments done in the then- and else-branches, nothing more.

This is just a quick test to make people start thinking about what we
want in such a language and to make people think about what kind of
transformation we can do and which we cannot do.

The program is attached below -- you will need to grab simplegeneric and
PLY (Python Lex-Yacc) to use it, but both modules are easy to install:

  http://pypi.python.org/pypi/simplegeneric
  http://www.dabeaz.com/ply/

You run the program by giving it the name of a file to parse and it will
tell you the results on stdout.

#!/usr/bin/python

from pprint import pprint
from simplegeneric import generic
from ply import lex, yacc

keywords = ("print", "if", "else", "fi")

tokens = ("NUMBER", "VARIABLE", "PLUS", "MINUS", "TIMES", "EQUAL", "COLON",
          "NEWLINE") + tuple(map(str.upper, keywords))

t_NUMBER   = r'\d+'
t_PLUS     = r'\+'
t_MINUS    = r'-'
t_TIMES    = r'\*'
t_EQUAL    = r'='
t_COLON    = r':'

def t_VARIABLE(t):
    r"""[a-z]+"""
    if t.value in keywords:
        t.type = t.value.upper()
    return t

def t_NEWLINE(t):
    r"""\n"""
    t.lexer.lineno += 1
    return t

def t_error(t):
    print "Illegal character '%s'" % t.value[0]
    t.lexer.skip(1)

t_ignore   = ' \t'

class BinOp(object):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs

    def __repr__(self):
        return "%s(%s, %s)" % (self.__class__.__name__, self.lhs, self.rhs)

class Add(BinOp): pass
class Sub(BinOp): pass
class Mul(BinOp): pass

class AssignStmt(object):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs

    def __repr__(self):
        return "AssignStmt(%r, %r)" % (self.lhs, self.rhs)

class PrintStmt(object):
    def __init__(self, expr):
        self.expr = expr

    def __repr__(self):
        return "PrintStmt(%r)" % self.expr

class IfStmt(object):
    def __init__(self, bool_expr, then_stmts, else_stmts=None):
        self.bool_expr = bool_expr
        self.then_stmts = then_stmts
        if else_stmts:
            self.else_stmts = else_stmts
        else:
            self.else_stmts = []

    def __repr__(self):
        return "IfStmt(%r, %r, %r)" % (self.bool_expr, self.then_stmts,
                                       self.else_stmts)

def p_program(p):
    """program : statement"""
    p[0] = [p[1]]

def p_program_cont(p):
    """program : program statement"""
    p[0] = p[1] + [p[2]]

def p_statement(p):
    """statement : assignment NEWLINE
                 | print NEWLINE
                 | if NEWLINE"""
    p[0] = p[1]

def p_assignment(p):
    """assignment : VARIABLE EQUAL expr"""
    p[0] = AssignStmt(p[1], p[3])

def p_print(p):
    """print : PRINT expr"""
    p[0] = PrintStmt(p[2])

def p_if(p):
    """if : IF expr COLON NEWLINE program FI"""
    p[0] = IfStmt(p[2], p[5])

def p_if_else(p):
    """if : IF expr COLON NEWLINE program ELSE COLON NEWLINE program FI"""
    p[0] = IfStmt(p[2], p[5], p[9])

def p_expr(p):
    """expr : expr PLUS term
            | expr MINUS term"""
    if p[2] == '+':
        p[0] = Add(p[1], p[3])
    else:
        p[0] = Sub(p[1], p[3])

def p_expr_term(p):
    """expr : term"""
    p[0] = p[1]

def p_term(p):
    """term : term TIMES factor"""
    p[0] = Mul(p[1], p[3])

def p_term_factor(p):
    """term : factor"""
    p[0] = p[1]

def p_factor(p):
    """factor : NUMBER
              | VARIABLE"""
    p[0] = p[1]

def p_error(p):
    print "Syntax error!"
    print p

lex.lex()
yacc.yacc()


### Printing code ###

@generic
def code(node):
    return str(node)

@code.when_type(list)
def code_result(result):
    return "\n".join(map(code, result))

@code.when_type(Add)
def code_add(node):
    return "(%s + %s)" % (code(node.lhs), code(node.rhs))

@code.when_type(Sub)
def code_sub(node):
    return "(%s - %s)" % (code(node.lhs), code(node.rhs))

@code.when_type(Mul)
def code_mul(node):
    return "%s * %s" % (code(node.lhs), code(node.rhs))

@code.when_type(AssignStmt)
def code_assignment(node):
    return "%s = %s" % (code(node.lhs), code(node.rhs))

@code.when_type(PrintStmt)
def code_print(node):
    return "print %s" % node.stmt

@code.when_type(IfStmt)
def code_if(node):
    result = ["if %s:" % code(node.bool_expr)]
    for stmt in node.then_stmts:
        result.append(code(stmt))
    if node.else_stmts:
        result.append("else:")
        for stmt in node.else_stmts:
            result.append(code(stmt))
    result.append("fi")
    return "\n".join(result)


### Desugaring of if-statements ###

@generic
def desugar_if(node, expr=None):
    return node

@desugar_if.when_type(list)
def desugar_if_list(node, expr=None):
    node = map(lambda n: desugar_if(n, expr), node)
    result = []
    for n in node:
        if isinstance(n, list):
            result.extend(n)
        else:
            result.append(n)
    return result

@desugar_if.when_type(IfStmt)
def desugar_if_if(node, expr=None):
    if expr:
        b = Mul(expr, node.bool_expr)
    else:
        b = node.bool_expr
    nb = Sub(1, b)

    node.then_stmts = map(lambda n: desugar_if(n, b), node.then_stmts)
    node.else_stmts = map(lambda n: desugar_if(n, nb), node.else_stmts)
    return node.then_stmts + node.else_stmts

@desugar_if.when_type(AssignStmt)
def desugar_if_assignment(node, expr=None):
    if expr:
        nexpr = Sub(1, expr)
        node.rhs = Add(Mul(expr, node.rhs), Mul(nexpr, node.lhs))
    return node

### Conversion of parse tree into nested tuples ###

@generic
def tree_to_tuples(node):
    return node

@tree_to_tuples.when_type(list)
def tree_to_tuples_list(node):
    return map(tree_to_tuples, node)

@tree_to_tuples.when_type(IfStmt)
def tree_to_tuples_if(node):
    if node.else_stmts:
        return ('IfStmt',
                tree_to_tuples(node.bool_expr),
                tree_to_tuples(node.then_stmts),
                tree_to_tuples(node.else_stmts))
    else:
        return ('IfStmt',
                tree_to_tuples(node.bool_expr),
                tree_to_tuples(node.then_stmts))

@tree_to_tuples.when_type(AssignStmt)
def tree_to_tuples_assign(node):
    return ('AssignStmt', tree_to_tuples(node.lhs), tree_to_tuples(node.rhs))


if __name__ == "__main__":
    import sys
    fp = open(sys.argv[1], 'r')
    data = fp.read()

    print "Original program:"
    print data

    tree = yacc.parse(data) #, debug=2)

    print "Parse tree:"
    pprint(tree_to_tuples(tree))
    print

    print "Raw Python code:"
    print code(tree)

    tree = desugar_if(tree)
    print "Desugared Python code:"
    print code(tree)
-- 
Martin Geisler
_______________________________________________
viff-devel mailing list (http://viff.dk/)
viff-devel@viff.dk
http://lists.viff.dk/listinfo.cgi/viff-devel-viff.dk

Reply via email to