This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-sagemaker-operators
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0d8ea29f54a5906b403128cf41bc44edeb9d1b90
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Fri Jun 9 14:14:56 2023 +0200

    openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, 
Transform and Training operators
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../providers/amazon/aws/operators/sagemaker.py    | 142 +++++++++++++++++++--
 dev/breeze/tests/test_selective_checks.py          |   9 +-
 generated/provider_dependencies.json               |   1 +
 .../aws/operators/test_sagemaker_processing.py     |  50 +++++++-
 .../aws/operators/test_sagemaker_training.py       |  32 ++++-
 .../aws/operators/test_sagemaker_transform.py      |  30 +++++
 6 files changed, 243 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index ac1b7a73d2..44b20bb854 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -155,6 +155,14 @@ class SageMakerBaseOperator(BaseOperator):
         """Return SageMakerHook."""
         return SageMakerHook(aws_conn_id=self.aws_conn_id)
 
+    @staticmethod
+    def path_to_s3_dataset(path):
+        from openlineage.client.run import Dataset
+
+        path = path.replace("s3://", "")
+        split_path = path.split("/")
+        return Dataset(namespace=f"s3://{split_path[0]}", 
name="/".join(split_path[1:]), facets={})
+
 
 class SageMakerProcessingOperator(SageMakerBaseOperator):
     """
@@ -222,6 +230,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.deferrable = deferrable
+        self.processing_job: dict[Any, Any] | None = None
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -279,14 +288,53 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by processing job."""
+        from airflow.providers.openlineage.extractors.base import 
OperatorLineage
+
+        inputs, outputs = [], []
+        try:
+            inputs, outputs = self._extract_s3_dataset_identifiers(
+                
processing_inputs=self.processing_job["Processing"]["ProcessingInputs"],
+                
processing_outputs=self.processing_job["Processing"]["ProcessingOutputConfig"]["Outputs"],
+            )
+        except KeyError:
+            self.log.exception("Could not find input/output information in 
Xcom.")
+
+        return OperatorLineage(
+            inputs=inputs,
+            outputs=outputs,
+        )
+
+    def _extract_s3_dataset_identifiers(self, processing_inputs, 
processing_outputs):
+        inputs, outputs = [], []
+        try:
+            for processing_input in processing_inputs:
+                
inputs.append(self.path_to_s3_dataset(processing_input["S3Input"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 input details", exc_info=True)
+
+        try:
+            for processing_output in processing_outputs:
+                
outputs.append(self.path_to_s3_dataset(processing_output["S3Output"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 output details.", exc_info=True)
+        return inputs, outputs
 
 
 class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
@@ -576,6 +624,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.transform_data: dict[Any, Any] | None = None
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -646,11 +695,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 ),
                 method_name="execute_complete",
             )
-
-        return {
+        self.transform_data = {
             "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
             "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
         }
+        return self.transform_data
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
@@ -658,10 +707,61 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         else:
             self.log.info(event["message"])
         transform_config = self.config.get("Transform", self.config)
-        return {
+        self.transform_data = {
             "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
             "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
         }
+        return self.transform_data
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by transform job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        model_package_arn = None
+        transform_input = None
+        transform_output = None
+
+        try:
+            model_package_arn = 
self.transform_data["Model"]["PrimaryContainer"]["ModelPackageName"]
+        except KeyError:
+            self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+        try:
+            transform = self.transform_data["Transform"]
+            transform_input = 
transform["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"]
+            transform_output = transform["TransformOutput"]["S3OutputPath"]
+        except KeyError:
+            self.log.error("Cannot find some required input/output details.", 
exc_info=True)
+
+        inputs = []
+
+        if transform_input is not None:
+            inputs.append(self.path_to_s3_dataset(transform_input))
+
+        if model_package_arn is not None:
+            model_data_urls = self._get_model_data_urls(model_package_arn)
+            for model_data_url in model_data_urls:
+                inputs.append(self.path_to_s3_dataset(model_data_url))
+
+        output = []
+        if transform_output is not None:
+            output.append(self.path_to_s3_dataset(transform_output))
+
+        return OperatorLineage(inputs=inputs, outputs=output)
+
+    def _get_model_data_urls(self, model_package_arn):
+        model_data_urls = []
+        try:
+            model_containers = self.hook.get_conn().describe_model_package(
+                ModelPackageName=model_package_arn
+            )["InferenceSpecification"]["Containers"]
+
+            for container in model_containers:
+                model_data_urls.append(container["ModelDataUrl"])
+        except Exception:
+            self.log.exception("Cannot retrieve model details.", exc_info=True)
+
+        return model_data_urls
 
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
@@ -888,6 +988,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.training_data: dict[Any, Any] | None = None
 
     def expand_role(self) -> None:
         """Expands an IAM role name into an ARN."""
@@ -948,16 +1049,39 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.training_data = {
+            "Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))
+        }
+        return self.training_data
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.training_data = {
+            "Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))
+        }
+        return self.training_data
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by training job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        inputs, outputs = [], []
+        try:
+            for input_data in 
self.training_data["Training"]["InputDataConfig"]:
+                
inputs.append(self.path_to_s3_dataset(input_data["DataSource"]["S3DataSource"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+
+        try:
+            outputs.append(
+                
self.path_to_s3_dataset(self.training_data["Training"]["ModelArtifacts"]["S3ModelArtifacts"])
+            )
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+        return OperatorLineage(inputs=inputs, outputs=outputs)
 
 
 class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index 7437642e4e..5df464429c 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -312,7 +312,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -326,7 +326,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "run-amazon-tests": "true",
                 "parallel-test-types-list-as-string": "Providers[amazon] 
Always "
                 
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure,"
-                "mongo,mysql,postgres,salesforce,ssh] Providers[google]",
+                "mongo,mysql,openlineage,postgres,salesforce,ssh] 
Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider 
files changed",
         ),
@@ -354,7 +354,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -368,7 +368,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "upgrade-to-newer-dependencies": "false",
                 "parallel-test-types-list-as-string": "Providers[amazon] 
Always "
                 "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,"
-                
"http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh] 
Providers[google]",
+                
"http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh] "
+                "Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider 
files changed",
         ),
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 09c2211b3a..1ca5605d98 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -47,6 +47,7 @@
       "imap",
       "microsoft.azure",
       "mongo",
+      "openlineage",
       "salesforce",
       "ssh"
     ],
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 1d73d44bdf..817761c014 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -20,12 +20,17 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerProcessingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerBaseOperator,
+    SageMakerProcessingOperator,
+)
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 CREATE_PROCESSING_PARAMS: dict = {
     "AppSpecification": {
@@ -238,14 +243,16 @@ class TestSageMakerProcessingOperator:
                 action_if_job_exists="not_fail_or_increment",
             )
 
-    @mock.patch.object(SageMakerHook, "create_processing_job")
-    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
-    def test_operator_defer(self, mock_job_exists, mock_processing):
-        mock_processing.return_value = {
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={
             "ProcessingJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
-        }
-        mock_job_exists.return_value = False
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_defer(self, mock_job_exists, mock_processing):
         sagemaker_operator = SageMakerProcessingOperator(
             **self.processing_config_kwargs,
             config=CREATE_PROCESSING_PARAMS,
@@ -255,3 +262,32 @@ class TestSageMakerProcessingOperator:
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_processing_job",
+        return_value={
+            "ProcessingInputs": [{"S3Input": {"S3Uri": 
"s3://input-bucket/input-path"}}],
+            "ProcessingOutputConfig": {
+                "Outputs": [{"S3Output": {"S3Uri": 
"s3://output-bucket/output-path"}}]
+            },
+        },
+    )
+    @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", 
return_value=0)
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_openlineage_data(self, check_job_exists, 
mock_processing, _, mock_desc):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+            deferrable=True,
+        )
+        sagemaker.execute(context=None)
+        assert sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://output-bucket", 
name="output-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index e551317d33..9cb50de7c8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -20,12 +20,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTrainingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerBaseOperator, SageMakerTrainingOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["ResourceConfig", "InstanceCount"],
@@ -127,3 +129,31 @@ class TestSageMakerTrainingOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job",
+        return_value={
+            "InputDataConfig": [
+                {
+                    "DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}},
+                }
+            ],
+            "ModelArtifacts": {"S3ModelArtifacts": 
"s3://model-bucket/model-path"},
+        },
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "create_training_job",
+        return_value={
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_execute_openlineage_data(self, mock_exists, mock_training, 
mock_desc):
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://model-bucket", 
name="model-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76a4d877b6..9a9af38b36 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -22,12 +22,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTransformOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["Transform", "TransformResources", "InstanceCount"],
@@ -178,3 +180,31 @@ class TestSageMakerTransformOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    @mock.patch.object(SageMakerHook, "describe_model")
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    def test_operator_lineage_data(self, mock_transform, mock_conn, 
mock_model, _, mock_desc):
+        self.sagemaker.check_if_job_exists = False
+        mock_conn.return_value.describe_model_package.return_value = {
+            "InferenceSpecification": {"Containers": [{"ModelDataUrl": 
"s3://model-bucket/model-path"}]},
+        }
+        mock_model.return_value = {"PrimaryContainer": {"ModelPackageName": 
"package-name"}}
+        mock_desc.return_value = {
+            "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}}},
+            "TransformOutput": {"S3OutputPath": 
"s3://output-bucket/output-path"},
+        }
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[
+                Dataset(namespace="s3://input-bucket", name="input-path"),
+                Dataset(namespace="s3://model-bucket", name="model-path"),
+            ],
+            outputs=[Dataset(namespace="s3://output-bucket", 
name="output-path")],
+        )

Reply via email to