This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 62f9e68a54 openlineage, sagemaker: add OpenLineage support for
SageMaker's Processing, Transform and Training operators (#31816)
62f9e68a54 is described below
commit 62f9e68a54d1223d169551ed301651cf0068e004
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Wed Aug 2 13:52:12 2023 +0200
openlineage, sagemaker: add OpenLineage support for SageMaker's Processing,
Transform and Training operators (#31816)
Signed-off-by: Maciej Obuchowski <[email protected]>
---
.../providers/amazon/aws/operators/sagemaker.py | 154 +++++++++++++++++++--
dev/breeze/tests/test_selective_checks.py | 9 +-
generated/provider_dependencies.json | 1 +
.../aws/operators/test_sagemaker_processing.py | 50 ++++++-
.../aws/operators/test_sagemaker_training.py | 32 ++++-
.../aws/operators/test_sagemaker_transform.py | 30 ++++
6 files changed, 249 insertions(+), 27 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 83a1e4f3d2..173de73933 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -40,13 +40,14 @@ from airflow.providers.amazon.aws.utils.tags import
format_tags
from airflow.utils.json import AirflowJsonEncoder
if TYPE_CHECKING:
+ from airflow.providers.openlineage.extractors.base import OperatorLineage
from airflow.utils.context import Context
DEFAULT_CONN_ID: str = "aws_default"
CHECK_INTERVAL_SECOND: int = 30
-def serialize(result: dict) -> str:
+def serialize(result: dict) -> dict:
return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
@@ -158,6 +159,14 @@ class SageMakerBaseOperator(BaseOperator):
"""Return SageMakerHook."""
return SageMakerHook(aws_conn_id=self.aws_conn_id)
+ @staticmethod
+ def path_to_s3_dataset(path):
+ from openlineage.client.run import Dataset
+
+ path = path.replace("s3://", "")
+ split_path = path.split("/")
+ return Dataset(namespace=f"s3://{split_path[0]}",
name="/".join(split_path[1:]), facets={})
+
class SageMakerProcessingOperator(SageMakerBaseOperator):
"""
@@ -225,6 +234,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.deferrable = deferrable
+ self.serialized_job: dict
def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
@@ -282,14 +292,48 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs = []
+ outputs = []
+ try:
+ inputs, outputs = self._extract_s3_dataset_identifiers(
+ processing_inputs=self.serialized_job["ProcessingInputs"],
+
processing_outputs=self.serialized_job["ProcessingOutputConfig"]["Outputs"],
+ )
+ except KeyError:
+ self.log.exception("Could not find input/output information in
Xcom.")
+
+ return OperatorLineage(inputs=inputs, outputs=outputs)
+
+ def _extract_s3_dataset_identifiers(self, processing_inputs,
processing_outputs):
+ inputs = []
+ outputs = []
+ try:
+ for processing_input in processing_inputs:
+
inputs.append(self.path_to_s3_dataset(processing_input["S3Input"]["S3Uri"]))
+ except KeyError:
+ self.log.exception("Cannot find S3 input details", exc_info=True)
+
+ try:
+ for processing_output in processing_outputs:
+
outputs.append(self.path_to_s3_dataset(processing_output["S3Output"]["S3Uri"]))
+ except KeyError:
+ self.log.exception("Cannot find S3 output details.", exc_info=True)
+ return inputs, outputs
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
@@ -579,6 +623,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
Provided value: '{action_if_job_exists}'."
)
self.deferrable = deferrable
+ self.serialized_model: dict
+ self.serialized_tranform: dict
def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
@@ -650,10 +696,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
- }
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
+ return {"Model": self.serialized_model, "Transform":
self.serialized_tranform}
def execute_complete(self, context, event=None):
if event["status"] != "success":
@@ -661,10 +708,62 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
- return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
- }
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
+ return {"Model": self.serialized_model, "Transform":
self.serialized_tranform}
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by transform job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ model_package_arn = None
+ transform_input = None
+ transform_output = None
+
+ try:
+ model_package_arn =
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+ except KeyError:
+ self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+ try:
+ transform_input =
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+ "S3Uri"
+ ]
+ transform_output =
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+ except KeyError:
+ self.log.error("Cannot find some required input/output details.",
exc_info=True)
+
+ inputs = []
+
+ if transform_input is not None:
+ inputs.append(self.path_to_s3_dataset(transform_input))
+
+ if model_package_arn is not None:
+ model_data_urls = self._get_model_data_urls(model_package_arn)
+ for model_data_url in model_data_urls:
+ inputs.append(self.path_to_s3_dataset(model_data_url))
+
+ outputs = []
+ if transform_output is not None:
+ outputs.append(self.path_to_s3_dataset(transform_output))
+
+ return OperatorLineage(inputs=inputs, outputs=outputs)
+
+ def _get_model_data_urls(self, model_package_arn) -> list:
+ model_data_urls = []
+ try:
+ model_containers = self.hook.get_conn().describe_model_package(
+ ModelPackageName=model_package_arn
+ )["InferenceSpecification"]["Containers"]
+
+ for container in model_containers:
+ model_data_urls.append(container["ModelDataUrl"])
+ except KeyError:
+ self.log.exception("Cannot retrieve model details.", exc_info=True)
+
+ return model_data_urls
class SageMakerTuningOperator(SageMakerBaseOperator):
@@ -891,6 +990,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
Provided value: '{action_if_job_exists}'."
)
self.deferrable = deferrable
+ self.serialized_training_data: dict
def expand_role(self) -> None:
"""Expands an IAM role name into an ARN."""
@@ -951,16 +1051,40 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- result = {"Training":
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
- return result
+ self.serialized_training_data = serialize(
+ self.hook.describe_training_job(self.config["TrainingJobName"])
+ )
+ return {"Training": self.serialized_training_data}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- result = {"Training":
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
- return result
+ self.serialized_training_data = serialize(
+ self.hook.describe_training_job(self.config["TrainingJobName"])
+ )
+ return {"Training": self.serialized_training_data}
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by training job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ inputs = []
+ outputs = []
+ try:
+ for input_data in self.serialized_training_data["InputDataConfig"]:
+
inputs.append(self.path_to_s3_dataset(input_data["DataSource"]["S3DataSource"]["S3Uri"]))
+ except KeyError:
+ self.log.exception("Issues extracting inputs.")
+
+ try:
+ outputs.append(
+
self.path_to_s3_dataset(self.serialized_training_data["ModelArtifacts"]["S3ModelArtifacts"])
+ )
+ except KeyError:
+ self.log.exception("Issues extracting inputs.")
+ return OperatorLineage(inputs=inputs, outputs=outputs)
class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index e3f9bacb98..3cdf3dc7bb 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -312,7 +312,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
{
"affected-providers-list-as-string": "amazon apache.hive
cncf.kubernetes "
"common.sql exasol ftp google http imap microsoft.azure "
- "mongo mysql postgres salesforce ssh",
+ "mongo mysql openlineage postgres salesforce ssh",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
@@ -326,7 +326,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
"run-amazon-tests": "true",
"parallel-test-types-list-as-string": "Providers[amazon]
Always "
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure,"
- "mongo,mysql,postgres,salesforce,ssh] Providers[google]",
+ "mongo,mysql,openlineage,postgres,salesforce,ssh]
Providers[google]",
},
id="Providers tests run including amazon tests if amazon provider
files changed",
),
@@ -354,7 +354,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
{
"affected-providers-list-as-string": "amazon apache.hive
cncf.kubernetes "
"common.sql exasol ftp google http imap microsoft.azure "
- "mongo mysql postgres salesforce ssh",
+ "mongo mysql openlineage postgres salesforce ssh",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
@@ -368,7 +368,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "Providers[amazon]
Always "
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,"
-
"http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh]
Providers[google]",
+
"http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh] "
+ "Providers[google]",
},
id="Providers tests run including amazon tests if amazon provider
files changed",
),
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 42146c2767..02b5d59a13 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -47,6 +47,7 @@
"imap",
"microsoft.azure",
"mongo",
+ "openlineage",
"salesforce",
"ssh"
],
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 1d73d44bdf..817761c014 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -20,12 +20,17 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerProcessingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import (
+ SageMakerBaseOperator,
+ SageMakerProcessingOperator,
+)
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
CREATE_PROCESSING_PARAMS: dict = {
"AppSpecification": {
@@ -238,14 +243,16 @@ class TestSageMakerProcessingOperator:
action_if_job_exists="not_fail_or_increment",
)
- @mock.patch.object(SageMakerHook, "create_processing_job")
-
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
- def test_operator_defer(self, mock_job_exists, mock_processing):
- mock_processing.return_value = {
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={
"ProcessingJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
- }
- mock_job_exists.return_value = False
+ },
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
+ def test_operator_defer(self, mock_job_exists, mock_processing):
sagemaker_operator = SageMakerProcessingOperator(
**self.processing_config_kwargs,
config=CREATE_PROCESSING_PARAMS,
@@ -255,3 +262,32 @@ class TestSageMakerProcessingOperator:
with pytest.raises(TaskDeferred) as exc:
sagemaker_operator.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_processing_job",
+ return_value={
+ "ProcessingInputs": [{"S3Input": {"S3Uri":
"s3://input-bucket/input-path"}}],
+ "ProcessingOutputConfig": {
+ "Outputs": [{"S3Output": {"S3Uri":
"s3://output-bucket/output-path"}}]
+ },
+ },
+ )
+ @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name",
return_value=0)
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata":
{"HTTPStatusCode": 200}},
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
+ def test_operator_openlineage_data(self, check_job_exists,
mock_processing, _, mock_desc):
+ sagemaker = SageMakerProcessingOperator(
+ **self.processing_config_kwargs,
+ config=CREATE_PROCESSING_PARAMS,
+ deferrable=True,
+ )
+ sagemaker.execute(context=None)
+ assert sagemaker.get_openlineage_facets_on_complete(None) ==
OperatorLineage(
+ inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+ outputs=[Dataset(namespace="s3://output-bucket",
name="output-path")],
+ )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index e551317d33..9cb50de7c8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -20,12 +20,14 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTrainingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerBaseOperator, SageMakerTrainingOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["ResourceConfig", "InstanceCount"],
@@ -127,3 +129,31 @@ class TestSageMakerTrainingOperator:
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_training_job",
+ return_value={
+ "InputDataConfig": [
+ {
+ "DataSource": {"S3DataSource": {"S3Uri":
"s3://input-bucket/input-path"}},
+ }
+ ],
+ "ModelArtifacts": {"S3ModelArtifacts":
"s3://model-bucket/model-path"},
+ },
+ )
+ @mock.patch.object(
+ SageMakerHook,
+ "create_training_job",
+ return_value={
+ "TrainingJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ },
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
+ def test_execute_openlineage_data(self, mock_exists, mock_training,
mock_desc):
+ self.sagemaker.execute(None)
+ assert self.sagemaker.get_openlineage_facets_on_complete(None) ==
OperatorLineage(
+ inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+ outputs=[Dataset(namespace="s3://model-bucket",
name="model-path")],
+ )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76a4d877b6..9a9af38b36 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -22,12 +22,14 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTransformOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["Transform", "TransformResources", "InstanceCount"],
@@ -178,3 +180,31 @@ class TestSageMakerTransformOperator:
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(SageMakerHook, "describe_transform_job")
+ @mock.patch.object(SageMakerHook, "create_model")
+ @mock.patch.object(SageMakerHook, "describe_model")
+ @mock.patch.object(SageMakerHook, "get_conn")
+ @mock.patch.object(SageMakerHook, "create_transform_job")
+ def test_operator_lineage_data(self, mock_transform, mock_conn,
mock_model, _, mock_desc):
+ self.sagemaker.check_if_job_exists = False
+ mock_conn.return_value.describe_model_package.return_value = {
+ "InferenceSpecification": {"Containers": [{"ModelDataUrl":
"s3://model-bucket/model-path"}]},
+ }
+ mock_model.return_value = {"PrimaryContainer": {"ModelPackageName":
"package-name"}}
+ mock_desc.return_value = {
+ "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri":
"s3://input-bucket/input-path"}}},
+ "TransformOutput": {"S3OutputPath":
"s3://output-bucket/output-path"},
+ }
+ mock_transform.return_value = {
+ "TransformJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ self.sagemaker.execute(None)
+ assert self.sagemaker.get_openlineage_facets_on_complete(None) ==
OperatorLineage(
+ inputs=[
+ Dataset(namespace="s3://input-bucket", name="input-path"),
+ Dataset(namespace="s3://model-bucket", name="model-path"),
+ ],
+ outputs=[Dataset(namespace="s3://output-bucket",
name="output-path")],
+ )