mobuchowski commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1280791684


##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
         else:
             self.log.info(event["message"])
         transform_config = self.config.get("Transform", self.config)
+        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            
self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
         return {
-            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+            "Model": self.serialized_model,
+            "Transform": self.serialized_tranform,
         }
 
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by transform job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        model_package_arn = None
+        transform_input = None
+        transform_output = None
+
+        try:
+            model_package_arn = 
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+        except KeyError:
+            self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+        try:
+            transform_input = 
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+                "S3Uri"
+            ]
+            transform_output = 
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+        except KeyError:
+            self.log.error("Cannot find some required input/output details.", 
exc_info=True)
+
+        inputs = []
+
+        if transform_input is not None:
+            inputs.append(self.path_to_s3_dataset(transform_input))
+
+        if model_package_arn is not None:
+            model_data_urls = self._get_model_data_urls(model_package_arn)
+            for model_data_url in model_data_urls:
+                inputs.append(self.path_to_s3_dataset(model_data_url))
+
+        output = []
+        if transform_output is not None:
+            output.append(self.path_to_s3_dataset(transform_output))
+
+        return OperatorLineage(inputs=inputs, outputs=output)
+
+    def _get_model_data_urls(self, model_package_arn):
+        model_data_urls = []
+        try:
+            model_containers = self.hook.get_conn().describe_model_package(
+                ModelPackageName=model_package_arn
+            )["InferenceSpecification"]["Containers"]
+
+            for container in model_containers:
+                model_data_urls.append(container["ModelDataUrl"])
+        except Exception:

Review Comment:
   Done.



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
         else:
             self.log.info(event["message"])
         transform_config = self.config.get("Transform", self.config)
+        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            
self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
         return {
-            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+            "Model": self.serialized_model,
+            "Transform": self.serialized_tranform,
         }
 
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by transform job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        model_package_arn = None
+        transform_input = None
+        transform_output = None
+
+        try:
+            model_package_arn = 
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+        except KeyError:
+            self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+        try:
+            transform_input = 
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+                "S3Uri"
+            ]
+            transform_output = 
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+        except KeyError:
+            self.log.error("Cannot find some required input/output details.", 
exc_info=True)
+
+        inputs = []
+
+        if transform_input is not None:
+            inputs.append(self.path_to_s3_dataset(transform_input))
+
+        if model_package_arn is not None:
+            model_data_urls = self._get_model_data_urls(model_package_arn)
+            for model_data_url in model_data_urls:
+                inputs.append(self.path_to_s3_dataset(model_data_url))
+
+        output = []
+        if transform_output is not None:
+            output.append(self.path_to_s3_dataset(transform_output))
+
+        return OperatorLineage(inputs=inputs, outputs=output)
+
+    def _get_model_data_urls(self, model_package_arn):

Review Comment:
   Done.



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -576,6 +620,8 @@ def __init__(
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_model: dict[Any, Any] | None = None
+        self.serialized_tranform: dict[Any, Any] | None = None

Review Comment:
   The same ⬆️.



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -576,6 +620,8 @@ def __init__(
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_model: dict[Any, Any] | None = None
+        self.serialized_tranform: dict[Any, Any] | None = None

Review Comment:
   The same ⬆️.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to