This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c609477260 Implements SqlToSlackApiFileOperator (#26374)
c609477260 is described below
commit c60947726082905f7b369a90e8d37fcdb873149f
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Nov 17 00:14:27 2022 +0400
Implements SqlToSlackApiFileOperator (#26374)
---
airflow/providers/slack/transfers/sql_to_slack.py | 186 ++++++++++++++++++---
airflow/providers/slack/utils/__init__.py | 40 ++++-
.../providers/slack/transfers/test_sql_to_slack.py | 164 +++++++++++++++++-
tests/providers/slack/utils/test_utils.py | 54 +++++-
4 files changed, 418 insertions(+), 26 deletions(-)
diff --git a/airflow/providers/slack/transfers/sql_to_slack.py
b/airflow/providers/slack/transfers/sql_to_slack.py
index bdd1cddd2b..cf5c01b22c 100644
--- a/airflow/providers/slack/transfers/sql_to_slack.py
+++ b/airflow/providers/slack/transfers/sql_to_slack.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
from pandas import DataFrame
@@ -25,13 +26,59 @@ from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.slack.hooks.slack import SlackHook
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
+from airflow.providers.slack.utils import parse_filename
if TYPE_CHECKING:
from airflow.utils.context import Context
-class SqlToSlackOperator(BaseOperator):
+class BaseSqlToSlackOperator(BaseOperator):
+ """
+ Operator implements base sql methods for SQL to Slack Transfer operators.
+
+ :param sql: The SQL query to be executed
+ :param sql_conn_id: reference to a specific DB-API Connection.
+ :param sql_hook_params: Extra config params to be passed to the underlying
hook.
+ Should match the desired hook constructor params.
+ :param parameters: The parameters to pass to the SQL query.
+ """
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ sql_conn_id: str,
+ sql_hook_params: dict | None = None,
+ parameters: Iterable | Mapping | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.sql_conn_id = sql_conn_id
+ self.sql_hook_params = sql_hook_params
+ self.sql = sql
+ self.parameters = parameters
+
+ def _get_hook(self) -> DbApiHook:
+ self.log.debug("Get connection for %s", self.sql_conn_id)
+ conn = BaseHook.get_connection(self.sql_conn_id)
+ hook = conn.get_hook(hook_params=self.sql_hook_params)
+ if not callable(getattr(hook, "get_pandas_df", None)):
+ raise AirflowException(
+ "This hook is not supported. The hook class must have
get_pandas_df method."
+ )
+ return hook
+
+ def _get_query_results(self) -> DataFrame:
+ sql_hook = self._get_hook()
+
+ self.log.info("Running SQL query: %s", self.sql)
+ df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters)
+ return df
+
+
+class SqlToSlackOperator(BaseSqlToSlackOperator):
"""
Executes an SQL statement in a given SQL connection and sends the results
to Slack. The results of the
query are rendered into the 'slack_message' parameter as a Pandas
dataframe using a JINJA variable called
@@ -79,12 +126,10 @@ class SqlToSlackOperator(BaseOperator):
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(
+ sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params,
parameters=parameters, **kwargs
+ )
- self.sql_conn_id = sql_conn_id
- self.sql_hook_params = sql_hook_params
- self.sql = sql
- self.parameters = parameters
self.slack_conn_id = slack_conn_id
self.slack_webhook_token = slack_webhook_token
self.slack_channel = slack_channel
@@ -97,23 +142,6 @@ class SqlToSlackOperator(BaseOperator):
"SqlToSlackOperator requires either a `slack_conn_id` or a
`slack_webhook_token` argument"
)
- def _get_hook(self) -> DbApiHook:
- self.log.debug("Get connection for %s", self.sql_conn_id)
- conn = BaseHook.get_connection(self.sql_conn_id)
- hook = conn.get_hook(hook_params=self.sql_hook_params)
- if not callable(getattr(hook, "get_pandas_df", None)):
- raise AirflowException(
- "This hook is not supported. The hook class must have
get_pandas_df method."
- )
- return hook
-
- def _get_query_results(self) -> DataFrame:
- sql_hook = self._get_hook()
-
- self.log.info("Running SQL query: %s", self.sql)
- df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters)
- return df
-
def _render_and_send_slack_message(self, context, df) -> None:
# Put the dataframe into the context and render the JINJA template
fields
context[self.results_df_name] = df
@@ -157,3 +185,115 @@ class SqlToSlackOperator(BaseOperator):
self._render_and_send_slack_message(context, df)
self.log.debug("Finished sending SQL data to Slack")
+
+
+class SqlToSlackApiFileOperator(BaseSqlToSlackOperator):
+ """
+ Executes an SQL statement in a given SQL connection and sends the results
to Slack API as file.
+
+ :param sql: The SQL query to be executed
+ :param sql_conn_id: reference to a specific DB-API Connection.
+ :param slack_conn_id: :ref:`Slack API Connection <howto/connection:slack>`.
+ :param slack_filename: Filename for display in slack.
+ Should contain supported extension which referenced to
``SUPPORTED_FILE_FORMATS``.
+ It is also possible to set compression in extension:
+ ``filename.csv.gzip``, ``filename.json.zip``, etc.
+ :param sql_hook_params: Extra config params to be passed to the underlying
hook.
+ Should match the desired hook constructor params.
+ :param parameters: The parameters to pass to the SQL query.
+ :param slack_channels: Comma-separated list of channel names or IDs where
the file will be shared.
+ If omitting this parameter, then file will send to workspace.
+ :param slack_initial_comment: The message text introducing the file in
specified ``slack_channels``.
+ :param slack_title: Title of file.
+ :param df_kwargs: Keyword arguments forwarded to
``pandas.DataFrame.to_{format}()`` method.
+
+ Example:
+ .. code-block:: python
+
+ SqlToSlackApiFileOperator(
+ task_id="sql_to_slack",
+ sql="SELECT 1 a, 2 b, 3 c",
+ sql_conn_id="sql-connection",
+ slack_conn_id="slack-api-connection",
+ slack_filename="awesome.json.gz",
+ slack_channels="#random,#general",
+ slack_initial_comment="Awesome load to compressed multiline JSON.",
+ df_kwargs={
+ "orient": "records",
+ "lines": True,
+ },
+ )
+ """
+
+ template_fields: Sequence[str] = (
+ "sql",
+ "slack_channels",
+ "slack_filename",
+ "slack_initial_comment",
+ "slack_title",
+ )
+ template_ext: Sequence[str] = (".sql", ".jinja", ".j2")
+ template_fields_renderers = {"sql": "sql", "slack_message": "jinja"}
+
+ SUPPORTED_FILE_FORMATS: Sequence[str] = ("csv", "json", "html")
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ sql_conn_id: str,
+ sql_hook_params: dict | None = None,
+ parameters: Iterable | Mapping | None = None,
+ slack_conn_id: str,
+ slack_filename: str,
+ slack_channels: str | Sequence[str] | None = None,
+ slack_initial_comment: str | None = None,
+ slack_title: str | None = None,
+ df_kwargs: dict | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params,
parameters=parameters, **kwargs
+ )
+ self.slack_conn_id = slack_conn_id
+ self.slack_filename = slack_filename
+ self.slack_channels = slack_channels
+ self.slack_initial_comment = slack_initial_comment
+ self.slack_title = slack_title
+ self.df_kwargs = df_kwargs or {}
+
+ def execute(self, context: Context) -> None:
+ # Parse file format from filename
+ output_file_format, _ = parse_filename(
+ filename=self.slack_filename,
+ supported_file_formats=self.SUPPORTED_FILE_FORMATS,
+ )
+
+ slack_hook = SlackHook(slack_conn_id=self.slack_conn_id)
+ with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}")
as fp:
+ # tempfile.NamedTemporaryFile used only for create and remove
temporary file,
+ # pandas will open file in correct mode itself depend on file type.
+ # So we close file descriptor here for avoid incidentally write
anything.
+ fp.close()
+
+ output_file_name = fp.name
+ output_file_format = output_file_format.upper()
+ df_result = self._get_query_results()
+ if output_file_format == "CSV":
+ df_result.to_csv(output_file_name, **self.df_kwargs)
+ elif output_file_format == "JSON":
+ df_result.to_json(output_file_name, **self.df_kwargs)
+ elif output_file_format == "HTML":
+ df_result.to_html(output_file_name, **self.df_kwargs)
+ else:
+ # Not expected that this error happen. This only possible
+ # if SUPPORTED_FILE_FORMATS extended and no actual
implementation for specific format.
+ raise AirflowException(f"Unexpected output file format:
{output_file_format}")
+
+ slack_hook.send_file(
+ channels=self.slack_channels,
+ file=output_file_name,
+ filename=self.slack_filename,
+ initial_comment=self.slack_initial_comment,
+ title=self.slack_title,
+ )
diff --git a/airflow/providers/slack/utils/__init__.py
b/airflow/providers/slack/utils/__init__.py
index dda12656d4..1071de6299 100644
--- a/airflow/providers/slack/utils/__init__.py
+++ b/airflow/providers/slack/utils/__init__.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import warnings
-from typing import Any
+from typing import Any, Sequence
from airflow.utils.types import NOTSET
@@ -77,3 +77,41 @@ class ConnectionExtraConfig:
if value != default:
value = int(value)
return value
+
+
+def parse_filename(
+ filename: str, supported_file_formats: Sequence[str], fallback: str | None
= None
+) -> tuple[str, str | None]:
+ """
+ Parse filetype and compression from given filename.
+ :param filename: filename to parse.
+ :param supported_file_formats: list of supported file extensions.
+ :param fallback: fallback to given file format.
+ :returns: filetype and compression (if specified)
+ """
+ if not filename:
+ raise ValueError("Expected 'filename' parameter is missing.")
+ if fallback and fallback not in supported_file_formats:
+ raise ValueError(f"Invalid fallback value {fallback!r}, expected one
of {supported_file_formats}.")
+
+ parts = filename.rsplit(".", 2)
+ try:
+ if len(parts) == 1:
+ raise ValueError(f"No file extension specified in filename
{filename!r}.")
+ if parts[-1] in supported_file_formats:
+ return parts[-1], None
+ elif len(parts) == 2:
+ raise ValueError(
+ f"Unsupported file format {parts[-1]!r}, expected one of
{supported_file_formats}."
+ )
+ else:
+ if parts[-2] not in supported_file_formats:
+ raise ValueError(
+ f"Unsupported file format '{parts[-2]}.{parts[-1]}', "
+ f"expected one of {supported_file_formats} with
compression extension."
+ )
+ return parts[-2], parts[-1]
+ except ValueError as ex:
+ if fallback:
+ return fallback, None
+ raise ex from None
diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py
b/tests/providers/slack/transfers/test_sql_to_slack.py
index 307469460b..23efa895a2 100644
--- a/tests/providers/slack/transfers/test_sql_to_slack.py
+++ b/tests/providers/slack/transfers/test_sql_to_slack.py
@@ -23,7 +23,11 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models import DAG, Connection
-from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator
+from airflow.providers.slack.transfers.sql_to_slack import (
+ BaseSqlToSlackOperator,
+ SqlToSlackApiFileOperator,
+ SqlToSlackOperator,
+)
from airflow.utils import timezone
TEST_DAG_ID = "sql_to_slack_unit_test"
@@ -31,6 +35,77 @@ TEST_TASK_ID = "sql_to_slack_unit_test_task"
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
+class TestBaseSqlToSlackOperator:
+ def setup_method(self):
+ self.default_op_kwargs = {
+ "sql": "SELECT 1",
+ "sql_conn_id": "test-sql-conn-id",
+ "sql_hook_params": None,
+ "parameters": None,
+ }
+
+ def test_execute_not_implemented(self):
+ """Test that no base implementation for
``BaseSqlToSlackOperator.execute()``."""
+ op = BaseSqlToSlackOperator(task_id="test_base_not_implements",
**self.default_op_kwargs)
+ with pytest.raises(NotImplementedError):
+ op.execute(mock.MagicMock())
+
+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
+ @mock.patch("airflow.models.connection.Connection.get_hook")
+ @pytest.mark.parametrize("conn_type", ["postgres", "snowflake"])
+ @pytest.mark.parametrize("sql_hook_params", [None, {"foo": "bar"}])
+ def test_get_hook(self, mock_get_hook, mock_get_conn, conn_type,
sql_hook_params):
+ class SomeDummyHook:
+ """Hook which implements ``get_pandas_df`` method"""
+
+ def get_pandas_df(self):
+ pass
+
+ expected_hook = SomeDummyHook()
+ mock_get_conn.return_value =
Connection(conn_id=f"test_connection_{conn_type}", conn_type=conn_type)
+ mock_get_hook.return_value = expected_hook
+ op_kwargs = {
+ **self.default_op_kwargs,
+ "sql_hook_params": sql_hook_params,
+ }
+ op = BaseSqlToSlackOperator(task_id="test_get_hook", **op_kwargs)
+ hook = op._get_hook()
+ mock_get_hook.assert_called_once_with(hook_params=sql_hook_params)
+ assert hook == expected_hook
+
+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
+ @mock.patch("airflow.models.connection.Connection.get_hook")
+ def test_get_not_supported_hook(self, mock_get_hook, mock_get_conn):
+ class SomeDummyHook:
+ """Hook which not implemented ``get_pandas_df`` method"""
+
+ mock_get_conn.return_value = Connection(conn_id="test_connection",
conn_type="test_connection")
+ mock_get_hook.return_value = SomeDummyHook()
+ op = BaseSqlToSlackOperator(task_id="test_get_not_supported_hook",
**self.default_op_kwargs)
+ error_message = r"This hook is not supported. The hook class must have
get_pandas_df method\."
+ with pytest.raises(AirflowException, match=error_message):
+ op._get_hook()
+
+
@mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_hook")
+ @pytest.mark.parametrize("sql", ["SELECT 42", "SELECT 1 FROM DUMMY WHERE
col = ?"])
+ @pytest.mark.parametrize("parameters", [None, {"col": "spam-egg"}])
+ def test_get_query_results(self, mock_op_get_hook, sql, parameters):
+ test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1])
+ mock_get_pandas_df = mock.MagicMock(return_value=test_df)
+ mock_hook = mock.MagicMock()
+ mock_hook.get_pandas_df = mock_get_pandas_df
+ mock_op_get_hook.return_value = mock_hook
+ op_kwargs = {
+ **self.default_op_kwargs,
+ "sql": sql,
+ "parameters": parameters,
+ }
+ op = BaseSqlToSlackOperator(task_id="test_get_query_results",
**op_kwargs)
+ df = op._get_query_results()
+ mock_get_pandas_df.assert_called_once_with(sql, parameters=parameters)
+ assert df is test_df
+
+
class TestSqlToSlackOperator:
def setup_method(self):
self.example_dag = DAG(TEST_DAG_ID, start_date=DEFAULT_DATE)
@@ -215,3 +290,90 @@ class TestSqlToSlackOperator:
assert hook.database == "database"
assert hook.role == "role"
assert hook.schema == "schema"
+
+
+class TestSqlToSlackApiFileOperator:
+ def setup_method(self):
+ self.default_op_kwargs = {
+ "sql": "SELECT 1",
+ "sql_conn_id": "test-sql-conn-id",
+ "slack_conn_id": "test-slack-conn-id",
+ "sql_hook_params": None,
+ "parameters": None,
+ }
+
+
@mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_query_results")
+ @mock.patch("airflow.providers.slack.transfers.sql_to_slack.SlackHook")
+ @pytest.mark.parametrize(
+ "filename,df_method",
+ [
+ ("awesome.json", "to_json"),
+ ("awesome.json.zip", "to_json"),
+ ("awesome.csv", "to_csv"),
+ ("awesome.csv.xz", "to_csv"),
+ ("awesome.html", "to_html"),
+ ],
+ )
+ @pytest.mark.parametrize("df_kwargs", [None, {}, {"foo": "bar"}])
+ @pytest.mark.parametrize("channels", ["#random", "#random,#general", None])
+ @pytest.mark.parametrize("initial_comment", [None, "Test Comment"])
+ @pytest.mark.parametrize("title", [None, "Test File Title"])
+ def test_send_file(
+ self,
+ mock_slack_hook_cls,
+ mock_get_query_results,
+ filename,
+ df_method,
+ df_kwargs,
+ channels,
+ initial_comment,
+ title,
+ ):
+ # Mock Hook
+ mock_send_file = mock.MagicMock()
+ mock_slack_hook_cls.return_value.send_file = mock_send_file
+
+ # Mock returns pandas.DataFrame and expected method
+ mock_df = mock.MagicMock()
+ mock_df_output_method = mock.MagicMock()
+ setattr(mock_df, df_method, mock_df_output_method)
+ mock_get_query_results.return_value = mock_df
+
+ op_kwargs = {
+ **self.default_op_kwargs,
+ "slack_conn_id": "expected-test-slack-conn-id",
+ "slack_filename": filename,
+ "slack_channels": channels,
+ "slack_initial_comment": initial_comment,
+ "slack_title": title,
+ "df_kwargs": df_kwargs,
+ }
+ op = SqlToSlackApiFileOperator(task_id="test_send_file", **op_kwargs)
+ op.execute(mock.MagicMock())
+
+
mock_slack_hook_cls.assert_called_once_with(slack_conn_id="expected-test-slack-conn-id")
+ mock_get_query_results.assert_called_once_with()
+ mock_df_output_method.assert_called_once_with(mock.ANY, **(df_kwargs
or {}))
+ mock_send_file.assert_called_once_with(
+ channels=channels,
+ filename=filename,
+ initial_comment=initial_comment,
+ title=title,
+ file=mock.ANY,
+ )
+
+ @pytest.mark.parametrize(
+ "filename",
+ [
+ "foo.parquet",
+ "bat.parquet.snappy",
+ "spam.xml",
+ "egg.xlsx",
+ ],
+ )
+ def test_unsupported_format(self, filename):
+ op = SqlToSlackApiFileOperator(
+ task_id="test_send_file", slack_filename=filename,
**self.default_op_kwargs
+ )
+ with pytest.raises(ValueError):
+ op.execute(mock.MagicMock())
diff --git a/tests/providers/slack/utils/test_utils.py
b/tests/providers/slack/utils/test_utils.py
index d794c80f60..bff3dbc658 100644
--- a/tests/providers/slack/utils/test_utils.py
+++ b/tests/providers/slack/utils/test_utils.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from airflow.providers.slack.utils import ConnectionExtraConfig
+from airflow.providers.slack.utils import ConnectionExtraConfig, parse_filename
class TestConnectionExtra:
@@ -92,3 +92,55 @@ class TestConnectionExtra:
)
assert extra_config.getint("int_arg_1") == 42
assert extra_config.getint("int_arg_2") == 9000
+
+
+class TestParseFilename:
+ SUPPORTED_FORMAT = ("so", "dll", "exe", "sh")
+
+ def test_error_parse_without_extension(self):
+ with pytest.raises(ValueError, match="No file extension specified in
filename"):
+ assert parse_filename("Untitled File", self.SUPPORTED_FORMAT)
+
+ @pytest.mark.parametrize(
+ "filename,expected_format",
+ [
+ ("libc.so", "so"),
+ ("kernel32.dll", "dll"),
+ ("xxx.mp4.exe", "exe"),
+ ("init.sh", "sh"),
+ ],
+ )
+ def test_parse_first_level(self, filename, expected_format):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT) ==
(expected_format, None)
+
+ @pytest.mark.parametrize("filename", ["New File.txt", "cats-memes.mp4"])
+ def test_error_parse_first_level(self, filename):
+ with pytest.raises(ValueError, match="Unsupported file format"):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT)
+
+ @pytest.mark.parametrize(
+ "filename,expected",
+ [
+ ("libc.so.6", ("so", "6")),
+ ("kernel32.dll.zip", ("dll", "zip")),
+ ("explorer.exe.7z", ("exe", "7z")),
+ ("init.sh.gz", ("sh", "gz")),
+ ],
+ )
+ def test_parse_second_level(self, filename, expected):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT) == expected
+
+ @pytest.mark.parametrize("filename", ["example.so.tar.gz", "w.i.e.r.d"])
+ def test_error_parse_second_level(self, filename):
+ with pytest.raises(ValueError, match="Unsupported file format.*with
compression extension."):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT)
+
+ @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt",
"example.so.tar.gz"])
+ @pytest.mark.parametrize("fallback", SUPPORTED_FORMAT)
+ def test_fallback(self, filename, fallback):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT, fallback) ==
(fallback, None)
+
+ @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt",
"example.so.tar.gz"])
+ def test_wrong_fallback(self, filename):
+ with pytest.raises(ValueError, match="Invalid fallback value"):
+ assert parse_filename(filename, self.SUPPORTED_FORMAT, "mp4")