o-nikolas commented on code in PR #45726: URL: https://github.com/apache/airflow/pull/45726#discussion_r1919248414
########## providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py: ########## @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +import time + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from sagemaker_studio import ClientConfig +from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with the Sagemaker Workflows API. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook + + notebook_hook = NotebookHook( + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, + output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, + execution_name='notebook_execution', + poll_interval=10, + ) + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to control Review Comment: This sentence seems to just tail off in the middle? ########## providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py: ########## @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +import time + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from sagemaker_studio import ClientConfig +from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with the Sagemaker Workflows API. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook + + notebook_hook = NotebookHook( + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, + output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, + execution_name='notebook_execution', + poll_interval=10, + ) + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to control + Example: {'output_formats': ['NOTEBOOK']} + :param compute: compute configuration to use for the notebook execution. This is an required attribute + if the execution is on a remote compute. + Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param poll_interval: Interval in seconds to check the notebook execution status. + """ + + def __init__( + self, + execution_name: str, + input_config: dict = {}, + output_config: dict = {"output_formats": ["NOTEBOOK"]}, + compute: dict = {}, + termination_condition: dict = {}, + tags: dict = {}, + poll_interval: int = 10, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config()) + self.execution_name = execution_name + self.input_config = input_config + self.output_config = output_config + self.compute = compute + self.termination_condition = termination_condition + self.tags = tags + self.poll_interval = poll_interval + + def _get_sagemaker_studio_config(self): + config = ClientConfig() + config.overrides["execution"] = {"local": is_local_runner()} + return config + + def _format_start_execution_input_config(self): + config = { + "notebook_config": { + "input_path": self.input_config.get("input_path"), + "input_parameters": self.input_config.get("input_params"), + }, + } + + return config + + def _format_start_execution_output_config(self): + output_formats = ( + self.output_config.get("output_formats") if self.output_config else ["NOTEBOOK"] Review Comment: This ternary is unnecessary right? There is a default value provided in the constructor so the output_config can't be empty? ########## providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py: ########## @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +import time + +from sagemaker_studio import ClientConfig +from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with the Sagemaker Workflows API. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook + + notebook_hook = NotebookHook( + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, + output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, + execution_name='notebook_execution', + poll_interval=10, + ) + + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to control + Example: {'output_formats': ['NOTEBOOK']} + :param compute: compute configuration to use for the notebook execution. This is an required attribute + if the execution is on a remote compute. + Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param poll_interval: Interval in seconds to check the notebook execution status. + """ + + def __init__( + self, + execution_name: str, + input_config: dict = {}, + output_config: dict = {"output_formats": ["NOTEBOOK"]}, + compute: dict = {}, + termination_condition: dict = {}, + tags: dict = {}, + poll_interval: int = 10, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config()) + self.execution_name = execution_name + self.input_config = input_config + self.output_config = output_config + self.compute = compute + self.termination_condition = termination_condition + self.tags = tags + self.poll_interval = poll_interval + + def _get_sagemaker_studio_config(self): + config = ClientConfig() + config.overrides["execution"] = {"local": is_local_runner()} + return config + + def _format_start_execution_input_config(self): + config = { + "notebook_config": { + "input_path": self.input_config.get("input_path"), + "input_parameters": self.input_config.get("input_params"), + }, + } + + return config + + def _format_start_execution_output_config(self): + output_formats = ( + self.output_config.get("output_formats") if self.output_config else ["NOTEBOOK"] + ) + config = { + "notebook_config": { + "output_formats": output_formats, + } + } + return config + + def start_notebook_execution(self): + start_execution_params = { + "execution_name": self.execution_name, + "execution_type": "NOTEBOOK", + "input_config": self._format_start_execution_input_config(), + "output_config": self._format_start_execution_output_config(), + "termination_condition": self.termination_condition, + "tags": self.tags, + } + if self.compute: + start_execution_params["compute"] = self.compute + + request = StartExecutionRequest(**start_execution_params) + + return self._sagemaker_studio.execution_client.start_execution(request) + + def wait_for_execution_completion(self, execution_id, context): + + while True: + time.sleep(self.poll_interval) + response = self.get_execution_response(execution_id) + error_message = response.get("error_details", {}).get("error_message") + status = response["status"] + if "files" in response: + self._set_xcom_files(response["files"], context) + if "s3_path" in response: + self._set_xcom_s3_path(response["s3_path"], context) + + ret = self._handle_state(execution_id, status, error_message) + if ret: + return ret + + def _set_xcom_files(self, files, context): + if not context: + error_message = "context is required" + raise AirflowException(error_message) + for file in files: + context["ti"].xcom_push( + key=f"{file['display_name']}.{file['file_format']}", + value=file["file_path"], + ) + + def _set_xcom_s3_path(self, s3_path, context): + if not context: + error_message = "context is required" + raise AirflowException(error_message) + context["ti"].xcom_push( + key="s3_path", + value=s3_path, + ) + + def get_execution_response(self, execution_id): + response = self._sagemaker_studio.execution_client.get_execution( + GetExecutionRequest(execution_id=execution_id) + ) + return response Review Comment: ```suggestion return self._sagemaker_studio.execution_client.get_execution( GetExecutionRequest(execution_id=execution_id) ) ``` Any specific reason you're wrapping this one-liner in a method? I thought maybe for mocking purposes but I don't see anything like that below. ########## providers/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py: ########## @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook sensor.""" + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.context import Context + + +class SageMakerNotebookSensor(BaseSensorOperator): + """ + Waits for an Sagemaker Workflows Notebook execution to reach any of the status below. Review Comment: ```suggestion Waits for a Sagemaker Workflows Notebook execution to reach any of the status below. ``` ########## providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py: ########## @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pytest + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with the following user initialized: + Username: airflowTestUser + Password: airflowSystemTestP@ssword1! +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain + +Essentially, this test will emulate a DAG run in the shared MWAA environment inside a SageMaker Unified Studio Project. +The setup tasks will set up the project and configure the test runnner to emulate an MWAA instance. +Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully. +The teardown tasks will finally delete the project and domain that was set up for this test run. +""" + +pytestmark = pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+" +) + +DAG_ID = "example_sagemaker_unified_studio" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID" +S3_PATH_KEY = "S3_PATH" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(ENVIRONMENT_ID_KEY) + .add_variable(S3_PATH_KEY) + .build() +) + + +@task +def emulate_mwaa_environment( + domain_id: str, project_id: str, environment_id: str, s3_path: str +): + """ + Sets several environment variables in the container to emulate an MWAA environment provisioned + within SageMaker Unified Studio. + """ + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] = ( + "https://datazone.us-east-1.api.aws" + ) + os.environ[f"{AIRFLOW_PREFIX}PROJECT_S3_PATH"] = s3_path + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_REGION"] = "us-east-1" + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + + test_env_id = test_context[ENV_ID_KEY] + domain_id = test_context[DOMAIN_ID_KEY] + project_id = test_context[PROJECT_ID_KEY] + environment_id = test_context[ENVIRONMENT_ID_KEY] + s3_path = test_context[S3_PATH_KEY] + + setup_mock_mwaa_environment = emulate_mwaa_environment( + domain_id, + project_id, + environment_id, + s3_path, + ) + + # [START howto_operator_sagemaker_unified_studio_notebook] + notebook_path = "test_notebook.ipynb" # This should be the path to your .ipynb, .sqlnb, or .vetl file in your project. + + run_notebook = SageMakerNotebookOperator( + task_id="initial", Review Comment: Somewhat strange task_id ########## providers/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py: ########## @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook operator.""" + +from functools import cached_property + +from airflow import AirflowException +from airflow.configuration import conf +from airflow.models import BaseOperator +from airflow.utils.context import Context + +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import SageMakerNotebookJobTrigger + + +class SageMakerNotebookOperator(BaseOperator): + """ + Provides Notebook execution functionality for Sagemaker Workflows. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.operators.sagemaker_workflows import NotebookOperator + + notebook_operator = NotebookOperator( + task_id='notebook_task', + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': ''}, + output_config={'output_format': 'ipynb'}, + wait_for_completion=True, + poll_interval=10, + max_polling_attempts=100, + ) + :param task_id: A unique, meaningful id for the task. + :param input_config: Configuration for the input file. Input path should be specified as a relative path. + The provided relative path will be automatically resolved to an absolute path within + the context of the user's home directory in the IDE. Input parms should be a dict Review Comment: ```suggestion the context of the user's home directory in the IDE. Input parsms should be a dict ``` ########## providers/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py: ########## @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook sensor.""" + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.context import Context + + +class SageMakerNotebookSensor(BaseSensorOperator): Review Comment: Use `AwsBaseSensor`? ########## providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py: ########## @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +import time + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from sagemaker_studio import ClientConfig +from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with the Sagemaker Workflows API. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook + + notebook_hook = NotebookHook( + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, + output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, + execution_name='notebook_execution', + poll_interval=10, + ) + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to control + Example: {'output_formats': ['NOTEBOOK']} + :param compute: compute configuration to use for the notebook execution. This is an required attribute Review Comment: ```suggestion :param compute: compute configuration to use for the notebook execution. This is a required attribute ``` ########## providers/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py: ########## @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook operator.""" + +from functools import cached_property + +from airflow import AirflowException +from airflow.configuration import conf +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import ( + SageMakerNotebookJobTrigger, +) +from airflow.utils.context import Context + + +class SageMakerNotebookOperator(BaseOperator): + """ + Provides Notebook execution functionality for Sagemaker Workflows. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.operators.sagemaker_workflows import NotebookOperator + + notebook_operator = NotebookOperator( Review Comment: ```suggestion notebook_operator = SageMakerNotebookOperator( ``` ########## providers/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py: ########## @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook job trigger.""" + +from airflow.triggers.base import BaseTrigger + + +class SageMakerNotebookJobTrigger(BaseTrigger): + """ + Watches for a notebook job, triggers when it finishes. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.operators.sagemaker_workflows import NotebookOperator + + notebook_operator = NotebookJobTrigger( + execution_id='notebook_job_1234', + execution_name='notebook_task', + poll_interval=10, + ) + + :param execution_id: A unique, meaningful id for the task. + :param execution_name: A unique, meaningful name for the task. + :param poll_interval: Interval in seconds to check the notebook execution status. + """ + + def __init__(self, execution_id, execution_name, poll_interval, **kwargs): + super().__init__(**kwargs) + self.execution_id = execution_id + self.execution_name = execution_name + self.poll_interval = poll_interval + + def serialize(self): + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "execution_id": self.execution_id, + "execution_name": self.execution_name, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + pass Review Comment: Is this intentionally left stubbed out? ########## providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py: ########## @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pytest + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with the following user initialized: + Username: airflowTestUser + Password: airflowSystemTestP@ssword1! +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain + +Essentially, this test will emulate a DAG run in the shared MWAA environment inside a SageMaker Unified Studio Project. +The setup tasks will set up the project and configure the test runnner to emulate an MWAA instance. +Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully. +The teardown tasks will finally delete the project and domain that was set up for this test run. +""" + +pytestmark = pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+" +) + +DAG_ID = "example_sagemaker_unified_studio" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID" +S3_PATH_KEY = "S3_PATH" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(ENVIRONMENT_ID_KEY) + .add_variable(S3_PATH_KEY) + .build() +) + + +@task +def emulate_mwaa_environment( + domain_id: str, project_id: str, environment_id: str, s3_path: str +): Review Comment: If run using an container based executor (like ECS or K8s) this will have no effect. A container will spin up, export these envs, then it will get torn down and the next task will run in a new container. So this test will fail to run on our ECS executor test suite. Any other way around this? Otherwise we'll need to create an image for container based tests, or at least provide these env vars through executor_config ########## providers/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py: ########## @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +import time + +from sagemaker_studio import ClientConfig +from sagemaker_studio._openapi.models import GetExecutionRequest, StartExecutionRequest +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with the Sagemaker Workflows API. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.hooks.notebook_hook import NotebookHook + + notebook_hook = NotebookHook( + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': {'param1': 'value1'}}, + output_config={'output_uri': 'folder/output/location/prefix', 'output_format': 'ipynb'}, + execution_name='notebook_execution', + poll_interval=10, + ) + + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to control + Example: {'output_formats': ['NOTEBOOK']} + :param compute: compute configuration to use for the notebook execution. This is an required attribute + if the execution is on a remote compute. + Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param poll_interval: Interval in seconds to check the notebook execution status. + """ + + def __init__( + self, + execution_name: str, + input_config: dict = {}, + output_config: dict = {"output_formats": ["NOTEBOOK"]}, + compute: dict = {}, + termination_condition: dict = {}, + tags: dict = {}, + poll_interval: int = 10, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config()) + self.execution_name = execution_name + self.input_config = input_config + self.output_config = output_config + self.compute = compute + self.termination_condition = termination_condition + self.tags = tags + self.poll_interval = poll_interval + + def _get_sagemaker_studio_config(self): + config = ClientConfig() + config.overrides["execution"] = {"local": is_local_runner()} + return config + + def _format_start_execution_input_config(self): + config = { + "notebook_config": { + "input_path": self.input_config.get("input_path"), + "input_parameters": self.input_config.get("input_params"), + }, + } + + return config + + def _format_start_execution_output_config(self): + output_formats = ( + self.output_config.get("output_formats") if self.output_config else ["NOTEBOOK"] + ) + config = { + "notebook_config": { + "output_formats": output_formats, + } + } + return config + + def start_notebook_execution(self): + start_execution_params = { + "execution_name": self.execution_name, + "execution_type": "NOTEBOOK", + "input_config": self._format_start_execution_input_config(), + "output_config": self._format_start_execution_output_config(), + "termination_condition": self.termination_condition, + "tags": self.tags, + } + if self.compute: + start_execution_params["compute"] = self.compute + + request = StartExecutionRequest(**start_execution_params) + + return self._sagemaker_studio.execution_client.start_execution(request) + + def wait_for_execution_completion(self, execution_id, context): + Review Comment: ```suggestion ``` ########## providers/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py: ########## @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook job trigger.""" + +from airflow.triggers.base import BaseTrigger + + +class SageMakerNotebookJobTrigger(BaseTrigger): Review Comment: Use `AwsBaseWaiterTrigger` ########## providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py: ########## @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pytest + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with the following user initialized: + Username: airflowTestUser + Password: airflowSystemTestP@ssword1! Review Comment: I don't see these credentials used anywhere in the test? Is it really mandatory? If yes, we should make it configurable and provide them through the test context builder ########## providers/tests/system/amazon/aws/example_sagemaker_unified_studio.py: ########## @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pytest + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with the following user initialized: + Username: airflowTestUser + Password: airflowSystemTestP@ssword1! +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain + +Essentially, this test will emulate a DAG run in the shared MWAA environment inside a SageMaker Unified Studio Project. +The setup tasks will set up the project and configure the test runnner to emulate an MWAA instance. +Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully. +The teardown tasks will finally delete the project and domain that was set up for this test run. +""" + +pytestmark = pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+" +) + +DAG_ID = "example_sagemaker_unified_studio" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID" +S3_PATH_KEY = "S3_PATH" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(ENVIRONMENT_ID_KEY) + .add_variable(S3_PATH_KEY) + .build() +) + + +@task +def emulate_mwaa_environment( + domain_id: str, project_id: str, environment_id: str, s3_path: str +): + """ + Sets several environment variables in the container to emulate an MWAA environment provisioned + within SageMaker Unified Studio. + """ + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod" + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] = ( + "https://datazone.us-east-1.api.aws" + ) + os.environ[f"{AIRFLOW_PREFIX}PROJECT_S3_PATH"] = s3_path + os.environ[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_REGION"] = "us-east-1" + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + + test_env_id = test_context[ENV_ID_KEY] + domain_id = test_context[DOMAIN_ID_KEY] + project_id = test_context[PROJECT_ID_KEY] + environment_id = test_context[ENVIRONMENT_ID_KEY] + s3_path = test_context[S3_PATH_KEY] + + setup_mock_mwaa_environment = emulate_mwaa_environment( + domain_id, + project_id, + environment_id, + s3_path, + ) + + # [START howto_operator_sagemaker_unified_studio_notebook] + notebook_path = "test_notebook.ipynb" # This should be the path to your .ipynb, .sqlnb, or .vetl file in your project. + + run_notebook = SageMakerNotebookOperator( + task_id="initial", + input_config={"input_path": notebook_path, "input_params": {}}, + output_config={"output_formats": ["NOTEBOOK"]}, # optional + compute={ + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 30, + }, # optional + termination_condition={"MaxRuntimeInSeconds": 600}, # optional + tags={}, # optional + wait_for_completion=True, # optional + poll_interval=5, # optional + deferrable=False, # optional + ) + # [END howto_operator_sagemaker_unified_studio_notebook] + + chain( + # TEST SETUP + test_context, + setup_mock_mwaa_environment, + # TEST BODY + run_notebook, + ) Review Comment: No teardown needed? What if things break, get stuck, timeout? Any way to manually stop a noteboook execution? ########## providers/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py: ########## @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook operator.""" + +from functools import cached_property + +from airflow import AirflowException +from airflow.configuration import conf +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import ( + SageMakerNotebookJobTrigger, +) +from airflow.utils.context import Context + + +class SageMakerNotebookOperator(BaseOperator): + """ + Provides Notebook execution functionality for Sagemaker Workflows. + + Examples: + .. code-block:: python + + from workflows.airflow.providers.amazon.aws.operators.sagemaker_workflows import NotebookOperator + + notebook_operator = NotebookOperator( + task_id='notebook_task', + input_config={'input_path': 'path/to/notebook.ipynb', 'input_params': ''}, + output_config={'output_format': 'ipynb'}, + wait_for_completion=True, + poll_interval=10, + max_polling_attempts=100, + ) + + :param task_id: A unique, meaningful id for the task. + :param input_config: Configuration for the input file. Input path should be specified as a relative path. + The provided relative path will be automatically resolved to an absolute path within + the context of the user's home directory in the IDE. Input params should be a dict + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params':{'key': 'value'}} + :param output_config: Configuration for the output format. It should include an output_format parameter to control + the format of the notebook execution output. + Example: {"output_formats": ["NOTEBOOK"]} + :param compute: compute configuration to use for the notebook execution. This is a required attribute + if the execution is on a remote compute. + Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param wait_for_completion: Indicates whether to wait for the notebook execution to complete. If True, wait for completion; if False, don't wait. + :param poll_interval: Interval in seconds to check the notebook execution status. Review Comment: Use `waiter_delay` and `waiter_max_attempts`, you can see other examples across the code base. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
