This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 6f956dc Allow S3ToSnowflakeOperator to omit schema (#15817)
6f956dc is described below
commit 6f956dc99b6c6393f7b50e9da9f778b5cf0bef88
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu May 13 19:48:17 2021 +0800
Allow S3ToSnowflakeOperator to omit schema (#15817)
Fix #12001.
As a drive-by cleanup, I also rewrote the text-building logic a bit to
remove superfulous spaces and blank lines in the query string.
---
.../snowflake/transfers/s3_to_snowflake.py | 42 +++++++++-------------
.../snowflake/transfers/test_s3_to_snowflake.py | 31 ++++++++--------
2 files changed, 31 insertions(+), 42 deletions(-)
diff --git a/airflow/providers/snowflake/transfers/s3_to_snowflake.py
b/airflow/providers/snowflake/transfers/s3_to_snowflake.py
index cebf546..be6e9a3 100644
--- a/airflow/providers/snowflake/transfers/s3_to_snowflake.py
+++ b/airflow/providers/snowflake/transfers/s3_to_snowflake.py
@@ -79,7 +79,7 @@ class S3ToSnowflakeOperator(BaseOperator):
stage: str,
prefix: Optional[str] = None,
file_format: str,
- schema: str, # TODO: shouldn't be required, rely on session/user
defaults
+ schema: Optional[str] = None,
columns_array: Optional[list] = None,
warehouse: Optional[str] = None,
database: Optional[str] = None,
@@ -117,33 +117,23 @@ class S3ToSnowflakeOperator(BaseOperator):
session_parameters=self.session_parameters,
)
- files = ""
- if self.s3_keys:
- files = "files=({})".format(", ".join(f"'{key}'" for key in
self.s3_keys))
+ if self.schema:
+ into = f"{self.schema}.{self.table}"
+ else:
+ into = self.table
+ if self.columns_array:
+ into = f"{into}({','.join(self.columns_array)})"
- # we can extend this based on stage
- base_sql = """
- FROM @{stage}/{prefix}
- {files}
- file_format={file_format}
- """.format(
- stage=self.stage,
- prefix=(self.prefix if self.prefix else ""),
- files=files,
- file_format=self.file_format,
- )
+ sql_parts = [
+ f"COPY INTO {into}",
+ f"FROM @{self.stage}/{self.prefix or ''}",
+ ]
+ if self.s3_keys:
+ files = ", ".join(f"'{key}'" for key in self.s3_keys)
+ sql_parts.append(f"files=({files})")
+ sql_parts.append(f"file_format={self.file_format}")
- if self.columns_array:
- copy_query = """
- COPY INTO {schema}.{table}({columns}) {base_sql}
- """.format(
- schema=self.schema, table=self.table,
columns=",".join(self.columns_array), base_sql=base_sql
- )
- else:
- copy_query = f"""
- COPY INTO {self.schema}.{self.table} {base_sql}
- """
- copy_query = "\n".join(line.strip() for line in
copy_query.splitlines())
+ copy_query = "\n".join(sql_parts)
self.log.info('Executing COPY command...')
snowflake_hook.run(copy_query, self.autocommit)
diff --git a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
index 37e6f3b..07defa1 100644
--- a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
+++ b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py
@@ -27,12 +27,12 @@ class TestS3ToSnowflakeTransfer:
@pytest.mark.parametrize("columns_array", [None, ['col1', 'col2', 'col3']])
@pytest.mark.parametrize("s3_keys", [None, ['1.csv', '2.csv']])
@pytest.mark.parametrize("prefix", [None, 'prefix'])
+ @pytest.mark.parametrize("schema", [None, 'schema'])
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.run")
- def test_execute(self, mock_run, prefix, s3_keys, columns_array):
+ def test_execute(self, mock_run, schema, prefix, s3_keys, columns_array):
table = 'table'
stage = 'stage'
file_format = 'file_format'
- schema = 'schema'
S3ToSnowflakeOperator(
s3_keys=s3_keys,
@@ -46,22 +46,21 @@ class TestS3ToSnowflakeTransfer:
dag=None,
).execute(None)
- files = None
+ copy_query = "COPY INTO "
+ if schema:
+ copy_query += f"{schema}.{table}"
+ else:
+ copy_query += table
+ if columns_array:
+ copy_query += f"({','.join(columns_array)})"
+
+ copy_query += f"\nFROM @{stage}/{prefix or ''}"
+
if s3_keys:
- files = "files=({})".format(", ".join(f"'{key}'" for key in
s3_keys))
- base_sql = f"""
- FROM @{stage}/{prefix if prefix else ''}
- {files if files else ''}
- file_format={file_format}
- """
+ files = ", ".join(f"'{key}'" for key in s3_keys)
+ copy_query += f"\nfiles=({files})"
- columns = None
- if columns_array:
- columns = f"({','.join(columns_array)})"
- copy_query = f"""
- COPY INTO {schema}.{table}{columns if columns else ''}
{base_sql}
- """
- copy_query = "\n".join(line.strip() for line in
copy_query.splitlines())
+ copy_query += f"\nfile_format={file_format}"
mock_run.assert_called_once()
assert mock_run.call_args[0][0] == copy_query