This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2852976ea6 fix: Fix parent id macro and remove unused utils (#37877)
2852976ea6 is described below
commit 2852976ea6321b152ebc631d30d5526703bc6590
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Mar 5 14:11:47 2024 +0100
fix: Fix parent id macro and remove unused utils (#37877)
---
airflow/providers/openlineage/plugins/macros.py | 26 +++----
airflow/providers/openlineage/utils/utils.py | 85 +---------------------
tests/providers/openlineage/plugins/test_macros.py | 19 +++--
tests/providers/openlineage/plugins/test_utils.py | 36 ---------
4 files changed, 20 insertions(+), 146 deletions(-)
diff --git a/airflow/providers/openlineage/plugins/macros.py
b/airflow/providers/openlineage/plugins/macros.py
index a4039db2f4..fa05a60386 100644
--- a/airflow/providers/openlineage/plugins/macros.py
+++ b/airflow/providers/openlineage/plugins/macros.py
@@ -16,17 +16,14 @@
# under the License.
from __future__ import annotations
-import os
import typing
-from airflow.configuration import conf
-from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
+from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE,
OpenLineageAdapter
+from airflow.providers.openlineage.utils.utils import get_job_name
if typing.TYPE_CHECKING:
from airflow.models import TaskInstance
-_JOB_NAMESPACE = conf.get("openlineage", "namespace",
fallback=os.getenv("OPENLINEAGE_NAMESPACE", "default"))
-
def lineage_run_id(task_instance: TaskInstance):
"""
@@ -46,21 +43,18 @@ def lineage_run_id(task_instance: TaskInstance):
)
-def lineage_parent_id(run_id: str, task_instance: TaskInstance):
+def lineage_parent_id(task_instance: TaskInstance):
"""
- Macro function which returns the generated job and run id for a given task.
+ Macro function which returns a unique identifier of given task that can be
used to create ParentRunFacet.
- This can be used to forward the ids from a task to a child run so the job
- hierarchy is preserved. Child run can create ParentRunFacet from those ids.
+ This identifier is composed of the namespace, job name, and generated run
id for given task, structured
+ as '{namespace}/{job_name}/{run_id}'. This can be used to forward task
information from a task to a child
+ run so the job hierarchy is preserved. Child run can easily create
ParentRunFacet from these information.
.. seealso::
For more information on how to use this macro, take a look at the
guide:
:ref:`howto/macros:openlineage`
"""
- job_name = OpenLineageAdapter.build_task_instance_run_id(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task.task_id,
- execution_date=task_instance.execution_date,
- try_number=task_instance.try_number,
- )
- return f"{_JOB_NAMESPACE}/{job_name}/{run_id}"
+ job_name = get_job_name(task_instance.task)
+ run_id = lineage_run_id(task_instance)
+ return f"{_DAG_NAMESPACE}/{job_name}/{run_id}"
diff --git a/airflow/providers/openlineage/utils/utils.py
b/airflow/providers/openlineage/utils/utils.py
index 1f6c723883..4f8cfbff71 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -24,7 +24,6 @@ import os
from contextlib import suppress
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable
-from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import attrs
from attrs import asdict
@@ -42,101 +41,19 @@ from airflow.utils.context import
AirflowContextDeprecationWarning
from airflow.utils.log.secrets_masker import Redactable, Redacted,
SecretsMasker, should_hide_value_for_key
if TYPE_CHECKING:
- from airflow.models import DAG, BaseOperator, Connection, DagRun,
TaskInstance
+ from airflow.models import DAG, BaseOperator, DagRun, TaskInstance
log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
-def openlineage_job_name(dag_id: str, task_id: str) -> str:
- return f"{dag_id}.{task_id}"
-
-
def get_operator_class(task: BaseOperator) -> type:
if task.__class__.__name__ in ("DecoratedMappedOperator",
"MappedOperator"):
return task.operator_class
return task.__class__
-def to_json_encodable(task: BaseOperator) -> dict[str, object]:
- def _task_encoder(obj):
- from airflow.models import DAG
-
- if isinstance(obj, datetime.datetime):
- return obj.isoformat()
- elif isinstance(obj, DAG):
- return {
- "dag_id": obj.dag_id,
- "tags": obj.tags,
- "schedule_interval": obj.schedule_interval,
- "timetable": obj.timetable.serialize(),
- }
- else:
- return str(obj)
-
- return json.loads(json.dumps(task.__dict__, default=_task_encoder))
-
-
-def url_to_https(url) -> str | None:
- # Ensure URL exists
- if not url:
- return None
-
- base_url = None
- if url.startswith("git@"):
- part = url.split("git@")[1:2]
- if part:
- base_url = f'https://{part[0].replace(":", "/", 1)}'
- elif url.startswith("https://"):
- base_url = url
-
- if not base_url:
- raise ValueError(f"Unable to extract location from: {url}")
-
- if base_url.endswith(".git"):
- base_url = base_url[:-4]
- return base_url
-
-
-def redacted_connection_uri(conn: Connection, filtered_params=None,
filtered_prefixes=None):
- """
- Return the connection URI for the given Connection.
-
- This method additionally filters URI by removing query parameters that are
known to carry sensitive data
- like username, password, access key.
- """
- if filtered_prefixes is None:
- filtered_prefixes = []
- if filtered_params is None:
- filtered_params = []
-
- def filter_key_params(k: str):
- return k not in filtered_params and any(substr in k for substr in
filtered_prefixes)
-
- conn_uri = conn.get_uri()
- parsed = urlparse(conn_uri)
-
- # Remove username and password
- netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "")
- parsed = parsed._replace(netloc=netloc)
- if parsed.query:
- query_dict = dict(parse_qsl(parsed.query))
- if conn.EXTRA_KEY in query_dict:
- query_dict = json.loads(query_dict[conn.EXTRA_KEY])
- filtered_qs = {k: v for k, v in query_dict.items() if not
filter_key_params(k)}
- parsed = parsed._replace(query=urlencode(filtered_qs))
- return urlunparse(parsed)
-
-
-def get_connection(conn_id) -> Connection | None:
- from airflow.hooks.base import BaseHook
-
- with suppress(Exception):
- return BaseHook.get_connection(conn_id=conn_id)
- return None
-
-
def get_job_name(task):
return f"{task.dag_id}.{task.task_id}"
diff --git a/tests/providers/openlineage/plugins/test_macros.py
b/tests/providers/openlineage/plugins/test_macros.py
index bea73628c1..415cea36e4 100644
--- a/tests/providers/openlineage/plugins/test_macros.py
+++ b/tests/providers/openlineage/plugins/test_macros.py
@@ -37,16 +37,15 @@ def test_lineage_run_id():
assert actual == expected
-def test_lineage_parent_id():
[email protected]("airflow.providers.openlineage.plugins.macros.lineage_run_id")
+def test_lineage_parent_id(mock_run_id):
+ mock_run_id.return_value = "run_id"
task = mock.MagicMock(
- dag_id="dag_id", execution_date="execution_date", try_number=1,
task=mock.MagicMock(task_id="task_id")
- )
- actual = lineage_parent_id(run_id="run_id", task_instance=task)
- job_name = str(
- uuid.uuid3(
- uuid.NAMESPACE_URL,
- f"{_DAG_NAMESPACE}.dag_id.task_id.execution_date.1",
- )
+ dag_id="dag_id",
+ execution_date="execution_date",
+ try_number=1,
+ task=mock.MagicMock(task_id="task_id", dag_id="dag_id"),
)
- expected = f"{_DAG_NAMESPACE}/{job_name}/run_id"
+ actual = lineage_parent_id(task_instance=task)
+ expected = f"{_DAG_NAMESPACE}/dag_id.task_id/run_id"
assert actual == expected
diff --git a/tests/providers/openlineage/plugins/test_utils.py
b/tests/providers/openlineage/plugins/test_utils.py
index b7ced7a37c..9984007083 100644
--- a/tests/providers/openlineage/plugins/test_utils.py
+++ b/tests/providers/openlineage/plugins/test_utils.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import datetime
import json
-import os
import uuid
from json import JSONEncoder
from typing import Any
@@ -29,23 +28,15 @@ from openlineage.client.utils import RedactMixin
from pkg_resources import parse_version
from airflow.models import DAG as AIRFLOW_DAG, DagModel
-from airflow.operators.empty import EmptyOperator
from airflow.providers.openlineage.utils.utils import (
InfoJsonEncodable,
OpenLineageRedactor,
_is_name_redactable,
- get_connection,
- to_json_encodable,
- url_to_https,
)
from airflow.utils import timezone
from airflow.utils.log.secrets_masker import _secrets_masker
from airflow.utils.state import State
-AIRFLOW_CONN_ID = "test_db"
-AIRFLOW_CONN_URI = "postgres://localhost:5432/testdb"
-SNOWFLAKE_CONN_URI =
"snowflake://12345.us-east-1.snowflakecomputing.com/MyTestRole?extra__snowflake__account=12345&extra__snowflake__database=TEST_DB&extra__snowflake__insecure_mode=false&extra__snowflake__region=us-east-1&extra__snowflake__role=MyTestRole&extra__snowflake__warehouse=TEST_WH&extra__snowflake__aws_access_key_id=123456&extra__snowflake__aws_secret_access_key=abcdefg"
-
class SafeStrDict(dict):
def __str__(self):
@@ -59,21 +50,6 @@ class SafeStrDict(dict):
return str(dict(castable))
-def test_get_connection():
- os.environ["AIRFLOW_CONN_DEFAULT"] = AIRFLOW_CONN_URI
-
- conn = get_connection("default")
- assert conn.host == "localhost"
- assert conn.port == 5432
- assert conn.conn_type == "postgres"
- assert conn
-
-
-def test_url_to_https_no_url():
- assert url_to_https(None) is None
- assert url_to_https("") is None
-
-
@pytest.mark.db_test
def test_get_dagrun_start_end():
start_date = datetime.datetime(2022, 1, 1)
@@ -105,18 +81,6 @@ def test_parse_version():
assert parse_version("2.2.4.dev0") < parse_version("2.3.0.dev0")
-def test_to_json_encodable():
- dag = AIRFLOW_DAG(
- dag_id="test_dag", schedule_interval="*/2 * * * *",
start_date=datetime.datetime.now(), catchup=False
- )
- task = EmptyOperator(task_id="test_task", dag=dag)
-
- encodable = to_json_encodable(task)
- encoded = json.dumps(encodable)
- decoded = json.loads(encoded)
- assert decoded == encodable
-
-
def test_safe_dict():
assert str(SafeStrDict({"a": 1})) == str({"a": 1})