This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b6ff085679 Amazon Bedrock - Fix system test (#38887)
b6ff085679 is described below
commit b6ff085679c283cd3ccc3edf20dd3e6b0eaec967
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed Apr 10 10:40:34 2024 -0700
Amazon Bedrock - Fix system test (#38887)
---
.../system/providers/amazon/aws/example_bedrock.py | 92 +++++++++++-----------
1 file changed, 46 insertions(+), 46 deletions(-)
diff --git a/tests/system/providers/amazon/aws/example_bedrock.py
b/tests/system/providers/amazon/aws/example_bedrock.py
index 12e2461547..e25bbb8ed7 100644
--- a/tests/system/providers/amazon/aws/example_bedrock.py
+++ b/tests/system/providers/amazon/aws/example_bedrock.py
@@ -18,12 +18,12 @@ from __future__ import annotations
import json
from datetime import datetime
+from os import environ
-from botocore.exceptions import ClientError
-
-from airflow.decorators import task
+from airflow.decorators import task, task_group
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
+from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
from airflow.providers.amazon.aws.operators.bedrock import (
BedrockCustomizeModelOperator,
@@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.operators.s3 import (
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.sensors.bedrock import
BedrockCustomizeModelCompletedSensor
+from airflow.utils.edgemodifier import Label
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
@@ -44,10 +45,10 @@ sys_test_context_task =
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).bu
DAG_ID = "example_bedrock"
-# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True
then set
-# the trigger rule to an improbable state. This way we can still have the
code snippets
-# for docs, and we can manually run the full tests occasionally.
-SKIP_LONG_TASKS = True
+# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS
+# is True then these tasks will be skipped. This way we can still have
+# the code snippets for docs, and we can manually run the full tests.
+SKIP_LONG_TASKS = environ.get("SKIP_LONG_SYSTEM_TEST_TASKS", default=True)
LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1"
PROMPT = "What color is an orange?"
@@ -61,15 +62,41 @@ HYPERPARAMETERS = {
}
-@task
-def delete_custom_model(model_name: str):
- try:
- BedrockHook().conn.delete_custom_model(modelIdentifier=model_name)
- except ClientError as e:
- if SKIP_LONG_TASKS and (e.response["Error"]["Code"] ==
"ValidationException"):
- # There is no model to delete. Since we skipped making one,
that's fine.
- return
- raise e
+@task_group
+def customize_model_workflow():
+ # [START howto_operator_customize_model]
+ customize_model = BedrockCustomizeModelOperator(
+ task_id="customize_model",
+ job_name=custom_model_job_name,
+ custom_model_name=custom_model_name,
+ role_arn=test_context[ROLE_ARN_KEY],
+
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
+ hyperparameters=HYPERPARAMETERS,
+ training_data_uri=training_data_uri,
+ output_data_uri=f"s3://{bucket_name}/myOutputData",
+ )
+ # [END howto_operator_customize_model]
+
+ # [START howto_sensor_customize_model]
+ await_custom_model_job = BedrockCustomizeModelCompletedSensor(
+ task_id="await_custom_model_job",
+ job_name=custom_model_job_name,
+ )
+ # [END howto_sensor_customize_model]
+
+ @task
+ def delete_custom_model():
+
BedrockHook().conn.delete_custom_model(modelIdentifier=custom_model_name)
+
+ @task.branch
+ def run_or_skip():
+ return end_workflow.task_id if SKIP_LONG_TASKS else
customize_model.task_id
+
+ run_or_skip = run_or_skip()
+ end_workflow = EmptyOperator(task_id="end_workflow",
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
+
+ chain(run_or_skip, Label("Long-running tasks skipped"), end_workflow)
+ chain(run_or_skip, customize_model, await_custom_model_job,
delete_custom_model(), end_workflow)
with DAG(
@@ -95,7 +122,7 @@ with DAG(
upload_training_data = S3CreateObjectOperator(
task_id="upload_data",
s3_bucket=bucket_name,
- s3_key=training_data_uri,
+ s3_key=input_data_s3_key,
data=json.dumps(TRAIN_DATA),
)
@@ -115,30 +142,6 @@ with DAG(
)
# [END howto_operator_invoke_titan_model]
- # [START howto_operator_customize_model]
- customize_model = BedrockCustomizeModelOperator(
- task_id="customize_model",
- job_name=custom_model_job_name,
- custom_model_name=custom_model_name,
- role_arn=test_context[ROLE_ARN_KEY],
-
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
- hyperparameters=HYPERPARAMETERS,
- training_data_uri=training_data_uri,
- output_data_uri=f"s3://{bucket_name}/myOutputData",
- )
- # [END howto_operator_customize_model]
-
- # [START howto_sensor_customize_model]
- await_custom_model_job = BedrockCustomizeModelCompletedSensor(
- task_id="await_custom_model_job",
- job_name=custom_model_job_name,
- )
- # [END howto_sensor_customize_model]
-
- if SKIP_LONG_TASKS:
- customize_model.trigger_rule = TriggerRule.ALL_SKIPPED
- await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED
-
delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
trigger_rule=TriggerRule.ALL_DONE,
@@ -152,12 +155,9 @@ with DAG(
create_bucket,
upload_training_data,
# TEST BODY
- invoke_llama_model,
- invoke_titan_model,
- customize_model,
- await_custom_model_job,
+ [invoke_llama_model, invoke_titan_model],
+ customize_model_workflow(),
# TEST TEARDOWN
- delete_custom_model(custom_model_name),
delete_bucket,
)