ferruzzi commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1279713333
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -647,9 +693,13 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ "Model": self.serialized_model,
Review Comment:
Nitpick: (here and below) For consistency I'd suggest you drop the trailing
comma so this collapses to one line now that it can fit.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -279,14 +288,49 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs, outputs = [], []
+ try:
+ inputs, outputs = self._extract_s3_dataset_identifiers(
+ processing_inputs=self.serialized_job["ProcessingInputs"],
+
processing_outputs=self.serialized_job["ProcessingOutputConfig"]["Outputs"],
+ )
+ except KeyError:
+ self.log.exception("Could not find input/output information in
Xcom.")
+
+ return OperatorLineage(
+ inputs=inputs,
+ outputs=outputs,
+ )
Review Comment:
Dropping the trailing comma lets that stay on one line as you did elsewhere.
```suggestion
return OperatorLineage(inputs=inputs, outputs=outputs)
```
##########
tests/providers/amazon/aws/operators/test_sagemaker_processing.py:
##########
@@ -238,14 +243,16 @@ def test_action_if_job_exists_validation(self,
mock_client):
action_if_job_exists="not_fail_or_increment",
)
- @mock.patch.object(SageMakerHook, "create_processing_job")
-
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
- def test_operator_defer(self, mock_job_exists, mock_processing):
- mock_processing.return_value = {
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={
"ProcessingJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
- }
- mock_job_exists.return_value = False
+ },
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
Review Comment:
These are very nice, thanks for cleaning them up.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -279,14 +288,49 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ return {"Processing": self.serialized_job}
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs, outputs = [], []
Review Comment:
Non-blocking nitpick: I know this does work in python, but I don't think
I've seen it used elsewhere in the airflow code; it seems to be a community
style preference to put each declaration on a new line.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -658,11 +708,66 @@ def execute_complete(self, context, event=None):
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
+ self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
+ self.serialized_tranform = serialize(
+
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ )
return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ "Model": self.serialized_model,
+ "Transform": self.serialized_tranform,
}
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by transform job."""
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ model_package_arn = None
+ transform_input = None
+ transform_output = None
+
+ try:
+ model_package_arn =
self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+ except KeyError:
+ self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+ try:
+ transform_input =
self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+ "S3Uri"
+ ]
+ transform_output =
self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+ except KeyError:
+ self.log.error("Cannot find some required input/output details.",
exc_info=True)
+
+ inputs = []
+
+ if transform_input is not None:
+ inputs.append(self.path_to_s3_dataset(transform_input))
+
+ if model_package_arn is not None:
+ model_data_urls = self._get_model_data_urls(model_package_arn)
+ for model_data_url in model_data_urls:
+ inputs.append(self.path_to_s3_dataset(model_data_url))
+
+ output = []
Review Comment:
Nitpick: you used `outputs` above and generally the name of a list is
pluralized as a hint. May want to do that here as well.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]