This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 401e7bd531 Bugfix: Fix RDS triggers parameters so that they handle
serialization/deserialization (#34222)
401e7bd531 is described below
commit 401e7bd53119e204bf68c75dca28b1e35676c056
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Tue Sep 12 10:59:26 2023 -0700
Bugfix: Fix RDS triggers parameters so that they handle
serialization/deserialization (#34222)
---
airflow/providers/amazon/aws/triggers/rds.py | 59 ++--
tests/providers/amazon/aws/triggers/test_athena.py | 35 --
tests/providers/amazon/aws/triggers/test_batch.py | 64 ----
tests/providers/amazon/aws/triggers/test_ecs.py | 35 --
tests/providers/amazon/aws/triggers/test_eks.py | 100 ------
tests/providers/amazon/aws/triggers/test_emr.py | 88 -----
.../amazon/aws/triggers/test_emr_serverless.py | 92 -----
.../amazon/aws/triggers/test_emr_trigger.py | 77 -----
tests/providers/amazon/aws/triggers/test_glue.py | 38 ---
.../amazon/aws/triggers/test_lambda_function.py | 54 ---
tests/providers/amazon/aws/triggers/test_rds.py | 84 -----
.../amazon/aws/triggers/test_redshift_cluster.py | 81 -----
.../amazon/aws/triggers/test_serialization.py | 375 +++++++++++++++++++++
tests/providers/amazon/aws/triggers/test_sqs.py | 18 -
.../amazon/aws/triggers/test_step_function.py | 53 ---
15 files changed, 415 insertions(+), 838 deletions(-)
diff --git a/airflow/providers/amazon/aws/triggers/rds.py
b/airflow/providers/amazon/aws/triggers/rds.py
index ebc80ba700..78a6dfa16a 100644
--- a/airflow/providers/amazon/aws/triggers/rds.py
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -100,12 +100,12 @@ class RdsDbInstanceTrigger(BaseTrigger):
_waiter_arg = {
- RdsDbType.INSTANCE: "DBInstanceIdentifier",
- RdsDbType.CLUSTER: "DBClusterIdentifier",
+ RdsDbType.INSTANCE.value: "DBInstanceIdentifier",
+ RdsDbType.CLUSTER.value: "DBClusterIdentifier",
}
_status_paths = {
- RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus",
"DBInstances[].StatusInfos"],
- RdsDbType.CLUSTER: ["DBClusters[].Status"],
+ RdsDbType.INSTANCE.value: ["DBInstances[].DBInstanceStatus",
"DBInstances[].StatusInfos"],
+ RdsDbType.CLUSTER.value: ["DBClusters[].Status"],
}
@@ -129,20 +129,27 @@ class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int,
aws_conn_id: str,
response: dict[str, Any],
- db_type: RdsDbType,
+ db_type: RdsDbType | str,
region_name: str | None = None,
) -> None:
+ # allow passing enums for users,
+ # but we can only rely on strings because (de-)serialization doesn't
support enums
+ if isinstance(db_type, RdsDbType):
+ db_type_str = db_type.value
+ else:
+ db_type_str = db_type
+
super().__init__(
serialized_fields={
"db_identifier": db_identifier,
"response": response,
- "db_type": db_type,
+ "db_type": db_type_str,
},
- waiter_name=f"db_{db_type.value}_available",
- waiter_args={_waiter_arg[db_type]: db_identifier},
+ waiter_name=f"db_{db_type_str}_available",
+ waiter_args={_waiter_arg[db_type_str]: db_identifier},
failure_message="Error while waiting for DB to be available",
status_message="DB initialization in progress",
- status_queries=_status_paths[db_type],
+ status_queries=_status_paths[db_type_str],
return_key="response",
return_value=response,
waiter_delay=waiter_delay,
@@ -175,20 +182,27 @@ class RdsDbDeletedTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int,
aws_conn_id: str,
response: dict[str, Any],
- db_type: RdsDbType,
+ db_type: RdsDbType | str,
region_name: str | None = None,
) -> None:
+ # allow passing enums for users,
+ # but we can only rely on strings because (de-)serialization doesn't
support enums
+ if isinstance(db_type, RdsDbType):
+ db_type_str = db_type.value
+ else:
+ db_type_str = db_type
+
super().__init__(
serialized_fields={
"db_identifier": db_identifier,
"response": response,
- "db_type": db_type,
+ "db_type": db_type_str,
},
- waiter_name=f"db_{db_type.value}_deleted",
- waiter_args={_waiter_arg[db_type]: db_identifier},
+ waiter_name=f"db_{db_type_str}_deleted",
+ waiter_args={_waiter_arg[db_type_str]: db_identifier},
failure_message="Error while deleting DB",
status_message="DB deletion in progress",
- status_queries=_status_paths[db_type],
+ status_queries=_status_paths[db_type_str],
return_key="response",
return_value=response,
waiter_delay=waiter_delay,
@@ -221,20 +235,27 @@ class RdsDbStoppedTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int,
aws_conn_id: str,
response: dict[str, Any],
- db_type: RdsDbType,
+ db_type: RdsDbType | str,
region_name: str | None = None,
) -> None:
+ # allow passing enums for users,
+ # but we can only rely on strings because (de-)serialization doesn't
support enums
+ if isinstance(db_type, RdsDbType):
+ db_type_str = db_type.value
+ else:
+ db_type_str = db_type
+
super().__init__(
serialized_fields={
"db_identifier": db_identifier,
"response": response,
- "db_type": db_type,
+ "db_type": db_type_str,
},
- waiter_name=f"db_{db_type.value}_stopped",
- waiter_args={_waiter_arg[db_type]: db_identifier},
+ waiter_name=f"db_{db_type_str}_stopped",
+ waiter_args={_waiter_arg[db_type_str]: db_identifier},
failure_message="Error while stopping DB",
status_message="DB is being stopped",
- status_queries=_status_paths[db_type],
+ status_queries=_status_paths[db_type_str],
return_key="response",
return_value=response,
waiter_delay=waiter_delay,
diff --git a/tests/providers/amazon/aws/triggers/test_athena.py
b/tests/providers/amazon/aws/triggers/test_athena.py
deleted file mode 100644
index d18bdc1553..0000000000
--- a/tests/providers/amazon/aws/triggers/test_athena.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
-
-
-class TestAthenaTrigger:
- def test_serialize_recreate(self):
- trigger = AthenaTrigger("query_id", 1, 5, "aws connection")
-
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py
b/tests/providers/amazon/aws/triggers/test_batch.py
deleted file mode 100644
index 6ceee61332..0000000000
--- a/tests/providers/amazon/aws/triggers/test_batch.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.batch import (
- BatchCreateComputeEnvironmentTrigger,
- BatchJobTrigger,
-)
-
-BATCH_JOB_ID = "job_id"
-POLL_INTERVAL = 5
-MAX_ATTEMPT = 5
-AWS_CONN_ID = "aws_batch_job_conn"
-AWS_REGION = "us-east-2"
-pytest.importorskip("aiobotocore")
-
-
-class TestBatchTrigger:
- @pytest.mark.parametrize(
- "trigger",
- [
- BatchJobTrigger(
- job_id=BATCH_JOB_ID,
- waiter_delay=POLL_INTERVAL,
- waiter_max_attempts=MAX_ATTEMPT,
- aws_conn_id=AWS_CONN_ID,
- region_name=AWS_REGION,
- ),
- BatchCreateComputeEnvironmentTrigger(
- compute_env_arn="my_arn",
- waiter_delay=POLL_INTERVAL,
- waiter_max_attempts=MAX_ATTEMPT,
- aws_conn_id=AWS_CONN_ID,
- region_name=AWS_REGION,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py
b/tests/providers/amazon/aws/triggers/test_ecs.py
index 4b0a58a84a..046a2634e2 100644
--- a/tests/providers/amazon/aws/triggers/test_ecs.py
+++ b/tests/providers/amazon/aws/triggers/test_ecs.py
@@ -27,8 +27,6 @@ from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
- ClusterActiveTrigger,
- ClusterInactiveTrigger,
TaskDoneTrigger,
)
@@ -97,36 +95,3 @@ class TestTaskDoneTrigger:
assert response.payload["status"] == "success"
assert response.payload["task_arn"] == "my_task_arn"
-
-
-class TestClusterTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- ClusterActiveTrigger(
- cluster_arn="my_arn",
- aws_conn_id="my_conn",
- waiter_delay=1,
- waiter_max_attempts=2,
- region_name="my_region",
- ),
- ClusterInactiveTrigger(
- cluster_arn="my_arn",
- aws_conn_id="my_conn",
- waiter_delay=1,
- waiter_max_attempts=2,
- region_name="my_region",
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_eks.py
b/tests/providers/amazon/aws/triggers/test_eks.py
deleted file mode 100644
index 023f8d2a97..0000000000
--- a/tests/providers/amazon/aws/triggers/test_eks.py
+++ /dev/null
@@ -1,100 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.eks import (
- EksCreateClusterTrigger,
- EksCreateFargateProfileTrigger,
- EksCreateNodegroupTrigger,
- EksDeleteClusterTrigger,
- EksDeleteFargateProfileTrigger,
- EksDeleteNodegroupTrigger,
-)
-
-TEST_CLUSTER_IDENTIFIER = "test-cluster"
-TEST_FARGATE_PROFILE_NAME = "test-fargate-profile"
-TEST_NODEGROUP_NAME = "test-nodegroup"
-TEST_WAITER_DELAY = 10
-TEST_WAITER_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-aws-id"
-TEST_REGION = "test-region"
-
-
-class TestEksTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- EksCreateFargateProfileTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EksDeleteFargateProfileTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EksCreateNodegroupTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- nodegroup_name=TEST_NODEGROUP_NAME,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- region_name=TEST_REGION,
- ),
- EksDeleteNodegroupTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- nodegroup_name=TEST_NODEGROUP_NAME,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- region_name=TEST_REGION,
- ),
- EksCreateClusterTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_DELAY,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- ),
- EksDeleteClusterTrigger(
- cluster_name=TEST_CLUSTER_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_DELAY,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- force_delete_compute=True,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py
b/tests/providers/amazon/aws/triggers/test_emr.py
deleted file mode 100644
index 5a1369e89d..0000000000
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.emr import (
- EmrAddStepsTrigger,
- EmrContainerTrigger,
- EmrCreateJobFlowTrigger,
- EmrStepSensorTrigger,
- EmrTerminateJobFlowTrigger,
-)
-
-TEST_JOB_FLOW_ID = "test-job-flow-id"
-TEST_POLL_INTERVAL = 10
-TEST_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-aws-id"
-VIRTUAL_CLUSTER_ID = "vzwemreks"
-JOB_ID = "job-1234"
-POLL_INTERVAL = 60
-TARGET_STATE = ["TERMINATED"]
-STEP_ID = "s-1234"
-
-
-class TestEmrTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=["my_step1", "my_step2"],
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_POLL_INTERVAL,
- waiter_max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrCreateJobFlowTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrTerminateJobFlowTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrContainerTrigger(
- virtual_cluster_id=VIRTUAL_CLUSTER_ID,
- job_id=JOB_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- poll_interval=POLL_INTERVAL,
- ),
- EmrStepSensorTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_id=STEP_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=POLL_INTERVAL,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
- assert instance.hook().aws_conn_id == TEST_AWS_CONN_ID
diff --git a/tests/providers/amazon/aws/triggers/test_emr_serverless.py
b/tests/providers/amazon/aws/triggers/test_emr_serverless.py
deleted file mode 100644
index 39f1a9a0ab..0000000000
--- a/tests/providers/amazon/aws/triggers/test_emr_serverless.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.emr import (
- EmrServerlessCancelJobsTrigger,
- EmrServerlessCreateApplicationTrigger,
- EmrServerlessDeleteApplicationTrigger,
- EmrServerlessStartApplicationTrigger,
- EmrServerlessStartJobTrigger,
- EmrServerlessStopApplicationTrigger,
-)
-
-TEST_APPLICATION_ID = "test-application-id"
-TEST_WAITER_DELAY = 10
-TEST_WAITER_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-aws-id"
-AWS_CONN_ID = "aws_emr_conn"
-TEST_JOB_ID = "test-job-id"
-
-
-class TestEmrTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- EmrServerlessCreateApplicationTrigger(
- application_id=TEST_APPLICATION_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EmrServerlessStartApplicationTrigger(
- application_id=TEST_APPLICATION_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EmrServerlessStopApplicationTrigger(
- application_id=TEST_APPLICATION_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EmrServerlessDeleteApplicationTrigger(
- application_id=TEST_APPLICATION_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EmrServerlessCancelJobsTrigger(
- application_id=TEST_APPLICATION_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- EmrServerlessStartJobTrigger(
- application_id=TEST_APPLICATION_ID,
- job_id=TEST_JOB_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py
b/tests/providers/amazon/aws/triggers/test_emr_trigger.py
deleted file mode 100644
index 187a948fb5..0000000000
--- a/tests/providers/amazon/aws/triggers/test_emr_trigger.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.emr import (
- EmrContainerTrigger,
- EmrCreateJobFlowTrigger,
- EmrStepSensorTrigger,
- EmrTerminateJobFlowTrigger,
-)
-
-TEST_JOB_FLOW_ID = "test_job_flow_id"
-TEST_STEP_IDS = ["step1", "step2"]
-TEST_AWS_CONN_ID = "test-aws-id"
-TEST_MAX_ATTEMPTS = 10
-TEST_POLL_INTERVAL = 10
-
-
-class TestEmrTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- EmrCreateJobFlowTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_POLL_INTERVAL,
- waiter_max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrTerminateJobFlowTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_POLL_INTERVAL,
- waiter_max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrContainerTrigger(
- virtual_cluster_id="my_cluster_id",
- job_id=TEST_JOB_FLOW_ID,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_POLL_INTERVAL,
- waiter_max_attempts=TEST_MAX_ATTEMPTS,
- ),
- EmrStepSensorTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_id="my_step",
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_POLL_INTERVAL,
- waiter_max_attempts=TEST_MAX_ATTEMPTS,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_glue.py
b/tests/providers/amazon/aws/triggers/test_glue.py
index 9e6c6652b2..428974c980 100644
--- a/tests/providers/amazon/aws/triggers/test_glue.py
+++ b/tests/providers/amazon/aws/triggers/test_glue.py
@@ -25,7 +25,6 @@ from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.triggers.glue import
GlueCatalogPartitionTrigger, GlueJobCompleteTrigger
-from airflow.providers.amazon.aws.triggers.glue_crawler import
GlueCrawlerCompleteTrigger
class TestGlueJobTrigger:
@@ -73,44 +72,7 @@ class TestGlueJobTrigger:
assert get_state_mock.call_count == 3
-class TestGlueCrawlerTrigger:
- def test_serialize_recreate(self):
- trigger = GlueCrawlerCompleteTrigger(
- crawler_name="my_crawler", waiter_delay=2, aws_conn_id="my_conn_id"
- )
-
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
-
-
class TestGlueCatalogPartitionSensorTrigger:
- def test_serialize_recreate(self):
- trigger = GlueCatalogPartitionTrigger(
- database_name="my_database",
- table_name="my_table",
- expression="my_expression",
- aws_conn_id="my_conn_id",
- )
-
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
-
@pytest.mark.asyncio
@mock.patch.object(GlueCatalogHook, "async_get_partitions")
async def test_poke(self, mock_async_get_partitions):
diff --git a/tests/providers/amazon/aws/triggers/test_lambda_function.py
b/tests/providers/amazon/aws/triggers/test_lambda_function.py
deleted file mode 100644
index c06a99d42e..0000000000
--- a/tests/providers/amazon/aws/triggers/test_lambda_function.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.lambda_function import
LambdaCreateFunctionCompleteTrigger
-
-TEST_FUNCTION_NAME = "test-function-name"
-TEST_FUNCTION_ARN = "test-function-arn"
-TEST_WAITER_DELAY = 10
-TEST_WAITER_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-conn-id"
-TEST_REGION_NAME = "test-region-name"
-
-
-class TestLambdaFunctionTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- LambdaCreateFunctionCompleteTrigger(
- function_name=TEST_FUNCTION_NAME,
- function_arn=TEST_FUNCTION_ARN,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- )
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py
b/tests/providers/amazon/aws/triggers/test_rds.py
deleted file mode 100644
index 57db41a5e0..0000000000
--- a/tests/providers/amazon/aws/triggers/test_rds.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.rds import (
- RdsDbAvailableTrigger,
- RdsDbDeletedTrigger,
- RdsDbStoppedTrigger,
-)
-from airflow.providers.amazon.aws.utils.rds import RdsDbType
-
-TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
-TEST_WAITER_DELAY = 10
-TEST_WAITER_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-aws-id"
-TEST_REGION = "sa-east-1"
-TEST_RESPONSE = {
- "DBInstance": {
- "DBInstanceIdentifier": "test-db-instance-identifier",
- "DBInstanceStatus": "test-db-instance-status",
- }
-}
-
-
-class TestRdsTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- RdsDbAvailableTrigger(
- db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- db_type=RdsDbType.INSTANCE,
- ),
- RdsDbDeletedTrigger(
- db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- db_type=RdsDbType.INSTANCE,
- ),
- RdsDbStoppedTrigger(
- db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- db_type=RdsDbType.INSTANCE,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
deleted file mode 100644
index af5e4a9da1..0000000000
--- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.redshift_cluster import (
- RedshiftCreateClusterSnapshotTrigger,
- RedshiftCreateClusterTrigger,
- RedshiftDeleteClusterTrigger,
- RedshiftPauseClusterTrigger,
- RedshiftResumeClusterTrigger,
-)
-
-TEST_CLUSTER_IDENTIFIER = "test-cluster"
-TEST_POLL_INTERVAL = 10
-TEST_MAX_ATTEMPT = 10
-TEST_AWS_CONN_ID = "test-aws-id"
-
-
-class TestRedshiftClusterTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- RedshiftCreateClusterTrigger(
- cluster_identifier=TEST_CLUSTER_IDENTIFIER,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempt=TEST_MAX_ATTEMPT,
- aws_conn_id=TEST_AWS_CONN_ID,
- ),
- RedshiftPauseClusterTrigger(
- cluster_identifier=TEST_CLUSTER_IDENTIFIER,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPT,
- aws_conn_id=TEST_AWS_CONN_ID,
- ),
- RedshiftCreateClusterSnapshotTrigger(
- cluster_identifier=TEST_CLUSTER_IDENTIFIER,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPT,
- aws_conn_id=TEST_AWS_CONN_ID,
- ),
- RedshiftResumeClusterTrigger(
- cluster_identifier=TEST_CLUSTER_IDENTIFIER,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPT,
- aws_conn_id=TEST_AWS_CONN_ID,
- ),
- RedshiftDeleteClusterTrigger(
- cluster_identifier=TEST_CLUSTER_IDENTIFIER,
- poll_interval=TEST_POLL_INTERVAL,
- max_attempts=TEST_MAX_ATTEMPT,
- aws_conn_id=TEST_AWS_CONN_ID,
- ),
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
diff --git a/tests/providers/amazon/aws/triggers/test_serialization.py
b/tests/providers/amazon/aws/triggers/test_serialization.py
new file mode 100644
index 0000000000..e4a343e98f
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_serialization.py
@@ -0,0 +1,375 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.jobs.triggerer_job_runner import TriggerRunner
+from airflow.models import Trigger
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.amazon.aws.triggers.batch import
BatchCreateComputeEnvironmentTrigger, BatchJobTrigger
+from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger,
ClusterInactiveTrigger
+from airflow.providers.amazon.aws.triggers.eks import (
+ EksCreateClusterTrigger,
+ EksCreateFargateProfileTrigger,
+ EksCreateNodegroupTrigger,
+ EksDeleteClusterTrigger,
+ EksDeleteFargateProfileTrigger,
+ EksDeleteNodegroupTrigger,
+)
+from airflow.providers.amazon.aws.triggers.emr import (
+ EmrAddStepsTrigger,
+ EmrContainerTrigger,
+ EmrCreateJobFlowTrigger,
+ EmrServerlessCancelJobsTrigger,
+ EmrServerlessCreateApplicationTrigger,
+ EmrServerlessDeleteApplicationTrigger,
+ EmrServerlessStartApplicationTrigger,
+ EmrServerlessStartJobTrigger,
+ EmrServerlessStopApplicationTrigger,
+ EmrStepSensorTrigger,
+ EmrTerminateJobFlowTrigger,
+)
+from airflow.providers.amazon.aws.triggers.glue import
GlueCatalogPartitionTrigger
+from airflow.providers.amazon.aws.triggers.glue_crawler import
GlueCrawlerCompleteTrigger
+from airflow.providers.amazon.aws.triggers.lambda_function import
LambdaCreateFunctionCompleteTrigger
+from airflow.providers.amazon.aws.triggers.rds import (
+ RdsDbAvailableTrigger,
+ RdsDbDeletedTrigger,
+ RdsDbStoppedTrigger,
+)
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+ RedshiftCreateClusterSnapshotTrigger,
+ RedshiftCreateClusterTrigger,
+ RedshiftDeleteClusterTrigger,
+ RedshiftPauseClusterTrigger,
+ RedshiftResumeClusterTrigger,
+)
+from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+from airflow.providers.amazon.aws.triggers.step_function import
StepFunctionsExecutionCompleteTrigger
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
+from airflow.serialization.serialized_objects import BaseSerialization
+
+BATCH_JOB_ID = "job_id"
+
+TEST_CLUSTER_IDENTIFIER = "test-cluster"
+TEST_FARGATE_PROFILE_NAME = "test-fargate-profile"
+TEST_NODEGROUP_NAME = "test-nodegroup"
+
+TEST_JOB_FLOW_ID = "test-job-flow-id"
+VIRTUAL_CLUSTER_ID = "vzwemreks"
+JOB_ID = "job-1234"
+TARGET_STATE = ["TERMINATED"]
+STEP_ID = "s-1234"
+
+TEST_APPLICATION_ID = "test-application-id"
+TEST_JOB_ID = "test-job-id"
+
+TEST_FUNCTION_NAME = "test-function-name"
+
+TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
+TEST_RESPONSE = {
+ "DBInstance": {
+ "DBInstanceIdentifier": "test-db-instance-identifier",
+ "DBInstanceStatus": "test-db-instance-status",
+ }
+}
+
+TEST_SQS_QUEUE = "test-sqs-queue"
+TEST_MAX_MESSAGES = 1
+TEST_NUM_BATCHES = 1
+TEST_WAIT_TIME_SECONDS = 1
+TEST_VISIBILITY_TIMEOUT = 1
+TEST_MESSAGE_FILTERING = "literal"
+TEST_MESSAGE_FILTERING_MATCH_VALUES = "test"
+TEST_MESSAGE_FILTERING_CONFIG = "test-message-filtering-config"
+TEST_DELETE_MESSAGE_ON_RECEPTION = False
+
+TEST_ARN = "test-aws-arn"
+
+WAITER_DELAY = 5
+MAX_ATTEMPTS = 5
+AWS_CONN_ID = "aws_batch_job_conn"
+AWS_REGION = "us-east-2"
+
+
+pytest.importorskip("aiobotocore")
+
+
+def gen_test_name(trigger):
+ """Gives to tests the name of the class being tested."""
+ return trigger.__class__.__name__
+
+
+class TestTriggersSerialization:
+ @pytest.mark.parametrize(
+ "trigger",
+ [
+ AthenaTrigger("query_id", 1, 5, "aws connection"),
+ BatchJobTrigger(
+ job_id=BATCH_JOB_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ ),
+ BatchCreateComputeEnvironmentTrigger(
+ compute_env_arn="my_arn",
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ ),
+ ClusterActiveTrigger(
+ cluster_arn="my_arn",
+ aws_conn_id="my_conn",
+ waiter_delay=1,
+ waiter_max_attempts=2,
+ region_name="my_region",
+ ),
+ ClusterInactiveTrigger(
+ cluster_arn="my_arn",
+ aws_conn_id="my_conn",
+ waiter_delay=1,
+ waiter_max_attempts=2,
+ region_name="my_region",
+ ),
+ EksCreateFargateProfileTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EksDeleteFargateProfileTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ fargate_profile_name=TEST_FARGATE_PROFILE_NAME,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EksCreateNodegroupTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ nodegroup_name=TEST_NODEGROUP_NAME,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ region_name=AWS_REGION,
+ ),
+ EksDeleteNodegroupTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ nodegroup_name=TEST_NODEGROUP_NAME,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ region_name=AWS_REGION,
+ ),
+ EksCreateClusterTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=WAITER_DELAY,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ ),
+ EksDeleteClusterTrigger(
+ cluster_name=TEST_CLUSTER_IDENTIFIER,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=WAITER_DELAY,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ force_delete_compute=True,
+ ),
+ EmrAddStepsTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ step_ids=["my_step1", "my_step2"],
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrCreateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrTerminateJobFlowTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrContainerTrigger(
+ virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+ job_id=JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=WAITER_DELAY,
+ ),
+ EmrStepSensorTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ step_id=STEP_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ ),
+ EmrServerlessCreateApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrServerlessStartApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrServerlessStopApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrServerlessDeleteApplicationTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrServerlessCancelJobsTrigger(
+ application_id=TEST_APPLICATION_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ EmrServerlessStartJobTrigger(
+ application_id=TEST_APPLICATION_ID,
+ job_id=TEST_JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ GlueCrawlerCompleteTrigger(crawler_name="my_crawler",
waiter_delay=2, aws_conn_id="my_conn_id"),
+ GlueCatalogPartitionTrigger(
+ database_name="my_database",
+ table_name="my_table",
+ expression="my_expression",
+ aws_conn_id="my_conn_id",
+ ),
+ LambdaCreateFunctionCompleteTrigger(
+ function_name=TEST_FUNCTION_NAME,
+ function_arn=TEST_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ ),
+ RedshiftCreateClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=WAITER_DELAY,
+ max_attempt=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ ),
+ RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ ),
+ RedshiftCreateClusterSnapshotTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ ),
+ RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ ),
+ RedshiftDeleteClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=WAITER_DELAY,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ ),
+ RdsDbAvailableTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ RdsDbDeletedTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ RdsDbStoppedTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ SqsSensorTrigger(
+ sqs_queue=TEST_SQS_QUEUE,
+ aws_conn_id=AWS_CONN_ID,
+ max_messages=TEST_MAX_MESSAGES,
+ num_batches=TEST_NUM_BATCHES,
+ wait_time_seconds=TEST_WAIT_TIME_SECONDS,
+ visibility_timeout=TEST_VISIBILITY_TIMEOUT,
+ message_filtering=TEST_MESSAGE_FILTERING,
+
message_filtering_match_values=TEST_MESSAGE_FILTERING_MATCH_VALUES,
+ message_filtering_config=TEST_MESSAGE_FILTERING_CONFIG,
+ delete_message_on_reception=TEST_DELETE_MESSAGE_ON_RECEPTION,
+ waiter_delay=WAITER_DELAY,
+ ),
+ StepFunctionsExecutionCompleteTrigger(
+ execution_arn=TEST_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ waiter_delay=WAITER_DELAY,
+ waiter_max_attempts=MAX_ATTEMPTS,
+ region_name=AWS_REGION,
+ ),
+ ],
+ ids=gen_test_name,
+ )
+ def test_serialize_recreate(self, trigger):
+ # generate the DB object from the trigger
+ trigger_db: Trigger = Trigger.from_object(trigger)
+
+ # serialize/deserialize using the same method that is used when
inserting in DB
+ json_params = BaseSerialization.serialize(trigger_db.kwargs)
+ retrieved_params = BaseSerialization.deserialize(json_params)
+
+ # recreate a new trigger object from the data we would have in DB
+ clazz = TriggerRunner().get_trigger_by_classpath(trigger_db.classpath)
+ # noinspection PyArgumentList
+ instance = clazz(**retrieved_params)
+
+ # recreate a DB column object from the new trigger so that we can
easily compare attributes
+ trigger_db_2: Trigger = Trigger.from_object(instance)
+
+ assert trigger_db.classpath == trigger_db_2.classpath
+ assert trigger_db.kwargs == trigger_db_2.kwargs
diff --git a/tests/providers/amazon/aws/triggers/test_sqs.py
b/tests/providers/amazon/aws/triggers/test_sqs.py
index 74ffd837a1..48bbf6bde6 100644
--- a/tests/providers/amazon/aws/triggers/test_sqs.py
+++ b/tests/providers/amazon/aws/triggers/test_sqs.py
@@ -50,24 +50,6 @@ trigger = SqsSensorTrigger(
class TestSqsTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- trigger,
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2
-
@pytest.mark.asyncio
async def test_poke(self):
sqs_trigger = trigger
diff --git a/tests/providers/amazon/aws/triggers/test_step_function.py
b/tests/providers/amazon/aws/triggers/test_step_function.py
deleted file mode 100644
index d0c25e096f..0000000000
--- a/tests/providers/amazon/aws/triggers/test_step_function.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.providers.amazon.aws.triggers.step_function import
StepFunctionsExecutionCompleteTrigger
-
-TEST_EXECUTION_ARN = "test-execution-arn"
-TEST_WAITER_DELAY = 10
-TEST_WAITER_MAX_ATTEMPTS = 10
-TEST_AWS_CONN_ID = "test-conn-id"
-TEST_REGION_NAME = "test-region-name"
-
-
-class TestStepFunctionsTriggers:
- @pytest.mark.parametrize(
- "trigger",
- [
- StepFunctionsExecutionCompleteTrigger(
- execution_arn=TEST_EXECUTION_ARN,
- aws_conn_id=TEST_AWS_CONN_ID,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- region_name=TEST_REGION_NAME,
- )
- ],
- )
- def test_serialize_recreate(self, trigger):
- class_path, args = trigger.serialize()
-
- class_name = class_path.split(".")[-1]
- clazz = globals()[class_name]
- instance = clazz(**args)
-
- class_path2, args2 = instance.serialize()
-
- assert class_path == class_path2
- assert args == args2