This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0723a8f01d Introduce Amazon Bedrock service (#38602)
0723a8f01d is described below
commit 0723a8f01d1bc9eb62324a222ba34b82a8d8252c
Author: D. Ferruzzi <[email protected]>
AuthorDate: Sat Mar 30 01:54:42 2024 -0700
Introduce Amazon Bedrock service (#38602)
* Introduce Amazon Bedrock service
---
airflow/providers/amazon/aws/hooks/bedrock.py | 39 +++++++++
airflow/providers/amazon/aws/operators/bedrock.py | 93 +++++++++++++++++++++
airflow/providers/amazon/provider.yaml | 12 +++
.../operators/bedrock.rst | 72 ++++++++++++++++
.../aws/[email protected] | Bin 0 -> 12621 bytes
tests/providers/amazon/aws/hooks/test_bedrock.py | 27 ++++++
.../providers/amazon/aws/operators/test_bedrock.py | 59 +++++++++++++
.../system/providers/amazon/aws/example_bedrock.py | 76 +++++++++++++++++
8 files changed, 378 insertions(+)
diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py
b/airflow/providers/amazon/aws/hooks/bedrock.py
new file mode 100644
index 0000000000..11bacd9414
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/bedrock.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class BedrockRuntimeHook(AwsBaseHook):
+ """
+ Interact with the Amazon Bedrock Runtime.
+
+ Provide thin wrapper around
:external+boto3:py:class:`boto3.client("bedrock-runtime")
<BedrockRuntime.Client>`.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ client_type = "bedrock-runtime"
+
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs["client_type"] = self.client_type
+ super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/bedrock.py
b/airflow/providers/amazon/aws/operators/bedrock.py
new file mode 100644
index 0000000000..d8eaf9e5d3
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/bedrock.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.helpers import prune_dict
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]):
+ """
+ Invoke the specified Bedrock model to run inference using the input
provided.
+
+ Use InvokeModel to run inference for text models, image models, and
embedding models.
+ To see the format and content of the input_data field for different
models, refer to
+ `Inference parameters docs
<https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`_.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockInvokeModelOperator`
+
+ :param model_id: The ID of the Bedrock model. (templated)
+ :param input_data: Input data in the format specified in the content-type
request header. (templated)
+ :param content_type: The MIME type of the input data in the request.
(templated) Default: application/json
+ :param accept: The desired MIME type of the inference body in the response.
+ (templated) Default: application/json
+
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockRuntimeHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "model_id", "input_data", "content_type", "accept_type"
+ )
+
+ def __init__(
+ self,
+ model_id: str,
+ input_data: dict[str, Any],
+ content_type: str | None = None,
+ accept_type: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.model_id = model_id
+ self.input_data = input_data
+ self.content_type = content_type
+ self.accept_type = accept_type
+
+ def execute(self, context: Context) -> dict[str, str | int]:
+ # These are optional values which the API defaults to
"application/json" if not provided here.
+ invoke_kwargs = prune_dict({"contentType": self.content_type,
"accept": self.accept_type})
+
+ response = self.hook.conn.invoke_model(
+ body=json.dumps(self.input_data),
+ modelId=self.model_id,
+ **invoke_kwargs,
+ )
+
+ response_body = json.loads(response["body"].read())
+ self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
+ self.log.info("Bedrock model response: %s", response_body)
+ return response_body
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index e2b0df930e..4c4f7cf597 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -142,6 +142,12 @@ integrations:
- /docs/apache-airflow-providers-amazon/operators/athena/athena_boto.rst
- /docs/apache-airflow-providers-amazon/operators/athena/athena_sql.rst
tags: [aws]
+ - integration-name: Amazon Bedrock
+ external-doc-url: https://aws.amazon.com/bedrock/
+ logo: /integration-logos/aws/[email protected]
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/bedrock.rst
+ tags: [aws]
- integration-name: Amazon Chime
external-doc-url: https://aws.amazon.com/chime/
logo: /integration-logos/aws/Amazon-Chime-light-bg.png
@@ -363,6 +369,9 @@ operators:
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.operators.batch
+ - integration-name: Amazon Bedrock
+ python-modules:
+ - airflow.providers.amazon.aws.operators.bedrock
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.operators.cloud_formation
@@ -514,6 +523,9 @@ hooks:
python-modules:
- airflow.providers.amazon.aws.hooks.athena
- airflow.providers.amazon.aws.hooks.athena_sql
+ - integration-name: Amazon Bedrock
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.bedrock
- integration-name: Amazon Chime
python-modules:
- airflow.providers.amazon.aws.hooks.chime
diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst
b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
new file mode 100644
index 0000000000..3e84cbc445
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
@@ -0,0 +1,72 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+==============
+Amazon Bedrock
+==============
+
+`Amazon Bedrock <https://aws.amazon.com/bedrock/>`__ is a fully managed
service that
+offers a choice of high-performing foundation models (FMs) from leading AI
companies
+like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon
via a
+single API, along with a broad set of capabilities you need to build
generative AI
+applications with security, privacy, and responsible AI.
+
+Prerequisite Tasks
+------------------
+
+.. include:: ../_partials/prerequisite_tasks.rst
+
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
+Operators
+---------
+
+.. _howto/operator:BedrockInvokeModelOperator:
+
+Invoke an existing Amazon Bedrock Model
+=======================================
+
+To invoke an existing Amazon Bedrock model, you can use
+:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`.
+
+Note that every model family has different input and output formats.
+For example, to invoke a Meta Llama model you would use:
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_invoke_llama_model]
+ :end-before: [END howto_operator_invoke_llama_model]
+
+To invoke an Amazon Titan model you would use:
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_invoke_titan_model]
+ :end-before: [END howto_operator_invoke_titan_model]
+
+For details on the different formats, see `Inference parameters for foundation
models
<https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__
+
+
+Reference
+---------
+
+* `AWS boto3 library documentation for Amazon Bedrock
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock.html>`__
diff --git a/docs/integration-logos/aws/[email protected]
b/docs/integration-logos/aws/[email protected]
new file mode 100644
index 0000000000..e6af4b7276
Binary files /dev/null and
b/docs/integration-logos/aws/[email protected] differ
diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py
b/tests/providers/amazon/aws/hooks/test_bedrock.py
new file mode 100644
index 0000000000..73612aacbc
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_bedrock.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+
+
+class TestBedrockRuntimeHook:
+ def test_conn_returns_a_boto3_connection(self):
+ hook = BedrockRuntimeHook()
+
+ assert hook.conn is not None
+ assert hook.conn.meta.service_model.service_name == "bedrock-runtime"
diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py
b/tests/providers/amazon/aws/operators/test_bedrock.py
new file mode 100644
index 0000000000..f6274de48f
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_bedrock.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+from typing import Generator
+from unittest import mock
+
+import pytest
+from moto import mock_aws
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from airflow.providers.amazon.aws.operators.bedrock import
BedrockInvokeModelOperator
+
+MODEL_ID = "meta.llama2-13b-chat-v1"
+PROMPT = "A very important question."
+GENERATED_RESPONSE = "An important answer."
+MOCK_RESPONSE = json.dumps(
+ {
+ "generation": GENERATED_RESPONSE,
+ "prompt_token_count": len(PROMPT),
+ "generation_token_count": len(GENERATED_RESPONSE),
+ "stop_reason": "stop",
+ }
+)
+
+
[email protected]
+def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]:
+ with mock_aws():
+ yield BedrockRuntimeHook(aws_conn_id="aws_default")
+
+
+class TestBedrockInvokeModelOperator:
+ @mock.patch.object(BedrockRuntimeHook, "conn")
+ def test_invoke_model_prompt_good_combinations(self, mock_conn):
+ mock_conn.invoke_model.return_value["body"].read.return_value =
MOCK_RESPONSE
+ operator = BedrockInvokeModelOperator(
+ task_id="test_task", model_id=MODEL_ID, input_data={"input_data":
{"prompt": PROMPT}}
+ )
+
+ response = operator.execute({})
+
+ assert response["generation"] == GENERATED_RESPONSE
diff --git a/tests/system/providers/amazon/aws/example_bedrock.py
b/tests/system/providers/amazon/aws/example_bedrock.py
new file mode 100644
index 0000000000..e86e5a2e92
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_bedrock.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.operators.bedrock import
BedrockInvokeModelOperator
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+DAG_ID = "example_bedrock"
+PROMPT = "What color is an orange?"
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ tags=["example"],
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+ env_id = test_context["ENV_ID"]
+
+ # [START howto_operator_invoke_llama_model]
+ invoke_llama_model = BedrockInvokeModelOperator(
+ task_id="invoke_llama",
+ model_id="meta.llama2-13b-chat-v1",
+ input_data={"prompt": PROMPT},
+ )
+ # [END howto_operator_invoke_llama_model]
+
+ # [START howto_operator_invoke_titan_model]
+ invoke_titan_model = BedrockInvokeModelOperator(
+ task_id="invoke_titan",
+ model_id="amazon.titan-text-express-v1",
+ input_data={"inputText": PROMPT},
+ )
+ # [END howto_operator_invoke_titan_model]
+
+ chain(
+ # TEST SETUP
+ test_context,
+ # TEST BODY
+ invoke_llama_model,
+ invoke_titan_model,
+ # TEST TEARDOWN
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)