This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0cc1b92bd72 Adds async support to SageMakerNotebookJobTrigger (#65571)
0cc1b92bd72 is described below
commit 0cc1b92bd728199b252144da0351135c815a0cbe
Author: EMMANUELA OPURUM
<[email protected]>
AuthorDate: Tue May 12 01:12:56 2026 +0100
Adds async support to SageMakerNotebookJobTrigger (#65571)
* feat: add async support to SageMakerNotebookJobTrigger
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: fix import order and trailing newline in test file
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: fix import order in test file
Signed-off-by: Emmanuela Opurum <[email protected]>
* ci: retrigger static checks
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: update exception allowlist and fix hook test error message
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: restore allowlist to upstream version
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: replace AirflowException with early return in xcom helper methods
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: update xcom tests to match early return behaviour
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: fix literal backtick-n in xcom test methods
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: fix syntax errors in hook and fix trigger timeout test
Signed-off-by: Emmanuela Opurum <[email protected]>
* fix: update exception allowlist count to 2
Signed-off-by: Emmanuela Opurum <[email protected]>
---------
Signed-off-by: Emmanuela Opurum <[email protected]>
Co-authored-by: Emmanuela Opurum <[email protected]>
---
.../amazon/aws/hooks/sagemaker_unified_studio.py | 100 +++++++-------------
.../aws/triggers/sagemaker_unified_studio.py | 94 ++++++++++++------
.../aws/hooks/test_sagemaker_unified_studio.py | 12 +--
.../test_sagemaker_unified_studio_trigger.py | 105 +++++++++++++++++++++
scripts/ci/prek/known_airflow_exceptions.txt | 2 +-
5 files changed, 212 insertions(+), 101 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
index 9c750e2973b..43587c6c78c 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
@@ -20,6 +20,7 @@
from __future__ import annotations
import time
+from typing import Any
from sagemaker_studio import ClientConfig
from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
@@ -33,52 +34,6 @@ class SageMakerNotebookHook(BaseHook):
Interact with Sagemaker Unified Studio Workflows for executing Jupyter
notebooks, querybooks, and visual ETL jobs.
This hook provides a wrapper around the Sagemaker Workflows Notebook
Execution API.
-
- Examples:
- .. code-block:: python
-
- from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio
import SageMakerNotebookHook
-
- notebook_hook = SageMakerNotebookHook(
- execution_name="notebook_execution",
- domain_id="dzd-example123456",
- project_id="example123456",
- input_config={"input_path": "path/to/notebook.ipynb",
"input_params": {"param1": "value1"}},
- output_config={"output_uri": "folder/output/location/prefix",
"output_formats": "NOTEBOOK"},
- domain_region="us-east-1",
- waiter_delay=10,
- waiter_max_attempts=1440,
- )
-
- :param execution_name: The name of the notebook job to be executed, this
is same as task_id.
- :param domain_id: The domain ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
- the SDK will attempt to resolve it from the environment.
- :param project_id: The project ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
- the SDK will attempt to resolve it from the environment.
- :param input_config: Configuration for the input file.
- Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params':
{'param1': 'value1'}}
- :param output_config: Configuration for the output format. It should
include an output_formats parameter to specify the output format.
- Example: {'output_formats': ['NOTEBOOK']}
- :param domain_region: The AWS region for the domain. If not provided, the
default AWS region will be used.
- :param compute: compute configuration to use for the notebook execution.
This is a required attribute
- if the execution is on a remote compute.
- Example::
-
- {
- "instance_type": "ml.c5.xlarge",
- "image_details": {
- "image_name": "sagemaker-distribution-prod",
- "image_version": "3",
- "ecr_uri":
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
- },
- }
-
- :param termination_condition: conditions to match to terminate the remote
execution.
- Example: ``{"MaxRuntimeInSeconds": 3600}``
- :param tags: tags to be associated with the remote execution runs.
- Example: ``{"md_analytics": "logs"}``
- :param waiter_delay: Interval in seconds to check the task execution
status.
- :param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
"""
def __init__(
@@ -124,23 +79,19 @@ class SageMakerNotebookHook(BaseHook):
return config
def _format_start_execution_input_config(self):
- config = {
+ return {
"notebook_config": {
"input_path": self.input_config.get("input_path"),
"input_parameters": self.input_config.get("input_params"),
},
}
- return config
-
def _format_start_execution_output_config(self):
- output_formats = self.output_config.get("output_formats")
- config = {
+ return {
"notebook_config": {
- "output_formats": output_formats,
+ "output_formats": self.output_config.get("output_formats"),
}
}
- return config
def start_notebook_execution(self):
start_execution_params = {
@@ -162,16 +113,26 @@ class SageMakerNotebookHook(BaseHook):
return
self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
+ def get_notebook_execution(self, execution_id: str) -> dict[str, Any]:
+ """Fetch the status of a SageMaker Notebook Job execution."""
+ if self._sagemaker_studio.execution_client is None:
+ raise AirflowException("SageMaker Studio execution client is not
initialized.")
+ return
self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
+
def wait_for_execution_completion(self, execution_id, context):
wait_attempts = 0
while wait_attempts < self.waiter_max_attempts:
wait_attempts += 1
time.sleep(self.waiter_delay)
- response =
self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
+
+ response = self.get_notebook_execution(execution_id)
+
error_message = response.get("error_details",
{}).get("error_message")
status = response["status"]
+
if "files" in response:
self._set_xcom_files(response["files"], context)
+
if "s3_path" in response:
self._set_xcom_s3_path(response["s3_path"], context)
@@ -179,13 +140,12 @@ class SageMakerNotebookHook(BaseHook):
if ret:
return ret
- # If timeout, handle state FAILED with timeout message
return self._handle_state(execution_id, "FAILED", "Execution timed
out")
def _set_xcom_files(self, files, context):
if not context:
- error_message = "context is required"
- raise AirflowException(error_message)
+ return
+
for file in files:
context["ti"].xcom_push(
key=f"{file['display_name']}.{file['file_format']}",
@@ -194,8 +154,8 @@ class SageMakerNotebookHook(BaseHook):
def _set_xcom_s3_path(self, s3_path, context):
if not context:
- error_message = "context is required"
- raise AirflowException(error_message)
+ return
+
context["ti"].xcom_push(
key="s3_path",
value=s3_path,
@@ -206,15 +166,21 @@ class SageMakerNotebookHook(BaseHook):
in_progress_states = ["IN_PROGRESS", "STOPPING"]
if status in in_progress_states:
- info_message = f"Execution {execution_id} is still in progress
with state:{status}, will check for a terminal status again in
{self.waiter_delay}"
- self.log.info(info_message)
+ self.log.info(
+ "Execution %s is still in progress with state:%s, will check
again in %ss",
+ execution_id,
+ status,
+ self.waiter_delay,
+ )
return None
- execution_message = f"Exiting Execution {execution_id} State: {status}"
+
if status in finished_states:
- self.log.info(execution_message)
+ self.log.info("Execution %s completed successfully", execution_id)
return {"Status": status, "ExecutionId": execution_id}
- log_error_message = f"Execution {execution_id} failed with error:
{error_message}"
- self.log.error(log_error_message)
- if error_message == "":
- error_message = execution_message
+
+ self.log.error("Execution %s failed with error: %s", execution_id,
error_message)
+
+ if not error_message:
+ error_message = f"Execution {execution_id} ended with status
{status}"
+
raise AirflowException(error_message)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
index e9285e9d8dd..2b172d7cf6a 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
@@ -19,48 +19,88 @@
from __future__ import annotations
-from airflow.triggers.base import BaseTrigger
+import asyncio
+from typing import Any
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import
SageMakerNotebookHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
-class SageMakerNotebookJobTrigger(BaseTrigger):
- """
- Watches for a notebook job, triggers when it finishes.
-
- Examples:
- .. code-block:: python
-
- from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio
import SageMakerNotebookJobTrigger
-
- notebook_trigger = SageMakerNotebookJobTrigger(
- execution_id="notebook_job_1234",
- execution_name="notebook_task",
- waiter_delay=10,
- waiter_max_attempts=1440,
- )
- :param execution_id: A unique, meaningful id for the task.
- :param execution_name: A unique, meaningful name for the task.
- :param waiter_delay: Interval in seconds to check the notebook execution
status.
- :param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
- """
+class SageMakerNotebookJobTrigger(BaseTrigger):
+ """Async trigger for SageMaker Unified Studio notebook executions."""
- def __init__(self, execution_id, execution_name, waiter_delay,
waiter_max_attempts, **kwargs):
+ def __init__(
+ self,
+ *,
+ execution_id: str,
+ execution_name: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.execution_id = execution_id
self.execution_name = execution_name
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
- def serialize(self):
+ def serialize(self) -> tuple[str, dict[str, Any]]:
return (
- # dynamically generate the fully qualified name of the class
- self.__class__.__module__ + "." + self.__class__.__qualname__,
+
"airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookJobTrigger",
{
"execution_id": self.execution_id,
"execution_name": self.execution_name,
- "poll_interval": self.poll_interval,
+ "waiter_delay": self.waiter_delay,
+ "waiter_max_attempts": self.waiter_max_attempts,
},
)
async def run(self):
- pass
+ hook = SageMakerNotebookHook(execution_name=self.execution_name)
+ attempts = 0
+
+ terminal_success = {"COMPLETED", "SUCCEEDED"}
+ terminal_failure = {"FAILED", "ERROR", "CANCELLED", "STOPPED"}
+
+ while attempts < self.waiter_max_attempts:
+ attempts += 1
+
+ # CI-safe async execution (NO run_in_executor)
+ response = await asyncio.to_thread(
+ hook.get_notebook_execution,
+ self.execution_id,
+ )
+
+ status = response.get("status")
+ error_message = response.get("error_details",
{}).get("error_message")
+
+ if status in terminal_success:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "execution_id": self.execution_id,
+ "files": response.get("files"),
+ "s3_path": response.get("s3_path"),
+ }
+ )
+ return
+
+ if status in terminal_failure:
+ yield TriggerEvent(
+ {
+ "status": "failed",
+ "execution_id": self.execution_id,
+ "error": error_message or f"Execution ended with
status: {status}",
+ }
+ )
+ return
+
+ await asyncio.sleep(self.waiter_delay)
+
+ yield TriggerEvent(
+ {
+ "status": "failed",
+ "execution_id": self.execution_id,
+ "error": "Execution timed out",
+ }
+ )
diff --git
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
index fd27931d228..449896ebd7b 100644
---
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
+++
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
@@ -162,13 +162,13 @@ class TestSageMakerNotebookHook:
status = "STOPPED"
error_message = ""
- with pytest.raises(AirflowException, match=f"Exiting Execution
{execution_id} State: {status}"):
+ with pytest.raises(AirflowException, match=f"Execution {execution_id}
ended with status {status}"):
self.hook._handle_state(execution_id, status, error_message)
def test_handle_unexpected_state(self):
execution_id = "123456"
status = "PENDING"
- error_message = f"Exiting Execution {execution_id} State: {status}"
+ error_message = f"Execution {execution_id} ended with status {status}"
with pytest.raises(AirflowException, match=error_message):
self.hook._handle_state(execution_id, status, error_message)
@@ -183,8 +183,8 @@ class TestSageMakerNotebookHook:
mock_set_xcom_files.assert_called_once_with(*expected_call.args,
**expected_call.kwargs)
def test_set_xcom_files_negative_missing_context(self):
- with pytest.raises(AirflowException, match="context is required"):
- self.hook._set_xcom_files(self.files, {})
+ # When context is empty, _set_xcom_files returns early without raising
+ self.hook._set_xcom_files(self.files, {})
@pytest.mark.db_test
@patch(
@@ -197,8 +197,8 @@ class TestSageMakerNotebookHook:
mock_set_xcom_s3_path.assert_called_once_with(*expected_call.args,
**expected_call.kwargs)
def test_set_xcom_s3_path_negative_missing_context(self):
- with pytest.raises(AirflowException, match="context is required"):
- self.hook._set_xcom_s3_path(self.s3Path, {})
+ # When context is empty, _set_xcom_s3_path returns early without
raising
+ self.hook._set_xcom_s3_path(self.s3Path, {})
def test_start_notebook_execution_custom_compute(self):
"""Test that custom compute config is used when provided."""
diff --git
a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
new file mode 100644
index 00000000000..c6efd51c6c0
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import (
+ SageMakerNotebookJobTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_success(mock_hook):
+ mock_hook.return_value.get_notebook_execution = MagicMock(
+ return_value={
+ "status": "COMPLETED",
+ "files": ["output.ipynb"],
+ "s3_path": "s3://bucket/path",
+ }
+ )
+ trigger = SageMakerNotebookJobTrigger(
+ execution_id="exec-123",
+ execution_name="my-notebook",
+ waiter_delay=1,
+ waiter_max_attempts=3,
+ )
+ gen = trigger.run()
+ event = await gen.asend(None)
+ assert isinstance(event, TriggerEvent)
+ assert event.payload["status"] == "success"
+ assert event.payload["execution_id"] == "exec-123"
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_failure(mock_hook):
+ mock_hook.return_value.get_notebook_execution = MagicMock(
+ return_value={
+ "status": "FAILED",
+ "error_details": {"error_message": "Something broke"},
+ }
+ )
+ trigger = SageMakerNotebookJobTrigger(
+ execution_id="exec-123",
+ execution_name="my-notebook",
+ waiter_delay=1,
+ waiter_max_attempts=3,
+ )
+ gen = trigger.run()
+ event = await gen.asend(None)
+ assert isinstance(event, TriggerEvent)
+ assert event.payload["status"] == "failed"
+ assert "Something broke" in event.payload["error"]
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_running_then_timeout(mock_hook):
+ mock_hook.return_value.get_notebook_execution =
MagicMock(return_value={"status": "IN_PROGRESS"})
+ trigger = SageMakerNotebookJobTrigger(
+ execution_id="exec-123",
+ execution_name="my-notebook",
+ waiter_delay=0,
+ waiter_max_attempts=2,
+ )
+ gen = trigger.run()
+ event = await gen.asend(None)
+ assert isinstance(event, TriggerEvent)
+ assert event.payload["status"] == "failed"
+ assert "timed out" in event.payload["error"]
+
+
+def test_trigger_serialize():
+ trigger = SageMakerNotebookJobTrigger(
+ execution_id="exec-123",
+ execution_name="my-notebook",
+ waiter_delay=5,
+ waiter_max_attempts=10,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == (
+
"airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookJobTrigger"
+ )
+ assert kwargs["execution_id"] == "exec-123"
+ assert kwargs["execution_name"] == "my-notebook"
+ assert kwargs["waiter_delay"] == 5
+ assert kwargs["waiter_max_attempts"] == 10
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt
b/scripts/ci/prek/known_airflow_exceptions.txt
index b9dd36f80c2..e56e96c35a3 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -61,7 +61,7 @@
providers/amazon/src/airflow/providers/amazon/aws/hooks/rds.py::7
providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py::3
providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py::2
providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py::10
-providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py::3
+providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py::2
providers/amazon/src/airflow/providers/amazon/aws/links/emr.py::2
providers/amazon/src/airflow/providers/amazon/aws/operators/appflow.py::3
providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py::3