This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e0a99c5f4 Use base aws classes in Amazon EventBridge Operators 
(#36765)
1e0a99c5f4 is described below

commit 1e0a99c5f482a5d243db8908200bdfe157fd0a06
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Mon Jan 15 03:16:10 2024 +0400

    Use base aws classes in Amazon EventBridge Operators (#36765)
---
 .../providers/amazon/aws/operators/eventbridge.py  | 137 +++++++++------------
 .../operators/eventbridge.rst                      |   5 +
 .../amazon/aws/operators/test_eventbridge.py       |  92 ++++++++++++--
 3 files changed, 142 insertions(+), 92 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/eventbridge.py 
b/airflow/providers/amazon/aws/operators/eventbridge.py
index 70fb8a05bb..5d02d275a0 100644
--- a/airflow/providers/amazon/aws/operators/eventbridge.py
+++ b/airflow/providers/amazon/aws/operators/eventbridge.py
@@ -16,19 +16,19 @@
 # under the License.
 from __future__ import annotations
 
-from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.eventbridge import EventBridgeHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class EventBridgePutEventsOperator(BaseOperator):
+class EventBridgePutEventsOperator(AwsBaseOperator[EventBridgeHook]):
     """
     Put Events onto Amazon EventBridge.
 
@@ -38,32 +38,25 @@ class EventBridgePutEventsOperator(BaseOperator):
 
     :param entries: the list of events to be put onto EventBridge, each event 
is a dict (required)
     :param endpoint_id: the URL subdomain of the endpoint
-    :param aws_conn_id: the AWS connection to use
-    :param region_name: the region where events are to be sent
-
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
     """
 
-    template_fields: Sequence[str] = ("entries", "endpoint_id", "aws_conn_id", 
"region_name")
+    aws_hook_class = EventBridgeHook
+    template_fields: Sequence[str] = aws_template_fields("entries", 
"endpoint_id")
 
-    def __init__(
-        self,
-        *,
-        entries: list[dict],
-        endpoint_id: str | None = None,
-        aws_conn_id: str = "aws_default",
-        region_name: str | None = None,
-        **kwargs,
-    ):
+    def __init__(self, *, entries: list[dict], endpoint_id: str | None = None, 
**kwargs):
         super().__init__(**kwargs)
         self.entries = entries
         self.endpoint_id = endpoint_id
-        self.aws_conn_id = aws_conn_id
-        self.region_name = region_name
-
-    @cached_property
-    def hook(self) -> EventBridgeHook:
-        """Create and return an EventBridgeHook."""
-        return EventBridgeHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     def execute(self, context: Context):
         response = self.hook.conn.put_events(
@@ -90,7 +83,7 @@ class EventBridgePutEventsOperator(BaseOperator):
             return [e["EventId"] for e in response["Entries"]]
 
 
-class EventBridgePutRuleOperator(BaseOperator):
+class EventBridgePutRuleOperator(AwsBaseOperator[EventBridgeHook]):
     """
     Create or update a specified EventBridge rule.
 
@@ -106,12 +99,20 @@ class EventBridgePutRuleOperator(BaseOperator):
     :param schedule_expression: the scheduling expression (for example, a cron 
or rate expression)
     :param state: indicates whether rule is set to be "ENABLED" or "DISABLED"
     :param tags: list of key-value pairs to associate with the rule
-    :param region: the region where rule is to be created or updated
-
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
     """
 
-    template_fields: Sequence[str] = (
-        "aws_conn_id",
+    aws_hook_class = EventBridgeHook
+    template_fields: Sequence[str] = aws_template_fields(
         "name",
         "description",
         "event_bus_name",
@@ -120,7 +121,6 @@ class EventBridgePutRuleOperator(BaseOperator):
         "schedule_expression",
         "state",
         "tags",
-        "region_name",
     )
 
     def __init__(
@@ -134,8 +134,6 @@ class EventBridgePutRuleOperator(BaseOperator):
         schedule_expression: str | None = None,
         state: str | None = None,
         tags: list | None = None,
-        region_name: str | None = None,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -144,16 +142,9 @@ class EventBridgePutRuleOperator(BaseOperator):
         self.event_bus_name = event_bus_name
         self.event_pattern = event_pattern
         self.role_arn = role_arn
-        self.region_name = region_name
         self.schedule_expression = schedule_expression
         self.state = state
         self.tags = tags
-        self.aws_conn_id = aws_conn_id
-
-    @cached_property
-    def hook(self) -> EventBridgeHook:
-        """Create and return an EventBridgeHook."""
-        return EventBridgeHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     def execute(self, context: Context):
         self.log.info('Sending rule "%s" to EventBridge.', self.name)
@@ -170,7 +161,7 @@ class EventBridgePutRuleOperator(BaseOperator):
         )
 
 
-class EventBridgeEnableRuleOperator(BaseOperator):
+class EventBridgeEnableRuleOperator(AwsBaseOperator[EventBridgeHook]):
     """
     Enable an EventBridge Rule.
 
@@ -180,32 +171,25 @@ class EventBridgeEnableRuleOperator(BaseOperator):
 
     :param name: the name of the rule to enable
     :param event_bus_name: the name or ARN of the event bus associated with 
the rule (default if omitted)
-    :param aws_conn_id: the AWS connection to use
-    :param region_name: the region of the rule to be enabled
-
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
     """
 
-    template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", 
"aws_conn_id")
+    aws_hook_class = EventBridgeHook
+    template_fields: Sequence[str] = aws_template_fields("name", 
"event_bus_name")
 
-    def __init__(
-        self,
-        *,
-        name: str,
-        event_bus_name: str | None = None,
-        region_name: str | None = None,
-        aws_conn_id: str = "aws_default",
-        **kwargs,
-    ):
+    def __init__(self, *, name: str, event_bus_name: str | None = None, 
**kwargs):
         super().__init__(**kwargs)
         self.name = name
         self.event_bus_name = event_bus_name
-        self.region_name = region_name
-        self.aws_conn_id = aws_conn_id
-
-    @cached_property
-    def hook(self) -> EventBridgeHook:
-        """Create and return an EventBridgeHook."""
-        return EventBridgeHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     def execute(self, context: Context):
         self.hook.conn.enable_rule(
@@ -220,7 +204,7 @@ class EventBridgeEnableRuleOperator(BaseOperator):
         self.log.info('Enabled rule "%s"', self.name)
 
 
-class EventBridgeDisableRuleOperator(BaseOperator):
+class EventBridgeDisableRuleOperator(AwsBaseOperator[EventBridgeHook]):
     """
     Disable an EventBridge Rule.
 
@@ -230,32 +214,25 @@ class EventBridgeDisableRuleOperator(BaseOperator):
 
     :param name: the name of the rule to disable
     :param event_bus_name: the name or ARN of the event bus associated with 
the rule (default if omitted)
-    :param aws_conn_id: the AWS connection to use
-    :param region_name: the region of the rule to be disabled
-
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
     """
 
-    template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", 
"aws_conn_id")
+    aws_hook_class = EventBridgeHook
+    template_fields: Sequence[str] = aws_template_fields("name", 
"event_bus_name")
 
-    def __init__(
-        self,
-        *,
-        name: str,
-        event_bus_name: str | None = None,
-        region_name: str | None = None,
-        aws_conn_id: str = "aws_default",
-        **kwargs,
-    ):
+    def __init__(self, *, name: str, event_bus_name: str | None = None, 
**kwargs):
         super().__init__(**kwargs)
         self.name = name
         self.event_bus_name = event_bus_name
-        self.region_name = region_name
-        self.aws_conn_id = aws_conn_id
-
-    @cached_property
-    def hook(self) -> EventBridgeHook:
-        """Create and return an EventBridgeHook."""
-        return EventBridgeHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     def execute(self, context: Context):
         self.hook.conn.disable_rule(
diff --git a/docs/apache-airflow-providers-amazon/operators/eventbridge.rst 
b/docs/apache-airflow-providers-amazon/operators/eventbridge.rst
index 302f3657ec..453e5af310 100644
--- a/docs/apache-airflow-providers-amazon/operators/eventbridge.rst
+++ b/docs/apache-airflow-providers-amazon/operators/eventbridge.rst
@@ -31,6 +31,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_eventbridge.py 
b/tests/providers/amazon/aws/operators/test_eventbridge.py
index 9527439777..4dcd068a81 100644
--- a/tests/providers/amazon/aws/operators/test_eventbridge.py
+++ b/tests/providers/amazon/aws/operators/test_eventbridge.py
@@ -41,12 +41,28 @@ RULE_NAME = "match_s3_events"
 
 class TestEventBridgePutEventsOperator:
     def test_init(self):
-        operator = EventBridgePutEventsOperator(
+        op = EventBridgePutEventsOperator(
             task_id="put_events_job",
             entries=ENTRIES,
+            aws_conn_id="fake-conn-id",
+            region_name="eu-central-1",
+            verify="/spam/egg.pem",
+            botocore_config={"read_timeout": 42},
         )
-
-        assert operator.entries == ENTRIES
+        assert op.entries == ENTRIES
+        assert op.hook.client_type == "events"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "eu-central-1"
+        assert op.hook._verify == "/spam/egg.pem"
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = EventBridgePutEventsOperator(task_id="put_events_job", 
entries=ENTRIES)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
 
     @mock.patch.object(EventBridgeHook, "conn")
     def test_execute(self, mock_conn: MagicMock):
@@ -83,11 +99,31 @@ class TestEventBridgePutEventsOperator:
 
 class TestEventBridgePutRuleOperator:
     def test_init(self):
-        operator = EventBridgePutRuleOperator(
+        op = EventBridgePutRuleOperator(
+            task_id="events_put_rule_job",
+            name=RULE_NAME,
+            event_pattern=EVENT_PATTERN,
+            aws_conn_id="fake-conn-id",
+            region_name="eu-west-1",
+            verify="/spam/egg.pem",
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.event_pattern == EVENT_PATTERN
+        assert op.hook.client_type == "events"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify == "/spam/egg.pem"
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = EventBridgePutRuleOperator(
             task_id="events_put_rule_job", name=RULE_NAME, 
event_pattern=EVENT_PATTERN
         )
-
-        assert operator.event_pattern == EVENT_PATTERN
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
 
     @mock.patch.object(EventBridgeHook, "conn")
     def test_execute(self, mock_conn: MagicMock):
@@ -117,12 +153,28 @@ class TestEventBridgePutRuleOperator:
 
 class TestEventBridgeEnableRuleOperator:
     def test_init(self):
-        operator = EventBridgeDisableRuleOperator(
+        op = EventBridgeEnableRuleOperator(
             task_id="enable_rule_task",
             name=RULE_NAME,
+            aws_conn_id="fake-conn-id",
+            region_name="us-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
         )
-
-        assert operator.name == RULE_NAME
+        assert op.name == RULE_NAME
+        assert op.hook.client_type == "events"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = EventBridgeEnableRuleOperator(task_id="enable_rule_task", 
name=RULE_NAME)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
 
     @mock.patch.object(EventBridgeHook, "conn")
     def test_enable_rule(self, mock_conn: MagicMock):
@@ -137,12 +189,28 @@ class TestEventBridgeEnableRuleOperator:
 
 class TestEventBridgeDisableRuleOperator:
     def test_init(self):
-        operator = EventBridgeDisableRuleOperator(
+        op = EventBridgeDisableRuleOperator(
             task_id="disable_rule_task",
             name=RULE_NAME,
+            aws_conn_id="fake-conn-id",
+            region_name="ca-west-1",
+            verify=True,
+            botocore_config={"read_timeout": 42},
         )
-
-        assert operator.name == RULE_NAME
+        assert op.name == RULE_NAME
+        assert op.hook.client_type == "events"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "ca-west-1"
+        assert op.hook._verify is True
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = EventBridgeDisableRuleOperator(task_id="disable_rule_task", 
name=RULE_NAME)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
 
     @mock.patch.object(EventBridgeHook, "conn")
     def test_disable_rule(self, mock_conn: MagicMock):

Reply via email to