RNHTTR commented on code in PR #66612: URL: https://github.com/apache/airflow/pull/66612#discussion_r3311667444
########## providers/informatica/src/airflow/providers/informatica/lineage/resolver.py: ########## @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +from airflow.providers.informatica.lineage.sql_parser import TableRef, parse_sql_tables + +log = logging.getLogger(__name__) + +try: + from airflow.providers.common.sql.operators.sql import BaseSQLOperator as _BaseSQLOperator + + _HAS_BASE_SQL_OPERATOR = True +except ImportError: + _BaseSQLOperator = None # type: ignore[assignment, misc] + _HAS_BASE_SQL_OPERATOR = False + +# Operator attribute names scanned in order to locate a connection ID. +# conn_id_field (BaseSQLOperator) is tried first; this list is the fallback. +_CONN_ID_ATTRS: tuple[str, ...] = ( + "conn_id", + "source_conn_id", + "mysql_conn_id", + "postgres_conn_id", + "mssql_conn_id", + "oracle_conn_id", + "sqlite_conn_id", + "snowflake_conn_id", + "databricks_conn_id", + "exasol_conn_id", + "hiveserver2_conn_id", +) + +# Keyword fragments found in a conn_id string mapped to sqlglot dialect names. +_CONN_TYPE_TO_DIALECT: dict[str, str] = { + "postgres": "postgres", + "redshift": "redshift", + "mysql": "mysql", + "mssql": "tsql", + "snowflake": "snowflake", + "bigquery": "bigquery", + "databricks": "databricks", + "sqlite": "sqlite", + "oracle": "oracle", + "trino": "trino", + "presto": "presto", + "hive": "hive", + "spark": "spark", +} + +# Operator attribute names checked as explicit write-target table when SQL +# parsing yields no targets (e.g. GenericTransfer, HiveToMySqlOperator). +_TARGET_TABLE_ATTRS: tuple[str, ...] = ( + "destination_table", + "mysql_table", + "hive_table", + "target_table", +) + + +class BaseLineageResolver(ABC): + """Base class for operator lineage resolvers.""" + + @abstractmethod + def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: + """Return ``(source_refs, target_refs)`` or ``None`` if the resolver does not apply.""" + + +class SQLLineageResolver(BaseLineageResolver): + """ + Resolves lineage for any operator that exposes a ``sql`` attribute. + + Detection is tiered: + + - Tier 1: operators inheriting from ``BaseSQLOperator`` — ``conn_id_field`` + points to the right connection attribute. + - Tier 2: operators with a ``sql`` attribute but no ``BaseSQLOperator`` + base (e.g. ``GenericTransfer``, ``BaseSQLToGCSOperator``) — dialect is + inferred from the first recognizable connection ID string found. + + Returns ``None`` when there is no SQL, when Jinja templates are detected, + or when parsing produces no table references. + """ + + def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: + sql = getattr(task, "sql", None) + if not sql: + return None + dialect = _infer_dialect(task) + default_database: str | None = getattr(task, "database", None) + sources, targets = parse_sql_tables(sql, dialect=dialect) + if not targets: Review Comment: Candidly, I don't have the EDC knowledge to be able to understand this, but my AI review mentioned the following: --- The explicit target fallback for operators like `GenericTransfer` looks like it mishandles schema-qualified table names. For example, a task may define the source query in `sql` and the target table separately: GenericTransfer( sql="SELECT * FROM orders", destination_table="public.customer_segment_snapshot", ) When SQL parsing finds sources but no write target, the fallback currently builds: TableRef(table="public.customer_segment_snapshot") Then the EDC lookup receives an empty schema/catalog and searches for a table literally named `public.customer_segment_snapshot`. In most catalogs, that object would be represented as table `customer_segment_snapshot` in schema `public`, so this can fail to resolve or resolve incorrectly. I think the fallback should parse qualified table names before creating the `TableRef`, so this becomes: TableRef(table="customer_segment_snapshot", schema="public") A focused test with a schema-qualified `destination_table` would probably catch this. ########## providers/informatica/src/airflow/providers/informatica/plugins/listener.py: ########## @@ -30,107 +33,214 @@ _informatica_listener: InformaticaListener | None = None +class InformaticaLineageResolutionError(RuntimeError): + """Raised when an EDC object cannot be resolved for a lineage URI.""" + + +def _resolve_uri_to_object_id(hook: InformaticaLineageExtractor, uri: str) -> str: + """ + Resolve an EDC lineage URI to an Informatica catalog object ID. + + Manual lineage entries are treated as concrete object identifiers/URIs. + They are validated directly via ``get_object`` instead of being reparsed + and looked up again with ``find_object_id``. + """ + log = logging.getLogger(__name__) + try: + obj = hook.get_object(uri) + except InformaticaEDCError as exc: + raise InformaticaLineageResolutionError( + f"Failed to resolve EDC object for URI {uri!r}: {exc}" + ) from exc + + object_id = obj.get("id") if isinstance(obj, dict) else None + if not object_id: + raise InformaticaLineageResolutionError( + f"Could not resolve EDC object for URI {uri!r}. Ensure the object exists in the Informatica catalog." + ) + log.debug("Resolved URI %r to EDC object_id=%s", uri, object_id) + return object_id + + class InformaticaListener: """Informatica listener sends events on task instance state changes to Informatica EDC for lineage tracking.""" def __init__(self): - self._executor = None self.log = logging.getLogger(__name__) self.hook = InformaticaLineageExtractor(edc_hook=InformaticaEDCHook()) - # self.extractor_manager = ExtractorManager() + # Cache: _cache_key(ti) -> (valid_inlets, valid_outlets) + # Populated by on_task_instance_running (pre-validation), consumed by + # on_task_instance_success and cleared by on_task_instance_failed. + self._resolved_cache: dict[tuple, tuple[list[tuple[str, str]], list[tuple[str, str]]]] = {} + + @staticmethod + def _cache_key(task_instance: TaskInstance) -> tuple: + dag_id = getattr(task_instance, "dag_id", None) + if dag_id is None: + task = getattr(task_instance, "task", None) + dag_id = getattr(task, "dag_id", None) + return ( + dag_id, + getattr(task_instance, "run_id", None), + task_instance.task_id, + getattr(task_instance, "map_index", -1), + getattr(task_instance, "try_number", None), + ) @hookimpl def on_task_instance_success( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="success") + key = self._cache_key(task_instance) + cached = self._resolved_cache.pop(key, None) + if cached is None: + # Running hook was skipped (e.g. operator disabled) - nothing to do. + return + valid_inlets, valid_outlets = cached + self._create_lineage_links(valid_inlets, valid_outlets, task_instance.task_id) @hookimpl def on_task_instance_failed( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="failed") + # Clean up cache entry so stale entries do not accumulate. + self._resolved_cache.pop(self._cache_key(task_instance), None) @hookimpl def on_task_instance_running( self, previous_state: TaskInstanceState, task_instance: TaskInstance, *args, **kwargs ): - self._handle_lineage(task_instance, state="running") - - def _handle_lineage(self, task_instance: TaskInstance, state: str): """ - Handle lineage resolution for inlets and outlets. + Validate and pre-resolve all inlet/outlet URIs before the task executes. - For each inlet and outlet, resolve Informatica EDC object IDs using getObject. - If valid, collect and create lineage links between all valid inlets and outlets. + Raises :class:`InformaticaLineageResolutionError` if any URI or table cannot + be resolved in the Informatica catalog. This causes Airflow to fail the task + immediately - before the operator ``execute()`` is called. Review Comment: The listener doesn't actually have the ability to fail the task prior to execution. Airflow merely logs the exception, but doesn't fail it. You can add a `pre_execute` method on the Operator instead, and have a shared function that can be used by the Listener and the Operator -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
