This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fcd1a26a9a Allow user-specified object attributes to be used in 
check_fn for S3KeySensor (#39950)
fcd1a26a9a is described below

commit fcd1a26a9a006ee8ee3ac023abb247f565d36e67
Author: ellisms <[email protected]>
AuthorDate: Fri May 31 10:49:07 2024 -0400

    Allow user-specified object attributes to be used in check_fn for 
S3KeySensor (#39950)
    
    * Ability to specify s3 object attributes for check_fn
    
    * removed unncessary size check
    
    * Update airflow/providers/amazon/aws/sensors/s3.py
    
    Co-authored-by: Vincent <[email protected]>
    
    ---------
    
    Co-authored-by: Vincent <[email protected]>
---
 airflow/providers/amazon/aws/sensors/s3.py    |  40 +++++++-
 tests/providers/amazon/aws/sensors/test_s3.py | 141 ++++++++++++++++++++++++++
 2 files changed, 176 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/s3.py 
b/airflow/providers/amazon/aws/sensors/s3.py
index bb7105c3fb..adcdcbf010 100644
--- a/airflow/providers/amazon/aws/sensors/s3.py
+++ b/airflow/providers/amazon/aws/sensors/s3.py
@@ -78,6 +78,11 @@ class S3KeySensor(BaseSensorOperator):
                  CA cert bundle than the one used by botocore.
     :param deferrable: Run operator in the deferrable mode
     :param use_regex: whether to use regex to check bucket
+    :param metadata_keys: List of head_object attributes to gather and send to 
``check_fn``.
+        Acceptable values: Any top level attribute returned by s3.head_object. 
Specify * to return
+        all available attributes.
+        Default value: "Size".
+        If the requested attribute is not found, the key is still included and 
the value is None.
     """
 
     template_fields: Sequence[str] = ("bucket_key", "bucket_name")
@@ -93,6 +98,7 @@ class S3KeySensor(BaseSensorOperator):
         verify: str | bool | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         use_regex: bool = False,
+        metadata_keys: list[str] | None = None,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -104,14 +110,14 @@ class S3KeySensor(BaseSensorOperator):
         self.verify = verify
         self.deferrable = deferrable
         self.use_regex = use_regex
+        self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
 
     def _check_key(self, key):
         bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, 
"bucket_name", "bucket_key")
         self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
 
         """
-        Set variable `files` which contains a list of dict which contains only 
the size
-        If needed we might want to add other attributes later
+        Set variable `files` which contains a list of dict which contains 
attributes defined by the user
         Format: [{
             'Size': int
         }]
@@ -123,8 +129,21 @@ class S3KeySensor(BaseSensorOperator):
             if not key_matches:
                 return False
 
-            # Reduce the set of metadata to size only
-            files = [{"Size": f["Size"]} for f in key_matches]
+            # Reduce the set of metadata to requested attributes
+            files = []
+            for f in key_matches:
+                metadata = {}
+                if "*" in self.metadata_keys:
+                    metadata = self.hook.head_object(f["Key"], bucket_name)
+                else:
+                    for key in self.metadata_keys:
+                        try:
+                            metadata[key] = f[key]
+                        except KeyError:
+                            # supplied key might be from head_object response
+                            self.log.info("Key %s not found in response, 
performing head_object", key)
+                            metadata[key] = self.hook.head_object(f["Key"], 
bucket_name).get(key, None)
+                files.append(metadata)
         elif self.use_regex:
             keys = self.hook.get_file_metadata("", bucket_name)
             key_matches = [k for k in keys if re.match(pattern=key, 
string=k["Key"])]
@@ -134,7 +153,18 @@ class S3KeySensor(BaseSensorOperator):
             obj = self.hook.head_object(key, bucket_name)
             if obj is None:
                 return False
-            files = [{"Size": obj["ContentLength"]}]
+            metadata = {}
+            if "*" in self.metadata_keys:
+                metadata = self.hook.head_object(key, bucket_name)
+
+            else:
+                for key in self.metadata_keys:
+                    # backwards compatibility with original implementation
+                    if key == "Size":
+                        metadata[key] = obj.get("ContentLength")
+                    else:
+                        metadata[key] = obj.get(key, None)
+            files = [metadata]
 
         if self.check_fn is not None:
             return self.check_fn(files)
diff --git a/tests/providers/amazon/aws/sensors/test_s3.py 
b/tests/providers/amazon/aws/sensors/test_s3.py
index 2fa2e458a9..fd70f7134a 100644
--- a/tests/providers/amazon/aws/sensors/test_s3.py
+++ b/tests/providers/amazon/aws/sensors/test_s3.py
@@ -22,10 +22,12 @@ from unittest import mock
 
 import pytest
 import time_machine
+from moto import mock_aws
 
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.models.variable import Variable
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, 
S3KeysUnchangedSensor
 from airflow.utils import timezone
 
@@ -285,6 +287,145 @@ class TestS3KeySensor:
         with pytest.raises(expected_exception, match=message):
             op.execute_complete(context={}, event={"status": "error", 
"message": message})
 
+    @mock_aws
+    def test_custom_metadata_default_return_vals(self):
+        def check_fn(files: list) -> bool:
+            for f in files:
+                if "Size" not in f:
+                    return False
+            return True
+
+        hook = S3Hook()
+        hook.create_bucket(bucket_name="test-bucket")
+        hook.load_string(
+            bucket_name="test-bucket",
+            key="test-key",
+            string_data="test-body",
+        )
+
+        op = S3KeySensor(
+            task_id="test-metadata",
+            bucket_key="test-key",
+            bucket_name="test-bucket",
+            metadata_keys=["Size"],
+            check_fn=check_fn,
+        )
+        assert op.poke(None) is True
+        op = S3KeySensor(
+            task_id="test-metadata",
+            bucket_key="test-key",
+            bucket_name="test-bucket",
+            metadata_keys=["Content"],
+            check_fn=check_fn,
+        )
+        assert op.poke(None) is False
+
+        op = S3KeySensor(
+            task_id="test-metadata",
+            bucket_key="test-key",
+            bucket_name="test-bucket",
+            check_fn=check_fn,
+        )
+        assert op.poke(None) is True
+
+    @mock_aws
+    def test_custom_metadata_default_custom_vals(self):
+        def check_fn(files: list) -> bool:
+            for f in files:
+                if "LastModified" not in f or "ETag" not in f or "Size" in f:
+                    return False
+            return True
+
+        hook = S3Hook()
+        hook.create_bucket(bucket_name="test-bucket")
+        hook.load_string(
+            bucket_name="test-bucket",
+            key="test-key",
+            string_data="test-body",
+        )
+
+        op = S3KeySensor(
+            task_id="test-metadata",
+            bucket_key="test-key",
+            bucket_name="test-bucket",
+            metadata_keys=["LastModified", "ETag"],
+            check_fn=check_fn,
+        )
+        assert op.poke(None) is True
+
+    @mock_aws
+    def test_custom_metadata_all_attributes(self):
+        def check_fn(files: list) -> bool:
+            hook = S3Hook()
+            metadata_keys = set(hook.head_object(bucket_name="test-bucket", 
key="test-key").keys())
+            test_data_keys = set(files[0].keys())
+
+            return test_data_keys == metadata_keys
+
+        hook = S3Hook()
+        hook.create_bucket(bucket_name="test-bucket")
+        hook.load_string(
+            bucket_name="test-bucket",
+            key="test-key",
+            string_data="test-body",
+        )
+
+        op = S3KeySensor(
+            task_id="test-metadata",
+            bucket_key="test-key",
+            bucket_name="test-bucket",
+            metadata_keys=["*"],
+            check_fn=check_fn,
+        )
+        assert op.poke(None) is True
+
+    @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object")
+    
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
+    def test_custom_metadata_wildcard(self, mock_file_metadata, 
mock_head_object):
+        def check_fn(files: list) -> bool:
+            for f in files:
+                if "ETag" not in f or "MissingMeta" not in f:
+                    return False
+            return True
+
+        op = S3KeySensor(
+            task_id="test-head-metadata",
+            bucket_key=["s3://test-bucket/test-key*"],
+            metadata_keys=["MissingMeta", "ETag"],
+            check_fn=check_fn,
+            wildcard_match=True,
+        )
+
+        mock_file_metadata.return_value = [{"Key": "test-key", "ETag": 0}]
+        mock_head_object.return_value = {"MissingMeta": 0, "ContentLength": 
100}
+        assert op.poke(None) is True
+        mock_head_object.assert_called_once()
+
+    @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object")
+    
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.get_file_metadata")
+    def test_custom_metadata_wildcard_all_attributes(self, mock_file_metadata, 
mock_head_object):
+        def check_fn(files: list) -> bool:
+            for f in files:
+                if "ContentLength" not in f or "MissingMeta" not in f:
+                    return False
+            return True
+
+        op = S3KeySensor(
+            task_id="test-head-metadata",
+            bucket_key=["s3://test-bucket/test-key*"],
+            metadata_keys=["*"],
+            check_fn=check_fn,
+            wildcard_match=True,
+        )
+
+        mock_file_metadata.return_value = [{"Key": "test-key", "ETag": 0}]
+        mock_head_object.return_value = {"MissingMeta": 0, "ContentLength": 
100}
+        assert op.poke(None) is True
+        mock_head_object.assert_called_once()
+
+        mock_head_object.return_value = {"MissingMeta": 0}
+        assert op.poke(None) is False
+
 
 class TestS3KeysUnchangedSensor:
     def setup_method(self):

Reply via email to