This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d64bcd4c73 openlineage: fix duplicate naming in docs with list of
supported operators (#41183)
d64bcd4c73 is described below
commit d64bcd4c7339293068a6465f6feb331b4195ab0a
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Aug 1 13:53:01 2024 +0200
openlineage: fix duplicate naming in docs with list of supported operators
(#41183)
Signed-off-by: Kacper Muda <[email protected]>
---
docs/exts/providers_extensions.py | 74 ++++++++++++++++++++++++++++-----------
1 file changed, 53 insertions(+), 21 deletions(-)
diff --git a/docs/exts/providers_extensions.py
b/docs/exts/providers_extensions.py
index 87c37f6640..5de6dcb8eb 100644
--- a/docs/exts/providers_extensions.py
+++ b/docs/exts/providers_extensions.py
@@ -35,6 +35,34 @@ from docs.exts.operators_and_hooks_ref import (
)
+def get_import_mappings(tree):
+ """Retrieve a mapping of local import names to their fully qualified
module paths from an AST tree.
+
+ :param tree: The AST tree to analyze for import statements.
+
+ :return: A dictionary where the keys are the local names (aliases) used in
the current module
+ and the values are the fully qualified names of the imported modules
or their members.
+
+ Example:
+ >>> import ast
+ >>> code = '''
+ ... import os
+ ... import numpy as np
+ ... from collections import defaultdict
+ ... from datetime import datetime as dt
+ ... '''
+ >>> get_import_mappings(ast.parse(code))
+ {'os': 'os', 'np': 'numpy', 'defaultdict': 'collections.defaultdict',
'dt': 'datetime.datetime'}
+ """
+ imports = {}
+ for node in ast.walk(tree):
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
+ for alias in node.names:
+ module_prefix = f"{node.module}." if hasattr(node, "module")
and node.module else ""
+ imports[alias.asname or alias.name] =
f"{module_prefix}{alias.name}"
+ return imports
+
+
def _get_module_class_registry(
module_filepath: str, class_extras: dict[str, Any]
) -> dict[str, dict[str, Any]]:
@@ -52,11 +80,16 @@ def _get_module_class_registry(
with open(module_filepath) as file:
ast_obj = ast.parse(file.read())
+ module_name = module_filepath.replace("/", ".").replace(".py",
"").lstrip(".")
+ import_mappings = get_import_mappings(ast_obj)
module_class_registry = {
- node.name: {
- "module_filepath": module_filepath,
+ f"{module_name}.{node.name}": {
"methods": {n.name for n in ast.walk(node) if isinstance(n,
ast.FunctionDef)},
- "base_classes": [b.id for b in node.bases if isinstance(b,
ast.Name)],
+ "base_classes": [
+ import_mappings.get(b.id, f"{module_name}.{b.id}")
+ for b in node.bases
+ if isinstance(b, ast.Name)
+ ],
**class_extras,
}
for node in ast_obj.body
@@ -66,11 +99,11 @@ def _get_module_class_registry(
def _has_method(
- class_name: str, method_names: Iterable[str], class_registry: dict[str,
dict[str, Any]]
+ class_path: str, method_names: Iterable[str], class_registry: dict[str,
dict[str, Any]]
) -> bool:
"""Determines if a class or its bases in the registry have any of the
specified methods.
- :param class_name: The name of the class to check.
+ :param class_path: The path of the class to check.
:param method_names: A list of names of methods to search for.
:param class_registry: A dictionary representing the class registry, where
each key is a class name
and the value is its metadata.
@@ -78,20 +111,20 @@ def _has_method(
Example:
>>> example_class_registry = {
- ... "MyClass": {"methods": {"foo", "bar"}, "base_classes":
["BaseClass"]},
- ... "BaseClass": {"methods": {"base_foo"}, "base_classes": []},
+ ... "some.module.MyClass": {"methods": {"foo", "bar"}, "base_classes":
["BaseClass"]},
+ ... "another.module.BaseClass": {"methods": {"base_foo"},
"base_classes": []},
... }
- >>> _has_method("MyClass", ["foo"], example_class_registry)
+ >>> _has_method("some.module.MyClass", ["foo"], example_class_registry)
True
- >>> _has_method("MyClass", ["base_foo"], example_class_registry)
+ >>> _has_method("some.module.MyClass", ["base_foo"],
example_class_registry)
True
- >>> _has_method("MyClass", ["not_a_method"], example_class_registry)
+ >>> _has_method("some.module.MyClass", ["not_a_method"],
example_class_registry)
False
"""
- if class_name in class_registry:
- if any(method in class_registry[class_name]["methods"] for method in
method_names):
+ if class_path in class_registry:
+ if any(method in class_registry[class_path]["methods"] for method in
method_names):
return True
- for base_name in class_registry[class_name]["base_classes"]:
+ for base_name in class_registry[class_path]["base_classes"]:
if _has_method(base_name, method_names, class_registry):
return True
return False
@@ -133,28 +166,27 @@ def _render_openlineage_supported_classes_content():
class_registry = _get_providers_class_registry()
# These excluded classes will be included in docs directly
- class_registry.pop("DbApiHook")
- class_registry.pop("SQLExecuteQueryOperator")
+ class_registry.pop("airflow.providers.common.sql.hooks.sql.DbApiHook")
+
class_registry.pop("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator")
providers: dict[str, dict[str, list[str]]] = {}
db_hooks: list[tuple[str, str]] = []
- for class_name, info in class_registry.items():
+ for class_path, info in class_registry.items():
+ class_name = class_path.split(".")[-1]
if class_name.startswith("_"):
continue
- module_name = info["module_filepath"].replace("/", ".").replace(".py",
"").lstrip(".")
- class_path = f"{module_name}.{class_name}"
provider_entry = providers.setdefault(info["provider_name"],
{"operators": []})
if class_name.lower().endswith("operator"):
if _has_method(
- class_name=class_name,
+ class_path=class_path,
method_names=openlineage_operator_methods,
class_registry=class_registry,
):
provider_entry["operators"].append(class_path)
elif class_name.lower().endswith("hook"):
if _has_method(
- class_name=class_name,
+ class_path=class_path,
method_names=openlineage_db_hook_methods,
class_registry=class_registry,
):
@@ -164,7 +196,7 @@ def _render_openlineage_supported_classes_content():
providers = {
provider: {key: sorted(set(value), key=lambda x: x.split(".")[-1]) for
key, value in details.items()}
for provider, details in sorted(providers.items())
- if any(details.values())
+ if any(details.values()) # This filters out providers with empty
'operators'
}
db_hooks = sorted({db_type: hook for db_type, hook in db_hooks}.items(),
key=lambda x: x[0])