This is an automated email from the ASF dual-hosted git repository. ferruzzi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 7ed31d5fdf Amazon Bedrock - Model Customization Jobs (#38693) 7ed31d5fdf is described below commit 7ed31d5fdf510e00528522ea313a20b19e498522 Author: D. Ferruzzi <ferru...@amazon.com> AuthorDate: Mon Apr 8 13:22:16 2024 -0700 Amazon Bedrock - Model Customization Jobs (#38693) * Amazon Bedrock - Customize Model Operator/Sensor/Waiter/Trigger --- airflow/providers/amazon/aws/hooks/bedrock.py | 20 +++ airflow/providers/amazon/aws/operators/bedrock.py | 161 ++++++++++++++++++++- airflow/providers/amazon/aws/sensors/bedrock.py | 110 ++++++++++++++ airflow/providers/amazon/aws/triggers/bedrock.py | 61 ++++++++ airflow/providers/amazon/aws/waiters/bedrock.json | 42 ++++++ airflow/providers/amazon/provider.yaml | 6 + .../operators/bedrock.rst | 38 +++++ tests/providers/amazon/aws/hooks/test_bedrock.py | 36 ++++- .../providers/amazon/aws/operators/test_bedrock.py | 161 ++++++++++++++++++--- tests/providers/amazon/aws/sensors/test_bedrock.py | 95 ++++++++++++ .../providers/amazon/aws/triggers/test_bedrock.py | 53 +++++++ tests/providers/amazon/aws/waiters/test_bedrock.py | 70 +++++++++ .../system/providers/amazon/aws/example_bedrock.py | 106 +++++++++++++- 13 files changed, 929 insertions(+), 30 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index 11bacd9414..96636eb952 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -19,6 +19,26 @@ from __future__ import annotations from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +class BedrockHook(AwsBaseHook): + """ + Interact with Amazon Bedrock. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock") <Bedrock.Client>`. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) + + class BedrockRuntimeHook(AwsBaseHook): """ Interact with the Amazon Bedrock Runtime. diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index d8eaf9e5d3..ee34a9aef7 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -19,10 +19,17 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Sequence -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from botocore.exceptions import ClientError + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict +from airflow.utils.timezone import utcnow if TYPE_CHECKING: from airflow.utils.context import Context @@ -91,3 +98,155 @@ class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data) self.log.info("Bedrock model response: %s", response_body) return response_body + + +class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]): + """ + Create a fine-tuning job to customize a base model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockCustomizeModelOperator` + + :param job_name: A unique name for the fine-tuning job. + :param custom_model_name: A name for the custom model being created. + :param role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon Bedrock can assume + to perform tasks on your behalf. + :param base_model_id: Name of the base model. + :param training_data_uri: The S3 URI where the training data is stored. + :param output_data_uri: The S3 URI where the output data is stored. + :param hyperparameters: Parameters related to tuning the model. + :param ensure_unique_job_name: If set to true, operator will check whether a model customization + job already exists for the name in the config and append the current timestamp if there is a + name conflict. (Default: True) + :param customization_job_kwargs: Any optional parameters to pass to the API. + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 120) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 75) + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields( + "job_name", + "custom_model_name", + "role_arn", + "base_model_id", + "hyperparameters", + "ensure_unique_job_name", + "customization_job_kwargs", + ) + + def __init__( + self, + job_name: str, + custom_model_name: str, + role_arn: str, + base_model_id: str, + training_data_uri: str, + output_data_uri: str, + hyperparameters: dict[str, str], + ensure_unique_job_name: bool = True, + customization_job_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + self.job_name = job_name + self.custom_model_name = custom_model_name + self.role_arn = role_arn + self.base_model_id = base_model_id + self.training_data_config = {"s3Uri": training_data_uri} + self.output_data_config = {"s3Uri": output_data_uri} + self.hyperparameters = hyperparameters + self.ensure_unique_job_name = ensure_unique_job_name + self.customization_job_kwargs = customization_job_kwargs or {} + + self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("Bedrock model customization job `%s` complete.", self.job_name) + return self.hook.conn.get_model_customization_job(jobIdentifier=event["job_name"])["jobArn"] + + def execute(self, context: Context) -> dict: + response = {} + retry = True + while retry: + # If there is a name conflict and ensure_unique_job_name is True, append the current timestamp + # to the name and retry until there is no name conflict. + # - Break the loop when the API call returns success. + # - If the API returns an exception other than a name conflict, raise that exception. + # - If the API returns a name conflict and ensure_unique_job_name is false, raise that exception. + try: + # Ensure the loop is executed at least once, and not repeat unless explicitly set to do so. + retry = False + self.log.info("Creating Bedrock model customization job '%s'.", self.job_name) + + response = self.hook.conn.create_model_customization_job( + jobName=self.job_name, + customModelName=self.custom_model_name, + roleArn=self.role_arn, + baseModelIdentifier=self.base_model_id, + trainingDataConfig=self.training_data_config, + outputDataConfig=self.output_data_config, + hyperParameters=self.hyperparameters, + **self.customization_job_kwargs, + ) + except ClientError as error: + if error.response["Error"]["Message"] != "The provided job name is currently in use.": + raise error + if not self.ensure_unique_job_name: + raise error + retry = True + self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}" + self.log.info("Changed job name to '%s' to avoid collision.", self.job_name) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 201: + raise AirflowException(f"Bedrock model customization job creation failed: {response}") + + task_description = f"Bedrock model customization job {self.job_name} to complete." + if self.deferrable: + self.log.info("Deferring for %s", task_description) + self.defer( + trigger=BedrockCustomizeModelCompletedTrigger( + job_name=self.job_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: + self.log.info("Waiting for %s", task_description) + self.hook.get_waiter("model_customization_job_complete").wait( + jobIdentifier=self.job_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return response["jobArn"] diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py new file mode 100644 index 0000000000..43a8846c73 --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.configuration import conf +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook + + +class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): + """ + Poll the state of the model customization job until it reaches a terminal state; fails if the job fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BedrockCustomizeModelCompletedSensor` + + + :param job_name: The name of the Bedrock model customization job. + + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param max_retries: Number of times before returning the current state. (default: 75) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES = ("InProgress",) + FAILURE_STATES = ("Failed", "Stopping", "Stopped") + SUCCESS_STATES = ("Completed",) + FAILURE_MESSAGE = "Bedrock model customization job sensor failed." + + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields("job_name") + ui_color = "#66c3ff" + + def __init__( + self, + *, + job_name: str, + max_retries: int = 75, + poke_interval: int = 120, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.poke_interval = poke_interval + self.max_retries = max_retries + self.deferrable = deferrable + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=BedrockCustomizeModelCompletedTrigger( + job_name=self.job_name, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) + + def poke(self, context: Context) -> bool: + state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] + self.log.info("Job '%s' state: %s", self.job_name, state) + + if state in self.FAILURE_STATES: + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(self.FAILURE_MESSAGE) + raise AirflowException(self.FAILURE_MESSAGE) + + return state not in self.INTERMEDIATE_STATES diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py b/airflow/providers/amazon/aws/triggers/bedrock.py new file mode 100644 index 0000000000..ae4805ed70 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/bedrock.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a Bedrock model customization job is complete. + + :param job_name: The name of the Bedrock model customization job. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + job_name: str, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={"job_name": job_name}, + waiter_name="model_customization_job_complete", + waiter_args={"jobIdentifier": job_name}, + failure_message="Bedrock model customization failed.", + status_message="Status of Bedrock model customization job is", + status_queries=["status"], + return_key="job_name", + return_value=job_name, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return BedrockHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/bedrock.json b/airflow/providers/amazon/aws/waiters/bedrock.json new file mode 100644 index 0000000000..c44b7c0589 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/bedrock.json @@ -0,0 +1,42 @@ +{ + "version": 2, + "waiters": { + "model_customization_job_complete": { + "delay": 120, + "maxAttempts": 75, + "operation": "GetModelCustomizationJob", + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "InProgress", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Completed", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Failed", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Stopping", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Stopped", + "state": "failure" + } + ] + } + } +} diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 4c4f7cf597..dc072b324e 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -454,6 +454,9 @@ sensors: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.sensors.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.sensors.bedrock - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.sensors.cloud_formation @@ -650,6 +653,9 @@ triggers: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.triggers.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.triggers.bedrock - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.triggers.ec2 diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index 3e84cbc445..411deba79f 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -65,6 +65,44 @@ To invoke an Amazon Titan model you would use: For details on the different formats, see `Inference parameters for foundation models <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__ +.. _howto/operator:BedrockCustomizeModelOperator: + +Customize an existing Amazon Bedrock Model +========================================== + +To create a fine-tuning job to customize a base model, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCustomizeModelOperator`. + +Model-customization jobs are asynchronous and the completion time depends on the base model +and the training/validation data size. To monitor the state of the job, you can use the +"model_customization_job_complete" Waiter, the +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` Sensor, +or the :class:`~airflow.providers.amazon.aws.triggers.BedrockCustomizeModelCompletedTrigger` Trigger. + + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_customize_model] + :end-before: [END howto_operator_customize_model] + + +Sensors +------- + +.. _howto/sensor:BedrockCustomizeModelCompletedSensor: + +Wait for an Amazon Bedrock customize model job +============================================== + +To wait on the state of an Amazon Bedrock customize model job until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_customize_model] + :end-before: [END howto_sensor_customize_model] Reference --------- diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 73612aacbc..16752477d5 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -16,7 +16,41 @@ # under the License. from __future__ import annotations -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from unittest import mock + +import pytest +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook + +JOB_NAME = "testJobName" +EXPECTED_STATUS = "InProgress" + + +@pytest.fixture +def mock_conn(): + with mock.patch.object(BedrockHook, "conn") as _conn: + _conn.get_model_customization_job.return_value = {"jobName": JOB_NAME, "status": EXPECTED_STATUS} + yield _conn + + +class TestBedrockHook: + VALIDATION_EXCEPTION_ERROR = ClientError( + error_response={"Error": {"Code": "ValidationException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) + + UNEXPECTED_EXCEPTION = ClientError( + error_response={"Error": {"Code": "ExpiredTokenException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) + + def setup_method(self): + self.hook = BedrockHook() + + def test_conn_returns_a_boto3_connection(self): + assert self.hook.conn is not None + assert self.hook.conn.meta.service_model.service_name == "bedrock" class TestBedrockRuntimeHook: diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index f6274de48f..2371877b4d 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -18,42 +18,155 @@ from __future__ import annotations import json -from typing import Generator +from typing import TYPE_CHECKING, Generator from unittest import mock import pytest +from botocore.exceptions import ClientError from moto import mock_aws -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook -from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator - -MODEL_ID = "meta.llama2-13b-chat-v1" -PROMPT = "A very important question." -GENERATED_RESPONSE = "An important answer." -MOCK_RESPONSE = json.dumps( - { - "generation": GENERATED_RESPONSE, - "prompt_token_count": len(PROMPT), - "generation_token_count": len(GENERATED_RESPONSE), - "stop_reason": "stop", - } +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCustomizeModelOperator, + BedrockInvokeModelOperator, ) - -@pytest.fixture -def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]: - with mock_aws(): - yield BedrockRuntimeHook(aws_conn_id="aws_default") +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection class TestBedrockInvokeModelOperator: - @mock.patch.object(BedrockRuntimeHook, "conn") - def test_invoke_model_prompt_good_combinations(self, mock_conn): - mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE + MODEL_ID = "meta.llama2-13b-chat-v1" + TEST_PROMPT = "A very important question." + GENERATED_RESPONSE = "An important answer." + + @pytest.fixture + def mock_runtime_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockRuntimeHook, "conn") as _conn: + _conn.invoke_model.return_value["body"].read.return_value = json.dumps( + { + "generation": self.GENERATED_RESPONSE, + "prompt_token_count": len(self.TEST_PROMPT), + "generation_token_count": len(self.GENERATED_RESPONSE), + "stop_reason": "stop", + } + ) + yield _conn + + @pytest.fixture + def runtime_hook(self) -> Generator[BedrockRuntimeHook, None, None]: + with mock_aws(): + yield BedrockRuntimeHook(aws_conn_id="aws_default") + + def test_invoke_model_prompt_good_combinations(self, mock_runtime_conn): operator = BedrockInvokeModelOperator( - task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}} + task_id="test_task", + model_id=self.MODEL_ID, + input_data={"input_data": {"prompt": self.TEST_PROMPT}}, ) response = operator.execute({}) - assert response["generation"] == GENERATED_RESPONSE + assert response["generation"] == self.GENERATED_RESPONSE + + +class TestBedrockCustomizeModelOperator: + CUSTOMIZE_JOB_ARN = "valid_arn" + CUSTOMIZE_JOB_NAME = "testModelJob" + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockHook, "conn") as _conn: + _conn.create_model_customization_job.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 201}, + "jobArn": self.CUSTOMIZE_JOB_ARN, + } + _conn.get_model_customization_job.return_value = { + "jobName": self.CUSTOMIZE_JOB_NAME, + "status": "InProgress", + } + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockHook, None, None]: + with mock_aws(): + hook = BedrockHook(aws_conn_id="aws_default") + yield hook + + def setup_method(self): + self.operator = BedrockCustomizeModelOperator( + task_id="test_task", + job_name=self.CUSTOMIZE_JOB_NAME, + custom_model_name="testModelName", + role_arn="valid_arn", + base_model_id="base_model_id", + hyperparameters={ + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", + }, + training_data_uri="s3://uri", + output_data_uri="s3://uri/output", + ) + self.operator.defer = mock.MagicMock() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(BedrockHook, "get_waiter") + def test_customize_model_wait_combinations( + self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook + ): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + response = self.operator.execute({}) + + assert response == self.CUSTOMIZE_JOB_ARN + assert bedrock_hook.get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable + + conflict_msg = "The provided job name is currently in use." + conflict_exception = ClientError( + error_response={"Error": {"Message": conflict_msg, "Code": "ValidationException"}}, + operation_name="UnitTest", + ) + success = {"ResponseMetadata": {"HTTPStatusCode": 201}, "jobArn": CUSTOMIZE_JOB_ARN} + + @pytest.mark.parametrize( + "side_effect, ensure_unique_name", + [ + pytest.param([conflict_exception, success], True, id="conflict_and_ensure_unique"), + pytest.param([conflict_exception, success], False, id="conflict_and_not_ensure_unique"), + pytest.param( + [conflict_exception, conflict_exception, success], + True, + id="multiple_conflict_and_ensure_unique", + ), + pytest.param( + [conflict_exception, conflict_exception, success], + False, + id="multiple_conflict_and_not_ensure_unique", + ), + pytest.param([success], True, id="no_conflict_and_ensure_unique"), + pytest.param([success], False, id="no_conflict_and_not_ensure_unique"), + ], + ) + @mock.patch.object(BedrockHook, "get_waiter") + def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_conn, bedrock_hook): + mock_conn.create_model_customization_job.side_effect = side_effect + expected_call_count = len(side_effect) if ensure_unique_name else 1 + self.operator.wait_for_completion = False + + response = self.operator.execute({}) + + assert response == self.CUSTOMIZE_JOB_ARN + mock_conn.create_model_customization_job.call_count == expected_call_count + bedrock_hook.get_waiter.assert_not_called() + self.operator.defer.assert_not_called() diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py new file mode 100644 index 0000000000..dab0f94ad3 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor + + +@pytest.fixture +def mock_get_job_state(): + with mock.patch.object(BedrockHook, "get_customize_model_job_state") as mock_state: + yield mock_state + + +class TestBedrockCustomizeModelCompletedSensor: + JOB_NAME = "test_job_name" + + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_bedrock_customize_model_sensor", + job_name=self.JOB_NAME, + poke_interval=5, + max_retries=1, + ) + self.sensor = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = BedrockCustomizeModelCompletedSensor( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.SUCCESS_STATES)) + @mock.patch.object(BedrockHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.get_model_customization_job.return_value = {"status": state} + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.INTERMEDIATE_STATES)) + @mock.patch.object(BedrockHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.get_model_customization_job.return_value = {"status": state} + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.FAILURE_STATES)) + @mock.patch.object(BedrockHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.get_model_customization_job.return_value = {"status": state} + sensor = BedrockCustomizeModelCompletedSensor( + **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail + ) + + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py b/tests/providers/amazon/aws/triggers/test_bedrock.py new file mode 100644 index 0000000000..0a54c56a77 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_bedrock.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.triggers.base import TriggerEvent + +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock." + + +class TestBedrockCustomizeModelCompletedTrigger: + JOB_NAME = "test_job" + + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockCustomizeModelCompletedTrigger" + assert kwargs.get("job_name") == self.JOB_NAME + + @pytest.mark.asyncio + @mock.patch.object(BedrockHook, "get_waiter") + @mock.patch.object(BedrockHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_name": self.JOB_NAME}) + assert mock_get_waiter().wait.call_count == 1 diff --git a/tests/providers/amazon/aws/waiters/test_bedrock.py b/tests/providers/amazon/aws/waiters/test_bedrock.py new file mode 100644 index 0000000000..00521ee013 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_bedrock.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import boto3 +import botocore +import pytest + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor + + +class TestBedrockCustomWaiters: + def test_service_waiters(self): + assert "model_customization_job_complete" in BedrockHook().list_waiters() + + +class TestBedrockCustomWaitersBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("bedrock") + monkeypatch.setattr(BedrockHook, "conn", self.client) + + +class TestModelCustomizationJobCompleteWaiter(TestBedrockCustomWaitersBase): + WAITER_NAME = "model_customization_job_complete" + + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "get_model_customization_job") as m: + yield m + + @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.SUCCESS_STATES) + def test_model_customization_job_complete(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") + + @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.FAILURE_STATES) + def test_model_customization_job_failed(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + with pytest.raises(botocore.exceptions.WaiterError): + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") + + def test_model_customization_job_wait(self, mock_get_job): + wait = {"status": "InProgress"} + success = {"status": "Completed"} + mock_get_job.side_effect = [wait, wait, success] + + BedrockHook().get_waiter(self.WAITER_NAME).wait( + jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index e86e5a2e92..12e2461547 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -16,17 +16,61 @@ # under the License. from __future__ import annotations +import json from datetime import datetime +from botocore.exceptions import ClientError + +from airflow.decorators import task from airflow.models.baseoperator import chain from airflow.models.dag import DAG -from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCustomizeModelOperator, + BedrockInvokeModelOperator, +) +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor +from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder -sys_test_context_task = SystemTestContextBuilder().build() +# Externally fetched variables: +ROLE_ARN_KEY = "ROLE_ARN" +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() DAG_ID = "example_bedrock" + +# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True then set +# the trigger rule to an improbable state. This way we can still have the code snippets +# for docs, and we can manually run the full tests occasionally. +SKIP_LONG_TASKS = True + +LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1" PROMPT = "What color is an orange?" +TITAN_MODEL_ID = "amazon.titan-text-express-v1" +TRAIN_DATA = {"prompt": "what is AWS", "completion": "it's Amazon Web Services"} +HYPERPARAMETERS = { + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", +} + + +@task +def delete_custom_model(model_name: str): + try: + BedrockHook().conn.delete_custom_model(modelIdentifier=model_name) + except ClientError as e: + if SKIP_LONG_TASKS and (e.response["Error"]["Code"] == "ValidationException"): + # There is no model to delete. Since we skipped making one, that's fine. + return + raise e + with DAG( dag_id=DAG_ID, @@ -37,11 +81,28 @@ with DAG( ) as dag: test_context = sys_test_context_task() env_id = test_context["ENV_ID"] + bucket_name = f"{env_id}-bedrock" + input_data_s3_key = f"{env_id}/train.jsonl" + training_data_uri = f"s3://{bucket_name}/{input_data_s3_key}" + custom_model_name = f"CustomModel{env_id}" + custom_model_job_name = f"CustomizeModelJob{env_id}" + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=bucket_name, + ) + + upload_training_data = S3CreateObjectOperator( + task_id="upload_data", + s3_bucket=bucket_name, + s3_key=training_data_uri, + data=json.dumps(TRAIN_DATA), + ) # [START howto_operator_invoke_llama_model] invoke_llama_model = BedrockInvokeModelOperator( task_id="invoke_llama", - model_id="meta.llama2-13b-chat-v1", + model_id=LLAMA_MODEL_ID, input_data={"prompt": PROMPT}, ) # [END howto_operator_invoke_llama_model] @@ -49,18 +110,55 @@ with DAG( # [START howto_operator_invoke_titan_model] invoke_titan_model = BedrockInvokeModelOperator( task_id="invoke_titan", - model_id="amazon.titan-text-express-v1", + model_id=TITAN_MODEL_ID, input_data={"inputText": PROMPT}, ) # [END howto_operator_invoke_titan_model] + # [START howto_operator_customize_model] + customize_model = BedrockCustomizeModelOperator( + task_id="customize_model", + job_name=custom_model_job_name, + custom_model_name=custom_model_name, + role_arn=test_context[ROLE_ARN_KEY], + base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}", + hyperparameters=HYPERPARAMETERS, + training_data_uri=training_data_uri, + output_data_uri=f"s3://{bucket_name}/myOutputData", + ) + # [END howto_operator_customize_model] + + # [START howto_sensor_customize_model] + await_custom_model_job = BedrockCustomizeModelCompletedSensor( + task_id="await_custom_model_job", + job_name=custom_model_job_name, + ) + # [END howto_sensor_customize_model] + + if SKIP_LONG_TASKS: + customize_model.trigger_rule = TriggerRule.ALL_SKIPPED + await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + chain( # TEST SETUP test_context, + create_bucket, + upload_training_data, # TEST BODY invoke_llama_model, invoke_titan_model, + customize_model, + await_custom_model_job, # TEST TEARDOWN + delete_custom_model(custom_model_name), + delete_bucket, ) from tests.system.utils.watcher import watcher