This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 94257f48f4 Expose SQL to GCS Metadata (#24382)
94257f48f4 is described below

commit 94257f48f4a3f123918b0d55c34753c7c413eb74
Author: Peter Wicks <[email protected]>
AuthorDate: Mon Jun 13 00:55:12 2022 -0600

    Expose SQL to GCS Metadata (#24382)
---
 .../providers/google/cloud/transfers/sql_to_gcs.py |  42 ++++++-
 .../google/cloud/transfers/test_mssql_to_gcs.py    |   6 +-
 .../google/cloud/transfers/test_mysql_to_gcs.py    |  14 +--
 .../google/cloud/transfers/test_oracle_to_gcs.py   |   6 +-
 .../google/cloud/transfers/test_postgres_to_gcs.py |   6 +-
 .../google/cloud/transfers/test_presto_to_gcs.py   |  12 +-
 .../google/cloud/transfers/test_sql_to_gcs.py      | 127 +++++++++++++++++++--
 .../google/cloud/transfers/test_trino_to_gcs.py    |  12 +-
 8 files changed, 185 insertions(+), 40 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py 
b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index 46e1ad505d..c204479024 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param upload_metadata: whether to upload the row count metadata as blob 
metadata
     :param exclude_columns: set of columns to exclude from transmission
     """
 
@@ -104,6 +105,7 @@ class BaseSQLToGCSOperator(BaseOperator):
         gcp_conn_id: str = 'google_cloud_default',
         delegate_to: Optional[str] = None,
         impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        upload_metadata: bool = False,
         exclude_columns=None,
         **kwargs,
     ) -> None:
@@ -125,6 +127,7 @@ class BaseSQLToGCSOperator(BaseOperator):
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
         self.impersonation_chain = impersonation_chain
+        self.upload_metadata = upload_metadata
         self.exclude_columns = exclude_columns
 
     def execute(self, context: 'Context'):
@@ -144,6 +147,9 @@ class BaseSQLToGCSOperator(BaseOperator):
             schema_file['file_handle'].close()
 
         counter = 0
+        files = []
+        total_row_count = 0
+        total_files = 0
         self.log.info('Writing local data files')
         for file_to_upload in self._write_local_data_files(cursor):
             # Flush file before uploading
@@ -154,8 +160,29 @@ class BaseSQLToGCSOperator(BaseOperator):
 
             self.log.info('Removing local file')
             file_to_upload['file_handle'].close()
+
+            # Metadata to be outputted to Xcom
+            total_row_count += file_to_upload['file_row_count']
+            total_files += 1
+            files.append(
+                {
+                    'file_name': file_to_upload['file_name'],
+                    'file_mime_type': file_to_upload['file_mime_type'],
+                    'file_row_count': file_to_upload['file_row_count'],
+                }
+            )
+
             counter += 1
 
+        file_meta = {
+            'bucket': self.bucket,
+            'total_row_count': total_row_count,
+            'total_files': total_files,
+            'files': files,
+        }
+
+        return file_meta
+
     def convert_types(self, schema, col_type_dict, row, stringify_dict=False) 
-> list:
         """Convert values from DBAPI to output-friendly formats."""
         return [
@@ -188,6 +215,7 @@ class BaseSQLToGCSOperator(BaseOperator):
             'file_name': self.filename.format(file_no),
             'file_handle': tmp_file_handle,
             'file_mime_type': file_mime_type,
+            'file_row_count': 0,
         }
 
         if self.export_format == 'csv':
@@ -197,6 +225,7 @@ class BaseSQLToGCSOperator(BaseOperator):
             parquet_writer = self._configure_parquet_file(tmp_file_handle, 
parquet_schema)
 
         for row in cursor:
+            file_to_upload['file_row_count'] += 1
             if self.export_format == 'csv':
                 row = self.convert_types(schema, col_type_dict, row)
                 if self.null_marker is not None:
@@ -232,6 +261,7 @@ class BaseSQLToGCSOperator(BaseOperator):
                     'file_name': self.filename.format(file_no),
                     'file_handle': tmp_file_handle,
                     'file_mime_type': file_mime_type,
+                    'file_row_count': 0,
                 }
                 if self.export_format == 'csv':
                     csv_writer = self._configure_csv_file(tmp_file_handle, 
schema)
@@ -239,7 +269,9 @@ class BaseSQLToGCSOperator(BaseOperator):
                     parquet_writer = 
self._configure_parquet_file(tmp_file_handle, parquet_schema)
         if self.export_format == 'parquet':
             parquet_writer.close()
-        yield file_to_upload
+        # Last file may have 0 rows, don't yield if empty
+        if file_to_upload['file_row_count'] > 0:
+            yield file_to_upload
 
     def _configure_csv_file(self, file_handle, schema):
         """Configure a csv writer with the file_handle and write schema
@@ -350,10 +382,16 @@ class BaseSQLToGCSOperator(BaseOperator):
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
+        is_data_file = file_to_upload.get('file_name') != self.schema_filename
+        metadata = None
+        if is_data_file and self.upload_metadata:
+            metadata = {'row_count': file_to_upload['file_row_count']}
+
         hook.upload(
             self.bucket,
             file_to_upload.get('file_name'),
             file_to_upload.get('file_handle').name,
             mime_type=file_to_upload.get('file_mime_type'),
-            gzip=self.gzip if file_to_upload.get('file_name') != 
self.schema_filename else False,
+            gzip=self.gzip if is_data_file else False,
+            metadata=metadata,
         )
diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
index b388f4548c..8b9d820221 100644
--- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
@@ -97,7 +97,7 @@ class 
TestMsSqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert JSON_FILENAME.format(0) == obj
             assert 'application/json' == mime_type
@@ -126,7 +126,7 @@ class 
TestMsSqlToGoogleCloudStorageOperator(unittest.TestCase):
             JSON_FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert 'application/json' == mime_type
             assert GZIP == gzip
@@ -154,7 +154,7 @@ class 
TestMsSqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, 'rb') as file:
                     assert b''.join(SCHEMA_JSON) == file.read()
diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
index c006c230d3..8d87ea9867 100644
--- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
@@ -124,7 +124,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert JSON_FILENAME.format(0) == obj
             assert 'application/json' == mime_type
@@ -158,7 +158,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert CSV_FILENAME.format(0) == obj
             assert 'text/csv' == mime_type
@@ -193,7 +193,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert CSV_FILENAME.format(0) == obj
             assert 'text/csv' == mime_type
@@ -228,7 +228,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert CSV_FILENAME.format(0) == obj
             assert 'text/csv' == mime_type
@@ -257,7 +257,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
             JSON_FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert 'application/json' == mime_type
             assert not gzip
@@ -285,7 +285,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 assert not gzip
                 with open(tmp_filename, 'rb') as file:
@@ -311,7 +311,7 @@ class 
TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 assert not gzip
                 with open(tmp_filename, 'rb') as file:
diff --git a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
index a49c224c7a..b90510cbae 100644
--- a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
@@ -70,7 +70,7 @@ class 
TestOracleToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert JSON_FILENAME.format(0) == obj
             assert 'application/json' == mime_type
@@ -99,7 +99,7 @@ class 
TestOracleToGoogleCloudStorageOperator(unittest.TestCase):
             JSON_FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type=None, 
gzip=False, metadata=None):
             assert BUCKET == bucket
             assert 'application/json' == mime_type
             assert GZIP == gzip
