mobuchowski commented on code in PR #47736:
URL: https://github.com/apache/airflow/pull/47736#discussion_r1999188457
##########
providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py:
##########
@@ -78,3 +94,262 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
hostname = fix_account_name(hostname)
# else - its new hostname, just return it
return urlunparse((parts.scheme, hostname, parts.path, parts.params,
parts.query, parts.fragment))
+
+
+def _get_parent_run_facet(task_instance):
+ """
+ Retrieve the ParentRunFacet associated with a specific Airflow task
instance.
+
+ This facet helps link OpenLineage events of child jobs - such as queries
executed within
+ external systems (e.g., Snowflake) by the Airflow task - to the original
Airflow task execution.
+ Establishing this connection enables better lineage tracking and
observability.
+
+ It's crucial that the task_instance's run_id creation logic matches
OpenLineage's listener implementation.
+ Only then can we ensure that the generated run_id aligns with the Airflow
task,
+ enabling a proper connection between events.
+ """
+ from openlineage.client.facet_v2 import parent_run
+
+ from airflow.providers.openlineage.conf import namespace
+ from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
+
+ def _get_logical_date():
+ # todo: remove when min airflow version >= 3.0
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun = task_instance.get_template_context()["dag_run"]
+ return dagrun.logical_date or dagrun.run_after
+
+ if hasattr(task_instance, "logical_date"):
+ date = task_instance.logical_date
+ else:
+ date = task_instance.execution_date
+
+ return date
+
+ def _get_try_number():
+ # todo: remove when min airflow version >= 2.10.0
+ if AIRFLOW_V_2_10_PLUS:
+ return task_instance.try_number
+ return task_instance.try_number - 1
+
+ # Generate same OL run id as is generated for current task instance
+ parent_run_id = OpenLineageAdapter.build_task_instance_run_id(
+ dag_id=task_instance.dag_id,
+ task_id=task_instance.task_id,
+ logical_date=_get_logical_date(),
+ try_number=_get_try_number(),
+ map_index=task_instance.map_index,
+ )
+
+ return parent_run.ParentRunFacet(
+ run=parent_run.Run(runId=parent_run_id),
+ job=parent_run.Job(
+ namespace=namespace(),
+ name=f"{task_instance.dag_id}.{task_instance.task_id}",
+ ),
+ )
+
+
+def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]:
+ """Execute a query against Snowflake without adding extra logging or
instrumentation."""
+ with closing(hook.get_conn()) as conn:
+ hook.set_autocommit(conn, False)
+ with hook._get_cursor(conn, return_dictionaries=True) as cur:
+ cur.execute(sql)
+ result = cur.fetchall()
+ conn.commit()
+ return result
+
+
+def _get_queries_details_from_snowflake(
+ hook: SnowflakeHook, query_ids: list[str]
+) -> dict[str, dict[str, str]]:
+ """Retrieve execution details for specific queries from Snowflake's query
history."""
+ if not query_ids:
+ return {}
+ query_condition = f"IN {tuple(query_ids)}" if len(query_ids) > 1 else f"=
'{query_ids[0]}'"
+ query = (
+ "SELECT "
+ "QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT,
ERROR_CODE, ERROR_MESSAGE "
+ "FROM "
+ "table(information_schema.query_history()) "
+ f"WHERE "
+ f"QUERY_ID {query_condition}"
+ f";"
+ )
+
+ result = _run_single_query_with_hook(hook=hook, sql=query)
+
+ return {row["QUERY_ID"]: row for row in result} if result else {}
+
+
+def _create_snowflake_event_pair(
+ job_namespace: str,
+ job_name: str,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ is_successful: bool,
+ run_facets: dict | None = None,
+ job_facets: dict | None = None,
+) -> tuple[RunEvent, RunEvent]:
+ """Create a pair of OpenLineage RunEvents representing the start and end
of a Snowflake job execution."""
+ from openlineage.client.event_v2 import Job, Run, RunEvent, RunState
+ from openlineage.client.uuid import generate_new_uuid
+
+ run = Run(runId=str(generate_new_uuid()), facets=run_facets or {})
+ job = Job(namespace=job_namespace, name=job_name, facets=job_facets or {})
+
+ start = RunEvent(
+ eventType=RunState.START,
+ eventTime=start_time.isoformat(),
+ run=run,
+ job=job,
+ )
+ end = RunEvent(
+ eventType=RunState.COMPLETE if is_successful else RunState.FAIL,
+ eventTime=end_time.isoformat(),
+ run=run,
+ job=job,
+ )
+ return start, end
+
+
+def _check_openlineage_modules_are_importable():
+ """
+ Ensure that required OpenLineage modules are importable.
+
+ This function checks whether the necessary OpenLineage modules are
available for use.
+ Since the Snowflake provider does not directly require OpenLineage, users
must ensure
+ that the relevant dependencies are installed to enable some OpenLineage
features.
+
+ If the required modules are missing, an
`AirflowOptionalProviderFeatureException` is raised,
+ guiding users to install the appropriate package versions.
+ """
+ try:
+ from airflow.providers.openlineage.conf import namespace # noqa: F401
+ except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ msg = "Please install `apache-airflow-providers-openlineage>=1.7.0`"
+ raise AirflowOptionalProviderFeatureException(e, msg)
+
+ try:
+ from openlineage.client.event_v2 import Job, Run, RunEvent, RunState
# noqa: F401
+ from openlineage.client.facet_v2 import job_type_job, parent_run #
noqa: F401
+ from openlineage.client.uuid import generate_new_uuid # noqa: F401
+
+ except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ msg = "Please install `openlineage-python>=1.15.0`"
+ raise AirflowOptionalProviderFeatureException(e, msg)
Review Comment:
We can have optional param in the check to check for particular provider or
OL client version 🙂
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]