This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2c4928da40 introduce a method to convert dictionaries to boto-style
key-value lists (#28816)
2c4928da40 is described below
commit 2c4928da40667cd4d52030b8b79419175948cb85
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Tue Jan 24 15:45:16 2023 -0800
introduce a method to convert dictionaries to boto-style key-value lists
(#28816)
* accept either dict of list for tags
---
airflow/providers/amazon/aws/hooks/s3.py | 28 ++++++++++------
airflow/providers/amazon/aws/hooks/sagemaker.py | 5 ++-
airflow/providers/amazon/aws/operators/rds.py | 32 ++++++++++--------
airflow/providers/amazon/aws/operators/s3.py | 4 +--
.../providers/amazon/aws/operators/sagemaker.py | 4 +--
airflow/providers/amazon/aws/utils/tags.py | 38 ++++++++++++++++++++++
tests/providers/amazon/aws/hooks/test_s3.py | 9 +++++
7 files changed, 89 insertions(+), 31 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py
b/airflow/providers/amazon/aws/hooks/s3.py
index 89c9261cb6..f88274747d 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -43,6 +43,7 @@ from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks
T = TypeVar("T", bound=Callable)
@@ -1063,36 +1064,43 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
def put_bucket_tagging(
self,
- tag_set: list[dict[str, str]] | None = None,
+ tag_set: dict[str, str] | list[dict[str, str]] | None = None,
key: str | None = None,
value: str | None = None,
bucket_name: str | None = None,
) -> None:
"""
- Overwrites the existing TagSet with provided tags. Must provide
either a TagSet or a key/value pair.
+ Overwrites the existing TagSet with provided tags.
+ Must provide a TagSet, a key/value pair, or both.
.. seealso::
- :external+boto3:py:meth:`S3.Client.put_bucket_tagging`
- :param tag_set: A List containing the key/value pairs for the tags.
+ :param tag_set: A dictionary containing the key/value pairs for the
tags,
+ or a list already formatted for the API
:param key: The Key for the new TagSet entry.
:param value: The Value for the new TagSet entry.
:param bucket_name: The name of the bucket.
+
:return: None
"""
- self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key,
value, tag_set)
- if not tag_set:
- tag_set = []
+ formatted_tags = format_tags(tag_set)
+
if key and value:
- tag_set.append({"Key": key, "Value": value})
- elif not tag_set or (key or value):
- message = "put_bucket_tagging() requires either a predefined
TagSet or a key/value pair."
+ formatted_tags.append({"Key": key, "Value": value})
+ elif key or value:
+ message = (
+ "Key and Value must be specified as a pair. "
+ f"Only one of the two had a value (key: '{key}', value:
'{value}')"
+ )
self.log.error(message)
raise ValueError(message)
+ self.log.info("Tagging S3 Bucket %s with %s", bucket_name,
formatted_tags)
+
try:
s3_client = self.get_conn()
- s3_client.put_bucket_tagging(Bucket=bucket_name,
Tagging={"TagSet": tag_set})
+ s3_client.put_bucket_tagging(Bucket=bucket_name,
Tagging={"TagSet": formatted_tags})
except ClientError as e:
self.log.error(e)
raise e
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index c5aeb3d9ed..4c731f2051 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -35,6 +35,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils import timezone
@@ -1100,9 +1101,7 @@ class SageMakerHook(AwsBaseHook):
:return: the ARN of the pipeline execution launched.
"""
- if pipeline_params is None:
- pipeline_params = {}
- formatted_params = [{"Name": kvp[0], "Value": kvp[1]} for kvp in
pipeline_params.items()]
+ formatted_params = format_tags(pipeline_params, key_label="Name")
try:
res = self.conn.start_pipeline_execution(
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index c10e969c8b..2f2cf58438 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -25,6 +25,7 @@ from mypy_boto3_rds.type_defs import TagTypeDef
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
+from airflow.providers.amazon.aws.utils.tags import format_tags
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -64,7 +65,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
:param db_type: Type of the DB - either "instance" or "cluster"
:param db_identifier: The identifier of the instance or cluster that you
want to create the snapshot of
:param db_snapshot_identifier: The identifier for the DB snapshot
- :param tags: A list of tags in format `[{"Key": "something", "Value":
"something"},]
+ :param tags: A dictionary of tags or a list of tags in format `[{"Key":
"...", "Value": "..."},]`
`USER Tagging
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the DB
snapshot to complete. (default: True)
"""
@@ -77,7 +78,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
db_type: str,
db_identifier: str,
db_snapshot_identifier: str,
- tags: Sequence[TagTypeDef] | None = None,
+ tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_conn_id",
**kwargs,
@@ -86,7 +87,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
self.db_type = RdsDbType(db_type)
self.db_identifier = db_identifier
self.db_snapshot_identifier = db_snapshot_identifier
- self.tags = tags or []
+ self.tags = tags
self.wait_for_completion = wait_for_completion
def execute(self, context: Context) -> str:
@@ -97,11 +98,12 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
self.db_snapshot_identifier,
)
+ formatted_tags = format_tags(self.tags)
if self.db_type.value == "instance":
create_instance_snap = self.hook.conn.create_db_snapshot(
DBInstanceIdentifier=self.db_identifier,
DBSnapshotIdentifier=self.db_snapshot_identifier,
- Tags=self.tags,
+ Tags=formatted_tags,
)
create_response = json.dumps(create_instance_snap, default=str)
if self.wait_for_completion:
@@ -110,7 +112,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=self.db_identifier,
DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
- Tags=self.tags,
+ Tags=formatted_tags,
)
create_response = json.dumps(create_cluster_snap, default=str)
if self.wait_for_completion:
@@ -132,7 +134,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
:param source_db_snapshot_identifier: The identifier of the source snapshot
:param target_db_snapshot_identifier: The identifier of the target snapshot
:param kms_key_id: The AWS KMS key identifier for an encrypted DB snapshot
- :param tags: A list of tags in format `[{"Key": "something", "Value":
"something"},]
+ :param tags: A dictionary of tags or a list of tags in format `[{"Key":
"...", "Value": "..."},]`
`USER Tagging
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param copy_tags: Whether to copy all tags from the source snapshot to the
target snapshot (default False)
:param pre_signed_url: The URL that contains a Signature Version 4 signed
request
@@ -159,7 +161,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
source_db_snapshot_identifier: str,
target_db_snapshot_identifier: str,
kms_key_id: str = "",
- tags: Sequence[TagTypeDef] | None = None,
+ tags: Sequence[TagTypeDef] | dict | None = None,
copy_tags: bool = False,
pre_signed_url: str = "",
option_group_name: str = "",
@@ -175,7 +177,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
self.source_db_snapshot_identifier = source_db_snapshot_identifier
self.target_db_snapshot_identifier = target_db_snapshot_identifier
self.kms_key_id = kms_key_id
- self.tags = tags or []
+ self.tags = tags
self.copy_tags = copy_tags
self.pre_signed_url = pre_signed_url
self.option_group_name = option_group_name
@@ -190,12 +192,13 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
self.target_db_snapshot_identifier,
)
+ formatted_tags = format_tags(self.tags)
if self.db_type.value == "instance":
copy_instance_snap = self.hook.conn.copy_db_snapshot(
SourceDBSnapshotIdentifier=self.source_db_snapshot_identifier,
TargetDBSnapshotIdentifier=self.target_db_snapshot_identifier,
KmsKeyId=self.kms_key_id,
- Tags=self.tags,
+ Tags=formatted_tags,
CopyTags=self.copy_tags,
PreSignedUrl=self.pre_signed_url,
OptionGroupName=self.option_group_name,
@@ -212,7 +215,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
TargetDBClusterSnapshotIdentifier=self.target_db_snapshot_identifier,
KmsKeyId=self.kms_key_id,
- Tags=self.tags,
+ Tags=formatted_tags,
CopyTags=self.copy_tags,
PreSignedUrl=self.pre_signed_url,
SourceRegion=self.source_region,
@@ -403,7 +406,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
`USER Events
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Events.Messages.html>`__
:param source_ids: The list of identifiers of the event sources for which
events are returned
:param enabled: A value that indicates whether to activate the
subscription (default True)l
- :param tags: A list of tags in format `[{"Key": "something", "Value":
"something"},]
+ :param tags: A dictionary of tags or a list of tags in format `[{"Key":
"...", "Value": "..."},]`
`USER Tagging
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the
subscription to complete. (default: True)
"""
@@ -426,7 +429,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
event_categories: Sequence[str] | None = None,
source_ids: Sequence[str] | None = None,
enabled: bool = True,
- tags: Sequence[TagTypeDef] | None = None,
+ tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
@@ -439,12 +442,13 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
self.event_categories = event_categories or []
self.source_ids = source_ids or []
self.enabled = enabled
- self.tags = tags or []
+ self.tags = tags
self.wait_for_completion = wait_for_completion
def execute(self, context: Context) -> str:
self.log.info("Creating event subscription '%s' to '%s'",
self.subscription_name, self.sns_topic_arn)
+ formatted_tags = format_tags(self.tags)
create_subscription = self.hook.conn.create_event_subscription(
SubscriptionName=self.subscription_name,
SnsTopicArn=self.sns_topic_arn,
@@ -452,7 +456,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
EventCategories=self.event_categories,
SourceIds=self.source_ids,
Enabled=self.enabled,
- Tags=self.tags,
+ Tags=formatted_tags,
)
if self.wait_for_completion:
diff --git a/airflow/providers/amazon/aws/operators/s3.py
b/airflow/providers/amazon/aws/operators/s3.py
index d748da67bb..d9ab1ab75c 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -163,7 +163,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
If a key is provided, a value must be provided as well.
:param value: The value portion of the key/value pair for a tag to be
added.
If a value is provided, a key must be provided as well.
- :param tag_set: A List of key/value pairs.
+ :param tag_set: A dictionary containing the tags, or a List of key/value
pairs.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -179,7 +179,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
bucket_name: str,
key: str | None = None,
value: str | None = None,
- tag_set: list[dict[str, str]] | None = None,
+ tag_set: dict | list[dict[str, str]] | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index f0191f2f53..aa6130e3f8 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -28,6 +28,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
+from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.json import AirflowJsonEncoder
if TYPE_CHECKING:
@@ -1090,11 +1091,10 @@ class
SageMakerCreateExperimentOperator(SageMakerBaseOperator):
def execute(self, context: Context) -> str:
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
- tags_set = [{"Key": kvp[0], "Value": kvp[1]} for kvp in
self.tags.items()]
params = {
"ExperimentName": self.name,
"Description": self.description,
- "Tags": tags_set,
+ "Tags": format_tags(self.tags),
}
ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params))
arn = ans["ExperimentArn"]
diff --git a/airflow/providers/amazon/aws/utils/tags.py
b/airflow/providers/amazon/aws/utils/tags.py
new file mode 100644
index 0000000000..c8afb124b6
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/tags.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+
+def format_tags(source: Any, *, key_label: str = "Key", value_label: str =
"Value"):
+ """
+ If given a dictionary, formats it as an array of objects with a key and a
value field to be passed to boto
+ calls that expect this format.
+ Else, assumes that it's already in the right format and returns it as is.
We do not validate
+ the format here since it's done by boto anyway, and the error wouldn't be
clearer if thrown from here.
+
+ :param source: a dict from which keys and values are read
+ :param key_label: optional, the label to use for keys if not "Key"
+ :param value_label: optional, the label to use for values if not "Value"
+ """
+ if source is None:
+ return []
+ elif isinstance(source, dict):
+ return [{key_label: kvp[0], value_label: kvp[1]} for kvp in
source.items()]
+ else:
+ return source
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py
b/tests/providers/amazon/aws/hooks/test_s3.py
index 6ee49ffb26..6eaec58dbf 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -734,6 +734,15 @@ class TestAwsS3Hook:
assert hook.get_bucket_tagging(bucket_name="new_bucket") == tag_set
+ @mock_s3
+ def test_put_bucket_tagging_with_dict(self):
+ hook = S3Hook()
+ hook.create_bucket(bucket_name="new_bucket")
+ tag_set = {"Color": "Green"}
+ hook.put_bucket_tagging(bucket_name="new_bucket", tag_set=tag_set)
+
+ assert hook.get_bucket_tagging(bucket_name="new_bucket") == [{"Key":
"Color", "Value": "Green"}]
+
@mock_s3
def test_put_bucket_tagging_with_pair(self):
hook = S3Hook()