vandonr-amz commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1224861901


##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
 
 
 def serialize(result: dict) -> str:
+    logging.getLogger(__name__).error(result)

Review Comment:
   is this a personal debug log that you forgot to remove ?



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -275,14 +287,53 @@ def execute(self, context: Context) -> dict:
                 method_name="execute_complete",
             )
 
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
+
+    def get_openlineage_facets_on_complete(self, task_instance):

Review Comment:
   the `task_instance` param seems to be unused ? (I didn't look at the static 
checks, but they may report this too ?)
   
   Also, rather than making this a method that should only be called after 
`execute` has set `self.processing_job`, which is a bit brittle, why not make 
the dependency explicit by expecting the serialized job as a parameter ? Then 
it wouldn't have to be stored in the class at all.
   This could also be a class method, the only usage of `self` after that is 
the log.



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -218,6 +229,7 @@ def __init__(
         self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.deferrable = deferrable
+        self.processing_job: dict[Any, Any] | None = None

Review Comment:
   I don't understand why this needs to be a dict when it seems to only ever 
contain one key (`"Processing"`)
   couldn't you replace this with a `serialized_job` ?



##########
tests/providers/amazon/aws/operators/test_sagemaker_training.py:
##########
@@ -127,3 +129,26 @@ def test_operator_defer(self, mock_training):
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_training_job")
+    @mock.patch.object(SageMakerHook, "create_training_job")
+    def test_execute_openlineage_data(self, mock_training, mock_desc):
+        mock_training.return_value = {
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        mock_desc.return_value = {
+            "InputDataConfig": [
+                {
+                    "DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}},
+                }
+            ],
+            "ModelArtifacts": {"S3ModelArtifacts": 
"s3://model-bucket/model-path"},
+        }
+        self.sagemaker.check_if_job_exists = False
+        self.sagemaker._check_if_job_exists = mock.MagicMock()

Review Comment:
   you could set this up in an annotation like the other 2



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -541,6 +592,7 @@ def __init__(
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.transform_data: dict[Any, Any] | None = None

Review Comment:
   same remarks as above



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
 
 
 def serialize(result: dict) -> str:
+    logging.getLogger(__name__).error(result)
     return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
 
 
+def path_to_s3_dataset(path):
+    from openlineage.client.run import Dataset
+
+    path = path.replace("s3://", "")
+    namespace = path.split("/")[0]
+    name = "/".join(path.split("/")[1:])

Review Comment:
   you could do the split only once where to avoid repeating the work.



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
 
 
 def serialize(result: dict) -> str:
+    logging.getLogger(__name__).error(result)
     return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
 
 
+def path_to_s3_dataset(path):

Review Comment:
   you could tuck this as a static method in SageMakerBaseOperator



##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -275,14 +287,53 @@ def execute(self, context: Context) -> dict:
                 method_name="execute_complete",
             )
 
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.processing_job = {
+            "Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        }
+        return self.processing_job
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Returns OpenLineage data gathered from SageMaker's API response 
saved by processing job."""
+        from airflow.providers.openlineage.extractors.base import 
OperatorLineage
+
+        inputs, outputs = [], []
+        try:
+            inputs, outputs = self._get_s3_datasets(
+                
processing_inputs=self.processing_job["Processing"]["ProcessingInputs"],
+                
processing_outputs=self.processing_job["Processing"]["ProcessingOutputConfig"]["Outputs"],
+            )
+        except KeyError:
+            self.log.exception("Could not find input/output information in 
Xcom.")
+
+        return OperatorLineage(
+            inputs=inputs,
+            outputs=outputs,
+        )
+
+    def _get_s3_datasets(self, processing_inputs, processing_outputs):

Review Comment:
   You might want to rename this method. With this name, it sounds like this is 
doing S3 queries, when all it's doing it building objects to be used to get the 
data later on.



##########
tests/providers/amazon/aws/operators/test_sagemaker_processing.py:
##########
@@ -255,3 +257,31 @@ def test_operator_defer(self, mock_job_exists, 
mock_processing):
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_processing_job")
+    @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", 
return_value=0)
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")

Review Comment:
   you could use `@mock.patch.object` here as well



##########
tests/providers/amazon/aws/operators/test_sagemaker_processing.py:
##########
@@ -255,3 +257,31 @@ def test_operator_defer(self, mock_job_exists, 
mock_processing):
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_processing_job")
+    @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", 
return_value=0)
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": 
{"HTTPStatusCode": 200}},
+    )
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
+    def test_operator_openlineage_data(self, check_job_exists, 
mock_processing, _, mock_desc):
+        check_job_exists.return_value = False
+        mock_desc.return_value = {
+            "ProcessingInputs": [{"S3Input": {"S3Uri": 
"s3://input-bucket/input-path"}}],
+            "ProcessingOutputConfig": {
+                "Outputs": [{"S3Output": {"S3Uri": 
"s3://output-bucket/output-path"}}]
+            },
+        }

Review Comment:
   for consistency, you could put this in the annotation, or put the other 
returns values in the test body, but mixing the setup styles makes it harder to 
read



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to