This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 446b6ae6359 Add SageMaker Unified Studio domain_id, project_id, 
domain_region as new parameters to SageMakerNotebookOperator (#62147)
446b6ae6359 is described below

commit 446b6ae63595aad6c9ed31a9ecd96c554508d5f0
Author: Nikita Arbuzov <[email protected]>
AuthorDate: Wed Feb 18 19:43:07 2026 -0500

    Add SageMaker Unified Studio domain_id, project_id, domain_region as new 
parameters to SageMakerNotebookOperator (#62147)
    
    * add SageMaker Unified Studio domain_id, project_id, domain region as new 
parameters to SageMakerNotebookOperator
    
    * Update sagemaker-studio version requirement
    
    * Add unit tests for domain_id, project_id, domain_region params and fix 
static check failures
    ---------
    
    Co-authored-by: Rui Jiang <[email protected]>
    Co-authored-by: Rui Jiang <[email protected]>
---
 docs/spelling_wordlist.txt                         |  2 +
 .../amazon/aws/hooks/sagemaker_unified_studio.py   | 44 ++++++++--
 .../aws/operators/sagemaker_unified_studio.py      | 36 +++++++-
 .../aws/hooks/test_sagemaker_unified_studio.py     | 98 +++++++++++++++++-----
 .../aws/operators/test_sagemaker_unified_studio.py | 76 +++++++++++++++++
 5 files changed, 222 insertions(+), 34 deletions(-)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 56bfd885844..205e68fd66a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -541,6 +541,7 @@ DisplayVideo
 distcp
 distro
 distros
+dkr
 Dlp
 dlp
 DlpJob
@@ -2096,6 +2097,7 @@ xcomresult
 XComs
 Xero
 Xiaodong
+xlarge
 xml
 xmltodict
 xpath
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
index 0895ffc3f8b..ac06fcbacf2 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
@@ -40,24 +40,43 @@ class SageMakerNotebookHook(BaseHook):
         from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio 
import SageMakerNotebookHook
 
         notebook_hook = SageMakerNotebookHook(
+            execution_name="notebook_execution",
+            domain_id="dzd-example123456",
+            project_id="example123456",
             input_config={"input_path": "path/to/notebook.ipynb", 
"input_params": {"param1": "value1"}},
             output_config={"output_uri": "folder/output/location/prefix", 
"output_formats": "NOTEBOOK"},
-            execution_name="notebook_execution",
+            domain_region="us-east-1",
             waiter_delay=10,
             waiter_max_attempts=1440,
         )
 
     :param execution_name: The name of the notebook job to be executed, this 
is same as task_id.
+    :param domain_id: The domain ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
+        the SDK will attempt to resolve it from the environment.
+    :param project_id: The project ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
+        the SDK will attempt to resolve it from the environment.
     :param input_config: Configuration for the input file.
         Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': 
{'param1': 'value1'}}
     :param output_config: Configuration for the output format. It should 
include an output_formats parameter to specify the output format.
         Example: {'output_formats': ['NOTEBOOK']}
-    :param compute: compute configuration to use for the notebook execution. 
This is a required attribute if the execution is on a remote compute.
-        Example: {"instance_type": "ml.m5.large", "volume_size_in_gb": 30, 
"volume_kms_key_id": "", "image_details": {"ecr_uri": "string"}, 
"container_entrypoint": ["string"]}
+    :param domain_region: The AWS region for the domain. If not provided, the 
default AWS region will be used.
+    :param compute: compute configuration to use for the notebook execution. 
This is a required attribute
+        if the execution is on a remote compute.
+        Example::
+
+            {
+                "instance_type": "ml.c5.xlarge",
+                "image_details": {
+                    "image_name": "sagemaker-distribution-prod",
+                    "image_version": "3",
+                    "ecr_uri": 
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
+                },
+            }
+
     :param termination_condition: conditions to match to terminate the remote 
execution.
-        Example: {"MaxRuntimeInSeconds": 3600}
+        Example: ``{"MaxRuntimeInSeconds": 3600}``
     :param tags: tags to be associated with the remote execution runs.
-        Example: {"md_analytics": "logs"}
+        Example: ``{"md_analytics": "logs"}``
     :param waiter_delay: Interval in seconds to check the task execution 
status.
     :param waiter_max_attempts: Number of attempts to wait before returning 
FAILED.
     """
@@ -66,7 +85,10 @@ class SageMakerNotebookHook(BaseHook):
         self,
         execution_name: str,
         input_config: dict | None = None,
+        domain_id: str | None = None,
+        project_id: str | None = None,
         output_config: dict | None = None,
+        domain_region: str | None = None,
         compute: dict | None = None,
         termination_condition: dict | None = None,
         tags: dict | None = None,
@@ -78,6 +100,9 @@ class SageMakerNotebookHook(BaseHook):
         super().__init__(*args, **kwargs)
         self._sagemaker_studio = 
SageMakerStudioAPI(self._get_sagemaker_studio_config())
         self.execution_name = execution_name
+        self.domain_id = domain_id
+        self.project_id = project_id
+        self.domain_region = domain_region
         self.input_config = input_config or {}
         self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
         self.compute = compute
@@ -114,17 +139,20 @@ class SageMakerNotebookHook(BaseHook):
         start_execution_params = {
             "execution_name": self.execution_name,
             "execution_type": "NOTEBOOK",
+            "domain_id": self.domain_id,
+            "project_id": self.project_id,
             "input_config": self._format_start_execution_input_config(),
             "output_config": self._format_start_execution_output_config(),
             "termination_condition": self.termination_condition,
             "tags": self.tags,
         }
+
+        if self.domain_region:
+            start_execution_params["domain_region"] = self.domain_region
+
         if self.compute:
             start_execution_params["compute"] = self.compute
-        else:
-            start_execution_params["compute"] = {"instance_type": 
"ml.m6i.xlarge"}
 
-        print(start_execution_params)
         return 
self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
 
     def wait_for_execution_completion(self, execution_id, context):
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
index 85de6dd1c42..70310873670 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
@@ -48,8 +48,11 @@ class SageMakerNotebookOperator(BaseOperator):
 
         notebook_operator = SageMakerNotebookOperator(
             task_id="notebook_task",
+            domain_id="dzd-example123456",
+            project_id="example123456",
             input_config={"input_path": "path/to/notebook.ipynb", 
"input_params": ""},
             output_config={"output_format": "ipynb"},
+            domain_region="us-east-1",
             wait_for_completion=True,
             waiter_delay=10,
             waiter_max_attempts=1440,
@@ -63,12 +66,28 @@ class SageMakerNotebookOperator(BaseOperator):
     :param output_config:  Configuration for the output format. It should 
include an output_format parameter to control
         the format of the notebook execution output.
         Example: {"output_formats": ["NOTEBOOK"]}
-    :param compute: compute configuration to use for the notebook execution. 
This is a required attribute if the execution is on a remote compute.
-        Example: {"instance_type": "ml.m5.large", "volume_size_in_gb": 30, 
"volume_kms_key_id": "", "image_details": {"ecr_uri": "string"}, 
"container_entrypoint": ["string"]}
+    :param domain_id: The domain ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
+        the SDK will attempt to resolve it from the environment.
+    :param project_id: The project ID for Amazon SageMaker Unified Studio. 
Optional - if not provided,
+        the SDK will attempt to resolve it from the environment.
+    :param domain_region: The AWS region for the domain. If not provided, the 
default AWS region will be used.
+    :param compute: compute configuration to use for the artifact execution. 
This is a required attribute
+        if the execution is on a remote compute.
+        Example::
+
+            {
+                "instance_type": "ml.c5.xlarge",
+                "image_details": {
+                    "image_name": "sagemaker-distribution-prod",
+                    "image_version": "3",
+                    "ecr_uri": 
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
+                },
+            }
+
     :param termination_condition: conditions to match to terminate the remote 
execution.
-        Example: { "MaxRuntimeInSeconds": 3600 }
+        Example: ``{"MaxRuntimeInSeconds": 3600}``
     :param tags: tags to be associated with the remote execution runs.
-        Example: { "md_analytics": "logs" }
+        Example: ``{"md_analytics": "logs"}``
     :param wait_for_completion: Indicates whether to wait for the notebook 
execution to complete. If True, wait for completion; if False, don't wait.
     :param waiter_delay: Interval in seconds to check the notebook execution 
status.
     :param waiter_max_attempts: Number of attempts to wait before returning 
FAILED.
@@ -87,7 +106,10 @@ class SageMakerNotebookOperator(BaseOperator):
         self,
         task_id: str,
         input_config: dict,
+        domain_id: str | None = None,
+        project_id: str | None = None,
         output_config: dict | None = None,
+        domain_region: str | None = None,
         compute: dict | None = None,
         termination_condition: dict | None = None,
         tags: dict | None = None,
@@ -99,8 +121,11 @@ class SageMakerNotebookOperator(BaseOperator):
     ):
         super().__init__(task_id=task_id, **kwargs)
         self.execution_name = task_id
+        self.domain_id = domain_id
+        self.project_id = project_id
         self.input_config = input_config
         self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
+        self.domain_region = domain_region
         self.compute = compute or {}
         self.termination_condition = termination_condition or {}
         self.tags = tags or {}
@@ -119,9 +144,12 @@ class SageMakerNotebookOperator(BaseOperator):
             raise AirflowException("input_path is a required field in the 
input_config")
 
         return SageMakerNotebookHook(
+            domain_id=self.domain_id,
+            project_id=self.project_id,
             input_config=self.input_config,
             output_config=self.output_config,
             execution_name=self.execution_name,
+            domain_region=self.domain_region,
             compute=self.compute,
             termination_condition=self.termination_condition,
             tags=self.tags,
diff --git 
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
index be81da9b282..fd27931d228 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
@@ -200,28 +200,6 @@ class TestSageMakerNotebookHook:
         with pytest.raises(AirflowException, match="context is required"):
             self.hook._set_xcom_s3_path(self.s3Path, {})
 
-    def test_start_notebook_execution_default_compute(self):
-        """Test that default compute uses ml.m6i.xlarge instance type."""
-        hook_without_compute = SageMakerNotebookHook(
-            input_config={
-                "input_path": "test-data/notebook/test_notebook.ipynb",
-                "input_params": {"key": "value"},
-            },
-            output_config={"output_formats": ["NOTEBOOK"]},
-            execution_name="test-execution",
-            waiter_delay=10,
-        )
-        hook_without_compute._sagemaker_studio = MagicMock()
-        hook_without_compute._sagemaker_studio.execution_client = 
MagicMock(spec=ExecutionClient)
-        
hook_without_compute._sagemaker_studio.execution_client.start_execution.return_value
 = {
-            "executionId": "123456"
-        }
-
-        hook_without_compute.start_notebook_execution()
-
-        call_kwargs = 
hook_without_compute._sagemaker_studio.execution_client.start_execution.call_args[1]
-        assert call_kwargs["compute"] == {"instance_type": "ml.m6i.xlarge"}
-
     def test_start_notebook_execution_custom_compute(self):
         """Test that custom compute config is used when provided."""
         custom_compute = {"instance_type": "ml.c5.xlarge", 
"volume_size_in_gb": 50}
@@ -245,3 +223,79 @@ class TestSageMakerNotebookHook:
 
         call_kwargs = 
hook_with_compute._sagemaker_studio.execution_client.start_execution.call_args[1]
         assert call_kwargs["compute"] == custom_compute
+
+    def test_init_with_domain_id_project_id_domain_region(self):
+        """Test that domain_id, project_id, and domain_region are stored on 
the hook."""
+        hook = SageMakerNotebookHook(
+            execution_name="test-execution",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            domain_id="dzd-example123456",
+            project_id="proj-example123456",
+            domain_region="us-east-1",
+        )
+        hook._sagemaker_studio = MagicMock()
+
+        assert hook.domain_id == "dzd-example123456"
+        assert hook.project_id == "proj-example123456"
+        assert hook.domain_region == "us-east-1"
+
+    def test_init_domain_params_default_to_none(self):
+        """When domain params are not provided they default to None so the SDK 
can resolve them."""
+        hook = SageMakerNotebookHook(
+            execution_name="test-execution",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+        )
+        hook._sagemaker_studio = MagicMock()
+
+        assert hook.domain_id is None
+        assert hook.project_id is None
+        assert hook.domain_region is None
+
+    def test_start_notebook_execution_includes_domain_id_and_project_id(self):
+        """domain_id and project_id are always forwarded to start_execution."""
+        hook = SageMakerNotebookHook(
+            execution_name="test-execution",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            domain_id="dzd-example123456",
+            project_id="proj-example123456",
+        )
+        hook._sagemaker_studio = MagicMock()
+        hook._sagemaker_studio.execution_client = 
MagicMock(spec=ExecutionClient)
+        hook._sagemaker_studio.execution_client.start_execution.return_value = 
{"execution_id": "abc"}
+
+        hook.start_notebook_execution()
+
+        call_kwargs = 
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+        assert call_kwargs["domain_id"] == "dzd-example123456"
+        assert call_kwargs["project_id"] == "proj-example123456"
+
+    def 
test_start_notebook_execution_includes_domain_region_when_provided(self):
+        """domain_region is conditionally added to start_execution params only 
when set."""
+        hook = SageMakerNotebookHook(
+            execution_name="test-execution",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            domain_region="eu-west-1",
+        )
+        hook._sagemaker_studio = MagicMock()
+        hook._sagemaker_studio.execution_client = 
MagicMock(spec=ExecutionClient)
+        hook._sagemaker_studio.execution_client.start_execution.return_value = 
{"execution_id": "abc"}
+
+        hook.start_notebook_execution()
+
+        call_kwargs = 
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+        assert call_kwargs["domain_region"] == "eu-west-1"
+
+    def 
test_start_notebook_execution_omits_domain_region_when_not_provided(self):
+        """domain_region must NOT appear in start_execution params when it is 
None."""
+        hook = SageMakerNotebookHook(
+            execution_name="test-execution",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+        )
+        hook._sagemaker_studio = MagicMock()
+        hook._sagemaker_studio.execution_client = 
MagicMock(spec=ExecutionClient)
+        hook._sagemaker_studio.execution_client.start_execution.return_value = 
{"execution_id": "abc"}
+
+        hook.start_notebook_execution()
+
+        call_kwargs = 
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+        assert "domain_region" not in call_kwargs
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
 
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
index 94542d03a5a..2f402475d17 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
@@ -45,6 +45,29 @@ class TestSageMakerNotebookOperator:
         }
         assert operator.output_config == {"output_format": "ipynb"}
 
+    def test_init_with_domain_id_project_id_domain_region(self):
+        operator = SageMakerNotebookOperator(
+            task_id="test_id",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            domain_id="dzd-example123456",
+            project_id="proj-example123456",
+            domain_region="us-east-1",
+        )
+
+        assert operator.domain_id == "dzd-example123456"
+        assert operator.project_id == "proj-example123456"
+        assert operator.domain_region == "us-east-1"
+
+    def test_init_domain_params_default_to_none(self):
+        operator = SageMakerNotebookOperator(
+            task_id="test_id",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+        )
+
+        assert operator.domain_id is None
+        assert operator.project_id is None
+        assert operator.domain_region is None
+
     def test_only_required_params_init(self):
         operator = SageMakerNotebookOperator(
             task_id="test_id",
@@ -54,6 +77,59 @@ class TestSageMakerNotebookOperator:
         )
         assert isinstance(operator, SageMakerNotebookOperator)
 
+    
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+    def test_execute_passes_domain_id_project_id_domain_region_to_hook(self, 
mock_notebook_hook):
+        mock_hook_instance = mock_notebook_hook.return_value
+        mock_hook_instance.start_notebook_execution.return_value = {
+            "execution_id": "123456",
+            "executionType": "test",
+        }
+
+        operator = SageMakerNotebookOperator(
+            task_id="test_id",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            domain_id="dzd-example123456",
+            project_id="proj-example123456",
+            domain_region="us-west-2",
+        )
+
+        operator.execute({})
+
+        mock_notebook_hook.assert_called_once_with(
+            domain_id="dzd-example123456",
+            project_id="proj-example123456",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+            output_config={"output_formats": ["NOTEBOOK"]},
+            execution_name="test_id",
+            domain_region="us-west-2",
+            compute={},
+            termination_condition={},
+            tags={},
+            waiter_delay=10,
+            waiter_max_attempts=1440,
+        )
+
+    
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+    def test_execute_passes_none_domain_params_to_hook(self, 
mock_notebook_hook):
+        """When domain_id/project_id/domain_region are omitted, None is 
forwarded so the SDK resolves them."""
+        mock_hook_instance = mock_notebook_hook.return_value
+        mock_hook_instance.start_notebook_execution.return_value = {
+            "execution_id": "123456",
+            "executionType": "test",
+        }
+
+        operator = SageMakerNotebookOperator(
+            task_id="test_id",
+            input_config={"input_path": "path/to/notebook.ipynb"},
+        )
+
+        operator.execute({})
+
+        call_kwargs = mock_notebook_hook.call_args[1]
+        assert call_kwargs["domain_id"] is None
+        assert call_kwargs["project_id"] is None
+        assert call_kwargs["domain_region"] is None
+
     
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
     def test_execute_success(self, mock_notebook_hook):  # Mock the 
NotebookHook and its execute method
         mock_hook_instance = mock_notebook_hook.return_value

Reply via email to