This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 1e0a99c5f4 Use base aws classes in Amazon EventBridge Operators (#36765) 1e0a99c5f4 is described below commit 1e0a99c5f482a5d243db8908200bdfe157fd0a06 Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Mon Jan 15 03:16:10 2024 +0400 Use base aws classes in Amazon EventBridge Operators (#36765) --- .../providers/amazon/aws/operators/eventbridge.py | 137 +++++++++------------ .../operators/eventbridge.rst | 5 + .../amazon/aws/operators/test_eventbridge.py | 92 ++++++++++++-- 3 files changed, 142 insertions(+), 92 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/eventbridge.py b/airflow/providers/amazon/aws/operators/eventbridge.py index 70fb8a05bb..5d02d275a0 100644 --- a/airflow/providers/amazon/aws/operators/eventbridge.py +++ b/airflow/providers/amazon/aws/operators/eventbridge.py @@ -16,19 +16,19 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.eventbridge import EventBridgeHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context -class EventBridgePutEventsOperator(BaseOperator): +class EventBridgePutEventsOperator(AwsBaseOperator[EventBridgeHook]): """ Put Events onto Amazon EventBridge. @@ -38,32 +38,25 @@ class EventBridgePutEventsOperator(BaseOperator): :param entries: the list of events to be put onto EventBridge, each event is a dict (required) :param endpoint_id: the URL subdomain of the endpoint - :param aws_conn_id: the AWS connection to use - :param region_name: the region where events are to be sent - + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt """ - template_fields: Sequence[str] = ("entries", "endpoint_id", "aws_conn_id", "region_name") + aws_hook_class = EventBridgeHook + template_fields: Sequence[str] = aws_template_fields("entries", "endpoint_id") - def __init__( - self, - *, - entries: list[dict], - endpoint_id: str | None = None, - aws_conn_id: str = "aws_default", - region_name: str | None = None, - **kwargs, - ): + def __init__(self, *, entries: list[dict], endpoint_id: str | None = None, **kwargs): super().__init__(**kwargs) self.entries = entries self.endpoint_id = endpoint_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name - - @cached_property - def hook(self) -> EventBridgeHook: - """Create and return an EventBridgeHook.""" - return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): response = self.hook.conn.put_events( @@ -90,7 +83,7 @@ class EventBridgePutEventsOperator(BaseOperator): return [e["EventId"] for e in response["Entries"]] -class EventBridgePutRuleOperator(BaseOperator): +class EventBridgePutRuleOperator(AwsBaseOperator[EventBridgeHook]): """ Create or update a specified EventBridge rule. @@ -106,12 +99,20 @@ class EventBridgePutRuleOperator(BaseOperator): :param schedule_expression: the scheduling expression (for example, a cron or rate expression) :param state: indicates whether rule is set to be "ENABLED" or "DISABLED" :param tags: list of key-value pairs to associate with the rule - :param region: the region where rule is to be created or updated - + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt """ - template_fields: Sequence[str] = ( - "aws_conn_id", + aws_hook_class = EventBridgeHook + template_fields: Sequence[str] = aws_template_fields( "name", "description", "event_bus_name", @@ -120,7 +121,6 @@ class EventBridgePutRuleOperator(BaseOperator): "schedule_expression", "state", "tags", - "region_name", ) def __init__( @@ -134,8 +134,6 @@ class EventBridgePutRuleOperator(BaseOperator): schedule_expression: str | None = None, state: str | None = None, tags: list | None = None, - region_name: str | None = None, - aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -144,16 +142,9 @@ class EventBridgePutRuleOperator(BaseOperator): self.event_bus_name = event_bus_name self.event_pattern = event_pattern self.role_arn = role_arn - self.region_name = region_name self.schedule_expression = schedule_expression self.state = state self.tags = tags - self.aws_conn_id = aws_conn_id - - @cached_property - def hook(self) -> EventBridgeHook: - """Create and return an EventBridgeHook.""" - return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): self.log.info('Sending rule "%s" to EventBridge.', self.name) @@ -170,7 +161,7 @@ class EventBridgePutRuleOperator(BaseOperator): ) -class EventBridgeEnableRuleOperator(BaseOperator): +class EventBridgeEnableRuleOperator(AwsBaseOperator[EventBridgeHook]): """ Enable an EventBridge Rule. @@ -180,32 +171,25 @@ class EventBridgeEnableRuleOperator(BaseOperator): :param name: the name of the rule to enable :param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted) - :param aws_conn_id: the AWS connection to use - :param region_name: the region of the rule to be enabled - + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt """ - template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id") + aws_hook_class = EventBridgeHook + template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name") - def __init__( - self, - *, - name: str, - event_bus_name: str | None = None, - region_name: str | None = None, - aws_conn_id: str = "aws_default", - **kwargs, - ): + def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs): super().__init__(**kwargs) self.name = name self.event_bus_name = event_bus_name - self.region_name = region_name - self.aws_conn_id = aws_conn_id - - @cached_property - def hook(self) -> EventBridgeHook: - """Create and return an EventBridgeHook.""" - return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): self.hook.conn.enable_rule( @@ -220,7 +204,7 @@ class EventBridgeEnableRuleOperator(BaseOperator): self.log.info('Enabled rule "%s"', self.name) -class EventBridgeDisableRuleOperator(BaseOperator): +class EventBridgeDisableRuleOperator(AwsBaseOperator[EventBridgeHook]): """ Disable an EventBridge Rule. @@ -230,32 +214,25 @@ class EventBridgeDisableRuleOperator(BaseOperator): :param name: the name of the rule to disable :param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted) - :param aws_conn_id: the AWS connection to use - :param region_name: the region of the rule to be disabled - + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt """ - template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id") + aws_hook_class = EventBridgeHook + template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name") - def __init__( - self, - *, - name: str, - event_bus_name: str | None = None, - region_name: str | None = None, - aws_conn_id: str = "aws_default", - **kwargs, - ): + def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs): super().__init__(**kwargs) self.name = name self.event_bus_name = event_bus_name - self.region_name = region_name - self.aws_conn_id = aws_conn_id - - @cached_property - def hook(self) -> EventBridgeHook: - """Create and return an EventBridgeHook.""" - return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): self.hook.conn.disable_rule( diff --git a/docs/apache-airflow-providers-amazon/operators/eventbridge.rst b/docs/apache-airflow-providers-amazon/operators/eventbridge.rst index 302f3657ec..453e5af310 100644 --- a/docs/apache-airflow-providers-amazon/operators/eventbridge.rst +++ b/docs/apache-airflow-providers-amazon/operators/eventbridge.rst @@ -31,6 +31,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_eventbridge.py b/tests/providers/amazon/aws/operators/test_eventbridge.py index 9527439777..4dcd068a81 100644 --- a/tests/providers/amazon/aws/operators/test_eventbridge.py +++ b/tests/providers/amazon/aws/operators/test_eventbridge.py @@ -41,12 +41,28 @@ RULE_NAME = "match_s3_events" class TestEventBridgePutEventsOperator: def test_init(self): - operator = EventBridgePutEventsOperator( + op = EventBridgePutEventsOperator( task_id="put_events_job", entries=ENTRIES, + aws_conn_id="fake-conn-id", + region_name="eu-central-1", + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, ) - - assert operator.entries == ENTRIES + assert op.entries == ENTRIES + assert op.hook.client_type == "events" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-central-1" + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = EventBridgePutEventsOperator(task_id="put_events_job", entries=ENTRIES) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(EventBridgeHook, "conn") def test_execute(self, mock_conn: MagicMock): @@ -83,11 +99,31 @@ class TestEventBridgePutEventsOperator: class TestEventBridgePutRuleOperator: def test_init(self): - operator = EventBridgePutRuleOperator( + op = EventBridgePutRuleOperator( + task_id="events_put_rule_job", + name=RULE_NAME, + event_pattern=EVENT_PATTERN, + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, + ) + assert op.event_pattern == EVENT_PATTERN + assert op.hook.client_type == "events" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = EventBridgePutRuleOperator( task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN ) - - assert operator.event_pattern == EVENT_PATTERN + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(EventBridgeHook, "conn") def test_execute(self, mock_conn: MagicMock): @@ -117,12 +153,28 @@ class TestEventBridgePutRuleOperator: class TestEventBridgeEnableRuleOperator: def test_init(self): - operator = EventBridgeDisableRuleOperator( + op = EventBridgeEnableRuleOperator( task_id="enable_rule_task", name=RULE_NAME, + aws_conn_id="fake-conn-id", + region_name="us-west-1", + verify=False, + botocore_config={"read_timeout": 42}, ) - - assert operator.name == RULE_NAME + assert op.name == RULE_NAME + assert op.hook.client_type == "events" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "us-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = EventBridgeEnableRuleOperator(task_id="enable_rule_task", name=RULE_NAME) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(EventBridgeHook, "conn") def test_enable_rule(self, mock_conn: MagicMock): @@ -137,12 +189,28 @@ class TestEventBridgeEnableRuleOperator: class TestEventBridgeDisableRuleOperator: def test_init(self): - operator = EventBridgeDisableRuleOperator( + op = EventBridgeDisableRuleOperator( task_id="disable_rule_task", name=RULE_NAME, + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, ) - - assert operator.name == RULE_NAME + assert op.name == RULE_NAME + assert op.hook.client_type == "events" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "ca-west-1" + assert op.hook._verify is True + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = EventBridgeDisableRuleOperator(task_id="disable_rule_task", name=RULE_NAME) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(EventBridgeHook, "conn") def test_disable_rule(self, mock_conn: MagicMock):