Author: Amaury Forgeot d'Arc <[email protected]>
Branch: py3.3
Changeset: r70807:fe0435cfe837
Date: 2014-04-21 15:10 +0200
http://bitbucket.org/pypy/pypy/changeset/fe0435cfe837/
Log: Add an AST validator, will prevent crashes when bad ast object are
built and compiled.
diff --git a/pypy/interpreter/astcompiler/test/test_validate.py
b/pypy/interpreter/astcompiler/test/test_validate.py
new file mode 100644
--- /dev/null
+++ b/pypy/interpreter/astcompiler/test/test_validate.py
@@ -0,0 +1,425 @@
+import os
+from pypy.interpreter.error import OperationError
+from pypy.interpreter.baseobjspace import W_Root
+from pypy.interpreter.astcompiler import ast
+from pypy.interpreter.astcompiler import validate
+
+class TestASTValidator:
+ def mod(self, mod, msg=None, mode="exec", exc=validate.ValidationError):
+ space = self.space
+ if isinstance(exc, W_Root):
+ w_exc = exc
+ exc = OperationError
+ else:
+ w_exc = None
+ with raises(exc) as cm:
+ validate.validate_ast(space, mod)
+ if w_exc is not None:
+ w_value = cm.value.get_w_value(space)
+ assert cm.value.match(space, w_exc)
+ exc_msg = str(cm.value)
+ else:
+ exc_msg = str(cm.value)
+ if msg is not None:
+ assert msg in exc_msg
+
+ def expr(self, node, msg=None, exc=validate.ValidationError):
+ mod = ast.Module([ast.Expr(node, 0, 0)])
+ self.mod(mod, msg, exc=exc)
+
+ def stmt(self, stmt, msg=None):
+ mod = ast.Module([stmt])
+ self.mod(mod, msg)
+
+ def test_module(self):
+ m = ast.Interactive([ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)])
+ self.mod(m, "must have Load context", "single")
+ m = ast.Expression(ast.Name("x", ast.Store, 0, 0))
+ self.mod(m, "must have Load context", "eval")
+
+ def _check_arguments(self, fac, check):
+ def arguments(args=None, vararg=None, varargannotation=None,
+ kwonlyargs=None, kwarg=None, kwargannotation=None,
+ defaults=None, kw_defaults=None):
+ if args is None:
+ args = []
+ if kwonlyargs is None:
+ kwonlyargs = []
+ if defaults is None:
+ defaults = []
+ if kw_defaults is None:
+ kw_defaults = []
+ args = ast.arguments(args, vararg, varargannotation, kwonlyargs,
+ kwarg, kwargannotation, defaults, kw_defaults)
+ return fac(args)
+ args = [ast.arg("x", ast.Name("x", ast.Store, 0, 0))]
+ check(arguments(args=args), "must have Load context")
+ check(arguments(varargannotation=ast.Num(self.space.wrap(3), 0, 0)),
+ "varargannotation but no vararg")
+ check(arguments(varargannotation=ast.Name("x", ast.Store, 0, 0),
vararg="x"),
+ "must have Load context")
+ check(arguments(kwonlyargs=args), "must have Load context")
+ check(arguments(kwargannotation=ast.Num(self.space.wrap(42), 0, 0)),
+ "kwargannotation but no kwarg")
+ check(arguments(kwargannotation=ast.Name("x", ast.Store, 0, 0),
+ kwarg="x"), "must have Load context")
+ check(arguments(defaults=[ast.Num(self.space.wrap(3), 0, 0)]),
+ "more positional defaults than args")
+ check(arguments(kw_defaults=[ast.Num(self.space.wrap(4), 0, 0)]),
+ "length of kwonlyargs is not the same as kw_defaults")
+ args = [ast.arg("x", ast.Name("x", ast.Load, 0, 0))]
+ check(arguments(args=args, defaults=[ast.Name("x", ast.Store, 0, 0)]),
+ "must have Load context")
+ args = [ast.arg("a", ast.Name("x", ast.Load, 0, 0)),
+ ast.arg("b", ast.Name("y", ast.Load, 0, 0))]
+ check(arguments(kwonlyargs=args,
+ kw_defaults=[None, ast.Name("x", ast.Store, 0, 0)]),
+ "must have Load context")
+
+ def test_funcdef(self):
+ a = ast.arguments([], None, None, [], None, None, [], [])
+ f = ast.FunctionDef("x", a, [], [], None, 0, 0)
+ self.stmt(f, "empty body on FunctionDef")
+ f = ast.FunctionDef("x", a, [ast.Pass(0, 0)], [ast.Name("x",
ast.Store, 0, 0)],
+ None, 0, 0)
+ self.stmt(f, "must have Load context")
+ f = ast.FunctionDef("x", a, [ast.Pass(0, 0)], [],
+ ast.Name("x", ast.Store, 0, 0), 0, 0)
+ self.stmt(f, "must have Load context")
+ def fac(args):
+ return ast.FunctionDef("x", args, [ast.Pass(0, 0)], [], None, 0, 0)
+ self._check_arguments(fac, self.stmt)
+
+ def test_classdef(self):
+ def cls(bases=None, keywords=None, starargs=None, kwargs=None,
+ body=None, decorator_list=None):
+ if bases is None:
+ bases = []
+ if keywords is None:
+ keywords = []
+ if body is None:
+ body = [ast.Pass(0, 0)]
+ if decorator_list is None:
+ decorator_list = []
+ return ast.ClassDef("myclass", bases, keywords, starargs,
+ kwargs, body, decorator_list, 0, 0)
+ self.stmt(cls(bases=[ast.Name("x", ast.Store, 0, 0)]),
+ "must have Load context")
+ self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store, 0,
0))]),
+ "must have Load context")
+ self.stmt(cls(starargs=ast.Name("x", ast.Store, 0, 0)),
+ "must have Load context")
+ self.stmt(cls(kwargs=ast.Name("x", ast.Store, 0, 0)),
+ "must have Load context")
+ self.stmt(cls(body=[]), "empty body on ClassDef")
+ self.stmt(cls(body=[None]), "None disallowed")
+ self.stmt(cls(decorator_list=[ast.Name("x", ast.Store, 0, 0)]),
+ "must have Load context")
+
+ def test_delete(self):
+ self.stmt(ast.Delete([], 0, 0), "empty targets on Delete")
+ self.stmt(ast.Delete([None], 0, 0), "None disallowed")
+ self.stmt(ast.Delete([ast.Name("x", ast.Load, 0, 0)], 0, 0),
+ "must have Del context")
+
+ def test_assign(self):
+ self.stmt(ast.Assign([], ast.Num(self.space.wrap(3), 0, 0), 0, 0),
"empty targets on Assign")
+ self.stmt(ast.Assign([None], ast.Num(self.space.wrap(3), 0, 0), 0, 0),
"None disallowed")
+ self.stmt(ast.Assign([ast.Name("x", ast.Load, 0, 0)],
ast.Num(self.space.wrap(3), 0, 0), 0, 0),
+ "must have Store context")
+ self.stmt(ast.Assign([ast.Name("x", ast.Store, 0, 0)],
+ ast.Name("y", ast.Store, 0, 0), 0, 0),
+ "must have Load context")
+
+ def test_augassign(self):
+ aug = ast.AugAssign(ast.Name("x", ast.Load, 0, 0), ast.Add,
+ ast.Name("y", ast.Load, 0, 0), 0, 0)
+ self.stmt(aug, "must have Store context")
+ aug = ast.AugAssign(ast.Name("x", ast.Store, 0, 0), ast.Add,
+ ast.Name("y", ast.Store, 0, 0), 0, 0)
+ self.stmt(aug, "must have Load context")
+
+ def test_for(self):
+ x = ast.Name("x", ast.Store, 0, 0)
+ y = ast.Name("y", ast.Load, 0, 0)
+ p = ast.Pass(0, 0)
+ self.stmt(ast.For(x, y, [], [], 0, 0), "empty body on For")
+ self.stmt(ast.For(ast.Name("x", ast.Load, 0, 0), y, [p], [], 0, 0),
+ "must have Store context")
+ self.stmt(ast.For(x, ast.Name("y", ast.Store, 0, 0), [p], [], 0, 0),
+ "must have Load context")
+ e = ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)
+ self.stmt(ast.For(x, y, [e], [], 0, 0), "must have Load context")
+ self.stmt(ast.For(x, y, [p], [e], 0, 0), "must have Load context")
+
+ def test_while(self):
+ self.stmt(ast.While(ast.Num(self.space.wrap(3), 0, 0), [], [], 0, 0),
"empty body on While")
+ self.stmt(ast.While(ast.Name("x", ast.Store, 0, 0), [ast.Pass(0, 0)],
[], 0, 0),
+ "must have Load context")
+ self.stmt(ast.While(ast.Num(self.space.wrap(3), 0, 0), [ast.Pass(0,
0)],
+ [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)],
0, 0),
+ "must have Load context")
+
+ def test_if(self):
+ self.stmt(ast.If(ast.Num(self.space.wrap(3), 0, 0), [], [], 0, 0),
"empty body on If")
+ i = ast.If(ast.Name("x", ast.Store, 0, 0), [ast.Pass(0, 0)], [], 0, 0)
+ self.stmt(i, "must have Load context")
+ i = ast.If(ast.Num(self.space.wrap(3), 0, 0), [ast.Expr(ast.Name("x",
ast.Store, 0, 0), 0, 0)], [], 0, 0)
+ self.stmt(i, "must have Load context")
+ i = ast.If(ast.Num(self.space.wrap(3), 0, 0), [ast.Pass(0, 0)],
+ [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], 0, 0)
+ self.stmt(i, "must have Load context")
+
+ @skip("enable when parser uses the new With construct")
+ def test_with(self):
+ p = ast.Pass(0, 0)
+ self.stmt(ast.With([], [p]), "empty items on With")
+ i = ast.withitem(ast.Num(self.space.wrap(3), 0, 0), None)
+ self.stmt(ast.With([i], []), "empty body on With")
+ i = ast.withitem(ast.Name("x", ast.Store, 0, 0), None)
+ self.stmt(ast.With([i], [p]), "must have Load context")
+ i = ast.withitem(ast.Num(self.space.wrap(3), 0, 0), ast.Name("x",
ast.Load, 0, 0))
+ self.stmt(ast.With([i], [p]), "must have Store context")
+
+ def test_raise(self):
+ r = ast.Raise(None, ast.Num(self.space.wrap(3), 0, 0), 0, 0)
+ self.stmt(r, "Raise with cause but no exception")
+ r = ast.Raise(ast.Name("x", ast.Store, 0, 0), None, 0, 0)
+ self.stmt(r, "must have Load context")
+ r = ast.Raise(ast.Num(self.space.wrap(4), 0, 0), ast.Name("x",
ast.Store, 0, 0), 0, 0)
+ self.stmt(r, "must have Load context")
+
+ @skip("enable when parser uses the new Try construct")
+ def test_try(self):
+ p = ast.Pass(0, 0)
+ t = ast.Try([], [], [], [p])
+ self.stmt(t, "empty body on Try")
+ t = ast.Try([ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)], [], [],
[p])
+ self.stmt(t, "must have Load context")
+ t = ast.Try([p], [], [], [])
+ self.stmt(t, "Try has neither except handlers nor finalbody")
+ t = ast.Try([p], [], [p], [p])
+ self.stmt(t, "Try has orelse but no except handlers")
+ t = ast.Try([p], [ast.ExceptHandler(None, "x", [])], [], [])
+ self.stmt(t, "empty body on ExceptHandler")
+ e = [ast.ExceptHandler(ast.Name("x", ast.Store, 0, 0), "y", [p])]
+ self.stmt(ast.Try([p], e, [], []), "must have Load context")
+ e = [ast.ExceptHandler(None, "x", [p])]
+ t = ast.Try([p], e, [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)],
[p])
+ self.stmt(t, "must have Load context")
+ t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store, 0, 0), 0,
0)])
+ self.stmt(t, "must have Load context")
+
+ def test_assert(self):
+ self.stmt(ast.Assert(ast.Name("x", ast.Store, 0, 0), None, 0, 0),
+ "must have Load context")
+ assrt = ast.Assert(ast.Name("x", ast.Load, 0, 0),
+ ast.Name("y", ast.Store, 0, 0), 0, 0)
+ self.stmt(assrt, "must have Load context")
+
+ def test_import(self):
+ self.stmt(ast.Import([], 0, 0), "empty names on Import")
+
+ def test_importfrom(self):
+ imp = ast.ImportFrom(None, [ast.alias("x", None)], -42, 0, 0)
+ self.stmt(imp, "level less than -1")
+ self.stmt(ast.ImportFrom(None, [], 0, 0, 0), "empty names on
ImportFrom")
+
+ def test_global(self):
+ self.stmt(ast.Global([], 0, 0), "empty names on Global")
+
+ def test_nonlocal(self):
+ self.stmt(ast.Nonlocal([], 0, 0), "empty names on Nonlocal")
+
+ def test_expr(self):
+ e = ast.Expr(ast.Name("x", ast.Store, 0, 0), 0, 0)
+ self.stmt(e, "must have Load context")
+
+ def test_boolop(self):
+ b = ast.BoolOp(ast.And, [], 0, 0)
+ self.expr(b, "less than 2 values")
+ b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(3), 0, 0)], 0, 0)
+ self.expr(b, "less than 2 values")
+ b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(4), 0, 0), None], 0,
0)
+ self.expr(b, "None disallowed")
+ b = ast.BoolOp(ast.And, [ast.Num(self.space.wrap(4), 0, 0),
ast.Name("x", ast.Store, 0, 0)], 0, 0)
+ self.expr(b, "must have Load context")
+
+ def test_unaryop(self):
+ u = ast.UnaryOp(ast.Not, ast.Name("x", ast.Store, 0, 0), 0, 0)
+ self.expr(u, "must have Load context")
+
+ def test_lambda(self):
+ a = ast.arguments([], None, None, [], None, None, [], [])
+ self.expr(ast.Lambda(a, ast.Name("x", ast.Store, 0, 0), 0, 0),
+ "must have Load context")
+ def fac(args):
+ return ast.Lambda(args, ast.Name("x", ast.Load, 0, 0), 0, 0)
+ self._check_arguments(fac, self.expr)
+
+ def test_ifexp(self):
+ l = ast.Name("x", ast.Load, 0, 0)
+ s = ast.Name("y", ast.Store, 0, 0)
+ for args in (s, l, l), (l, s, l), (l, l, s):
+ self.expr(ast.IfExp(*(args + (0, 0))), "must have Load context")
+
+ def test_dict(self):
+ d = ast.Dict([], [ast.Name("x", ast.Load, 0, 0)], 0, 0)
+ self.expr(d, "same number of keys as values")
+ d = ast.Dict([None], [ast.Name("x", ast.Load, 0, 0)], 0, 0)
+ self.expr(d, "None disallowed")
+ d = ast.Dict([ast.Name("x", ast.Load, 0, 0)], [None], 0, 0)
+ self.expr(d, "None disallowed")
+
+ def test_set(self):
+ self.expr(ast.Set([None], 0, 0), "None disallowed")
+ s = ast.Set([ast.Name("x", ast.Store, 0, 0)], 0, 0)
+ self.expr(s, "must have Load context")
+
+ def _check_comprehension(self, fac):
+ self.expr(fac([]), "comprehension with no generators")
+ g = ast.comprehension(ast.Name("x", ast.Load, 0, 0),
+ ast.Name("x", ast.Load, 0, 0), [])
+ self.expr(fac([g]), "must have Store context")
+ g = ast.comprehension(ast.Name("x", ast.Store, 0, 0),
+ ast.Name("x", ast.Store, 0, 0), [])
+ self.expr(fac([g]), "must have Load context")
+ x = ast.Name("x", ast.Store, 0, 0)
+ y = ast.Name("y", ast.Load, 0, 0)
+ g = ast.comprehension(x, y, [None])
+ self.expr(fac([g]), "None disallowed")
+ g = ast.comprehension(x, y, [ast.Name("x", ast.Store, 0, 0)])
+ self.expr(fac([g]), "must have Load context")
+
+ def _simple_comp(self, fac):
+ g = ast.comprehension(ast.Name("x", ast.Store, 0, 0),
+ ast.Name("x", ast.Load, 0, 0), [])
+ self.expr(fac(ast.Name("x", ast.Store, 0, 0), [g], 0, 0),
+ "must have Load context")
+ def wrap(gens):
+ return fac(ast.Name("x", ast.Store, 0, 0), gens, 0, 0)
+ self._check_comprehension(wrap)
+
+ def test_listcomp(self):
+ self._simple_comp(ast.ListComp)
+
+ def test_setcomp(self):
+ self._simple_comp(ast.SetComp)
+
+ def test_generatorexp(self):
+ self._simple_comp(ast.GeneratorExp)
+
+ def test_dictcomp(self):
+ g = ast.comprehension(ast.Name("y", ast.Store, 0, 0),
+ ast.Name("p", ast.Load, 0, 0), [])
+ c = ast.DictComp(ast.Name("x", ast.Store, 0, 0),
+ ast.Name("y", ast.Load, 0, 0), [g], 0, 0)
+ self.expr(c, "must have Load context")
+ c = ast.DictComp(ast.Name("x", ast.Load, 0, 0),
+ ast.Name("y", ast.Store, 0, 0), [g], 0, 0)
+ self.expr(c, "must have Load context")
+ def factory(comps):
+ k = ast.Name("x", ast.Load, 0, 0)
+ v = ast.Name("y", ast.Load, 0, 0)
+ return ast.DictComp(k, v, comps, 0, 0)
+ self._check_comprehension(factory)
+
+ def test_yield(self):
+ self.expr(ast.Yield(ast.Name("x", ast.Store, 0, 0), 0, 0), "must have
Load")
+ self.expr(ast.YieldFrom(ast.Name("x", ast.Store, 0, 0), 0, 0), "must
have Load")
+
+ def test_compare(self):
+ left = ast.Name("x", ast.Load, 0, 0)
+ comp = ast.Compare(left, [ast.In], [], 0, 0)
+ self.expr(comp, "no comparators")
+ comp = ast.Compare(left, [ast.In], [ast.Num(self.space.wrap(4), 0, 0),
ast.Num(self.space.wrap(5), 0, 0)], 0, 0)
+ self.expr(comp, "different number of comparators and operands")
+ comp = ast.Compare(ast.Num(self.space.wrap("blah"), 0, 0), [ast.In],
[left], 0, 0)
+ self.expr(comp, "non-numeric", exc=self.space.w_TypeError)
+ comp = ast.Compare(left, [ast.In], [ast.Num(self.space.wrap("blah"),
0, 0)], 0, 0)
+ self.expr(comp, "non-numeric", exc=self.space.w_TypeError)
+
+ def test_call(self):
+ func = ast.Name("x", ast.Load, 0, 0)
+ args = [ast.Name("y", ast.Load, 0, 0)]
+ keywords = [ast.keyword("w", ast.Name("z", ast.Load, 0, 0))]
+ stararg = ast.Name("p", ast.Load, 0, 0)
+ kwarg = ast.Name("q", ast.Load, 0, 0)
+ call = ast.Call(ast.Name("x", ast.Store, 0, 0), args, keywords,
stararg,
+ kwarg, 0, 0)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, [None], keywords, stararg, kwarg, 0, 0)
+ self.expr(call, "None disallowed")
+ bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store, 0, 0))]
+ call = ast.Call(func, args, bad_keywords, stararg, kwarg, 0, 0)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, args, keywords, ast.Name("z", ast.Store, 0, 0),
kwarg, 0, 0)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, args, keywords, stararg,
+ ast.Name("w", ast.Store, 0, 0), 0, 0)
+ self.expr(call, "must have Load context")
+
+ def test_num(self):
+ space = self.space
+ w_objs = space.appexec([], """():
+ class subint(int):
+ pass
+ class subfloat(float):
+ pass
+ class subcomplex(complex):
+ pass
+ return ("0", "hello", subint(), subfloat(), subcomplex())
+ """)
+ for w_obj in space.unpackiterable(w_objs):
+ self.expr(ast.Num(w_obj, 0, 0), "non-numeric",
exc=self.space.w_TypeError)
+
+ def test_attribute(self):
+ attr = ast.Attribute(ast.Name("x", ast.Store, 0, 0), "y", ast.Load, 0,
0)
+ self.expr(attr, "must have Load context")
+
+ def test_subscript(self):
+ sub = ast.Subscript(ast.Name("x", ast.Store, 0, 0),
ast.Index(ast.Num(self.space.wrap(3), 0, 0)),
+ ast.Load, 0, 0)
+ self.expr(sub, "must have Load context")
+ x = ast.Name("x", ast.Load, 0, 0)
+ sub = ast.Subscript(x, ast.Index(ast.Name("y", ast.Store, 0, 0)),
+ ast.Load, 0, 0)
+ self.expr(sub, "must have Load context")
+ s = ast.Name("x", ast.Store, 0, 0)
+ for args in (s, None, None), (None, s, None), (None, None, s):
+ sl = ast.Slice(*args)
+ self.expr(ast.Subscript(x, sl, ast.Load, 0, 0),
+ "must have Load context")
+ sl = ast.ExtSlice([])
+ self.expr(ast.Subscript(x, sl, ast.Load, 0, 0), "empty dims on
ExtSlice")
+ sl = ast.ExtSlice([ast.Index(s)])
+ self.expr(ast.Subscript(x, sl, ast.Load, 0, 0), "must have Load
context")
+
+ def test_starred(self):
+ left = ast.List([ast.Starred(ast.Name("x", ast.Load, 0, 0), ast.Store,
0, 0)],
+ ast.Store, 0, 0)
+ assign = ast.Assign([left], ast.Num(self.space.wrap(4), 0, 0), 0, 0)
+ self.stmt(assign, "must have Store context")
+
+ def _sequence(self, fac):
+ self.expr(fac([None], ast.Load, 0, 0), "None disallowed")
+ self.expr(fac([ast.Name("x", ast.Store, 0, 0)], ast.Load, 0, 0),
+ "must have Load context")
+
+ def test_list(self):
+ self._sequence(ast.List)
+
+ def test_tuple(self):
+ self._sequence(ast.Tuple)
+
+ def test_stdlib_validates(self):
+ stdlib = os.path.join(os.path.dirname(ast.__file__),
'../../../lib-python/3')
+ tests = ["os.py", "test/test_grammar.py", "test/test_unpack_ex.py"]
+ for module in tests:
+ fn = os.path.join(stdlib, module)
+ print 'compiling', fn
+ with open(fn, "r") as fp:
+ source = fp.read()
+ ec = self.space.getexecutioncontext()
+ ast_node = ec.compiler.compile_to_ast(source, fn, "exec", 0)
+ ec.compiler.validate_ast(ast_node)
diff --git a/pypy/interpreter/astcompiler/validate.py
b/pypy/interpreter/astcompiler/validate.py
new file mode 100644
--- /dev/null
+++ b/pypy/interpreter/astcompiler/validate.py
@@ -0,0 +1,410 @@
+"""A visitor to validate an AST object."""
+
+from pypy.interpreter.error import OperationError, oefmt
+from pypy.interpreter.astcompiler import ast
+from rpython.tool.pairtype import pair, pairtype
+from pypy.interpreter.baseobjspace import W_Root
+
+
+def validate_ast(space, node):
+ node.walkabout(AstValidator(space))
+
+
+class ValidationError(Exception):
+ pass
+
+
+def expr_context_name(ctx):
+ if not 1 <= ctx <= len(ast.expr_context_to_class):
+ return '??'
+ return ast.expr_context_to_class[ctx - 1].typedef.name
+
+def _check_context(expected_ctx, actual_ctx):
+ if expected_ctx != actual_ctx:
+ raise ValidationError(
+ "expression must have %s context but has %s instead" %
+ (expr_context_name(expected_ctx), expr_context_name(actual_ctx)))
+
+
+class __extend__(ast.AST):
+
+ def check_context(self, visitor, ctx):
+ raise AssertionError("should only be on expressions")
+
+ def walkabout_with_ctx(self, visitor, ctx):
+ self.walkabout(visitor) # With "load" context.
+
+
+class __extend__(ast.expr):
+
+ def check_context(self, visitor, ctx):
+ if ctx != ast.Load:
+ raise ValidationError(
+ "expression which can't be assigned to in %s context" %
+ expr_context_name(ctx))
+
+
+class __extend__(ast.Name):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+
+class __extend__(ast.List):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+ def walkabout_with_ctx(self, visitor, ctx):
+ visitor._validate_exprs(self.elts, ctx)
+
+
+class __extend__(ast.Tuple):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+ def walkabout_with_ctx(self, visitor, ctx):
+ visitor._validate_exprs(self.elts, ctx)
+
+
+class __extend__(ast.Starred):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+ def walkabout_with_ctx(self, visitor, ctx):
+ visitor._validate_expr(self.value, ctx)
+
+
+class __extend__(ast.Subscript):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+
+class __extend__(ast.Attribute):
+
+ def check_context(self, visitor, ctx):
+ _check_context(ctx, self.ctx)
+
+
+class AstValidator(ast.ASTVisitor):
+ def __init__(self, space):
+ self.space = space
+
+ def _validate_stmts(self, stmts):
+ if not stmts:
+ return
+ for stmt in stmts:
+ if not stmt:
+ raise ValidationError("None disallowed in statement list")
+ stmt.walkabout(self)
+
+ def _len(self, node):
+ if node is None:
+ return 0
+ return len(node)
+
+ def _validate_expr(self, expr, ctx=ast.Load):
+ expr.check_context(self, ctx)
+ expr.walkabout_with_ctx(self, ctx)
+
+ def _validate_exprs(self, exprs, ctx=ast.Load, null_ok=False):
+ if not exprs:
+ return
+ for expr in exprs:
+ if expr:
+ self._validate_expr(expr, ctx)
+ elif not null_ok:
+ raise ValidationError("None disallowed in expression list")
+
+ def _validate_body(self, body, owner):
+ self._validate_nonempty_seq(body, "body", owner)
+ self._validate_stmts(body)
+
+ def _validate_nonempty_seq(self, seq, what, owner):
+ if not seq:
+ raise ValidationError("empty %s on %s" % (what, owner))
+
+ def _validate_nonempty_seq_s(self, seq, what, owner):
+ if not seq:
+ raise ValidationError("empty %s on %s" % (what, owner))
+
+ def visit_Interactive(self, node):
+ self._validate_stmts(node.body)
+
+ def visit_Module(self, node):
+ self._validate_stmts(node.body)
+
+ def visit_Expression(self, node):
+ self._validate_expr(node.body)
+
+ # Statements
+
+ def visit_arg(self, node):
+ if node.annotation:
+ self._validate_expr(node.annotation)
+
+ def visit_arguments(self, node):
+ self.visit_sequence(node.args)
+ if node.varargannotation:
+ if not node.vararg:
+ raise ValidationError("varargannotation but no vararg on
arguments")
+ self._validate_expr(node.varargannotation)
+ self.visit_sequence(node.kwonlyargs)
+ if node.kwargannotation:
+ if not node.kwarg:
+ raise ValidationError("kwargannotation but no kwarg on
arguments")
+ self._validate_expr(node.kwargannotation)
+ if self._len(node.defaults) > self._len(node.args):
+ raise ValidationError("more positional defaults than args on
arguments")
+ if self._len(node.kw_defaults) != self._len(node.kwonlyargs):
+ raise ValidationError("length of kwonlyargs is not the same as "
+ "kw_defaults on arguments")
+ self._validate_exprs(node.defaults)
+ self._validate_exprs(node.kw_defaults, null_ok=True)
+
+ def visit_FunctionDef(self, node):
+ self._validate_body(node.body, "FunctionDef")
+ node.args.walkabout(self)
+ self._validate_exprs(node.decorator_list)
+ if node.returns:
+ self._validate_expr(node.returns)
+
+ def visit_keyword(self, node):
+ self._validate_expr(node.value)
+
+ def visit_ClassDef(self, node):
+ self._validate_body(node.body, "ClassDef")
+ self._validate_exprs(node.bases)
+ self.visit_sequence(node.keywords)
+ self._validate_exprs(node.decorator_list)
+ if node.starargs:
+ self._validate_expr(node.starargs)
+ if node.kwargs:
+ self._validate_expr(node.kwargs)
+
+ def visit_Return(self, node):
+ if node.value:
+ self._validate_expr(node.value)
+
+ def visit_Delete(self, node):
+ self._validate_nonempty_seq(node.targets, "targets", "Delete")
+ self._validate_exprs(node.targets, ast.Del)
+
+ def visit_Assign(self, node):
+ self._validate_nonempty_seq(node.targets, "targets", "Assign")
+ self._validate_exprs(node.targets, ast.Store)
+ self._validate_expr(node.value)
+
+ def visit_AugAssign(self, node):
+ self._validate_expr(node.target, ast.Store)
+ self._validate_expr(node.value)
+
+ def visit_For(self, node):
+ self._validate_expr(node.target, ast.Store)
+ self._validate_expr(node.iter)
+ self._validate_body(node.body, "For")
+ self._validate_stmts(node.orelse)
+
+ def visit_While(self, node):
+ self._validate_expr(node.test)
+ self._validate_body(node.body, "While")
+ self._validate_stmts(node.orelse)
+
+ def visit_If(self, node):
+ self._validate_expr(node.test)
+ self._validate_body(node.body, "If")
+ self._validate_stmts(node.orelse)
+
+ def visit_With(self, node):
+ self._validate_expr(node.context_expr)
+ if node.optional_vars:
+ self._validate_expr(node.optional_vars, ast.Store)
+ self._validate_body(node.body, "With")
+
+ def visit_Raise(self, node):
+ if node.exc:
+ self._validate_expr(node.exc)
+ if node.cause:
+ self._validate_expr(node.cause)
+ elif node.cause:
+ raise ValidationError("Raise with cause but no exception")
+
+ def visit_TryExcept(self, node):
+ self._validate_body(node.body, "TryExcept")
+ for handler in node.handlers:
+ handler.walkabout(self)
+ self._validate_stmts(node.orelse)
+
+ def visit_TryFinally(self, node):
+ self._validate_body(node.body, "TryFinally")
+ self._validate_body(node.finalbody, "TryFinally")
+
+ def visit_ExceptHandler(self, node):
+ if node.type:
+ self._validate_expr(node.type)
+ self._validate_body(node.body, "ExceptHandler")
+
+ def visit_Assert(self, node):
+ self._validate_expr(node.test)
+ if node.msg:
+ self._validate_expr(node.msg)
+
+ def visit_Import(self, node):
+ self._validate_nonempty_seq(node.names, "names", "Import")
+
+ def visit_ImportFrom(self, node):
+ if node.level < -1:
+ raise ValidationError("ImportFrom level less than -1")
+ self._validate_nonempty_seq(node.names, "names", "ImportFrom")
+
+ def visit_Global(self, node):
+ self._validate_nonempty_seq_s(node.names, "names", "Global")
+
+ def visit_Nonlocal(self, node):
+ self._validate_nonempty_seq_s(node.names, "names", "Nonlocal")
+
+ def visit_Expr(self, node):
+ self._validate_expr(node.value)
+
+ def visit_Pass(self, node):
+ pass
+
+ def visit_Break(self, node):
+ pass
+
+ def visit_Continue(self, node):
+ pass
+
+ # Expressions
+
+ def visit_Name(self, node):
+ pass
+
+ def visit_Ellipsis(self, node):
+ pass
+
+ def visit_BoolOp(self, node):
+ if len(node.values) < 2:
+ raise ValidationError("BoolOp with less than 2 values")
+ self._validate_exprs(node.values)
+
+ def visit_UnaryOp(self, node):
+ self._validate_expr(node.operand)
+
+ def visit_BinOp(self, node):
+ self._validate_expr(node.left)
+ self._validate_expr(node.right)
+
+ def visit_Lambda(self, node):
+ node.args.walkabout(self)
+ self._validate_expr(node.body)
+
+ def visit_IfExp(self, node):
+ self._validate_expr(node.test)
+ self._validate_expr(node.body)
+ self._validate_expr(node.orelse)
+
+ def visit_Dict(self, node):
+ if self._len(node.keys) != self._len(node.values):
+ raise ValidationError(
+ "Dict doesn't have the same number of keys as values")
+ self._validate_exprs(node.keys)
+ self._validate_exprs(node.values)
+
+ def visit_Set(self, node):
+ self._validate_exprs(node.elts)
+
+ def _validate_comprehension(self, generators):
+ if not generators:
+ raise ValidationError("comprehension with no generators")
+ for comp in generators:
+ self._validate_expr(comp.target, ast.Store)
+ self._validate_expr(comp.iter)
+ self._validate_exprs(comp.ifs)
+
+ def visit_ListComp(self, node):
+ self._validate_comprehension(node.generators)
+ self._validate_expr(node.elt)
+
+ def visit_SetComp(self, node):
+ self._validate_comprehension(node.generators)
+ self._validate_expr(node.elt)
+
+ def visit_GeneratorExp(self, node):
+ self._validate_comprehension(node.generators)
+ self._validate_expr(node.elt)
+
+ def visit_DictComp(self, node):
+ self._validate_comprehension(node.generators)
+ self._validate_expr(node.key)
+ self._validate_expr(node.value)
+
+ def visit_Yield(self, node):
+ if node.value:
+ self._validate_expr(node.value)
+
+ def visit_YieldFrom(self, node):
+ self._validate_expr(node.value)
+
+ def visit_Compare(self, node):
+ if not node.comparators:
+ raise ValidationError("Compare with no comparators")
+ if len(node.comparators) != len(node.ops):
+ raise ValidationError("Compare has a different number "
+ "of comparators and operands")
+ self._validate_exprs(node.comparators)
+ self._validate_expr(node.left)
+
+ def visit_Call(self, node):
+ self._validate_expr(node.func)
+ self._validate_exprs(node.args)
+ self.visit_sequence(node.keywords)
+ if node.starargs:
+ self._validate_expr(node.starargs)
+ if node.kwargs:
+ self._validate_expr(node.kwargs)
+
+ def visit_Num(self, node):
+ space = self.space
+ w_type = space.type(node.n)
+ if w_type not in [space.w_int, space.w_float, space.w_complex]:
+ raise oefmt(space.w_TypeError, "non-numeric type in Num")
+
+ def visit_Str(self, node):
+ space = self.space
+ w_type = space.type(node.s)
+ if w_type != space.w_unicode:
+ raise oefmt(space.w_TypeError, "non-string type in Str")
+
+ def visit_Bytes(self, node):
+ space = self.space
+ w_type = space.type(node.s)
+ if w_type != space.w_bytes:
+ raise oefmt(space.w_TypeError, "non-bytes type in Bytes")
+
+ def visit_Attribute(self, node):
+ self._validate_expr(node.value)
+
+ def visit_Subscript(self, node):
+ node.slice.walkabout(self)
+ self._validate_expr(node.value)
+
+ # Subscripts
+ def visit_Slice(self, node):
+ if node.lower:
+ self._validate_expr(node.lower)
+ if node.upper:
+ self._validate_expr(node.upper)
+ if node.step:
+ self._validate_expr(node.step)
+
+ def visit_ExtSlice(self, node):
+ self._validate_nonempty_seq(node.dims, "dims", "ExtSlice")
+ for dim in node.dims:
+ dim.walkabout(self)
+
+ def visit_Index(self, node):
+ self._validate_expr(node.value)
diff --git a/pypy/interpreter/pycompiler.py b/pypy/interpreter/pycompiler.py
--- a/pypy/interpreter/pycompiler.py
+++ b/pypy/interpreter/pycompiler.py
@@ -6,7 +6,7 @@
from pypy.interpreter import pycode
from pypy.interpreter.pyparser import future, pyparse, error as parseerror
from pypy.interpreter.astcompiler import (astbuilder, codegen, consts, misc,
- optimize, ast)
+ optimize, ast, validate)
from pypy.interpreter.error import OperationError
@@ -136,6 +136,9 @@
e.wrap_info(space))
return code
+ def validate_ast(self, node):
+ validate.validate_ast(self.space, node)
+
def compile_to_ast(self, source, filename, mode, flags):
info = pyparse.CompileInfo(filename, mode, flags)
return self._compile_to_ast(source, info)
diff --git a/pypy/module/__builtin__/compiling.py
b/pypy/module/__builtin__/compiling.py
--- a/pypy/module/__builtin__/compiling.py
+++ b/pypy/module/__builtin__/compiling.py
@@ -56,6 +56,7 @@
# XXX: optimize flag is not used
if ast_node is not None:
+ ec.compiler.validate_ast(ast_node)
code = ec.compiler.compile_ast(ast_node, filename, mode, flags)
elif flags & consts.PyCF_ONLY_AST:
ast_node = ec.compiler.compile_to_ast(source, filename, mode, flags)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit