Taragolis commented on code in PR #38693:
URL: https://github.com/apache/airflow/pull/38693#discussion_r1549147317


##########
airflow/providers/amazon/aws/hooks/bedrock.py:
##########
@@ -16,9 +16,54 @@
 # under the License.
 from __future__ import annotations
 
+from botocore.exceptions import ClientError
+
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 
 
+class BedrockHook(AwsBaseHook):
+    """
+    Interact with Amazon Bedrock.
+
+    Provide thin wrapper around 
:external+boto3:py:class:`boto3.client("bedrock") <Bedrock.Client>`.
+
+    Additional arguments (such as ``aws_conn_id``) may be specified and
+    are passed down to the underlying AwsBaseHook.
+
+    .. seealso::
+        - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+    """
+
+    client_type = "bedrock"
+
+    def __init__(self, *args, **kwargs) -> None:
+        kwargs["client_type"] = self.client_type
+        super().__init__(*args, **kwargs)
+
+    def _get_job_by_name(self, job_name: str):
+        return self.conn.get_model_customization_job(jobIdentifier=job_name)
+
+    def get_customize_model_job_state(self, job_name) -> str:
+        state = self._get_job_by_name(job_name)["status"]
+        self.log.info("Job '%s' state: %s", job_name, state)
+        return state
+
+    def job_name_exists(self, job_name: str) -> bool:
+        try:
+            self._get_job_by_name(job_name)
+            self.log.info("Verified that job name '%s' does exist.", job_name)
+            return True
+        except ClientError as e:
+            if e.response["Error"]["Code"] == "ValidationException":
+                self.log.info("Job name '%s' does not exist.", job_name)
+                return False
+            else:
+                raise e

Review Comment:
   ```suggestion
               raise
   ```



##########
airflow/providers/amazon/aws/waiters/bedrock.json:
##########


Review Comment:
   Worthwhile to tests waiter separately, time to time we have invalid 
definitions in it:
   
https://github.com/apache/airflow/tree/main/tests/providers/amazon/aws/waiters 



##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -91,3 +96,150 @@ def execute(self, context: Context) -> dict[str, str | int]:
         self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
         self.log.info("Bedrock model response: %s", response_body)
         return response_body
+
+
+class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]):
+    """
+    Create a fine-tuning job to customize a base model.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BedrockCustomizeModelOperator`
+
+    :param job_name: A unique name for the fine-tuning job.
+    :param custom_model_name: A name for the custom model being created.
+    :param role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon 
Bedrock can assume
+        to perform tasks on your behalf.
+    :param base_model_id: Name of the base model.
+    :param training_data_uri: The S3 URI where the training data is stored.
+    :param output_data_uri: The S3 URI where the output data is stored.
+    :param hyperparameters: Parameters related to tuning the model.
+    :param check_if_job_exists: If set to true, operator will check whether a 
model customization
+        job already exists for the name in the config. (Default: True)
+    :param action_if_job_exists: Behavior if the job name already exists. 
Options are "timestamp" (default),
+        and "fail"
+    :param customization_job_kwargs: Any optional parameters to pass to the 
API.
+
+    :param wait_for_completion: Whether to wait for cluster to stop. (default: 
True)
+    :param waiter_delay: Time in seconds to wait between status checks.
+    :param waiter_max_attempts: Maximum number of attempts to check for job 
completion.
+    :param deferrable: If True, the operator will wait asynchronously for the 
cluster to stop.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    aws_hook_class = BedrockHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "job_name",
+        "custom_model_name",
+        "role_arn",
+        "base_model_id",
+        "hyperparameters",
+        "check_if_job_exists",
+        "action_if_job_exists",
+        "customization_job_kwargs",
+    )
+
+    def __init__(
+        self,
+        job_name: str,
+        custom_model_name: str,
+        role_arn: str,
+        base_model_id: str,
+        training_data_uri: str,
+        output_data_uri: str,
+        hyperparameters: dict[str, str],
+        check_if_job_exists: bool = True,
+        action_if_job_exists: str = "timestamp",
+        customization_job_kwargs: dict[str, Any] | None = None,
+        wait_for_completion: bool = True,
+        waiter_delay: int = 120,
+        waiter_max_attempts: int = 75,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
+
+        self.job_name = job_name
+        self.custom_model_name = custom_model_name
+        self.role_arn = role_arn
+        self.base_model_id = base_model_id
+        self.training_data_config = {"s3Uri": training_data_uri}
+        self.output_data_config = {"s3Uri": output_data_uri}
+        self.hyperparameters = hyperparameters
+        self.check_if_job_exists = check_if_job_exists
+        self.customization_job_kwargs = customization_job_kwargs or {}
+        if action_if_job_exists in {"timestamp", "fail"}:
+            self.action_if_job_exists = action_if_job_exists
+        else:
+            raise AirflowException(
+                f"Argument action_if_job_exists accepts only 'timestamp', and 
'fail'. \
+                Provided value: '{action_if_job_exists}."
+            )

Review Comment:
   `action_if_job_exists` is a templated field, so better move this validation 
into place after templates are rendered, e.g. into the execute method.
   
   Pros: It does not raise random error on Jinja template or XComArg (or 
simmilar)
   Cons: No fast fail



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to