This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 223b41d68f Added Amazon SageMaker Notebook hook and operators (#33219)
223b41d68f is described below

commit 223b41d68f53e7aa76588ffb8ba1e37e780d9e3b
Author: ellisms <[email protected]>
AuthorDate: Wed Aug 16 12:53:33 2023 -0400

    Added Amazon SageMaker Notebook hook and operators (#33219)
    
    
    
    ---------
    
    Co-authored-by: Vincent <[email protected]>
---
 .../providers/amazon/aws/operators/sagemaker.py    | 242 +++++++++++++++++++++
 .../operators/sagemaker.rst                        |  57 +++++
 .../aws/operators/test_sagemaker_notebook.py       | 165 ++++++++++++++
 .../amazon/aws/example_sagemaker_notebook.py       | 109 ++++++++++
 4 files changed, 573 insertions(+)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 376c1051fb..ce0fa6f7c5 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -37,6 +37,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import (
 from airflow.providers.amazon.aws.utils import trim_none_values
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
 from airflow.providers.amazon.aws.utils.tags import format_tags
+from airflow.utils.helpers import prune_dict
 from airflow.utils.json import AirflowJsonEncoder
 
 if TYPE_CHECKING:
@@ -1523,3 +1524,244 @@ class 
SageMakerCreateExperimentOperator(SageMakerBaseOperator):
         arn = ans["ExperimentArn"]
         self.log.info("Experiment %s created successfully with ARN %s.", 
self.name, arn)
         return arn
+
+
+class SageMakerCreateNotebookOperator(BaseOperator):
+    """
+    Create a SageMaker notebook.
+
+    More information regarding parameters of this operator can be found here
+    
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_notebook_instance.html.
+
+    .. seealso:
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerCreateNotebookOperator`
+
+    :param instance_name: The name of the notebook instance.
+    :param instance_type: The type of instance to create.
+    :param role_arn: The Amazon Resource Name (ARN) of the IAM role that 
SageMaker can assume to access
+    :param volume_size_in_gb: Size in GB of the EBS root device volume of the 
notebook instance.
+    :param volume_kms_key_id: The KMS key ID for the EBS root device volume.
+    :param lifecycle_config_name: The name of the lifecycle configuration to 
associate with the notebook
+    :param direct_internet_access: Whether to enable direct internet access 
for the notebook instance.
+    :param root_access: Whether to give the notebook instance root access to 
the Amazon S3 bucket.
+    :param wait_for_completion: Whether or not to wait for the notebook to be 
InService before returning
+    :param create_instance_kwargs: Additional configuration options for the 
create call.
+    :param aws_conn_id: The AWS connection ID to use.
+
+    :return: The ARN of the created notebook.
+    """
+
+    template_fields: Sequence[str] = (
+        "instance_name",
+        "instance_type",
+        "role_arn",
+        "volume_size_in_gb",
+        "volume_kms_key_id",
+        "lifecycle_config_name",
+        "direct_internet_access",
+        "root_access",
+        "wait_for_completion",
+        "create_instance_kwargs",
+    )
+
+    ui_color = "#ff7300"
+
+    def __init__(
+        self,
+        *,
+        instance_name: str,
+        instance_type: str,
+        role_arn: str,
+        volume_size_in_gb: int | None = None,
+        volume_kms_key_id: str | None = None,
+        lifecycle_config_name: str | None = None,
+        direct_internet_access: str | None = None,
+        root_access: str | None = None,
+        create_instance_kwargs: dict[str, Any] = {},
+        wait_for_completion: bool = True,
+        aws_conn_id: str = "aws_default",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_name = instance_name
+        self.instance_type = instance_type
+        self.role_arn = role_arn
+        self.volume_size_in_gb = volume_size_in_gb
+        self.volume_kms_key_id = volume_kms_key_id
+        self.lifecycle_config_name = lifecycle_config_name
+        self.direct_internet_access = direct_internet_access
+        self.root_access = root_access
+        self.wait_for_completion = wait_for_completion
+        self.aws_conn_id = aws_conn_id
+        self.create_instance_kwargs = create_instance_kwargs
+
+        if self.create_instance_kwargs.get("tags") is not None:
+            self.create_instance_kwargs["tags"] = 
format_tags(self.create_instance_kwargs["tags"])
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        """Create and return SageMakerHook."""
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    def execute(self, context: Context):
+
+        create_notebook_instance_kwargs = {
+            "NotebookInstanceName": self.instance_name,
+            "InstanceType": self.instance_type,
+            "RoleArn": self.role_arn,
+            "VolumeSizeInGB": self.volume_size_in_gb,
+            "KmsKeyId": self.volume_kms_key_id,
+            "LifecycleConfigName": self.lifecycle_config_name,
+            "DirectInternetAccess": self.direct_internet_access,
+            "RootAccess": self.root_access,
+        }
+        if len(self.create_instance_kwargs) > 0:
+            create_notebook_instance_kwargs.update(self.create_instance_kwargs)
+
+        self.log.info("Creating SageMaker notebook %s.", self.instance_name)
+        response = 
self.hook.conn.create_notebook_instance(**prune_dict(create_notebook_instance_kwargs))
+
+        self.log.info("SageMaker notebook created: %s", 
response["NotebookInstanceArn"])
+
+        if self.wait_for_completion:
+            self.log.info("Waiting for SageMaker notebook %s to be in 
service", self.instance_name)
+            waiter = self.hook.conn.get_waiter("notebook_instance_in_service")
+            waiter.wait(NotebookInstanceName=self.instance_name)
+
+        return response["NotebookInstanceArn"]
+
+
+class SageMakerStopNotebookOperator(BaseOperator):
+    """
+    Stop a notebook instance.
+
+    .. seealso:
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerStopNotebookOperator`
+
+    :param instance_name: The name of the notebook instance to stop.
+    :param wait_for_completion: Whether or not to wait for the notebook to be 
stopped before returning
+    :param aws_conn_id: The AWS connection ID to use.
+    """
+
+    template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
+
+    ui_color = "#ff7300"
+
+    def __init__(
+        self,
+        instance_name: str,
+        wait_for_completion: bool = True,
+        aws_conn_id: str = "aws_default",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_name = instance_name
+        self.wait_for_completion = wait_for_completion
+        self.aws_conn_id = aws_conn_id
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        """Create and return SageMakerHook."""
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    def execute(self, context):
+        self.log.info("Stopping SageMaker notebook %s.", self.instance_name)
+        
self.hook.conn.stop_notebook_instance(NotebookInstanceName=self.instance_name)
+
+        if self.wait_for_completion:
+            self.log.info("Waiting for SageMaker notebook %s to stop", 
self.instance_name)
+            self.hook.conn.get_waiter("notebook_instance_stopped").wait(
+                NotebookInstanceName=self.instance_name
+            )
+
+
+class SageMakerDeleteNotebookOperator(BaseOperator):
+    """
+    Delete a notebook instance.
+
+    .. seealso:
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerDeleteNotebookOperator`
+
+    :param instance_name: The name of the notebook instance to delete.
+    :param wait_for_completion: Whether or not to wait for the notebook to 
delete before returning.
+    :param aws_conn_id: The AWS connection ID to use.
+    """
+
+    template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
+
+    ui_color = "#ff7300"
+
+    def __init__(
+        self,
+        instance_name: str,
+        wait_for_completion: bool = True,
+        aws_conn_id: str = "aws_default",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_name = instance_name
+        self.aws_conn_id = aws_conn_id
+        self.wait_for_completion = wait_for_completion
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        """Create and return SageMakerHook."""
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    def execute(self, context):
+        self.log.info("Deleting SageMaker notebook %s....", self.instance_name)
+        
self.hook.conn.delete_notebook_instance(NotebookInstanceName=self.instance_name)
+
+        if self.wait_for_completion:
+            self.log.info("Waiting for SageMaker notebook %s to delete...", 
self.instance_name)
+            self.hook.conn.get_waiter("notebook_instance_deleted").wait(
+                NotebookInstanceName=self.instance_name
+            )
+
+
+class SageMakerStartNoteBookOperator(BaseOperator):
+    """
+    Start a notebook instance.
+
+    .. seealso:
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerStartNotebookOperator`
+
+    :param instance_name: The name of the notebook instance to start.
+    :param wait_for_completion: Whether or not to wait for notebook to be 
InService before returning
+    :param aws_conn_id: The AWS connection ID to use.
+    """
+
+    template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
+
+    ui_color = "#ff7300"
+
+    def __init__(
+        self,
+        instance_name: str,
+        wait_for_completion: bool = True,
+        aws_conn_id: str = "aws_default",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_name = instance_name
+        self.aws_conn_id = aws_conn_id
+        self.wait_for_completion = wait_for_completion
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        """Create and return SageMakerHook."""
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    def execute(self, context):
+        self.log.info("Starting SageMaker notebook %s....", self.instance_name)
+        
self.hook.conn.start_notebook_instance(NotebookInstanceName=self.instance_name)
+
+        if self.wait_for_completion:
+            self.log.info("Waiting for SageMaker notebook %s to start...", 
self.instance_name)
+            self.hook.conn.get_waiter("notebook_instance_in_service").wait(
+                NotebookInstanceName=self.instance_name
+            )
diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst 
b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index ca4ab34d69..c77b689693 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -222,6 +222,63 @@ This creates an experiment so that it's ready to be 
associated with processing,
     :start-after: [START howto_operator_sagemaker_experiment]
     :end-before: [END howto_operator_sagemaker_experiment]
 
+.. _howto/operator:SageMakerCreateNotebookOperator:
+
+Create a SageMaker Notebook Instance
+====================================
+
+To create a SageMaker Notebook Instance , you can use 
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerCreateNotebookOperator`.
+This creates a SageMaker Notebook Instance ready to run Jupyter notebooks.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_notebook_create]
+    :end-before: [END howto_operator_sagemaker_notebook_create]
+
+.. _howto/operator:SageMakerStopNotebookOperator:
+
+Stop a SageMaker Notebook Instance
+==================================
+
+To terminate SageMaker Notebook Instance , you can use 
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStopNotebookOperator`.
+This terminates the ML compute instance and disconnects the ML storage volume.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_notebook_stop]
+    :end-before: [END howto_operator_sagemaker_notebook_stop]
+
+.. _howto/operator:SageMakerStartNotebookOperator:
+
+Start a SageMaker Notebook Instance
+===================================
+
+To launch a SageMaker Notebook Instance and re-attach an ML storage volume, 
you can use 
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStartNotebookOperator`.
+This launches a new ML compute instance with the latest version of the 
libraries and attached your ML storage volume.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_notebook_start]
+    :end-before: [END howto_operator_sagemaker_notebook_start]
+
+
+.. _howto/operator:SageMakerDeleteNotebookOperator:
+
+Delete a SageMaker Notebook Instance
+====================================
+
+To delete a SageMaker Notebook Instance, you can use 
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteNotebookOperator`.
+This terminates the instance and deletes the ML storage volume and network 
interface associated with the instance. The instance must be stopped before it 
can be deleted.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_notebook_delete]
+    :end-before: [END howto_operator_sagemaker_notebook_delete]
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py
new file mode 100644
index 0000000000..d622b1177f
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from moto import mock_sagemaker
+
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerCreateNotebookOperator,
+    SageMakerDeleteNotebookOperator,
+    SageMakerStartNoteBookOperator,
+    SageMakerStopNotebookOperator,
+)
+
+INSTANCE_NAME = "notebook"
+INSTANCE_TYPE = "ml.t3.medium"
+ROLE_ARN = "arn:aws:iam:role/role"
+
+
[email protected]
+def hook() -> SageMakerHook:
+    with mock_sagemaker():
+        yield SageMakerHook(aws_conn_id="aws_default")
+
+
[email protected]
+def create_instance_args():
+    return {
+        "NotebookInstanceName": INSTANCE_NAME,
+        "InstanceType": INSTANCE_TYPE,
+        "RoleArn": ROLE_ARN,
+    }
+
+
+class TestSageMakerHook:
+    def test_conn(self):
+        hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
+        assert hook.aws_conn_id == "sagemaker_test_conn_id"
+
+    def test_create_instance(self, hook: SageMakerHook, create_instance_args, 
capsys):
+        # create a notebook
+        resp = hook.conn.create_notebook_instance(**create_instance_args)
+        assert resp["NotebookInstanceArn"]
+
+    def test_start_instance(self, hook, create_instance_args, capsys):
+        hook.conn.create_notebook_instance(**create_instance_args)
+        resp = 
hook.conn.start_notebook_instance(NotebookInstanceName=INSTANCE_NAME)
+        assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
+
+    def test_stop_instance(self, hook, create_instance_args, capsys):
+        hook.conn.create_notebook_instance(**create_instance_args)
+        resp = 
hook.conn.stop_notebook_instance(NotebookInstanceName=INSTANCE_NAME)
+        assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
+
+    def test_delete_instance(self, hook, create_instance_args, capsys):
+        hook.conn.create_notebook_instance(**create_instance_args)
+        hook.conn.stop_notebook_instance(NotebookInstanceName=INSTANCE_NAME)
+        resp = 
hook.conn.delete_notebook_instance(NotebookInstanceName=INSTANCE_NAME)
+        assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
+
+
+class TestSagemakerCreateNotebookOperator:
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_create_notebook_without_wait_for_completion(self, mock_hook_conn):
+        operator = SageMakerCreateNotebookOperator(
+            task_id="task_test",
+            instance_name=INSTANCE_NAME,
+            instance_type=INSTANCE_TYPE,
+            role_arn=ROLE_ARN,
+            wait_for_completion=False,
+            volume_size_in_gb=50,
+        )
+        operator.execute(None)
+        mock_hook_conn.create_notebook_instance.assert_called_once()
+        mock_hook_conn.get_waiter.assert_not_called()
+
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_create_notebook_wait_for_completion(self, mock_hook_conn):
+        operator = SageMakerCreateNotebookOperator(
+            task_id="task_test",
+            instance_name=INSTANCE_NAME,
+            instance_type=INSTANCE_TYPE,
+            role_arn=ROLE_ARN,
+        )
+        operator.execute(None)
+        mock_hook_conn.create_notebook_instance.assert_called_once()
+        
mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_in_service")
+
+
+class TestSageMakerStopNotebookOperator:
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_stop_notebook_without_wait_for_completion(self, mock_hook_conn, 
hook):
+        operator = SageMakerStopNotebookOperator(
+            task_id="stop_test", instance_name=INSTANCE_NAME, 
wait_for_completion=False
+        )
+        operator.execute(None)
+        hook.conn.stop_notebook_instance.assert_called_once()
+        mock_hook_conn.get_waiter.assert_not_called()
+
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_stop_notebook_wait_for_completion(self, mock_hook_conn, hook):
+        operator = SageMakerStopNotebookOperator(
+            task_id="stop_test", instance_name=INSTANCE_NAME, 
wait_for_completion=True
+        )
+        operator.execute(None)
+        hook.conn.stop_notebook_instance.assert_called_once()
+        
mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_stopped")
+
+
+class TestSageMakerDeleteNotebookOperator:
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_delete_notebook_without_wait_for_completion(self, mock_hook_conn, 
hook):
+        operator = SageMakerDeleteNotebookOperator(
+            task_id="delete_test", instance_name=INSTANCE_NAME, 
wait_for_completion=False
+        )
+        operator.execute(None)
+        hook.conn.delete_notebook_instance.assert_called_once()
+        mock_hook_conn.get_waiter.assert_not_called()
+
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_delete_notebook_wait_for_completion(self, mock_hook_conn, hook):
+        operator = SageMakerDeleteNotebookOperator(
+            task_id="delete_test", instance_name=INSTANCE_NAME, 
wait_for_completion=True
+        )
+        operator.execute(None)
+        hook.conn.delete_notebook_instance.assert_called_once()
+        
mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_deleted")
+
+
+class TestSageMakerStartNotebookOperator:
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_start_notebook_without_wait_for_completion(self, mock_hook_conn, 
hook):
+        operator = SageMakerStartNoteBookOperator(
+            task_id="start_test", instance_name=INSTANCE_NAME, 
wait_for_completion=False
+        )
+        operator.execute(None)
+        hook.conn.start_notebook_instance.assert_called_once()
+        mock_hook_conn.get_waiter.assert_not_called()
+
+    @mock.patch.object(SageMakerHook, "conn")
+    def test_start_notebook_wait_for_completion(self, mock_hook_conn, hook):
+        operator = SageMakerStartNoteBookOperator(
+            task_id="start_test", instance_name=INSTANCE_NAME, 
wait_for_completion=True
+        )
+        operator.execute(None)
+        hook.conn.start_notebook_instance.assert_called_once()
+        
mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_in_service")
diff --git a/tests/system/providers/amazon/aws/example_sagemaker_notebook.py 
b/tests/system/providers/amazon/aws/example_sagemaker_notebook.py
new file mode 100644
index 0000000000..d597b9b922
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_sagemaker_notebook.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerCreateNotebookOperator,
+    SageMakerDeleteNotebookOperator,
+    SageMakerStartNoteBookOperator,
+    SageMakerStopNotebookOperator,
+)
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_notebook"
+
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"  # must have an IAM role to run notebooks
+
+sys_test_context_task = 
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+with DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    tags=["example"],
+    catchup=False,
+) as dag:
+
+    test_context = sys_test_context_task()
+
+    instance_name: str = f"{test_context[ENV_ID_KEY]}-test-notebook"
+
+    role_arn = test_context[ROLE_ARN_KEY]
+
+    # [START howto_operator_sagemaker_notebook_create]
+    instance = SageMakerCreateNotebookOperator(
+        task_id="create_instance",
+        instance_name=instance_name,
+        instance_type="ml.t3.medium",
+        role_arn=role_arn,
+        wait_for_completion=True,
+    )
+    # [END howto_operator_sagemaker_notebook_create]
+
+    # [START howto_operator_sagemaker_notebook_stop]
+    stop_instance = SageMakerStopNotebookOperator(
+        task_id="stop_instance",
+        instance_name=instance_name,
+    )
+    # [END howto_operator_sagemaker_notebook_stop]
+
+    # [START howto_operator_sagemaker_notebook_start]
+    start_instance = SageMakerStartNoteBookOperator(
+        task_id="start_instance",
+        instance_name=instance_name,
+    )
+
+    # [END howto_operator_sagemaker_notebook_start]
+
+    # Instance must be stopped before it can be deleted.
+    stop_instance_before_delete = SageMakerStopNotebookOperator(
+        task_id="stop_instance_before_delete",
+        instance_name=instance_name,
+    )
+    # [START howto_operator_sagemaker_notebook_delete]
+    delete_instance = 
SageMakerDeleteNotebookOperator(task_id="delete_instance", 
instance_name=instance_name)
+    # [END howto_operator_sagemaker_notebook_delete]
+
+    chain(
+        test_context,
+        # create a new instance
+        instance,
+        # stop the instance
+        stop_instance,
+        # restart the instance
+        start_instance,
+        # must stop before deleting
+        stop_instance_before_delete,
+        # delete the instance
+        delete_instance,
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to