Author: Amaury Forgeot d'Arc <[email protected]>
Branch: py3.3
Changeset: r70807:fe0435cfe837
Date: 2014-04-21 15:10 +0200
http://bitbucket.org/pypy/pypy/changeset/fe0435cfe837/

Log:    Add an AST validator, will prevent crashes when bad ast object are
        built and compiled.

diff --git a/pypy/interpreter/astcompiler/test/test_validate.py 
b/pypy/interpreter/astcompiler/test/test_validate.py
new file mode 100644
--- /dev/null
+++ b/pypy/interpreter/astcompiler/test/test_validate.py
@@ -0,0 +1,425 @@
+import os
+from pypy.interpreter.error import OperationError
+from pypy.interpreter.baseobjspace import W_Root
+from pypy.interpreter.astcompiler import ast
+from pypy.interpreter.astcompiler import validate
+
+class TestASTValidator:
+    def mod(self, mod, msg=None, mode="exec", exc=validate.ValidationError):
+        space = self.space
+        if isinstance(exc, W_Root):
+            w_exc = exc
+            exc = OperationError
+        else:
+            w_exc = None
+        with raises(exc) as cm:
+            validate.validate_ast(space, mod)
+        if w_exc is not None:
+            w_value = cm.value.get_w_value(space)
+            assert cm.value.match(space, w_exc)
+            exc_msg = str(cm.value)
+        else:
+            exc_msg = str(cm.value)
+        if msg is not None:
+            assert msg in exc_msg
+
+    def expr(self, node, msg=None, exc=validate.ValidationError):
+        mod = ast.Module([ast.Expr(node, 0, 0)])
+        self.mod(mod, msg, exc=exc)
+
+    def stmt(self, stmt, msg=None):
+        mod = ast.Module([stmt])
+        self.mod(mod, msg)
+
+    def test_module(self):
+        m = ast.Interactive([ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)])
+        self.mod(m, "must have Load context", "single")
+        m = ast.Expression(ast.Name("x", ast.Store, 0, 0))
+        self.mod(m, "must have Load context", "eval")
+
+    def _check_arguments(self, fac, check):
+        def arguments(args=None, vararg=None, varargannotation=None,
+                      kwonlyargs=None, kwarg=None, kwargannotation=None,
+                      defaults=None, kw_defaults=None):
+            if args is None:
+                args = []
+            if kwonlyargs is None:
+                kwonlyargs = []
+            if defaults is None:
+                defaults = []
+            if kw_defaults is None:
+                kw_defaults = []
+            args = ast.arguments(args, vararg, varargannotation, kwonlyargs,
+                                 kwarg, kwargannotation, defaults, kw_defaults)
+            return fac(args)
+        args = [ast.arg("x", ast.Name("x", ast.Store, 0, 0))]
+        check(arguments(args=args), "must have Load context")
+        check(arguments(varargannotation=ast.Num(self.space.wrap(3), 0, 0)),
+              "varargannotation but no vararg")
+        check(arguments(varargannotation=ast.Name("x", ast.Store, 0, 0), 
vararg="x"),
+                         "must have Load context")
+        check(arguments(kwonlyargs=args), "must have Load context")
+        check(arguments(kwargannotation=ast.Num(self.space.wrap(42), 0, 0)),
+                       "kwargannotation but no kwarg")
+        check(arguments(kwargannotation=ast.Name("x", ast.Store, 0, 0),
+                          kwarg="x"), "must have Load context")
+        check(arguments(defaults=[ast.Num(self.space.wrap(3), 0, 0)]),
+                       "more positional defaults than args")
+        check(arguments(kw_defaults=[ast.Num(self.space.wrap(4), 0, 0)]),
+                       "length of kwonlyargs is not the same as kw_defaults")
+        args = [ast.arg("x", ast.Name("x", ast.Load, 0, 0))]
+        check(arguments(args=args, defaults=[ast.Name("x", ast.Store, 0, 0)]),
+                       "must have Load context")
+        args = [ast.arg("a", ast.Name("x", ast.Load, 0, 0)),
+                ast.arg("b", ast.Name("y", ast.Load, 0, 0))]
+        check(arguments(kwonlyargs=args,
+                          kw_defaults=[None, ast.Name("x", ast.Store, 0, 0)]),
+                          "must have Load context")
+
+    def test_funcdef(self):
+        a = ast.arguments([], None, None, [], None, None, [], [])
+        f = ast.FunctionDef("x", a, [], [], None, 0, 0)
+        self.stmt(f, "empty body on FunctionDef")
+        f = ast.FunctionDef("x", a, [ast.Pass(0, 0)], [ast.Name("x", 
ast.Store, 0, 0)],
+                            None, 0, 0)
+        self.stmt(f, "must have Load context")
+        f = ast.FunctionDef("x", a, [ast.Pass(0, 0)], [],
+                            ast.Name("x", ast.Store, 0, 0), 0, 0)
+        self.stmt(f, "must have Load context")
+        def fac(args):
+            return ast.FunctionDef("x", args, [ast.Pass(0, 0)], [], None, 0, 0)
+        self._check_arguments(fac, self.stmt)
+
+    def test_classdef(self):
+        def cls(bases=None, keywords=None, starargs=None, kwargs=None,
+                body=None, decorator_list=None):
+            if bases is None:
+                bases = []
+            if keywords is None:
+                keywords = []
+            if body is None:
+                body = [ast.Pass(0, 0)]
+            if decorator_list is None:
+                decorator_list = []
+            return ast.ClassDef("myclass", bases, keywords, starargs,
+                                kwargs, body, decorator_list, 0, 0)
+        self.stmt(cls(bases=[ast.Name("x", ast.Store, 0, 0)]),
+                  "must have Load context")
+        self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store, 0, 
0))]),
+                  "must have Load context")
+        self.stmt(cls(starargs=ast.Name("x", ast.Store, 0, 0)),
+                  "must have Load context")
+        self.stmt(cls(kwargs=ast.Name("x", ast.Store, 0, 0)),
+                  "must have Load context")
+        self.stmt(cls(body=[]), "empty body on ClassDef")
+        self.stmt(cls(body=[None]), "None disallowed")
+        self.stmt(cls(decorator_list=[ast.Name("x", ast.Store, 0, 0)]),
+                  "must have Load context")
+
+    def test_delete(self):
+        self.stmt(ast.Delete([], 0, 0), "empty targets on Delete")
+        self.stmt(ast.Delete([None], 0, 0), "None disallowed")
+        self.stmt(ast.Delete([ast.Name("x", ast.Load, 0, 0)], 0, 0),
+                  "must have Del context")
+
+    def test_assign(self):
+        self.stmt(ast.Assign([], ast.Num(self.space.wrap(3), 0, 0), 0, 0), 
"empty targets on Assign")
+        self.stmt(ast.Assign([None], ast.Num(self.space.wrap(3), 0, 0), 0, 0), 
"None disallowed")
+        self.stmt(ast.Assign([ast.Name("x", ast.Load, 0, 0)], 
ast.Num(self.space.wrap(3), 0, 0), 0, 0),
+                  "must have Store context")
+        self.stmt(ast.Assign([ast.Name("x", ast.Store, 0, 0)],
+                                ast.Name("y", ast.Store, 0, 0), 0, 0),
+                  "must have Load context")
+
+    def test_augassign(self):
+        aug = ast.AugAssign(ast.Name("x", ast.Load, 0, 0), ast.Add,
+                            ast.Name("y", ast.Load, 0, 0), 0, 0)
+        self.stmt(aug, "must have Store context")
+        aug = ast.AugAssign(ast.Name("x", ast.Store, 0, 0), ast.Add,
+                            ast.Name("y", ast.Store, 0, 0), 0, 0)
+        self.stmt(aug, "must have Load context")
+
+    def test_for(self):
+        x = ast.Name("x", ast.Store, 0, 0)
+        y = ast.Name("y", ast.Load, 0, 0)
+        p = ast.Pass(0, 0)
+        self.stmt(ast.For(x, y, [], [], 0, 0), "empty body on For")
+        self.stmt(ast.For(ast.Name("x", ast.Load, 0, 0), y, [p], [], 0, 0),
+                  "must have Store context")
+        self.stmt(ast.For(x, ast.Name("y", ast.Store, 0, 0), [p], [], 0, 0),
+                  "must have Load context")
+        e = ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)
+        self.stmt(ast.For(x, y, [e], [], 0, 0), "must have Load context")
+        self.stmt(ast.For(x, y, [p], [e], 0, 0), "must have Load context")
+
+    def test_while(self):
+        self.stmt(ast.While(ast.Num(self.space.wrap(3), 0, 0), [], [], 0, 0), 
"empty body on While")
+        self.stmt(ast.While(ast.Name("x", ast.Store, 0, 0), [ast.Pass(0, 0)], 
[], 0, 0),
+                  "must have Load context")
+        self.stmt(ast.While(ast.Num(self.space.wrap(3), 0, 0), [ast.Pass(0, 
0)],
+                             [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], 
0, 0),
+                             "must have Load context")
+
+    def test_if(self):
+        self.stmt(ast.If(ast.Num(self.space.wrap(3), 0, 0), [], [], 0, 0), 
"empty body on If")
+        i = ast.If(ast.Name("x", ast.Store, 0, 0), [ast.Pass(0, 0)], [], 0, 0)
+        self.stmt(i, "must have Load context")
+        i = ast.If(ast.Num(self.space.wrap(3), 0, 0), [ast.Expr(ast.Name("x", 
ast.Store, 0, 0), 0, 0)], [], 0, 0)
+        self.stmt(i, "must have Load context")
+        i = ast.If(ast.Num(self.space.wrap(3), 0, 0), [ast.Pass(0, 0)],
+                   [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], 0, 0)
+        self.stmt(i, "must have Load context")
+
+    @skip("enable when parser uses the new With construct")
+    def test_with(self):
+        p = ast.Pass(0, 0)
+        self.stmt(ast.With([], [p]), "empty items on With")
+        i = ast.withitem(ast.Num(self.space.wrap(3), 0, 0), None)
+        self.stmt(ast.With([i], []), "empty body on With")
+        i = ast.withitem(ast.Name("x", ast.Store, 0, 0), None)
+        self.stmt(ast.With([i], [p]), "must have Load context")
+        i = ast.withitem(ast.Num(self.space.wrap(3), 0, 0), ast.Name("x", 
ast.Load, 0, 0))
+        self.stmt(ast.With([i], [p]), "must have Store context")
+
+    def test_raise(self):
+        r = ast.Raise(None, ast.Num(self.space.wrap(3), 0, 0), 0, 0)
+        self.stmt(r, "Raise with cause but no exception")
+        r = ast.Raise(ast.Name("x", ast.Store, 0, 0), None, 0, 0)
+        self.stmt(r, "must have Load context")
+        r = ast.Raise(ast.Num(self.space.wrap(4), 0, 0), ast.Name("x", 
ast.Store, 0, 0), 0, 0)
+        self.stmt(r, "must have Load context")
+
+    @skip("enable when parser uses the new Try construct")
+    def test_try(self):
+        p = ast.Pass(0, 0)
+        t = ast.Try([], [], [], [p])
+        self.stmt(t, "empty body on Try")
+        t = ast.Try([ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], [], [], 
[p])
+        self.stmt(t, "must have Load context")
+        t = ast.Try([p], [], [], [])
+        self.stmt(t, "Try has neither except handlers nor finalbody")
+        t = ast.Try([p], [], [p], [p])
+        self.stmt(t, "Try has orelse but no except handlers")
+        t = ast.Try([p], [ast.ExceptHandler(None, "x", [])], [], [])
+        self.stmt(t, "empty body on ExceptHandler")
+        e = [ast.ExceptHandler(ast.Name("x", ast.Store, 0, 0), "y", [p])]
+        self.stmt(ast.Try([p], e, [], []), "must have Load context")
+        e = [ast.ExceptHandler(None, "x", [p])]
+        t = ast.Try([p], e, [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], 
[p])
+        self.stmt(t, "must have Load context")
+        t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 
0)])
+        self.stmt(t, "must have Load context")
+
+    def test_assert(self):
+        self.stmt(ast.Assert(ast.Name("x", ast.Store, 0, 0), None, 0, 0),
+                  "must have Load context")
+        assrt = ast.Assert(ast.Name("x", ast.Load, 0, 0),
+                           ast.Name("y", ast.Store, 0, 0), 0, 0)
+        self.stmt(assrt, "must have Load context")
+
+    def test_import(self):
+        self.stmt(ast.Import([], 0, 0), "empty names on Import")
+
+    def test_importfrom(self):
+        imp = ast.ImportFrom(None, [ast.alias("x", None)], -42, 0, 0)
+        self.stmt(imp, "level less than -1")
+        self.stmt(ast.ImportFrom(None, [], 0, 0, 0), "empty names on 
ImportFrom")
+
+    def test_global(self):
+        self.stmt(ast.Global([], 0, 0), "empty names on Global")
+
+    def test_nonlocal(self):
+        self.stmt(ast.Nonlocal([], 0, 0), "empty names on Nonlocal")
+
+    def test_expr(self):
+        e = ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)
+        self.stmt(e, "must have Load context")
+
+    def test_boolop(self):
+        b = ast.BoolOp(ast.And, [], 0, 0)
+        self.expr(b, "less than 2 values")
+        b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(3), 0, 0)], 0, 0)
+        self.expr(b, "less than 2 values")
+        b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(4), 0, 0), None], 0, 
0)
+        self.expr(b, "None disallowed")
+        b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(4), 0, 0), 
ast.Name("x", ast.Store, 0, 0)], 0, 0)
+        self.expr(b, "must have Load context")
+
+    def test_unaryop(self):
+        u = ast.UnaryOp(ast.Not, ast.Name("x", ast.Store, 0, 0), 0, 0)
+        self.expr(u, "must have Load context")
+
+    def test_lambda(self):
+        a = ast.arguments([], None, None, [], None, None, [], [])
+        self.expr(ast.Lambda(a, ast.Name("x", ast.Store, 0, 0), 0, 0),
+                  "must have Load context")
+        def fac(args):
+            return ast.Lambda(args, ast.Name("x", ast.Load, 0, 0), 0, 0)
+        self._check_arguments(fac, self.expr)
+
+    def test_ifexp(self):
+        l = ast.Name("x", ast.Load, 0, 0)
+        s = ast.Name("y", ast.Store, 0, 0)
+        for args in (s, l, l), (l, s, l), (l, l, s):
+            self.expr(ast.IfExp(*(args + (0, 0))), "must have Load context")
+
+    def test_dict(self):
+        d = ast.Dict([], [ast.Name("x", ast.Load, 0, 0)], 0, 0)
+        self.expr(d, "same number of keys as values")
+        d = ast.Dict([None], [ast.Name("x", ast.Load, 0, 0)], 0, 0)
+        self.expr(d, "None disallowed")
+        d = ast.Dict([ast.Name("x", ast.Load, 0, 0)], [None], 0, 0)
+        self.expr(d, "None disallowed")
+
+    def test_set(self):
+        self.expr(ast.Set([None], 0, 0), "None disallowed")
+        s = ast.Set([ast.Name("x", ast.Store, 0, 0)], 0, 0)
+        self.expr(s, "must have Load context")
+
+    def _check_comprehension(self, fac):
+        self.expr(fac([]), "comprehension with no generators")
+        g = ast.comprehension(ast.Name("x", ast.Load, 0, 0),
+                              ast.Name("x", ast.Load, 0, 0), [])
+        self.expr(fac([g]), "must have Store context")
+        g = ast.comprehension(ast.Name("x", ast.Store, 0, 0),
+                              ast.Name("x", ast.Store, 0, 0), [])
+        self.expr(fac([g]), "must have Load context")
+        x = ast.Name("x", ast.Store, 0, 0)
+        y = ast.Name("y", ast.Load, 0, 0)
+        g = ast.comprehension(x, y, [None])
+        self.expr(fac([g]), "None disallowed")
+        g = ast.comprehension(x, y, [ast.Name("x", ast.Store, 0, 0)])
+        self.expr(fac([g]), "must have Load context")
+
+    def _simple_comp(self, fac):
+        g = ast.comprehension(ast.Name("x", ast.Store, 0, 0),
+                              ast.Name("x", ast.Load, 0, 0), [])
+        self.expr(fac(ast.Name("x", ast.Store, 0, 0), [g], 0, 0),
+                  "must have Load context")
+        def wrap(gens):
+            return fac(ast.Name("x", ast.Store, 0, 0), gens, 0, 0)
+        self._check_comprehension(wrap)
+
+    def test_listcomp(self):
+        self._simple_comp(ast.ListComp)
+
+    def test_setcomp(self):
+        self._simple_comp(ast.SetComp)
+
+    def test_generatorexp(self):
+        self._simple_comp(ast.GeneratorExp)
+
+    def test_dictcomp(self):
+        g = ast.comprehension(ast.Name("y", ast.Store, 0, 0),
+                              ast.Name("p", ast.Load, 0, 0), [])
+        c = ast.DictComp(ast.Name("x", ast.Store, 0, 0),
+                         ast.Name("y", ast.Load, 0, 0), [g], 0, 0)
+        self.expr(c, "must have Load context")
+        c = ast.DictComp(ast.Name("x", ast.Load, 0, 0),
+                         ast.Name("y", ast.Store, 0, 0), [g], 0, 0)
+        self.expr(c, "must have Load context")
+        def factory(comps):
+            k = ast.Name("x", ast.Load, 0, 0)
+            v = ast.Name("y", ast.Load, 0, 0)
+            return ast.DictComp(k, v, comps, 0, 0)
+        self._check_comprehension(factory)
+
+    def test_yield(self):
+        self.expr(ast.Yield(ast.Name("x", ast.Store, 0, 0), 0, 0), "must have 
Load")
+        self.expr(ast.YieldFrom(ast.Name("x", ast.Store, 0, 0), 0, 0), "must 
have Load")
+
+    def test_compare(self):
+        left = ast.Name("x", ast.Load, 0, 0)
+        comp = ast.Compare(left, [ast.In], [], 0, 0)
+        self.expr(comp, "no comparators")
+        comp = ast.Compare(left, [ast.In], [ast.Num(self.space.wrap(4), 0, 0), 
ast.Num(self.space.wrap(5), 0, 0)], 0, 0)
+        self.expr(comp, "different number of comparators and operands")
+        comp = ast.Compare(ast.Num(self.space.wrap("blah"), 0, 0), [ast.In], 
[left], 0, 0)
+        self.expr(comp, "non-numeric", exc=self.space.w_TypeError)
+        comp = ast.Compare(left, [ast.In], [ast.Num(self.space.wrap("blah"), 
0, 0)], 0, 0)
+        self.expr(comp, "non-numeric", exc=self.space.w_TypeError)
+
+    def test_call(self):
+        func = ast.Name("x", ast.Load, 0, 0)
+        args = [ast.Name("y", ast.Load, 0, 0)]
+        keywords = [ast.keyword("w", ast.Name("z", ast.Load, 0, 0))]
+        stararg = ast.Name("p", ast.Load, 0, 0)
+        kwarg = ast.Name("q", ast.Load, 0, 0)
+        call = ast.Call(ast.Name("x", ast.Store, 0, 0), args, keywords, 
stararg,
+                        kwarg, 0, 0)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, [None], keywords, stararg, kwarg, 0, 0)
+        self.expr(call, "None disallowed")
+        bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store, 0, 0))]
+        call = ast.Call(func, args, bad_keywords, stararg, kwarg, 0, 0)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, args, keywords, ast.Name("z", ast.Store, 0, 0), 
kwarg, 0, 0)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, args, keywords, stararg,
+                        ast.Name("w", ast.Store, 0, 0), 0, 0)
+        self.expr(call, "must have Load context")
+
+    def test_num(self):
+        space = self.space
+        w_objs = space.appexec([], """():
+        class subint(int):
+            pass
+        class subfloat(float):
+            pass
+        class subcomplex(complex):
+            pass
+        return ("0", "hello", subint(), subfloat(), subcomplex())
+        """)
+        for w_obj in space.unpackiterable(w_objs):
+            self.expr(ast.Num(w_obj, 0, 0), "non-numeric", 
exc=self.space.w_TypeError)
+
+    def test_attribute(self):
+        attr = ast.Attribute(ast.Name("x", ast.Store, 0, 0), "y", ast.Load, 0, 
0)
+        self.expr(attr, "must have Load context")
+
+    def test_subscript(self):
+        sub = ast.Subscript(ast.Name("x", ast.Store, 0, 0), 
ast.Index(ast.Num(self.space.wrap(3), 0, 0)),
+                            ast.Load, 0, 0)
+        self.expr(sub, "must have Load context")
+        x = ast.Name("x", ast.Load, 0, 0)
+        sub = ast.Subscript(x, ast.Index(ast.Name("y", ast.Store, 0, 0)),
+                            ast.Load, 0, 0)
+        self.expr(sub, "must have Load context")
+        s = ast.Name("x", ast.Store, 0, 0)
+        for args in (s, None, None), (None, s, None), (None, None, s):
+            sl = ast.Slice(*args)
+            self.expr(ast.Subscript(x, sl, ast.Load, 0, 0),
+                      "must have Load context")
+        sl = ast.ExtSlice([])
+        self.expr(ast.Subscript(x, sl, ast.Load, 0, 0), "empty dims on 
ExtSlice")
+        sl = ast.ExtSlice([ast.Index(s)])
+        self.expr(ast.Subscript(x, sl, ast.Load, 0, 0), "must have Load 
context")
+
+    def test_starred(self):
+        left = ast.List([ast.Starred(ast.Name("x", ast.Load, 0, 0), ast.Store, 
0, 0)],
+                        ast.Store, 0, 0)
+        assign = ast.Assign([left], ast.Num(self.space.wrap(4), 0, 0), 0, 0)
+        self.stmt(assign, "must have Store context")
+
+    def _sequence(self, fac):
+        self.expr(fac([None], ast.Load, 0, 0), "None disallowed")
+        self.expr(fac([ast.Name("x", ast.Store, 0, 0)], ast.Load, 0, 0),
+                  "must have Load context")
+
+    def test_list(self):
+        self._sequence(ast.List)
+
+    def test_tuple(self):
+        self._sequence(ast.Tuple)
+
+    def test_stdlib_validates(self):
+        stdlib = os.path.join(os.path.dirname(ast.__file__), 
'../../../lib-python/3')
+        tests = ["os.py", "test/test_grammar.py", "test/test_unpack_ex.py"]
+        for module in tests:
+            fn = os.path.join(stdlib, module)
+            print 'compiling', fn
+            with open(fn, "r") as fp:
+                source = fp.read()
+            ec = self.space.getexecutioncontext()
+            ast_node = ec.compiler.compile_to_ast(source, fn, "exec", 0)
+            ec.compiler.validate_ast(ast_node)
diff --git a/pypy/interpreter/astcompiler/validate.py 
b/pypy/interpreter/astcompiler/validate.py
new file mode 100644
--- /dev/null
+++ b/pypy/interpreter/astcompiler/validate.py
@@ -0,0 +1,410 @@
+"""A visitor to validate an AST object."""
+
+from pypy.interpreter.error import OperationError, oefmt
+from pypy.interpreter.astcompiler import ast
+from rpython.tool.pairtype import pair, pairtype
+from pypy.interpreter.baseobjspace import W_Root
+
+
+def validate_ast(space, node):
+    node.walkabout(AstValidator(space))
+
+
+class ValidationError(Exception):
+    pass
+
+
+def expr_context_name(ctx):
+    if not 1 <= ctx <= len(ast.expr_context_to_class):
+        return '??'
+    return ast.expr_context_to_class[ctx - 1].typedef.name
+
+def _check_context(expected_ctx, actual_ctx):
+    if expected_ctx != actual_ctx:
+        raise ValidationError(
+            "expression must have %s context but has %s instead" %
+            (expr_context_name(expected_ctx), expr_context_name(actual_ctx)))
+
+
+class __extend__(ast.AST):
+
+    def check_context(self, visitor, ctx):
+        raise AssertionError("should only be on expressions")
+
+    def walkabout_with_ctx(self, visitor, ctx):
+        self.walkabout(visitor)  # With "load" context.
+
+
+class __extend__(ast.expr):
+
+    def check_context(self, visitor, ctx):
+        if ctx != ast.Load:
+            raise ValidationError(
+                "expression which can't be assigned to in %s context" %
+                expr_context_name(ctx))
+        
+
+class __extend__(ast.Name):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+
+class __extend__(ast.List):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+    def walkabout_with_ctx(self, visitor, ctx):
+        visitor._validate_exprs(self.elts, ctx)
+
+
+class __extend__(ast.Tuple):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+    def walkabout_with_ctx(self, visitor, ctx):
+        visitor._validate_exprs(self.elts, ctx)
+
+
+class __extend__(ast.Starred):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+    def walkabout_with_ctx(self, visitor, ctx):
+        visitor._validate_expr(self.value, ctx)
+
+
+class __extend__(ast.Subscript):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+
+class __extend__(ast.Attribute):
+
+    def check_context(self, visitor, ctx):
+        _check_context(ctx, self.ctx)
+
+
+class AstValidator(ast.ASTVisitor):
+    def __init__(self, space):
+        self.space = space
+
+    def _validate_stmts(self, stmts):
+        if not stmts:
+            return
+        for stmt in stmts:
+            if not stmt:
+                raise ValidationError("None disallowed in statement list")
+            stmt.walkabout(self)
+
+    def _len(self, node):
+        if node is None:
+            return 0
+        return len(node)
+
+    def _validate_expr(self, expr, ctx=ast.Load):
+        expr.check_context(self, ctx)
+        expr.walkabout_with_ctx(self, ctx)
+
+    def _validate_exprs(self, exprs, ctx=ast.Load, null_ok=False):
+        if not exprs:
+            return
+        for expr in exprs:
+            if expr:
+                self._validate_expr(expr, ctx)
+            elif not null_ok:
+                raise ValidationError("None disallowed in expression list")
+
+    def _validate_body(self, body, owner):
+        self._validate_nonempty_seq(body, "body", owner)
+        self._validate_stmts(body)
+
+    def _validate_nonempty_seq(self, seq, what, owner):
+        if not seq:
+            raise ValidationError("empty %s on %s" % (what, owner))
+
+    def _validate_nonempty_seq_s(self, seq, what, owner):
+        if not seq:
+            raise ValidationError("empty %s on %s" % (what, owner))
+
+    def visit_Interactive(self, node):
+        self._validate_stmts(node.body)
+
+    def visit_Module(self, node):
+        self._validate_stmts(node.body)
+
+    def visit_Expression(self, node):
+        self._validate_expr(node.body)
+
+    # Statements
+
+    def visit_arg(self, node):
+        if node.annotation:
+            self._validate_expr(node.annotation)
+
+    def visit_arguments(self, node):
+        self.visit_sequence(node.args)
+        if node.varargannotation:
+            if not node.vararg:
+                raise ValidationError("varargannotation but no vararg on 
arguments")
+            self._validate_expr(node.varargannotation)
+        self.visit_sequence(node.kwonlyargs)
+        if node.kwargannotation:
+            if not node.kwarg:
+                raise ValidationError("kwargannotation but no kwarg on 
arguments")
+            self._validate_expr(node.kwargannotation)
+        if self._len(node.defaults) > self._len(node.args):
+            raise ValidationError("more positional defaults than args on 
arguments")
+        if self._len(node.kw_defaults) != self._len(node.kwonlyargs):
+            raise ValidationError("length of kwonlyargs is not the same as "
+                                  "kw_defaults on arguments")
+        self._validate_exprs(node.defaults)
+        self._validate_exprs(node.kw_defaults, null_ok=True)
+
+    def visit_FunctionDef(self, node):
+        self._validate_body(node.body, "FunctionDef")
+        node.args.walkabout(self)
+        self._validate_exprs(node.decorator_list)
+        if node.returns:
+            self._validate_expr(node.returns)
+
+    def visit_keyword(self, node):
+        self._validate_expr(node.value)
+
+    def visit_ClassDef(self, node):
+        self._validate_body(node.body, "ClassDef")
+        self._validate_exprs(node.bases)
+        self.visit_sequence(node.keywords)
+        self._validate_exprs(node.decorator_list)
+        if node.starargs:
+            self._validate_expr(node.starargs)
+        if node.kwargs:
+            self._validate_expr(node.kwargs)
+
+    def visit_Return(self, node):
+        if node.value:
+            self._validate_expr(node.value)
+
+    def visit_Delete(self, node):
+        self._validate_nonempty_seq(node.targets, "targets", "Delete")
+        self._validate_exprs(node.targets, ast.Del)
+
+    def visit_Assign(self, node):
+        self._validate_nonempty_seq(node.targets, "targets", "Assign")
+        self._validate_exprs(node.targets, ast.Store)
+        self._validate_expr(node.value)
+
+    def visit_AugAssign(self, node):
+        self._validate_expr(node.target, ast.Store)
+        self._validate_expr(node.value)
+
+    def visit_For(self, node):
+        self._validate_expr(node.target, ast.Store)
+        self._validate_expr(node.iter)
+        self._validate_body(node.body, "For")
+        self._validate_stmts(node.orelse)
+
+    def visit_While(self, node):
+        self._validate_expr(node.test)
+        self._validate_body(node.body, "While")
+        self._validate_stmts(node.orelse)
+
+    def visit_If(self, node):
+        self._validate_expr(node.test)
+        self._validate_body(node.body, "If")
+        self._validate_stmts(node.orelse)
+
+    def visit_With(self, node):
+        self._validate_expr(node.context_expr)
+        if node.optional_vars:
+            self._validate_expr(node.optional_vars, ast.Store)
+        self._validate_body(node.body, "With")
+
+    def visit_Raise(self, node):
+        if node.exc:
+            self._validate_expr(node.exc)
+            if node.cause:
+                self._validate_expr(node.cause)
+        elif node.cause:
+            raise ValidationError("Raise with cause but no exception")
+
+    def visit_TryExcept(self, node):
+        self._validate_body(node.body, "TryExcept")
+        for handler in node.handlers:
+            handler.walkabout(self)
+        self._validate_stmts(node.orelse)
+
+    def visit_TryFinally(self, node):
+        self._validate_body(node.body, "TryFinally")
+        self._validate_body(node.finalbody, "TryFinally")
+
+    def visit_ExceptHandler(self, node):
+        if node.type:
+            self._validate_expr(node.type)
+        self._validate_body(node.body, "ExceptHandler")
+
+    def visit_Assert(self, node):
+        self._validate_expr(node.test)
+        if node.msg:
+            self._validate_expr(node.msg)
+
+    def visit_Import(self, node):
+        self._validate_nonempty_seq(node.names, "names", "Import")
+        
+    def visit_ImportFrom(self, node):
+        if node.level < -1:
+            raise ValidationError("ImportFrom level less than -1")
+        self._validate_nonempty_seq(node.names, "names", "ImportFrom")
+        
+    def visit_Global(self, node):
+        self._validate_nonempty_seq_s(node.names, "names", "Global")
+
+    def visit_Nonlocal(self, node):
+        self._validate_nonempty_seq_s(node.names, "names", "Nonlocal")
+
+    def visit_Expr(self, node):
+        self._validate_expr(node.value)
+
+    def visit_Pass(self, node):
+        pass
+
+    def visit_Break(self, node):
+        pass
+
+    def visit_Continue(self, node):
+        pass
+
+    # Expressions
+
+    def visit_Name(self, node):
+        pass
+
+    def visit_Ellipsis(self, node):
+        pass
+
+    def visit_BoolOp(self, node):
+        if len(node.values) < 2:
+            raise ValidationError("BoolOp with less than 2 values")
+        self._validate_exprs(node.values)
+
+    def visit_UnaryOp(self, node):
+        self._validate_expr(node.operand)
+
+    def visit_BinOp(self, node):
+        self._validate_expr(node.left)
+        self._validate_expr(node.right)
+
+    def visit_Lambda(self, node):
+        node.args.walkabout(self)
+        self._validate_expr(node.body)
+
+    def visit_IfExp(self, node):
+        self._validate_expr(node.test)
+        self._validate_expr(node.body)
+        self._validate_expr(node.orelse)
+
+    def visit_Dict(self, node):
+        if self._len(node.keys) != self._len(node.values):
+            raise ValidationError(
+                "Dict doesn't have the same number of keys as values")
+        self._validate_exprs(node.keys)
+        self._validate_exprs(node.values)
+
+    def visit_Set(self, node):
+        self._validate_exprs(node.elts)
+
+    def _validate_comprehension(self, generators):
+        if not generators:
+            raise ValidationError("comprehension with no generators")
+        for comp in generators:
+            self._validate_expr(comp.target, ast.Store)
+            self._validate_expr(comp.iter)
+            self._validate_exprs(comp.ifs)
+
+    def visit_ListComp(self, node):
+        self._validate_comprehension(node.generators)
+        self._validate_expr(node.elt)
+
+    def visit_SetComp(self, node):
+        self._validate_comprehension(node.generators)
+        self._validate_expr(node.elt)
+
+    def visit_GeneratorExp(self, node):
+        self._validate_comprehension(node.generators)
+        self._validate_expr(node.elt)
+
+    def visit_DictComp(self, node):
+        self._validate_comprehension(node.generators)
+        self._validate_expr(node.key)
+        self._validate_expr(node.value)
+
+    def visit_Yield(self, node):
+        if node.value:
+            self._validate_expr(node.value)
+
+    def visit_YieldFrom(self, node):
+        self._validate_expr(node.value)
+
+    def visit_Compare(self, node):
+        if not node.comparators:
+            raise ValidationError("Compare with no comparators")
+        if len(node.comparators) != len(node.ops):
+            raise ValidationError("Compare has a different number "
+                                  "of comparators and operands")
+        self._validate_exprs(node.comparators)
+        self._validate_expr(node.left)
+
+    def visit_Call(self, node):
+        self._validate_expr(node.func)
+        self._validate_exprs(node.args)
+        self.visit_sequence(node.keywords)
+        if node.starargs:
+            self._validate_expr(node.starargs)
+        if node.kwargs:
+            self._validate_expr(node.kwargs)
+
+    def visit_Num(self, node):
+        space = self.space
+        w_type = space.type(node.n)
+        if w_type not in [space.w_int, space.w_float, space.w_complex]:
+            raise oefmt(space.w_TypeError, "non-numeric type in Num")
+
+    def visit_Str(self, node):
+        space = self.space
+        w_type = space.type(node.s)
+        if w_type != space.w_unicode:
+            raise oefmt(space.w_TypeError, "non-string type in Str")
+
+    def visit_Bytes(self, node):
+        space = self.space
+        w_type = space.type(node.s)
+        if w_type != space.w_bytes:
+            raise oefmt(space.w_TypeError, "non-bytes type in Bytes")
+
+    def visit_Attribute(self, node):
+        self._validate_expr(node.value)
+
+    def visit_Subscript(self, node):
+        node.slice.walkabout(self)
+        self._validate_expr(node.value)
+
+    # Subscripts
+    def visit_Slice(self, node):
+        if node.lower:
+            self._validate_expr(node.lower)
+        if node.upper:
+            self._validate_expr(node.upper)
+        if node.step:
+            self._validate_expr(node.step)
+
+    def visit_ExtSlice(self, node):
+        self._validate_nonempty_seq(node.dims, "dims", "ExtSlice")
+        for dim in node.dims:
+            dim.walkabout(self)
+
+    def visit_Index(self, node):
+        self._validate_expr(node.value)
diff --git a/pypy/interpreter/pycompiler.py b/pypy/interpreter/pycompiler.py
--- a/pypy/interpreter/pycompiler.py
+++ b/pypy/interpreter/pycompiler.py
@@ -6,7 +6,7 @@
 from pypy.interpreter import pycode
 from pypy.interpreter.pyparser import future, pyparse, error as parseerror
 from pypy.interpreter.astcompiler import (astbuilder, codegen, consts, misc,
-                                          optimize, ast)
+                                          optimize, ast, validate)
 from pypy.interpreter.error import OperationError
 
 
@@ -136,6 +136,9 @@
                                  e.wrap_info(space))
         return code
 
+    def validate_ast(self, node):
+        validate.validate_ast(self.space, node)
+
     def compile_to_ast(self, source, filename, mode, flags):
         info = pyparse.CompileInfo(filename, mode, flags)
         return self._compile_to_ast(source, info)
diff --git a/pypy/module/__builtin__/compiling.py 
b/pypy/module/__builtin__/compiling.py
--- a/pypy/module/__builtin__/compiling.py
+++ b/pypy/module/__builtin__/compiling.py
@@ -56,6 +56,7 @@
     # XXX: optimize flag is not used
 
     if ast_node is not None:
+        ec.compiler.validate_ast(ast_node)
         code = ec.compiler.compile_ast(ast_node, filename, mode, flags)
     elif flags & consts.PyCF_ONLY_AST:
         ast_node = ec.compiler.compile_to_ast(source, filename, mode, flags)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to