This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ff28969ff3 fix: EmrServerlessStartJobOperator not serializing DAGs
correctly when partial/expand is used. (#38022)
ff28969ff3 is described below
commit ff28969ff3370034ed9246d4ce9d0022129b3152
Author: jliu0812 <[email protected]>
AuthorDate: Mon Mar 25 16:47:53 2024 -0500
fix: EmrServerlessStartJobOperator not serializing DAGs correctly when
partial/expand is used. (#38022)
---
airflow/providers/amazon/aws/operators/emr.py | 62 +++++++++++++++++++---
.../amazon/aws/operators/test_emr_serverless.py | 55 +++++++++++++++++++
2 files changed, 111 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index 7c4d86c5e8..01e1567eab 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -1253,27 +1253,77 @@ class EmrServerlessStartJobOperator(BaseOperator):
op_extra_links = []
if isinstance(self, MappedOperator):
+ operator_class = self.operator_class
enable_application_ui_links = self.partial_kwargs.get(
"enable_application_ui_links"
) or self.expand_input.value.get("enable_application_ui_links")
- job_driver = self.partial_kwargs.get("job_driver") or
self.expand_input.value.get("job_driver")
+ job_driver = self.partial_kwargs.get("job_driver", {}) or
self.expand_input.value.get(
+ "job_driver", {}
+ )
configuration_overrides = self.partial_kwargs.get(
"configuration_overrides"
) or self.expand_input.value.get("configuration_overrides")
+ # Configuration overrides can either be a list or a dictionary,
depending on whether it's passed in as partial or expand.
+ if isinstance(configuration_overrides, list):
+ if any(
+ [
+ operator_class.is_monitoring_in_job_override(
+ self=operator_class,
+ config_key="s3MonitoringConfiguration",
+ job_override=job_override,
+ )
+ for job_override in configuration_overrides
+ ]
+ ):
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
+ if any(
+ [
+ operator_class.is_monitoring_in_job_override(
+ self=operator_class,
+ config_key="cloudWatchLoggingConfiguration",
+ job_override=job_override,
+ )
+ for job_override in configuration_overrides
+ ]
+ ):
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+ else:
+ if operator_class.is_monitoring_in_job_override(
+ self=operator_class,
+ config_key="s3MonitoringConfiguration",
+ job_override=configuration_overrides,
+ ):
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
+ if operator_class.is_monitoring_in_job_override(
+ self=operator_class,
+ config_key="cloudWatchLoggingConfiguration",
+ job_override=configuration_overrides,
+ ):
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+
else:
+ operator_class = self
enable_application_ui_links = self.enable_application_ui_links
configuration_overrides = self.configuration_overrides
job_driver = self.job_driver
+ if operator_class.is_monitoring_in_job_override(
+ "s3MonitoringConfiguration", configuration_overrides
+ ):
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
+ if operator_class.is_monitoring_in_job_override(
+ "cloudWatchLoggingConfiguration", configuration_overrides
+ ):
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+
if enable_application_ui_links:
op_extra_links.extend([EmrServerlessDashboardLink()])
- if "sparkSubmit" in job_driver:
+ if isinstance(job_driver, list):
+ if any("sparkSubmit" in ind_job_driver for ind_job_driver in
job_driver):
+ op_extra_links.extend([EmrServerlessLogsLink()])
+ elif "sparkSubmit" in job_driver:
op_extra_links.extend([EmrServerlessLogsLink()])
- if self.is_monitoring_in_job_override("s3MonitoringConfiguration",
configuration_overrides):
- op_extra_links.extend([EmrServerlessS3LogsLink()])
- if
self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration",
configuration_overrides):
- op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
return tuple(op_extra_links)
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index eed292c3cd..35eae39210 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -25,12 +25,21 @@ from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
+from airflow.providers.amazon.aws.links.emr import (
+ EmrServerlessCloudWatchLogsLink,
+ EmrServerlessDashboardLink,
+ EmrServerlessLogsLink,
+ EmrServerlessS3LogsLink,
+)
from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)
+from airflow.serialization.serialized_objects import (
+ SerializedBaseOperator,
+)
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
@@ -1096,6 +1105,52 @@ class TestEmrServerlessStartJobOperator:
job_run_id=job_run_id,
)
+ def test_operator_extra_links_mapped_without_applicationui_enabled(
+ self,
+ ):
+ operator = EmrServerlessStartJobOperator.partial(
+ task_id=task_id,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=spark_job_driver,
+ enable_application_ui_links=False,
+ ).expand(
+ configuration_overrides=[s3_configuration_overrides,
cloudwatch_configuration_overrides],
+ )
+
+ serialize = SerializedBaseOperator.serialize
+ deserialize = SerializedBaseOperator.deserialize_operator
+ deserialized_operator = deserialize(serialize(operator))
+
+ assert deserialized_operator.operator_extra_links == [
+ EmrServerlessS3LogsLink(),
+ EmrServerlessCloudWatchLogsLink(),
+ ]
+
+ def test_operator_extra_links_mapped_with_applicationui_enabled_at_partial(
+ self,
+ ):
+ operator = EmrServerlessStartJobOperator.partial(
+ task_id=task_id,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=spark_job_driver,
+ enable_application_ui_links=True,
+ ).expand(
+ configuration_overrides=[s3_configuration_overrides,
cloudwatch_configuration_overrides],
+ )
+
+ serialize = SerializedBaseOperator.serialize
+ deserialize = SerializedBaseOperator.deserialize_operator
+ deserialized_operator = deserialize(serialize(operator))
+
+ assert deserialized_operator.operator_extra_links == [
+ EmrServerlessS3LogsLink(),
+ EmrServerlessCloudWatchLogsLink(),
+ EmrServerlessDashboardLink(),
+ EmrServerlessLogsLink(),
+ ]
+
class TestEmrServerlessDeleteOperator:
@mock.patch.object(EmrServerlessHook, "get_waiter")