vincbeck commented on code in PR #33219:
URL: https://github.com/apache/airflow/pull/33219#discussion_r1295980649
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1523,3 +1524,243 @@ def execute(self, context: Context) -> str:
arn = ans["ExperimentArn"]
self.log.info("Experiment %s created successfully with ARN %s.",
self.name, arn)
return arn
+
+
+class SageMakerCreateNotebookOperator(BaseOperator):
+ """
+ Create a SageMaker notebook.
+
+ More information regarding parameters of this operator can be found here
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_notebook_instance.html.
+
+ .. seealso:
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerCreateNotebookOperator`
+
+ :param instance_name: The name of the notebook instance.
+ :param instance_type: The type of instance to create.
+ :param role_arn: The Amazon Resource Name (ARN) of the IAM role that
SageMaker can assume to access
+ :param volume_size_in_gb: Size in GB of the EBS root device volume of the
notebook instance.
+ :param volume_kms_key_id: The KMS key ID for the EBS root device volume.
+ :param lifecycle_config_name: The name of the lifecycle configuration to
associate with the notebook
+ :param direct_internet_access: Whether to enable direct internet access
for the notebook instance.
+ :param root_access: Whether to give the notebook instance root access to
the Amazon S3 bucket.
+ :param wait_for_completion: Whether or not to wait for the notebook to be
InService before returning
+ :param create_instance_kwargs: Additional configuration options for the
create call.
+ :param aws_conn_id: The AWS connection ID to use.
+
+ :return: The ARN of the created notebook.
+ """
+
+ template_fields: Sequence[str] = (
+ "instance_name",
+ "instance_type",
+ "role_arn",
+ "volume_size_in_gb",
+ "volume_kms_key_id",
+ "lifecycle_config_name",
+ "direct_internet_access",
+ "root_access",
+ "wait_for_completion",
+ "create_instance_kwargs",
+ )
+
+ ui_color = "#ff7300"
+
+ def __init__(
+ self,
+ *,
+ instance_name: str,
+ instance_type: str,
+ role_arn: str,
+ volume_size_in_gb: int | None = None,
+ volume_kms_key_id: str | None = None,
+ lifecycle_config_name: str | None = None,
+ direct_internet_access: str | None = None,
+ root_access: str | None = None,
+ create_instance_kwargs: dict[str, Any] = {},
+ wait_for_completion: bool = True,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_name = instance_name
+ self.instance_type = instance_type
+ self.role_arn = role_arn
+ self.volume_size_in_gb = volume_size_in_gb
+ self.volume_kms_key_id = volume_kms_key_id
+ self.lifecycle_config_name = lifecycle_config_name
+ self.direct_internet_access = direct_internet_access
+ self.root_access = root_access
+ self.wait_for_completion = wait_for_completion
+ self.aws_conn_id = aws_conn_id
+ self.create_instance_kwargs = create_instance_kwargs
+
+ if self.create_instance_kwargs.get("tags") is not None:
+ self.create_instance_kwargs["tags"] =
format_tags(self.create_instance_kwargs["tags"])
+
+ @cached_property
+ def hook(self) -> SageMakerHook:
+ """Create and return SageMakerHook."""
+ return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: Context):
+
+ create_notebook_instance_kwargs = {
+ "NotebookInstanceName": self.instance_name,
+ "InstanceType": self.instance_type,
+ "RoleArn": self.role_arn,
+ "VolumeSizeInGB": self.volume_size_in_gb,
+ "KmsKeyId": self.volume_kms_key_id,
+ "LifecycleConfigName": self.lifecycle_config_name,
+ "DirectInternetAccess": self.direct_internet_access,
+ "RootAccess": self.root_access,
+ }
+ if len(self.create_instance_kwargs) > 0:
+ create_notebook_instance_kwargs.update(self.create_instance_kwargs)
+
+ self.log.info("Creating SageMaker notebook %s.", self.instance_name)
+ response =
self.hook.conn.create_notebook_instance(**prune_dict(create_notebook_instance_kwargs))
+
+ self.log.info("SageMaker notebook created: %s",
response["NotebookInstanceArn"])
+
+ if self.wait_for_completion:
+ self.log.info("Waiting for SageMaker notebook %s to be in
service", self.instance_name)
+ waiter = self.hook.conn.get_waiter("notebook_instance_in_service")
+ waiter.wait(NotebookInstanceName=self.instance_name)
+
+ return response["NotebookInstanceArn"]
+
+
+class SageMakerStopNotebookOperator(BaseOperator):
+ """
+ Stop a notebook instance.
+
+ .. seealso:
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerStopNotebookOperator`
+
+ :param instance_name: The name of the notebook instance to stop.
+ :param aws_conn_id: The AWS connection ID to use.
+ """
+
+ template_fields: Sequence[str] = ("instance_name", "wait_for_completion",
"config")
+
+ ui_color = "#ff7300"
+
+ def __init__(
+ self,
+ instance_name: str,
+ wait_for_completion: bool = True,
Review Comment:
Missing in the docstring
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]