This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7ed31d5fdf Amazon Bedrock - Model Customization Jobs (#38693)
7ed31d5fdf is described below
commit 7ed31d5fdf510e00528522ea313a20b19e498522
Author: D. Ferruzzi <[email protected]>
AuthorDate: Mon Apr 8 13:22:16 2024 -0700
Amazon Bedrock - Model Customization Jobs (#38693)
* Amazon Bedrock - Customize Model Operator/Sensor/Waiter/Trigger
---
airflow/providers/amazon/aws/hooks/bedrock.py | 20 +++
airflow/providers/amazon/aws/operators/bedrock.py | 161 ++++++++++++++++++++-
airflow/providers/amazon/aws/sensors/bedrock.py | 110 ++++++++++++++
airflow/providers/amazon/aws/triggers/bedrock.py | 61 ++++++++
airflow/providers/amazon/aws/waiters/bedrock.json | 42 ++++++
airflow/providers/amazon/provider.yaml | 6 +
.../operators/bedrock.rst | 38 +++++
tests/providers/amazon/aws/hooks/test_bedrock.py | 36 ++++-
.../providers/amazon/aws/operators/test_bedrock.py | 161 ++++++++++++++++++---
tests/providers/amazon/aws/sensors/test_bedrock.py | 95 ++++++++++++
.../providers/amazon/aws/triggers/test_bedrock.py | 53 +++++++
tests/providers/amazon/aws/waiters/test_bedrock.py | 70 +++++++++
.../system/providers/amazon/aws/example_bedrock.py | 106 +++++++++++++-
13 files changed, 929 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py
b/airflow/providers/amazon/aws/hooks/bedrock.py
index 11bacd9414..96636eb952 100644
--- a/airflow/providers/amazon/aws/hooks/bedrock.py
+++ b/airflow/providers/amazon/aws/hooks/bedrock.py
@@ -19,6 +19,26 @@ from __future__ import annotations
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+class BedrockHook(AwsBaseHook):
+ """
+ Interact with Amazon Bedrock.
+
+ Provide thin wrapper around
:external+boto3:py:class:`boto3.client("bedrock") <Bedrock.Client>`.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ client_type = "bedrock"
+
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs["client_type"] = self.client_type
+ super().__init__(*args, **kwargs)
+
+
class BedrockRuntimeHook(AwsBaseHook):
"""
Interact with the Amazon Bedrock Runtime.
diff --git a/airflow/providers/amazon/aws/operators/bedrock.py
b/airflow/providers/amazon/aws/operators/bedrock.py
index d8eaf9e5d3..ee34a9aef7 100644
--- a/airflow/providers/amazon/aws/operators/bedrock.py
+++ b/airflow/providers/amazon/aws/operators/bedrock.py
@@ -19,10 +19,17 @@ from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Sequence
-from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from botocore.exceptions import ClientError
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook,
BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.bedrock import
BedrockCustomizeModelCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict
+from airflow.utils.timezone import utcnow
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -91,3 +98,155 @@ class
BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]):
self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
self.log.info("Bedrock model response: %s", response_body)
return response_body
+
+
+class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]):
+ """
+ Create a fine-tuning job to customize a base model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockCustomizeModelOperator`
+
+ :param job_name: A unique name for the fine-tuning job.
+ :param custom_model_name: A name for the custom model being created.
+ :param role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon
Bedrock can assume
+ to perform tasks on your behalf.
+ :param base_model_id: Name of the base model.
+ :param training_data_uri: The S3 URI where the training data is stored.
+ :param output_data_uri: The S3 URI where the output data is stored.
+ :param hyperparameters: Parameters related to tuning the model.
+ :param ensure_unique_job_name: If set to true, operator will check whether
a model customization
+ job already exists for the name in the config and append the current
timestamp if there is a
+ name conflict. (Default: True)
+ :param customization_job_kwargs: Any optional parameters to pass to the
API.
+
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ :param waiter_delay: Time in seconds to wait between status checks.
(default: 120)
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion. (default: 75)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "job_name",
+ "custom_model_name",
+ "role_arn",
+ "base_model_id",
+ "hyperparameters",
+ "ensure_unique_job_name",
+ "customization_job_kwargs",
+ )
+
+ def __init__(
+ self,
+ job_name: str,
+ custom_model_name: str,
+ role_arn: str,
+ base_model_id: str,
+ training_data_uri: str,
+ output_data_uri: str,
+ hyperparameters: dict[str, str],
+ ensure_unique_job_name: bool = True,
+ customization_job_kwargs: dict[str, Any] | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ self.job_name = job_name
+ self.custom_model_name = custom_model_name
+ self.role_arn = role_arn
+ self.base_model_id = base_model_id
+ self.training_data_config = {"s3Uri": training_data_uri}
+ self.output_data_config = {"s3Uri": output_data_uri}
+ self.hyperparameters = hyperparameters
+ self.ensure_unique_job_name = ensure_unique_job_name
+ self.customization_job_kwargs = customization_job_kwargs or {}
+
+ self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"}
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ event = validate_execute_complete_event(event)
+
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+
+ self.log.info("Bedrock model customization job `%s` complete.",
self.job_name)
+ return
self.hook.conn.get_model_customization_job(jobIdentifier=event["job_name"])["jobArn"]
+
+ def execute(self, context: Context) -> dict:
+ response = {}
+ retry = True
+ while retry:
+ # If there is a name conflict and ensure_unique_job_name is True,
append the current timestamp
+ # to the name and retry until there is no name conflict.
+ # - Break the loop when the API call returns success.
+ # - If the API returns an exception other than a name conflict,
raise that exception.
+ # - If the API returns a name conflict and ensure_unique_job_name
is false, raise that exception.
+ try:
+ # Ensure the loop is executed at least once, and not repeat
unless explicitly set to do so.
+ retry = False
+ self.log.info("Creating Bedrock model customization job
'%s'.", self.job_name)
+
+ response = self.hook.conn.create_model_customization_job(
+ jobName=self.job_name,
+ customModelName=self.custom_model_name,
+ roleArn=self.role_arn,
+ baseModelIdentifier=self.base_model_id,
+ trainingDataConfig=self.training_data_config,
+ outputDataConfig=self.output_data_config,
+ hyperParameters=self.hyperparameters,
+ **self.customization_job_kwargs,
+ )
+ except ClientError as error:
+ if error.response["Error"]["Message"] != "The provided job
name is currently in use.":
+ raise error
+ if not self.ensure_unique_job_name:
+ raise error
+ retry = True
+ self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}"
+ self.log.info("Changed job name to '%s' to avoid collision.",
self.job_name)
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 201:
+ raise AirflowException(f"Bedrock model customization job creation
failed: {response}")
+
+ task_description = f"Bedrock model customization job {self.job_name}
to complete."
+ if self.deferrable:
+ self.log.info("Deferring for %s", task_description)
+ self.defer(
+ trigger=BedrockCustomizeModelCompletedTrigger(
+ job_name=self.job_name,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ self.log.info("Waiting for %s", task_description)
+ self.hook.get_waiter("model_customization_job_complete").wait(
+ jobIdentifier=self.job_name,
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts":
self.waiter_max_attempts},
+ )
+
+ return response["jobArn"]
diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py
b/airflow/providers/amazon/aws/sensors/bedrock.py
new file mode 100644
index 0000000000..43a8846c73
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/bedrock.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.triggers.bedrock import
BedrockCustomizeModelCompletedTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+
+
+class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]):
+ """
+ Poll the state of the model customization job until it reaches a terminal
state; fails if the job fails.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:BedrockCustomizeModelCompletedSensor`
+
+
+ :param job_name: The name of the Bedrock model customization job.
+
+ :param deferrable: If True, the sensor will operate in deferrable mode.
This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
+ :param max_retries: Number of times before returning the current state.
(default: 75)
+ :param poke_interval: Polling period in seconds to check for the status of
the job. (default: 120)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ INTERMEDIATE_STATES = ("InProgress",)
+ FAILURE_STATES = ("Failed", "Stopping", "Stopped")
+ SUCCESS_STATES = ("Completed",)
+ FAILURE_MESSAGE = "Bedrock model customization job sensor failed."
+
+ aws_hook_class = BedrockHook
+ template_fields: Sequence[str] = aws_template_fields("job_name")
+ ui_color = "#66c3ff"
+
+ def __init__(
+ self,
+ *,
+ job_name: str,
+ max_retries: int = 75,
+ poke_interval: int = 120,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.job_name = job_name
+ self.poke_interval = poke_interval
+ self.max_retries = max_retries
+ self.deferrable = deferrable
+
+ def execute(self, context: Context) -> Any:
+ if self.deferrable:
+ self.defer(
+ trigger=BedrockCustomizeModelCompletedTrigger(
+ job_name=self.job_name,
+ waiter_delay=int(self.poke_interval),
+ waiter_max_attempts=self.max_retries,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="poke",
+ )
+ else:
+ super().execute(context=context)
+
+ def poke(self, context: Context) -> bool:
+ state =
self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
+ self.log.info("Job '%s' state: %s", self.job_name, state)
+
+ if state in self.FAILURE_STATES:
+ # TODO: remove this if block when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
+ raise AirflowException(self.FAILURE_MESSAGE)
+
+ return state not in self.INTERMEDIATE_STATES
diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py
b/airflow/providers/amazon/aws/triggers/bedrock.py
new file mode 100644
index 0000000000..ae4805ed70
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/bedrock.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+
+class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger when a Bedrock model customization job is complete.
+
+ :param job_name: The name of the Bedrock model customization job.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 120)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 75)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ job_name: str,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ aws_conn_id: str | None = None,
+ ) -> None:
+ super().__init__(
+ serialized_fields={"job_name": job_name},
+ waiter_name="model_customization_job_complete",
+ waiter_args={"jobIdentifier": job_name},
+ failure_message="Bedrock model customization failed.",
+ status_message="Status of Bedrock model customization job is",
+ status_queries=["status"],
+ return_key="job_name",
+ return_value=job_name,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return BedrockHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/waiters/bedrock.json
b/airflow/providers/amazon/aws/waiters/bedrock.json
new file mode 100644
index 0000000000..c44b7c0589
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/bedrock.json
@@ -0,0 +1,42 @@
+{
+ "version": 2,
+ "waiters": {
+ "model_customization_job_complete": {
+ "delay": 120,
+ "maxAttempts": 75,
+ "operation": "GetModelCustomizationJob",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "InProgress",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopping",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "status",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 4c4f7cf597..dc072b324e 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -454,6 +454,9 @@ sensors:
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.sensors.batch
+ - integration-name: Amazon Bedrock
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.bedrock
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.sensors.cloud_formation
@@ -650,6 +653,9 @@ triggers:
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.triggers.batch
+ - integration-name: Amazon Bedrock
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.bedrock
- integration-name: Amazon EC2
python-modules:
- airflow.providers.amazon.aws.triggers.ec2
diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst
b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
index 3e84cbc445..411deba79f 100644
--- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst
+++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
@@ -65,6 +65,44 @@ To invoke an Amazon Titan model you would use:
For details on the different formats, see `Inference parameters for foundation
models
<https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__
+.. _howto/operator:BedrockCustomizeModelOperator:
+
+Customize an existing Amazon Bedrock Model
+==========================================
+
+To create a fine-tuning job to customize a base model, you can use
+:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCustomizeModelOperator`.
+
+Model-customization jobs are asynchronous and the completion time depends on
the base model
+and the training/validation data size. To monitor the state of the job, you
can use the
+"model_customization_job_complete" Waiter, the
+:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor`
Sensor,
+or the
:class:`~airflow.providers.amazon.aws.triggers.BedrockCustomizeModelCompletedTrigger`
Trigger.
+
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_customize_model]
+ :end-before: [END howto_operator_customize_model]
+
+
+Sensors
+-------
+
+.. _howto/sensor:BedrockCustomizeModelCompletedSensor:
+
+Wait for an Amazon Bedrock customize model job
+==============================================
+
+To wait on the state of an Amazon Bedrock customize model job until it reaches
a terminal state you can use
+:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor`
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_customize_model]
+ :end-before: [END howto_sensor_customize_model]
Reference
---------
diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py
b/tests/providers/amazon/aws/hooks/test_bedrock.py
index 73612aacbc..16752477d5 100644
--- a/tests/providers/amazon/aws/hooks/test_bedrock.py
+++ b/tests/providers/amazon/aws/hooks/test_bedrock.py
@@ -16,7 +16,41 @@
# under the License.
from __future__ import annotations
-from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from unittest import mock
+
+import pytest
+from botocore.exceptions import ClientError
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook,
BedrockRuntimeHook
+
+JOB_NAME = "testJobName"
+EXPECTED_STATUS = "InProgress"
+
+
[email protected]
+def mock_conn():
+ with mock.patch.object(BedrockHook, "conn") as _conn:
+ _conn.get_model_customization_job.return_value = {"jobName": JOB_NAME,
"status": EXPECTED_STATUS}
+ yield _conn
+
+
+class TestBedrockHook:
+ VALIDATION_EXCEPTION_ERROR = ClientError(
+ error_response={"Error": {"Code": "ValidationException", "Message":
""}},
+ operation_name="GetModelCustomizationJob",
+ )
+
+ UNEXPECTED_EXCEPTION = ClientError(
+ error_response={"Error": {"Code": "ExpiredTokenException", "Message":
""}},
+ operation_name="GetModelCustomizationJob",
+ )
+
+ def setup_method(self):
+ self.hook = BedrockHook()
+
+ def test_conn_returns_a_boto3_connection(self):
+ assert self.hook.conn is not None
+ assert self.hook.conn.meta.service_model.service_name == "bedrock"
class TestBedrockRuntimeHook:
diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py
b/tests/providers/amazon/aws/operators/test_bedrock.py
index f6274de48f..2371877b4d 100644
--- a/tests/providers/amazon/aws/operators/test_bedrock.py
+++ b/tests/providers/amazon/aws/operators/test_bedrock.py
@@ -18,42 +18,155 @@
from __future__ import annotations
import json
-from typing import Generator
+from typing import TYPE_CHECKING, Generator
from unittest import mock
import pytest
+from botocore.exceptions import ClientError
from moto import mock_aws
-from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
-from airflow.providers.amazon.aws.operators.bedrock import
BedrockInvokeModelOperator
-
-MODEL_ID = "meta.llama2-13b-chat-v1"
-PROMPT = "A very important question."
-GENERATED_RESPONSE = "An important answer."
-MOCK_RESPONSE = json.dumps(
- {
- "generation": GENERATED_RESPONSE,
- "prompt_token_count": len(PROMPT),
- "generation_token_count": len(GENERATED_RESPONSE),
- "stop_reason": "stop",
- }
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook,
BedrockRuntimeHook
+from airflow.providers.amazon.aws.operators.bedrock import (
+ BedrockCustomizeModelOperator,
+ BedrockInvokeModelOperator,
)
-
[email protected]
-def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]:
- with mock_aws():
- yield BedrockRuntimeHook(aws_conn_id="aws_default")
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
class TestBedrockInvokeModelOperator:
- @mock.patch.object(BedrockRuntimeHook, "conn")
- def test_invoke_model_prompt_good_combinations(self, mock_conn):
- mock_conn.invoke_model.return_value["body"].read.return_value =
MOCK_RESPONSE
+ MODEL_ID = "meta.llama2-13b-chat-v1"
+ TEST_PROMPT = "A very important question."
+ GENERATED_RESPONSE = "An important answer."
+
+ @pytest.fixture
+ def mock_runtime_conn(self) -> Generator[BaseAwsConnection, None, None]:
+ with mock.patch.object(BedrockRuntimeHook, "conn") as _conn:
+ _conn.invoke_model.return_value["body"].read.return_value =
json.dumps(
+ {
+ "generation": self.GENERATED_RESPONSE,
+ "prompt_token_count": len(self.TEST_PROMPT),
+ "generation_token_count": len(self.GENERATED_RESPONSE),
+ "stop_reason": "stop",
+ }
+ )
+ yield _conn
+
+ @pytest.fixture
+ def runtime_hook(self) -> Generator[BedrockRuntimeHook, None, None]:
+ with mock_aws():
+ yield BedrockRuntimeHook(aws_conn_id="aws_default")
+
+ def test_invoke_model_prompt_good_combinations(self, mock_runtime_conn):
operator = BedrockInvokeModelOperator(
- task_id="test_task", model_id=MODEL_ID, input_data={"input_data":
{"prompt": PROMPT}}
+ task_id="test_task",
+ model_id=self.MODEL_ID,
+ input_data={"input_data": {"prompt": self.TEST_PROMPT}},
)
response = operator.execute({})
- assert response["generation"] == GENERATED_RESPONSE
+ assert response["generation"] == self.GENERATED_RESPONSE
+
+
+class TestBedrockCustomizeModelOperator:
+ CUSTOMIZE_JOB_ARN = "valid_arn"
+ CUSTOMIZE_JOB_NAME = "testModelJob"
+
+ @pytest.fixture
+ def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
+ with mock.patch.object(BedrockHook, "conn") as _conn:
+ _conn.create_model_customization_job.return_value = {
+ "ResponseMetadata": {"HTTPStatusCode": 201},
+ "jobArn": self.CUSTOMIZE_JOB_ARN,
+ }
+ _conn.get_model_customization_job.return_value = {
+ "jobName": self.CUSTOMIZE_JOB_NAME,
+ "status": "InProgress",
+ }
+ yield _conn
+
+ @pytest.fixture
+ def bedrock_hook(self) -> Generator[BedrockHook, None, None]:
+ with mock_aws():
+ hook = BedrockHook(aws_conn_id="aws_default")
+ yield hook
+
+ def setup_method(self):
+ self.operator = BedrockCustomizeModelOperator(
+ task_id="test_task",
+ job_name=self.CUSTOMIZE_JOB_NAME,
+ custom_model_name="testModelName",
+ role_arn="valid_arn",
+ base_model_id="base_model_id",
+ hyperparameters={
+ "epochCount": "1",
+ "batchSize": "1",
+ "learningRate": ".0005",
+ "learningRateWarmupSteps": "0",
+ },
+ training_data_uri="s3://uri",
+ output_data_uri="s3://uri/output",
+ )
+ self.operator.defer = mock.MagicMock()
+
+ @pytest.mark.parametrize(
+ "wait_for_completion, deferrable",
+ [
+ pytest.param(False, False, id="no_wait"),
+ pytest.param(True, False, id="wait"),
+ pytest.param(False, True, id="defer"),
+ ],
+ )
+ @mock.patch.object(BedrockHook, "get_waiter")
+ def test_customize_model_wait_combinations(
+ self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook
+ ):
+ self.operator.wait_for_completion = wait_for_completion
+ self.operator.deferrable = deferrable
+
+ response = self.operator.execute({})
+
+ assert response == self.CUSTOMIZE_JOB_ARN
+ assert bedrock_hook.get_waiter.call_count == wait_for_completion
+ assert self.operator.defer.call_count == deferrable
+
+ conflict_msg = "The provided job name is currently in use."
+ conflict_exception = ClientError(
+ error_response={"Error": {"Message": conflict_msg, "Code":
"ValidationException"}},
+ operation_name="UnitTest",
+ )
+ success = {"ResponseMetadata": {"HTTPStatusCode": 201}, "jobArn":
CUSTOMIZE_JOB_ARN}
+
+ @pytest.mark.parametrize(
+ "side_effect, ensure_unique_name",
+ [
+ pytest.param([conflict_exception, success], True,
id="conflict_and_ensure_unique"),
+ pytest.param([conflict_exception, success], False,
id="conflict_and_not_ensure_unique"),
+ pytest.param(
+ [conflict_exception, conflict_exception, success],
+ True,
+ id="multiple_conflict_and_ensure_unique",
+ ),
+ pytest.param(
+ [conflict_exception, conflict_exception, success],
+ False,
+ id="multiple_conflict_and_not_ensure_unique",
+ ),
+ pytest.param([success], True, id="no_conflict_and_ensure_unique"),
+ pytest.param([success], False,
id="no_conflict_and_not_ensure_unique"),
+ ],
+ )
+ @mock.patch.object(BedrockHook, "get_waiter")
+ def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name,
mock_conn, bedrock_hook):
+ mock_conn.create_model_customization_job.side_effect = side_effect
+ expected_call_count = len(side_effect) if ensure_unique_name else 1
+ self.operator.wait_for_completion = False
+
+ response = self.operator.execute({})
+
+ assert response == self.CUSTOMIZE_JOB_ARN
+ mock_conn.create_model_customization_job.call_count ==
expected_call_count
+ bedrock_hook.get_waiter.assert_not_called()
+ self.operator.defer.assert_not_called()
diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py
b/tests/providers/amazon/aws/sensors/test_bedrock.py
new file mode 100644
index 0000000000..dab0f94ad3
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_bedrock.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.sensors.bedrock import
BedrockCustomizeModelCompletedSensor
+
+
[email protected]
+def mock_get_job_state():
+ with mock.patch.object(BedrockHook, "get_customize_model_job_state") as
mock_state:
+ yield mock_state
+
+
+class TestBedrockCustomizeModelCompletedSensor:
+ JOB_NAME = "test_job_name"
+
+ def setup_method(self):
+ self.default_op_kwargs = dict(
+ task_id="test_bedrock_customize_model_sensor",
+ job_name=self.JOB_NAME,
+ poke_interval=5,
+ max_retries=1,
+ )
+ self.sensor =
BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs, aws_conn_id=None)
+
+ def test_base_aws_op_attributes(self):
+ op = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ op = BedrockCustomizeModelCompletedSensor(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ @pytest.mark.parametrize("state",
list(BedrockCustomizeModelCompletedSensor.SUCCESS_STATES))
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_success_states(self, mock_conn, state):
+ mock_conn.get_model_customization_job.return_value = {"status": state}
+ assert self.sensor.poke({}) is True
+
+ @pytest.mark.parametrize("state",
list(BedrockCustomizeModelCompletedSensor.INTERMEDIATE_STATES))
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_intermediate_states(self, mock_conn, state):
+ mock_conn.get_model_customization_job.return_value = {"status": state}
+ assert self.sensor.poke({}) is False
+
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception",
+ [
+ pytest.param(False, AirflowException, id="not-soft-fail"),
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ ],
+ )
+ @pytest.mark.parametrize("state",
list(BedrockCustomizeModelCompletedSensor.FAILURE_STATES))
+ @mock.patch.object(BedrockHook, "conn")
+ def test_poke_failure_states(self, mock_conn, state, soft_fail,
expected_exception):
+ mock_conn.get_model_customization_job.return_value = {"status": state}
+ sensor = BedrockCustomizeModelCompletedSensor(
+ **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail
+ )
+
+ with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE):
+ sensor.poke({})
diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py
b/tests/providers/amazon/aws/triggers/test_bedrock.py
new file mode 100644
index 0000000000..0a54c56a77
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_bedrock.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.triggers.bedrock import
BedrockCustomizeModelCompletedTrigger
+from airflow.triggers.base import TriggerEvent
+
+BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock."
+
+
+class TestBedrockCustomizeModelCompletedTrigger:
+ JOB_NAME = "test_job"
+
+ def test_serialization(self):
+ """Assert that arguments and classpath are correctly serialized."""
+ trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME)
+ classpath, kwargs = trigger.serialize()
+ assert classpath == BASE_TRIGGER_CLASSPATH +
"BedrockCustomizeModelCompletedTrigger"
+ assert kwargs.get("job_name") == self.JOB_NAME
+
+ @pytest.mark.asyncio
+ @mock.patch.object(BedrockHook, "get_waiter")
+ @mock.patch.object(BedrockHook, "async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = mock.MagicMock()
+ mock_get_waiter().wait = AsyncMock()
+ trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME)
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "job_name":
self.JOB_NAME})
+ assert mock_get_waiter().wait.call_count == 1
diff --git a/tests/providers/amazon/aws/waiters/test_bedrock.py
b/tests/providers/amazon/aws/waiters/test_bedrock.py
new file mode 100644
index 0000000000..00521ee013
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_bedrock.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import botocore
+import pytest
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.sensors.bedrock import
BedrockCustomizeModelCompletedSensor
+
+
+class TestBedrockCustomWaiters:
+ def test_service_waiters(self):
+ assert "model_customization_job_complete" in
BedrockHook().list_waiters()
+
+
+class TestBedrockCustomWaitersBase:
+ @pytest.fixture(autouse=True)
+ def mock_conn(self, monkeypatch):
+ self.client = boto3.client("bedrock")
+ monkeypatch.setattr(BedrockHook, "conn", self.client)
+
+
+class TestModelCustomizationJobCompleteWaiter(TestBedrockCustomWaitersBase):
+ WAITER_NAME = "model_customization_job_complete"
+
+ @pytest.fixture
+ def mock_get_job(self):
+ with mock.patch.object(self.client, "get_model_customization_job") as
m:
+ yield m
+
+ @pytest.mark.parametrize("state",
BedrockCustomizeModelCompletedSensor.SUCCESS_STATES)
+ def test_model_customization_job_complete(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+ BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id")
+
+ @pytest.mark.parametrize("state",
BedrockCustomizeModelCompletedSensor.FAILURE_STATES)
+ def test_model_customization_job_failed(self, state, mock_get_job):
+ mock_get_job.return_value = {"status": state}
+
+ with pytest.raises(botocore.exceptions.WaiterError):
+
BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id")
+
+ def test_model_customization_job_wait(self, mock_get_job):
+ wait = {"status": "InProgress"}
+ success = {"status": "Completed"}
+ mock_get_job.side_effect = [wait, wait, success]
+
+ BedrockHook().get_waiter(self.WAITER_NAME).wait(
+ jobIdentifier="job_id", WaiterConfig={"Delay": 0.01,
"MaxAttempts": 3}
+ )
diff --git a/tests/system/providers/amazon/aws/example_bedrock.py
b/tests/system/providers/amazon/aws/example_bedrock.py
index e86e5a2e92..12e2461547 100644
--- a/tests/system/providers/amazon/aws/example_bedrock.py
+++ b/tests/system/providers/amazon/aws/example_bedrock.py
@@ -16,17 +16,61 @@
# under the License.
from __future__ import annotations
+import json
from datetime import datetime
+from botocore.exceptions import ClientError
+
+from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.operators.bedrock import
BedrockInvokeModelOperator
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
+from airflow.providers.amazon.aws.operators.bedrock import (
+ BedrockCustomizeModelOperator,
+ BedrockInvokeModelOperator,
+)
+from airflow.providers.amazon.aws.operators.s3 import (
+ S3CreateBucketOperator,
+ S3CreateObjectOperator,
+ S3DeleteBucketOperator,
+)
+from airflow.providers.amazon.aws.sensors.bedrock import
BedrockCustomizeModelCompletedSensor
+from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
-sys_test_context_task = SystemTestContextBuilder().build()
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"
+sys_test_context_task =
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
DAG_ID = "example_bedrock"
+
+# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True
then set
+# the trigger rule to an improbable state. This way we can still have the
code snippets
+# for docs, and we can manually run the full tests occasionally.
+SKIP_LONG_TASKS = True
+
+LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1"
PROMPT = "What color is an orange?"
+TITAN_MODEL_ID = "amazon.titan-text-express-v1"
+TRAIN_DATA = {"prompt": "what is AWS", "completion": "it's Amazon Web
Services"}
+HYPERPARAMETERS = {
+ "epochCount": "1",
+ "batchSize": "1",
+ "learningRate": ".0005",
+ "learningRateWarmupSteps": "0",
+}
+
+
+@task
+def delete_custom_model(model_name: str):
+ try:
+ BedrockHook().conn.delete_custom_model(modelIdentifier=model_name)
+ except ClientError as e:
+ if SKIP_LONG_TASKS and (e.response["Error"]["Code"] ==
"ValidationException"):
+ # There is no model to delete. Since we skipped making one,
that's fine.
+ return
+ raise e
+
with DAG(
dag_id=DAG_ID,
@@ -37,11 +81,28 @@ with DAG(
) as dag:
test_context = sys_test_context_task()
env_id = test_context["ENV_ID"]
+ bucket_name = f"{env_id}-bedrock"
+ input_data_s3_key = f"{env_id}/train.jsonl"
+ training_data_uri = f"s3://{bucket_name}/{input_data_s3_key}"
+ custom_model_name = f"CustomModel{env_id}"
+ custom_model_job_name = f"CustomizeModelJob{env_id}"
+
+ create_bucket = S3CreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=bucket_name,
+ )
+
+ upload_training_data = S3CreateObjectOperator(
+ task_id="upload_data",
+ s3_bucket=bucket_name,
+ s3_key=training_data_uri,
+ data=json.dumps(TRAIN_DATA),
+ )
# [START howto_operator_invoke_llama_model]
invoke_llama_model = BedrockInvokeModelOperator(
task_id="invoke_llama",
- model_id="meta.llama2-13b-chat-v1",
+ model_id=LLAMA_MODEL_ID,
input_data={"prompt": PROMPT},
)
# [END howto_operator_invoke_llama_model]
@@ -49,18 +110,55 @@ with DAG(
# [START howto_operator_invoke_titan_model]
invoke_titan_model = BedrockInvokeModelOperator(
task_id="invoke_titan",
- model_id="amazon.titan-text-express-v1",
+ model_id=TITAN_MODEL_ID,
input_data={"inputText": PROMPT},
)
# [END howto_operator_invoke_titan_model]
+ # [START howto_operator_customize_model]
+ customize_model = BedrockCustomizeModelOperator(
+ task_id="customize_model",
+ job_name=custom_model_job_name,
+ custom_model_name=custom_model_name,
+ role_arn=test_context[ROLE_ARN_KEY],
+
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
+ hyperparameters=HYPERPARAMETERS,
+ training_data_uri=training_data_uri,
+ output_data_uri=f"s3://{bucket_name}/myOutputData",
+ )
+ # [END howto_operator_customize_model]
+
+ # [START howto_sensor_customize_model]
+ await_custom_model_job = BedrockCustomizeModelCompletedSensor(
+ task_id="await_custom_model_job",
+ job_name=custom_model_job_name,
+ )
+ # [END howto_sensor_customize_model]
+
+ if SKIP_LONG_TASKS:
+ customize_model.trigger_rule = TriggerRule.ALL_SKIPPED
+ await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED
+
+ delete_bucket = S3DeleteBucketOperator(
+ task_id="delete_bucket",
+ trigger_rule=TriggerRule.ALL_DONE,
+ bucket_name=bucket_name,
+ force_delete=True,
+ )
+
chain(
# TEST SETUP
test_context,
+ create_bucket,
+ upload_training_data,
# TEST BODY
invoke_llama_model,
invoke_titan_model,
+ customize_model,
+ await_custom_model_job,
# TEST TEARDOWN
+ delete_custom_model(custom_model_name),
+ delete_bucket,
)
from tests.system.utils.watcher import watcher