josh-fell commented on a change in pull request #18447: URL: https://github.com/apache/airflow/pull/18447#discussion_r720321218
########## File path: airflow/providers/amazon/aws/example_dags/example_redshift.py ########## @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This is an example dag for using `RedshiftSQLOperator` to authenticate with Amazon Redshift +then execute a simple select statement +""" +# [START redshift_operator_howto_guide] +from airflow import DAG +from airflow.providers.amazon.aws.operators.redshift import RedshiftSQLOperator +from airflow.utils.dates import days_ago + +with DAG(dag_id="example_redshift", start_date=days_ago(1), schedule_interval=None, tags=['example']) as dag: Review comment: There is an ongoing cleanup effort across all example DAGs to transition away from using `days_ago(n)` for `start_date` in deference for a static value since it's best practice. We should have new DAGs contain a static `start_date` value; we've been using `datetime(2021, 1, 1)` but it doesn't matter what the value is really. Sorry for not catching this earlier. ########## File path: airflow/providers/amazon/aws/hooks/redshift.py ########## @@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: ClusterIdentifier=cluster_identifier, ) return response['Snapshot'] if response['Snapshot'] else None + + +class RedshiftSQLHook(DbApiHook): + """ + Execute statements against Amazon Redshift, using redshift_connector + + This hook requires the redshift_conn_id connection. This connection must + be initialized with the host, port, login, password. Additional connection Review comment: If the connection requires `host`, `port`, `login` and `password`, it would be a good idea if there was logic in the `_get_conn_params()` method to verify and raise an exception if any are missing. ########## File path: airflow/providers/amazon/aws/hooks/redshift.py ########## @@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: ClusterIdentifier=cluster_identifier, ) return response['Snapshot'] if response['Snapshot'] else None + + +class RedshiftSQLHook(DbApiHook): + """ + Execute statements against Amazon Redshift, using redshift_connector + + This hook requires the redshift_conn_id connection. This connection must + be initialized with the host, port, login, password. Additional connection + options can be passed to extra as a JSON string. + + :param redshift_conn_id: reference to + :ref:`Amazon Redshift connection id<howto/connection:redshift>` + :type redshift_conn_id: str + + .. note:: + get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift + """ + + conn_name_attr = 'redshift_conn_id' + default_conn_name = 'redshift_default' + conn_type = 'redshift+redshift_connector' + hook_name = 'Amazon Redshift' + supports_autocommit = True + + @staticmethod + def get_ui_field_behavior() -> Dict: + """Returns custom field behavior""" + return { + "hidden_fields": [], + "relabeling": {'login': 'User', 'schema': 'Database'}, + } + + @cached_property + def conn(self): + return self.get_connection(self.redshift_conn_id) # type: ignore[attr-defined] + + def _get_conn_params(self) -> Dict[str, Union[str, int]]: + """Helper method to retrieve connection args""" + conn = self.conn + + conn_params: Dict[str, Union[str, int]] = {} + + if conn.login: + conn_params['user'] = conn.login + if conn.password: + conn_params['password'] = conn.password + if conn.host: + conn_params['host'] = conn.host + if conn.port: + conn_params['port'] = conn.port + if conn.schema: + conn_params['database'] = conn.schema + + return conn_params + + def get_uri(self) -> str: + """ + Override DbApiHook get_uri method for get_sqlalchemy_engine() + + .. note:: + Value passed to connection extra parameter will be excluded + from returned uri but passed to get_sqlalchemy_engine() + by default + """ + from sqlalchemy.engine.url import URL + + conn_params = self._get_conn_params() + + conn = self.conn + + conn_type = conn.conn_type or RedshiftSQLHook.conn_type + + if 'user' in conn_params: + conn_params['username'] = conn_params.pop('user') Review comment: Can the statement in `_get_conn_params()` for `user` assignment be changed to have the `conn_params` key named "username" instead of deleting the "user" and adding the "username" key here? From this: ```python if conn.login: conn_params['user'] = conn.login ``` to: ```python if conn.login: conn_params['username'] = conn.login ``` ########## File path: tests/providers/amazon/aws/hooks/test_redshift.py ########## @@ -103,3 +106,47 @@ def test_cluster_status_returns_available_cluster(self): hook = RedshiftHook(aws_conn_id='aws_default') status = hook.cluster_status('test_cluster') assert status == 'available' + + +class TestRedshiftSQLHookConn(unittest.TestCase): + def setUp(self): + super().setUp() + + self.connection = Connection(login='login', password='password', host='host', port=5439, schema="dev") + + self.db_hook = RedshiftSQLHook() + self.db_hook.get_connection = mock.Mock() + self.db_hook.get_connection.return_value = self.connection + + def test_get_uri(self): Review comment: If connection-parameter verification logic is added to the `RedshiftSQLHook._get_conn_params()` method, adding a test validating the logic should be included here as well. ########## File path: airflow/providers/amazon/aws/hooks/redshift.py ########## @@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: ClusterIdentifier=cluster_identifier, ) return response['Snapshot'] if response['Snapshot'] else None + + +class RedshiftSQLHook(DbApiHook): + """ + Execute statements against Amazon Redshift, using redshift_connector + + This hook requires the redshift_conn_id connection. This connection must + be initialized with the host, port, login, password. Additional connection Review comment: Is `schema` (aka `database` in this connection's case) required as well? ########## File path: airflow/providers/amazon/aws/operators/redshift.py ########## @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.redshift import RedshiftSQLHook + + +class RedshiftSQLOperator(BaseOperator): + """ + Executes SQL Statements against an Amazon Redshift cluster + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftSQLOperator` + + :param sql: the sql code to be executed + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements) + :param redshift_conn_id: reference to + :ref:`Amazon Redshift connection id<howto/connection:redshift>` + :type redshift_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: dict or iterable + :param autocommit: if True, each command is automatically committed. + (default value: False) + :type autocommit: bool + """ + + template_fields = ('sql',) + template_ext = ('.sql',) + + def __init__( + self, + *, + sql: Union[str, List[str]], + redshift_conn_id: str = 'redshift_default', Review comment: Knowing that it's been the pattern in operators for a long time but it seems cleaner to not hardcode the default value for `redshift_conn_id` but use `RedshiftSQLHook.defaul_conn_name`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
