vincbeck commented on code in PR #25717:
URL: https://github.com/apache/airflow/pull/25717#discussion_r954260289
##########
airflow/providers/apache/drill/operators/drill.py:
##########
@@ -15,16 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Sequence
-from airflow.models import BaseOperator
-from airflow.providers.apache.drill.hooks.drill import DrillHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
if TYPE_CHECKING:
- from airflow.utils.context import Context
+ pass
Review Comment:
This if looks useless to me
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -63,13 +75,13 @@ class BaseSQLOperator(BaseOperator):
The provided method is .get_db_hook(). The default behavior will try to
retrieve the DB hook based on connection type.
- You can custom the behavior by overriding the .get_db_hook() method.
+ You can customize the behavior by overriding the .get_db_hook() method.
"""
def __init__(
self,
*,
- conn_id: Optional[str] = None,
+ conn_id: str,
Review Comment:
This seems not backward compatible to me
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
return self._hook
+class SQLExecuteQueryOperator(BaseSQLOperator):
+ """
+ Executes SQL code in a specific database
+ :param sql: the SQL code or string pointing to a template file to be
executed (templated).
+ File must have a '.sql' extensions.
+ :param handler: (optional) the function that will be applied to the cursor.
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :param conn_id: reference to a specific database
Review Comment:
Some parameters are missing for the docstring
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
return self._hook
+class SQLExecuteQueryOperator(BaseSQLOperator):
+ """
+ Executes SQL code in a specific database
+ :param sql: the SQL code or string pointing to a template file to be
executed (templated).
+ File must have a '.sql' extensions.
+ :param handler: (optional) the function that will be applied to the cursor.
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :param conn_id: reference to a specific database
+ """
+
+ template_fields: Sequence[str] = ('sql', 'parameters')
+ template_ext: Sequence[str] = ('.sql',)
+ ui_color = '#cdaaed'
+
+ def __init__(
+ self,
+ *,
+ sql: Union[str, List[str]],
+ autocommit: bool = False,
+ parameters: Optional[Union[Mapping, Iterable]] = None,
+ handler: Callable[[Any], Any] = fetch_all_handler,
+ split_statements: bool = False,
+ return_last: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.sql = sql
+ self.autocommit = autocommit
+ self.parameters = parameters
+ self.handler = handler
+ self.split_statements = split_statements
+ self.return_last = return_last
+
+ def execute(self, context):
+ self.log.info('Executing: %s', self.sql)
+ hook = self.get_db_hook()
+ if self.do_xcom_push:
+ output = hook.run(
+ self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ handler=self.handler,
+ split_statements=self.split_statements,
+ return_last=self.return_last,
+ )
+ else:
+ output = hook.run(self.sql, autocommit=self.autocommit,
parameters=self.parameters)
+
+ if hasattr(self, '_process_output'):
+ for out in output:
+ self._process_output(*out)
+
+ return output
Review Comment:
After, it is up to the user to use the returned value or not (using Xcom)
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
return self._hook
+class SQLExecuteQueryOperator(BaseSQLOperator):
Review Comment:
Please create unit tests for this operator. You should not rely on unit
tests made for all the different DB operators because at some point we are
going to remove them (since they are now deprecated). Along the operators we'll
remove, we'll remove as well the unit tests associated to them
##########
airflow/providers/jdbc/operators/jdbc.py:
##########
@@ -15,18 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import warnings
+from typing import TYPE_CHECKING, Sequence
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional,
Sequence, Union
-
-from airflow.models import BaseOperator
-from airflow.providers.common.sql.hooks.sql import fetch_all_handler
-from airflow.providers.jdbc.hooks.jdbc import JdbcHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
if TYPE_CHECKING:
Review Comment:
Needed?
##########
airflow/providers/trino/operators/trino.py:
##########
@@ -16,19 +16,19 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains the Trino operator."""
-
-from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence,
Union
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
from trino.exceptions import TrinoQueryError
-from airflow.models import BaseOperator
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.trino.hooks.trino import TrinoHook
if TYPE_CHECKING:
Review Comment:
Needed?
##########
airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -80,53 +80,44 @@ class DatabricksSqlOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, Iterable[str]],
databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
http_path: Optional[str] = None,
sql_endpoint_name: Optional[str] = None,
- parameters: Optional[Union[Iterable, Mapping]] = None,
session_configuration=None,
http_headers: Optional[List[Tuple[str, str]]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
- do_xcom_push: bool = False,
output_path: Optional[str] = None,
output_format: str = 'csv',
csv_params: Optional[Dict[str, Any]] = None,
client_parameters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
- """Creates a new ``DatabricksSqlOperator``."""
- super().__init__(**kwargs)
+ super().__init__(conn_id=databricks_conn_id, **kwargs)
self.databricks_conn_id = databricks_conn_id
- self.sql = sql
- self._http_path = http_path
- self._sql_endpoint_name = sql_endpoint_name
self._output_path = output_path
self._output_format = output_format
self._csv_params = csv_params
- self.parameters = parameters
- self.do_xcom_push = do_xcom_push
- self.session_config = session_configuration
- self.http_headers = http_headers
- self.catalog = catalog
- self.schema = schema
- self.client_parameters = client_parameters or {}
- def _get_hook(self) -> DatabricksSqlHook:
- return DatabricksSqlHook(
- self.databricks_conn_id,
- http_path=self._http_path,
- session_configuration=self.session_config,
- sql_endpoint_name=self._sql_endpoint_name,
- http_headers=self.http_headers,
- catalog=self.catalog,
- schema=self.schema,
- caller="DatabricksSqlOperator",
- **self.client_parameters,
- )
+ client_parameters = {} if client_parameters is None else
client_parameters
+ hook_params = kwargs.pop('hook_params', {})
- def _format_output(self, schema, results):
+ self.hook_params = {
+ 'http_path': http_path,
+ 'session_configuration': session_configuration,
+ 'sql_endpoint_name': sql_endpoint_name,
+ 'http_headers': http_headers,
+ 'catalog': catalog,
+ 'schema': schema,
+ 'caller': "DatabricksSqlOperator",
+ **client_parameters,
+ **hook_params,
+ }
+
+ def get_db_hook(self) -> DatabricksSqlHook:
+ return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)
Review Comment:
In the other operators, defining only the `hook_params` seem enough. e.g.
`kwargs['hook_params'] = {'schema': schema, **hook_params}`. Why in this case
you need to override `get_db_hook`?
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -112,6 +124,61 @@ def get_db_hook(self) -> DbApiHook:
return self._hook
+class SQLExecuteQueryOperator(BaseSQLOperator):
+ """
+ Executes SQL code in a specific database
+ :param sql: the SQL code or string pointing to a template file to be
executed (templated).
+ File must have a '.sql' extensions.
+ :param handler: (optional) the function that will be applied to the cursor.
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :param conn_id: reference to a specific database
+ """
+
+ template_fields: Sequence[str] = ('sql', 'parameters')
+ template_ext: Sequence[str] = ('.sql',)
+ ui_color = '#cdaaed'
+
+ def __init__(
+ self,
+ *,
+ sql: Union[str, List[str]],
+ autocommit: bool = False,
+ parameters: Optional[Union[Mapping, Iterable]] = None,
+ handler: Callable[[Any], Any] = fetch_all_handler,
+ split_statements: bool = False,
+ return_last: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.sql = sql
+ self.autocommit = autocommit
+ self.parameters = parameters
+ self.handler = handler
+ self.split_statements = split_statements
+ self.return_last = return_last
+
+ def execute(self, context):
+ self.log.info('Executing: %s', self.sql)
+ hook = self.get_db_hook()
+ if self.do_xcom_push:
+ output = hook.run(
+ self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ handler=self.handler,
+ split_statements=self.split_statements,
+ return_last=self.return_last,
+ )
+ else:
+ output = hook.run(self.sql, autocommit=self.autocommit,
parameters=self.parameters)
+
+ if hasattr(self, '_process_output'):
+ for out in output:
+ self._process_output(*out)
+
+ return output
Review Comment:
This if is not really clear to me. Why not splitting the statements if
`do_xcom_push` is `False`? How is that related? Same for `return_last`. I would
do it differently:
```suggestion
output = hook.run(
self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler,
split_statements=self.split_statements,
return_last=self.return_last,
)
if hasattr(self, '_process_output'):
for out in output:
self._process_output(*out)
return output
```
##########
airflow/providers/postgres/operators/postgres.py:
##########
@@ -15,19 +15,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Mapping, Optional, Sequence
from psycopg2.sql import SQL, Identifier
-from airflow.models import BaseOperator
-from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.www import utils as wwwutils
if TYPE_CHECKING:
Review Comment:
Needed?
##########
airflow/providers/exasol/operators/exasol.py:
##########
@@ -15,16 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Sequence, Optional
-from airflow.models import BaseOperator
-from airflow.providers.exasol.hooks.exasol import ExasolHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
if TYPE_CHECKING:
- from airflow.utils.context import Context
+ pass
Review Comment:
Is it needed?
##########
airflow/providers/vertica/operators/vertica.py:
##########
@@ -15,16 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
-from airflow.models import BaseOperator
-from airflow.providers.vertica.hooks.vertica import VerticaHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
if TYPE_CHECKING:
Review Comment:
Needed?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]