This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 50a68c6c4e Append region info to S3ToRedshitOperator if present
(#32328)
50a68c6c4e is described below
commit 50a68c6c4ecb0a45272be7df7939ded6f28cf2f9
Author: Richard Goodman <[email protected]>
AuthorDate: Tue Jul 11 23:06:25 2023 +0100
Append region info to S3ToRedshitOperator if present (#32328)
* Append region info to S3ToRedshitOperator if present
It's possible to copy from S3 into Redshift across different regions,
however, currently you are unable to do so with the
S3ToRedshiftOperator. This PR simply makes this possible, by checking
the aws connection passed has the region set in the extras part of the
connection config. If this is set, it'll use this in line with the
syntax defined
[here](https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html)
* Update tests to make assertion checking valid
Following on from discussion in PR, currently the way assertion is done
is kind of redundant, as it's asserting static variables == other static
variables. Instead, this now gets compared to what gets generated from
the `_build_copy_query` function, this has been reflected for all
applicable test cases in this file.
---
.../amazon/aws/transfers/s3_to_redshift.py | 13 ++-
.../amazon/aws/transfers/test_s3_to_redshift.py | 116 ++++++++++++++++-----
2 files changed, 98 insertions(+), 31 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 0d2a059f6e..b42b2d8cbb 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -119,13 +119,16 @@ class S3ToRedshiftOperator(BaseOperator):
if arg in self.redshift_data_api_kwargs.keys():
raise AirflowException(f"Cannot include param '{arg}' in
Redshift Data API kwargs")
- def _build_copy_query(self, copy_destination: str, credentials_block: str,
copy_options: str) -> str:
+ def _build_copy_query(
+ self, copy_destination: str, credentials_block: str, region_info: str,
copy_options: str
+ ) -> str:
column_names = "(" + ", ".join(self.column_list) + ")" if
self.column_list else ""
return f"""
COPY {copy_destination} {column_names}
FROM 's3://{self.s3_bucket}/{self.s3_key}'
credentials
'{credentials_block}'
+ {region_info}
{copy_options};
"""
@@ -139,7 +142,9 @@ class S3ToRedshiftOperator(BaseOperator):
else:
redshift_hook =
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
-
+ region_info = ""
+ if conn.extra_dejson.get("region", False):
+ region_info = f"region '{conn.extra_dejson['region']}'"
if conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
else:
@@ -151,7 +156,9 @@ class S3ToRedshiftOperator(BaseOperator):
destination = f"{self.schema}.{self.table}"
copy_destination = f"#{self.table}" if self.method == "UPSERT" else
destination
- copy_statement = self._build_copy_query(copy_destination,
credentials_block, copy_options)
+ copy_statement = self._build_copy_query(
+ copy_destination, credentials_block, region_info, copy_options
+ )
sql: str | Iterable[str]
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 33d38b94e7..f73acf661e 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -63,17 +63,19 @@ class TestS3ToRedshiftTransfer:
dag=None,
)
op.execute(None)
- copy_query = """
+ expected_copy_query = """
COPY schema.table
FROM 's3://bucket/key'
credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
"""
+ actual_copy_query = mock_run.call_args.args[0]
+
assert mock_run.call_count == 1
- assert access_key in copy_query
- assert secret_key in copy_query
- assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0],
copy_query)
+ assert access_key in actual_copy_query
+ assert secret_key in actual_copy_query
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
@@ -110,17 +112,19 @@ class TestS3ToRedshiftTransfer:
dag=None,
)
op.execute(None)
- copy_query = """
+ expected_copy_query = """
COPY schema.table (column_1, column_2)
FROM 's3://bucket/key'
credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
"""
+ actual_copy_query = mock_run.call_args.args[0]
+
assert mock_run.call_count == 1
- assert access_key in copy_query
- assert secret_key in copy_query
- assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0],
copy_query)
+ assert access_key in actual_copy_query
+ assert secret_key in actual_copy_query
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
@@ -263,18 +267,20 @@ class TestS3ToRedshiftTransfer:
dag=None,
)
op.execute(None)
- copy_statement = """
+ expected_copy_query = """
COPY schema.table
FROM 's3://bucket/key'
credentials
'aws_access_key_id=ASIA_aws_access_key_id;aws_secret_access_key=aws_secret_access_key;token=aws_secret_token'
;
"""
- assert access_key in copy_statement
- assert secret_key in copy_statement
- assert token in copy_statement
+ actual_copy_query = mock_run.call_args.args[0]
+
+ assert access_key in actual_copy_query
+ assert secret_key in actual_copy_query
+ assert token in actual_copy_query
assert mock_run.call_count == 1
- assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0],
copy_statement)
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
@@ -312,17 +318,68 @@ class TestS3ToRedshiftTransfer:
dag=None,
)
op.execute(None)
- copy_statement = """
+ expected_copy_query = """
COPY schema.table
FROM 's3://bucket/key'
credentials
'aws_iam_role=arn:aws:iam::112233445566:role/myRole'
;
"""
+ actual_copy_query = mock_run.call_args.args[0]
+
+ assert extra["role_arn"] in actual_copy_query
+ assert mock_run.call_count == 1
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
+ @mock.patch("boto3.session.Session")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+ def test_different_region(self, mock_run, mock_session, mock_connection,
mock_hook):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ extra = {"region": "eu-central-1"}
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+
+ mock_connection.return_value = Connection(extra=extra)
+ mock_hook.return_value = Connection(extra=extra)
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ op = S3ToRedshiftOperator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ copy_options=copy_options,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
+ )
+ op.execute(None)
+ expected_copy_query = """
+ COPY schema.table
+ FROM 's3://bucket/key'
+ credentials
+
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+ region 'eu-central-1'
+ ;
+ """
+ actual_copy_query = mock_run.call_args.args[0]
- assert extra["role_arn"] in copy_statement
+ assert access_key in actual_copy_query
+ assert secret_key in actual_copy_query
+ assert extra["region"] in actual_copy_query
assert mock_run.call_count == 1
- assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0],
copy_statement)
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)
def test_template_fields_overrides(self):
assert S3ToRedshiftOperator.template_fields == (
@@ -420,19 +477,10 @@ class TestS3ToRedshiftTransfer:
),
)
op.execute(None)
- copy_query = """
- COPY schema.table
- FROM 's3://bucket/key'
- credentials
-
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
- ;
- """
+
mock_run.assert_not_called()
- assert access_key in copy_query
- assert secret_key in copy_query
mock_rs.execute_statement.assert_called_once()
- # test with all args besides sql
_call = deepcopy(mock_rs.execute_statement.call_args.kwargs)
_call.pop("Sql")
assert _call == dict(
@@ -443,8 +491,20 @@ class TestS3ToRedshiftTransfer:
StatementName=statement_name,
WithEvent=False,
)
+
+ expected_copy_query = """
+ COPY schema.table
+ FROM 's3://bucket/key'
+ credentials
+
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
+ ;
+ """
+ actual_copy_query = mock_rs.execute_statement.call_args.kwargs["Sql"]
+
mock_rs.describe_statement.assert_called_once_with(
Id="STATEMENT_ID",
)
- # test sql arg
-
assert_equal_ignore_multiple_spaces(mock_rs.execute_statement.call_args.kwargs["Sql"],
copy_query)
+
+ assert access_key in actual_copy_query
+ assert secret_key in actual_copy_query
+ assert_equal_ignore_multiple_spaces(actual_copy_query,
expected_copy_query)