This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 62f9e68a54 openlineage, sagemaker: add OpenLineage support for 
SageMaker's Processing, Transform and Training operators (#31816)
62f9e68a54 is described below

commit 62f9e68a54d1223d169551ed301651cf0068e004
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Wed Aug 2 13:52:12 2023 +0200

    openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, 
Transform and Training operators (#31816)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../providers/amazon/aws/operators/sagemaker.py    | 154 +++++++++++++++++++--
 dev/breeze/tests/test_selective_checks.py          |   9 +-
 generated/provider_dependencies.json               |   1 +
 .../aws/operators/test_sagemaker_processing.py     |  50 ++++++-
 .../aws/operators/test_sagemaker_training.py       |  32 ++++-
 .../aws/operators/test_sagemaker_transform.py      |  30 ++++
 6 files changed, 249 insertions(+), 27 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 83a1e4f3d2..173de73933 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -40,13 +40,14 @@ from airflow.providers.amazon.aws.utils.tags import 
format_tags
 from airflow.utils.json import AirflowJsonEncoder
 
 if TYPE_CHECKING:
+    from airflow.providers.openlineage.extractors.base import OperatorLineage
     from airflow.utils.context import Context
 
 DEFAULT_CONN_ID: str = "aws_default"
 CHECK_INTERVAL_SECOND: int = 30
 
 
-def serialize(result: dict) -> str:
+def serialize(result: dict) -> dict:
     return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
 
 
@@ -158,6 +159,14 @@ class SageMakerBaseOperator(BaseOperator):
         """Return SageMakerHook."""
         return SageMakerHook(aws_conn_id=self.aws_conn_id)
 
+    @staticmethod
+    def path_to_s3_dataset(path):
+        from openlineage.client.run import Dataset
+
+        path = path.replace("s3://", "")
+        split_path = path.split("/")
+        return Dataset(namespace=f"s3://{split_path[0]}", 
name="/".join(split_path[1:]), facets={})
+
 
 class SageMakerProcessingOperator(SageMakerBaseOperator):
     """
@@ -225,6 +234,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.deferrable = deferrable
+        self.serialized_job: dict
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -282,14 +292,48 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.serialized_job = 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        return {"Processing": self.serialized_job}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.serialized_job = 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        return {"Processing": self.serialized_job}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by processing job."""
+        from airflow.providers.openlineage.extractors.base import 
OperatorLineage
+
+        inputs = []
+        outputs = []
+        try:
+            inputs, outputs = self._extract_s3_dataset_identifiers(
+                processing_inputs=self.serialized_job["ProcessingInputs"],
+                
processing_outputs=self.serialized_job["ProcessingOutputConfig"]["Outputs"],
+            )
+        except KeyError:
+            self.log.exception("Could not find input/output information in 
Xcom.")
+
+        return OperatorLineage(inputs=inputs, outputs=outputs)
+
+    def _extract_s3_dataset_identifiers(self, processing_inputs, 
processing_outputs):
+        inputs = []
+        outputs = []
+        try:
+            for processing_input in processing_inputs:
+                
inputs.append(self.path_to_s3_dataset(processing_input["S3Input"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 input details", exc_info=True)
+
+        try:
+            for processing_output in processing_outputs:
+                
outputs.append(self.path_to_s3_dataset(processing_output["S3Output"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 output details.", exc_info=True)
+        return inputs, outputs
 
 
 class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
@@ -579,6 +623,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_model: dict
+        self.serialized_tranform: dict
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -650,10 +696,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return {
-            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
-        }
+        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            
self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
+        return {"Model": self.serialized_model, "Transform": 
self.serialized_tranform}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
@@ -661,10 +708,62 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         else:
             self.log.info(event["message"])
         transform_config = self.config.get("Transform", self.config)
-        return {
-            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
-        }
+        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            
self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
+        return {"Model": self.serialized_model, "Transform": 
self.serialized_tranform}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by transform job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        model_package_arn = None
+        transform_input = None
+        transform_output = None
+
+        try:
+            model_package_arn = 
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+        except KeyError:
+            self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+        try:
+            transform_input = 
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+                "S3Uri"
+            ]
+            transform_output = 
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+        except KeyError:
+            self.log.error("Cannot find some required input/output details.", 
exc_info=True)
+
+        inputs = []
+
+        if transform_input is not None:
+            inputs.append(self.path_to_s3_dataset(transform_input))
+
+        if model_package_arn is not None:
+            model_data_urls = self._get_model_data_urls(model_package_arn)
+            for model_data_url in model_data_urls:
+                inputs.append(self.path_to_s3_dataset(model_data_url))
+
+        outputs = []
+        if transform_output is not None:
+            outputs.append(self.path_to_s3_dataset(transform_output))
+
+        return OperatorLineage(inputs=inputs, outputs=outputs)
+
+    def _get_model_data_urls(self, model_package_arn) -> list:
+        model_data_urls = []
+        try:
+            model_containers = self.hook.get_conn().describe_model_package(
+                ModelPackageName=model_package_arn
+            )["InferenceSpecification"]["Containers"]
+
+            for container in model_containers:
+                model_data_urls.append(container["ModelDataUrl"])
+        except KeyError:
+            self.log.exception("Cannot retrieve model details.", exc_info=True)
+
+        return model_data_urls
 
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
@@ -891,6 +990,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_training_data: dict
 
     def expand_role(self) -> None:
         """Expands an IAM role name into an ARN."""
@@ -951,16 +1051,40 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.serialized_training_data = serialize(
+            self.hook.describe_training_job(self.config["TrainingJobName"])
+        )
+        return {"Training": self.serialized_training_data}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.serialized_training_data = serialize(
+            self.hook.describe_training_job(self.config["TrainingJobName"])
+        )
+        return {"Training": self.serialized_training_data}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by training job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        inputs = []
+        outputs = []
+        try:
+            for input_data in self.serialized_training_data["InputDataConfig"]:
+                
inputs.append(self.path_to_s3_dataset(input_data["DataSource"]["S3DataSource"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+
+        try:
+            outputs.append(
+                
self.path_to_s3_dataset(self.serialized_training_data["ModelArtifacts"]["S3ModelArtifacts"])
+            )
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+        return OperatorLineage(inputs=inputs, outputs=outputs)
 
 
 class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index e3f9bacb98..3cdf3dc7bb 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -312,7 +312,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -326,7 +326,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "run-amazon-tests": "true",
                 "parallel-test-types-list-as-string": "Providers[amazon] 
Always "
                 
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure,"
-                "mongo,mysql,postgres,salesforce,ssh] Providers[google]",
+                "mongo,mysql,openlineage,postgres,salesforce,ssh] 
Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider 
files changed",
         ),
@@ -354,7 +354,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive 
cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -368,7 +368,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 "upgrade-to-newer-dependencies": "false",
                 "parallel-test-types-list-as-string": "Providers[amazon] 
Always "
                 "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,"
-                
"http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh] 
Providers[google]",
+                
"http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh] "
+                "Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider 
files changed",
         ),
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 42146c2767..02b5d59a13 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -47,6 +47,7 @@
       "imap",
       "microsoft.azure",
       "mongo",
+      "openlineage",
       "salesforce",
       "ssh"
     ],
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 1d73d44bdf..817761c014 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -20,12 +20,17 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerProcessingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerBaseOperator,
+    SageMakerProcessingOperator,
+)
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 CREATE_PROCESSING_PARAMS: dict = {
     "AppSpecification": {
@@ -238,14 +243,16 @@ class TestSageMakerProcessingOperator:
                 action_if_job_exists="not_fail_or_increment",
             )
 
-    @mock.patch.object(SageMakerHook, "create_processing_job")
-    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
-    def test_operator_defer(self, mock_job_exists, mock_processing):
-        mock_processing.return_value = {
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={
             "ProcessingJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
-        }
-        mock_job_exists.return_value = False
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_defer(self, mock_job_exists, mock_processing):
         sagemaker_operator = SageMakerProcessingOperator(
             **self.processing_config_kwargs,
             config=CREATE_PROCESSING_PARAMS,
@@ -255,3 +262,32 @@ class TestSageMakerProcessingOperator:
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_processing_job",
+        return_value={
+            "ProcessingInputs": [{"S3Input": {"S3Uri": 
"s3://input-bucket/input-path"}}],
+            "ProcessingOutputConfig": {
+                "Outputs": [{"S3Output": {"S3Uri": 
"s3://output-bucket/output-path"}}]
+            },
+        },
+    )
+    @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", 
return_value=0)
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_operator_openlineage_data(self, check_job_exists, 
mock_processing, _, mock_desc):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+            deferrable=True,
+        )
+        sagemaker.execute(context=None)
+        assert sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://output-bucket", 
name="output-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index e551317d33..9cb50de7c8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -20,12 +20,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTrainingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerBaseOperator, SageMakerTrainingOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["ResourceConfig", "InstanceCount"],
@@ -127,3 +129,31 @@ class TestSageMakerTrainingOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job",
+        return_value={
+            "InputDataConfig": [
+                {
+                    "DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}},
+                }
+            ],
+            "ModelArtifacts": {"S3ModelArtifacts": 
"s3://model-bucket/model-path"},
+        },
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "create_training_job",
+        return_value={
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", 
return_value=False)
+    def test_execute_openlineage_data(self, mock_exists, mock_training, 
mock_desc):
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://model-bucket", 
name="model-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76a4d877b6..9a9af38b36 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -22,12 +22,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTransformOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["Transform", "TransformResources", "InstanceCount"],
@@ -178,3 +180,31 @@ class TestSageMakerTransformOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    @mock.patch.object(SageMakerHook, "describe_model")
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    def test_operator_lineage_data(self, mock_transform, mock_conn, 
mock_model, _, mock_desc):
+        self.sagemaker.check_if_job_exists = False
+        mock_conn.return_value.describe_model_package.return_value = {
+            "InferenceSpecification": {"Containers": [{"ModelDataUrl": 
"s3://model-bucket/model-path"}]},
+        }
+        mock_model.return_value = {"PrimaryContainer": {"ModelPackageName": 
"package-name"}}
+        mock_desc.return_value = {
+            "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}}},
+            "TransformOutput": {"S3OutputPath": 
"s3://output-bucket/output-path"},
+        }
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == 
OperatorLineage(
+            inputs=[
+                Dataset(namespace="s3://input-bucket", name="input-path"),
+                Dataset(namespace="s3://model-bucket", name="model-path"),
+            ],
+            outputs=[Dataset(namespace="s3://output-bucket", 
name="output-path")],
+        )

Reply via email to