vandonr-amz commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1279745380
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ "Model": self.serialized_model,
+ "Transform": self.serialized_tranform,
}
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by transform job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ model_package_arn = None
+ transform_input = None
+ transform_output = None
+
+ try:
+ model_package_arn =
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+ except KeyError:
+ self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+ try:
+ transform_input =
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+ "S3Uri"
+ ]
+ transform_output =
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+ except KeyError:
+ self.log.error("Cannot find some required input/output details.",
exc_info=True)
+
+ inputs = []
+
+ if transform_input is not None:
+ inputs.append(self.path_to_s3_dataset(transform_input))
+
+ if model_package_arn is not None:
+ model_data_urls = self._get_model_data_urls(model_package_arn)
+ for model_data_url in model_data_urls:
+ inputs.append(self.path_to_s3_dataset(model_data_url))
+
+ output = []
+ if transform_output is not None:
+ output.append(self.path_to_s3_dataset(transform_output))
+
+ return OperatorLineage(inputs=inputs, outputs=output)
+
+ def _get_model_data_urls(self, model_package_arn):
Review Comment:
```suggestion
def _get_model_data_urls(self, model_package_arn) -> list:
```
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ "Model": self.serialized_model,
+ "Transform": self.serialized_tranform,
}
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by transform job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ model_package_arn = None
+ transform_input = None
+ transform_output = None
+
+ try:
+ model_package_arn =
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+ except KeyError:
+ self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+ try:
+ transform_input =
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+ "S3Uri"
+ ]
+ transform_output =
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+ except KeyError:
+ self.log.error("Cannot find some required input/output details.",
exc_info=True)
+
+ inputs = []
+
+ if transform_input is not None:
+ inputs.append(self.path_to_s3_dataset(transform_input))
+
+ if model_package_arn is not None:
+ model_data_urls = self._get_model_data_urls(model_package_arn)
+ for model_data_url in model_data_urls:
+ inputs.append(self.path_to_s3_dataset(model_data_url))
+
+ output = []
+ if transform_output is not None:
+ output.append(self.path_to_s3_dataset(transform_output))
+
+ return OperatorLineage(inputs=inputs, outputs=output)
+
+ def _get_model_data_urls(self, model_package_arn):
+ model_data_urls = []
+ try:
+ model_containers = self.hook.get_conn().describe_model_package(
+ ModelPackageName=model_package_arn
+ )["InferenceSpecification"]["Containers"]
+
+ for container in model_containers:
+ model_data_urls.append(container["ModelDataUrl"])
+ except Exception:
Review Comment:
don't you want to catch the specific `KeyError` rather than everything, like
you did elsewhere ?
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -576,6 +620,8 @@ def __init__(
Provided value: '{action_if_job_exists}'."
)
self.deferrable = deferrable
+ self.serialized_model: dict[Any, Any] | None = None
+ self.serialized_tranform: dict[Any, Any] | None = None
Review Comment:
```suggestion
self.serialized_model: dict | None = None
self.serialized_tranform: dict | None = None
```
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -222,6 +230,7 @@ def __init__(
self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.deferrable = deferrable
+ self.serialized_job: dict[Any, Any] | None = None
Review Comment:
```suggestion
self.serialized_job: dict | None = None
```
`Any, Any` doesn't bring any new information.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]