ferruzzi commented on code in PR #38693:
URL: https://github.com/apache/airflow/pull/38693#discussion_r1552560048
##########
airflow/providers/amazon/aws/operators/bedrock.py:
##########
@@ -91,3 +96,155 @@ def execute(self, context: Context) -> dict[str, str | int]:
self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
self.log.info("Bedrock model response: %s", response_body)
return response_body
+
+
+class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]):
+ """
+ Create a fine-tuning job to customize a base model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:BedrockCustomizeModelOperator`
+
+ :param job_name: A unique name for the fine-tuning job.
+ :param custom_model_name: A name for the custom model being created.
+ :param role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon
Bedrock can assume
+ to perform tasks on your behalf.
+ :param base_model_id: Name of the base model.
+ :param training_data_uri: The S3 URI where the training data is stored.
+ :param output_data_uri: The S3 URI where the output data is stored.
+ :param hyperparameters: Parameters related to tuning the model.
+ :param check_if_job_exists: If set to true, operator will check whether a
model customization
+ job already exists for the name in the config. (Default: True)
+ :param action_if_job_exists: Behavior if the job name already exists.
Options are "timestamp" (default),
+ and "fail"
+ :param customization_job_kwargs: Any optional parameters to pass to the
API.
+
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ :param waiter_delay: Time in seconds to wait between status checks.
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion.
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ """
+
+ aws_hook_class = BedrockHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "job_name",
+ "custom_model_name",
+ "role_arn",
+ "base_model_id",
+ "hyperparameters",
+ "check_if_job_exists",
+ "action_if_job_exists",
+ "customization_job_kwargs",
+ )
+
+ def __init__(
+ self,
+ job_name: str,
+ custom_model_name: str,
+ role_arn: str,
+ base_model_id: str,
+ training_data_uri: str,
+ output_data_uri: str,
+ hyperparameters: dict[str, str],
+ check_if_job_exists: bool = True,
+ action_if_job_exists: str = "timestamp",
+ customization_job_kwargs: dict[str, Any] | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 120,
+ waiter_max_attempts: int = 75,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+
+ self.job_name = job_name
+ self.custom_model_name = custom_model_name
+ self.role_arn = role_arn
+ self.base_model_id = base_model_id
+ self.training_data_config = {"s3Uri": training_data_uri}
+ self.output_data_config = {"s3Uri": output_data_uri}
+ self.hyperparameters = hyperparameters
+ self.check_if_job_exists = check_if_job_exists
+ self.customization_job_kwargs = customization_job_kwargs or {}
+ self.action_if_job_exists = action_if_job_exists.lower()
+
+ self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"}
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ event = validate_execute_complete_event(event)
+
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+
+ self.log.info("Bedrock model customization job `%s` complete.",
self.job_name)
+ return self.hook.get_job_arn(event["job_name"])
+
+ def _validate_action_if_job_exists(self):
+ if self.action_if_job_exists not in self.valid_action_if_job_exists:
+ raise AirflowException(
+ f"Invalid value for argument action_if_job_exists
{self.action_if_job_exists}; "
+ f"must be one of: {self.valid_action_if_job_exists}."
+ )
+
+ def execute(self, context: Context) -> dict:
+ self._validate_action_if_job_exists()
+
+ if self.check_if_job_exists and
self.hook.job_name_exists(self.job_name):
+ if self.action_if_job_exists == "fail":
+ raise AirflowException(f"A Bedrock job with name
{self.job_name} already exists.")
+ self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}"
+ self.log.info("Changed job name to '%s' to avoid collision.",
self.job_name)
+
+ self.log.info("Creating Bedrock model customization job '%s'.",
self.job_name)
+
+ response = self.hook.conn.create_model_customization_job(
+ jobName=self.job_name,
+ customModelName=self.custom_model_name,
+ roleArn=self.role_arn,
+ baseModelIdentifier=self.base_model_id,
+ trainingDataConfig=self.training_data_config,
+ outputDataConfig=self.output_data_config,
+ hyperParameters=self.hyperparameters,
+ **self.customization_job_kwargs,
Review Comment:
Sorry, I don't follow. Is the suggestion to build something like a
"job_parameters" dict in the init and just pass/unpack that here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]