This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0cc1b92bd72 Adds async support to SageMakerNotebookJobTrigger (#65571)
0cc1b92bd72 is described below

commit 0cc1b92bd728199b252144da0351135c815a0cbe
Author: EMMANUELA OPURUM 
<[email protected]>
AuthorDate: Tue May 12 01:12:56 2026 +0100

    Adds async support to SageMakerNotebookJobTrigger (#65571)
    
    * feat: add async support to SageMakerNotebookJobTrigger
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: fix import order and trailing newline in test file
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: fix import order in test file
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * ci: retrigger static checks
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: update exception allowlist and fix hook test error message
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: restore allowlist to upstream version
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: replace AirflowException with early return in xcom helper methods
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: update xcom tests to match early return behaviour
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: fix literal backtick-n in xcom test methods
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: fix syntax errors in hook and fix trigger timeout test
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    * fix: update exception allowlist count to 2
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    
    ---------
    
    Signed-off-by: Emmanuela Opurum <[email protected]>
    Co-authored-by: Emmanuela Opurum <[email protected]>
---
 .../amazon/aws/hooks/sagemaker_unified_studio.py   | 100 +++++++-------------
 .../aws/triggers/sagemaker_unified_studio.py       |  94 ++++++++++++------
 .../aws/hooks/test_sagemaker_unified_studio.py     |  12 +--
 .../test_sagemaker_unified_studio_trigger.py       | 105 +++++++++++++++++++++
 scripts/ci/prek/known_airflow_exceptions.txt       |   2 +-
 5 files changed, 212 insertions(+), 101 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
index 9c750e2973b..43587c6c78c 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
@@ -20,6 +20,7 @@
 from __future__ import annotations
 
 import time
+from typing import Any
 
 from sagemaker_studio import ClientConfig
 from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
@@ -33,52 +34,6 @@ class SageMakerNotebookHook(BaseHook):
     Interact with Sagemaker Unified Studio Workflows for executing Jupyter 
notebooks, querybooks, and visual ETL jobs.
 
     This hook provides a wrapper around the Sagemaker Workflows Notebook 
Execution API.
-
-    Examples:
-     .. code-block:: python
-
-        from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio 
import SageMakerNotebookHook
-
-        notebook_hook = SageMakerNotebookHook(
-            execution_name="notebook_execution",
-            domain_id="dzd-example123456",
-            project_id="example123456",
-            input_config={"input_path": "path/to/notebook.ipynb", 
"input_params": {"param1": "value1"}},
-            output_config={"output_uri": "folder/output/location/prefix", 
"output_formats": "NOTEBOOK"},
-            domain_region="us-east-1",
-            waiter_delay=10,
-            waiter_max_attempts=1440,
-        )
-
-    :param execution_name: The name of the notebook job to be executed, this 
is same as task_id.
-    :param domain_id: The domain ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
-        the SDK will attempt to resolve it from the environment.
-    :param project_id: The project ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
-        the SDK will attempt to resolve it from the environment.
-    :param input_config: Configuration for the input file.
-        Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': 
{'param1': 'value1'}}
-    :param output_config: Configuration for the output format. It should 
include an output_formats parameter to specify the output format.
-        Example: {'output_formats': ['NOTEBOOK']}
-    :param domain_region: The AWS region for the domain. If not provided, the 
default AWS region will be used.
-    :param compute: compute configuration to use for the notebook execution. 
This is a required attribute
-        if the execution is on a remote compute.
-        Example::
-
-            {
-                "instance_type": "ml.c5.xlarge",
-                "image_details": {
-                    "image_name": "sagemaker-distribution-prod",
-                    "image_version": "3",
-                    "ecr_uri": 
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
-                },
-            }
-
-    :param termination_condition: conditions to match to terminate the remote 
execution.
-        Example: ``{"MaxRuntimeInSeconds": 3600}``
-    :param tags: tags to be associated with the remote execution runs.
-        Example: ``{"md_analytics": "logs"}``
-    :param waiter_delay: Interval in seconds to check the task execution 
status.
-    :param waiter_max_attempts: Number of attempts to wait before returning 
FAILED.
     """
 
     def __init__(
@@ -124,23 +79,19 @@ class SageMakerNotebookHook(BaseHook):
         return config
 
     def _format_start_execution_input_config(self):
-        config = {
+        return {
             "notebook_config": {
                 "input_path": self.input_config.get("input_path"),
                 "input_parameters": self.input_config.get("input_params"),
             },
         }
 
-        return config
-
     def _format_start_execution_output_config(self):
-        output_formats = self.output_config.get("output_formats")
-        config = {
+        return {
             "notebook_config": {
-                "output_formats": output_formats,
+                "output_formats": self.output_config.get("output_formats"),
             }
         }
-        return config
 
     def start_notebook_execution(self):
         start_execution_params = {
@@ -162,16 +113,26 @@ class SageMakerNotebookHook(BaseHook):
 
         return 
self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
 
+    def get_notebook_execution(self, execution_id: str) -> dict[str, Any]:
+        """Fetch the status of a SageMaker Notebook Job execution."""
+        if self._sagemaker_studio.execution_client is None:
+            raise AirflowException("SageMaker Studio execution client is not 
initialized.")
+        return 
self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
+
     def wait_for_execution_completion(self, execution_id, context):
         wait_attempts = 0
         while wait_attempts < self.waiter_max_attempts:
             wait_attempts += 1
             time.sleep(self.waiter_delay)
-            response = 
self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
+
+            response = self.get_notebook_execution(execution_id)
+
             error_message = response.get("error_details", 
{}).get("error_message")
             status = response["status"]
+
             if "files" in response:
                 self._set_xcom_files(response["files"], context)
+
             if "s3_path" in response:
                 self._set_xcom_s3_path(response["s3_path"], context)
 
@@ -179,13 +140,12 @@ class SageMakerNotebookHook(BaseHook):
             if ret:
                 return ret
 
-        # If timeout, handle state FAILED with timeout message
         return self._handle_state(execution_id, "FAILED", "Execution timed 
out")
 
     def _set_xcom_files(self, files, context):
         if not context:
-            error_message = "context is required"
-            raise AirflowException(error_message)
+            return
+
         for file in files:
             context["ti"].xcom_push(
                 key=f"{file['display_name']}.{file['file_format']}",
@@ -194,8 +154,8 @@ class SageMakerNotebookHook(BaseHook):
 
     def _set_xcom_s3_path(self, s3_path, context):
         if not context:
-            error_message = "context is required"
-            raise AirflowException(error_message)
+            return
+
         context["ti"].xcom_push(
             key="s3_path",
             value=s3_path,
@@ -206,15 +166,21 @@ class SageMakerNotebookHook(BaseHook):
         in_progress_states = ["IN_PROGRESS", "STOPPING"]
 
         if status in in_progress_states:
-            info_message = f"Execution {execution_id} is still in progress 
with state:{status}, will check for a terminal status again in 
{self.waiter_delay}"
-            self.log.info(info_message)
+            self.log.info(
+                "Execution %s is still in progress with state:%s, will check 
again in %ss",
+                execution_id,
+                status,
+                self.waiter_delay,
+            )
             return None
-        execution_message = f"Exiting Execution {execution_id} State: {status}"
+
         if status in finished_states:
-            self.log.info(execution_message)
+            self.log.info("Execution %s completed successfully", execution_id)
             return {"Status": status, "ExecutionId": execution_id}
-        log_error_message = f"Execution {execution_id} failed with error: 
{error_message}"
-        self.log.error(log_error_message)
-        if error_message == "":
-            error_message = execution_message
+
+        self.log.error("Execution %s failed with error: %s", execution_id, 
error_message)
+
+        if not error_message:
+            error_message = f"Execution {execution_id} ended with status 
{status}"
+
         raise AirflowException(error_message)
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
index e9285e9d8dd..2b172d7cf6a 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py
@@ -19,48 +19,88 @@
 
 from __future__ import annotations
 
-from airflow.triggers.base import BaseTrigger
+import asyncio
+from typing import Any
 
+from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import 
SageMakerNotebookHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
 
-class SageMakerNotebookJobTrigger(BaseTrigger):
-    """
-    Watches for a notebook job, triggers when it finishes.
-
-    Examples:
-     .. code-block:: python
-
-        from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio 
import SageMakerNotebookJobTrigger
-
-        notebook_trigger = SageMakerNotebookJobTrigger(
-            execution_id="notebook_job_1234",
-            execution_name="notebook_task",
-            waiter_delay=10,
-            waiter_max_attempts=1440,
-        )
 
-    :param execution_id: A unique, meaningful id for the task.
-    :param execution_name: A unique, meaningful name for the task.
-    :param waiter_delay: Interval in seconds to check the notebook execution 
status.
-    :param waiter_max_attempts: Number of attempts to wait before returning 
FAILED.
-    """
+class SageMakerNotebookJobTrigger(BaseTrigger):
+    """Async trigger for SageMaker Unified Studio notebook executions."""
 
-    def __init__(self, execution_id, execution_name, waiter_delay, 
waiter_max_attempts, **kwargs):
+    def __init__(
+        self,
+        *,
+        execution_id: str,
+        execution_name: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        **kwargs,
+    ):
         super().__init__(**kwargs)
         self.execution_id = execution_id
         self.execution_name = execution_name
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
 
-    def serialize(self):
+    def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
-            # dynamically generate the fully qualified name of the class
-            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            
"airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookJobTrigger",
             {
                 "execution_id": self.execution_id,
                 "execution_name": self.execution_name,
-                "poll_interval": self.poll_interval,
+                "waiter_delay": self.waiter_delay,
+                "waiter_max_attempts": self.waiter_max_attempts,
             },
         )
 
     async def run(self):
-        pass
+        hook = SageMakerNotebookHook(execution_name=self.execution_name)
+        attempts = 0
+
+        terminal_success = {"COMPLETED", "SUCCEEDED"}
+        terminal_failure = {"FAILED", "ERROR", "CANCELLED", "STOPPED"}
+
+        while attempts < self.waiter_max_attempts:
+            attempts += 1
+
+            # CI-safe async execution (NO run_in_executor)
+            response = await asyncio.to_thread(
+                hook.get_notebook_execution,
+                self.execution_id,
+            )
+
+            status = response.get("status")
+            error_message = response.get("error_details", 
{}).get("error_message")
+
+            if status in terminal_success:
+                yield TriggerEvent(
+                    {
+                        "status": "success",
+                        "execution_id": self.execution_id,
+                        "files": response.get("files"),
+                        "s3_path": response.get("s3_path"),
+                    }
+                )
+                return
+
+            if status in terminal_failure:
+                yield TriggerEvent(
+                    {
+                        "status": "failed",
+                        "execution_id": self.execution_id,
+                        "error": error_message or f"Execution ended with 
status: {status}",
+                    }
+                )
+                return
+
+            await asyncio.sleep(self.waiter_delay)
+
+        yield TriggerEvent(
+            {
+                "status": "failed",
+                "execution_id": self.execution_id,
+                "error": "Execution timed out",
+            }
+        )
diff --git 
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
index fd27931d228..449896ebd7b 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
@@ -162,13 +162,13 @@ class TestSageMakerNotebookHook:
 
         status = "STOPPED"
         error_message = ""
-        with pytest.raises(AirflowException, match=f"Exiting Execution 
{execution_id} State: {status}"):
+        with pytest.raises(AirflowException, match=f"Execution {execution_id} 
ended with status {status}"):
             self.hook._handle_state(execution_id, status, error_message)
 
     def test_handle_unexpected_state(self):
         execution_id = "123456"
         status = "PENDING"
-        error_message = f"Exiting Execution {execution_id} State: {status}"
+        error_message = f"Execution {execution_id} ended with status {status}"
         with pytest.raises(AirflowException, match=error_message):
             self.hook._handle_state(execution_id, status, error_message)
 
@@ -183,8 +183,8 @@ class TestSageMakerNotebookHook:
         mock_set_xcom_files.assert_called_once_with(*expected_call.args, 
**expected_call.kwargs)
 
     def test_set_xcom_files_negative_missing_context(self):
-        with pytest.raises(AirflowException, match="context is required"):
-            self.hook._set_xcom_files(self.files, {})
+        # When context is empty, _set_xcom_files returns early without raising
+        self.hook._set_xcom_files(self.files, {})
 
     @pytest.mark.db_test
     @patch(
@@ -197,8 +197,8 @@ class TestSageMakerNotebookHook:
         mock_set_xcom_s3_path.assert_called_once_with(*expected_call.args, 
**expected_call.kwargs)
 
     def test_set_xcom_s3_path_negative_missing_context(self):
-        with pytest.raises(AirflowException, match="context is required"):
-            self.hook._set_xcom_s3_path(self.s3Path, {})
+        # When context is empty, _set_xcom_s3_path returns early without 
raising
+        self.hook._set_xcom_s3_path(self.s3Path, {})
 
     def test_start_notebook_execution_custom_compute(self):
         """Test that custom compute config is used when provided."""
diff --git 
a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
new file mode 100644
index 00000000000..c6efd51c6c0
--- /dev/null
+++ 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio_trigger.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import (
+    SageMakerNotebookJobTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_success(mock_hook):
+    mock_hook.return_value.get_notebook_execution = MagicMock(
+        return_value={
+            "status": "COMPLETED",
+            "files": ["output.ipynb"],
+            "s3_path": "s3://bucket/path",
+        }
+    )
+    trigger = SageMakerNotebookJobTrigger(
+        execution_id="exec-123",
+        execution_name="my-notebook",
+        waiter_delay=1,
+        waiter_max_attempts=3,
+    )
+    gen = trigger.run()
+    event = await gen.asend(None)
+    assert isinstance(event, TriggerEvent)
+    assert event.payload["status"] == "success"
+    assert event.payload["execution_id"] == "exec-123"
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_failure(mock_hook):
+    mock_hook.return_value.get_notebook_execution = MagicMock(
+        return_value={
+            "status": "FAILED",
+            "error_details": {"error_message": "Something broke"},
+        }
+    )
+    trigger = SageMakerNotebookJobTrigger(
+        execution_id="exec-123",
+        execution_name="my-notebook",
+        waiter_delay=1,
+        waiter_max_attempts=3,
+    )
+    gen = trigger.run()
+    event = await gen.asend(None)
+    assert isinstance(event, TriggerEvent)
+    assert event.payload["status"] == "failed"
+    assert "Something broke" in event.payload["error"]
+
+
[email protected]
+@patch("airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookHook")
+async def test_trigger_running_then_timeout(mock_hook):
+    mock_hook.return_value.get_notebook_execution = 
MagicMock(return_value={"status": "IN_PROGRESS"})
+    trigger = SageMakerNotebookJobTrigger(
+        execution_id="exec-123",
+        execution_name="my-notebook",
+        waiter_delay=0,
+        waiter_max_attempts=2,
+    )
+    gen = trigger.run()
+    event = await gen.asend(None)
+    assert isinstance(event, TriggerEvent)
+    assert event.payload["status"] == "failed"
+    assert "timed out" in event.payload["error"]
+
+
+def test_trigger_serialize():
+    trigger = SageMakerNotebookJobTrigger(
+        execution_id="exec-123",
+        execution_name="my-notebook",
+        waiter_delay=5,
+        waiter_max_attempts=10,
+    )
+    classpath, kwargs = trigger.serialize()
+    assert classpath == (
+        
"airflow.providers.amazon.aws.triggers.sagemaker_unified_studio.SageMakerNotebookJobTrigger"
+    )
+    assert kwargs["execution_id"] == "exec-123"
+    assert kwargs["execution_name"] == "my-notebook"
+    assert kwargs["waiter_delay"] == 5
+    assert kwargs["waiter_max_attempts"] == 10
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt 
b/scripts/ci/prek/known_airflow_exceptions.txt
index b9dd36f80c2..e56e96c35a3 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -61,7 +61,7 @@ 
providers/amazon/src/airflow/providers/amazon/aws/hooks/rds.py::7
 providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py::3
 providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py::2
 providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py::10
-providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py::3
+providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py::2
 providers/amazon/src/airflow/providers/amazon/aws/links/emr.py::2
 providers/amazon/src/airflow/providers/amazon/aws/operators/appflow.py::3
 providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py::3

Reply via email to