@@ -127,7 +127,7 @@ class 
TestOracleToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, 'rb') as file:
                     assert b''.join(SCHEMA_JSON) == file.read()
diff --git a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
index ff653292c4..e8007fc427 100644
--- a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
@@ -92,7 +92,7 @@ class 
TestPostgresToGoogleCloudStorageOperator(unittest.TestCase):
         assert op.bucket == BUCKET
         assert op.filename == FILENAME
 
-    def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, 
mime_type, gzip):
+    def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, 
mime_type, gzip, metadata=None):
         assert BUCKET == bucket
         assert FILENAME.format(0) == obj
         assert 'application/json' == mime_type
@@ -159,7 +159,7 @@ class 
TestPostgresToGoogleCloudStorageOperator(unittest.TestCase):
             FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert 'application/json' == mime_type
             assert not gzip
@@ -183,7 +183,7 @@ class 
TestPostgresToGoogleCloudStorageOperator(unittest.TestCase):
 
         gcs_hook_mock = gcs_hook_mock_class.return_value
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, 'rb') as file:
                     assert SCHEMA_JSON == file.read()
diff --git a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
index 80a5a50386..46b76621f2 100644
--- a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
@@ -65,7 +65,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
     @patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
     @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
     def test_save_as_json(self, mock_gcs_hook, mock_presto_hook):
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert FILENAME.format(0) == obj
             assert "application/json" == mime_type
