This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 446b6ae6359 Add SageMaker Unified Studio domain_id, project_id,
domain_region as new parameters to SageMakerNotebookOperator (#62147)
446b6ae6359 is described below
commit 446b6ae63595aad6c9ed31a9ecd96c554508d5f0
Author: Nikita Arbuzov <[email protected]>
AuthorDate: Wed Feb 18 19:43:07 2026 -0500
Add SageMaker Unified Studio domain_id, project_id, domain_region as new
parameters to SageMakerNotebookOperator (#62147)
* add SageMaker Unified Studio domain_id, project_id, domain region as new
parameters to SageMakerNotebookOperator
* Update sagemaker-studio version requirement
* Add unit tests for domain_id, project_id, domain_region params and fix
static check failures
---------
Co-authored-by: Rui Jiang <[email protected]>
Co-authored-by: Rui Jiang <[email protected]>
---
docs/spelling_wordlist.txt | 2 +
.../amazon/aws/hooks/sagemaker_unified_studio.py | 44 ++++++++--
.../aws/operators/sagemaker_unified_studio.py | 36 +++++++-
.../aws/hooks/test_sagemaker_unified_studio.py | 98 +++++++++++++++++-----
.../aws/operators/test_sagemaker_unified_studio.py | 76 +++++++++++++++++
5 files changed, 222 insertions(+), 34 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 56bfd885844..205e68fd66a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -541,6 +541,7 @@ DisplayVideo
distcp
distro
distros
+dkr
Dlp
dlp
DlpJob
@@ -2096,6 +2097,7 @@ xcomresult
XComs
Xero
Xiaodong
+xlarge
xml
xmltodict
xpath
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
index 0895ffc3f8b..ac06fcbacf2 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py
@@ -40,24 +40,43 @@ class SageMakerNotebookHook(BaseHook):
from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio
import SageMakerNotebookHook
notebook_hook = SageMakerNotebookHook(
+ execution_name="notebook_execution",
+ domain_id="dzd-example123456",
+ project_id="example123456",
input_config={"input_path": "path/to/notebook.ipynb",
"input_params": {"param1": "value1"}},
output_config={"output_uri": "folder/output/location/prefix",
"output_formats": "NOTEBOOK"},
- execution_name="notebook_execution",
+ domain_region="us-east-1",
waiter_delay=10,
waiter_max_attempts=1440,
)
:param execution_name: The name of the notebook job to be executed, this
is same as task_id.
+ :param domain_id: The domain ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
+ the SDK will attempt to resolve it from the environment.
+ :param project_id: The project ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
+ the SDK will attempt to resolve it from the environment.
:param input_config: Configuration for the input file.
Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params':
{'param1': 'value1'}}
:param output_config: Configuration for the output format. It should
include an output_formats parameter to specify the output format.
Example: {'output_formats': ['NOTEBOOK']}
- :param compute: compute configuration to use for the notebook execution.
This is a required attribute if the execution is on a remote compute.
- Example: {"instance_type": "ml.m5.large", "volume_size_in_gb": 30,
"volume_kms_key_id": "", "image_details": {"ecr_uri": "string"},
"container_entrypoint": ["string"]}
+ :param domain_region: The AWS region for the domain. If not provided, the
default AWS region will be used.
+ :param compute: compute configuration to use for the notebook execution.
This is a required attribute
+ if the execution is on a remote compute.
+ Example::
+
+ {
+ "instance_type": "ml.c5.xlarge",
+ "image_details": {
+ "image_name": "sagemaker-distribution-prod",
+ "image_version": "3",
+ "ecr_uri":
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
+ },
+ }
+
:param termination_condition: conditions to match to terminate the remote
execution.
- Example: {"MaxRuntimeInSeconds": 3600}
+ Example: ``{"MaxRuntimeInSeconds": 3600}``
:param tags: tags to be associated with the remote execution runs.
- Example: {"md_analytics": "logs"}
+ Example: ``{"md_analytics": "logs"}``
:param waiter_delay: Interval in seconds to check the task execution
status.
:param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
"""
@@ -66,7 +85,10 @@ class SageMakerNotebookHook(BaseHook):
self,
execution_name: str,
input_config: dict | None = None,
+ domain_id: str | None = None,
+ project_id: str | None = None,
output_config: dict | None = None,
+ domain_region: str | None = None,
compute: dict | None = None,
termination_condition: dict | None = None,
tags: dict | None = None,
@@ -78,6 +100,9 @@ class SageMakerNotebookHook(BaseHook):
super().__init__(*args, **kwargs)
self._sagemaker_studio =
SageMakerStudioAPI(self._get_sagemaker_studio_config())
self.execution_name = execution_name
+ self.domain_id = domain_id
+ self.project_id = project_id
+ self.domain_region = domain_region
self.input_config = input_config or {}
self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
self.compute = compute
@@ -114,17 +139,20 @@ class SageMakerNotebookHook(BaseHook):
start_execution_params = {
"execution_name": self.execution_name,
"execution_type": "NOTEBOOK",
+ "domain_id": self.domain_id,
+ "project_id": self.project_id,
"input_config": self._format_start_execution_input_config(),
"output_config": self._format_start_execution_output_config(),
"termination_condition": self.termination_condition,
"tags": self.tags,
}
+
+ if self.domain_region:
+ start_execution_params["domain_region"] = self.domain_region
+
if self.compute:
start_execution_params["compute"] = self.compute
- else:
- start_execution_params["compute"] = {"instance_type":
"ml.m6i.xlarge"}
- print(start_execution_params)
return
self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
def wait_for_execution_completion(self, execution_id, context):
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
index 85de6dd1c42..70310873670 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py
@@ -48,8 +48,11 @@ class SageMakerNotebookOperator(BaseOperator):
notebook_operator = SageMakerNotebookOperator(
task_id="notebook_task",
+ domain_id="dzd-example123456",
+ project_id="example123456",
input_config={"input_path": "path/to/notebook.ipynb",
"input_params": ""},
output_config={"output_format": "ipynb"},
+ domain_region="us-east-1",
wait_for_completion=True,
waiter_delay=10,
waiter_max_attempts=1440,
@@ -63,12 +66,28 @@ class SageMakerNotebookOperator(BaseOperator):
:param output_config: Configuration for the output format. It should
include an output_format parameter to control
the format of the notebook execution output.
Example: {"output_formats": ["NOTEBOOK"]}
- :param compute: compute configuration to use for the notebook execution.
This is a required attribute if the execution is on a remote compute.
- Example: {"instance_type": "ml.m5.large", "volume_size_in_gb": 30,
"volume_kms_key_id": "", "image_details": {"ecr_uri": "string"},
"container_entrypoint": ["string"]}
+ :param domain_id: The domain ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
+ the SDK will attempt to resolve it from the environment.
+ :param project_id: The project ID for Amazon SageMaker Unified Studio.
Optional - if not provided,
+ the SDK will attempt to resolve it from the environment.
+ :param domain_region: The AWS region for the domain. If not provided, the
default AWS region will be used.
+ :param compute: compute configuration to use for the artifact execution.
This is a required attribute
+ if the execution is on a remote compute.
+ Example::
+
+ {
+ "instance_type": "ml.c5.xlarge",
+ "image_details": {
+ "image_name": "sagemaker-distribution-prod",
+ "image_version": "3",
+ "ecr_uri":
"123456123456.dkr.ecr.us-west-2.amazonaws.com/ImageName:latest",
+ },
+ }
+
:param termination_condition: conditions to match to terminate the remote
execution.
- Example: { "MaxRuntimeInSeconds": 3600 }
+ Example: ``{"MaxRuntimeInSeconds": 3600}``
:param tags: tags to be associated with the remote execution runs.
- Example: { "md_analytics": "logs" }
+ Example: ``{"md_analytics": "logs"}``
:param wait_for_completion: Indicates whether to wait for the notebook
execution to complete. If True, wait for completion; if False, don't wait.
:param waiter_delay: Interval in seconds to check the notebook execution
status.
:param waiter_max_attempts: Number of attempts to wait before returning
FAILED.
@@ -87,7 +106,10 @@ class SageMakerNotebookOperator(BaseOperator):
self,
task_id: str,
input_config: dict,
+ domain_id: str | None = None,
+ project_id: str | None = None,
output_config: dict | None = None,
+ domain_region: str | None = None,
compute: dict | None = None,
termination_condition: dict | None = None,
tags: dict | None = None,
@@ -99,8 +121,11 @@ class SageMakerNotebookOperator(BaseOperator):
):
super().__init__(task_id=task_id, **kwargs)
self.execution_name = task_id
+ self.domain_id = domain_id
+ self.project_id = project_id
self.input_config = input_config
self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
+ self.domain_region = domain_region
self.compute = compute or {}
self.termination_condition = termination_condition or {}
self.tags = tags or {}
@@ -119,9 +144,12 @@ class SageMakerNotebookOperator(BaseOperator):
raise AirflowException("input_path is a required field in the
input_config")
return SageMakerNotebookHook(
+ domain_id=self.domain_id,
+ project_id=self.project_id,
input_config=self.input_config,
output_config=self.output_config,
execution_name=self.execution_name,
+ domain_region=self.domain_region,
compute=self.compute,
termination_condition=self.termination_condition,
tags=self.tags,
diff --git
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
index be81da9b282..fd27931d228 100644
---
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
+++
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py
@@ -200,28 +200,6 @@ class TestSageMakerNotebookHook:
with pytest.raises(AirflowException, match="context is required"):
self.hook._set_xcom_s3_path(self.s3Path, {})
- def test_start_notebook_execution_default_compute(self):
- """Test that default compute uses ml.m6i.xlarge instance type."""
- hook_without_compute = SageMakerNotebookHook(
- input_config={
- "input_path": "test-data/notebook/test_notebook.ipynb",
- "input_params": {"key": "value"},
- },
- output_config={"output_formats": ["NOTEBOOK"]},
- execution_name="test-execution",
- waiter_delay=10,
- )
- hook_without_compute._sagemaker_studio = MagicMock()
- hook_without_compute._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
-
hook_without_compute._sagemaker_studio.execution_client.start_execution.return_value
= {
- "executionId": "123456"
- }
-
- hook_without_compute.start_notebook_execution()
-
- call_kwargs =
hook_without_compute._sagemaker_studio.execution_client.start_execution.call_args[1]
- assert call_kwargs["compute"] == {"instance_type": "ml.m6i.xlarge"}
-
def test_start_notebook_execution_custom_compute(self):
"""Test that custom compute config is used when provided."""
custom_compute = {"instance_type": "ml.c5.xlarge",
"volume_size_in_gb": 50}
@@ -245,3 +223,79 @@ class TestSageMakerNotebookHook:
call_kwargs =
hook_with_compute._sagemaker_studio.execution_client.start_execution.call_args[1]
assert call_kwargs["compute"] == custom_compute
+
+ def test_init_with_domain_id_project_id_domain_region(self):
+ """Test that domain_id, project_id, and domain_region are stored on
the hook."""
+ hook = SageMakerNotebookHook(
+ execution_name="test-execution",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ domain_id="dzd-example123456",
+ project_id="proj-example123456",
+ domain_region="us-east-1",
+ )
+ hook._sagemaker_studio = MagicMock()
+
+ assert hook.domain_id == "dzd-example123456"
+ assert hook.project_id == "proj-example123456"
+ assert hook.domain_region == "us-east-1"
+
+ def test_init_domain_params_default_to_none(self):
+ """When domain params are not provided they default to None so the SDK
can resolve them."""
+ hook = SageMakerNotebookHook(
+ execution_name="test-execution",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ )
+ hook._sagemaker_studio = MagicMock()
+
+ assert hook.domain_id is None
+ assert hook.project_id is None
+ assert hook.domain_region is None
+
+ def test_start_notebook_execution_includes_domain_id_and_project_id(self):
+ """domain_id and project_id are always forwarded to start_execution."""
+ hook = SageMakerNotebookHook(
+ execution_name="test-execution",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ domain_id="dzd-example123456",
+ project_id="proj-example123456",
+ )
+ hook._sagemaker_studio = MagicMock()
+ hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+ hook._sagemaker_studio.execution_client.start_execution.return_value =
{"execution_id": "abc"}
+
+ hook.start_notebook_execution()
+
+ call_kwargs =
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+ assert call_kwargs["domain_id"] == "dzd-example123456"
+ assert call_kwargs["project_id"] == "proj-example123456"
+
+ def
test_start_notebook_execution_includes_domain_region_when_provided(self):
+ """domain_region is conditionally added to start_execution params only
when set."""
+ hook = SageMakerNotebookHook(
+ execution_name="test-execution",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ domain_region="eu-west-1",
+ )
+ hook._sagemaker_studio = MagicMock()
+ hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+ hook._sagemaker_studio.execution_client.start_execution.return_value =
{"execution_id": "abc"}
+
+ hook.start_notebook_execution()
+
+ call_kwargs =
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+ assert call_kwargs["domain_region"] == "eu-west-1"
+
+ def
test_start_notebook_execution_omits_domain_region_when_not_provided(self):
+ """domain_region must NOT appear in start_execution params when it is
None."""
+ hook = SageMakerNotebookHook(
+ execution_name="test-execution",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ )
+ hook._sagemaker_studio = MagicMock()
+ hook._sagemaker_studio.execution_client =
MagicMock(spec=ExecutionClient)
+ hook._sagemaker_studio.execution_client.start_execution.return_value =
{"execution_id": "abc"}
+
+ hook.start_notebook_execution()
+
+ call_kwargs =
hook._sagemaker_studio.execution_client.start_execution.call_args[1]
+ assert "domain_region" not in call_kwargs
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
index 94542d03a5a..2f402475d17 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py
@@ -45,6 +45,29 @@ class TestSageMakerNotebookOperator:
}
assert operator.output_config == {"output_format": "ipynb"}
+ def test_init_with_domain_id_project_id_domain_region(self):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ domain_id="dzd-example123456",
+ project_id="proj-example123456",
+ domain_region="us-east-1",
+ )
+
+ assert operator.domain_id == "dzd-example123456"
+ assert operator.project_id == "proj-example123456"
+ assert operator.domain_region == "us-east-1"
+
+ def test_init_domain_params_default_to_none(self):
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ )
+
+ assert operator.domain_id is None
+ assert operator.project_id is None
+ assert operator.domain_region is None
+
def test_only_required_params_init(self):
operator = SageMakerNotebookOperator(
task_id="test_id",
@@ -54,6 +77,59 @@ class TestSageMakerNotebookOperator:
)
assert isinstance(operator, SageMakerNotebookOperator)
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_passes_domain_id_project_id_domain_region_to_hook(self,
mock_notebook_hook):
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ domain_id="dzd-example123456",
+ project_id="proj-example123456",
+ domain_region="us-west-2",
+ )
+
+ operator.execute({})
+
+ mock_notebook_hook.assert_called_once_with(
+ domain_id="dzd-example123456",
+ project_id="proj-example123456",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ output_config={"output_formats": ["NOTEBOOK"]},
+ execution_name="test_id",
+ domain_region="us-west-2",
+ compute={},
+ termination_condition={},
+ tags={},
+ waiter_delay=10,
+ waiter_max_attempts=1440,
+ )
+
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
+ def test_execute_passes_none_domain_params_to_hook(self,
mock_notebook_hook):
+ """When domain_id/project_id/domain_region are omitted, None is
forwarded so the SDK resolves them."""
+ mock_hook_instance = mock_notebook_hook.return_value
+ mock_hook_instance.start_notebook_execution.return_value = {
+ "execution_id": "123456",
+ "executionType": "test",
+ }
+
+ operator = SageMakerNotebookOperator(
+ task_id="test_id",
+ input_config={"input_path": "path/to/notebook.ipynb"},
+ )
+
+ operator.execute({})
+
+ call_kwargs = mock_notebook_hook.call_args[1]
+ assert call_kwargs["domain_id"] is None
+ assert call_kwargs["project_id"] is None
+ assert call_kwargs["domain_region"] is None
+
@patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook")
def test_execute_success(self, mock_notebook_hook): # Mock the
NotebookHook and its execute method
mock_hook_instance = mock_notebook_hook.return_value