vincbeck commented on code in PR #28964: URL: https://github.com/apache/airflow/pull/28964#discussion_r1080580457
########## airflow/providers/amazon/aws/transfers/s3_to_sql.py: ########## @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Callable, Iterable, Sequence + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.common.sql.hooks.sql import DbApiHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class S3ToSqlOperator(BaseOperator): + """ + Loads Data from S3 into a SQL Database. + You need to provide a parser function that takes a filename as an input + and returns a iterable of rows Review Comment: ```suggestion and returns an iterable of rows ``` ########## tests/providers/amazon/aws/transfers/test_s3_to_sql.py: ########## @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import or_ + +from airflow import configuration, models +from airflow.providers.amazon.aws.transfers.s3_to_sql import S3ToSqlOperator +from airflow.utils import db +from airflow.utils.session import create_session + + +class TestS3ToSqlTransfer: + def setup_method(self): + configuration.conf.load_test_config() + + db.merge_conn( + models.Connection( + conn_id="s3_test", + conn_type="s3", + schema="test", + extra='{"aws_access_key_id": "aws_access_key_id", "aws_secret_access_key":' + ' "aws_secret_access_key"}', + ) + ) + db.merge_conn( + models.Connection( + conn_id="sql_test", + conn_type="postgres", + host="some.host.com", + schema="test_db", + login="user", + password="password", + ) + ) + + self.s3_to_sql_transfer_kwargs = { + "task_id": "s3_to_sql_task", + "aws_conn_id": "s3_test", + "sql_conn_id": "sql_test", + "s3_key": "test/test.csv", + "s3_bucket": "testbucket", + "table": "sql_table", + "column_list": ["Column1", "Column2"], + "schema": "sql_schema", + "commit_every": 5000, + } + + @pytest.fixture() + def mock_parser(self): + return MagicMock() + + @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile") + @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.DbApiHook.insert_rows") + @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key") + def test_execute(self, mock_get_key, mock_insert_rows, mock_tempfile, mock_parser): Review Comment: For the sake to have a better coverage, could you add a test when the hook does not have the method `insert_rows` defined? ########## airflow/providers/amazon/aws/transfers/s3_to_sql.py: ########## @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Callable, Iterable, Sequence + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.common.sql.hooks.sql import DbApiHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class S3ToSqlOperator(BaseOperator): + """ + Loads Data from S3 into a SQL Database. + You need to provide a parser function that takes a filename as an input + and returns a iterable of rows + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3ToSqlOperator` + + :param schema: reference to a specific schema in SQL database + :param table: reference to a specific table in SQL database + :param s3_bucket: reference to a specific S3 bucket + :param s3_key: reference to a specific S3 key + :param sql_conn_id: reference to a specific SQL database. Must be of type DBApiHook + :param aws_conn_id: reference to a specific S3 / AWS connection + :param column_list: list of column names to use in the insert SQL. + :param commit_every: The maximum number of rows to insert in one + transaction. Set to `0` to insert all rows in one transaction. + :param parser: parser function that takes a filepath as input and returns an iterable. + e.g. to use a CSV parser that yields rows line-by-line, pass the following + function: + + def parse_csv(filepath): Review Comment: Love it! ########## tests/system/providers/amazon/aws/example_s3_to_sql.py: ########## @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task +from airflow.exceptions import AirflowException +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, + S3DeleteObjectsOperator, + S3CreateObjectOperator, +) +from airflow.providers.amazon.aws.transfers.s3_to_sql import S3ToSqlOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder, ENV_ID_KEY +import os + +sys_test_context_task = SystemTestContextBuilder().build() + +DAG_ID = "example_s3_to_sql" + +SQL_TABLE_NAME = "cocktails" +SQL_COLUMN_LIST = ["cocktail_id", "cocktail_name", "base_spirit"] +SAMPLE_DATA = """ +1;Caipirinha;Cachaca\n +2;Bramble;Gin\n +3;Daiquiri;Rum +""" + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2023, 1, 1), + schedule="@once", + catchup=False, + tags=["example"], +) as dag: + + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + + s3_bucket_name = f"{env_id}-bucket" + s3_key = f"{env_id}/files/cocktail_list.csv" + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=s3_bucket_name, + ) + + create_object = S3CreateObjectOperator( + task_id="create_object", + s3_bucket=s3_bucket_name, + s3_key=s3_key, + data=SAMPLE_DATA, + replace=True, + ) + + create_table = SQLExecuteQueryOperator( + task_id="create_sample_table", + sql=f""" + CREATE TABLE IF NOT EXISTS {SQL_TABLE_NAME} ( + cocktail_id INT NOT NULL, + cocktail_name VARCHAR NOT NULL, + base_spirit VARCHAR NOT NULL); + """, + ) + + # [START howto_transfer_s3_to_sql] + # + # This operator requires a parser method. The Parser should take a filename as input + # and return an iterable of rows. + # This example parser uses the builtin csv library and returns a list of rows + # + def parse_csv_to_list(filepath): + import csv + + with open(filepath, newline="") as file: + return [row for row in csv.reader(file)] + + transfer_s3_to_sql = S3ToSqlOperator( + task_id="transfer_s3_to_sql", + s3_bucket=s3_bucket_name, + s3_key=s3_key, + table=SQL_TABLE_NAME, + column_list=SQL_COLUMN_LIST, + parser=parse_csv_to_list, + ) + # [END howto_transfer_s3_to_sql] + + # [START howto_transfer_s3_to_sql_generator] + # + # As the parser can return any kind of iterator, a generator is also allowed. + # This example parser returns a generator which prevents python from loading + # the whole file into memory. + # + + def parse_csv_to_generator(filepath): + import csv + + with open(filepath, newline="") as file: + yield from csv.reader(file) + + transfer_s3_to_sql_generator = S3ToSqlOperator( + task_id="transfer_s3_to_sql_paser_to_generator", + s3_bucket=s3_bucket_name, + s3_key=s3_key, + table=SQL_TABLE_NAME, + column_list=SQL_COLUMN_LIST, + parser=parse_csv_to_generator, + ) + # [END howto_transfer_s3_to_sql_generator] + + drop_table = SQLExecuteQueryOperator( + trigger_rule=TriggerRule.ALL_DONE, task_id="drop_table", sql=f"DROP TABLE {SQL_TABLE_NAME}" + ) + + delete_s3_objects = S3DeleteObjectsOperator( + trigger_rule=TriggerRule.ALL_DONE, + task_id="delete_objects", + bucket=s3_bucket_name, + keys=s3_key, + ) + + delete_s3_bucket = S3DeleteBucketOperator( + trigger_rule=TriggerRule.ALL_DONE, + task_id="delete_bucket", + bucket_name=s3_bucket_name, + force_delete=True, + ) + + @task(trigger_rule=TriggerRule.ONE_FAILED, retries=0) + def watcher(): Review Comment: This task is already defined in `tests/system/utils/watcher.py`. No need to redefine it, you can just import it ########## airflow/providers/amazon/aws/transfers/s3_to_sql.py: ########## @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.common.sql.hooks.sql import DbApiHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from typing_extensions import Literal + +try: + import csv as csv +except ImportError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException from e + + +class S3ToSqlOperator(BaseOperator): + """ + Loads Data from S3 into a SQL Database. + Data should be readable as CSV. + + This operator downloads a file from an S3, reads it via `csv.reader` + and inserts the data into a SQL database using `insert_rows` method. + All SQL hooks are supported, as long as it is of type DbApiHook + + Extra arguments can be passed to it by using csv_reader_kwargs parameter. + (e.g. Use different quoting or delimiters) + Here you will find a list of all kwargs + https://docs.python.org/3/library/csv.html#csv.reader + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3ToSqlOperator` + + :param schema: reference to a specific schema in SQL database + :param table: reference to a specific table in SQL database + :param s3_bucket: reference to a specific S3 bucket + :param s3_key: reference to a specific S3 key + :param sql_conn_id: reference to a specific SQL database. Must be of type DBApiHook + :param aws_conn_id: reference to a specific S3 / AWS connection + :param column_list: list of column names. + Set to `infer` if column names should be read from first line of CSV file (default) + :param skip_first_line: If first line of CSV file should be skipped. + If `column_list` is set to 'infer', this is ignored + :param commit_every: The maximum number of rows to insert in one + transaction. Set to `0` to insert all rows in one transaction. + :param csv_reader_kwargs: key word arguments to pass to csv.reader(). + This lets you control how the CSV is read. + e.g. To use a different delimiter, pass the following dict: + {'delimiter' : ';'} + """ + + template_fields: Sequence[str] = ( + "s3_bucket", + "s3_key", + "schema", + "table", + "column_list", + "sql_conn_id", + ) + template_ext: Sequence[str] = () + ui_color = "#f4a460" + + def __init__( + self, + *, + s3_key: str, + s3_bucket: str, + table: str, + column_list: Literal["infer"] | list[str] | None = "infer", + commit_every: int = 1000, + schema: str | None = None, + skip_first_row: bool = False, + sql_conn_id: str = "sql_default", + aws_conn_id: str = "aws_default", + csv_reader_kwargs: dict[str, Any] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.table = table + self.schema = schema + self.aws_conn_id = aws_conn_id + self.sql_conn_id = sql_conn_id + self.column_list = column_list + self.commit_every = commit_every + self.skip_first_row = skip_first_row + if csv_reader_kwargs: + self.csv_reader_kwargs = csv_reader_kwargs + else: + self.csv_reader_kwargs = {} + + def execute(self, context: Context) -> None: + + self.log.info("Loading %s to SQL table %s...", self.s3_key, self.table) + + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + self._file = s3_hook.download_file(key=self.s3_key, bucket_name=self.s3_bucket) + + hook = self._get_hook() + try: + # open with newline='' as recommended + # https://docs.python.org/3/library/csv.html#csv.reader + with open(self._file, newline="") as file: + reader = csv.reader(file, **self.csv_reader_kwargs) + + if self.column_list == "infer": + self.column_list = list(next(reader)) + self.log.info("Column Names inferred from csv: %s", self.column_list) + elif self.skip_first_row: + next(reader) + + hook.insert_rows( + table=self.table, + schema=self.schema, + target_fields=self.column_list, + rows=reader, + commit_every=self.commit_every, + ) + + finally: + # Remove file downloaded from s3 to be idempotent. + os.remove(self._file) + + def _get_hook(self) -> DbApiHook: Review Comment: There is some work going on to standardize the hook access in Amazon provider package. See https://github.com/apache/airflow/pull/29001. I agree with you it is not necessary to store the hook in a property but (and this is only my personal opinion), using @cached_property makes the code cleaner -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
