https://github.com/python/cpython/commit/a8cb5e4a43a0f4699590a746ca02cd688480ba15
commit: a8cb5e4a43a0f4699590a746ca02cd688480ba15
branch: main
author: Tomasz Pytel <[email protected]>
committer: JelleZijlstra <[email protected]>
date: 2025-03-19T15:29:40-07:00
summary:
gh-129598: ast: allow multi stmts for ast single with ';' (#129620)
files:
A Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst
M Lib/ast.py
M Lib/test/test_unparse.py
diff --git a/Lib/ast.py b/Lib/ast.py
index 0937c27bdf8a11..cb1f8dfe128ead 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -674,6 +674,7 @@ def __init__(self):
self._type_ignores = {}
self._indent = 0
self._in_try_star = False
+ self._in_interactive = False
def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
@@ -702,11 +703,20 @@ def maybe_newline(self):
if self._source:
self.write("\n")
- def fill(self, text=""):
+ def maybe_semicolon(self):
+ """Adds a "; " delimiter if it isn't the start of generated source"""
+ if self._source:
+ self.write("; ")
+
+ def fill(self, text="", *, allow_semicolon=True):
"""Indent a piece of text and append it, according to the current
- indentation level"""
- self.maybe_newline()
- self.write(" " * self._indent + text)
+ indentation level, or only delineate with semicolon if applicable"""
+ if self._in_interactive and not self._indent and allow_semicolon:
+ self.maybe_semicolon()
+ self.write(text)
+ else:
+ self.maybe_newline()
+ self.write(" " * self._indent + text)
def write(self, *text):
"""Add new source parts"""
@@ -812,8 +822,17 @@ def visit_Module(self, node):
ignore.lineno: f"ignore{ignore.tag}"
for ignore in node.type_ignores
}
- self._write_docstring_and_traverse_body(node)
- self._type_ignores.clear()
+ try:
+ self._write_docstring_and_traverse_body(node)
+ finally:
+ self._type_ignores.clear()
+
+ def visit_Interactive(self, node):
+ self._in_interactive = True
+ try:
+ self._write_docstring_and_traverse_body(node)
+ finally:
+ self._in_interactive = False
def visit_FunctionType(self, node):
with self.delimit("(", ")"):
@@ -945,17 +964,17 @@ def visit_Raise(self, node):
self.traverse(node.cause)
def do_visit_try(self, node):
- self.fill("try")
+ self.fill("try", allow_semicolon=False)
with self.block():
self.traverse(node.body)
for ex in node.handlers:
self.traverse(ex)
if node.orelse:
- self.fill("else")
+ self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
if node.finalbody:
- self.fill("finally")
+ self.fill("finally", allow_semicolon=False)
with self.block():
self.traverse(node.finalbody)
@@ -976,7 +995,7 @@ def visit_TryStar(self, node):
self._in_try_star = prev_in_try_star
def visit_ExceptHandler(self, node):
- self.fill("except*" if self._in_try_star else "except")
+ self.fill("except*" if self._in_try_star else "except",
allow_semicolon=False)
if node.type:
self.write(" ")
self.traverse(node.type)
@@ -989,9 +1008,9 @@ def visit_ExceptHandler(self, node):
def visit_ClassDef(self, node):
self.maybe_newline()
for deco in node.decorator_list:
- self.fill("@")
+ self.fill("@", allow_semicolon=False)
self.traverse(deco)
- self.fill("class " + node.name)
+ self.fill("class " + node.name, allow_semicolon=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit_if("(", ")", condition = node.bases or
node.keywords):
@@ -1021,10 +1040,10 @@ def visit_AsyncFunctionDef(self, node):
def _function_helper(self, node, fill_suffix):
self.maybe_newline()
for deco in node.decorator_list:
- self.fill("@")
+ self.fill("@", allow_semicolon=False)
self.traverse(deco)
def_str = fill_suffix + " " + node.name
- self.fill(def_str)
+ self.fill(def_str, allow_semicolon=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit("(", ")"):
@@ -1075,7 +1094,7 @@ def visit_AsyncFor(self, node):
self._for_helper("async for ", node)
def _for_helper(self, fill, node):
- self.fill(fill)
+ self.fill(fill, allow_semicolon=False)
self.set_precedence(_Precedence.TUPLE, node.target)
self.traverse(node.target)
self.write(" in ")
@@ -1083,46 +1102,46 @@ def _for_helper(self, fill, node):
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
if node.orelse:
- self.fill("else")
+ self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
def visit_If(self, node):
- self.fill("if ")
+ self.fill("if ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# collapse nested ifs into equivalent elifs.
while node.orelse and len(node.orelse) == 1 and
isinstance(node.orelse[0], If):
node = node.orelse[0]
- self.fill("elif ")
+ self.fill("elif ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# final else
if node.orelse:
- self.fill("else")
+ self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
def visit_While(self, node):
- self.fill("while ")
+ self.fill("while ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
- self.fill("else")
+ self.fill("else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
def visit_With(self, node):
- self.fill("with ")
+ self.fill("with ", allow_semicolon=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
def visit_AsyncWith(self, node):
- self.fill("async with ")
+ self.fill("async with ", allow_semicolon=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
@@ -1264,7 +1283,7 @@ def visit_Name(self, node):
self.write(node.id)
def _write_docstring(self, node):
- self.fill()
+ self.fill(allow_semicolon=False)
if node.kind == "u":
self.write("u")
self._write_str_avoiding_backslashes(node.value,
quote_types=_MULTI_QUOTES)
@@ -1558,7 +1577,7 @@ def visit_Slice(self, node):
self.traverse(node.step)
def visit_Match(self, node):
- self.fill("match ")
+ self.fill("match ", allow_semicolon=False)
self.traverse(node.subject)
with self.block():
for case in node.cases:
@@ -1652,7 +1671,7 @@ def visit_withitem(self, node):
self.traverse(node.optional_vars)
def visit_match_case(self, node):
- self.fill("case ")
+ self.fill("case ", allow_semicolon=False)
self.traverse(node.pattern)
if node.guard:
self.write(" if ")
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index 686649a520880e..839326f6436809 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -142,13 +142,13 @@ def check_invalid(self, node, raises=ValueError):
with self.subTest(node=node):
self.assertRaises(raises, ast.unparse, node)
- def get_source(self, code1, code2=None):
+ def get_source(self, code1, code2=None, **kwargs):
code2 = code2 or code1
- code1 = ast.unparse(ast.parse(code1))
+ code1 = ast.unparse(ast.parse(code1, **kwargs))
return code1, code2
- def check_src_roundtrip(self, code1, code2=None):
- code1, code2 = self.get_source(code1, code2)
+ def check_src_roundtrip(self, code1, code2=None, **kwargs):
+ code1, code2 = self.get_source(code1, code2, **kwargs)
with self.subTest(code1=code1, code2=code2):
self.assertEqual(code2, code1)
@@ -469,6 +469,120 @@ def test_type_ignore(self):
):
self.check_ast_roundtrip(statement, type_comments=True)
+ def test_unparse_interactive_semicolons(self):
+ # gh-129598: Fix ast.unparse() when ast.Interactive contains multiple
statements
+ self.check_src_roundtrip("i = 1; 'expr'; raise Exception",
mode='single')
+ self.check_src_roundtrip("i: int = 1; j: float = 0; k += l",
mode='single')
+ combinable = (
+ "'expr'",
+ "(i := 1)",
+ "import foo",
+ "from foo import bar",
+ "i = 1",
+ "i += 1",
+ "i: int = 1",
+ "return i",
+ "pass",
+ "break",
+ "continue",
+ "del i",
+ "assert i",
+ "global i",
+ "nonlocal j",
+ "await i",
+ "yield i",
+ "yield from i",
+ "raise i",
+ "type t[T] = ...",
+ "i",
+ )
+ for a in combinable:
+ for b in combinable:
+ self.check_src_roundtrip(f"{a}; {b}", mode='single')
+
+ def test_unparse_interactive_integrity_1(self):
+ # rest of unparse_interactive_integrity tests just make sure
mode='single' parse and unparse didn't break
+ self.check_src_roundtrip(
+ "if i:\n 'expr'\nelse:\n raise Exception",
+ "if i:\n 'expr'\nelse:\n raise Exception",
+ mode='single'
+ )
+ self.check_src_roundtrip(
+ "@decorator1\n@decorator2\ndef func():\n 'docstring'\n i = 1;
'expr'; raise Exception",
+ '''@decorator1\n@decorator2\ndef func():\n """docstring"""\n
i = 1\n 'expr'\n raise Exception''',
+ mode='single'
+ )
+ self.check_src_roundtrip(
+ "@decorator1\n@decorator2\nclass cls:\n 'docstring'\n i = 1;
'expr'; raise Exception",
+ '''@decorator1\n@decorator2\nclass cls:\n """docstring"""\n
i = 1\n 'expr'\n raise Exception''',
+ mode='single'
+ )
+
+ def test_unparse_interactive_integrity_2(self):
+ for statement in (
+ "def x():\n pass",
+ "def x(y):\n pass",
+ "async def x():\n pass",
+ "async def x(y):\n pass",
+ "for x in y:\n pass",
+ "async for x in y:\n pass",
+ "with x():\n pass",
+ "async with x():\n pass",
+ "def f():\n pass",
+ "def f(a):\n pass",
+ "def f(b=2):\n pass",
+ "def f(a, b):\n pass",
+ "def f(a, b=2):\n pass",
+ "def f(a=5, b=2):\n pass",
+ "def f(*, a=1, b=2):\n pass",
+ "def f(*, a=1, b):\n pass",
+ "def f(*, a, b=2):\n pass",
+ "def f(a, b=None, *, c, **kwds):\n pass",
+ "def f(a=2, *args, c=5, d, **kwds):\n pass",
+ "def f(*args, **kwargs):\n pass",
+ "class cls:\n\n def f(self):\n pass",
+ "class cls:\n\n def f(self, a):\n pass",
+ "class cls:\n\n def f(self, b=2):\n pass",
+ "class cls:\n\n def f(self, a, b):\n pass",
+ "class cls:\n\n def f(self, a, b=2):\n pass",
+ "class cls:\n\n def f(self, a=5, b=2):\n pass",
+ "class cls:\n\n def f(self, *, a=1, b=2):\n pass",
+ "class cls:\n\n def f(self, *, a=1, b):\n pass",
+ "class cls:\n\n def f(self, *, a, b=2):\n pass",
+ "class cls:\n\n def f(self, a, b=None, *, c, **kwds):\n
pass",
+ "class cls:\n\n def f(self, a=2, *args, c=5, d, **kwds):\n
pass",
+ "class cls:\n\n def f(self, *args, **kwargs):\n pass",
+ ):
+ self.check_src_roundtrip(statement, mode='single')
+
+ def test_unparse_interactive_integrity_3(self):
+ for statement in (
+ "def x():",
+ "def x(y):",
+ "async def x():",
+ "async def x(y):",
+ "for x in y:",
+ "async for x in y:",
+ "with x():",
+ "async with x():",
+ "def f():",
+ "def f(a):",
+ "def f(b=2):",
+ "def f(a, b):",
+ "def f(a, b=2):",
+ "def f(a=5, b=2):",
+ "def f(*, a=1, b=2):",
+ "def f(*, a=1, b):",
+ "def f(*, a, b=2):",
+ "def f(a, b=None, *, c, **kwds):",
+ "def f(a=2, *args, c=5, d, **kwds):",
+ "def f(*args, **kwargs):",
+ ):
+ src = statement + '\n i=1;j=2'
+ out = statement + '\n i = 1\n j = 2'
+
+ self.check_src_roundtrip(src, out, mode='single')
+
class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecessary additions"""
diff --git
a/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst
b/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst
new file mode 100644
index 00000000000000..f59eeb236e24a2
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-02-03-16-27-14.gh-issue-129598.0js33I.rst
@@ -0,0 +1 @@
+Fix :func:`ast.unparse` when :class:`ast.Interactive` contains multiple
statements.
_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]