vincbeck commented on code in PR #25717:
URL: https://github.com/apache/airflow/pull/25717#discussion_r954260289


##########
airflow/providers/apache/drill/operators/drill.py:
##########
@@ -15,16 +15,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Sequence
 
-from airflow.models import BaseOperator
-from airflow.providers.apache.drill.hooks.drill import DrillHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 
 if TYPE_CHECKING:
-    from airflow.utils.context import Context
+    pass

Review Comment:
   This if looks useless to me



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -63,13 +75,13 @@ class BaseSQLOperator(BaseOperator):
 
     The provided method is .get_db_hook(). The default behavior will try to
     retrieve the DB hook based on connection type.
-    You can custom the behavior by overriding the .get_db_hook() method.
+    You can customize the behavior by overriding the .get_db_hook() method.
     """
 
     def __init__(
         self,
         *,
-        conn_id: Optional[str] = None,
+        conn_id: str,

Review Comment:
   This seems not backward compatible to me



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
         return self._hook
 
 
+class SQLExecuteQueryOperator(BaseSQLOperator):
+    """
+    Executes SQL code in a specific database
+    :param sql: the SQL code or string pointing to a template file to be 
executed (templated).
+    File must have a '.sql' extensions.
+    :param handler: (optional) the function that will be applied to the cursor.
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :param conn_id: reference to a specific database

Review Comment:
   Some parameters are missing for the docstring



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
         return self._hook
 
 
