ashb commented on code in PR #29940: URL: https://github.com/apache/airflow/pull/29940#discussion_r1133955748
########## airflow/providers/openlineage/extractors/base.py: ########## @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import Factory, define + +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client.facet import BaseFacet +from openlineage.client.run import Dataset + + +@define +class OperatorLineage: + """Structure returned from lineage extraction.""" + + inputs: list[Dataset] = Factory(list) + outputs: list[Dataset] = Factory(list) + run_facets: dict[str, BaseFacet] = Factory(dict) + job_facets: dict[str, BaseFacet] = Factory(dict) + + +class BaseExtractor(ABC, LoggingMixin): + """ + Abstract base extractor class. + + This is used mostly to maintain support for custom extractors. + """ + + _allowed_query_params: list[str] = [] + + def __init__(self, operator): + super().__init__() + self.operator = operator + self.patch() + + def patch(self): + # Extractor should register extension methods or patches to operator here + pass + + @classmethod + def get_operator_classnames(cls) -> list[str]: + """ + Implement this method returning list of operators that extractor works for. + Particularly, in Airflow 2 some operators are deprecated and simply subclass the new + implementation, for example BigQueryOperator: + https://github.com/apache/airflow/blob/main/airflow/contrib/operators/bigquery_operator.py + The BigQueryExtractor needs to work with both of them. + :return: + """ + raise NotImplementedError() + + def validate(self): + assert self.operator.__class__.__name__ in self.get_operator_classnames() + + @abstractmethod + def extract(self) -> OperatorLineage | None: + pass + + def extract_on_complete(self, task_instance) -> OperatorLineage | None: + return self.extract() + + @classmethod + def get_connection_uri(cls, conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. Review Comment: Firstly: This doc string doesn't really match the implementation. Secondly: How does this differ from just `conn.get_uri()` It's not clear from reading the function alone. Please add some comments saying why that isn't suitable (or is it?) ########## airflow/providers/openlineage/extractors/base.py: ########## @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import Factory, define + +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client.facet import BaseFacet +from openlineage.client.run import Dataset + + +@define +class OperatorLineage: + """Structure returned from lineage extraction.""" + + inputs: list[Dataset] = Factory(list) + outputs: list[Dataset] = Factory(list) + run_facets: dict[str, BaseFacet] = Factory(dict) + job_facets: dict[str, BaseFacet] = Factory(dict) + + +class BaseExtractor(ABC, LoggingMixin): + """ + Abstract base extractor class. + + This is used mostly to maintain support for custom extractors. + """ + + _allowed_query_params: list[str] = [] + + def __init__(self, operator): + super().__init__() + self.operator = operator + self.patch() + + def patch(self): + # Extractor should register extension methods or patches to operator here + pass + + @classmethod + def get_operator_classnames(cls) -> list[str]: + """ + Implement this method returning list of operators that extractor works for. + Particularly, in Airflow 2 some operators are deprecated and simply subclass the new + implementation, for example BigQueryOperator: + https://github.com/apache/airflow/blob/main/airflow/contrib/operators/bigquery_operator.py + The BigQueryExtractor needs to work with both of them. + :return: + """ + raise NotImplementedError() + + def validate(self): + assert self.operator.__class__.__name__ in self.get_operator_classnames() Review Comment: This would mess up with mapped operator I think ```suggestion assert self.operator.task_type in self.get_operator_classnames() ``` ########## airflow/providers/openlineage/extractors/extractors.py: ########## @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os + +from airflow.providers.openlineage.extractors.base import BaseExtractor, DefaultExtractor + + +class Extractors: + """ + This exposes implemented extractors, while hiding ones that require additional, unmet + dependency. Patchers are a category of extractor that needs to hook up to operator's + internals during DAG creation. + """ + + def __init__(self): + # Do not expose extractors relying on external dependencies that are not installed + self.extractors = {} + self.default_extractor = DefaultExtractor + + # Comma-separated extractors in OPENLINEAGE_EXTRACTORS variable. + # Extractors should implement BaseExtractor + from airflow.providers.openlineage.utils import import_from_string + + env_extractors = os.getenv("OPENLINEAGE_EXTRACTORS") Review Comment: Possibly this should now be `conf.get("openlineage", "extractors")` or similar now. ########## airflow/providers/openlineage/plugins/adapter.py: ########## @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import uuid +from typing import TYPE_CHECKING + +import requests.exceptions + +from airflow.providers.openlineage import version as OPENLINEAGE_PROVIDER_VERSION +from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.openlineage.utils import redact_with_exclusions +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client import OpenLineageClient, set_producer +from openlineage.client.facet import ( + BaseFacet, + DocumentationJobFacet, + ErrorMessageRunFacet, + NominalTimeRunFacet, + OwnershipJobFacet, + OwnershipJobFacetOwners, + ParentRunFacet, + ProcessingEngineRunFacet, + SourceCodeLocationJobFacet, +) +from openlineage.client.run import Job, Run, RunEvent, RunState + +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + + +_DAG_DEFAULT_NAMESPACE = "default" + +_DAG_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE) + +_PRODUCER = f"https://github.com/apache/airflow/tree/providers-openlineage/" f"{OPENLINEAGE_PROVIDER_VERSION}" + +set_producer(_PRODUCER) + + +class OpenLineageAdapter(LoggingMixin): + """ + Adapter for translating Airflow metadata to OpenLineage events, + instead of directly creating them from Airflow code. + """ + + def __init__(self, client=None): + super().__init__() + self._client = client + + def get_or_create_openlineage_client(self) -> OpenLineageClient: + if not self._client: + self._client = OpenLineageClient.from_environment() + return self._client + + def build_dag_run_id(self, dag_id, dag_run_id): + return str(uuid.uuid3(uuid.NAMESPACE_URL, f"{_DAG_NAMESPACE}.{dag_id}.{dag_run_id}")) + + @staticmethod + def build_task_instance_run_id(task_id, execution_date, try_number): + return str( + uuid.uuid3( + uuid.NAMESPACE_URL, + f"{_DAG_NAMESPACE}.{task_id}.{execution_date}.{try_number}", + ) + ) + + def emit(self, event: RunEvent): + event = redact_with_exclusions(event) + try: + return self.get_or_create_openlineage_client().emit(event) + except requests.exceptions.RequestException: + self.log.exception(f"Failed to emit OpenLineage event of id {event.run.runId}") + + def start_task( + self, + run_id: str, + job_name: str, + job_description: str, + event_time: str, + parent_job_name: str | None, + parent_run_id: str | None, + code_location: str | None, + nominal_start_time: str, + nominal_end_time: str, + owners: list[str], + task: OperatorLineage | None, + run_facets: dict[str, type[BaseFacet]] | None = None, # Custom run facets + ) -> str: + """ + Emits openlineage event of type START Review Comment: ```suggestion Emits openlineage event of type START ``` ########## airflow/providers/openlineage/plugins/facets.py: ########## @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from attrs import define, field + +from airflow.providers.openlineage import version as OPENLINEAGE_AIRFLOW_VERSION +from airflow.version import version as AIRFLOW_VERSION +from openlineage.client.facet import BaseFacet +from openlineage.client.utils import RedactMixin + + +@define(slots=False) +class AirflowVersionRunFacet(BaseFacet): + """Run facet containing task and DAG info""" + + operator: str = field() + taskInfo: dict[str, object] = field() + airflowVersion: str = field() + openlineageAirflowVersion: str = field() + + _additional_skip_redact: list[str] = [ + "operator", + "airflowVersion", + "openlineageAirflowVersion", + ] + + @classmethod + def from_dagrun_and_task(cls, dagrun, task): + # task.__dict__ may contain values uncastable to str + from airflow.providers.openlineage.utils import get_operator_class, to_json_encodable + + task_info = to_json_encodable(task) + task_info["dag_run"] = to_json_encodable(dagrun) + + return cls( + operator=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + taskInfo=task_info, + airflowVersion=AIRFLOW_VERSION, + openlineageAirflowVersion=OPENLINEAGE_AIRFLOW_VERSION, + ) + + +@define(slots=False) +class AirflowRunArgsRunFacet(BaseFacet): + """Run facet pointing if DAG was triggered manually""" + + externalTrigger: bool = field(default=False) + + _additional_skip_redact: list[str] = ["externalTrigger"] + + +@define(slots=False) +class AirflowMappedTaskRunFacet(BaseFacet): + """Run facet containing information about mapped tasks""" + + mapIndex: int = field() + operatorClass: str = field() + + _additional_skip_redact: list[str] = ["operatorClass"] + + @classmethod + def from_task_instance(cls, task_instance): + task = task_instance.task + from airflow.providers.openlineage.utils import get_operator_class + + return cls( + mapIndex=task_instance.map_index, + operatorClass=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + ) + + +@define(slots=False) +class AirflowRunFacet(BaseFacet): + """Composite Airflow run facet.""" + + dag: dict = field() + dagRun: dict = field() + task: dict = field() + taskInstance: dict = field() + taskUuid: str = field() Review Comment: `= field()` shouldn't be needed here ```suggestion dag: dict dagRun: dict task: dict taskInstance: dict taskUuid: str ``` ########## airflow/providers/openlineage/plugins/listener.py: ########## @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import TYPE_CHECKING + +from airflow.listeners import hookimpl +from airflow.providers.openlineage.extractors import ExtractorManager +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter +from airflow.providers.openlineage.utils import ( + get_airflow_run_facet, + get_custom_facets, + get_job_name, + get_task_location, + print_exception, +) + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models import DagRun, TaskInstance + + +class OpenLineageListener: + """ + OpenLineage listener + Sends events on task instance and dag run starts, completes and failures. + """ + + def __init__(self): + self.log = logging.getLogger(__name__) + self.executor: Executor = None # type: ignore + self.extractor_manager = ExtractorManager() + self.adapter = OpenLineageAdapter() + + @hookimpl + def on_task_instance_running( + self, previous_state, task_instance: TaskInstance, session: Session # This will always be QUEUED + ): + if not hasattr(task_instance, "task"): + self.log.warning( + f"No task set for TI object task_id: {task_instance.task_id} - " + f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" + ) + return + + self.log.debug("OpenLineage listener got notification about task instance start") + dagrun = task_instance.dag_run + task = task_instance.task + dag = task.dag + + @print_exception + def on_running(): + # that's a workaround to detect task running from deferred state + # we return here because Airflow 2.3 needs task from deferred state + if task_instance.next_method is not None: + return + parent_run_id = self.adapter.build_dag_run_id(dag.dag_id, dagrun.run_id) + + task_uuid = self.adapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number + ) + + task_metadata = self.extractor_manager.extract_metadata(dagrun, task) + + self.adapter.start_task( + run_id=task_uuid, + job_name=get_job_name(task), + job_description=dag.description, + event_time=task_instance.start_date.isoformat(), + parent_job_name=dag.dag_id, + parent_run_id=parent_run_id, + code_location=get_task_location(task), + nominal_start_time=dagrun.data_interval_start.isoformat(), + nominal_end_time=dagrun.data_interval_end.isoformat(), + owners=dag.owner.split(", "), + task=task_metadata, + run_facets={ + **task_metadata.run_facets, + **get_custom_facets(dagrun, task, dagrun.external_trigger, task_instance), + **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid), + }, + ) + + self.executor.submit(on_running) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session): + self.log.debug("OpenLineage listener got notification about task instance success") + + dagrun = task_instance.dag_run + task = task_instance.task + + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number - 1 + ) + + @print_exception + def on_success(): + task_metadata = self.extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance + ) + self.adapter.complete_task( + run_id=task_uuid, + job_name=get_job_name(task), + end_time=task_instance.end_date.isoformat(), + task=task_metadata, + ) + + self.executor.submit(on_success) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session): + self.log.debug("OpenLineage listener got notification about task instance failure") + + dagrun = task_instance.dag_run + task = task_instance.task + + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number - 1 + ) + + @print_exception + def on_failure(): + task_metadata = self.extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance + ) + + self.adapter.fail_task( + run_id=task_uuid, + job_name=get_job_name(task), + end_time=task_instance.end_date.isoformat(), + task=task_metadata, + ) + + self.executor.submit(on_failure) + + @hookimpl + def on_starting(self, component): + self.log.debug("on_starting: %s", component.__class__.__name__) + self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") Review Comment: That's a lot of max workers. Does it ever need more than 1? ########## airflow/providers/openlineage/plugins/facets.py: ########## @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from attrs import define, field + +from airflow.providers.openlineage import version as OPENLINEAGE_AIRFLOW_VERSION +from airflow.version import version as AIRFLOW_VERSION +from openlineage.client.facet import BaseFacet +from openlineage.client.utils import RedactMixin + + +@define(slots=False) +class AirflowVersionRunFacet(BaseFacet): + """Run facet containing task and DAG info""" + + operator: str = field() + taskInfo: dict[str, object] = field() + airflowVersion: str = field() + openlineageAirflowVersion: str = field() + + _additional_skip_redact: list[str] = [ + "operator", + "airflowVersion", + "openlineageAirflowVersion", + ] + + @classmethod + def from_dagrun_and_task(cls, dagrun, task): + # task.__dict__ may contain values uncastable to str + from airflow.providers.openlineage.utils import get_operator_class, to_json_encodable + + task_info = to_json_encodable(task) + task_info["dag_run"] = to_json_encodable(dagrun) + + return cls( + operator=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + taskInfo=task_info, + airflowVersion=AIRFLOW_VERSION, + openlineageAirflowVersion=OPENLINEAGE_AIRFLOW_VERSION, + ) + + +@define(slots=False) +class AirflowRunArgsRunFacet(BaseFacet): + """Run facet pointing if DAG was triggered manually""" + + externalTrigger: bool = field(default=False) + + _additional_skip_redact: list[str] = ["externalTrigger"] + + +@define(slots=False) +class AirflowMappedTaskRunFacet(BaseFacet): + """Run facet containing information about mapped tasks""" + + mapIndex: int = field() + operatorClass: str = field() Review Comment: ```suggestion mapIndex: int operatorClass: str ########## airflow/providers/openlineage/plugins/macros.py: ########## @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import typing + +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + +if typing.TYPE_CHECKING: + from airflow.models import BaseOperator, TaskInstance + +_JOB_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", "default") + + +def lineage_run_id(task: BaseOperator, task_instance: TaskInstance): + """ + Macro function which returns the generated run id for a given task. This + can be used to forward the run id from a task to a child run so the job + hierarchy is preserved. Invoke as a jinja template, e.g. + + PythonOperator( + task_id='render_template', + python_callable=my_task_function, + op_args=['{{ lineage_run_id(task, task_instance) }}'], # lineage_run_id macro invoked Review Comment: Do you need to pass task and ti? Couldn't you look at `ti.task`? ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin Review Comment: Yes please! ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): Review Comment: Second copy of this function in this PR :) ########## airflow/providers/openlineage/plugins/adapter.py: ########## @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import uuid +from typing import TYPE_CHECKING + +import requests.exceptions + +from airflow.providers.openlineage import version as OPENLINEAGE_PROVIDER_VERSION +from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.openlineage.utils import redact_with_exclusions +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client import OpenLineageClient, set_producer +from openlineage.client.facet import ( + BaseFacet, + DocumentationJobFacet, + ErrorMessageRunFacet, + NominalTimeRunFacet, + OwnershipJobFacet, + OwnershipJobFacetOwners, + ParentRunFacet, + ProcessingEngineRunFacet, + SourceCodeLocationJobFacet, +) +from openlineage.client.run import Job, Run, RunEvent, RunState + +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + + +_DAG_DEFAULT_NAMESPACE = "default" + +_DAG_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE) + +_PRODUCER = f"https://github.com/apache/airflow/tree/providers-openlineage/" f"{OPENLINEAGE_PROVIDER_VERSION}" + +set_producer(_PRODUCER) + + +class OpenLineageAdapter(LoggingMixin): + """ + Adapter for translating Airflow metadata to OpenLineage events, + instead of directly creating them from Airflow code. + """ + + def __init__(self, client=None): + super().__init__() + self._client = client + + def get_or_create_openlineage_client(self) -> OpenLineageClient: + if not self._client: + self._client = OpenLineageClient.from_environment() + return self._client + + def build_dag_run_id(self, dag_id, dag_run_id): + return str(uuid.uuid3(uuid.NAMESPACE_URL, f"{_DAG_NAMESPACE}.{dag_id}.{dag_run_id}")) + + @staticmethod + def build_task_instance_run_id(task_id, execution_date, try_number): + return str( + uuid.uuid3( + uuid.NAMESPACE_URL, + f"{_DAG_NAMESPACE}.{task_id}.{execution_date}.{try_number}", + ) + ) + + def emit(self, event: RunEvent): + event = redact_with_exclusions(event) + try: + return self.get_or_create_openlineage_client().emit(event) + except requests.exceptions.RequestException: + self.log.exception(f"Failed to emit OpenLineage event of id {event.run.runId}") + + def start_task( + self, + run_id: str, + job_name: str, + job_description: str, + event_time: str, + parent_job_name: str | None, + parent_run_id: str | None, + code_location: str | None, + nominal_start_time: str, + nominal_end_time: str, + owners: list[str], + task: OperatorLineage | None, + run_facets: dict[str, type[BaseFacet]] | None = None, # Custom run facets + ) -> str: + """ + Emits openlineage event of type START + :param run_id: globally unique identifier of task in dag run + :param job_name: globally unique identifier of task in dag + :param job_description: user provided description of job + :param event_time: + :param parent_job_name: the name of the parent job (typically the DAG, + but possibly a task group) + :param parent_run_id: identifier of job spawning this task + :param code_location: file path or URL of DAG file + :param nominal_start_time: scheduled time of dag run + :param nominal_end_time: following schedule of dag run + :param owners: list of owners of DAG + :param task: metadata container with information extracted from operator + :param run_facets: custom run facets + :return: Review Comment: ```suggestion :return: OpenLineage event run ID ``` (I guessed) ########## airflow/providers/openlineage/plugins/adapter.py: ########## @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import uuid +from typing import TYPE_CHECKING + +import requests.exceptions + +from airflow.providers.openlineage import version as OPENLINEAGE_PROVIDER_VERSION +from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.openlineage.utils import redact_with_exclusions +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client import OpenLineageClient, set_producer +from openlineage.client.facet import ( + BaseFacet, + DocumentationJobFacet, + ErrorMessageRunFacet, + NominalTimeRunFacet, + OwnershipJobFacet, + OwnershipJobFacetOwners, + ParentRunFacet, + ProcessingEngineRunFacet, + SourceCodeLocationJobFacet, +) +from openlineage.client.run import Job, Run, RunEvent, RunState + +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + + +_DAG_DEFAULT_NAMESPACE = "default" + +_DAG_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE) Review Comment: If you want to make it easier to upgrade you can do `conf.get("openlineage", "namespace", fallback=os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE))` ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. + """ + conn_uri = conn.get_uri() + parsed = urlparse(conn_uri) + + # Remove username and password + netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + parsed = parsed._replace(netloc=netloc) + if parsed.query: + query_dict = dict(parse_qsl(parsed.query)) + if conn.EXTRA_KEY in query_dict: + query_dict = json.loads(query_dict[conn.EXTRA_KEY]) + filtered_qs = {k: v for k, v in query_dict.items() if not _filtered_query_params(k)} + parsed = parsed._replace(query=urlencode(filtered_qs)) + return urlunparse(parsed) + + +def _filtered_query_params(k: str): + unfiltered_snowflake_keys = [ + "extra__snowflake__warehouse", + "extra__snowflake__account", + "extra__snowflake__database", + ] + filtered_key_substrings = [ + "aws_access_key_id", + "aws_secret_access_key", + "extra__snowflake__", + ] + return k not in unfiltered_snowflake_keys and any(substr in k for substr in filtered_key_substrings) + + +def get_normalized_postgres_connection_uri(conn): + """ + URIs starting with postgresql:// and postgres:// are both valid + PostgreSQL connection strings. This function normalizes it to + postgres:// as canonical name according to OpenLineage spec. + """ + uri = get_connection_uri(conn) + if uri.startswith("postgresql"): + uri = uri.replace("postgresql", "postgres", 1) + return uri + + +def get_connection(conn_id) -> Connection | None: + from airflow.hooks.base import BaseHook + + try: + return BaseHook.get_connection(conn_id=conn_id) + except Exception: + return None + + +def get_job_name(task): + return f"{task.dag_id}.{task.task_id}" + + +def get_custom_facets( + dagrun, task, is_external_trigger: bool, task_instance: TaskInstance | None = None +) -> dict[str, Any]: + custom_facets = { + "airflow_runArgs": AirflowRunArgsRunFacet(is_external_trigger), + "airflow_version": AirflowVersionRunFacet.from_dagrun_and_task(dagrun, task), + } + # check for -1 comes from SmartSensor compatibility with dynamic task mapping + # this comes from Airflow code + if hasattr(task_instance, "map_index") and getattr(task_instance, "map_index") != -1: + custom_facets["airflow_mappedTask"] = AirflowMappedTaskRunFacet.from_task_instance(task_instance) + return custom_facets + + +class InfoJsonEncodable(dict): + """ + Airflow objects might not be json-encodable overall. + + The class provides additional attributes to control + what and how is encoded: + * renames: a dictionary of attribute name changes + * casts: a dictionary consisting of attribute names + and corresponding methods that should change + object value + * includes: list of attributes to be included in encoding + * excludes: list of attributes to be excluded from encoding + + Don't use both includes and excludes. + """ + + renames: dict[str, str] = {} + casts: dict[str, Any] = {} + includes: list[str] = [] + excludes: list[str] = [] + + def __init__(self, obj): + self.obj = obj + self._fields = [] + + self._cast_fields() + self._rename_fields() + self._include_fields() + dict.__init__( + self, + **{field: InfoJsonEncodable._cast_basic_types(getattr(self, field)) for field in self._fields}, + ) + + @staticmethod + def _cast_basic_types(value): + if isinstance(value, datetime.datetime): + return value.isoformat() + if isinstance(value, (set, list, tuple)): + return str(list(value)) + return value + + def _rename_fields(self): + for field, renamed in self.renames.items(): + if hasattr(self.obj, field): + setattr(self, renamed, getattr(self.obj, field)) + self._fields.append(renamed) + + def _cast_fields(self): + for field, func in self.casts.items(): + setattr(self, field, func(self.obj)) + self._fields.append(field) + + def _include_fields(self): + if self.includes and self.excludes: + raise Exception("Don't use both includes and excludes.") + if self.includes: + for field in self.includes: + if field in self._fields or not hasattr(self.obj, field): + continue + setattr(self, field, getattr(self.obj, field)) + self._fields.append(field) + else: + for field, val in self.obj.__dict__.items(): + if field in self._fields or field in self.excludes or field in self.renames: + continue + setattr(self, field, val) + self._fields.append(field) + + +class DagInfo(InfoJsonEncodable): + """Defines encoding DAG object to JSON.""" + + includes = ["dag_id", "schedule_interval", "tags", "start_date"] + casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + renames = {"_dag_id": "dag_id"} + + +class DagRunInfo(InfoJsonEncodable): + """Defines encoding DagRun object to JSON.""" + + includes = [ + "conf", + "dag_id", + "data_interval_start", + "data_interval_end", + "external_trigger", + "run_id", + "run_type", + "start_date", + ] + + +class TaskInstanceInfo(InfoJsonEncodable): + """Defines encoding TaskInstance object to JSON.""" + + includes = ["duration", "try_number", "pool"] + casts = { + "map_index": lambda ti: ti.map_index + if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 + else None + } + + +class TaskInfo(InfoJsonEncodable): + """Defines encoding BaseOperator/AbstractOperator object to JSON.""" + + renames = { + "_BaseOperator__init_kwargs": "args", + "_BaseOperator__from_mapped": "mapped", + "_downstream_task_ids": "downstream_task_ids", + "_upstream_task_ids": "upstream_task_ids", + } + excludes = [ + "_BaseOperator__instantiated", + "_dag", + "_hook", + "_log", + "_outlets", + "_inlets", + "_lock_for_execution", + "handler", + "params", + "python_callable", + "retry_delay", + ] + casts = { + "operator_class": lambda task: f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", # noqa + "task_group": lambda task: TaskGroupInfo(task.task_group) + if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) + else None, + } + + +class TaskGroupInfo(InfoJsonEncodable): + """Defines encoding TaskGroup object to JSON.""" + + renames = { + "_group_id": "group_id", + } + includes = [ + "downstream_group_ids", + "downstream_task_ids", + "prefix_group_id", + "tooltip", + "upstream_group_ids", + "upstream_task_ids", + ] + + +def get_airflow_run_facet( + dag_run: DagRun, + dag: DAG, + task_instance: TaskInstance, + task: BaseOperator, + task_uuid: str, +): + return { + "airflow": json.loads( + json.dumps( + asdict( + AirflowRunFacet( + dag=DagInfo(dag), + dagRun=DagRunInfo(dag_run), + taskInstance=TaskInstanceInfo(task_instance), + task=TaskInfo(task), + taskUuid=task_uuid, + ) + ), + default=str, + ) + ) + } + + +def get_dagrun_start_end(dagrun: DagRun, dag: DAG): + try: + return dagrun.data_interval_start, dagrun.data_interval_end + except AttributeError: + # Airflow < 2.2 before adding data interval + pass + start = dagrun.execution_date + end = dag.following_schedule(start) + return start, end or start + + +def import_from_string(path: str): + try: + module_path, target = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, target) + except Exception as e: + log.warning(e) + raise ImportError(f"Failed to import {path}") from e + + +def try_import_from_string(path: str): + try: + return import_from_string(path) + except ImportError: + return None + + +def redact_with_exclusions(source: Any): Review Comment: A lot of this looks like a reimplementation of the existing masking code. How is it different? ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin Review Comment: It looks like you already are in `redact_with_exclusions`? ########## airflow/providers/openlineage/extractors/base.py: ########## @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import Factory, define + +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client.facet import BaseFacet +from openlineage.client.run import Dataset + + +@define +class OperatorLineage: + """Structure returned from lineage extraction.""" + + inputs: list[Dataset] = Factory(list) + outputs: list[Dataset] = Factory(list) + run_facets: dict[str, BaseFacet] = Factory(dict) + job_facets: dict[str, BaseFacet] = Factory(dict) + + +class BaseExtractor(ABC, LoggingMixin): + """ + Abstract base extractor class. + + This is used mostly to maintain support for custom extractors. + """ + + _allowed_query_params: list[str] = [] + + def __init__(self, operator): + super().__init__() + self.operator = operator + self.patch() + + def patch(self): + # Extractor should register extension methods or patches to operator here + pass + + @classmethod Review Comment: ```suggestion @abstractclassmethod ``` I think? ########## airflow/providers/openlineage/extractors/extractors.py: ########## @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os + +from airflow.providers.openlineage.extractors.base import BaseExtractor, DefaultExtractor + + +class Extractors: + """ + This exposes implemented extractors, while hiding ones that require additional, unmet + dependency. Patchers are a category of extractor that needs to hook up to operator's + internals during DAG creation. + """ + + def __init__(self): + # Do not expose extractors relying on external dependencies that are not installed + self.extractors = {} + self.default_extractor = DefaultExtractor + + # Comma-separated extractors in OPENLINEAGE_EXTRACTORS variable. + # Extractors should implement BaseExtractor + from airflow.providers.openlineage.utils import import_from_string + + env_extractors = os.getenv("OPENLINEAGE_EXTRACTORS") + if env_extractors is not None: + for extractor in env_extractors.split(";"): Review Comment: Comma more often used for this purpose in Airflow ########## airflow/providers/openlineage/plugins/facets.py: ########## @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from attrs import define, field + +from airflow.providers.openlineage import version as OPENLINEAGE_AIRFLOW_VERSION +from airflow.version import version as AIRFLOW_VERSION +from openlineage.client.facet import BaseFacet +from openlineage.client.utils import RedactMixin + + +@define(slots=False) +class AirflowVersionRunFacet(BaseFacet): + """Run facet containing task and DAG info""" + + operator: str = field() + taskInfo: dict[str, object] = field() + airflowVersion: str = field() + openlineageAirflowVersion: str = field() + + _additional_skip_redact: list[str] = [ + "operator", + "airflowVersion", + "openlineageAirflowVersion", + ] + + @classmethod + def from_dagrun_and_task(cls, dagrun, task): + # task.__dict__ may contain values uncastable to str + from airflow.providers.openlineage.utils import get_operator_class, to_json_encodable + + task_info = to_json_encodable(task) + task_info["dag_run"] = to_json_encodable(dagrun) + + return cls( + operator=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + taskInfo=task_info, + airflowVersion=AIRFLOW_VERSION, + openlineageAirflowVersion=OPENLINEAGE_AIRFLOW_VERSION, + ) + + +@define(slots=False) +class AirflowRunArgsRunFacet(BaseFacet): + """Run facet pointing if DAG was triggered manually""" + + externalTrigger: bool = field(default=False) + + _additional_skip_redact: list[str] = ["externalTrigger"] + + +@define(slots=False) +class AirflowMappedTaskRunFacet(BaseFacet): + """Run facet containing information about mapped tasks""" + + mapIndex: int = field() + operatorClass: str = field() + + _additional_skip_redact: list[str] = ["operatorClass"] + + @classmethod + def from_task_instance(cls, task_instance): + task = task_instance.task + from airflow.providers.openlineage.utils import get_operator_class + + return cls( + mapIndex=task_instance.map_index, + operatorClass=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + ) + + +@define(slots=False) +class AirflowRunFacet(BaseFacet): + """Composite Airflow run facet.""" + + dag: dict = field() + dagRun: dict = field() + task: dict = field() + taskInstance: dict = field() + taskUuid: str = field() + + +@define(slots=False) +class UnknownOperatorInstance(RedactMixin): + """ + Describes an unknown operator - specifies the (class) name of the operator + and its properties + """ + + name: str = field() + properties: dict[str, object] = field() + type: str = field(default="operator") Review Comment: ```suggestion name: str properties: dict[str, object] type: str = "operator" ``` Same effect ########## airflow/providers/openlineage/plugins/facets.py: ########## @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from attrs import define, field + +from airflow.providers.openlineage import version as OPENLINEAGE_AIRFLOW_VERSION +from airflow.version import version as AIRFLOW_VERSION +from openlineage.client.facet import BaseFacet +from openlineage.client.utils import RedactMixin + + +@define(slots=False) +class AirflowVersionRunFacet(BaseFacet): + """Run facet containing task and DAG info""" + + operator: str = field() + taskInfo: dict[str, object] = field() + airflowVersion: str = field() + openlineageAirflowVersion: str = field() + + _additional_skip_redact: list[str] = [ + "operator", + "airflowVersion", + "openlineageAirflowVersion", + ] + + @classmethod + def from_dagrun_and_task(cls, dagrun, task): + # task.__dict__ may contain values uncastable to str + from airflow.providers.openlineage.utils import get_operator_class, to_json_encodable + + task_info = to_json_encodable(task) + task_info["dag_run"] = to_json_encodable(dagrun) + + return cls( + operator=f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", + taskInfo=task_info, + airflowVersion=AIRFLOW_VERSION, + openlineageAirflowVersion=OPENLINEAGE_AIRFLOW_VERSION, + ) + + +@define(slots=False) +class AirflowRunArgsRunFacet(BaseFacet): + """Run facet pointing if DAG was triggered manually""" + + externalTrigger: bool = field(default=False) Review Comment: ```suggestion externalTrigger: bool = False ``` ########## airflow/providers/openlineage/plugins/macros.py: ########## @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import typing + +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + +if typing.TYPE_CHECKING: + from airflow.models import BaseOperator, TaskInstance + +_JOB_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", "default") + + +def lineage_run_id(task: BaseOperator, task_instance: TaskInstance): + """ + Macro function which returns the generated run id for a given task. This + can be used to forward the run id from a task to a child run so the job + hierarchy is preserved. Invoke as a jinja template, e.g. + + PythonOperator( + task_id='render_template', + python_callable=my_task_function, + op_args=['{{ lineage_run_id(task, task_instance) }}'], # lineage_run_id macro invoked + provide_context=False, + dag=dag + ) + """ + return OpenLineageAdapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number + ) + + +def lineage_parent_id(run_id: str, task: BaseOperator, task_instance: TaskInstance): + """ + Macro function which returns the generated job and run id for a given task. This + can be used to forward the ids from a task to a child run so the job + hierarchy is preserved. Child run can create ParentRunFacet from those ids. + Invoke as a jinja template, e.g. + + PythonOperator( + task_id='render_template', + python_callable=my_task_function, + op_args=['{{ lineage_parent_id(run_id, task, task_instance) }}'], # macro invoked Review Comment: Where does `run_id` come from? That isn't a key in the context (at least not to my memory) ########## airflow/providers/openlineage/plugins/openlineage.py: ########## @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id + + +def _is_disabled(): + return os.getenv("OPENLINEAGE_DISABLED", None) in [True, "true", "True"] + + +if _is_disabled(): # type: ignore + # Provide empty plugin when OL is disabled + class OpenLineageProviderPlugin(AirflowPlugin): + """OpenLineage plugin that provides macros only""" + + name = "OpenLineageProviderPlugin" + macros = [lineage_run_id, lineage_parent_id] + +else: + from airflow.providers.openlineage.plugins.listener import OpenLineageListener + + # Provide entrypoint airflow plugin that registers listener module + class OpenLineageProviderPlugin(AirflowPlugin): # type: ignore + """OpenLineage plugin that provides listener module and macros""" + + name = "OpenLineageProviderPlugin" + listeners = [OpenLineageListener()] + macros = [lineage_run_id, lineage_parent_id] Review Comment: ```suggestion class OpenLineageProviderPlugin(AirflowPlugin): name = "OpenLineageProviderPlugin" macros = [lineage_run_id, lineage_parent_id] if _is_disabled(): from airflow.providers.openlineage.plugins.listener import OpenLineageListener listeners = [OpenLineageListener()] ``` ########## airflow/providers/openlineage/plugins/openlineage.py: ########## @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id + + +def _is_disabled(): Review Comment: ```suggestion def _is_disabled() -> bool: ``` ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. + """ + conn_uri = conn.get_uri() + parsed = urlparse(conn_uri) + + # Remove username and password + netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + parsed = parsed._replace(netloc=netloc) + if parsed.query: + query_dict = dict(parse_qsl(parsed.query)) + if conn.EXTRA_KEY in query_dict: + query_dict = json.loads(query_dict[conn.EXTRA_KEY]) + filtered_qs = {k: v for k, v in query_dict.items() if not _filtered_query_params(k)} + parsed = parsed._replace(query=urlencode(filtered_qs)) + return urlunparse(parsed) + + +def _filtered_query_params(k: str): + unfiltered_snowflake_keys = [ + "extra__snowflake__warehouse", + "extra__snowflake__account", + "extra__snowflake__database", + ] Review Comment: Do these need to live here? Shouldn't they be in the facet generation code inside the snowflake provider? ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. + """ + conn_uri = conn.get_uri() + parsed = urlparse(conn_uri) + + # Remove username and password + netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + parsed = parsed._replace(netloc=netloc) + if parsed.query: + query_dict = dict(parse_qsl(parsed.query)) + if conn.EXTRA_KEY in query_dict: + query_dict = json.loads(query_dict[conn.EXTRA_KEY]) + filtered_qs = {k: v for k, v in query_dict.items() if not _filtered_query_params(k)} + parsed = parsed._replace(query=urlencode(filtered_qs)) + return urlunparse(parsed) + + +def _filtered_query_params(k: str): + unfiltered_snowflake_keys = [ + "extra__snowflake__warehouse", + "extra__snowflake__account", + "extra__snowflake__database", + ] + filtered_key_substrings = [ + "aws_access_key_id", + "aws_secret_access_key", + "extra__snowflake__", + ] + return k not in unfiltered_snowflake_keys and any(substr in k for substr in filtered_key_substrings) + + +def get_normalized_postgres_connection_uri(conn): + """ + URIs starting with postgresql:// and postgres:// are both valid + PostgreSQL connection strings. This function normalizes it to + postgres:// as canonical name according to OpenLineage spec. + """ + uri = get_connection_uri(conn) + if uri.startswith("postgresql"): + uri = uri.replace("postgresql", "postgres", 1) + return uri + + +def get_connection(conn_id) -> Connection | None: + from airflow.hooks.base import BaseHook + + try: + return BaseHook.get_connection(conn_id=conn_id) + except Exception: + return None + + +def get_job_name(task): + return f"{task.dag_id}.{task.task_id}" + + +def get_custom_facets( + dagrun, task, is_external_trigger: bool, task_instance: TaskInstance | None = None +) -> dict[str, Any]: + custom_facets = { + "airflow_runArgs": AirflowRunArgsRunFacet(is_external_trigger), + "airflow_version": AirflowVersionRunFacet.from_dagrun_and_task(dagrun, task), + } + # check for -1 comes from SmartSensor compatibility with dynamic task mapping + # this comes from Airflow code + if hasattr(task_instance, "map_index") and getattr(task_instance, "map_index") != -1: + custom_facets["airflow_mappedTask"] = AirflowMappedTaskRunFacet.from_task_instance(task_instance) + return custom_facets + + +class InfoJsonEncodable(dict): + """ + Airflow objects might not be json-encodable overall. + + The class provides additional attributes to control + what and how is encoded: + * renames: a dictionary of attribute name changes + * casts: a dictionary consisting of attribute names + and corresponding methods that should change + object value + * includes: list of attributes to be included in encoding + * excludes: list of attributes to be excluded from encoding + + Don't use both includes and excludes. + """ + + renames: dict[str, str] = {} + casts: dict[str, Any] = {} + includes: list[str] = [] + excludes: list[str] = [] + + def __init__(self, obj): + self.obj = obj + self._fields = [] + + self._cast_fields() + self._rename_fields() + self._include_fields() + dict.__init__( + self, + **{field: InfoJsonEncodable._cast_basic_types(getattr(self, field)) for field in self._fields}, + ) + + @staticmethod + def _cast_basic_types(value): + if isinstance(value, datetime.datetime): + return value.isoformat() + if isinstance(value, (set, list, tuple)): + return str(list(value)) + return value + + def _rename_fields(self): + for field, renamed in self.renames.items(): + if hasattr(self.obj, field): + setattr(self, renamed, getattr(self.obj, field)) + self._fields.append(renamed) + + def _cast_fields(self): + for field, func in self.casts.items(): + setattr(self, field, func(self.obj)) + self._fields.append(field) + + def _include_fields(self): + if self.includes and self.excludes: + raise Exception("Don't use both includes and excludes.") + if self.includes: + for field in self.includes: + if field in self._fields or not hasattr(self.obj, field): + continue + setattr(self, field, getattr(self.obj, field)) + self._fields.append(field) + else: + for field, val in self.obj.__dict__.items(): + if field in self._fields or field in self.excludes or field in self.renames: + continue + setattr(self, field, val) + self._fields.append(field) + + +class DagInfo(InfoJsonEncodable): + """Defines encoding DAG object to JSON.""" + + includes = ["dag_id", "schedule_interval", "tags", "start_date"] + casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + renames = {"_dag_id": "dag_id"} + + +class DagRunInfo(InfoJsonEncodable): + """Defines encoding DagRun object to JSON.""" + + includes = [ + "conf", + "dag_id", + "data_interval_start", + "data_interval_end", + "external_trigger", + "run_id", + "run_type", + "start_date", + ] + + +class TaskInstanceInfo(InfoJsonEncodable): + """Defines encoding TaskInstance object to JSON.""" + + includes = ["duration", "try_number", "pool"] + casts = { + "map_index": lambda ti: ti.map_index + if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 + else None + } + + +class TaskInfo(InfoJsonEncodable): + """Defines encoding BaseOperator/AbstractOperator object to JSON.""" + + renames = { + "_BaseOperator__init_kwargs": "args", + "_BaseOperator__from_mapped": "mapped", + "_downstream_task_ids": "downstream_task_ids", + "_upstream_task_ids": "upstream_task_ids", + } + excludes = [ + "_BaseOperator__instantiated", + "_dag", + "_hook", + "_log", + "_outlets", + "_inlets", + "_lock_for_execution", + "handler", + "params", + "python_callable", + "retry_delay", + ] + casts = { + "operator_class": lambda task: f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", # noqa + "task_group": lambda task: TaskGroupInfo(task.task_group) + if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) + else None, + } + + +class TaskGroupInfo(InfoJsonEncodable): + """Defines encoding TaskGroup object to JSON.""" + + renames = { + "_group_id": "group_id", + } + includes = [ + "downstream_group_ids", + "downstream_task_ids", + "prefix_group_id", + "tooltip", + "upstream_group_ids", + "upstream_task_ids", + ] + + +def get_airflow_run_facet( + dag_run: DagRun, + dag: DAG, + task_instance: TaskInstance, + task: BaseOperator, + task_uuid: str, +): + return { + "airflow": json.loads( + json.dumps( + asdict( + AirflowRunFacet( + dag=DagInfo(dag), + dagRun=DagRunInfo(dag_run), + taskInstance=TaskInstanceInfo(task_instance), + task=TaskInfo(task), + taskUuid=task_uuid, + ) + ), + default=str, + ) + ) + } + + +def get_dagrun_start_end(dagrun: DagRun, dag: DAG): + try: + return dagrun.data_interval_start, dagrun.data_interval_end + except AttributeError: + # Airflow < 2.2 before adding data interval + pass + start = dagrun.execution_date + end = dag.following_schedule(start) + return start, end or start + + +def import_from_string(path: str): + try: + module_path, target = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, target) + except Exception as e: + log.warning(e) + raise ImportError(f"Failed to import {path}") from e + + +def try_import_from_string(path: str): + try: + return import_from_string(path) + except ImportError: + return None Review Comment: `from contextlib import suppress` and then ```suggestion with suppress(ImportError): return import_from_string(path) ``` ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. + """ + conn_uri = conn.get_uri() + parsed = urlparse(conn_uri) + + # Remove username and password + netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + parsed = parsed._replace(netloc=netloc) + if parsed.query: + query_dict = dict(parse_qsl(parsed.query)) + if conn.EXTRA_KEY in query_dict: + query_dict = json.loads(query_dict[conn.EXTRA_KEY]) + filtered_qs = {k: v for k, v in query_dict.items() if not _filtered_query_params(k)} + parsed = parsed._replace(query=urlencode(filtered_qs)) + return urlunparse(parsed) + + +def _filtered_query_params(k: str): + unfiltered_snowflake_keys = [ + "extra__snowflake__warehouse", + "extra__snowflake__account", + "extra__snowflake__database", + ] + filtered_key_substrings = [ + "aws_access_key_id", + "aws_secret_access_key", + "extra__snowflake__", + ] + return k not in unfiltered_snowflake_keys and any(substr in k for substr in filtered_key_substrings) + + +def get_normalized_postgres_connection_uri(conn): + """ + URIs starting with postgresql:// and postgres:// are both valid + PostgreSQL connection strings. This function normalizes it to + postgres:// as canonical name according to OpenLineage spec. + """ + uri = get_connection_uri(conn) + if uri.startswith("postgresql"): + uri = uri.replace("postgresql", "postgres", 1) + return uri + + +def get_connection(conn_id) -> Connection | None: + from airflow.hooks.base import BaseHook + + try: + return BaseHook.get_connection(conn_id=conn_id) + except Exception: + return None + + +def get_job_name(task): + return f"{task.dag_id}.{task.task_id}" + + +def get_custom_facets( + dagrun, task, is_external_trigger: bool, task_instance: TaskInstance | None = None +) -> dict[str, Any]: + custom_facets = { + "airflow_runArgs": AirflowRunArgsRunFacet(is_external_trigger), + "airflow_version": AirflowVersionRunFacet.from_dagrun_and_task(dagrun, task), + } + # check for -1 comes from SmartSensor compatibility with dynamic task mapping + # this comes from Airflow code + if hasattr(task_instance, "map_index") and getattr(task_instance, "map_index") != -1: + custom_facets["airflow_mappedTask"] = AirflowMappedTaskRunFacet.from_task_instance(task_instance) + return custom_facets + + +class InfoJsonEncodable(dict): + """ + Airflow objects might not be json-encodable overall. + + The class provides additional attributes to control + what and how is encoded: + * renames: a dictionary of attribute name changes + * casts: a dictionary consisting of attribute names + and corresponding methods that should change + object value + * includes: list of attributes to be included in encoding + * excludes: list of attributes to be excluded from encoding + + Don't use both includes and excludes. + """ + + renames: dict[str, str] = {} + casts: dict[str, Any] = {} + includes: list[str] = [] + excludes: list[str] = [] + + def __init__(self, obj): + self.obj = obj + self._fields = [] + + self._cast_fields() + self._rename_fields() + self._include_fields() + dict.__init__( + self, + **{field: InfoJsonEncodable._cast_basic_types(getattr(self, field)) for field in self._fields}, + ) + + @staticmethod + def _cast_basic_types(value): + if isinstance(value, datetime.datetime): + return value.isoformat() + if isinstance(value, (set, list, tuple)): + return str(list(value)) + return value + + def _rename_fields(self): + for field, renamed in self.renames.items(): + if hasattr(self.obj, field): + setattr(self, renamed, getattr(self.obj, field)) + self._fields.append(renamed) + + def _cast_fields(self): + for field, func in self.casts.items(): + setattr(self, field, func(self.obj)) + self._fields.append(field) + + def _include_fields(self): + if self.includes and self.excludes: + raise Exception("Don't use both includes and excludes.") + if self.includes: + for field in self.includes: + if field in self._fields or not hasattr(self.obj, field): + continue + setattr(self, field, getattr(self.obj, field)) + self._fields.append(field) + else: + for field, val in self.obj.__dict__.items(): + if field in self._fields or field in self.excludes or field in self.renames: + continue + setattr(self, field, val) + self._fields.append(field) + + +class DagInfo(InfoJsonEncodable): + """Defines encoding DAG object to JSON.""" + + includes = ["dag_id", "schedule_interval", "tags", "start_date"] + casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + renames = {"_dag_id": "dag_id"} + + +class DagRunInfo(InfoJsonEncodable): + """Defines encoding DagRun object to JSON.""" + + includes = [ + "conf", + "dag_id", + "data_interval_start", + "data_interval_end", + "external_trigger", + "run_id", + "run_type", + "start_date", + ] + + +class TaskInstanceInfo(InfoJsonEncodable): + """Defines encoding TaskInstance object to JSON.""" + + includes = ["duration", "try_number", "pool"] + casts = { + "map_index": lambda ti: ti.map_index + if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 + else None + } + + +class TaskInfo(InfoJsonEncodable): + """Defines encoding BaseOperator/AbstractOperator object to JSON.""" + + renames = { + "_BaseOperator__init_kwargs": "args", + "_BaseOperator__from_mapped": "mapped", + "_downstream_task_ids": "downstream_task_ids", + "_upstream_task_ids": "upstream_task_ids", + } + excludes = [ + "_BaseOperator__instantiated", + "_dag", + "_hook", + "_log", + "_outlets", + "_inlets", + "_lock_for_execution", + "handler", + "params", + "python_callable", + "retry_delay", + ] + casts = { + "operator_class": lambda task: f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", # noqa + "task_group": lambda task: TaskGroupInfo(task.task_group) + if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) + else None, + } + + +class TaskGroupInfo(InfoJsonEncodable): + """Defines encoding TaskGroup object to JSON.""" + + renames = { + "_group_id": "group_id", + } + includes = [ + "downstream_group_ids", + "downstream_task_ids", + "prefix_group_id", + "tooltip", + "upstream_group_ids", + "upstream_task_ids", + ] + + +def get_airflow_run_facet( + dag_run: DagRun, + dag: DAG, + task_instance: TaskInstance, + task: BaseOperator, + task_uuid: str, +): + return { + "airflow": json.loads( + json.dumps( + asdict( + AirflowRunFacet( + dag=DagInfo(dag), + dagRun=DagRunInfo(dag_run), + taskInstance=TaskInstanceInfo(task_instance), + task=TaskInfo(task), + taskUuid=task_uuid, + ) + ), + default=str, + ) + ) + } + + +def get_dagrun_start_end(dagrun: DagRun, dag: DAG): + try: + return dagrun.data_interval_start, dagrun.data_interval_end + except AttributeError: + # Airflow < 2.2 before adding data interval + pass + start = dagrun.execution_date + end = dag.following_schedule(start) + return start, end or start + + +def import_from_string(path: str): + try: + module_path, target = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, target) + except Exception as e: + log.warning(e) + raise ImportError(f"Failed to import {path}") from e + + +def try_import_from_string(path: str): + try: + return import_from_string(path) + except ImportError: + return None + + +def redact_with_exclusions(source: Any): + try: + from airflow.utils.log.secrets_masker import ( + _secrets_masker, + should_hide_value_for_key, + ) + except ImportError: + return source + import copy + + sm = copy.deepcopy(_secrets_masker()) + MAX_RECURSION_DEPTH = 20 + + def _redact(item, name: str | None, depth: int): + if depth > MAX_RECURSION_DEPTH: + return item + try: + if name and should_hide_value_for_key(name): + return sm._redact_all(item, depth) + if isinstance(item, dict): + return { + dict_key: _redact(subval, name=dict_key, depth=(depth + 1)) + for dict_key, subval in item.items() + } + elif is_dataclass(item) or (is_json_serializable(item) and hasattr(item, "__dict__")): + for dict_key, subval in item.__dict__.items(): + if _is_name_redactable(dict_key, item): + setattr( + item, + dict_key, + _redact(subval, name=dict_key, depth=(depth + 1)), + ) + return item + elif isinstance(item, str): + if sm.replacer: + return sm.replacer.sub("***", item) + return item + elif isinstance(item, (tuple, set)): + return tuple(_redact(subval, name=None, depth=(depth + 1)) for subval in item) + elif isinstance(item, list): + return [_redact(subval, name=None, depth=(depth + 1)) for subval in item] + else: + return item + except Exception as e: + log.warning( + "Unable to redact %s" "Error was: %s: %s", + repr(item), + type(e).__name__, + str(e), + ) + return item + + return _redact(source, name=None, depth=0) + + +def is_dataclass(item): + return getattr(item.__class__, "__attrs_attrs__", None) is not None Review Comment: a) this is not dataclass, but attrs, b) https://www.attrs.org/en/stable/api.html#attrs.has ```suggestion return attrs.has(item.__class__) ``` ########## airflow/providers/openlineage/utils/converters.py: ########## @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.lineage.entities import Table +from openlineage.client.run import Dataset + + +def convert_to_dataset(obj): Review Comment: Since Airflow has it's own concept of Dataset we should probably disambiguate them a bit. ```suggestion def convert_to_ol_dataset(obj): ``` ########## airflow/providers/openlineage/extractors/manager.py: ########## @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.openlineage.extractors import BaseExtractor, Extractors, OperatorLineage +from airflow.providers.openlineage.plugins.facets import ( + UnknownOperatorAttributeRunFacet, + UnknownOperatorInstance, +) +from airflow.providers.openlineage.utils import get_operator_class +from airflow.utils.log.logging_mixin import LoggingMixin + + +class ExtractorManager(LoggingMixin): + """Class abstracting management of custom extractors.""" + + def __init__(self): + self.extractors = {} + self.task_to_extractor = Extractors() + + def add_extractor(self, operator, extractor: type[BaseExtractor]): + self.task_to_extractor.add_extractor(operator, extractor) + + def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage: + extractor = self._get_extractor(task) + task_info = ( + f"task_type={get_operator_class(task).__name__} " + f"airflow_dag_id={task.dag_id} " + f"task_id={task.task_id} " + f"airflow_run_id={dagrun.run_id} " + ) + + if extractor: + # Extracting advanced metadata is only possible when extractor for particular operator + # is defined. Without it, we can't extract any input or output data. + try: + self.log.debug("Using extractor %s %s", extractor.__class__.__name__, str(task_info)) + if complete: + task_metadata = extractor.extract_on_complete(task_instance) + else: + task_metadata = extractor.extract() + + self.log.debug("Found task metadata for operation %s: %s", task.task_id, str(task_metadata)) + if task_metadata: + if (not task_metadata.inputs) and (not task_metadata.outputs): + inlets = task.get_inlet_defs() + outlets = task.get_outlet_defs() + self.extract_inlets_and_outlets(task_metadata, inlets, outlets) + + return task_metadata + + except Exception as e: + self.log.exception("Failed to extract metadata %s %s", e, task_info) + else: + self.log.warning("Unable to find an extractor %s", task_info) + + # Only include the unkonwnSourceAttribute facet if there is no extractor + task_metadata = OperatorLineage( + run_facets={ + "unknownSourceAttribute": UnknownOperatorAttributeRunFacet( + unknownItems=[ + UnknownOperatorInstance( + name=get_operator_class(task).__name__, + properties={attr: value for attr, value in task.__dict__.items()}, + ) + ] + ) + }, + ) + inlets = task.get_inlet_defs() + outlets = task.get_outlet_defs() + self.extract_inlets_and_outlets(task_metadata, inlets, outlets) + return task_metadata + + return OperatorLineage() + + def _get_extractor(self, task) -> BaseExtractor | None: + # TODO: Re-enable in Extractor PR + # self.task_to_extractor.instantiate_abstract_extractors(task) + if task.task_id in self.extractors: + return self.extractors[task.task_id] Review Comment: Why do we need to cache by task_id? What process is this called in that it makes sense to cache this at all? ########## airflow/providers/openlineage/extractors/manager.py: ########## @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.openlineage.extractors import BaseExtractor, Extractors, OperatorLineage +from airflow.providers.openlineage.plugins.facets import ( + UnknownOperatorAttributeRunFacet, + UnknownOperatorInstance, +) +from airflow.providers.openlineage.utils import get_operator_class +from airflow.utils.log.logging_mixin import LoggingMixin + + +class ExtractorManager(LoggingMixin): + """Class abstracting management of custom extractors.""" + + def __init__(self): + self.extractors = {} + self.task_to_extractor = Extractors() + + def add_extractor(self, operator, extractor: type[BaseExtractor]): + self.task_to_extractor.add_extractor(operator, extractor) + + def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage: + extractor = self._get_extractor(task) + task_info = ( + f"task_type={get_operator_class(task).__name__} " Review Comment: This should use `task.task_type` etc, not a custom function. ########## airflow/providers/openlineage/plugins/adapter.py: ########## @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import uuid +from typing import TYPE_CHECKING + +import requests.exceptions + +from airflow.providers.openlineage import version as OPENLINEAGE_PROVIDER_VERSION +from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.openlineage.utils import redact_with_exclusions +from airflow.utils.log.logging_mixin import LoggingMixin +from openlineage.client import OpenLineageClient, set_producer +from openlineage.client.facet import ( + BaseFacet, + DocumentationJobFacet, + ErrorMessageRunFacet, + NominalTimeRunFacet, + OwnershipJobFacet, + OwnershipJobFacetOwners, + ParentRunFacet, + ProcessingEngineRunFacet, + SourceCodeLocationJobFacet, +) +from openlineage.client.run import Job, Run, RunEvent, RunState + +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + + +_DAG_DEFAULT_NAMESPACE = "default" + +_DAG_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE) Review Comment: Use Airflow config please. ########## airflow/providers/openlineage/extractors/manager.py: ########## @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.openlineage.extractors import BaseExtractor, Extractors, OperatorLineage +from airflow.providers.openlineage.plugins.facets import ( + UnknownOperatorAttributeRunFacet, + UnknownOperatorInstance, +) +from airflow.providers.openlineage.utils import get_operator_class +from airflow.utils.log.logging_mixin import LoggingMixin + + +class ExtractorManager(LoggingMixin): + """Class abstracting management of custom extractors.""" + + def __init__(self): + self.extractors = {} + self.task_to_extractor = Extractors() + + def add_extractor(self, operator, extractor: type[BaseExtractor]): + self.task_to_extractor.add_extractor(operator, extractor) + + def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage: + extractor = self._get_extractor(task) + task_info = ( + f"task_type={get_operator_class(task).__name__} " + f"airflow_dag_id={task.dag_id} " + f"task_id={task.task_id} " + f"airflow_run_id={dagrun.run_id} " + ) + + if extractor: + # Extracting advanced metadata is only possible when extractor for particular operator + # is defined. Without it, we can't extract any input or output data. + try: + self.log.debug("Using extractor %s %s", extractor.__class__.__name__, str(task_info)) + if complete: + task_metadata = extractor.extract_on_complete(task_instance) + else: + task_metadata = extractor.extract() + + self.log.debug("Found task metadata for operation %s: %s", task.task_id, str(task_metadata)) + if task_metadata: + if (not task_metadata.inputs) and (not task_metadata.outputs): + inlets = task.get_inlet_defs() + outlets = task.get_outlet_defs() + self.extract_inlets_and_outlets(task_metadata, inlets, outlets) + + return task_metadata + + except Exception as e: + self.log.exception("Failed to extract metadata %s %s", e, task_info) + else: + self.log.warning("Unable to find an extractor %s", task_info) Review Comment: Not sure (yet as I'm reading this PR) where this code is called, but would this end up warning just because that task has no registered extractor? If so this should be at debug level, not warning ########## airflow/providers/openlineage/plugins/openlineage.py: ########## @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id + + +def _is_disabled(): + return os.getenv("OPENLINEAGE_DISABLED", None) in [True, "true", "True"] Review Comment: Airflow config here too (again, can fallback to OL env var) -- using `conf.getboolean` ########## airflow/providers/openlineage/plugins/listener.py: ########## @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import TYPE_CHECKING + +from airflow.listeners import hookimpl +from airflow.providers.openlineage.extractors import ExtractorManager +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter +from airflow.providers.openlineage.utils import ( + get_airflow_run_facet, + get_custom_facets, + get_job_name, + get_task_location, + print_exception, +) + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models import DagRun, TaskInstance + + +class OpenLineageListener: + """ + OpenLineage listener + Sends events on task instance and dag run starts, completes and failures. + """ + + def __init__(self): + self.log = logging.getLogger(__name__) + self.executor: Executor = None # type: ignore + self.extractor_manager = ExtractorManager() + self.adapter = OpenLineageAdapter() + + @hookimpl + def on_task_instance_running( + self, previous_state, task_instance: TaskInstance, session: Session # This will always be QUEUED + ): + if not hasattr(task_instance, "task"): + self.log.warning( + f"No task set for TI object task_id: {task_instance.task_id} - " + f"dag_id: {task_instance.dag_id} - run_id {task_instance.run_id}" + ) + return + + self.log.debug("OpenLineage listener got notification about task instance start") + dagrun = task_instance.dag_run + task = task_instance.task + dag = task.dag + + @print_exception + def on_running(): + # that's a workaround to detect task running from deferred state + # we return here because Airflow 2.3 needs task from deferred state + if task_instance.next_method is not None: + return + parent_run_id = self.adapter.build_dag_run_id(dag.dag_id, dagrun.run_id) + + task_uuid = self.adapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number + ) + + task_metadata = self.extractor_manager.extract_metadata(dagrun, task) + + self.adapter.start_task( + run_id=task_uuid, + job_name=get_job_name(task), + job_description=dag.description, + event_time=task_instance.start_date.isoformat(), + parent_job_name=dag.dag_id, + parent_run_id=parent_run_id, + code_location=get_task_location(task), + nominal_start_time=dagrun.data_interval_start.isoformat(), + nominal_end_time=dagrun.data_interval_end.isoformat(), + owners=dag.owner.split(", "), + task=task_metadata, + run_facets={ + **task_metadata.run_facets, + **get_custom_facets(dagrun, task, dagrun.external_trigger, task_instance), + **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid), + }, + ) + + self.executor.submit(on_running) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session): + self.log.debug("OpenLineage listener got notification about task instance success") + + dagrun = task_instance.dag_run + task = task_instance.task + + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number - 1 + ) + + @print_exception + def on_success(): + task_metadata = self.extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance + ) + self.adapter.complete_task( + run_id=task_uuid, + job_name=get_job_name(task), + end_time=task_instance.end_date.isoformat(), + task=task_metadata, + ) + + self.executor.submit(on_success) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session): + self.log.debug("OpenLineage listener got notification about task instance failure") + + dagrun = task_instance.dag_run + task = task_instance.task + + task_uuid = OpenLineageAdapter.build_task_instance_run_id( + task.task_id, task_instance.execution_date, task_instance.try_number - 1 + ) + + @print_exception + def on_failure(): + task_metadata = self.extractor_manager.extract_metadata( + dagrun, task, complete=True, task_instance=task_instance + ) + + self.adapter.fail_task( + run_id=task_uuid, + job_name=get_job_name(task), + end_time=task_instance.end_date.isoformat(), + task=task_metadata, + ) + + self.executor.submit(on_failure) + + @hookimpl + def on_starting(self, component): + self.log.debug("on_starting: %s", component.__class__.__name__) + self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + + @hookimpl + def before_stopping(self, component): + self.log.debug("before_stopping: %s", component.__class__.__name__) + self.executor.shutdown(wait=True) Review Comment: Can we place a timeout on this wait so it doesn't hang blocking the task for ever? ########## airflow/providers/openlineage/utils/__init__.py: ########## @@ -0,0 +1,499 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import importlib +import json +import logging +import os +import subprocess +from functools import wraps +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from attrs import asdict + +from airflow.models import DAG as AIRFLOW_DAG +from airflow.providers.openlineage.plugins.facets import ( + AirflowMappedTaskRunFacet, + AirflowRunArgsRunFacet, + AirflowRunFacet, + AirflowVersionRunFacet, +) + +# TODO: move this maybe to Airflow's logic? +from openlineage.client.utils import RedactMixin + +if TYPE_CHECKING: + from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance + + +log = logging.getLogger(__name__) +_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def openlineage_job_name(dag_id: str, task_id: str) -> str: + return f"{dag_id}.{task_id}" + + +def get_operator_class(task: BaseOperator) -> type: + if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): + return task.operator_class + return task.__class__ + + +def to_json_encodable(task: BaseOperator) -> dict[str, object]: + def _task_encoder(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + elif isinstance(obj, AIRFLOW_DAG): + return { + "dag_id": obj.dag_id, + "tags": obj.tags, + "schedule_interval": obj.schedule_interval, + } + else: + return str(obj) + + return json.loads(json.dumps(task.__dict__, default=_task_encoder)) + + +def url_to_https(url) -> str | None: + # Ensure URL exists + if not url: + return None + + base_url = None + if url.startswith("git@"): + part = url.split("git@")[1:2] + if part: + base_url = f'https://{part[0].replace(":", "/", 1)}' + elif url.startswith("https://"): + base_url = url + + if not base_url: + raise ValueError(f"Unable to extract location from: {url}") + + if base_url.endswith(".git"): + base_url = base_url[:-4] + return base_url + + +def get_location(file_path) -> str | None: + # Ensure file path exists + if not file_path: + return None + + # move to the file directory + abs_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + cwd = os.path.dirname(abs_path) + + # get the repo url + repo_url = execute_git(cwd, ["config", "--get", "remote.origin.url"]) + + # get the repo relative path + repo_relative_path = execute_git(cwd, ["rev-parse", "--show-prefix"]) + + # get the commitId for the particular file + commit_id = execute_git(cwd, ["rev-list", "HEAD", "-1", "--", file_name]) + + # build the URL + base_url = url_to_https(repo_url) + if not base_url: + return None + + return f"{base_url}/blob/{commit_id}/{repo_relative_path}{file_name}" + + +def get_task_location(task): + try: + if hasattr(task, "file_path") and task.file_path: + return get_location(task.file_path) + else: + return get_location(task.dag.fileloc) + except Exception: + return None + + +def execute_git(cwd, params): + p = subprocess.Popen(["git"] + params, cwd=cwd, stdout=subprocess.PIPE, stderr=None) + p.wait(timeout=0.5) + out, err = p.communicate() + return out.decode("utf8").strip() + + +def get_connection_uri(conn): + """ + Return the connection URI for the given ID. We first attempt to lookup + the connection URI via AIRFLOW_CONN_<conn_id>, else fallback on querying + the Airflow's connection table. + """ + conn_uri = conn.get_uri() + parsed = urlparse(conn_uri) + + # Remove username and password + netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "") + parsed = parsed._replace(netloc=netloc) + if parsed.query: + query_dict = dict(parse_qsl(parsed.query)) + if conn.EXTRA_KEY in query_dict: + query_dict = json.loads(query_dict[conn.EXTRA_KEY]) + filtered_qs = {k: v for k, v in query_dict.items() if not _filtered_query_params(k)} + parsed = parsed._replace(query=urlencode(filtered_qs)) + return urlunparse(parsed) + + +def _filtered_query_params(k: str): + unfiltered_snowflake_keys = [ + "extra__snowflake__warehouse", + "extra__snowflake__account", + "extra__snowflake__database", + ] + filtered_key_substrings = [ + "aws_access_key_id", + "aws_secret_access_key", + "extra__snowflake__", + ] + return k not in unfiltered_snowflake_keys and any(substr in k for substr in filtered_key_substrings) + + +def get_normalized_postgres_connection_uri(conn): + """ + URIs starting with postgresql:// and postgres:// are both valid + PostgreSQL connection strings. This function normalizes it to + postgres:// as canonical name according to OpenLineage spec. + """ + uri = get_connection_uri(conn) + if uri.startswith("postgresql"): + uri = uri.replace("postgresql", "postgres", 1) + return uri + + +def get_connection(conn_id) -> Connection | None: + from airflow.hooks.base import BaseHook + + try: + return BaseHook.get_connection(conn_id=conn_id) + except Exception: + return None + + +def get_job_name(task): + return f"{task.dag_id}.{task.task_id}" + + +def get_custom_facets( + dagrun, task, is_external_trigger: bool, task_instance: TaskInstance | None = None +) -> dict[str, Any]: + custom_facets = { + "airflow_runArgs": AirflowRunArgsRunFacet(is_external_trigger), + "airflow_version": AirflowVersionRunFacet.from_dagrun_and_task(dagrun, task), + } + # check for -1 comes from SmartSensor compatibility with dynamic task mapping + # this comes from Airflow code + if hasattr(task_instance, "map_index") and getattr(task_instance, "map_index") != -1: + custom_facets["airflow_mappedTask"] = AirflowMappedTaskRunFacet.from_task_instance(task_instance) + return custom_facets + + +class InfoJsonEncodable(dict): + """ + Airflow objects might not be json-encodable overall. + + The class provides additional attributes to control + what and how is encoded: + * renames: a dictionary of attribute name changes + * casts: a dictionary consisting of attribute names + and corresponding methods that should change + object value + * includes: list of attributes to be included in encoding + * excludes: list of attributes to be excluded from encoding + + Don't use both includes and excludes. + """ + + renames: dict[str, str] = {} + casts: dict[str, Any] = {} + includes: list[str] = [] + excludes: list[str] = [] + + def __init__(self, obj): + self.obj = obj + self._fields = [] + + self._cast_fields() + self._rename_fields() + self._include_fields() + dict.__init__( + self, + **{field: InfoJsonEncodable._cast_basic_types(getattr(self, field)) for field in self._fields}, + ) + + @staticmethod + def _cast_basic_types(value): + if isinstance(value, datetime.datetime): + return value.isoformat() + if isinstance(value, (set, list, tuple)): + return str(list(value)) + return value + + def _rename_fields(self): + for field, renamed in self.renames.items(): + if hasattr(self.obj, field): + setattr(self, renamed, getattr(self.obj, field)) + self._fields.append(renamed) + + def _cast_fields(self): + for field, func in self.casts.items(): + setattr(self, field, func(self.obj)) + self._fields.append(field) + + def _include_fields(self): + if self.includes and self.excludes: + raise Exception("Don't use both includes and excludes.") + if self.includes: + for field in self.includes: + if field in self._fields or not hasattr(self.obj, field): + continue + setattr(self, field, getattr(self.obj, field)) + self._fields.append(field) + else: + for field, val in self.obj.__dict__.items(): + if field in self._fields or field in self.excludes or field in self.renames: + continue + setattr(self, field, val) + self._fields.append(field) + + +class DagInfo(InfoJsonEncodable): + """Defines encoding DAG object to JSON.""" + + includes = ["dag_id", "schedule_interval", "tags", "start_date"] + casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + renames = {"_dag_id": "dag_id"} + + +class DagRunInfo(InfoJsonEncodable): + """Defines encoding DagRun object to JSON.""" + + includes = [ + "conf", + "dag_id", + "data_interval_start", + "data_interval_end", + "external_trigger", + "run_id", + "run_type", + "start_date", + ] + + +class TaskInstanceInfo(InfoJsonEncodable): + """Defines encoding TaskInstance object to JSON.""" + + includes = ["duration", "try_number", "pool"] + casts = { + "map_index": lambda ti: ti.map_index + if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 + else None + } + + +class TaskInfo(InfoJsonEncodable): + """Defines encoding BaseOperator/AbstractOperator object to JSON.""" + + renames = { + "_BaseOperator__init_kwargs": "args", + "_BaseOperator__from_mapped": "mapped", + "_downstream_task_ids": "downstream_task_ids", + "_upstream_task_ids": "upstream_task_ids", + } + excludes = [ + "_BaseOperator__instantiated", + "_dag", + "_hook", + "_log", + "_outlets", + "_inlets", + "_lock_for_execution", + "handler", + "params", + "python_callable", + "retry_delay", + ] + casts = { + "operator_class": lambda task: f"{get_operator_class(task).__module__}.{get_operator_class(task).__name__}", # noqa + "task_group": lambda task: TaskGroupInfo(task.task_group) + if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) + else None, + } + + +class TaskGroupInfo(InfoJsonEncodable): + """Defines encoding TaskGroup object to JSON.""" + + renames = { + "_group_id": "group_id", + } + includes = [ + "downstream_group_ids", + "downstream_task_ids", + "prefix_group_id", + "tooltip", + "upstream_group_ids", + "upstream_task_ids", + ] + + +def get_airflow_run_facet( + dag_run: DagRun, + dag: DAG, + task_instance: TaskInstance, + task: BaseOperator, + task_uuid: str, +): + return { + "airflow": json.loads( + json.dumps( + asdict( + AirflowRunFacet( + dag=DagInfo(dag), + dagRun=DagRunInfo(dag_run), + taskInstance=TaskInstanceInfo(task_instance), + task=TaskInfo(task), + taskUuid=task_uuid, + ) + ), + default=str, + ) + ) + } + + +def get_dagrun_start_end(dagrun: DagRun, dag: DAG): + try: + return dagrun.data_interval_start, dagrun.data_interval_end + except AttributeError: + # Airflow < 2.2 before adding data interval + pass + start = dagrun.execution_date + end = dag.following_schedule(start) + return start, end or start + + +def import_from_string(path: str): + try: + module_path, target = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, target) + except Exception as e: + log.warning(e) + raise ImportError(f"Failed to import {path}") from e + + +def try_import_from_string(path: str): + try: + return import_from_string(path) + except ImportError: + return None + + +def redact_with_exclusions(source: Any): + try: + from airflow.utils.log.secrets_masker import ( + _secrets_masker, + should_hide_value_for_key, + ) + except ImportError: + return source + import copy + + sm = copy.deepcopy(_secrets_masker()) Review Comment: Accessing `_secrets_masker` directly is a little bit naughty. Do you need to? ########## airflow/providers/openlineage/plugins/macros.py: ########## @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import typing + +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + +if typing.TYPE_CHECKING: + from airflow.models import BaseOperator, TaskInstance + +_JOB_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", "default") + + +def lineage_run_id(task: BaseOperator, task_instance: TaskInstance): + """ + Macro function which returns the generated run id for a given task. This + can be used to forward the run id from a task to a child run so the job + hierarchy is preserved. Invoke as a jinja template, e.g. + + PythonOperator( + task_id='render_template', + python_callable=my_task_function, + op_args=['{{ lineage_run_id(task, task_instance) }}'], # lineage_run_id macro invoked Review Comment: Oh, all you seem to get out of the task is the task_id, which ti has anyway. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
