This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d64bcd4c73 openlineage: fix duplicate naming in docs with list of 
supported operators (#41183)
d64bcd4c73 is described below

commit d64bcd4c7339293068a6465f6feb331b4195ab0a
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Aug 1 13:53:01 2024 +0200

    openlineage: fix duplicate naming in docs with list of supported operators 
(#41183)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 docs/exts/providers_extensions.py | 74 ++++++++++++++++++++++++++++-----------
 1 file changed, 53 insertions(+), 21 deletions(-)

diff --git a/docs/exts/providers_extensions.py 
b/docs/exts/providers_extensions.py
index 87c37f6640..5de6dcb8eb 100644
--- a/docs/exts/providers_extensions.py
+++ b/docs/exts/providers_extensions.py
@@ -35,6 +35,34 @@ from docs.exts.operators_and_hooks_ref import (
 )
 
 
+def get_import_mappings(tree):
+    """Retrieve a mapping of local import names to their fully qualified 
module paths from an AST tree.
+
+    :param tree: The AST tree to analyze for import statements.
+
+    :return: A dictionary where the keys are the local names (aliases) used in 
the current module
+        and the values are the fully qualified names of the imported modules 
or their members.
+
+    Example:
+        >>> import ast
+        >>> code = '''
+        ... import os
+        ... import numpy as np
+        ... from collections import defaultdict
+        ... from datetime import datetime as dt
+        ... '''
+        >>> get_import_mappings(ast.parse(code))
+        {'os': 'os', 'np': 'numpy', 'defaultdict': 'collections.defaultdict', 
'dt': 'datetime.datetime'}
+    """
+    imports = {}
+    for node in ast.walk(tree):
+        if isinstance(node, (ast.Import, ast.ImportFrom)):
+            for alias in node.names:
+                module_prefix = f"{node.module}." if hasattr(node, "module") 
and node.module else ""
+                imports[alias.asname or alias.name] = 
f"{module_prefix}{alias.name}"
+    return imports
+
+
 def _get_module_class_registry(
     module_filepath: str, class_extras: dict[str, Any]
 ) -> dict[str, dict[str, Any]]:
@@ -52,11 +80,16 @@ def _get_module_class_registry(
     with open(module_filepath) as file:
         ast_obj = ast.parse(file.read())
 
+    module_name = module_filepath.replace("/", ".").replace(".py", 
"").lstrip(".")
+    import_mappings = get_import_mappings(ast_obj)
     module_class_registry = {
-        node.name: {
-            "module_filepath": module_filepath,
+        f"{module_name}.{node.name}": {
             "methods": {n.name for n in ast.walk(node) if isinstance(n, 
ast.FunctionDef)},
-            "base_classes": [b.id for b in node.bases if isinstance(b, 
ast.Name)],
+            "base_classes": [
+                import_mappings.get(b.id, f"{module_name}.{b.id}")
+                for b in node.bases
+                if isinstance(b, ast.Name)
+            ],
             **class_extras,
         }
         for node in ast_obj.body
@@ -66,11 +99,11 @@ def _get_module_class_registry(
 
 
 def _has_method(
-    class_name: str, method_names: Iterable[str], class_registry: dict[str, 
dict[str, Any]]
+    class_path: str, method_names: Iterable[str], class_registry: dict[str, 
dict[str, Any]]
 ) -> bool:
     """Determines if a class or its bases in the registry have any of the 
specified methods.
 
-    :param class_name: The name of the class to check.
+    :param class_path: The path of the class to check.
     :param method_names: A list of names of methods to search for.
     :param class_registry: A dictionary representing the class registry, where 
each key is a class name
                             and the value is its metadata.
@@ -78,20 +111,20 @@ def _has_method(
 
     Example:
     >>> example_class_registry = {
-    ...     "MyClass": {"methods": {"foo", "bar"}, "base_classes": 
["BaseClass"]},
-    ...     "BaseClass": {"methods": {"base_foo"}, "base_classes": []},
+    ...     "some.module.MyClass": {"methods": {"foo", "bar"}, "base_classes": 
["BaseClass"]},
+    ...     "another.module.BaseClass": {"methods": {"base_foo"}, 
"base_classes": []},
     ... }
-    >>> _has_method("MyClass", ["foo"], example_class_registry)
+    >>> _has_method("some.module.MyClass", ["foo"], example_class_registry)
     True
-    >>> _has_method("MyClass", ["base_foo"], example_class_registry)
+    >>> _has_method("some.module.MyClass", ["base_foo"], 
example_class_registry)
     True
-    >>> _has_method("MyClass", ["not_a_method"], example_class_registry)
+    >>> _has_method("some.module.MyClass", ["not_a_method"], 
example_class_registry)
     False
     """
-    if class_name in class_registry:
-        if any(method in class_registry[class_name]["methods"] for method in 
method_names):
+    if class_path in class_registry:
+        if any(method in class_registry[class_path]["methods"] for method in 
method_names):
             return True
-        for base_name in class_registry[class_name]["base_classes"]:
+        for base_name in class_registry[class_path]["base_classes"]:
             if _has_method(base_name, method_names, class_registry):
                 return True
     return False
@@ -133,28 +166,27 @@ def _render_openlineage_supported_classes_content():
 
     class_registry = _get_providers_class_registry()
     # These excluded classes will be included in docs directly
-    class_registry.pop("DbApiHook")
-    class_registry.pop("SQLExecuteQueryOperator")
+    class_registry.pop("airflow.providers.common.sql.hooks.sql.DbApiHook")
+    
class_registry.pop("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator")
 
     providers: dict[str, dict[str, list[str]]] = {}
     db_hooks: list[tuple[str, str]] = []
-    for class_name, info in class_registry.items():
+    for class_path, info in class_registry.items():
+        class_name = class_path.split(".")[-1]
         if class_name.startswith("_"):
             continue
-        module_name = info["module_filepath"].replace("/", ".").replace(".py", 
"").lstrip(".")
-        class_path = f"{module_name}.{class_name}"
         provider_entry = providers.setdefault(info["provider_name"], 
{"operators": []})
 
         if class_name.lower().endswith("operator"):
             if _has_method(
-                class_name=class_name,
+                class_path=class_path,
                 method_names=openlineage_operator_methods,
                 class_registry=class_registry,
             ):
                 provider_entry["operators"].append(class_path)
         elif class_name.lower().endswith("hook"):
             if _has_method(
-                class_name=class_name,
+                class_path=class_path,
                 method_names=openlineage_db_hook_methods,
                 class_registry=class_registry,
             ):
@@ -164,7 +196,7 @@ def _render_openlineage_supported_classes_content():
     providers = {
         provider: {key: sorted(set(value), key=lambda x: x.split(".")[-1]) for 
key, value in details.items()}
         for provider, details in sorted(providers.items())
-        if any(details.values())
+        if any(details.values())  # This filters out providers with empty 
'operators'
     }
     db_hooks = sorted({db_type: hook for db_type, hook in db_hooks}.items(), 
key=lambda x: x[0])
 

Reply via email to