@@ -120,7 +120,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
             FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert "application/json" == mime_type
             assert not gzip
@@ -160,7 +160,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
     def test_save_as_json_with_schema_file(self, mock_gcs_hook, 
mock_presto_hook):
         """Test writing schema files."""
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, "rb") as file:
                     assert SCHEMA_JSON == file.read()
@@ -199,7 +199,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
     @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
     @patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
     def test_save_as_csv(self, mock_presto_hook, mock_gcs_hook):
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert FILENAME.format(0) == obj
             assert "text/csv" == mime_type
@@ -255,7 +255,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
             FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]),
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert "text/csv" == mime_type
             assert not gzip
@@ -296,7 +296,7 @@ class TestPrestoToGCSOperator(unittest.TestCase):
     def test_save_as_csv_with_schema_file(self, mock_gcs_hook, 
mock_presto_hook):
         """Test writing schema files."""
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, "rb") as file:
                     assert SCHEMA_JSON == file.read()
diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
index 824ab8ff31..918450e0e5 100644
--- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
@@ -127,8 +127,20 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
             gzip=True,
             schema=SCHEMA,
             gcp_conn_id='google_cloud_default',
+            upload_metadata=True,
         )
-        operator.execute(context=dict())
+        result = operator.execute(context=dict())
+
+        assert result == {
+            'bucket': 'TEST-BUCKET-1',
+            'total_row_count': 3,
+            'total_files': 3,
+            'files': [
+                {'file_name': 'test_results_0.csv', 'file_mime_type': 
'text/csv', 'file_row_count': 1},
+                {'file_name': 'test_results_1.csv', 'file_mime_type': 
'text/csv', 'file_row_count': 1},
+                {'file_name': 'test_results_2.csv', 'file_mime_type': 
'text/csv', 'file_row_count': 1},
+            ],
+        }
 
         mock_query.assert_called_once()
         mock_writerow.assert_has_calls(
@@ -142,16 +154,25 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
                 mock.call(COLUMNS),
             ]
         )
-        mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), 
mock.call(), mock.call()])
+        mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), 
mock.call()])
         csv_calls = []
         for i in range(0, 3):
             csv_calls.append(
-                mock.call(BUCKET, FILENAME.format(i), TMP_FILE_NAME, 
mime_type='text/csv', gzip=True)
+                mock.call(
+                    BUCKET,
+                    FILENAME.format(i),
+                    TMP_FILE_NAME,
+                    mime_type='text/csv',
+                    gzip=True,
+                    metadata={'row_count': 1},
+                )
             )
-        json_call = mock.call(BUCKET, SCHEMA_FILE, TMP_FILE_NAME, 
mime_type=APP_JSON, gzip=False)
+        json_call = mock.call(
+            BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, 
gzip=False, metadata=None
+        )
         upload_calls = [json_call, csv_calls[0], csv_calls[1], csv_calls[2]]
         mock_upload.assert_has_calls(upload_calls)
-        mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), 
mock.call(), mock.call()])
+        mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), 
mock.call()])
 
         mock_query.reset_mock()
         mock_flush.reset_mock()
@@ -165,7 +186,16 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         operator = DummySQLToGCSOperator(
             sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, 
export_format="json", schema=SCHEMA
         )
