This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9939b1b3d76 Add AWS SageMaker Unified Studio Workflow Operator (#45726)
9939b1b3d76 is described below
commit 9939b1b3d76081245afca88c351d3f116bce25dc
Author: Arunav Gupta <[email protected]>
AuthorDate: Tue Mar 4 16:01:29 2025 -0500
Add AWS SageMaker Unified Studio Workflow Operator (#45726)
Adds an operator used for executing Jupyter Notebooks, Querybooks, and
Visual ETL jobs within the context of a SageMaker Unified Studio project.
---------
Co-authored-by: Niko Oliveira <[email protected]>
---
docs/spelling_wordlist.txt | 10 +
generated/provider_dependencies.json | 1 +
providers/amazon/README.rst | 1 +
.../docs/operators/sagemakerunifiedstudio.rst | 60 ++++++
providers/amazon/provider.yaml | 20 +-
providers/amazon/pyproject.toml | 1 +
.../amazon/aws/hooks/sagemaker_unified_studio.py | 188 +++++++++++++++++++
.../amazon/aws/links/sagemaker_unified_studio.py | 27 +++
.../aws/operators/sagemaker_unified_studio.py | 155 ++++++++++++++++
.../amazon/aws/sensors/sagemaker_unified_studio.py | 73 ++++++++
.../aws/triggers/sagemaker_unified_studio.py | 66 +++++++
.../amazon/aws/utils/sagemaker_unified_studio.py | 28 +++
.../airflow/providers/amazon/get_provider_info.py | 27 +++
.../amazon/aws/example_sagemaker_unified_studio.py | 166 +++++++++++++++++
.../aws/hooks/test_sagemaker_unified_studio.py | 201 +++++++++++++++++++++
.../aws/links/test_sagemaker_unified_studio.py | 32 ++++
.../unit/amazon/aws/operators/test_notebook.ipynb | 61 +++++++
.../aws/operators/test_sagemaker_unified_studio.py | 176 ++++++++++++++++++
.../aws/sensors/test_sagemaker_unified_studio.py | 105 +++++++++++
.../aws/utils/test_sagemaker_unified_studio.py | 50 +++++
.../providers/3rd-party-licenses/LICENSES-ui.txt | 89 +++++++++
tests/always/test_project_structure.py | 3 +
22 files changed, 1539 insertions(+), 1 deletion(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 2f7a408bb1a..7f94448a8d4 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -319,6 +319,7 @@ connectTimeoutMS
connexion
containerConfiguration
containerd
+ContainerEntrypoint
ContainerGroup
containerinstance
ContainerPort
@@ -693,6 +694,7 @@ Gantt
gantt
gapic
gapped
+gb
gbq
gcc
gcloud
@@ -826,6 +828,7 @@ ImageAnnotatorClient
imageORfile
imagePullPolicy
imagePullSecrets
+ImageUri
imageVersion
Imap
imap
@@ -859,6 +862,7 @@ InstanceFlexibilityPolicy
InstanceGroupConfig
InstanceSelection
instanceTemplates
+InstanceType
instantiation
integrations
interdependencies
@@ -876,6 +880,7 @@ IPv4
ipv4
IPv6
ipv6
+ipynb
iPython
irreproducible
IRSA
@@ -1050,6 +1055,7 @@ masterType
Matomo
matomo
Maxime
+MaxRuntimeInSeconds
mb
md
mediawiki
@@ -1373,6 +1379,8 @@ Qubole
qubole
QuboleCheckHook
Quboles
+querybook
+Querybooks
queryParameters
querystring
queueing
@@ -1887,8 +1895,10 @@ views
virtualenv
virtualenvs
vm
+VolumeKmsKeyId
VolumeMount
volumeMounts
+VolumeSizeInGB
vpc
WaiterModel
wape
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index e1fc18617cb..0b887e4481e 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -38,6 +38,7 @@
"jsonpath_ng>=1.5.3",
"python3-saml>=1.16.0",
"redshift_connector>=2.0.918",
+ "sagemaker-studio>=1.0.9",
"watchtower>=3.0.0,!=3.3.0,<4"
],
"devel-deps": [
diff --git a/providers/amazon/README.rst b/providers/amazon/README.rst
index 0fec71dbae1..c61e4133b02 100644
--- a/providers/amazon/README.rst
+++ b/providers/amazon/README.rst
@@ -67,6 +67,7 @@ PIP package Version required
``PyAthena`` ``>=3.0.10``
``jmespath`` ``>=0.7.0``
``python3-saml`` ``>=1.16.0``
+``sagemaker-studio`` ``>=1.0.9``
========================================== ======================
Cross provider package dependencies
diff --git a/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
new file mode 100644
index 00000000000..33833cf395d
--- /dev/null
+++ b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
@@ -0,0 +1,60 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+===============================
+Amazon SageMaker Unified Studio
+===============================
+
+`Amazon SageMaker Unified Studio
<https://aws.amazon.com/sagemaker/unified-studio/>`__ is a unified development
experience that
+brings together AWS data, analytics, artificial intelligence (AI), and machine
learning (ML) services.
+It provides a place to build, deploy, execute, and monitor end-to-end
workflows from a single interface.
+This helps drive collaboration across teams and facilitate agile development.
+
+Airflow provides operators to orchestrate Notebooks, Querybooks, and Visual
ETL jobs within SageMaker Unified Studio Workflows.
+
+Prerequisite Tasks
+------------------
+
+To use these operators, you must do a few things:
+
+ * Create a SageMaker Unified Studio domain and project, following the
instruction in `AWS documentation
<https://docs.aws.amazon.com/sagemaker-unified-studio/latest/userguide/getting-started.html>`__.
+ * Within your project:
+ * Navigate to the "Compute > Workflow environments" tab, and click
"Create" to create a new MWAA environment.
+ * Create a Notebook, Querybook, or Visual ETL job and save it to your
project.
+
+Operators
+---------
+
+.. _howto/operator:SageMakerNotebookOperator:
+
+Create an Amazon SageMaker Unified Studio Workflow
+==================================================
+
+To create an Amazon SageMaker Unified Studio workflow to orchestrate your
notebook, querybook, and visual ETL runs you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookOperator`.
+
+.. exampleinclude::
/../../providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_unified_studio_notebook]
+ :end-before: [END howto_operator_sagemaker_unified_studio_notebook]
+
+
+Reference
+---------
+
+* `What is Amazon SageMaker Unified Studio
<https://docs.aws.amazon.com/sagemaker-unified-studio/latest/userguide/what-is-sagemaker-unified-studio.html>`__
diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 107bd35bb09..1aa2947fea5 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -234,6 +234,12 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-amazon/operators/sagemaker.rst
tags: [aws]
+ - integration-name: Amazon SageMaker Unified Studio
+ external-doc-url: https://aws.amazon.com/sagemaker/unified-studio/
+ logo: /docs/integration-logos/[email protected]
+ how-to-guide:
+ -
/docs/apache-airflow-providers-amazon/operators/sagemakerunifiedstudio.rst
+ tags: [aws]
- integration-name: Amazon SecretsManager
external-doc-url: https://aws.amazon.com/secrets-manager/
logo: /docs/integration-logos/[email protected]
@@ -402,6 +408,9 @@ operators:
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.operators.sagemaker
+ - integration-name: Amazon SageMaker Unified Studio
+ python-modules:
+ - airflow.providers.amazon.aws.operators.sagemaker_unified_studio
- integration-name: Amazon Simple Notification Service (SNS)
python-modules:
- airflow.providers.amazon.aws.operators.sns
@@ -503,6 +512,9 @@ sensors:
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.sensors.sagemaker
+ - integration-name: Amazon SageMaker Unified Studio
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.sagemaker_unified_studio
- integration-name: Amazon Simple Queue Service (SQS)
python-modules:
- airflow.providers.amazon.aws.sensors.sqs
@@ -627,6 +639,9 @@ hooks:
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.hooks.sagemaker
+ - integration-name: Amazon SageMaker Unified Studio
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.sagemaker_unified_studio
- integration-name: Amazon Simple Email Service (SES)
python-modules:
- airflow.providers.amazon.aws.hooks.ses
@@ -699,6 +714,9 @@ triggers:
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.triggers.sagemaker
+ - integration-name: Amazon SageMaker Unified Studio
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.sagemaker_unified_studio
- integration-name: AWS Glue
python-modules:
- airflow.providers.amazon.aws.triggers.glue
@@ -734,7 +752,6 @@ triggers:
python-modules:
- airflow.providers.amazon.aws.triggers.dms
-
transfers:
- source-integration-name: Amazon DynamoDB
target-integration-name: Amazon Simple Storage Service (S3)
@@ -837,6 +854,7 @@ extra-links:
- airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
- airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink
+ -
airflow.providers.amazon.aws.links.sagemaker_unified_studio.SageMakerUnifiedStudioLink
- airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink
-
airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink
-
airflow.providers.amazon.aws.links.comprehend.ComprehendPiiEntitiesDetectionLink
diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml
index 7586425e3f3..14d54986136 100644
--- a/providers/amazon/pyproject.toml
+++ b/providers/amazon/pyproject.toml
@@ -75,6 +75,7 @@ dependencies = [
"PyAthena>=3.0.10",
"jmespath>=0.7.0",
"python3-saml>=1.16.0",
+ "sagemaker-studio>=1.0.9",
]
# The optional dependencies should be modified in place in the generated file
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..4ad327b51c5
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
@@ -0,0 +1,188 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook hook."""
+
+from __future__ import annotations
+
+import time
+
+from sagemaker_studio import ClientConfig
+from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import
is_local_runner
+
+
+class SageMakerNotebookHook(BaseHook):
+ """
+ Interact with Sagemaker Unified Studio Workflows.
+
+ This hook provides a wrapper around the Sagemaker Workflows Notebook
Execution API.
+
+ Examples:
+ .. code-block:: python
+
+ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio
import SageMakerNotebookHook
+
+ notebook_hook = SageMakerNotebookHook(
+ input_config={"input_path": "path/to/notebook.ipynb",
"input_params": {"param1": "value1"}},
+ output_config={"output_uri": "folder/output/location/prefix",
"output_formats": "NOTEBOOK"},
+ execution_name="notebook_execution",
+ waiter_delay=10,
+ waiter_max_attempts=1440,
+ )
+
+ :param execution_name: The name of the notebook job to be executed, this
is same as task_id.
+ :param input_config: Configuration for the input file.
+ Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params':
{'param1': 'value1'}}
+ :param output_config: Configuration for the output format. It should
include an output_formats parameter to specify the output format.
+ Example: {'output_formats': ['NOTEBOOK']}
+ :param compute: compute configuration to use for the notebook execution.
This is a required attribute
+ if the execution is on a remote compute.
+ Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30,
"volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [
"string" ]}
+ :param termination_condition: conditions to match to terminate the remote
execution.
+ Example: { "MaxRuntimeInSeconds": 3600 }
+ :param tags: tags to be associated with the remote execution runs.
+ Example: { "md_analytics": "logs" }
+ :param waiter_delay: Interval in seconds to check the task execution
status.
+ :param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
+ """
+
+ def __init__(
+ self,
+ execution_name: str,
+ input_config: dict | None = None,
+ output_config: dict | None = None,
+ compute: dict | None = None,
+ termination_condition: dict | None = None,
+ tags: dict | None = None,
+ waiter_delay: int = 10,
+ waiter_max_attempts: int = 1440,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self._sagemaker_studio =
SageMakerStudioAPI(self._get_sagemaker_studio_config())
+ self.execution_name = execution_name
+ self.input_config = input_config or {}
+ self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
+ self.compute = compute
+ self.termination_condition = termination_condition or {}
+ self.tags = tags or {}
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+
+ def _get_sagemaker_studio_config(self):
+ config = ClientConfig()
+ config.overrides["execution"] = {"local": is_local_runner()}
+ return config
+
+ def _format_start_execution_input_config(self):
+ config = {
+ "notebook_config": {
+ "input_path": self.input_config.get("input_path"),
+ "input_parameters": self.input_config.get("input_params"),
+ },
+ }
+
+ return config
+
+ def _format_start_execution_output_config(self):
+ output_formats = self.output_config.get("output_formats")
+ config = {
+ "notebook_config": {
+ "output_formats": output_formats,
+ }
+ }
+ return config
+
+ def start_notebook_execution(self):
+ start_execution_params = {
+ "execution_name": self.execution_name,
+ "execution_type": "NOTEBOOK",
+ "input_config": self._format_start_execution_input_config(),
+ "output_config": self._format_start_execution_output_config(),
+ "termination_condition": self.termination_condition,
+ "tags": self.tags,
+ }
+ if self.compute:
+ start_execution_params["compute"] = self.compute
+ else:
+ start_execution_params["compute"] = {"instance_type":
"ml.m4.xlarge"}
+
+ print(start_execution_params)
+ return
self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
+
+ def wait_for_execution_completion(self, execution_id, context):
+ wait_attempts = 0
+ while wait_attempts < self.waiter_max_attempts:
+ wait_attempts += 1
+ time.sleep(self.waiter_delay)
+ response =
self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
+ error_message = response.get("error_details",
{}).get("error_message")
+ status = response["status"]
+ if "files" in response:
+ self._set_xcom_files(response["files"], context)
+ if "s3_path" in response:
+ self._set_xcom_s3_path(response["s3_path"], context)
+
+ ret = self._handle_state(execution_id, status, error_message)
+ if ret:
+ return ret
+
+ # If timeout, handle state FAILED with timeout message
+ return self._handle_state(execution_id, "FAILED", "Execution timed
out")
+
+ def _set_xcom_files(self, files, context):
+ if not context:
+ error_message = "context is required"
+ raise AirflowException(error_message)
+ for file in files:
+ context["ti"].xcom_push(
+ key=f"{file['display_name']}.{file['file_format']}",
+ value=file["file_path"],
+ )
+
+ def _set_xcom_s3_path(self, s3_path, context):
+ if not context:
+ error_message = "context is required"
+ raise AirflowException(error_message)
+ context["ti"].xcom_push(
+ key="s3_path",
+ value=s3_path,
+ )
+
+ def _handle_state(self, execution_id, status, error_message):
+ finished_states = ["COMPLETED"]
+ in_progress_states = ["IN_PROGRESS", "STOPPING"]
+
+ if status in in_progress_states:
+ info_message = f"Execution {execution_id} is still in progress
with state:{status}, will check for a terminal status again in
{self.waiter_delay}"
+ self.log.info(info_message)
+ return None
+ execution_message = f"Exiting Execution {execution_id} State: {status}"
+ if status in finished_states:
+ self.log.info(execution_message)
+ return {"Status": status, "ExecutionId": execution_id}
+ else:
+ log_error_message = f"Execution {execution_id} failed with error:
{error_message}"
+ self.log.error(log_error_message)
+ if error_message == "":
+ error_message = execution_message
+ raise AirflowException(error_message)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..802a1fbfff8
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK,
BaseAwsLink
+
+
+class SageMakerUnifiedStudioLink(BaseAwsLink):
+ """Helper class for constructing Amazon SageMaker Unified Studio Links."""
+
+ name = "Amazon SageMaker Unified Studio"
+ key = "sagemaker_unified_studio"
+ format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}"
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..c872c56afa6
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook
operator."""
+
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import (
+ SageMakerNotebookHook,
+)
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
+ SageMakerUnifiedStudioLink,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import (
+ SageMakerNotebookJobTrigger,
+)
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class SageMakerNotebookOperator(BaseOperator):
+ """
+ Provides Artifact execution functionality for Sagemaker Unified Studio
Workflows.
+
+ Examples:
+ .. code-block:: python
+
+ from airflow.providers.amazon.aws.operators.sagemaker_unified_studio
import SageMakerNotebookOperator
+
+ notebook_operator = SageMakerNotebookOperator(
+ task_id="notebook_task",
+ input_config={"input_path": "path/to/notebook.ipynb",
"input_params": ""},
+ output_config={"output_format": "ipynb"},
+ wait_for_completion=True,
+ waiter_delay=10,
+ waiter_max_attempts=1440,
+ )
+
+ :param task_id: A unique, meaningful id for the task.
+ :param input_config: Configuration for the input file. Input path should
be specified as a relative path.
+ The provided relative path will be automatically resolved to an
absolute path within
+ the context of the user's home directory in the IDE. Input params
should be a dict.
+ Example: {'input_path': 'folder/input/notebook.ipynb',
'input_params':{'key': 'value'}}
+ :param output_config: Configuration for the output format. It should
include an output_format parameter to control
+ the format of the notebook execution output.
+ Example: {"output_formats": ["NOTEBOOK"]}
+ :param compute: compute configuration to use for the artifact execution.
This is a required attribute
+ if the execution is on a remote compute.
+ Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30,
"VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]}
+ :param termination_condition: conditions to match to terminate the remote
execution.
+ Example: { "MaxRuntimeInSeconds": 3600 }
+ :param tags: tags to be associated with the remote execution runs.
+ Example: { "md_analytics": "logs" }
+ :param wait_for_completion: Indicates whether to wait for the notebook
execution to complete. If True, wait for completion; if False, don't wait.
+ :param waiter_delay: Interval in seconds to check the notebook execution
status.
+ :param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
+ :param deferrable: If True, the operator will wait asynchronously for the
job to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerNotebookOperator`
+ """
+
+ operator_extra_links = (SageMakerUnifiedStudioLink(),)
+
+ def __init__(
+ self,
+ task_id: str,
+ input_config: dict,
+ output_config: dict | None = None,
+ compute: dict | None = None,
+ termination_condition: dict | None = None,
+ tags: dict | None = None,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 10,
+ waiter_max_attempts: int = 1440,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(task_id=task_id, **kwargs)
+ self.execution_name = task_id
+ self.input_config = input_config
+ self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
+ self.compute = compute or {}
+ self.termination_condition = termination_condition or {}
+ self.tags = tags or {}
+ self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
+ self.input_kwargs = kwargs
+
+ @cached_property
+ def notebook_execution_hook(self):
+ if not self.input_config:
+ raise AirflowException("input_config is required")
+
+ if "input_path" not in self.input_config:
+ raise AirflowException("input_path is a required field in the
input_config")
+
+ return SageMakerNotebookHook(
+ input_config=self.input_config,
+ output_config=self.output_config,
+ execution_name=self.execution_name,
+ compute=self.compute,
+ termination_condition=self.termination_condition,
+ tags=self.tags,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ )
+
+ def execute(self, context: Context):
+ notebook_execution =
self.notebook_execution_hook.start_notebook_execution()
+ execution_id = notebook_execution["execution_id"]
+
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerNotebookJobTrigger(
+ execution_id=execution_id,
+ execution_name=self.execution_name,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ response =
self.notebook_execution_hook.wait_for_execution_completion(execution_id,
context)
+ status = response["Status"]
+ log_info_message = (
+ f"Notebook Execution: {self.execution_name} Status: {status}.
Run Id: {execution_id}"
+ )
+ self.log.info(log_info_message)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..ab32b50dbe8
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook sensor."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import (
+ SageMakerNotebookHook,
+)
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class SageMakerNotebookSensor(BaseSensorOperator):
+ """
+ Waits for a Sagemaker Workflows Notebook execution to reach any of the
status below.
+
+ 'FAILED', 'STOPPED', 'COMPLETED'
+
+ :param execution_id: The Sagemaker Workflows Notebook running execution
identifier
+ :param execution_name: The Sagemaker Workflows Notebook unique execution
name
+ """
+
+ def __init__(self, *, execution_id: str, execution_name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.execution_id = execution_id
+ self.execution_name = execution_name
+ self.success_state = ["COMPLETED"]
+ self.in_progress_states = ["PENDING", "RUNNING"]
+
+ def hook(self):
+ return SageMakerNotebookHook(execution_name=self.execution_name)
+
+ # override from base sensor
+ def poke(self, context=None):
+ status =
self.hook().get_execution_status(execution_id=self.execution_id)
+
+ if status in self.success_state:
+ log_info_message = f"Exiting Execution {self.execution_id} State:
{status}"
+ self.log.info(log_info_message)
+ return True
+ elif status in self.in_progress_states:
+ return False
+ else:
+ error_message = f"Exiting Execution {self.execution_id} State:
{status}"
+ self.log.info(error_message)
+ raise AirflowException(error_message)
+
+ def execute(self, context: Context):
+ # This will invoke poke method in the base sensor
+ log_info_message = f"Polling Sagemaker Workflows Artifact execution:
{self.execution_name} and execution id: {self.execution_id}"
+ self.log.info(log_info_message)
+ super().execute(context=context)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..e9285e9d8dd
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains the Amazon SageMaker Unified Studio Notebook job
trigger."""
+
+from __future__ import annotations
+
+from airflow.triggers.base import BaseTrigger
+
+
+class SageMakerNotebookJobTrigger(BaseTrigger):
+ """
+ Watches for a notebook job, triggers when it finishes.
+
+ Examples:
+ .. code-block:: python
+
+ from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio
import SageMakerNotebookJobTrigger
+
+ notebook_trigger = SageMakerNotebookJobTrigger(
+ execution_id="notebook_job_1234",
+ execution_name="notebook_task",
+ waiter_delay=10,
+ waiter_max_attempts=1440,
+ )
+
+ :param execution_id: A unique, meaningful id for the task.
+ :param execution_name: A unique, meaningful name for the task.
+ :param waiter_delay: Interval in seconds to check the notebook execution
status.
+ :param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
+ """
+
+ def __init__(self, execution_id, execution_name, waiter_delay,
waiter_max_attempts, **kwargs):
+ super().__init__(**kwargs)
+ self.execution_id = execution_id
+ self.execution_name = execution_name
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+
+ def serialize(self):
+ return (
+ # dynamically generate the fully qualified name of the class
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "execution_id": self.execution_id,
+ "execution_name": self.execution_name,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self):
+ pass
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py
new file mode 100644
index 00000000000..63862239bdd
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module contains utils for the Amazon SageMaker Unified Studio Notebook
plugin."""
+
+from __future__ import annotations
+
+import os
+
+workflows_env_key = "WORKFLOWS_ENV"
+
+
+def is_local_runner():
+ return os.getenv(workflows_env_key, "") == "Local"
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index 69a1a80fc4d..cef82566eeb 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -275,6 +275,15 @@ def get_provider_info():
"how-to-guide":
["/docs/apache-airflow-providers-amazon/operators/sagemaker.rst"],
"tags": ["aws"],
},
+ {
+ "integration-name": "Amazon SageMaker Unified Studio",
+ "external-doc-url":
"https://aws.amazon.com/sagemaker/unified-studio/",
+ "logo":
"/docs/integration-logos/[email protected]",
+ "how-to-guide": [
+
"/docs/apache-airflow-providers-amazon/operators/sagemakerunifiedstudio.rst"
+ ],
+ "tags": ["aws"],
+ },
{
"integration-name": "Amazon SecretsManager",
"external-doc-url": "https://aws.amazon.com/secrets-manager/",
@@ -491,6 +500,10 @@ def get_provider_info():
"integration-name": "Amazon SageMaker",
"python-modules":
["airflow.providers.amazon.aws.operators.sagemaker"],
},
+ {
+ "integration-name": "Amazon SageMaker Unified Studio",
+ "python-modules":
["airflow.providers.amazon.aws.operators.sagemaker_unified_studio"],
+ },
{
"integration-name": "Amazon Simple Notification Service (SNS)",
"python-modules":
["airflow.providers.amazon.aws.operators.sns"],
@@ -628,6 +641,10 @@ def get_provider_info():
"integration-name": "Amazon SageMaker",
"python-modules":
["airflow.providers.amazon.aws.sensors.sagemaker"],
},
+ {
+ "integration-name": "Amazon SageMaker Unified Studio",
+ "python-modules":
["airflow.providers.amazon.aws.sensors.sagemaker_unified_studio"],
+ },
{
"integration-name": "Amazon Simple Queue Service (SQS)",
"python-modules": ["airflow.providers.amazon.aws.sensors.sqs"],
@@ -781,6 +798,10 @@ def get_provider_info():
"integration-name": "Amazon SageMaker",
"python-modules":
["airflow.providers.amazon.aws.hooks.sagemaker"],
},
+ {
+ "integration-name": "Amazon SageMaker Unified Studio",
+ "python-modules":
["airflow.providers.amazon.aws.hooks.sagemaker_unified_studio"],
+ },
{
"integration-name": "Amazon Simple Email Service (SES)",
"python-modules": ["airflow.providers.amazon.aws.hooks.ses"],
@@ -878,6 +899,10 @@ def get_provider_info():
"integration-name": "Amazon SageMaker",
"python-modules":
["airflow.providers.amazon.aws.triggers.sagemaker"],
},
+ {
+ "integration-name": "Amazon SageMaker Unified Studio",
+ "python-modules":
["airflow.providers.amazon.aws.triggers.sagemaker_unified_studio"],
+ },
{
"integration-name": "AWS Glue",
"python-modules": [
@@ -1072,6 +1097,7 @@ def get_provider_info():
"airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink",
"airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink",
"airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink",
+
"airflow.providers.amazon.aws.links.sagemaker_unified_studio.SageMakerUnifiedStudioLink",
"airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink",
"airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink",
"airflow.providers.amazon.aws.links.comprehend.ComprehendPiiEntitiesDetectionLink",
@@ -1354,6 +1380,7 @@ def get_provider_info():
"PyAthena>=3.0.10",
"jmespath>=0.7.0",
"python3-saml>=1.16.0",
+ "sagemaker-studio>=1.0.9",
],
"optional-dependencies": {
"pandas": ["pandas>=2.1.2,<2.2"],
diff --git
a/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
b/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..8a4a5c14c66
--- /dev/null
+++
b/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+import pytest
+
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import (
+ SageMakerNotebookOperator,
+)
+from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
+
+"""
+Prerequisites: The account which runs this test must manually have the
following:
+1. An IAM IDC organization set up in the testing region with a user initialized
+2. A SageMaker Unified Studio Domain (with default VPC and roles)
+3. A project within the SageMaker Unified Studio Domain
+4. A notebook (test_notebook.ipynb) placed in the project's s3 path
+
+This test will emulate a DAG run in the shared MWAA environment inside a
SageMaker Unified Studio Project.
+The setup tasks will set up the project and configure the test runner to
emulate an MWAA instance.
+Then, the SageMakerNotebookOperator will run a test notebook. This should spin
up a SageMaker training job, run the notebook, and exit successfully.
+"""
+
+pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires
Airflow 2.10+")
+
+DAG_ID = "example_sagemaker_unified_studio"
+
+# Externally fetched variables:
+DOMAIN_ID_KEY = "DOMAIN_ID"
+PROJECT_ID_KEY = "PROJECT_ID"
+ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID"
+S3_PATH_KEY = "S3_PATH"
+REGION_NAME_KEY = "REGION_NAME"
+
+sys_test_context_task = (
+ SystemTestContextBuilder()
+ .add_variable(DOMAIN_ID_KEY)
+ .add_variable(PROJECT_ID_KEY)
+ .add_variable(ENVIRONMENT_ID_KEY)
+ .add_variable(S3_PATH_KEY)
+ .add_variable(REGION_NAME_KEY)
+ .build()
+)
+
+
+def get_mwaa_environment_params(
+ domain_id: str,
+ project_id: str,
+ environment_id: str,
+ s3_path: str,
+ region_name: str,
+):
+ AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__"
+
+ parameters = {}
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev"
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod"
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] =
f"https://datazone.{region_name}.api.aws"
+ parameters[f"{AIRFLOW_PREFIX}PROJECT_S3_PATH"] = s3_path
+ parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_REGION"] = region_name
+ return parameters
+
+
+@task
+def mock_mwaa_environment(parameters: dict):
+ """
+ Sets several environment variables in the container to emulate an MWAA
environment provisioned
+ within SageMaker Unified Studio. When running in the ECSExecutor, this is
a no-op.
+ """
+ import os
+
+ for key, value in parameters.items():
+ os.environ[key] = value
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ tags=["example"],
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+
+ test_env_id = test_context[ENV_ID_KEY]
+ domain_id = test_context[DOMAIN_ID_KEY]
+ project_id = test_context[PROJECT_ID_KEY]
+ environment_id = test_context[ENVIRONMENT_ID_KEY]
+ s3_path = test_context[S3_PATH_KEY]
+ region_name = test_context[REGION_NAME_KEY]
+
+ mock_mwaa_environment_params = get_mwaa_environment_params(
+ domain_id,
+ project_id,
+ environment_id,
+ s3_path,
+ region_name,
+ )
+
+ setup_mwaa_environment =
mock_mwaa_environment(mock_mwaa_environment_params)
+
+ # [START howto_operator_sagemaker_unified_studio_notebook]
+ notebook_path = "test_notebook.ipynb" # This should be the path to your
.ipynb, .sqlnb, or .vetl file in your project.
+
+ run_notebook = SageMakerNotebookOperator(
+ task_id="run-notebook",
+ input_config={"input_path": notebook_path, "input_params": {}},
+ output_config={"output_formats": ["NOTEBOOK"]}, # optional
+ compute={
+ "instance_type": "ml.m5.large",
+ "volume_size_in_gb": 30,
+ }, # optional
+ termination_condition={"max_runtime_in_seconds": 600}, # optional
+ tags={}, # optional
+ wait_for_completion=True, # optional
+ waiter_delay=5, # optional
+ deferrable=False, # optional
+ executor_config={ # optional
+ "overrides": {"containerOverrides": {"environment":
mock_mwaa_environment_params}}
+ },
+ )
+ # [END howto_operator_sagemaker_unified_studio_notebook]
+
+ chain(
+ # TEST SETUP
+ test_context,
+ setup_mwaa_environment,
+ # TEST BODY
+ run_notebook,
+ )
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..179d997740c
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
@@ -0,0 +1,201 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+from sagemaker_studio.models.execution import ExecutionClient
+
+from airflow.exceptions import AirflowException
+from airflow.models import TaskInstance
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import (
+ SageMakerNotebookHook,
+)
+from airflow.utils.session import create_session
+
+pytestmark = pytest.mark.db_test
+
+
+class TestSageMakerNotebookHook:
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ with patch(
+
"airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerStudioAPI",
+ autospec=True,
+ ) as mock_sdk:
+ self.execution_name = "test-execution"
+ self.waiter_delay = 10
+ sdk_instance = mock_sdk.return_value
+ sdk_instance.execution_client = MagicMock(spec=ExecutionClient)
+ sdk_instance.execution_client.start_execution.return_value = {
+ "execution_id": "execution_id",
+ "execution_name": "execution_name",
+ }
+ self.hook = SageMakerNotebookHook(
+ input_config={
+ "input_path": "test-data/notebook/test_notebook.ipynb",
+ "input_params": {"key": "value"},
+ },
+ output_config={"output_formats": ["NOTEBOOK"]},
+ execution_name=self.execution_name,
+ waiter_delay=self.waiter_delay,
+ compute={"instance_type": "ml.c4.2xlarge"},
+ )
+
+ self.hook._sagemaker_studio = mock_sdk
+ self.files = [
+ {"display_name": "file1.txt", "url":
"http://example.com/file1.txt"},
+ {"display_name": "file2.txt", "url":
"http://example.com/file2.txt"},
+ ]
+ self.context = {
+ "ti": MagicMock(spec=TaskInstance),
+ }
+ self.s3Path = "S3Path"
+ yield
+
+ def test_format_input_config(self):
+ expected_config = {
+ "notebook_config": {
+ "input_path": "test-data/notebook/test_notebook.ipynb",
+ "input_parameters": {"key": "value"},
+ }
+ }
+
+ config = self.hook._format_start_execution_input_config()
+ assert config == expected_config
+
+ def test_format_output_config(self):
+ expected_config = {
+ "notebook_config": {
+ "output_formats": ["NOTEBOOK"],
+ }
+ }
+
+ config = self.hook._format_start_execution_output_config()
+ assert config == expected_config
+
+ def test_format_output_config_default(self):
+ no_output_config_hook = SageMakerNotebookHook(
+ input_config={
+ "input_path": "test-data/notebook/test_notebook.ipynb",
+ "input_params": {"key": "value"},
+ },
+ execution_name=self.execution_name,
+ waiter_delay=self.waiter_delay,
+ )
+
+ no_output_config_hook._sagemaker_studio = self.hook._sagemaker_studio
+ expected_config = {"notebook_config": {"output_formats": ["NOTEBOOK"]}}
+
+ config = no_output_config_hook._format_start_execution_output_config()
+ assert config == expected_config
+
+ def test_start_notebook_execution(self):
+ self.hook._sagemaker_studio = MagicMock()
+ self.hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+
+
self.hook._sagemaker_studio.execution_client.start_execution.return_value =
{"executionId": "123456"}
+ result = self.hook.start_notebook_execution()
+ assert result == {"executionId": "123456"}
+
self.hook._sagemaker_studio.execution_client.start_execution.assert_called_once()
+
+ @patch("time.sleep", return_value=None) # To avoid actual sleep during
tests
+ def test_wait_for_execution_completion(self, mock_sleep):
+ execution_id = "123456"
+ self.hook._sagemaker_studio = MagicMock()
+ self.hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+
self.hook._sagemaker_studio.execution_client.get_execution.return_value =
{"status": "COMPLETED"}
+
+ result = self.hook.wait_for_execution_completion(execution_id, {})
+ assert result == {"Status": "COMPLETED", "ExecutionId": execution_id}
+
self.hook._sagemaker_studio.execution_client.get_execution.assert_called()
+ mock_sleep.assert_called_once()
+
+ @patch("time.sleep", return_value=None)
+ def test_wait_for_execution_completion_failed(self, mock_sleep):
+ execution_id = "123456"
+ self.hook._sagemaker_studio = MagicMock()
+ self.hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+
self.hook._sagemaker_studio.execution_client.get_execution.return_value = {
+ "status": "FAILED",
+ "error_details": {"error_message": "Execution failed"},
+ }
+
+ with pytest.raises(AirflowException, match="Execution failed"):
+ self.hook.wait_for_execution_completion(execution_id, self.context)
+
+ def test_handle_in_progress_state(self):
+ execution_id = "123456"
+ states = ["IN_PROGRESS", "STOPPING"]
+
+ for status in states:
+ result = self.hook._handle_state(execution_id, status, None)
+ assert result is None
+
+ def test_handle_finished_state(self):
+ execution_id = "123456"
+ states = ["COMPLETED"]
+
+ for status in states:
+ result = self.hook._handle_state(execution_id, status, None)
+ assert result == {"Status": status, "ExecutionId": execution_id}
+
+ def test_handle_failed_state(self):
+ execution_id = "123456"
+ status = "FAILED"
+ error_message = "Execution failed"
+ with pytest.raises(AirflowException, match=error_message):
+ self.hook._handle_state(execution_id, status, error_message)
+
+ status = "STOPPED"
+ error_message = ""
+ with pytest.raises(AirflowException, match=f"Exiting Execution
{execution_id} State: {status}"):
+ self.hook._handle_state(execution_id, status, error_message)
+
+ def test_handle_unexpected_state(self):
+ execution_id = "123456"
+ status = "PENDING"
+ error_message = f"Exiting Execution {execution_id} State: {status}"
+ with pytest.raises(AirflowException, match=error_message):
+ self.hook._handle_state(execution_id, status, error_message)
+
+ @patch(
+
"airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerNotebookHook._set_xcom_files"
+ )
+ def test_set_xcom_files(self, mock_set_xcom_files):
+ with create_session():
+ self.hook._set_xcom_files(self.files, self.context)
+ expected_call = call(self.files, self.context)
+ mock_set_xcom_files.assert_called_once_with(*expected_call.args,
**expected_call.kwargs)
+
+ def test_set_xcom_files_negative_missing_context(self):
+ with pytest.raises(AirflowException, match="context is required"):
+ self.hook._set_xcom_files(self.files, {})
+
+ @patch(
+
"airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerNotebookHook._set_xcom_s3_path"
+ )
+ def test_set_xcom_s3_path(self, mock_set_xcom_s3_path):
+ with create_session():
+ self.hook._set_xcom_s3_path(self.s3Path, self.context)
+ expected_call = call(self.s3Path, self.context)
+ mock_set_xcom_s3_path.assert_called_once_with(*expected_call.args,
**expected_call.kwargs)
+
+ def test_set_xcom_s3_path_negative_missing_context(self):
+ with pytest.raises(AirflowException, match="context is required"):
+ self.hook._set_xcom_s3_path(self.s3Path, {})
diff --git
a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..c55d1231fd8
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.sagemaker_unified_studio import
SageMakerUnifiedStudioLink
+from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase
+
+
+class TestSageMakerUnifiedStudioLink(BaseAwsLinksTestCase):
+ link_class = SageMakerUnifiedStudioLink
+
+ def test_extra_link(self):
+ self.assert_extra_link_url(
+
expected_url=("https://console.aws.amazon.com/datazone/home?region=us-east-1"),
+ region_name="us-east-1",
+ aws_partition="aws",
+ job_name="test_job_name",
+ )
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb
b/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb
new file mode 100644
index 00000000000..395eff4ef62
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb
@@ -0,0 +1,61 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "437d7d66",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a734f58df854b5fa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def add(num1, num2):\n",
+ " return num1 + num2"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..87e9b004db1
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import (
+ SageMakerNotebookOperator,
+)
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import (
+ SageMakerNotebookJobTrigger,
+)
+
+
+class TestSageMakerNotebookOperator:
+ def test_init(self):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={
+ "notebook_path":
"tests/amazon/aws/operators/test_notebook.ipynb",
+ },
+ output_config={"output_format": "ipynb"},
+ )
+
+ assert operator.task_id == "test_id"
+ assert operator.input_config == {
+ "notebook_path": "tests/amazon/aws/operators/test_notebook.ipynb",
+ }
+ assert operator.output_config == {"output_format": "ipynb"}
+
+ def test_only_required_params_init(self):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={
+ "notebook_path":
"tests/amazon/aws/operators/test_notebook.ipynb",
+ },
+ )
+ assert isinstance(operator, SageMakerNotebookOperator)
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_success(self, mock_notebook_hook): # Mock the
NotebookHook and its execute method
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+
+ # Create the operator
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "test_input_path"},
+ output_config={"output_uri": "test_output_uri", "output_format":
"ipynb"},
+ )
+
+ # Execute the operator
+ operator.execute({})
+ mock_hook_instance.start_notebook_execution.assert_called_once_with()
+
mock_hook_instance.wait_for_execution_completion.assert_called_once_with("123456",
{})
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_failure_missing_input_config(self, mock_notebook_hook):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={},
+ output_config={"output_uri": "test_output_uri", "output_format":
"ipynb"},
+ )
+
+ with pytest.raises(AirflowException, match="input_config is required"):
+ operator.execute({})
+
+ mock_notebook_hook.assert_not_called()
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_failure_missing_input_path(self, mock_notebook_hook):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"invalid_key": "test_input_path"},
+ output_config={"output_uri": "test_output_uri", "output_format":
"ipynb"},
+ )
+
+ with pytest.raises(AirflowException, match="input_path is a required
field in the input_config"):
+ operator.execute({})
+
+ mock_notebook_hook.assert_not_called()
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_with_wait_for_completion(self, mock_notebook_hook):
+ # Mock the execute and job_completion methods of NotebookHook
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+ mock_hook_instance.wait_for_execution_completion.return_value =
{"Status": "COMPLETED"}
+
+ # Create the operator with wait_for_completion set to True
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "test_input_path"},
+ output_config={"output_uri": "test_output_uri", "output_format":
"ipynb"},
+ wait_for_completion=True,
+ )
+ # Execute the operator
+ operator.execute({})
+
+ # Verify that execute and wait_for_execution_completion methods are
called
+ mock_hook_instance.start_notebook_execution.assert_called_once_with()
+
mock_hook_instance.wait_for_execution_completion.assert_called_once_with("123456",
{})
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ @patch.object(SageMakerNotebookOperator, "defer")
+ def test_execute_with_deferrable(self, mock_defer, mock_notebook_hook):
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "test_input_path"},
+ output_config={"output_format": "ipynb"},
+ deferrable=True,
+ )
+
+ operator.execute({})
+
+ mock_hook_instance.start_notebook_execution.assert_called_once_with()
+ mock_defer.assert_called_once()
+ trigger_call = mock_defer.call_args[1]["trigger"]
+ assert isinstance(trigger_call, SageMakerNotebookJobTrigger)
+ assert trigger_call.execution_id == "123456"
+ assert trigger_call.execution_name == "test_id"
+ assert trigger_call.waiter_delay == 10
+ mock_hook_instance.wait_for_execution_completion.assert_not_called()
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_without_wait_for_completion(self, mock_notebook_hook):
+ # Mock the execute method of NotebookHook
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+
+ # Create the operator with wait_for_completion set to False
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "test_input_path"},
+ output_config={"output_uri": "test_output_uri", "output_format":
"ipynb"},
+ wait_for_completion=False,
+ )
+
+ # Execute the operator
+ operator.execute({})
+
+ # Verify that execute and wait_for_execution_completion methods are
called
+ mock_hook_instance.start_notebook_execution.assert_called_once_with()
+ mock_hook_instance.wait_for_execution_completion.assert_not_called()
diff --git
a/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..46b4e40bf7c
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.sensors.sagemaker_unified_studio import (
+ SageMakerNotebookSensor,
+)
+from airflow.utils.context import Context
+
+
+class TestSageMakerNotebookSensor:
+ def test_init(self):
+ # Test the initialization of the sensor
+ sensor = SageMakerNotebookSensor(
+ task_id="test_task",
+ execution_id="test_execution_id",
+ execution_name="test_execution_name",
+ )
+ assert sensor.execution_id == "test_execution_id"
+ assert sensor.execution_name == "test_execution_name"
+ assert sensor.success_state == ["COMPLETED"]
+ assert sensor.in_progress_states == ["PENDING", "RUNNING"]
+
+
@patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_poke_success_state(self, mock_notebook_hook):
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.get_execution_status.return_value = "COMPLETED"
+
+ sensor = SageMakerNotebookSensor(
+ task_id="test_task",
+ execution_id="test_execution_id",
+ execution_name="test_execution_name",
+ )
+
+ # Test the poke method
+ result = sensor.poke()
+ assert result is True
+
mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id")
+
+
@patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_poke_failure_state(self, mock_notebook_hook):
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.get_execution_status.return_value = "FAILED"
+
+ sensor = SageMakerNotebookSensor(
+ task_id="test_task",
+ execution_id="test_execution_id",
+ execution_name="test_execution_name",
+ )
+
+ # Test the poke method and assert exception
+ with pytest.raises(AirflowException, match="Exiting Execution
test_execution_id State: FAILED"):
+ sensor.poke()
+
+
mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id")
+
+
@patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_poke_in_progress_state(self, mock_notebook_hook):
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.get_execution_status.return_value = "RUNNING"
+
+ sensor = SageMakerNotebookSensor(
+ task_id="test_task",
+ execution_id="test_execution_id",
+ execution_name="test_execution_name",
+ )
+
+ # Test the poke method
+ result = sensor.poke()
+ assert result is False
+
mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id")
+
+ @patch.object(SageMakerNotebookSensor, "poke", return_value=True)
+ def test_execute_calls_poke(self, mock_poke):
+ # Create the sensor
+ sensor = SageMakerNotebookSensor(
+ task_id="test_task",
+ execution_id="test_execution_id",
+ execution_name="test_execution_name",
+ )
+
+ context = MagicMock(spec=Context)
+ sensor.execute(context=context)
+
+ # Assert that the poke method was called
+ mock_poke.assert_called_once_with(context)
diff --git
a/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py
new file mode 100644
index 00000000000..0730923fae1
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+
+from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import
is_local_runner, workflows_env_key
+
+
+def test_is_local_runner_false():
+ assert not is_local_runner()
+
+
+def test_is_local_runner_true():
+ os.environ[workflows_env_key] = "Local"
+ assert is_local_runner()
+
+
+def test_is_local_runner_false_with_env_var():
+ os.environ[workflows_env_key] = "False"
+ assert not is_local_runner()
+
+
+def test_is_local_runner_false_with_env_var_empty():
+ os.environ[workflows_env_key] = ""
+ assert not is_local_runner()
+
+
+def test_is_local_runner_false_with_env_var_invalid():
+ os.environ[workflows_env_key] = "random string"
+ assert not is_local_runner()
+
+
+def test_is_local_runner_false_with_string_int():
+ os.environ[workflows_env_key] = "1"
+ assert not is_local_runner()
diff --git
a/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt
b/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt
new file mode 100644
index 00000000000..7ad85fd1746
--- /dev/null
+++ b/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt
@@ -0,0 +1,89 @@
+Apache Airflow
+Copyright 2016-2023 The Apache Software Foundation
+
+This product includes software developed at The Apache Software
+Foundation (http://www.apache.org/).
+
+=======================================================================
+css-loader|5.2.7:
+-----
+MIT
+Copyright JS Foundation and other contributors
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+'Software'), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+webpack-contrib/css-loader
+
+
+moment|2.30.1:
+-----
+MIT
+Copyright (c) JS Foundation and other contributors
+
+Permission is hereby granted, free of charge, to any person
+obtaining a copy of this software and associated documentation
+files (the "Software"), to deal in the Software without
+restriction, including without limitation the rights to use,
+copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+https://github.com/moment/moment.git
+
+
+moment-timezone|0.5.47:
+-----
+MIT
+The MIT License (MIT)
+
+Copyright (c) JS Foundation and other contributors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+https://github.com/moment/moment-timezone.git
+
+
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 310994cf38e..cd73fbce1ec 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -71,6 +71,7 @@ class TestProjectStructure:
"providers/amazon/tests/unit/amazon/aws/sensors/test_emr.py",
"providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker.py",
"providers/amazon/tests/unit/amazon/aws/test_exceptions.py",
+
"providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio.py",
"providers/amazon/tests/unit/amazon/aws/triggers/test_step_function.py",
"providers/amazon/tests/unit/amazon/aws/utils/test_rds.py",
"providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker.py",
@@ -603,6 +604,8 @@ class
TestAmazonProviderProjectStructure(ExampleCoverageTest):
# These operations take a lot of time, there are commented out in the
system tests for this reason
"airflow.providers.amazon.aws.operators.dms.DmsStartReplicationOperator",
"airflow.providers.amazon.aws.operators.dms.DmsStopReplicationOperator",
+ # These modules are used in the SageMakerNotebookOperator and
therefore don't have their own examples
+
"airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookSensor",
}
DEPRECATED_CLASSES = {