o-nikolas commented on code in PR #62240:
URL: https://github.com/apache/airflow/pull/62240#discussion_r2843070355
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py:
##########
@@ -30,7 +30,7 @@
class SageMakerNotebookHook(BaseHook):
"""
- Interact with Sagemaker Unified Studio Workflows.
+ Interact with Sagemaker Unified Studio Workflows for Jupyter notebook
execution.
Review Comment:
Also Querybook and visual ETL jobs? Or is that not the case?
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
Review Comment:
Is this a param of the constructor? I don't see it used anywhere.
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
+ """
+
+ def __init__(
+ self,
+ domain_id: str,
+ project_id: str,
+ waiter_delay: int = 10,
+ timeout_configuration: dict | None = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.waiter_delay = waiter_delay
+ self.timeout_configuration = timeout_configuration
+ run_timeout = (timeout_configuration or {}).get(
+ "run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES
+ ) # Default timeout is 12 hours
+ self.waiter_max_attempts = int(run_timeout * 60 / self.waiter_delay)
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy-initialized boto3 DataZone client."""
+ if self._client is None:
+ self._client = boto3.client("datazone")
+ self._validate_api_availability()
+ return self._client
+
+ def _validate_api_availability(self):
+ """
+ Verify that the NotebookRun APIs are available in the installed
boto3/botocore version.
+
+ :raises AirflowException: If the required APIs are not available.
+ """
+ required_methods = ("start_notebook_run", "get_notebook_run")
+ for method_name in required_methods:
+ if not hasattr(self._client, method_name):
+ raise AirflowException(
+ f"The '{method_name}' API is not available in the
installed boto3/botocore version. "
+ "Please upgrade boto3/botocore to a version that supports
the DataZone "
+ "NotebookRun APIs."
+ )
+
+ def start_notebook_run(
+ self,
+ notebook_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ workflow_name: str | None = None,
+ ) -> dict:
+ """
+ Start an asynchronous notebook run via the DataZone StartNotebookRun
API.
+
+ :param notebook_id: The ID of the notebook to execute.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ :param compute_configuration: Compute config (e.g. instance_type).
+ :param timeout_configuration: Timeout settings
(run_timeout_in_minutes).
+ :param workflow_name: Name of the workflow (DAG) that triggered this
run.
+ :return: The StartNotebookRun API response dict.
+ """
+ params: dict = {
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
+ "notebook_id": notebook_id,
+ "client_token": client_token or str(uuid.uuid4()),
+ }
+
+ if notebook_parameters:
+ params["notebook_parameters"] = notebook_parameters
+ if compute_configuration:
+ params["compute_configuration"] = compute_configuration
+ if timeout_configuration:
+ params["timeout_configuration"] = timeout_configuration
+ if workflow_name:
+ params["trigger_source"] = {"type": "workflow", "workflow_name":
workflow_name}
+
+ log_message = f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}"
+ self.log.info(log_message)
+ return self.client.start_notebook_run(**params)
+
+ def get_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Get the status of a notebook run via the DataZone GetNotebookRun API.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :return: The GetNotebookRun API response dict.
+ """
+ return self.client.get_notebook_run(
+ domain_id=self.domain_id,
+ notebook_run_id=notebook_run_id,
+ )
+
+ def wait_for_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Poll GetNotebookRun until the run reaches a terminal state.
+
+ :param notebook_run_id: The ID of the notebook run to monitor.
+ :return: A dict with Status and NotebookRunId on success.
+ :raises AirflowException: If the run fails or times out.
+ """
+ for _attempt in range(1, self.waiter_max_attempts + 1):
Review Comment:
Nit: Why not zero based index and avoid the math? Or even better just use
range with the number itself:
```suggestion
for _attempt in range(self.waiter_max_attempts):
```
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
+ the GetNotebookRun API until the run reaches a terminal state.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import
(
+ SageMakerUnifiedStudioNotebookOperator,
+ )
+
+ notebook_operator = SageMakerUnifiedStudioNotebookOperator(
+ task_id="run_notebook",
+ notebook_id="nb-1234567890",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_parameters={"param1": "value1"},
+ compute_configuration={"instance_type": "ml.m5.large"},
+ timeout_configuration={"run_timeout_in_minutes": 1440},
+ )
+
+ :param task_id: A unique, meaningful id for the task.
+ :param notebook_id: The ID of the notebook to execute.
+ :param domain_id: The ID of the SageMaker Unified Studio domain containing
the notebook.
+ :param project_id: The ID of the SageMaker Unified Studio project
containing the notebook.
+ :param client_token: Optional idempotency token. Auto-generated if not
provided.
+ :param notebook_parameters: Optional dict of parameters to pass to the
notebook.
+ :param compute_configuration: Optional compute config.
+ Example: {"instance_type": "ml.m5.large"}
+ :param timeout_configuration: Optional timeout settings.
+ Example: {"run_timeout_in_minutes": 1440}
+ :param wait_for_completion: If True, wait for the notebook run to finish
before
+ completing the task. If False, the operator returns immediately after
starting
+ the run. (default: True)
+ :param waiter_delay: Interval in seconds to poll the notebook run status
(default: 10).
+ :param deferrable: If True, the operator will defer polling to the trigger,
+ freeing up the worker slot while waiting. (default: False)
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerUnifiedStudioNotebookOperator`
+ """
+
+ operator_extra_links = (SageMakerUnifiedStudioLink(),)
+
+ def __init__(
+ self,
+ task_id: str,
+ notebook_id: str,
+ domain_id: str,
+ project_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(task_id=task_id, **kwargs)
+ self.notebook_id = notebook_id
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.client_token = client_token
+ self.notebook_parameters = notebook_parameters
+ self.compute_configuration = compute_configuration
+ self.timeout_configuration = timeout_configuration
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.deferrable = deferrable
+
+ @cached_property
+ def hook(self) -> SageMakerUnifiedStudioNotebookHook:
+ return SageMakerUnifiedStudioNotebookHook(
Review Comment:
When you update to use the correct base class you will not need to define
this. However, since your hook takes params, you'll need to implement
`_hook_parameters()` like this:
https://github.com/apache/airflow/blob/42cb911345d157dac27aa89d3524a696ece3b69b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py#L60-L62
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
+ the GetNotebookRun API until the run reaches a terminal state.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import
(
+ SageMakerUnifiedStudioNotebookOperator,
+ )
+
+ notebook_operator = SageMakerUnifiedStudioNotebookOperator(
+ task_id="run_notebook",
+ notebook_id="nb-1234567890",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_parameters={"param1": "value1"},
+ compute_configuration={"instance_type": "ml.m5.large"},
+ timeout_configuration={"run_timeout_in_minutes": 1440},
+ )
+
+ :param task_id: A unique, meaningful id for the task.
+ :param notebook_id: The ID of the notebook to execute.
+ :param domain_id: The ID of the SageMaker Unified Studio domain containing
the notebook.
+ :param project_id: The ID of the SageMaker Unified Studio project
containing the notebook.
+ :param client_token: Optional idempotency token. Auto-generated if not
provided.
+ :param notebook_parameters: Optional dict of parameters to pass to the
notebook.
+ :param compute_configuration: Optional compute config.
+ Example: {"instance_type": "ml.m5.large"}
+ :param timeout_configuration: Optional timeout settings.
+ Example: {"run_timeout_in_minutes": 1440}
+ :param wait_for_completion: If True, wait for the notebook run to finish
before
+ completing the task. If False, the operator returns immediately after
starting
+ the run. (default: True)
+ :param waiter_delay: Interval in seconds to poll the notebook run status
(default: 10).
+ :param deferrable: If True, the operator will defer polling to the trigger,
+ freeing up the worker slot while waiting. (default: False)
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerUnifiedStudioNotebookOperator`
+ """
+
+ operator_extra_links = (SageMakerUnifiedStudioLink(),)
+
+ def __init__(
+ self,
+ task_id: str,
+ notebook_id: str,
+ domain_id: str,
+ project_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(task_id=task_id, **kwargs)
+ self.notebook_id = notebook_id
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.client_token = client_token
+ self.notebook_parameters = notebook_parameters
+ self.compute_configuration = compute_configuration
+ self.timeout_configuration = timeout_configuration
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.deferrable = deferrable
+
+ @cached_property
+ def hook(self) -> SageMakerUnifiedStudioNotebookHook:
+ return SageMakerUnifiedStudioNotebookHook(
+ domain_id=self.domain_id,
+ project_id=self.project_id,
+ waiter_delay=self.waiter_delay,
+ timeout_configuration=self.timeout_configuration,
+ )
+
+ def execute(self, context: Context):
+ if not self.notebook_id:
+ raise AirflowException("notebook_id is required")
+ if not self.domain_id:
+ raise AirflowException("domain_id is required")
+ if not self.project_id:
+ raise AirflowException("project_id is required")
+
+ workflow_name = context["dag"].dag_id # Workflow name is the same as
the dag_id
+ response = self.hook.start_notebook_run(
+ notebook_id=self.notebook_id,
+ client_token=self.client_token,
+ notebook_parameters=self.notebook_parameters,
+ compute_configuration=self.compute_configuration,
+ timeout_configuration=self.timeout_configuration,
+ workflow_name=workflow_name,
+ )
+ notebook_run_id = response["notebook_run_id"]
+ log_message = f"Started notebook run {notebook_run_id} for notebook
{self.notebook_id}"
Review Comment:
Looks like this is pretty pervasive so I won't keep calling it out, but
there is no need for the useless variable declaration.
##########
providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_notebook.py:
##########
Review Comment:
Have you run this dag yet? Does it run successfully to completion?
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
Review Comment:
```suggestion
class
SageMakerUnifiedStudioNotebookOperator(AwsBaseOperator[SageMakerUnifiedStudioNotebookHook]):
```
##########
providers/amazon/docs/operators/sagemakerunifiedstudio.rst:
##########
@@ -19,40 +19,60 @@
Amazon SageMaker Unified Studio
===============================
-`Amazon SageMaker Unified Studio
<https://aws.amazon.com/sagemaker/unified-studio/>`__ is a unified development
experience that
+`Amazon SageMaker Unified Studio
<https://aws.amazon.com/sagemaker/unified-studio/>`__ (SMUS) is a unified
development experience that
brings together AWS data, analytics, artificial intelligence (AI), and machine
learning (ML) services.
It provides a place to build, deploy, execute, and monitor end-to-end
workflows from a single interface.
This helps drive collaboration across teams and facilitate agile development.
-Airflow provides operators to orchestrate Notebooks, Querybooks, and Visual
ETL jobs within SageMaker Unified Studio Workflows.
+Airflow provides different operators for running artifacts in SageMaker
Unified Studio. Be sure to read the descriptions
+below to understand which operator is best suited for your use case.
Review Comment:
This is a bit informal, maybe drop the first part?:
I honestly don't think it's even necessary, but don't feel too strongly.
```suggestion
Airflow provides different operators for running artifacts in SageMaker
Unified Studio. Read the descriptions
below to understand which operator is best suited for your use case.
```
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
Review Comment:
Any reason you did not use the `AwsBaseHook`?
Example usage:
https://github.com/apache/airflow/blob/42cb911345d157dac27aa89d3524a696ece3b69b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sqs.py#L26-L41
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
+ """
+
+ def __init__(
+ self,
+ domain_id: str,
+ project_id: str,
+ waiter_delay: int = 10,
+ timeout_configuration: dict | None = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.waiter_delay = waiter_delay
+ self.timeout_configuration = timeout_configuration
+ run_timeout = (timeout_configuration or {}).get(
+ "run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES
+ ) # Default timeout is 12 hours
+ self.waiter_max_attempts = int(run_timeout * 60 / self.waiter_delay)
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy-initialized boto3 DataZone client."""
+ if self._client is None:
+ self._client = boto3.client("datazone")
+ self._validate_api_availability()
+ return self._client
+
+ def _validate_api_availability(self):
+ """
+ Verify that the NotebookRun APIs are available in the installed
boto3/botocore version.
+
+ :raises AirflowException: If the required APIs are not available.
+ """
+ required_methods = ("start_notebook_run", "get_notebook_run")
+ for method_name in required_methods:
+ if not hasattr(self._client, method_name):
+ raise AirflowException(
+ f"The '{method_name}' API is not available in the
installed boto3/botocore version. "
+ "Please upgrade boto3/botocore to a version that supports
the DataZone "
+ "NotebookRun APIs."
+ )
+
+ def start_notebook_run(
+ self,
+ notebook_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ workflow_name: str | None = None,
+ ) -> dict:
+ """
+ Start an asynchronous notebook run via the DataZone StartNotebookRun
API.
+
+ :param notebook_id: The ID of the notebook to execute.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ :param compute_configuration: Compute config (e.g. instance_type).
+ :param timeout_configuration: Timeout settings
(run_timeout_in_minutes).
+ :param workflow_name: Name of the workflow (DAG) that triggered this
run.
+ :return: The StartNotebookRun API response dict.
+ """
+ params: dict = {
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
+ "notebook_id": notebook_id,
+ "client_token": client_token or str(uuid.uuid4()),
+ }
+
+ if notebook_parameters:
+ params["notebook_parameters"] = notebook_parameters
+ if compute_configuration:
+ params["compute_configuration"] = compute_configuration
+ if timeout_configuration:
+ params["timeout_configuration"] = timeout_configuration
+ if workflow_name:
+ params["trigger_source"] = {"type": "workflow", "workflow_name":
workflow_name}
+
+ log_message = f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}"
+ self.log.info(log_message)
Review Comment:
```suggestion
self.log.info(f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}")
```
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
Review Comment:
```suggestion
Execute a notebook in SageMaker Unified Studio.
```
I'd drop the mention of async, it's implied and confuses from the
deferrable/non-deferrable async execution paths (makes it seem like this only
supports deferrable).
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
+ the GetNotebookRun API until the run reaches a terminal state.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import
(
+ SageMakerUnifiedStudioNotebookOperator,
+ )
+
+ notebook_operator = SageMakerUnifiedStudioNotebookOperator(
+ task_id="run_notebook",
+ notebook_id="nb-1234567890",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_parameters={"param1": "value1"},
+ compute_configuration={"instance_type": "ml.m5.large"},
+ timeout_configuration={"run_timeout_in_minutes": 1440},
+ )
+
+ :param task_id: A unique, meaningful id for the task.
Review Comment:
There is no need to pull out args from the base init, just pass through all
of kwargs to super. You can add `*` to the beginning of the init signature
below (after self) to ensure all params are kwargs
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
Review Comment:
```suggestion
headless notebook execution. When not configured otherwise, polls
```
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
+ """
+
+ def __init__(
+ self,
+ domain_id: str,
+ project_id: str,
+ waiter_delay: int = 10,
+ timeout_configuration: dict | None = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.waiter_delay = waiter_delay
+ self.timeout_configuration = timeout_configuration
+ run_timeout = (timeout_configuration or {}).get(
+ "run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES
+ ) # Default timeout is 12 hours
+ self.waiter_max_attempts = int(run_timeout * 60 / self.waiter_delay)
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy-initialized boto3 DataZone client."""
+ if self._client is None:
+ self._client = boto3.client("datazone")
+ self._validate_api_availability()
+ return self._client
+
+ def _validate_api_availability(self):
+ """
+ Verify that the NotebookRun APIs are available in the installed
boto3/botocore version.
+
+ :raises AirflowException: If the required APIs are not available.
+ """
+ required_methods = ("start_notebook_run", "get_notebook_run")
+ for method_name in required_methods:
+ if not hasattr(self._client, method_name):
+ raise AirflowException(
+ f"The '{method_name}' API is not available in the
installed boto3/botocore version. "
+ "Please upgrade boto3/botocore to a version that supports
the DataZone "
+ "NotebookRun APIs."
+ )
+
+ def start_notebook_run(
+ self,
+ notebook_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ workflow_name: str | None = None,
+ ) -> dict:
+ """
+ Start an asynchronous notebook run via the DataZone StartNotebookRun
API.
+
+ :param notebook_id: The ID of the notebook to execute.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ :param compute_configuration: Compute config (e.g. instance_type).
+ :param timeout_configuration: Timeout settings
(run_timeout_in_minutes).
+ :param workflow_name: Name of the workflow (DAG) that triggered this
run.
+ :return: The StartNotebookRun API response dict.
+ """
+ params: dict = {
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
+ "notebook_id": notebook_id,
+ "client_token": client_token or str(uuid.uuid4()),
+ }
+
+ if notebook_parameters:
+ params["notebook_parameters"] = notebook_parameters
+ if compute_configuration:
+ params["compute_configuration"] = compute_configuration
+ if timeout_configuration:
+ params["timeout_configuration"] = timeout_configuration
+ if workflow_name:
+ params["trigger_source"] = {"type": "workflow", "workflow_name":
workflow_name}
+
+ log_message = f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}"
+ self.log.info(log_message)
+ return self.client.start_notebook_run(**params)
+
+ def get_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Get the status of a notebook run via the DataZone GetNotebookRun API.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :return: The GetNotebookRun API response dict.
+ """
+ return self.client.get_notebook_run(
+ domain_id=self.domain_id,
+ notebook_run_id=notebook_run_id,
+ )
+
+ def wait_for_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Poll GetNotebookRun until the run reaches a terminal state.
+
+ :param notebook_run_id: The ID of the notebook run to monitor.
+ :return: A dict with Status and NotebookRunId on success.
+ :raises AirflowException: If the run fails or times out.
+ """
+ for _attempt in range(1, self.waiter_max_attempts + 1):
+ time.sleep(self.waiter_delay)
+ response = self.get_notebook_run(notebook_run_id)
+ status = response.get("status")
+ error_message = response.get("errorMessage", "")
+
+ ret = self._handle_state(notebook_run_id, status, error_message)
+ if ret:
+ return ret
+
+ return self._handle_state(notebook_run_id, "FAILED", "Execution timed
out")
+
+ def _handle_state(self, notebook_run_id: str, status: str, error_message:
str) -> dict | None:
+ """
+ Evaluate the current notebook run state and return or raise
accordingly.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :param status: The current status string.
+ :param error_message: Error message from the API response, if any.
+ :return: A dict with Status and NotebookRunId on success, None if
still in progress.
+ :raises AirflowException: If the run has failed.
+ """
+ in_progress_states = {"QUEUED", "STARTING", "RUNNING", "STOPPING"}
+ finished_states = {"SUCCEEDED", "STOPPED"}
+ failure_states = {"FAILED"}
+
+ if status in in_progress_states:
+ log_message = (
+ f"Notebook run {notebook_run_id} is still in progress with
state: {status}, "
+ f"will check for a terminal status again in
{self.waiter_delay}s"
+ )
+ self.log.info(log_message)
+ return None
+
+ execution_message = f"Exiting notebook run {notebook_run_id}. State:
{status}"
+
+ if status in finished_states:
+ self.log.info(execution_message)
+ return {"Status": status, "NotebookRunId": notebook_run_id}
+
+ if status in failure_states:
+ log_message = f"Notebook run {notebook_run_id} failed with error:
{error_message}"
Review Comment:
Here and throughout let's avoid useless declaration of variables. It cleans
up the code.
##########
providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook Run sensor.
+
+This sensor polls the DataZone GetNotebookRun API until the notebook run
+reaches a terminal state.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.common.compat.sdk import AirflowException,
BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookSensor(BaseSensorOperator):
Review Comment:
Same as above, use the AwsBaseSensor with hook mixin:
```suggestion
class
SageMakerUnifiedStudioNotebookSensor(AwsBaseSensor[SageMakerUnifiedStudioNotebookHook]):
```
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
+ """
+
+ def __init__(
+ self,
+ domain_id: str,
+ project_id: str,
+ waiter_delay: int = 10,
+ timeout_configuration: dict | None = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.waiter_delay = waiter_delay
+ self.timeout_configuration = timeout_configuration
+ run_timeout = (timeout_configuration or {}).get(
+ "run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES
+ ) # Default timeout is 12 hours
+ self.waiter_max_attempts = int(run_timeout * 60 / self.waiter_delay)
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy-initialized boto3 DataZone client."""
+ if self._client is None:
+ self._client = boto3.client("datazone")
+ self._validate_api_availability()
+ return self._client
+
+ def _validate_api_availability(self):
+ """
+ Verify that the NotebookRun APIs are available in the installed
boto3/botocore version.
+
+ :raises AirflowException: If the required APIs are not available.
+ """
+ required_methods = ("start_notebook_run", "get_notebook_run")
+ for method_name in required_methods:
+ if not hasattr(self._client, method_name):
+ raise AirflowException(
+ f"The '{method_name}' API is not available in the
installed boto3/botocore version. "
+ "Please upgrade boto3/botocore to a version that supports
the DataZone "
+ "NotebookRun APIs."
+ )
+
+ def start_notebook_run(
+ self,
+ notebook_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ workflow_name: str | None = None,
+ ) -> dict:
+ """
+ Start an asynchronous notebook run via the DataZone StartNotebookRun
API.
+
+ :param notebook_id: The ID of the notebook to execute.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ :param compute_configuration: Compute config (e.g. instance_type).
+ :param timeout_configuration: Timeout settings
(run_timeout_in_minutes).
+ :param workflow_name: Name of the workflow (DAG) that triggered this
run.
+ :return: The StartNotebookRun API response dict.
+ """
+ params: dict = {
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
+ "notebook_id": notebook_id,
+ "client_token": client_token or str(uuid.uuid4()),
+ }
+
+ if notebook_parameters:
+ params["notebook_parameters"] = notebook_parameters
+ if compute_configuration:
+ params["compute_configuration"] = compute_configuration
+ if timeout_configuration:
+ params["timeout_configuration"] = timeout_configuration
+ if workflow_name:
+ params["trigger_source"] = {"type": "workflow", "workflow_name":
workflow_name}
+
+ log_message = f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}"
+ self.log.info(log_message)
+ return self.client.start_notebook_run(**params)
+
+ def get_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Get the status of a notebook run via the DataZone GetNotebookRun API.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :return: The GetNotebookRun API response dict.
+ """
+ return self.client.get_notebook_run(
+ domain_id=self.domain_id,
+ notebook_run_id=notebook_run_id,
+ )
+
+ def wait_for_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Poll GetNotebookRun until the run reaches a terminal state.
+
+ :param notebook_run_id: The ID of the notebook run to monitor.
+ :return: A dict with Status and NotebookRunId on success.
+ :raises AirflowException: If the run fails or times out.
+ """
+ for _attempt in range(1, self.waiter_max_attempts + 1):
+ time.sleep(self.waiter_delay)
+ response = self.get_notebook_run(notebook_run_id)
+ status = response.get("status")
+ error_message = response.get("errorMessage", "")
+
+ ret = self._handle_state(notebook_run_id, status, error_message)
+ if ret:
+ return ret
+
+ return self._handle_state(notebook_run_id, "FAILED", "Execution timed
out")
+
+ def _handle_state(self, notebook_run_id: str, status: str, error_message:
str) -> dict | None:
Review Comment:
We're using state everywhere else so the comparisons look a bit strange.
```suggestion
def _handle_state(self, notebook_run_id: str, state: str, error_message:
str) -> dict | None:
```
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Trigger for monitoring SageMaker Unified Studio Notebook runs
asynchronously."""
+
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from functools import partial
+from typing import Any
+
+import boto3
+
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+IN_PROGRESS_STATES = {"QUEUED", "STARTING", "RUNNING", "STOPPING"}
+FINISHED_STATES = {"SUCCEEDED", "STOPPED"}
+FAILURE_STATES = {"FAILED"}
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookTrigger(BaseTrigger):
Review Comment:
@ferruzzi Can you review this class? In this case do you think it's possible
for them to implement a custom Waiter and then use the `AwsBaseWaiterTrigger`
##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook Run
hook."""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+import boto3
+
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+TWELVE_HOURS_IN_MINUTES = 12 * 60
+
+
+class SageMakerUnifiedStudioNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows for asynchronous notebook
execution.
+
+ This hook provides a wrapper around the DataZone StartNotebookRun /
GetNotebookRun APIs.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookHook,
+ )
+
+ hook = SageMakerUnifiedStudioNotebookHook(
+ domain_id="dzd_example",
+ project_id="proj_example",
+ waiter_delay=10,
+ )
+
+ :param domain_id: The ID of the DataZone domain containing the notebook.
+ :param project_id: The ID of the DataZone project containing the notebook.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ Example: {"param1": "value1", "param2": "value2"}
+ :param compute_configuration: Compute config to use for the notebook
execution.
+ Example: {"instance_type": "ml.m5.large"}
+ :param waiter_delay: Interval in seconds to poll the notebook run status.
+ :param timeout_configuration: Timeout settings for the notebook execution.
+ When provided, the maximum number of poll attempts is derived from
+ ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours.
+ Example: {"run_timeout_in_minutes": 720}
+ :param workflow_name: Name of the workflow (DAG) that triggered this run.
+ """
+
+ def __init__(
+ self,
+ domain_id: str,
+ project_id: str,
+ waiter_delay: int = 10,
+ timeout_configuration: dict | None = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.waiter_delay = waiter_delay
+ self.timeout_configuration = timeout_configuration
+ run_timeout = (timeout_configuration or {}).get(
+ "run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES
+ ) # Default timeout is 12 hours
+ self.waiter_max_attempts = int(run_timeout * 60 / self.waiter_delay)
+ self._client = None
+
+ @property
+ def client(self):
+ """Lazy-initialized boto3 DataZone client."""
+ if self._client is None:
+ self._client = boto3.client("datazone")
+ self._validate_api_availability()
+ return self._client
+
+ def _validate_api_availability(self):
+ """
+ Verify that the NotebookRun APIs are available in the installed
boto3/botocore version.
+
+ :raises AirflowException: If the required APIs are not available.
+ """
+ required_methods = ("start_notebook_run", "get_notebook_run")
+ for method_name in required_methods:
+ if not hasattr(self._client, method_name):
+ raise AirflowException(
+ f"The '{method_name}' API is not available in the
installed boto3/botocore version. "
+ "Please upgrade boto3/botocore to a version that supports
the DataZone "
+ "NotebookRun APIs."
+ )
+
+ def start_notebook_run(
+ self,
+ notebook_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ workflow_name: str | None = None,
+ ) -> dict:
+ """
+ Start an asynchronous notebook run via the DataZone StartNotebookRun
API.
+
+ :param notebook_id: The ID of the notebook to execute.
+ :param client_token: Idempotency token. Auto-generated if not provided.
+ :param notebook_parameters: Parameters to pass to the notebook.
+ :param compute_configuration: Compute config (e.g. instance_type).
+ :param timeout_configuration: Timeout settings
(run_timeout_in_minutes).
+ :param workflow_name: Name of the workflow (DAG) that triggered this
run.
+ :return: The StartNotebookRun API response dict.
+ """
+ params: dict = {
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
+ "notebook_id": notebook_id,
+ "client_token": client_token or str(uuid.uuid4()),
+ }
+
+ if notebook_parameters:
+ params["notebook_parameters"] = notebook_parameters
+ if compute_configuration:
+ params["compute_configuration"] = compute_configuration
+ if timeout_configuration:
+ params["timeout_configuration"] = timeout_configuration
+ if workflow_name:
+ params["trigger_source"] = {"type": "workflow", "workflow_name":
workflow_name}
+
+ log_message = f"Starting notebook run for notebook {notebook_id} in
domain {self.domain_id}"
+ self.log.info(log_message)
+ return self.client.start_notebook_run(**params)
+
+ def get_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Get the status of a notebook run via the DataZone GetNotebookRun API.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :return: The GetNotebookRun API response dict.
+ """
+ return self.client.get_notebook_run(
+ domain_id=self.domain_id,
+ notebook_run_id=notebook_run_id,
+ )
+
+ def wait_for_notebook_run(self, notebook_run_id: str) -> dict:
+ """
+ Poll GetNotebookRun until the run reaches a terminal state.
+
+ :param notebook_run_id: The ID of the notebook run to monitor.
+ :return: A dict with Status and NotebookRunId on success.
+ :raises AirflowException: If the run fails or times out.
+ """
+ for _attempt in range(1, self.waiter_max_attempts + 1):
+ time.sleep(self.waiter_delay)
+ response = self.get_notebook_run(notebook_run_id)
+ status = response.get("status")
+ error_message = response.get("errorMessage", "")
+
+ ret = self._handle_state(notebook_run_id, status, error_message)
+ if ret:
+ return ret
+
+ return self._handle_state(notebook_run_id, "FAILED", "Execution timed
out")
+
+ def _handle_state(self, notebook_run_id: str, status: str, error_message:
str) -> dict | None:
+ """
+ Evaluate the current notebook run state and return or raise
accordingly.
+
+ :param notebook_run_id: The ID of the notebook run.
+ :param status: The current status string.
+ :param error_message: Error message from the API response, if any.
+ :return: A dict with Status and NotebookRunId on success, None if
still in progress.
+ :raises AirflowException: If the run has failed.
+ """
+ in_progress_states = {"QUEUED", "STARTING", "RUNNING", "STOPPING"}
+ finished_states = {"SUCCEEDED", "STOPPED"}
+ failure_states = {"FAILED"}
+
+ if status in in_progress_states:
+ log_message = (
+ f"Notebook run {notebook_run_id} is still in progress with
state: {status}, "
+ f"will check for a terminal status again in
{self.waiter_delay}s"
+ )
+ self.log.info(log_message)
+ return None
+
+ execution_message = f"Exiting notebook run {notebook_run_id}. State:
{status}"
+
+ if status in finished_states:
+ self.log.info(execution_message)
+ return {"Status": status, "NotebookRunId": notebook_run_id}
+
+ if status in failure_states:
+ log_message = f"Notebook run {notebook_run_id} failed with error:
{error_message}"
+ self.log.error(log_message)
+ else:
+ log_message = f"Notebook run {notebook_run_id} reached unexpected
state: {status}"
+ self.log.error(log_message)
+
+ if error_message == "":
+ error_message = execution_message
+ raise AirflowException(error_message)
Review Comment:
We're trying to avoid the overuse of `AirflowException` there is nothing
related to Airflow failing here. Use whatever exception more closely matches
this failure. Perhaps `RuntimeException`.
This is done throughout the rest of this PR, please fix those other cases as
well.
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
+ the GetNotebookRun API until the run reaches a terminal state.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import
(
+ SageMakerUnifiedStudioNotebookOperator,
+ )
+
+ notebook_operator = SageMakerUnifiedStudioNotebookOperator(
+ task_id="run_notebook",
+ notebook_id="nb-1234567890",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_parameters={"param1": "value1"},
+ compute_configuration={"instance_type": "ml.m5.large"},
+ timeout_configuration={"run_timeout_in_minutes": 1440},
+ )
+
+ :param task_id: A unique, meaningful id for the task.
+ :param notebook_id: The ID of the notebook to execute.
+ :param domain_id: The ID of the SageMaker Unified Studio domain containing
the notebook.
+ :param project_id: The ID of the SageMaker Unified Studio project
containing the notebook.
+ :param client_token: Optional idempotency token. Auto-generated if not
provided.
+ :param notebook_parameters: Optional dict of parameters to pass to the
notebook.
+ :param compute_configuration: Optional compute config.
+ Example: {"instance_type": "ml.m5.large"}
+ :param timeout_configuration: Optional timeout settings.
+ Example: {"run_timeout_in_minutes": 1440}
+ :param wait_for_completion: If True, wait for the notebook run to finish
before
+ completing the task. If False, the operator returns immediately after
starting
+ the run. (default: True)
+ :param waiter_delay: Interval in seconds to poll the notebook run status
(default: 10).
+ :param deferrable: If True, the operator will defer polling to the trigger,
+ freeing up the worker slot while waiting. (default: False)
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerUnifiedStudioNotebookOperator`
+ """
+
+ operator_extra_links = (SageMakerUnifiedStudioLink(),)
+
+ def __init__(
+ self,
+ task_id: str,
+ notebook_id: str,
+ domain_id: str,
+ project_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(task_id=task_id, **kwargs)
+ self.notebook_id = notebook_id
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.client_token = client_token
+ self.notebook_parameters = notebook_parameters
+ self.compute_configuration = compute_configuration
+ self.timeout_configuration = timeout_configuration
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.deferrable = deferrable
+
+ @cached_property
+ def hook(self) -> SageMakerUnifiedStudioNotebookHook:
+ return SageMakerUnifiedStudioNotebookHook(
+ domain_id=self.domain_id,
+ project_id=self.project_id,
+ waiter_delay=self.waiter_delay,
+ timeout_configuration=self.timeout_configuration,
+ )
+
+ def execute(self, context: Context):
+ if not self.notebook_id:
+ raise AirflowException("notebook_id is required")
+ if not self.domain_id:
+ raise AirflowException("domain_id is required")
+ if not self.project_id:
+ raise AirflowException("project_id is required")
+
+ workflow_name = context["dag"].dag_id # Workflow name is the same as
the dag_id
+ response = self.hook.start_notebook_run(
+ notebook_id=self.notebook_id,
+ client_token=self.client_token,
+ notebook_parameters=self.notebook_parameters,
+ compute_configuration=self.compute_configuration,
+ timeout_configuration=self.timeout_configuration,
+ workflow_name=workflow_name,
+ )
+ notebook_run_id = response["notebook_run_id"]
+ log_message = f"Started notebook run {notebook_run_id} for notebook
{self.notebook_id}"
+ self.log.info(log_message)
+
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerUnifiedStudioNotebookTrigger(
+ notebook_run_id=notebook_run_id,
+ domain_id=self.domain_id,
+ project_id=self.project_id,
+ waiter_delay=self.waiter_delay,
+ timeout_configuration=self.timeout_configuration,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ self.hook.wait_for_notebook_run(notebook_run_id)
+ log_message = f"Notebook run {notebook_run_id} completed for
notebook {self.notebook_id}"
Review Comment:
Doesn't the wait function already log the completion?
##########
providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook Run sensor.
+
+This sensor polls the DataZone GetNotebookRun API until the notebook run
+reaches a terminal state.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.common.compat.sdk import AirflowException,
BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookSensor(BaseSensorOperator):
+ """
+ Polls a SageMakerUnifiedStudio Workflow asynchronous Notebook execution
until it reaches a terminal state.
+
+ 'SUCCEEDED', 'FAILED', 'STOPPED'
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.sensors.sagemaker_unified_studio_notebook import (
+ SageMakerUnifiedStudioNotebookSensor,
+ )
+
+ notebook_sensor = SageMakerUnifiedStudioNotebookSensor(
+ task_id="wait_for_notebook",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_run_id="nr-1234567890",
+ )
+
+ :param domain_id: The ID of the SageMaker Unified Studio domain containing
the notebook.
+ :param project_id: The ID of the SageMaker Unified Studio project
containing the notebook.
+ :param notebook_run_id: The ID of the notebook run to monitor.
+ This is returned by the ``SageMakerUnifiedStudioNotebookOperator``.
+ """
+
+ def __init__(
+ self,
+ *,
+ domain_id: str,
+ project_id: str,
+ notebook_run_id: str,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.notebook_run_id = notebook_run_id
+ self.success_states = ["SUCCEEDED"]
+ self.in_progress_states = ["QUEUED", "STARTING", "RUNNING", "STOPPING"]
+
+ @cached_property
+ def hook(self) -> SageMakerUnifiedStudioNotebookHook:
+ return SageMakerUnifiedStudioNotebookHook(
+ domain_id=self.domain_id,
+ project_id=self.project_id,
+ )
Review Comment:
Same as above, you don't need to implement this, implement
`_hook_parameters()` instead
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py:
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+This module contains the Amazon SageMaker Unified Studio Notebook operator.
+
+This operator supports asynchronous notebook execution in SageMaker Unified
+Studio.
+"""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook
import (
+ SageMakerUnifiedStudioNotebookTrigger,
+)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class SageMakerUnifiedStudioNotebookOperator(BaseOperator):
+ """
+ Execute a notebook asynchronously in SageMaker Unified Studio.
+
+ This operator calls the DataZone StartNotebookRun API to kick off
+ headless notebook execution, and, when not configured otherwise, polls
+ the GetNotebookRun API until the run reaches a terminal state.
+
+ Examples:
+ .. code-block:: python
+
+ from
airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import
(
+ SageMakerUnifiedStudioNotebookOperator,
+ )
+
+ notebook_operator = SageMakerUnifiedStudioNotebookOperator(
+ task_id="run_notebook",
+ notebook_id="nb-1234567890",
+ domain_id="dzd_example",
+ project_id="proj_example",
+ notebook_parameters={"param1": "value1"},
+ compute_configuration={"instance_type": "ml.m5.large"},
+ timeout_configuration={"run_timeout_in_minutes": 1440},
+ )
+
+ :param task_id: A unique, meaningful id for the task.
+ :param notebook_id: The ID of the notebook to execute.
+ :param domain_id: The ID of the SageMaker Unified Studio domain containing
the notebook.
+ :param project_id: The ID of the SageMaker Unified Studio project
containing the notebook.
+ :param client_token: Optional idempotency token. Auto-generated if not
provided.
+ :param notebook_parameters: Optional dict of parameters to pass to the
notebook.
+ :param compute_configuration: Optional compute config.
+ Example: {"instance_type": "ml.m5.large"}
+ :param timeout_configuration: Optional timeout settings.
+ Example: {"run_timeout_in_minutes": 1440}
+ :param wait_for_completion: If True, wait for the notebook run to finish
before
+ completing the task. If False, the operator returns immediately after
starting
+ the run. (default: True)
+ :param waiter_delay: Interval in seconds to poll the notebook run status
(default: 10).
+ :param deferrable: If True, the operator will defer polling to the trigger,
+ freeing up the worker slot while waiting. (default: False)
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerUnifiedStudioNotebookOperator`
+ """
+
+ operator_extra_links = (SageMakerUnifiedStudioLink(),)
+
+ def __init__(
+ self,
+ task_id: str,
+ notebook_id: str,
+ domain_id: str,
+ project_id: str,
+ client_token: str | None = None,
+ notebook_parameters: dict | None = None,
+ compute_configuration: dict | None = None,
+ timeout_configuration: dict | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 10,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(task_id=task_id, **kwargs)
+ self.notebook_id = notebook_id
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.client_token = client_token
+ self.notebook_parameters = notebook_parameters
+ self.compute_configuration = compute_configuration
+ self.timeout_configuration = timeout_configuration
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.deferrable = deferrable
+
+ @cached_property
+ def hook(self) -> SageMakerUnifiedStudioNotebookHook:
+ return SageMakerUnifiedStudioNotebookHook(
+ domain_id=self.domain_id,
+ project_id=self.project_id,
+ waiter_delay=self.waiter_delay,
+ timeout_configuration=self.timeout_configuration,
+ )
+
+ def execute(self, context: Context):
+ if not self.notebook_id:
+ raise AirflowException("notebook_id is required")
Review Comment:
1) Same as earlier, don't use AirflowException here. Airflow is not failing.
2) Why are you validating these non-optional params here? This seems like
overly defensive code?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]