-        operator.execute(context=dict())
+        result = operator.execute(context=dict())
+
+        assert result == {
+            'bucket': 'TEST-BUCKET-1',
+            'total_row_count': 3,
+            'total_files': 1,
+            'files': [
+                {'file_name': 'test_results_0.csv', 'file_mime_type': 
'application/json', 'file_row_count': 3}
+            ],
+        }
 
         mock_query.assert_called_once()
         mock_write.assert_has_calls(
@@ -180,7 +210,59 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         )
         mock_flush.assert_called_once()
         mock_upload.assert_called_once_with(
-            BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, 
gzip=False
+            BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, 
gzip=False, metadata=None
+        )
+        mock_close.assert_called_once()
+
+        mock_query.reset_mock()
+        mock_flush.reset_mock()
+        mock_upload.reset_mock()
+        mock_close.reset_mock()
+        cursor_mock.reset_mock()
+
+        cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
+
+        # Test Metadata Upload
+        operator = DummySQLToGCSOperator(
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            task_id=TASK_ID,
+            export_format="json",
+            schema=SCHEMA,
+            upload_metadata=True,
+        )
+        result = operator.execute(context=dict())
+
+        assert result == {
+            'bucket': 'TEST-BUCKET-1',
+            'total_row_count': 3,
+            'total_files': 1,
+            'files': [
+                {'file_name': 'test_results_0.csv', 'file_mime_type': 
'application/json', 'file_row_count': 3}
+            ],
+        }
+
+        mock_query.assert_called_once()
+        mock_write.assert_has_calls(
+            [
+                mock.call(OUTPUT_DATA),
+                mock.call(b"\n"),
+                mock.call(OUTPUT_DATA),
+                mock.call(b"\n"),
+                mock.call(OUTPUT_DATA),
+                mock.call(b"\n"),
+            ]
+        )
+
+        mock_flush.assert_called_once()
+        mock_upload.assert_called_once_with(
+            BUCKET,
+            FILENAME.format(0),
+            TMP_FILE_NAME,
+            mime_type=APP_JSON,
+            gzip=False,
+            metadata={'row_count': 3},
         )
         mock_close.assert_called_once()
 
@@ -196,12 +278,30 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         operator = DummySQLToGCSOperator(
             sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, 
export_format="parquet", schema=SCHEMA
         )
-        operator.execute(context=dict())
+        result = operator.execute(context=dict())
+
+        assert result == {
+            'bucket': 'TEST-BUCKET-1',
+            'total_row_count': 3,
+            'total_files': 1,
+            'files': [
+                {
+                    'file_name': 'test_results_0.csv',
+                    'file_mime_type': 'application/octet-stream',
+                    'file_row_count': 3,
+                }
+            ],
+        }
 
         mock_query.assert_called_once()
         mock_flush.assert_called_once()
         mock_upload.assert_called_once_with(
-            BUCKET, FILENAME.format(0), TMP_FILE_NAME, 
mime_type='application/octet-stream', gzip=False
+            BUCKET,
+            FILENAME.format(0),
+            TMP_FILE_NAME,
+            mime_type='application/octet-stream',
+            gzip=False,
+            metadata=None,
         )
         mock_close.assert_called_once()
 
@@ -217,7 +317,14 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
             export_format="csv",
             null_marker="NULL",
         )
-        operator.execute(context=dict())
+        result = operator.execute(context=dict())
+
+        assert result == {
+            'bucket': 'TEST-BUCKET-1',
+            'total_row_count': 3,
+            'total_files': 1,
+            'files': [{'file_name': 'test_results_0.csv', 'file_mime_type': 
'text/csv', 'file_row_count': 3}],
+        }
 
         mock_writerow.assert_has_calls(
             [
diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py 
b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py
index 1e5443f679..50828a36ea 100644
--- a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py
@@ -65,7 +65,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
     @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
     @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
     def test_save_as_json(self, mock_gcs_hook, mock_trino_hook):
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert FILENAME.format(0) == obj
             assert "application/json" == mime_type
@@ -120,7 +120,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
             FILENAME.format(1): NDJSON_LINES[2],
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert "application/json" == mime_type
             assert not gzip
@@ -160,7 +160,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
     def test_save_as_json_with_schema_file(self, mock_gcs_hook, 
mock_trino_hook):
         """Test writing schema files."""
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, "rb") as file:
                     assert SCHEMA_JSON == file.read()
@@ -199,7 +199,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
     @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
     @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
     def test_save_as_csv(self, mock_trino_hook, mock_gcs_hook):
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert FILENAME.format(0) == obj
             assert "text/csv" == mime_type
@@ -255,7 +255,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
             FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]),
         }
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             assert BUCKET == bucket
             assert "text/csv" == mime_type
             assert not gzip
@@ -296,7 +296,7 @@ class TestTrinoToGCSOperator(unittest.TestCase):
     def test_save_as_csv_with_schema_file(self, mock_gcs_hook, 
mock_trino_hook):
         """Test writing schema files."""
 
-        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+        def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, 
metadata=None):
             if obj == SCHEMA_FILENAME:
                 with open(tmp_filename, "rb") as file:
                     assert SCHEMA_JSON == file.read()

Reply via email to