This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2c4928da40 introduce a method to convert dictionaries to boto-style key-value lists (#28816) 2c4928da40 is described below commit 2c4928da40667cd4d52030b8b79419175948cb85 Author: Raphaƫl Vandon <114772123+vandonr-...@users.noreply.github.com> AuthorDate: Tue Jan 24 15:45:16 2023 -0800 introduce a method to convert dictionaries to boto-style key-value lists (#28816) * accept either dict of list for tags --- airflow/providers/amazon/aws/hooks/s3.py | 28 ++++++++++------ airflow/providers/amazon/aws/hooks/sagemaker.py | 5 ++- airflow/providers/amazon/aws/operators/rds.py | 32 ++++++++++-------- airflow/providers/amazon/aws/operators/s3.py | 4 +-- .../providers/amazon/aws/operators/sagemaker.py | 4 +-- airflow/providers/amazon/aws/utils/tags.py | 38 ++++++++++++++++++++++ tests/providers/amazon/aws/hooks/test_s3.py | 9 +++++ 7 files changed, 89 insertions(+), 31 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 89c9261cb6..f88274747d 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -43,6 +43,7 @@ from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils.helpers import chunks T = TypeVar("T", bound=Callable) @@ -1063,36 +1064,43 @@ class S3Hook(AwsBaseHook): @provide_bucket_name def put_bucket_tagging( self, - tag_set: list[dict[str, str]] | None = None, + tag_set: dict[str, str] | list[dict[str, str]] | None = None, key: str | None = None, value: str | None = None, bucket_name: str | None = None, ) -> None: """ - Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair. + Overwrites the existing TagSet with provided tags. + Must provide a TagSet, a key/value pair, or both. .. seealso:: - :external+boto3:py:meth:`S3.Client.put_bucket_tagging` - :param tag_set: A List containing the key/value pairs for the tags. + :param tag_set: A dictionary containing the key/value pairs for the tags, + or a list already formatted for the API :param key: The Key for the new TagSet entry. :param value: The Value for the new TagSet entry. :param bucket_name: The name of the bucket. + :return: None """ - self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set) - if not tag_set: - tag_set = [] + formatted_tags = format_tags(tag_set) + if key and value: - tag_set.append({"Key": key, "Value": value}) - elif not tag_set or (key or value): - message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair." + formatted_tags.append({"Key": key, "Value": value}) + elif key or value: + message = ( + "Key and Value must be specified as a pair. " + f"Only one of the two had a value (key: '{key}', value: '{value}')" + ) self.log.error(message) raise ValueError(message) + self.log.info("Tagging S3 Bucket %s with %s", bucket_name, formatted_tags) + try: s3_client = self.get_conn() - s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": tag_set}) + s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": formatted_tags}) except ClientError as e: self.log.error(e) raise e diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index c5aeb3d9ed..4c731f2051 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -35,6 +35,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils import timezone @@ -1100,9 +1101,7 @@ class SageMakerHook(AwsBaseHook): :return: the ARN of the pipeline execution launched. """ - if pipeline_params is None: - pipeline_params = {} - formatted_params = [{"Name": kvp[0], "Value": kvp[1]} for kvp in pipeline_params.items()] + formatted_params = format_tags(pipeline_params, key_label="Name") try: res = self.conn.start_pipeline_execution( diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index c10e969c8b..2f2cf58438 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -25,6 +25,7 @@ from mypy_boto3_rds.type_defs import TagTypeDef from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.utils.rds import RdsDbType +from airflow.providers.amazon.aws.utils.tags import format_tags if TYPE_CHECKING: from airflow.utils.context import Context @@ -64,7 +65,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): :param db_type: Type of the DB - either "instance" or "cluster" :param db_identifier: The identifier of the instance or cluster that you want to create the snapshot of :param db_snapshot_identifier: The identifier for the DB snapshot - :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},] + :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]` `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__ :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True) """ @@ -77,7 +78,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): db_type: str, db_identifier: str, db_snapshot_identifier: str, - tags: Sequence[TagTypeDef] | None = None, + tags: Sequence[TagTypeDef] | dict | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_conn_id", **kwargs, @@ -86,7 +87,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): self.db_type = RdsDbType(db_type) self.db_identifier = db_identifier self.db_snapshot_identifier = db_snapshot_identifier - self.tags = tags or [] + self.tags = tags self.wait_for_completion = wait_for_completion def execute(self, context: Context) -> str: @@ -97,11 +98,12 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): self.db_snapshot_identifier, ) + formatted_tags = format_tags(self.tags) if self.db_type.value == "instance": create_instance_snap = self.hook.conn.create_db_snapshot( DBInstanceIdentifier=self.db_identifier, DBSnapshotIdentifier=self.db_snapshot_identifier, - Tags=self.tags, + Tags=formatted_tags, ) create_response = json.dumps(create_instance_snap, default=str) if self.wait_for_completion: @@ -110,7 +112,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): create_cluster_snap = self.hook.conn.create_db_cluster_snapshot( DBClusterIdentifier=self.db_identifier, DBClusterSnapshotIdentifier=self.db_snapshot_identifier, - Tags=self.tags, + Tags=formatted_tags, ) create_response = json.dumps(create_cluster_snap, default=str) if self.wait_for_completion: @@ -132,7 +134,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): :param source_db_snapshot_identifier: The identifier of the source snapshot :param target_db_snapshot_identifier: The identifier of the target snapshot :param kms_key_id: The AWS KMS key identifier for an encrypted DB snapshot - :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},] + :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]` `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__ :param copy_tags: Whether to copy all tags from the source snapshot to the target snapshot (default False) :param pre_signed_url: The URL that contains a Signature Version 4 signed request @@ -159,7 +161,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): source_db_snapshot_identifier: str, target_db_snapshot_identifier: str, kms_key_id: str = "", - tags: Sequence[TagTypeDef] | None = None, + tags: Sequence[TagTypeDef] | dict | None = None, copy_tags: bool = False, pre_signed_url: str = "", option_group_name: str = "", @@ -175,7 +177,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): self.source_db_snapshot_identifier = source_db_snapshot_identifier self.target_db_snapshot_identifier = target_db_snapshot_identifier self.kms_key_id = kms_key_id - self.tags = tags or [] + self.tags = tags self.copy_tags = copy_tags self.pre_signed_url = pre_signed_url self.option_group_name = option_group_name @@ -190,12 +192,13 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): self.target_db_snapshot_identifier, ) + formatted_tags = format_tags(self.tags) if self.db_type.value == "instance": copy_instance_snap = self.hook.conn.copy_db_snapshot( SourceDBSnapshotIdentifier=self.source_db_snapshot_identifier, TargetDBSnapshotIdentifier=self.target_db_snapshot_identifier, KmsKeyId=self.kms_key_id, - Tags=self.tags, + Tags=formatted_tags, CopyTags=self.copy_tags, PreSignedUrl=self.pre_signed_url, OptionGroupName=self.option_group_name, @@ -212,7 +215,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier, TargetDBClusterSnapshotIdentifier=self.target_db_snapshot_identifier, KmsKeyId=self.kms_key_id, - Tags=self.tags, + Tags=formatted_tags, CopyTags=self.copy_tags, PreSignedUrl=self.pre_signed_url, SourceRegion=self.source_region, @@ -403,7 +406,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): `USER Events <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Events.Messages.html>`__ :param source_ids: The list of identifiers of the event sources for which events are returned :param enabled: A value that indicates whether to activate the subscription (default True)l - :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},] + :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]` `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__ :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True) """ @@ -426,7 +429,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): event_categories: Sequence[str] | None = None, source_ids: Sequence[str] | None = None, enabled: bool = True, - tags: Sequence[TagTypeDef] | None = None, + tags: Sequence[TagTypeDef] | dict | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, @@ -439,12 +442,13 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): self.event_categories = event_categories or [] self.source_ids = source_ids or [] self.enabled = enabled - self.tags = tags or [] + self.tags = tags self.wait_for_completion = wait_for_completion def execute(self, context: Context) -> str: self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn) + formatted_tags = format_tags(self.tags) create_subscription = self.hook.conn.create_event_subscription( SubscriptionName=self.subscription_name, SnsTopicArn=self.sns_topic_arn, @@ -452,7 +456,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): EventCategories=self.event_categories, SourceIds=self.source_ids, Enabled=self.enabled, - Tags=self.tags, + Tags=formatted_tags, ) if self.wait_for_completion: diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index d748da67bb..d9ab1ab75c 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -163,7 +163,7 @@ class S3PutBucketTaggingOperator(BaseOperator): If a key is provided, a value must be provided as well. :param value: The value portion of the key/value pair for a tag to be added. If a value is provided, a key must be provided as well. - :param tag_set: A List of key/value pairs. + :param tag_set: A dictionary containing the tags, or a List of key/value pairs. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -179,7 +179,7 @@ class S3PutBucketTaggingOperator(BaseOperator): bucket_name: str, key: str | None = None, value: str | None = None, - tag_set: list[dict[str, str]] | None = None, + tag_set: dict | list[dict[str, str]] | None = None, aws_conn_id: str | None = "aws_default", **kwargs, ) -> None: diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index f0191f2f53..aa6130e3f8 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -28,6 +28,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.utils import trim_none_values from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus +from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: @@ -1090,11 +1091,10 @@ class SageMakerCreateExperimentOperator(SageMakerBaseOperator): def execute(self, context: Context) -> str: sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) - tags_set = [{"Key": kvp[0], "Value": kvp[1]} for kvp in self.tags.items()] params = { "ExperimentName": self.name, "Description": self.description, - "Tags": tags_set, + "Tags": format_tags(self.tags), } ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params)) arn = ans["ExperimentArn"] diff --git a/airflow/providers/amazon/aws/utils/tags.py b/airflow/providers/amazon/aws/utils/tags.py new file mode 100644 index 0000000000..c8afb124b6 --- /dev/null +++ b/airflow/providers/amazon/aws/utils/tags.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + + +def format_tags(source: Any, *, key_label: str = "Key", value_label: str = "Value"): + """ + If given a dictionary, formats it as an array of objects with a key and a value field to be passed to boto + calls that expect this format. + Else, assumes that it's already in the right format and returns it as is. We do not validate + the format here since it's done by boto anyway, and the error wouldn't be clearer if thrown from here. + + :param source: a dict from which keys and values are read + :param key_label: optional, the label to use for keys if not "Key" + :param value_label: optional, the label to use for values if not "Value" + """ + if source is None: + return [] + elif isinstance(source, dict): + return [{key_label: kvp[0], value_label: kvp[1]} for kvp in source.items()] + else: + return source diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 6ee49ffb26..6eaec58dbf 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -734,6 +734,15 @@ class TestAwsS3Hook: assert hook.get_bucket_tagging(bucket_name="new_bucket") == tag_set + @mock_s3 + def test_put_bucket_tagging_with_dict(self): + hook = S3Hook() + hook.create_bucket(bucket_name="new_bucket") + tag_set = {"Color": "Green"} + hook.put_bucket_tagging(bucket_name="new_bucket", tag_set=tag_set) + + assert hook.get_bucket_tagging(bucket_name="new_bucket") == [{"Key": "Color", "Value": "Green"}] + @mock_s3 def test_put_bucket_tagging_with_pair(self): hook = S3Hook()