josh-fell commented on a change in pull request #18447:
URL: https://github.com/apache/airflow/pull/18447#discussion_r720321218



##########
File path: airflow/providers/amazon/aws/example_dags/example_redshift.py
##########
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This is an example dag for using `RedshiftSQLOperator` to authenticate with 
Amazon Redshift
+then execute a simple select statement
+"""
+# [START redshift_operator_howto_guide]
+from airflow import DAG
+from airflow.providers.amazon.aws.operators.redshift import RedshiftSQLOperator
+from airflow.utils.dates import days_ago
+
+with DAG(dag_id="example_redshift", start_date=days_ago(1), 
schedule_interval=None, tags=['example']) as dag:

Review comment:
       There is an ongoing cleanup effort across all example DAGs to transition 
away from using `days_ago(n)` for `start_date` in deference for a static value 
since it's best practice.  We should have new DAGs contain a static 
`start_date` value; we've been using `datetime(2021, 1, 1)` but it doesn't 
matter what the value is really.  Sorry for not catching this earlier.

##########
File path: airflow/providers/amazon/aws/hooks/redshift.py
##########
@@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: 
str, cluster_identifier:
             ClusterIdentifier=cluster_identifier,
         )
         return response['Snapshot'] if response['Snapshot'] else None
+
+
+class RedshiftSQLHook(DbApiHook):
+    """
+    Execute statements against Amazon Redshift, using redshift_connector
+
+    This hook requires the redshift_conn_id connection. This connection must
+    be initialized with the host, port, login, password. Additional connection

Review comment:
       If the connection requires `host`, `port`, `login` and `password`, it 
would be a good idea if there was logic in the `_get_conn_params()` method to 
verify and raise an exception if any are missing.

##########
File path: airflow/providers/amazon/aws/hooks/redshift.py
##########
@@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: 
str, cluster_identifier:
             ClusterIdentifier=cluster_identifier,
         )
         return response['Snapshot'] if response['Snapshot'] else None
+
+
+class RedshiftSQLHook(DbApiHook):
+    """
+    Execute statements against Amazon Redshift, using redshift_connector
+
+    This hook requires the redshift_conn_id connection. This connection must
+    be initialized with the host, port, login, password. Additional connection
+    options can be passed to extra as a JSON string.
+
+    :param redshift_conn_id: reference to
+        :ref:`Amazon Redshift connection id<howto/connection:redshift>`
+    :type redshift_conn_id: str
+
+    .. note::
+        get_sqlalchemy_engine() and get_uri() depend on 
sqlalchemy-amazon-redshift
+    """
+
+    conn_name_attr = 'redshift_conn_id'
+    default_conn_name = 'redshift_default'
+    conn_type = 'redshift+redshift_connector'
+    hook_name = 'Amazon Redshift'
+    supports_autocommit = True
+
+    @staticmethod
+    def get_ui_field_behavior() -> Dict:
+        """Returns custom field behavior"""
+        return {
+            "hidden_fields": [],
+            "relabeling": {'login': 'User', 'schema': 'Database'},
+        }
+
+    @cached_property
+    def conn(self):
+        return self.get_connection(self.redshift_conn_id)  # type: 
ignore[attr-defined]
+
+    def _get_conn_params(self) -> Dict[str, Union[str, int]]:
+        """Helper method to retrieve connection args"""
+        conn = self.conn
+
+        conn_params: Dict[str, Union[str, int]] = {}
+
+        if conn.login:
+            conn_params['user'] = conn.login
+        if conn.password:
+            conn_params['password'] = conn.password
+        if conn.host:
+            conn_params['host'] = conn.host
+        if conn.port:
+            conn_params['port'] = conn.port
+        if conn.schema:
+            conn_params['database'] = conn.schema
+
+        return conn_params
+
+    def get_uri(self) -> str:
+        """
+        Override DbApiHook get_uri method for get_sqlalchemy_engine()
+
+        .. note::
+            Value passed to connection extra parameter will be excluded
+            from returned uri but passed to get_sqlalchemy_engine()
+            by default
+        """
+        from sqlalchemy.engine.url import URL
+
+        conn_params = self._get_conn_params()
+
+        conn = self.conn
+
+        conn_type = conn.conn_type or RedshiftSQLHook.conn_type
+
+        if 'user' in conn_params:
+            conn_params['username'] = conn_params.pop('user')

Review comment:
       Can the statement in `_get_conn_params()` for `user` assignment be 
changed to have the `conn_params` key named "username" instead of deleting the 
"user" and adding the "username" key here?
   
   From this:
   ```python
   if conn.login:
       conn_params['user'] = conn.login
   ```
   to:
   ```python
   if conn.login:
       conn_params['username'] = conn.login
   ```

##########
File path: tests/providers/amazon/aws/hooks/test_redshift.py
##########
@@ -103,3 +106,47 @@ def test_cluster_status_returns_available_cluster(self):
         hook = RedshiftHook(aws_conn_id='aws_default')
         status = hook.cluster_status('test_cluster')
         assert status == 'available'
+
+
+class TestRedshiftSQLHookConn(unittest.TestCase):
+    def setUp(self):
+        super().setUp()
+
+        self.connection = Connection(login='login', password='password', 
host='host', port=5439, schema="dev")
+
+        self.db_hook = RedshiftSQLHook()
+        self.db_hook.get_connection = mock.Mock()
+        self.db_hook.get_connection.return_value = self.connection
+
+    def test_get_uri(self):

Review comment:
       If connection-parameter verification logic is added to the 
`RedshiftSQLHook._get_conn_params()` method, adding a  test validating the 
logic should be included here as well.

##########
File path: airflow/providers/amazon/aws/hooks/redshift.py
##########
@@ -126,3 +135,101 @@ def create_cluster_snapshot(self, snapshot_identifier: 
str, cluster_identifier:
             ClusterIdentifier=cluster_identifier,
         )
         return response['Snapshot'] if response['Snapshot'] else None
+
+
+class RedshiftSQLHook(DbApiHook):
+    """
+    Execute statements against Amazon Redshift, using redshift_connector
+
+    This hook requires the redshift_conn_id connection. This connection must
+    be initialized with the host, port, login, password. Additional connection

Review comment:
       Is `schema` (aka `database` in this connection's case) required as well?

##########
File path: airflow/providers/amazon/aws/operators/redshift.py
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import List, Optional, Union
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.redshift import RedshiftSQLHook
+
+
+class RedshiftSQLOperator(BaseOperator):
+    """
+    Executes SQL Statements against an Amazon Redshift cluster
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:RedshiftSQLOperator`
+
+    :param sql: the sql code to be executed
+    :type sql: Can receive a str representing a sql statement,
+        a list of str (sql statements)
+    :param redshift_conn_id: reference to
+        :ref:`Amazon Redshift connection id<howto/connection:redshift>`
+    :type redshift_conn_id: str
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: dict or iterable
+    :param autocommit: if True, each command is automatically committed.
+        (default value: False)
+    :type autocommit: bool
+    """
+
+    template_fields = ('sql',)
+    template_ext = ('.sql',)
+
+    def __init__(
+        self,
+        *,
+        sql: Union[str, List[str]],
+        redshift_conn_id: str = 'redshift_default',

Review comment:
       Knowing that it's been the pattern in operators for a long time but it 
seems cleaner to not hardcode the default value for `redshift_conn_id` but use 
`RedshiftSQLHook.defaul_conn_name`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to