This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2852976ea6 fix: Fix parent id macro and remove unused utils (#37877)
2852976ea6 is described below

commit 2852976ea6321b152ebc631d30d5526703bc6590
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Mar 5 14:11:47 2024 +0100

    fix: Fix parent id macro and remove unused utils (#37877)
---
 airflow/providers/openlineage/plugins/macros.py    | 26 +++----
 airflow/providers/openlineage/utils/utils.py       | 85 +---------------------
 tests/providers/openlineage/plugins/test_macros.py | 19 +++--
 tests/providers/openlineage/plugins/test_utils.py  | 36 ---------
 4 files changed, 20 insertions(+), 146 deletions(-)

diff --git a/airflow/providers/openlineage/plugins/macros.py 
b/airflow/providers/openlineage/plugins/macros.py
index a4039db2f4..fa05a60386 100644
--- a/airflow/providers/openlineage/plugins/macros.py
+++ b/airflow/providers/openlineage/plugins/macros.py
@@ -16,17 +16,14 @@
 # under the License.
 from __future__ import annotations
 
-import os
 import typing
 
-from airflow.configuration import conf
-from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
+from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE, 
OpenLineageAdapter
+from airflow.providers.openlineage.utils.utils import get_job_name
 
 if typing.TYPE_CHECKING:
     from airflow.models import TaskInstance
 
-_JOB_NAMESPACE = conf.get("openlineage", "namespace", 
fallback=os.getenv("OPENLINEAGE_NAMESPACE", "default"))
-
 
 def lineage_run_id(task_instance: TaskInstance):
     """
@@ -46,21 +43,18 @@ def lineage_run_id(task_instance: TaskInstance):
     )
 
 
-def lineage_parent_id(run_id: str, task_instance: TaskInstance):
+def lineage_parent_id(task_instance: TaskInstance):
     """
-    Macro function which returns the generated job and run id for a given task.
+    Macro function which returns a unique identifier of given task that can be 
used to create ParentRunFacet.
 
-    This can be used to forward the ids from a task to a child run so the job
-    hierarchy is preserved. Child run can create ParentRunFacet from those ids.
+    This identifier is composed of the namespace, job name, and generated run 
id for given task, structured
+    as '{namespace}/{job_name}/{run_id}'. This can be used to forward task 
information from a task to a child
+    run so the job hierarchy is preserved. Child run can easily create 
ParentRunFacet from these information.
 
     .. seealso::
         For more information on how to use this macro, take a look at the 
guide:
         :ref:`howto/macros:openlineage`
     """
-    job_name = OpenLineageAdapter.build_task_instance_run_id(
-        dag_id=task_instance.dag_id,
-        task_id=task_instance.task.task_id,
-        execution_date=task_instance.execution_date,
-        try_number=task_instance.try_number,
-    )
-    return f"{_JOB_NAMESPACE}/{job_name}/{run_id}"
+    job_name = get_job_name(task_instance.task)
+    run_id = lineage_run_id(task_instance)
+    return f"{_DAG_NAMESPACE}/{job_name}/{run_id}"
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index 1f6c723883..4f8cfbff71 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -24,7 +24,6 @@ import os
 from contextlib import suppress
 from functools import wraps
 from typing import TYPE_CHECKING, Any, Iterable
-from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
 
 import attrs
 from attrs import asdict
@@ -42,101 +41,19 @@ from airflow.utils.context import 
AirflowContextDeprecationWarning
 from airflow.utils.log.secrets_masker import Redactable, Redacted, 
SecretsMasker, should_hide_value_for_key
 
 if TYPE_CHECKING:
-    from airflow.models import DAG, BaseOperator, Connection, DagRun, 
TaskInstance
+    from airflow.models import DAG, BaseOperator, DagRun, TaskInstance
 
 
 log = logging.getLogger(__name__)
 _NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
 
 
-def openlineage_job_name(dag_id: str, task_id: str) -> str:
-    return f"{dag_id}.{task_id}"
-
-
 def get_operator_class(task: BaseOperator) -> type:
     if task.__class__.__name__ in ("DecoratedMappedOperator", 
"MappedOperator"):
         return task.operator_class
     return task.__class__
 
 
-def to_json_encodable(task: BaseOperator) -> dict[str, object]:
-    def _task_encoder(obj):
-        from airflow.models import DAG
-
-        if isinstance(obj, datetime.datetime):
-            return obj.isoformat()
-        elif isinstance(obj, DAG):
-            return {
-                "dag_id": obj.dag_id,
-                "tags": obj.tags,
-                "schedule_interval": obj.schedule_interval,
-                "timetable": obj.timetable.serialize(),
-            }
-        else:
-            return str(obj)
-
-    return json.loads(json.dumps(task.__dict__, default=_task_encoder))
-
-
-def url_to_https(url) -> str | None:
-    # Ensure URL exists
-    if not url:
-        return None
-
-    base_url = None
-    if url.startswith("git@"):
-        part = url.split("git@")[1:2]
-        if part:
-            base_url = f'https://{part[0].replace(":", "/", 1)}'
-    elif url.startswith("https://";):
-        base_url = url
-
-    if not base_url:
-        raise ValueError(f"Unable to extract location from: {url}")
-
-    if base_url.endswith(".git"):
-        base_url = base_url[:-4]
-    return base_url
-
-
-def redacted_connection_uri(conn: Connection, filtered_params=None, 
filtered_prefixes=None):
-    """
-    Return the connection URI for the given Connection.
-
-    This method additionally filters URI by removing query parameters that are 
known to carry sensitive data
-    like username, password, access key.
-    """
-    if filtered_prefixes is None:
-        filtered_prefixes = []
-    if filtered_params is None:
-        filtered_params = []
-
-    def filter_key_params(k: str):
-        return k not in filtered_params and any(substr in k for substr in 
filtered_prefixes)
-
-    conn_uri = conn.get_uri()
-    parsed = urlparse(conn_uri)
-
-    # Remove username and password
-    netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "")
-    parsed = parsed._replace(netloc=netloc)
-    if parsed.query:
-        query_dict = dict(parse_qsl(parsed.query))
-        if conn.EXTRA_KEY in query_dict:
-            query_dict = json.loads(query_dict[conn.EXTRA_KEY])
-        filtered_qs = {k: v for k, v in query_dict.items() if not 
filter_key_params(k)}
-        parsed = parsed._replace(query=urlencode(filtered_qs))
-    return urlunparse(parsed)
-
-
-def get_connection(conn_id) -> Connection | None:
-    from airflow.hooks.base import BaseHook
-
-    with suppress(Exception):
-        return BaseHook.get_connection(conn_id=conn_id)
-    return None
-
-
 def get_job_name(task):
     return f"{task.dag_id}.{task.task_id}"
 
diff --git a/tests/providers/openlineage/plugins/test_macros.py 
b/tests/providers/openlineage/plugins/test_macros.py
index bea73628c1..415cea36e4 100644
--- a/tests/providers/openlineage/plugins/test_macros.py
+++ b/tests/providers/openlineage/plugins/test_macros.py
@@ -37,16 +37,15 @@ def test_lineage_run_id():
     assert actual == expected
 
 
-def test_lineage_parent_id():
[email protected]("airflow.providers.openlineage.plugins.macros.lineage_run_id")
+def test_lineage_parent_id(mock_run_id):
+    mock_run_id.return_value = "run_id"
     task = mock.MagicMock(
-        dag_id="dag_id", execution_date="execution_date", try_number=1, 
task=mock.MagicMock(task_id="task_id")
-    )
-    actual = lineage_parent_id(run_id="run_id", task_instance=task)
-    job_name = str(
-        uuid.uuid3(
-            uuid.NAMESPACE_URL,
-            f"{_DAG_NAMESPACE}.dag_id.task_id.execution_date.1",
-        )
+        dag_id="dag_id",
+        execution_date="execution_date",
+        try_number=1,
+        task=mock.MagicMock(task_id="task_id", dag_id="dag_id"),
     )
-    expected = f"{_DAG_NAMESPACE}/{job_name}/run_id"
+    actual = lineage_parent_id(task_instance=task)
+    expected = f"{_DAG_NAMESPACE}/dag_id.task_id/run_id"
     assert actual == expected
diff --git a/tests/providers/openlineage/plugins/test_utils.py 
b/tests/providers/openlineage/plugins/test_utils.py
index b7ced7a37c..9984007083 100644
--- a/tests/providers/openlineage/plugins/test_utils.py
+++ b/tests/providers/openlineage/plugins/test_utils.py
@@ -18,7 +18,6 @@ from __future__ import annotations
 
 import datetime
 import json
-import os
 import uuid
 from json import JSONEncoder
 from typing import Any
@@ -29,23 +28,15 @@ from openlineage.client.utils import RedactMixin
 from pkg_resources import parse_version
 
 from airflow.models import DAG as AIRFLOW_DAG, DagModel
-from airflow.operators.empty import EmptyOperator
 from airflow.providers.openlineage.utils.utils import (
     InfoJsonEncodable,
     OpenLineageRedactor,
     _is_name_redactable,
-    get_connection,
-    to_json_encodable,
-    url_to_https,
 )
 from airflow.utils import timezone
 from airflow.utils.log.secrets_masker import _secrets_masker
 from airflow.utils.state import State
 
-AIRFLOW_CONN_ID = "test_db"
-AIRFLOW_CONN_URI = "postgres://localhost:5432/testdb"
-SNOWFLAKE_CONN_URI = 
"snowflake://12345.us-east-1.snowflakecomputing.com/MyTestRole?extra__snowflake__account=12345&extra__snowflake__database=TEST_DB&extra__snowflake__insecure_mode=false&extra__snowflake__region=us-east-1&extra__snowflake__role=MyTestRole&extra__snowflake__warehouse=TEST_WH&extra__snowflake__aws_access_key_id=123456&extra__snowflake__aws_secret_access_key=abcdefg"
-
 
 class SafeStrDict(dict):
     def __str__(self):
@@ -59,21 +50,6 @@ class SafeStrDict(dict):
         return str(dict(castable))
 
 
-def test_get_connection():
-    os.environ["AIRFLOW_CONN_DEFAULT"] = AIRFLOW_CONN_URI
-
-    conn = get_connection("default")
-    assert conn.host == "localhost"
-    assert conn.port == 5432
-    assert conn.conn_type == "postgres"
-    assert conn
-
-
-def test_url_to_https_no_url():
-    assert url_to_https(None) is None
-    assert url_to_https("") is None
-
-
 @pytest.mark.db_test
 def test_get_dagrun_start_end():
     start_date = datetime.datetime(2022, 1, 1)
@@ -105,18 +81,6 @@ def test_parse_version():
     assert parse_version("2.2.4.dev0") < parse_version("2.3.0.dev0")
 
 
-def test_to_json_encodable():
-    dag = AIRFLOW_DAG(
-        dag_id="test_dag", schedule_interval="*/2 * * * *", 
start_date=datetime.datetime.now(), catchup=False
-    )
-    task = EmptyOperator(task_id="test_task", dag=dag)
-
-    encodable = to_json_encodable(task)
-    encoded = json.dumps(encodable)
-    decoded = json.loads(encoded)
-    assert decoded == encodable
-
-
 def test_safe_dict():
     assert str(SafeStrDict({"a": 1})) == str({"a": 1})
 

Reply via email to