vandonr-amz commented on code in PR #31816:
URL: https://github.com/apache/airflow/pull/31816#discussion_r1224861901
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
def serialize(result: dict) -> str:
+ logging.getLogger(__name__).error(result)
Review Comment:
is this a personal debug log that you forgot to remove ?
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -275,14 +287,53 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.processing_job = {
+ "Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ }
+ return self.processing_job
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.processing_job = {
+ "Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ }
+ return self.processing_job
+
+ def get_openlineage_facets_on_complete(self, task_instance):
Review Comment:
the `task_instance` param seems to be unused ? (I didn't look at the static
checks, but they may report this too ?)
Also, rather than making this a method that should only be called after
`execute` has set `self.processing_job`, which is a bit brittle, why not make
the dependency explicit by expecting the serialized job as a parameter ? Then
it wouldn't have to be stored in the class at all.
This could also be a class method, the only usage of `self` after that is
the log.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -218,6 +229,7 @@ def __init__(
self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.deferrable = deferrable
+ self.processing_job: dict[Any, Any] | None = None
Review Comment:
I don't understand why this needs to be a dict when it seems to only ever
contain one key (`"Processing"`)
couldn't you replace this with a `serialized_job` ?
##########
tests/providers/amazon/aws/operators/test_sagemaker_training.py:
##########
@@ -127,3 +129,26 @@ def test_operator_defer(self, mock_training):
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(SageMakerHook, "describe_training_job")
+ @mock.patch.object(SageMakerHook, "create_training_job")
+ def test_execute_openlineage_data(self, mock_training, mock_desc):
+ mock_training.return_value = {
+ "TrainingJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ mock_desc.return_value = {
+ "InputDataConfig": [
+ {
+ "DataSource": {"S3DataSource": {"S3Uri":
"s3://input-bucket/input-path"}},
+ }
+ ],
+ "ModelArtifacts": {"S3ModelArtifacts":
"s3://model-bucket/model-path"},
+ }
+ self.sagemaker.check_if_job_exists = False
+ self.sagemaker._check_if_job_exists = mock.MagicMock()
Review Comment:
you could set this up in an annotation like the other 2
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -541,6 +592,7 @@ def __init__(
Provided value: '{action_if_job_exists}'."
)
self.deferrable = deferrable
+ self.transform_data: dict[Any, Any] | None = None
Review Comment:
same remarks as above
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
def serialize(result: dict) -> str:
+ logging.getLogger(__name__).error(result)
return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
+def path_to_s3_dataset(path):
+ from openlineage.client.run import Dataset
+
+ path = path.replace("s3://", "")
+ namespace = path.split("/")[0]
+ name = "/".join(path.split("/")[1:])
Review Comment:
you could do the split only once where to avoid repeating the work.
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -42,9 +43,19 @@
def serialize(result: dict) -> str:
+ logging.getLogger(__name__).error(result)
return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
+def path_to_s3_dataset(path):
Review Comment:
you could tuck this as a static method in SageMakerBaseOperator
##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -275,14 +287,53 @@ def execute(self, context: Context) -> dict:
method_name="execute_complete",
)
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.processing_job = {
+ "Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ }
+ return self.processing_job
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
- return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+ self.processing_job = {
+ "Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ }
+ return self.processing_job
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """Returns OpenLineage data gathered from SageMaker's API response
saved by processing job."""
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+
+ inputs, outputs = [], []
+ try:
+ inputs, outputs = self._get_s3_datasets(
+
processing_inputs=self.processing_job["Processing"]["ProcessingInputs"],
+
processing_outputs=self.processing_job["Processing"]["ProcessingOutputConfig"]["Outputs"],
+ )
+ except KeyError:
+ self.log.exception("Could not find input/output information in
Xcom.")
+
+ return OperatorLineage(
+ inputs=inputs,
+ outputs=outputs,
+ )
+
+ def _get_s3_datasets(self, processing_inputs, processing_outputs):
Review Comment:
You might want to rename this method. With this name, it sounds like this is
doing S3 queries, when all it's doing it building objects to be used to get the
data later on.
##########
tests/providers/amazon/aws/operators/test_sagemaker_processing.py:
##########
@@ -255,3 +257,31 @@ def test_operator_defer(self, mock_job_exists,
mock_processing):
with pytest.raises(TaskDeferred) as exc:
sagemaker_operator.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(SageMakerHook, "describe_processing_job")
+ @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name",
return_value=0)
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata":
{"HTTPStatusCode": 200}},
+ )
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
Review Comment:
you could use `@mock.patch.object` here as well
##########
tests/providers/amazon/aws/operators/test_sagemaker_processing.py:
##########
@@ -255,3 +257,31 @@ def test_operator_defer(self, mock_job_exists,
mock_processing):
with pytest.raises(TaskDeferred) as exc:
sagemaker_operator.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
+ @mock.patch.object(SageMakerHook, "describe_processing_job")
+ @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name",
return_value=0)
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata":
{"HTTPStatusCode": 200}},
+ )
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
+ def test_operator_openlineage_data(self, check_job_exists,
mock_processing, _, mock_desc):
+ check_job_exists.return_value = False
+ mock_desc.return_value = {
+ "ProcessingInputs": [{"S3Input": {"S3Uri":
"s3://input-bucket/input-path"}}],
+ "ProcessingOutputConfig": {
+ "Outputs": [{"S3Output": {"S3Uri":
"s3://output-bucket/output-path"}}]
+ },
+ }
Review Comment:
for consistency, you could put this in the annotation, or put the other
returns values in the test body, but mixing the setup styles makes it harder to
read
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]