This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch openlineage-sagemaker-operators in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0d8ea29f54a5906b403128cf41bc44edeb9d1b90 Author: Maciej Obuchowski <[email protected]> AuthorDate: Fri Jun 9 14:14:56 2023 +0200 openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, Transform and Training operators Signed-off-by: Maciej Obuchowski <[email protected]> --- .../providers/amazon/aws/operators/sagemaker.py | 142 +++++++++++++++++++-- dev/breeze/tests/test_selective_checks.py | 9 +- generated/provider_dependencies.json | 1 + .../aws/operators/test_sagemaker_processing.py | 50 +++++++- .../aws/operators/test_sagemaker_training.py | 32 ++++- .../aws/operators/test_sagemaker_transform.py | 30 +++++ 6 files changed, 243 insertions(+), 21 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index ac1b7a73d2..44b20bb854 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -155,6 +155,14 @@ class SageMakerBaseOperator(BaseOperator): """Return SageMakerHook.""" return SageMakerHook(aws_conn_id=self.aws_conn_id) + @staticmethod + def path_to_s3_dataset(path): + from openlineage.client.run import Dataset + + path = path.replace("s3://", "") + split_path = path.split("/") + return Dataset(namespace=f"s3://{split_path[0]}", name="/".join(split_path[1:]), facets={}) + class SageMakerProcessingOperator(SageMakerBaseOperator): """ @@ -222,6 +230,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): self.max_attempts = max_attempts or 60 self.max_ingestion_time = max_ingestion_time self.deferrable = deferrable + self.processing_job: dict[Any, Any] | None = None def _create_integer_fields(self) -> None: """Set fields which should be cast to integers.""" @@ -279,14 +288,53 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): method_name="execute_complete", ) - return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))} + self.processing_job = { + "Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"])) + } + return self.processing_job def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") else: self.log.info(event["message"]) - return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))} + self.processing_job = { + "Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"])) + } + return self.processing_job + + def get_openlineage_facets_on_complete(self, task_instance): + """Returns OpenLineage data gathered from SageMaker's API response saved by processing job.""" + from airflow.providers.openlineage.extractors.base import OperatorLineage + + inputs, outputs = [], [] + try: + inputs, outputs = self._extract_s3_dataset_identifiers( + processing_inputs=self.processing_job["Processing"]["ProcessingInputs"], + processing_outputs=self.processing_job["Processing"]["ProcessingOutputConfig"]["Outputs"], + ) + except KeyError: + self.log.exception("Could not find input/output information in Xcom.") + + return OperatorLineage( + inputs=inputs, + outputs=outputs, + ) + + def _extract_s3_dataset_identifiers(self, processing_inputs, processing_outputs): + inputs, outputs = [], [] + try: + for processing_input in processing_inputs: + inputs.append(self.path_to_s3_dataset(processing_input["S3Input"]["S3Uri"])) + except KeyError: + self.log.exception("Cannot find S3 input details", exc_info=True) + + try: + for processing_output in processing_outputs: + outputs.append(self.path_to_s3_dataset(processing_output["S3Output"]["S3Uri"])) + except KeyError: + self.log.exception("Cannot find S3 output details.", exc_info=True) + return inputs, outputs class SageMakerEndpointConfigOperator(SageMakerBaseOperator): @@ -576,6 +624,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator): Provided value: '{action_if_job_exists}'." ) self.deferrable = deferrable + self.transform_data: dict[Any, Any] | None = None def _create_integer_fields(self) -> None: """Set fields which should be cast to integers.""" @@ -646,11 +695,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator): ), method_name="execute_complete", ) - - return { + self.transform_data = { "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])), } + return self.transform_data def execute_complete(self, context, event=None): if event["status"] != "success": @@ -658,10 +707,61 @@ class SageMakerTransformOperator(SageMakerBaseOperator): else: self.log.info(event["message"]) transform_config = self.config.get("Transform", self.config) - return { + self.transform_data = { "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])), } + return self.transform_data + + def get_openlineage_facets_on_complete(self, task_instance): + """Returns OpenLineage data gathered from SageMaker's API response saved by transform job.""" + from airflow.providers.openlineage.extractors import OperatorLineage + + model_package_arn = None + transform_input = None + transform_output = None + + try: + model_package_arn = self.transform_data["Model"]["PrimaryContainer"]["ModelPackageName"] + except KeyError: + self.log.error("Cannot find Model Package Name.", exc_info=True) + + try: + transform = self.transform_data["Transform"] + transform_input = transform["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] + transform_output = transform["TransformOutput"]["S3OutputPath"] + except KeyError: + self.log.error("Cannot find some required input/output details.", exc_info=True) + + inputs = [] + + if transform_input is not None: + inputs.append(self.path_to_s3_dataset(transform_input)) + + if model_package_arn is not None: + model_data_urls = self._get_model_data_urls(model_package_arn) + for model_data_url in model_data_urls: + inputs.append(self.path_to_s3_dataset(model_data_url)) + + output = [] + if transform_output is not None: + output.append(self.path_to_s3_dataset(transform_output)) + + return OperatorLineage(inputs=inputs, outputs=output) + + def _get_model_data_urls(self, model_package_arn): + model_data_urls = [] + try: + model_containers = self.hook.get_conn().describe_model_package( + ModelPackageName=model_package_arn + )["InferenceSpecification"]["Containers"] + + for container in model_containers: + model_data_urls.append(container["ModelDataUrl"]) + except Exception: + self.log.exception("Cannot retrieve model details.", exc_info=True) + + return model_data_urls class SageMakerTuningOperator(SageMakerBaseOperator): @@ -888,6 +988,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): Provided value: '{action_if_job_exists}'." ) self.deferrable = deferrable + self.training_data: dict[Any, Any] | None = None def expand_role(self) -> None: """Expands an IAM role name into an ARN.""" @@ -948,16 +1049,39 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): method_name="execute_complete", ) - result = {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))} - return result + self.training_data = { + "Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"])) + } + return self.training_data def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") else: self.log.info(event["message"]) - result = {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))} - return result + self.training_data = { + "Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"])) + } + return self.training_data + + def get_openlineage_facets_on_complete(self, task_instance): + """Returns OpenLineage data gathered from SageMaker's API response saved by training job.""" + from airflow.providers.openlineage.extractors import OperatorLineage + + inputs, outputs = [], [] + try: + for input_data in self.training_data["Training"]["InputDataConfig"]: + inputs.append(self.path_to_s3_dataset(input_data["DataSource"]["S3DataSource"]["S3Uri"])) + except KeyError: + self.log.exception("Issues extracting inputs.") + + try: + outputs.append( + self.path_to_s3_dataset(self.training_data["Training"]["ModelArtifacts"]["S3ModelArtifacts"]) + ) + except KeyError: + self.log.exception("Issues extracting inputs.") + return OperatorLineage(inputs=inputs, outputs=outputs) class SageMakerDeleteModelOperator(SageMakerBaseOperator): diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 7437642e4e..5df464429c 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -312,7 +312,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " "common.sql exasol ftp google http imap microsoft.azure " - "mongo mysql postgres salesforce ssh", + "mongo mysql openlineage postgres salesforce ssh", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -326,7 +326,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-amazon-tests": "true", "parallel-test-types-list-as-string": "Providers[amazon] Always " "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure," - "mongo,mysql,postgres,salesforce,ssh] Providers[google]", + "mongo,mysql,openlineage,postgres,salesforce,ssh] Providers[google]", }, id="Providers tests run including amazon tests if amazon provider files changed", ), @@ -354,7 +354,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " "common.sql exasol ftp google http imap microsoft.azure " - "mongo mysql postgres salesforce ssh", + "mongo mysql openlineage postgres salesforce ssh", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -368,7 +368,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] Always " "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp," - "http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh] Providers[google]", + "http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh] " + "Providers[google]", }, id="Providers tests run including amazon tests if amazon provider files changed", ), diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 09c2211b3a..1ca5605d98 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -47,6 +47,7 @@ "imap", "microsoft.azure", "mongo", + "openlineage", "salesforce", "ssh" ], diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 1d73d44bdf..817761c014 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -20,12 +20,17 @@ from unittest import mock import pytest from botocore.exceptions import ClientError +from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator +from airflow.providers.amazon.aws.operators.sagemaker import ( + SageMakerBaseOperator, + SageMakerProcessingOperator, +) from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.openlineage.extractors import OperatorLineage CREATE_PROCESSING_PARAMS: dict = { "AppSpecification": { @@ -238,14 +243,16 @@ class TestSageMakerProcessingOperator: action_if_job_exists="not_fail_or_increment", ) - @mock.patch.object(SageMakerHook, "create_processing_job") - @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists") - def test_operator_defer(self, mock_job_exists, mock_processing): - mock_processing.return_value = { + @mock.patch.object( + SageMakerHook, + "create_processing_job", + return_value={ "ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, - } - mock_job_exists.return_value = False + }, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_operator_defer(self, mock_job_exists, mock_processing): sagemaker_operator = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS, @@ -255,3 +262,32 @@ class TestSageMakerProcessingOperator: with pytest.raises(TaskDeferred) as exc: sagemaker_operator.execute(context=None) assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger" + + @mock.patch.object( + SageMakerHook, + "describe_processing_job", + return_value={ + "ProcessingInputs": [{"S3Input": {"S3Uri": "s3://input-bucket/input-path"}}], + "ProcessingOutputConfig": { + "Outputs": [{"S3Output": {"S3Uri": "s3://output-bucket/output-path"}}] + }, + }, + ) + @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) + @mock.patch.object( + SageMakerHook, + "create_processing_job", + return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_operator_openlineage_data(self, check_job_exists, mock_processing, _, mock_desc): + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, + config=CREATE_PROCESSING_PARAMS, + deferrable=True, + ) + sagemaker.execute(context=None) + assert sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage( + inputs=[Dataset(namespace="s3://input-bucket", name="input-path")], + outputs=[Dataset(namespace="s3://output-bucket", name="output-path")], + ) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index e551317d33..9cb50de7c8 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -20,12 +20,14 @@ from unittest import mock import pytest from botocore.exceptions import ClientError +from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator +from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator, SageMakerTrainingOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.openlineage.extractors import OperatorLineage EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["ResourceConfig", "InstanceCount"], @@ -127,3 +129,31 @@ class TestSageMakerTrainingOperator: with pytest.raises(TaskDeferred) as exc: self.sagemaker.execute(context=None) assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger" + + @mock.patch.object( + SageMakerHook, + "describe_training_job", + return_value={ + "InputDataConfig": [ + { + "DataSource": {"S3DataSource": {"S3Uri": "s3://input-bucket/input-path"}}, + } + ], + "ModelArtifacts": {"S3ModelArtifacts": "s3://model-bucket/model-path"}, + }, + ) + @mock.patch.object( + SageMakerHook, + "create_training_job", + return_value={ + "TrainingJobArn": "test_arn", + "ResponseMetadata": {"HTTPStatusCode": 200}, + }, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_execute_openlineage_data(self, mock_exists, mock_training, mock_desc): + self.sagemaker.execute(None) + assert self.sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage( + inputs=[Dataset(namespace="s3://input-bucket", name="input-path")], + outputs=[Dataset(namespace="s3://model-bucket", name="model-path")], + ) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 76a4d877b6..9a9af38b36 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -22,12 +22,14 @@ from unittest import mock import pytest from botocore.exceptions import ClientError +from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.openlineage.extractors import OperatorLineage EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["Transform", "TransformResources", "InstanceCount"], @@ -178,3 +180,31 @@ class TestSageMakerTransformOperator: with pytest.raises(TaskDeferred) as exc: self.sagemaker.execute(context=None) assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger" + + @mock.patch.object(SageMakerHook, "describe_transform_job") + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") + @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "create_transform_job") + def test_operator_lineage_data(self, mock_transform, mock_conn, mock_model, _, mock_desc): + self.sagemaker.check_if_job_exists = False + mock_conn.return_value.describe_model_package.return_value = { + "InferenceSpecification": {"Containers": [{"ModelDataUrl": "s3://model-bucket/model-path"}]}, + } + mock_model.return_value = {"PrimaryContainer": {"ModelPackageName": "package-name"}} + mock_desc.return_value = { + "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri": "s3://input-bucket/input-path"}}}, + "TransformOutput": {"S3OutputPath": "s3://output-bucket/output-path"}, + } + mock_transform.return_value = { + "TransformJobArn": "test_arn", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + self.sagemaker.execute(None) + assert self.sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage( + inputs=[ + Dataset(namespace="s3://input-bucket", name="input-path"), + Dataset(namespace="s3://model-bucket", name="model-path"), + ], + outputs=[Dataset(namespace="s3://output-bucket", name="output-path")], + )
