This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f956dc  Allow S3ToSnowflakeOperator to omit schema (#15817)
6f956dc is described below

commit 6f956dc99b6c6393f7b50e9da9f778b5cf0bef88
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu May 13 19:48:17 2021 +0800

    Allow S3ToSnowflakeOperator to omit schema (#15817)
    
    Fix #12001.
    
    As a drive-by cleanup, I also rewrote the text-building logic a bit to 
remove superfulous spaces and blank lines in the query string.
---
 .../snowflake/transfers/s3_to_snowflake.py         | 42 +++++++++-------------
 .../snowflake/transfers/test_s3_to_snowflake.py    | 31 ++++++++--------
 2 files changed, 31 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/snowflake/transfers/s3_to_snowflake.py 
b/airflow/providers/snowflake/transfers/s3_to_snowflake.py
index cebf546..be6e9a3 100644
--- a/airflow/providers/snowflake/transfers/s3_to_snowflake.py
+++ b/airflow/providers/snowflake/transfers/s3_to_snowflake.py
@@ -79,7 +79,7 @@ class S3ToSnowflakeOperator(BaseOperator):
         stage: str,
         prefix: Optional[str] = None,
         file_format: str,
-        schema: str,  # TODO: shouldn't be required, rely on session/user 
defaults
+        schema: Optional[str] = None,
         columns_array: Optional[list] = None,
         warehouse: Optional[str] = None,
         database: Optional[str] = None,
@@ -117,33 +117,23 @@ class S3ToSnowflakeOperator(BaseOperator):
             session_parameters=self.session_parameters,
         )
 
-        files = ""
-        if self.s3_keys:
-            files = "files=({})".format(", ".join(f"'{key}'" for key in 
self.s3_keys))
+        if self.schema:
+            into = f"{self.schema}.{self.table}"
+        else:
+            into = self.table
+        if self.columns_array:
+            into = f"{into}({','.join(self.columns_array)})"
 
-        # we can extend this based on stage
-        base_sql = """
-                    FROM @{stage}/{prefix}
-                    {files}
-                    file_format={file_format}
-                """.format(
-            stage=self.stage,
-            prefix=(self.prefix if self.prefix else ""),
-            files=files,
-            file_format=self.file_format,
-        )
+        sql_parts = [
+            f"COPY INTO {into}",
+            f"FROM @{self.stage}/{self.prefix or ''}",
+        ]
+        if self.s3_keys:
+            files = ", ".join(f"'{key}'" for key in self.s3_keys)
+            sql_parts.append(f"files=({files})")
+        sql_parts.append(f"file_format={self.file_format}")
 
-        if self.columns_array:
-            copy_query = """
-                COPY INTO {schema}.{table}({columns}) {base_sql}
-            """.format(
-                schema=self.schema, table=self.table, 
columns=",".join(self.columns_array), base_sql=base_sql
-            )
-        else:
-            copy_query = f"""
-                COPY INTO {self.schema}.{self.table} {base_sql}
-            """
-        copy_query = "\n".join(line.strip() for line in 
copy_query.splitlines())
+        copy_query = "\n".join(sql_parts)
 
         self.log.info('Executing COPY command...')
         snowflake_hook.run(copy_query, self.autocommit)
diff --git a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py 
b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
index 37e6f3b..07defa1 100644
--- a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
+++ b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
@@ -27,12 +27,12 @@ class TestS3ToSnowflakeTransfer:
     @pytest.mark.parametrize("columns_array", [None, ['col1', 'col2', 'col3']])
     @pytest.mark.parametrize("s3_keys", [None, ['1.csv', '2.csv']])
     @pytest.mark.parametrize("prefix", [None, 'prefix'])
+    @pytest.mark.parametrize("schema", [None, 'schema'])
     
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.run")
-    def test_execute(self, mock_run, prefix, s3_keys, columns_array):
+    def test_execute(self, mock_run, schema, prefix, s3_keys, columns_array):
         table = 'table'
         stage = 'stage'
         file_format = 'file_format'
-        schema = 'schema'
 
         S3ToSnowflakeOperator(
             s3_keys=s3_keys,
@@ -46,22 +46,21 @@ class TestS3ToSnowflakeTransfer:
             dag=None,
         ).execute(None)
 
-        files = None
+        copy_query = "COPY INTO "
+        if schema:
+            copy_query += f"{schema}.{table}"
+        else:
+            copy_query += table
+        if columns_array:
+            copy_query += f"({','.join(columns_array)})"
+
+        copy_query += f"\nFROM @{stage}/{prefix or ''}"
+
         if s3_keys:
-            files = "files=({})".format(", ".join(f"'{key}'" for key in 
s3_keys))
-        base_sql = f"""
-                FROM @{stage}/{prefix if prefix else ''}
-                {files if files else ''}
-                file_format={file_format}
-            """
+            files = ", ".join(f"'{key}'" for key in s3_keys)
+            copy_query += f"\nfiles=({files})"
 
-        columns = None
-        if columns_array:
-            columns = f"({','.join(columns_array)})"
-        copy_query = f"""
-                COPY INTO {schema}.{table}{columns if columns else ''} 
{base_sql}
-            """
-        copy_query = "\n".join(line.strip() for line in 
copy_query.splitlines())
+        copy_query += f"\nfile_format={file_format}"
 
         mock_run.assert_called_once()
         assert mock_run.call_args[0][0] == copy_query

Reply via email to