This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch glue-pass-params in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 1dc055af16c1f37784afc6c3c799cf9108f0e2b7 Author: Maciej Obuchowski <[email protected]> AuthorDate: Fri Dec 5 13:30:17 2025 +0100 glue pass openlineage params Signed-off-by: Maciej Obuchowski <[email protected]> --- .../airflow/providers/amazon/aws/operators/glue.py | 10 + .../providers/amazon/aws/utils/openlineage.py | 150 ++++++++++++++ .../unit/amazon/aws/utils/test_openlineage.py | 224 +++++++++++++++++++++ .../common/compat/openlineage/utils/spark.py | 6 + .../tests/unit/dbt/cloud/utils/test_openlineage.py | 2 +- .../airflow/providers/openlineage/utils/spark.py | 53 ++++- 6 files changed, 436 insertions(+), 9 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 5c75a43d52b..0036a94c5c3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -37,6 +37,9 @@ from airflow.providers.amazon.aws.triggers.glue import ( ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.providers.amazon.aws.utils.openlineage import ( + inject_parent_job_information_into_glue_script_args, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -139,10 +142,14 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): job_poll_interval: int | float = 6, waiter_delay: int = 60, waiter_max_attempts: int = 75, + openlineage_inject_parent_job_info: bool = conf.getboolean( + "openlineage", "spark_inject_parent_job_info", fallback=False + ), **kwargs, ): super().__init__(**kwargs) self.job_name = job_name + self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self.job_desc = job_desc self.script_location = script_location self.concurrent_run_limit = concurrent_run_limit or 1 @@ -217,6 +224,9 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): :return: the current Glue job ID. """ + if self._openlineage_inject_parent_job_info: + self.log.debug("Injecting OpenLineage parent job information into Glue script_args.") + self.script_args = inject_parent_job_information_into_glue_script_args(self.script_args, context) self.log.info( "Initializing AWS Glue Job: %s. Wait for completion: %s", self.job_name, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/openlineage.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/openlineage.py index be5703e2f6e..2c54f1d4eb0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/openlineage.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/openlineage.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook @@ -28,9 +29,13 @@ from airflow.providers.common.compat.openlineage.facet import ( SchemaDatasetFacet, SchemaDatasetFacetFields, ) +from airflow.providers.common.compat.openlineage.utils.spark import get_parent_job_information if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook + from airflow.utils.context import Context + +log = logging.getLogger(__name__) def get_facets_from_redshift_table( @@ -136,3 +141,148 @@ def get_identity_column_lineage_facet( } ) return column_lineage_facet + + +def _parse_glue_customer_env_vars(env_vars_string: str | None) -> dict[str, str]: + """ + Parse the --customer-driver-env-vars format into a dict. + + Format: "KEY1=VAL1,KEY2=\"val2,val2 val2\"" + - Simple values: KEY=VALUE + - Values with commas/spaces: KEY="value with, spaces" + + Args: + env_vars_string: The environment variables string from Glue script args. + + Returns: + Dict of key-value pairs. + """ + if not env_vars_string: + return {} + + result: dict[str, str] = {} + current = "" + in_quotes = False + + for char in env_vars_string: + if char == '"' and (not current or current[-1] != "\\"): + in_quotes = not in_quotes + current += char + elif char == "," and not in_quotes: + if "=" in current: + key, value = current.split("=", 1) + # Strip surrounding quotes if present + value = value.strip() + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + result[key.strip()] = value + current = "" + else: + current += char + + # Handle last element + if current and "=" in current: + key, value = current.split("=", 1) + value = value.strip() + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + result[key.strip()] = value + + return result + + +def _format_glue_customer_env_vars(env_vars: dict[str, str]) -> str: + """ + Format a dict back into the --customer-driver-env-vars string format. + + - Values containing commas, spaces, or quotes need quoting + - Quotes within values need escaping + + Args: + env_vars: Dict of environment variables. + + Returns: + String in format "KEY1=VAL1,KEY2=\"val2\"" + """ + parts = [] + for key, value in env_vars.items(): + # Quote if contains special chars + if "," in value or " " in value or '"' in value: + escaped_value = value.replace('"', '\\"') + parts.append(f'{key}="{escaped_value}"') + else: + parts.append(f"{key}={value}") + return ",".join(parts) + + +def _is_parent_job_info_present_in_glue_env_vars(script_args: dict[str, Any]) -> bool: + """ + Check if any OpenLineage parent job env vars are already set. + + Args: + script_args: The Glue job's script_args dict. + + Returns: + True if any OL parent job env vars are present. + """ + # Check --customer-driver-env-vars + driver_env_vars_str = script_args.get("--customer-driver-env-vars", "") + driver_env_vars = _parse_glue_customer_env_vars(driver_env_vars_str) + + # Also check --customer-executor-env-vars + executor_env_vars_str = script_args.get("--customer-executor-env-vars", "") + executor_env_vars = _parse_glue_customer_env_vars(executor_env_vars_str) + + all_env_vars = {**driver_env_vars, **executor_env_vars} + + # Check if ANY OpenLineage parent env var is present + return any( + key.startswith("OPENLINEAGE_PARENT") or key.startswith("OPENLINEAGE_ROOT_PARENT") + for key in all_env_vars + ) + + +def inject_parent_job_information_into_glue_script_args( + script_args: dict[str, Any], context: Context +) -> dict[str, Any]: + """ + Inject OpenLineage parent job info into Glue script_args. + + The parent job information is injected via the --customer-driver-env-vars argument, + which sets environment variables in the Spark driver process. + + - If OpenLineage provider is not available, skip injection + - If user already set any OPENLINEAGE_PARENT_* or OPENLINEAGE_ROOT_PARENT_* env vars, + skip injection to preserve user-provided values + - Merge with existing --customer-driver-env-vars if present + - Return new dict (don't mutate original) + + Args: + script_args: The Glue job's script_args dict. + context: Airflow task context. + + Returns: + Modified script_args with OpenLineage env vars injected. + """ + info = get_parent_job_information(context) + if info is None: + return script_args + + existing_env_vars_str = script_args.get("--customer-driver-env-vars", "") + existing_env_vars = _parse_glue_customer_env_vars(existing_env_vars_str) + + ol_env_vars = { + "OPENLINEAGE_PARENT_JOB_NAMESPACE": info.parent_job_namespace, + "OPENLINEAGE_PARENT_JOB_NAME": info.parent_job_name, + "OPENLINEAGE_PARENT_RUN_ID": info.parent_run_id, + "OPENLINEAGE_ROOT_PARENT_JOB_NAMESPACE": info.root_parent_job_namespace, + "OPENLINEAGE_ROOT_PARENT_JOB_NAME": info.root_parent_job_name, + "OPENLINEAGE_ROOT_PARENT_RUN_ID": info.root_parent_run_id, + } + + merged_env_vars = {**existing_env_vars, **ol_env_vars} + + new_script_args = {**script_args} + new_script_args["--customer-driver-env-vars"] = _format_glue_customer_env_vars(merged_env_vars) + + return new_script_args diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_openlineage.py b/providers/amazon/tests/unit/amazon/aws/utils/test_openlineage.py index 3790a590238..4561241a2f8 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_openlineage.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_openlineage.py @@ -168,3 +168,227 @@ def test_get_identity_column_lineage_facet_no_input_datasets(): ValueError, match="When providing `field_names` You must provide at least one `input_dataset`." ): get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + + +# --- Glue OpenLineage tests --- + +from datetime import datetime +from unittest.mock import MagicMock + +from airflow.providers.amazon.aws.utils.openlineage import ( + _format_glue_customer_env_vars, + _is_parent_job_info_present_in_glue_env_vars, + _parse_glue_customer_env_vars, + inject_parent_job_information_into_glue_script_args, +) + +EXAMPLE_CONTEXT = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=-1, + logical_date=datetime(2024, 11, 11), + dag_run=MagicMock(logical_date=datetime(2024, 11, 11), clear_number=0), + ) +} + + +class TestParseGlueCustomerEnvVars: + def test_empty_string(self): + assert _parse_glue_customer_env_vars("") == {} + + def test_none(self): + assert _parse_glue_customer_env_vars(None) == {} + + def test_simple_single(self): + result = _parse_glue_customer_env_vars("KEY1=VAL1") + assert result == {"KEY1": "VAL1"} + + def test_simple_multiple(self): + result = _parse_glue_customer_env_vars("KEY1=VAL1,KEY2=VAL2") + assert result == {"KEY1": "VAL1", "KEY2": "VAL2"} + + def test_quoted_value_with_comma(self): + result = _parse_glue_customer_env_vars('KEY1=VAL1,KEY2="val with, comma"') + assert result == {"KEY1": "VAL1", "KEY2": "val with, comma"} + + def test_quoted_value_with_space(self): + result = _parse_glue_customer_env_vars('KEY1="value with spaces"') + assert result == {"KEY1": "value with spaces"} + + def test_value_with_equals(self): + result = _parse_glue_customer_env_vars("KEY1=val=ue") + assert result == {"KEY1": "val=ue"} + + +class TestFormatGlueCustomerEnvVars: + def test_empty(self): + result = _format_glue_customer_env_vars({}) + assert result == "" + + def test_simple_single(self): + result = _format_glue_customer_env_vars({"KEY1": "VAL1"}) + assert result == "KEY1=VAL1" + + def test_simple_multiple(self): + result = _format_glue_customer_env_vars({"KEY1": "VAL1", "KEY2": "VAL2"}) + assert "KEY1=VAL1" in result + assert "KEY2=VAL2" in result + + def test_value_with_comma_gets_quoted(self): + result = _format_glue_customer_env_vars({"KEY": "val,ue"}) + assert result == 'KEY="val,ue"' + + def test_value_with_space_gets_quoted(self): + result = _format_glue_customer_env_vars({"KEY": "val ue"}) + assert result == 'KEY="val ue"' + + def test_roundtrip(self): + original = {"KEY1": "simple", "KEY2": "has, comma", "KEY3": "has space"} + formatted = _format_glue_customer_env_vars(original) + parsed = _parse_glue_customer_env_vars(formatted) + assert parsed == original + + +class TestIsParentJobInfoPresentInGlueEnvVars: + def test_empty_returns_false(self): + assert _is_parent_job_info_present_in_glue_env_vars({}) is False + + def test_unrelated_vars_return_false(self): + script_args = {"--customer-driver-env-vars": "MY_VAR=value"} + assert _is_parent_job_info_present_in_glue_env_vars(script_args) is False + + def test_parent_var_returns_true(self): + script_args = {"--customer-driver-env-vars": "OPENLINEAGE_PARENT_JOB_NAME=test"} + assert _is_parent_job_info_present_in_glue_env_vars(script_args) is True + + def test_root_parent_var_returns_true(self): + script_args = {"--customer-driver-env-vars": "OPENLINEAGE_ROOT_PARENT_RUN_ID=123"} + assert _is_parent_job_info_present_in_glue_env_vars(script_args) is True + + def test_executor_env_vars_also_checked(self): + script_args = {"--customer-executor-env-vars": "OPENLINEAGE_PARENT_RUN_ID=123"} + assert _is_parent_job_info_present_in_glue_env_vars(script_args) is True + + def test_no_script_args_key(self): + script_args = {"--other-arg": "value"} + assert _is_parent_job_info_present_in_glue_env_vars(script_args) is False + + +class TestInjectParentJobInformationIntoGlueScriptArgs: + @mock.patch( + "airflow.providers.amazon.aws.utils.openlineage.get_parent_job_information", + return_value=None, + ) + def test_skips_when_openlineage_not_available(self, mock_get_parent_info): + result = inject_parent_job_information_into_glue_script_args({}, EXAMPLE_CONTEXT) + assert result == {} + + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_parent_job_information") + def test_injects_into_empty_script_args(self, mock_get_parent_info): + from airflow.providers.openlineage.utils.spark import ParentJobInformation + + mock_get_parent_info.return_value = ParentJobInformation( + parent_job_namespace="default", + parent_job_name="dag_id.task_id", + parent_run_id="uuid-123", + root_parent_job_namespace="default", + root_parent_job_name="dag_id", + root_parent_run_id="uuid-456", + ) + + result = inject_parent_job_information_into_glue_script_args({}, EXAMPLE_CONTEXT) + + assert "--customer-driver-env-vars" in result + env_vars = _parse_glue_customer_env_vars(result["--customer-driver-env-vars"]) + assert env_vars["OPENLINEAGE_PARENT_JOB_NAMESPACE"] == "default" + assert env_vars["OPENLINEAGE_PARENT_JOB_NAME"] == "dag_id.task_id" + assert env_vars["OPENLINEAGE_PARENT_RUN_ID"] == "uuid-123" + assert env_vars["OPENLINEAGE_ROOT_PARENT_JOB_NAMESPACE"] == "default" + assert env_vars["OPENLINEAGE_ROOT_PARENT_JOB_NAME"] == "dag_id" + assert env_vars["OPENLINEAGE_ROOT_PARENT_RUN_ID"] == "uuid-456" + + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_parent_job_information") + def test_merges_with_existing_env_vars(self, mock_get_parent_info): + from airflow.providers.openlineage.utils.spark import ParentJobInformation + + mock_get_parent_info.return_value = ParentJobInformation( + parent_job_namespace="default", + parent_job_name="dag.task", + parent_run_id="uuid-123", + root_parent_job_namespace="default", + root_parent_job_name="dag", + root_parent_run_id="uuid-456", + ) + + existing = {"--customer-driver-env-vars": "EXISTING_VAR=value"} + result = inject_parent_job_information_into_glue_script_args(existing, EXAMPLE_CONTEXT) + + env_vars = _parse_glue_customer_env_vars(result["--customer-driver-env-vars"]) + assert "EXISTING_VAR" in env_vars + assert env_vars["EXISTING_VAR"] == "value" + assert "OPENLINEAGE_PARENT_JOB_NAME" in env_vars + + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_parent_job_information") + def test_preserves_other_script_args(self, mock_get_parent_info): + from airflow.providers.openlineage.utils.spark import ParentJobInformation + + mock_get_parent_info.return_value = ParentJobInformation( + parent_job_namespace="default", + parent_job_name="dag.task", + parent_run_id="uuid-123", + root_parent_job_namespace="default", + root_parent_job_name="dag", + root_parent_run_id="uuid-456", + ) + + existing = {"--input": "s3://bucket/input", "--output": "s3://bucket/output"} + result = inject_parent_job_information_into_glue_script_args(existing, EXAMPLE_CONTEXT) + + assert result["--input"] == "s3://bucket/input" + assert result["--output"] == "s3://bucket/output" + assert "--customer-driver-env-vars" in result + + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_parent_job_information") + def test_preserves_existing_customer_driver_env_vars(self, mock_get_parent_info): + """Key test: verify that existing user-defined env vars are preserved when injecting OL vars.""" + from airflow.providers.openlineage.utils.spark import ParentJobInformation + + mock_get_parent_info.return_value = ParentJobInformation( + parent_job_namespace="default", + parent_job_name="dag.task", + parent_run_id="uuid-123", + root_parent_job_namespace="default", + root_parent_job_name="dag", + root_parent_run_id="uuid-456", + ) + + # User has set multiple custom env vars including ones with special characters + existing = { + "--customer-driver-env-vars": 'MY_VAR=value1,ANOTHER_VAR=value2,COMPLEX_VAR="value with, comma"', + "--other-arg": "other_value", + } + result = inject_parent_job_information_into_glue_script_args(existing, EXAMPLE_CONTEXT) + + # Parse the resulting env vars + env_vars = _parse_glue_customer_env_vars(result["--customer-driver-env-vars"]) + + # Verify ALL original user env vars are preserved + assert env_vars["MY_VAR"] == "value1" + assert env_vars["ANOTHER_VAR"] == "value2" + assert env_vars["COMPLEX_VAR"] == "value with, comma" + + # Verify OpenLineage env vars were added + assert env_vars["OPENLINEAGE_PARENT_JOB_NAMESPACE"] == "default" + assert env_vars["OPENLINEAGE_PARENT_JOB_NAME"] == "dag.task" + assert env_vars["OPENLINEAGE_PARENT_RUN_ID"] == "uuid-123" + assert env_vars["OPENLINEAGE_ROOT_PARENT_JOB_NAMESPACE"] == "default" + assert env_vars["OPENLINEAGE_ROOT_PARENT_JOB_NAME"] == "dag" + assert env_vars["OPENLINEAGE_ROOT_PARENT_RUN_ID"] == "uuid-456" + + # Verify other script args are unchanged + assert result["--other-arg"] == "other_value" + + # Total count: 3 user vars + 6 OL vars = 9 vars + assert len(env_vars) == 9 diff --git a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py index 1028bf3debf..414acce2081 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py +++ b/providers/common/compat/src/airflow/providers/common/compat/openlineage/utils/spark.py @@ -24,12 +24,14 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: from airflow.providers.openlineage.utils.spark import ( + get_parent_job_information, inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, ) from airflow.sdk import Context try: from airflow.providers.openlineage.utils.spark import ( + get_parent_job_information, inject_parent_job_information_into_spark_properties, inject_transport_information_into_spark_properties, ) @@ -49,8 +51,12 @@ except ImportError: ) return properties + def get_parent_job_information(context: Context) -> None: + return None + __all__ = [ "inject_parent_job_information_into_spark_properties", "inject_transport_information_into_spark_properties", + "get_parent_job_information", ] diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py index e5d9431f8b3..35d615ca82e 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py @@ -28,7 +28,7 @@ from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator from airflow.providers.dbt.cloud.utils.openlineage import generate_openlineage_events_from_dbt_cloud_run -from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.git puopenlineage.extractors import OperatorLineage TASK_ID = "dbt_test" DAG_ID = "dbt_dag" diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py index a92ac25eab2..3803496fb55 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple from airflow.providers.openlineage.plugins.listener import get_openlineage_listener from airflow.providers.openlineage.plugins.macros import ( @@ -35,6 +35,43 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) +class ParentJobInformation(NamedTuple): + """Container for OpenLineage parent job information.""" + + parent_job_namespace: str + parent_job_name: str + parent_run_id: str + root_parent_job_namespace: str + root_parent_job_name: str + root_parent_run_id: str + + +def get_parent_job_information(context: Context) -> ParentJobInformation | None: + """ + Retrieve parent job information from the Airflow context. + + This function extracts OpenLineage parent job details from the task instance, + which can be used by various integrations (Spark, Glue, etc.) to propagate + lineage information to child jobs. + + Args: + context: The Airflow context containing task instance information. + + Returns: + ParentJobInformation containing namespace, job name, and run IDs + for both parent and root parent. + """ + ti = context["ti"] + return ParentJobInformation( + parent_job_namespace=lineage_job_namespace(), + parent_job_name=lineage_job_name(ti), # type: ignore[arg-type] + parent_run_id=lineage_run_id(ti), # type: ignore[arg-type] + root_parent_job_namespace=lineage_job_namespace(), + root_parent_job_name=lineage_root_job_name(ti), # type: ignore[arg-type] + root_parent_run_id=lineage_root_run_id(ti), # type: ignore[arg-type] + ) + + def _get_parent_job_information_as_spark_properties(context: Context) -> dict: """ Retrieve parent job information as Spark properties. @@ -45,14 +82,14 @@ def _get_parent_job_information_as_spark_properties(context: Context) -> dict: Returns: Spark properties with the parent job information. """ - ti = context["ti"] + info = get_parent_job_information(context) return { - "spark.openlineage.parentJobNamespace": lineage_job_namespace(), - "spark.openlineage.parentJobName": lineage_job_name(ti), # type: ignore[arg-type] - "spark.openlineage.parentRunId": lineage_run_id(ti), # type: ignore[arg-type] - "spark.openlineage.rootParentRunId": lineage_root_run_id(ti), # type: ignore[arg-type] - "spark.openlineage.rootParentJobName": lineage_root_job_name(ti), # type: ignore[arg-type] - "spark.openlineage.rootParentJobNamespace": lineage_job_namespace(), + "spark.openlineage.parentJobNamespace": info.parent_job_namespace, + "spark.openlineage.parentJobName": info.parent_job_name, + "spark.openlineage.parentRunId": info.parent_run_id, + "spark.openlineage.rootParentRunId": info.root_parent_run_id, + "spark.openlineage.rootParentJobName": info.root_parent_job_name, + "spark.openlineage.rootParentJobNamespace": info.root_parent_job_namespace, }
