This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b6ff085679 Amazon Bedrock - Fix system test (#38887)
b6ff085679 is described below

commit b6ff085679c283cd3ccc3edf20dd3e6b0eaec967
Author: D. Ferruzzi <ferru...@amazon.com>
AuthorDate: Wed Apr 10 10:40:34 2024 -0700

    Amazon Bedrock - Fix system test (#38887)
---
 .../system/providers/amazon/aws/example_bedrock.py | 92 +++++++++++-----------
 1 file changed, 46 insertions(+), 46 deletions(-)

diff --git a/tests/system/providers/amazon/aws/example_bedrock.py 
b/tests/system/providers/amazon/aws/example_bedrock.py
index 12e2461547..e25bbb8ed7 100644
--- a/tests/system/providers/amazon/aws/example_bedrock.py
+++ b/tests/system/providers/amazon/aws/example_bedrock.py
@@ -18,12 +18,12 @@ from __future__ import annotations
 
 import json
 from datetime import datetime
+from os import environ
 
-from botocore.exceptions import ClientError
-
-from airflow.decorators import task
+from airflow.decorators import task, task_group
 from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
+from airflow.operators.empty import EmptyOperator
 from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
 from airflow.providers.amazon.aws.operators.bedrock import (
     BedrockCustomizeModelOperator,
@@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.operators.s3 import (
     S3DeleteBucketOperator,
 )
 from airflow.providers.amazon.aws.sensors.bedrock import 
BedrockCustomizeModelCompletedSensor
+from airflow.utils.edgemodifier import Label
 from airflow.utils.trigger_rule import TriggerRule
 from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
 
@@ -44,10 +45,10 @@ sys_test_context_task = 
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).bu
 
 DAG_ID = "example_bedrock"
 
-# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True 
then set
-# the trigger rule to an improbable state.  This way we can still have the 
code snippets
-# for docs, and we can manually run the full tests occasionally.
-SKIP_LONG_TASKS = True
+# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS
+# is True then these tasks will be skipped. This way we can still have
+# the code snippets for docs, and we can manually run the full tests.
+SKIP_LONG_TASKS = environ.get("SKIP_LONG_SYSTEM_TEST_TASKS", default=True)
 
 LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1"
 PROMPT = "What color is an orange?"
@@ -61,15 +62,41 @@ HYPERPARAMETERS = {
 }
 
 
-@task
-def delete_custom_model(model_name: str):
-    try:
-        BedrockHook().conn.delete_custom_model(modelIdentifier=model_name)
-    except ClientError as e:
-        if SKIP_LONG_TASKS and (e.response["Error"]["Code"] == 
"ValidationException"):
-            # There is no model to delete.  Since we skipped making one, 
that's fine.
-            return
-        raise e
+@task_group
+def customize_model_workflow():
+    # [START howto_operator_customize_model]
+    customize_model = BedrockCustomizeModelOperator(
+        task_id="customize_model",
+        job_name=custom_model_job_name,
+        custom_model_name=custom_model_name,
+        role_arn=test_context[ROLE_ARN_KEY],
+        
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
+        hyperparameters=HYPERPARAMETERS,
+        training_data_uri=training_data_uri,
+        output_data_uri=f"s3://{bucket_name}/myOutputData",
+    )
+    # [END howto_operator_customize_model]
+
+    # [START howto_sensor_customize_model]
+    await_custom_model_job = BedrockCustomizeModelCompletedSensor(
+        task_id="await_custom_model_job",
+        job_name=custom_model_job_name,
+    )
+    # [END howto_sensor_customize_model]
+
+    @task
+    def delete_custom_model():
+        
BedrockHook().conn.delete_custom_model(modelIdentifier=custom_model_name)
+
+    @task.branch
+    def run_or_skip():
+        return end_workflow.task_id if SKIP_LONG_TASKS else 
customize_model.task_id
+
+    run_or_skip = run_or_skip()
+    end_workflow = EmptyOperator(task_id="end_workflow", 
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
+
+    chain(run_or_skip, Label("Long-running tasks skipped"), end_workflow)
+    chain(run_or_skip, customize_model, await_custom_model_job, 
delete_custom_model(), end_workflow)
 
 
 with DAG(
@@ -95,7 +122,7 @@ with DAG(
     upload_training_data = S3CreateObjectOperator(
         task_id="upload_data",
         s3_bucket=bucket_name,
-        s3_key=training_data_uri,
+        s3_key=input_data_s3_key,
         data=json.dumps(TRAIN_DATA),
     )
 
@@ -115,30 +142,6 @@ with DAG(
     )
     # [END howto_operator_invoke_titan_model]
 
-    # [START howto_operator_customize_model]
-    customize_model = BedrockCustomizeModelOperator(
-        task_id="customize_model",
-        job_name=custom_model_job_name,
-        custom_model_name=custom_model_name,
-        role_arn=test_context[ROLE_ARN_KEY],
-        
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
-        hyperparameters=HYPERPARAMETERS,
-        training_data_uri=training_data_uri,
-        output_data_uri=f"s3://{bucket_name}/myOutputData",
-    )
-    # [END howto_operator_customize_model]
-
-    # [START howto_sensor_customize_model]
-    await_custom_model_job = BedrockCustomizeModelCompletedSensor(
-        task_id="await_custom_model_job",
-        job_name=custom_model_job_name,
-    )
-    # [END howto_sensor_customize_model]
-
-    if SKIP_LONG_TASKS:
-        customize_model.trigger_rule = TriggerRule.ALL_SKIPPED
-        await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED
-
     delete_bucket = S3DeleteBucketOperator(
         task_id="delete_bucket",
         trigger_rule=TriggerRule.ALL_DONE,
@@ -152,12 +155,9 @@ with DAG(
         create_bucket,
         upload_training_data,
         # TEST BODY
-        invoke_llama_model,
-        invoke_titan_model,
-        customize_model,
-        await_custom_model_job,
+        [invoke_llama_model, invoke_titan_model],
+        customize_model_workflow(),
         # TEST TEARDOWN
-        delete_custom_model(custom_model_name),
         delete_bucket,
     )
 

Reply via email to