https://github.com/python/cpython/commit/d065edfb66470bbf06367b3570661d0346aa6707
commit: d065edfb66470bbf06367b3570661d0346aa6707
branch: main
author: Batuhan Taskaya <[email protected]>
committer: jeremyhylton <[email protected]>
date: 2024-05-22T01:39:26Z
summary:

gh-60191: Implement ast.compare (#19211)

* bpo-15987: Implement ast.compare

Add a compare() function that compares two ASTs for structural equality. There 
are two set of attributes on AST node objects, fields and attributes. The 
fields are always compared, since they represent the actual structure of the 
code. The attributes can be optionally be included in the comparison. 
Attributes capture things like line numbers of column offsets, so comparing 
them involves test whether the layout of the program text is the same. Since 
whitespace seems inessential for comparing ASTs, the default is to compare 
fields but not attributes.

ASTs are just Python objects that can be modified in arbitrary ways. The API 
for ASTs is under-specified in the presence of user modifications to objects. 
The comparison respects modifications to fields and attributes, and to _fields 
and _attributes attributes. A user could create obviously malformed objects, 
and the code will probably fail with an AttributeError when that happens. (For 
example, adding "spam" to _fields but not adding a "spam" attribute to the 
object.) 

Co-authored-by: Jeremy Hylton <[email protected]>

files:
A Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst
M Doc/library/ast.rst
M Doc/whatsnew/3.14.rst
M Lib/ast.py
M Lib/test/test_ast.py

diff --git a/Doc/library/ast.rst b/Doc/library/ast.rst
index d4ccf282a5d00a..9ee56b92431b57 100644
--- a/Doc/library/ast.rst
+++ b/Doc/library/ast.rst
@@ -2472,6 +2472,20 @@ effects on the compilation of a program:
    .. versionadded:: 3.8
 
 
+.. function:: compare(a, b, /, *, compare_attributes=False)
+
+   Recursively compares two ASTs.
+
+   *compare_attributes* affects whether AST attributes are considered
+   in the comparison. If *compare_attributes* is ``False`` (default), then
+   attributes are ignored. Otherwise they must all be equal. This
+   option is useful to check whether the ASTs are structurally equal but
+   differ in whitespace or similar details. Attributes include line numbers
+   and column offsets.
+
+   .. versionadded:: 3.14
+
+
 .. _ast-cli:
 
 Command-Line Usage
diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst
index 27c985bec104fe..39172ac60cf1e0 100644
--- a/Doc/whatsnew/3.14.rst
+++ b/Doc/whatsnew/3.14.rst
@@ -86,6 +86,13 @@ New Modules
 Improved Modules
 ================
 
+ast
+---
+
+Added :func:`ast.compare` for comparing two ASTs.
+(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`)
+
+
 
 Optimizations
 =============
diff --git a/Lib/ast.py b/Lib/ast.py
index d7e51aba595706..031bab43df7579 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -401,6 +401,77 @@ def walk(node):
         yield node
 
 
+def compare(
+    a,
+    b,
+    /,
+    *,
+    compare_attributes=False,
+):
+    """Recursively compares two ASTs.
+
+    compare_attributes affects whether AST attributes are considered
+    in the comparison. If compare_attributes is False (default), then
+    attributes are ignored. Otherwise they must all be equal. This
+    option is useful to check whether the ASTs are structurally equal but
+    might differ in whitespace or similar details.
+    """
+
+    def _compare(a, b):
+        # Compare two fields on an AST object, which may themselves be
+        # AST objects, lists of AST objects, or primitive ASDL types
+        # like identifiers and constants.
+        if isinstance(a, AST):
+            return compare(
+                a,
+                b,
+                compare_attributes=compare_attributes,
+            )
+        elif isinstance(a, list):
+            # If a field is repeated, then both objects will represent
+            # the value as a list.
+            if len(a) != len(b):
+                return False
+            for a_item, b_item in zip(a, b):
+                if not _compare(a_item, b_item):
+                    return False
+            else:
+                return True
+        else:
+            return type(a) is type(b) and a == b
+
+    def _compare_fields(a, b):
+        if a._fields != b._fields:
+            return False
+        for field in a._fields:
+            a_field = getattr(a, field)
+            b_field = getattr(b, field)
+            if not _compare(a_field, b_field):
+                return False
+        else:
+            return True
+
+    def _compare_attributes(a, b):
+        if a._attributes != b._attributes:
+            return False
+        # Attributes are always ints.
+        for attr in a._attributes:
+            a_attr = getattr(a, attr)
+            b_attr = getattr(b, attr)
+            if a_attr != b_attr:
+                return False
+        else:
+            return True
+
+    if type(a) is not type(b):
+        return False
+    if not _compare_fields(a, b):
+        return False
+    if compare_attributes and not _compare_attributes(a, b):
+        return False
+    return True
+
+
 class NodeVisitor(object):
     """
     A node visitor base class that walks the abstract syntax tree and calls a
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 5422c861ffb5c0..8a4374c56cbc08 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -38,6 +38,9 @@ def to_tuple(t):
         result.append(to_tuple(getattr(t, f)))
     return tuple(result)
 
+STDLIB = os.path.dirname(ast.__file__)
+STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
+STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
 
 # These tests are compiled through "exec"
 # There should be at least one test per statement
@@ -1066,6 +1069,114 @@ def test_ast_asdl_signature(self):
         expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
         self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
 
+    def test_compare_basics(self):
+        self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
+        self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
+        self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
+        self.assertFalse(
+            ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
+        )
+
+    def test_compare_modified_ast(self):
+        # The ast API is a bit underspecified. The objects are mutable,
+        # and even _fields and _attributes are mutable. The compare() does
+        # some simple things to accommodate mutability.
+        a = ast.parse("m * x + b", mode="eval")
+        b = ast.parse("m * x + b", mode="eval")
+        self.assertTrue(ast.compare(a, b))
+
+        a._fields = a._fields + ("spam",)
+        a.spam = "Spam"
+        self.assertNotEqual(a._fields, b._fields)
+        self.assertFalse(ast.compare(a, b))
+        self.assertFalse(ast.compare(b, a))
+
+        b._fields = a._fields
+        b.spam = a.spam
+        self.assertTrue(ast.compare(a, b))
+        self.assertTrue(ast.compare(b, a))
+
+        b._attributes = b._attributes + ("eggs",)
+        b.eggs = "eggs"
+        self.assertNotEqual(a._attributes, b._attributes)
+        self.assertFalse(ast.compare(a, b, compare_attributes=True))
+        self.assertFalse(ast.compare(b, a, compare_attributes=True))
+
+        a._attributes = b._attributes
+        a.eggs = b.eggs
+        self.assertTrue(ast.compare(a, b, compare_attributes=True))
+        self.assertTrue(ast.compare(b, a, compare_attributes=True))
+
+    def test_compare_literals(self):
+        constants = (
+            -20,
+            20,
+            20.0,
+            1,
+            1.0,
+            True,
+            0,
+            False,
+            frozenset(),
+            tuple(),
+            "ABCD",
+            "abcd",
+            "中文字",
+            1e1000,
+            -1e1000,
+        )
+        for next_index, constant in enumerate(constants[:-1], 1):
+            next_constant = constants[next_index]
+            with self.subTest(literal=constant, next_literal=next_constant):
+                self.assertTrue(
+                    ast.compare(ast.Constant(constant), ast.Constant(constant))
+                )
+                self.assertFalse(
+                    ast.compare(
+                        ast.Constant(constant), ast.Constant(next_constant)
+                    )
+                )
+
+        same_looking_literal_cases = [
+            {1, 1.0, True, 1 + 0j},
+            {0, 0.0, False, 0 + 0j},
+        ]
+        for same_looking_literals in same_looking_literal_cases:
+            for literal in same_looking_literals:
+                for same_looking_literal in same_looking_literals - {literal}:
+                    self.assertFalse(
+                        ast.compare(
+                            ast.Constant(literal),
+                            ast.Constant(same_looking_literal),
+                        )
+                    )
+
+    def test_compare_fieldless(self):
+        self.assertTrue(ast.compare(ast.Add(), ast.Add()))
+        self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
+
+    def test_compare_modes(self):
+        for mode, sources in (
+            ("exec", exec_tests),
+            ("eval", eval_tests),
+            ("single", single_tests),
+        ):
+            for source in sources:
+                a = ast.parse(source, mode=mode)
+                b = ast.parse(source, mode=mode)
+                self.assertTrue(
+                    ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
+                )
+
+    def test_compare_attributes_option(self):
+        def parse(a, b):
+            return ast.parse(a), ast.parse(b)
+
+        a, b = parse("2 + 2", "2+2")
+        self.assertTrue(ast.compare(a, b))
+        self.assertTrue(ast.compare(a, b, compare_attributes=False))
+        self.assertFalse(ast.compare(a, b, compare_attributes=True))
+
     def test_positional_only_feature_version(self):
         ast.parse('def foo(x, /): ...', feature_version=(3, 8))
         ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
@@ -1222,6 +1333,7 @@ def test_none_checks(self) -> None:
         for node, attr, source in tests:
             self.assert_none_check(node, attr, source)
 
+
 class ASTHelpers_Test(unittest.TestCase):
     maxDiff = None
 
@@ -2191,16 +2303,15 @@ def test_nameconstant(self):
 
     @support.requires_resource('cpu')
     def test_stdlib_validates(self):
-        stdlib = os.path.dirname(ast.__file__)
-        tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
-        tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
-        for module in tests:
+        for module in STDLIB_FILES:
             with self.subTest(module):
-                fn = os.path.join(stdlib, module)
+                fn = os.path.join(STDLIB, module)
                 with open(fn, "r", encoding="utf-8") as fp:
                     source = fp.read()
                 mod = ast.parse(source, fn)
                 compile(mod, fn, "exec")
+                mod2 = ast.parse(source, fn)
+                self.assertTrue(ast.compare(mod, mod2))
 
     constant_1 = ast.Constant(1)
     pattern_1 = ast.MatchValue(constant_1)
diff --git a/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst 
b/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst
new file mode 100644
index 00000000000000..b906393449656d
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst
@@ -0,0 +1,2 @@
+Implemented :func:`ast.compare` for comparing two ASTs. Patch by Batuhan
+Taskaya with some help from Jeremy Hylton.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to