mobuchowski commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1280779995
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -279,14 +288,49 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs, outputs = [], []
Review Comment:
Done.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ "Model": self.serialized_model,
+ "Transform": self.serialized_tranform,
}
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by transform job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ model_package_arn = None
+ transform_input = None
+ transform_output = None
+
+ try:
+ model_package_arn =
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+ except KeyError:
+ self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+ try:
+ transform_input =
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+ "S3Uri"
+ ]
+ transform_output =
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+ except KeyError:
+ self.log.error("Cannot find some required input/output details.",
exc_info=True)
+
+ inputs = []
+
+ if transform_input is not None:
+ inputs.append(self.path_to_s3_dataset(transform_input))
+
+ if model_package_arn is not None:
+ model_data_urls = self._get_model_data_urls(model_package_arn)
+ for model_data_url in model_data_urls:
+ inputs.append(self.path_to_s3_dataset(model_data_url))
+
+ output = []
Review Comment:
Done.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -279,14 +288,49 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs, outputs = [], []
+ try:
+ inputs, outputs = self._extract_s3_dataset_identifiers(
+ processing_inputs=self.serialized_job["ProcessingInputs"],
+
processing_outputs=self.serialized_job["ProcessingOutputConfig"]["Outputs"],
+ )
+ except KeyError:
+ self.log.exception("Could not find input/output information in
Xcom.")
+
+ return OperatorLineage(
+ inputs=inputs,
+ outputs=outputs,
+ )
Review Comment:
Fixed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]