+class SQLExecuteQueryOperator(BaseSQLOperator):
+    """
+    Executes SQL code in a specific database
+    :param sql: the SQL code or string pointing to a template file to be 
executed (templated).
+    File must have a '.sql' extensions.
+    :param handler: (optional) the function that will be applied to the cursor.
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :param conn_id: reference to a specific database
+    """
+
+    template_fields: Sequence[str] = ('sql', 'parameters')
+    template_ext: Sequence[str] = ('.sql',)
+    ui_color = '#cdaaed'
+
+    def __init__(
+        self,
+        *,
+        sql: Union[str, List[str]],
+        autocommit: bool = False,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        handler: Callable[[Any], Any] = fetch_all_handler,
+        split_statements: bool = False,
+        return_last: bool = True,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.sql = sql
+        self.autocommit = autocommit
+        self.parameters = parameters
+        self.handler = handler
+        self.split_statements = split_statements
+        self.return_last = return_last
+
+    def execute(self, context):
+        self.log.info('Executing: %s', self.sql)
+        hook = self.get_db_hook()
+        if self.do_xcom_push:
+            output = hook.run(
+                self.sql,
+                autocommit=self.autocommit,
+                parameters=self.parameters,
+                handler=self.handler,
+                split_statements=self.split_statements,
+                return_last=self.return_last,
+            )
+        else:
+            output = hook.run(self.sql, autocommit=self.autocommit, 
parameters=self.parameters)
+
+        if hasattr(self, '_process_output'):
+            for out in output:
+                self._process_output(*out)
+
+        return output

Review Comment:
   After, it is up to the user to use the returned value or not (using Xcom)



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
         return self._hook
 
 
+class SQLExecuteQueryOperator(BaseSQLOperator):

Review Comment:
   Please create unit tests for this operator. You should not rely on unit 
tests made for all the different DB operators because at some point we are 
going to remove them (since they are now deprecated). Along the operators we'll 
remove, we'll remove as well the unit tests associated to them



##########
airflow/providers/jdbc/operators/jdbc.py:
##########
@@ -15,18 +15,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import warnings
+from typing import TYPE_CHECKING, Sequence
 
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, 
Sequence, Union
-
-from airflow.models import BaseOperator
-from airflow.providers.common.sql.hooks.sql import fetch_all_handler
-from airflow.providers.jdbc.hooks.jdbc import JdbcHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 
 if TYPE_CHECKING:

Review Comment:
   Needed?



##########
airflow/providers/trino/operators/trino.py:
##########
@@ -16,19 +16,19 @@
 # specific language governing permissions and limitations
 # under the License.
 """This module contains the Trino operator."""
-
-from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, 
Union
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
 
 from trino.exceptions import TrinoQueryError
 
-from airflow.models import BaseOperator
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 from airflow.providers.trino.hooks.trino import TrinoHook
 
 if TYPE_CHECKING:

Review Comment:
   Needed?



##########
airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -80,53 +80,44 @@ class DatabricksSqlOperator(BaseOperator):
     def __init__(
         self,
         *,
-        sql: Union[str, Iterable[str]],
         databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
         http_path: Optional[str] = None,
         sql_endpoint_name: Optional[str] = None,
-        parameters: Optional[Union[Iterable, Mapping]] = None,
         session_configuration=None,
         http_headers: Optional[List[Tuple[str, str]]] = None,
         catalog: Optional[str] = None,
         schema: Optional[str] = None,
-        do_xcom_push: bool = False,
         output_path: Optional[str] = None,
         output_format: str = 'csv',
         csv_params: Optional[Dict[str, Any]] = None,
         client_parameters: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> None:
-        """Creates a new ``DatabricksSqlOperator``."""
-        super().__init__(**kwargs)
+        super().__init__(conn_id=databricks_conn_id, **kwargs)
         self.databricks_conn_id = databricks_conn_id
-        self.sql = sql
-        self._http_path = http_path
-        self._sql_endpoint_name = sql_endpoint_name
         self._output_path = output_path
         self._output_format = output_format
         self._csv_params = csv_params
-        self.parameters = parameters
-        self.do_xcom_push = do_xcom_push
-        self.session_config = session_configuration
-        self.http_headers = http_headers
-        self.catalog = catalog
-        self.schema = schema
-        self.client_parameters = client_parameters or {}
 
-    def _get_hook(self) -> DatabricksSqlHook:
-        return DatabricksSqlHook(
-            self.databricks_conn_id,
-            http_path=self._http_path,
-            session_configuration=self.session_config,
-            sql_endpoint_name=self._sql_endpoint_name,
-            http_headers=self.http_headers,
-            catalog=self.catalog,
-            schema=self.schema,
-            caller="DatabricksSqlOperator",
-            **self.client_parameters,
-        )
+        client_parameters = {} if client_parameters is None else 
client_parameters
+        hook_params = kwargs.pop('hook_params', {})
 
-    def _format_output(self, schema, results):
+        self.hook_params = {
+            'http_path': http_path,
+            'session_configuration': session_configuration,
+            'sql_endpoint_name': sql_endpoint_name,
+            'http_headers': http_headers,
+            'catalog': catalog,
+            'schema': schema,
+            'caller': "DatabricksSqlOperator",
+            **client_parameters,
+            **hook_params,
+        }
+
+    def get_db_hook(self) -> DatabricksSqlHook:
+        return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)

Review Comment:
   In the other operators, defining only the `hook_params` seem enough. e.g. 
`kwargs['hook_params'] = {'schema': schema, **hook_params}`. Why in this case 
you need to override `get_db_hook`?



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
         return self._hook
 
 
+class SQLExecuteQueryOperator(BaseSQLOperator):
+    """
+    Executes SQL code in a specific database
+    :param sql: the SQL code or string pointing to a template file to be 
executed (templated).
+    File must have a '.sql' extensions.
+    :param handler: (optional) the function that will be applied to the cursor.
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :param conn_id: reference to a specific database
+    """
+
+    template_fields: Sequence[str] = ('sql', 'parameters')
+    template_ext: Sequence[str] = ('.sql',)
+    ui_color = '#cdaaed'
+
+    def __init__(
+        self,
+        *,
+        sql: Union[str, List[str]],
+        autocommit: bool = False,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        handler: Callable[[Any], Any] = fetch_all_handler,
+        split_statements: bool = False,
+        return_last: bool = True,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.sql = sql
+        self.autocommit = autocommit
+        self.parameters = parameters
+        self.handler = handler
+        self.split_statements = split_statements
+        self.return_last = return_last
+
+    def execute(self, context):
+        self.log.info('Executing: %s', self.sql)
+        hook = self.get_db_hook()
+        if self.do_xcom_push:
+            output = hook.run(
+                self.sql,
+                autocommit=self.autocommit,
+                parameters=self.parameters,
+                handler=self.handler,
+                split_statements=self.split_statements,
+                return_last=self.return_last,
+            )
+        else:
+            output = hook.run(self.sql, autocommit=self.autocommit, 
parameters=self.parameters)
+
+        if hasattr(self, '_process_output'):
+            for out in output:
+                self._process_output(*out)
+
+        return output

Review Comment:
   This if is not really clear to me. Why not splitting the statements if 
`do_xcom_push` is `False`? How is that related? Same for `return_last`. I would 
do it differently:
   
   ```suggestion
               output = hook.run(
                   self.sql,
                   autocommit=self.autocommit,
                   parameters=self.parameters,
                   handler=self.handler,
                   split_statements=self.split_statements,
                   return_last=self.return_last,
               )
   
           if hasattr(self, '_process_output'):
               for out in output:
                   self._process_output(*out)
   
           return output
   ```



##########
airflow/providers/postgres/operators/postgres.py:
##########
@@ -15,19 +15,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Mapping, Optional, Sequence
 
 from psycopg2.sql import SQL, Identifier
 
-from airflow.models import BaseOperator
-from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 from airflow.www import utils as wwwutils
 
 if TYPE_CHECKING:

Review Comment:
   Needed?



##########
airflow/providers/exasol/operators/exasol.py:
##########
@@ -15,16 +15,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Sequence, Optional
 
-from airflow.models import BaseOperator
-from airflow.providers.exasol.hooks.exasol import ExasolHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 
 if TYPE_CHECKING:
-    from airflow.utils.context import Context
+    pass

Review Comment:
   Is it needed?



##########
airflow/providers/vertica/operators/vertica.py:
##########
@@ -15,16 +15,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
 
-from airflow.models import BaseOperator
-from airflow.providers.vertica.hooks.vertica import VerticaHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 
 if TYPE_CHECKING:

Review Comment:
   Needed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to