This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c609477260 Implements SqlToSlackApiFileOperator (#26374)
c609477260 is described below

commit c60947726082905f7b369a90e8d37fcdb873149f
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Nov 17 00:14:27 2022 +0400

    Implements SqlToSlackApiFileOperator (#26374)
---
 airflow/providers/slack/transfers/sql_to_slack.py  | 186 ++++++++++++++++++---
 airflow/providers/slack/utils/__init__.py          |  40 ++++-
 .../providers/slack/transfers/test_sql_to_slack.py | 164 +++++++++++++++++-
 tests/providers/slack/utils/test_utils.py          |  54 +++++-
 4 files changed, 418 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/slack/transfers/sql_to_slack.py 
b/airflow/providers/slack/transfers/sql_to_slack.py
index bdd1cddd2b..cf5c01b22c 100644
--- a/airflow/providers/slack/transfers/sql_to_slack.py
+++ b/airflow/providers/slack/transfers/sql_to_slack.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
 
 from pandas import DataFrame
@@ -25,13 +26,59 @@ from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.models import BaseOperator
 from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.slack.hooks.slack import SlackHook
 from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
+from airflow.providers.slack.utils import parse_filename
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class SqlToSlackOperator(BaseOperator):
+class BaseSqlToSlackOperator(BaseOperator):
+    """
+    Operator implements base sql methods for SQL to Slack Transfer operators.
+
+    :param sql: The SQL query to be executed
+    :param sql_conn_id: reference to a specific DB-API Connection.
+    :param sql_hook_params: Extra config params to be passed to the underlying 
hook.
+        Should match the desired hook constructor params.
+    :param parameters: The parameters to pass to the SQL query.
+    """
+
+    def __init__(
+        self,
+        *,
+        sql: str,
+        sql_conn_id: str,
+        sql_hook_params: dict | None = None,
+        parameters: Iterable | Mapping | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.sql_conn_id = sql_conn_id
+        self.sql_hook_params = sql_hook_params
+        self.sql = sql
+        self.parameters = parameters
+
+    def _get_hook(self) -> DbApiHook:
+        self.log.debug("Get connection for %s", self.sql_conn_id)
+        conn = BaseHook.get_connection(self.sql_conn_id)
+        hook = conn.get_hook(hook_params=self.sql_hook_params)
+        if not callable(getattr(hook, "get_pandas_df", None)):
+            raise AirflowException(
+                "This hook is not supported. The hook class must have 
get_pandas_df method."
+            )
+        return hook
+
+    def _get_query_results(self) -> DataFrame:
+        sql_hook = self._get_hook()
+
+        self.log.info("Running SQL query: %s", self.sql)
+        df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters)
+        return df
+
+
+class SqlToSlackOperator(BaseSqlToSlackOperator):
     """
     Executes an SQL statement in a given SQL connection and sends the results 
to Slack. The results of the
     query are rendered into the 'slack_message' parameter as a Pandas 
dataframe using a JINJA variable called
@@ -79,12 +126,10 @@ class SqlToSlackOperator(BaseOperator):
         **kwargs,
     ) -> None:
 
-        super().__init__(**kwargs)
+        super().__init__(
+            sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, 
parameters=parameters, **kwargs
+        )
 
-        self.sql_conn_id = sql_conn_id
-        self.sql_hook_params = sql_hook_params
-        self.sql = sql
-        self.parameters = parameters
         self.slack_conn_id = slack_conn_id
         self.slack_webhook_token = slack_webhook_token
         self.slack_channel = slack_channel
@@ -97,23 +142,6 @@ class SqlToSlackOperator(BaseOperator):
                 "SqlToSlackOperator requires either a `slack_conn_id` or a 
`slack_webhook_token` argument"
             )
 
-    def _get_hook(self) -> DbApiHook:
-        self.log.debug("Get connection for %s", self.sql_conn_id)
-        conn = BaseHook.get_connection(self.sql_conn_id)
-        hook = conn.get_hook(hook_params=self.sql_hook_params)
-        if not callable(getattr(hook, "get_pandas_df", None)):
-            raise AirflowException(
-                "This hook is not supported. The hook class must have 
get_pandas_df method."
-            )
-        return hook
-
-    def _get_query_results(self) -> DataFrame:
-        sql_hook = self._get_hook()
-
-        self.log.info("Running SQL query: %s", self.sql)
-        df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters)
-        return df
-
     def _render_and_send_slack_message(self, context, df) -> None:
         # Put the dataframe into the context and render the JINJA template 
fields
         context[self.results_df_name] = df
@@ -157,3 +185,115 @@ class SqlToSlackOperator(BaseOperator):
         self._render_and_send_slack_message(context, df)
 
         self.log.debug("Finished sending SQL data to Slack")
+
+
+class SqlToSlackApiFileOperator(BaseSqlToSlackOperator):
+    """
+    Executes an SQL statement in a given SQL connection and sends the results 
to Slack API as file.
+
+    :param sql: The SQL query to be executed
+    :param sql_conn_id: reference to a specific DB-API Connection.
+    :param slack_conn_id: :ref:`Slack API Connection <howto/connection:slack>`.
+    :param slack_filename: Filename for display in slack.
+        Should contain supported extension which referenced to 
``SUPPORTED_FILE_FORMATS``.
+        It is also possible to set compression in extension:
+        ``filename.csv.gzip``, ``filename.json.zip``, etc.
+    :param sql_hook_params: Extra config params to be passed to the underlying 
hook.
+        Should match the desired hook constructor params.
+    :param parameters: The parameters to pass to the SQL query.
+    :param slack_channels: Comma-separated list of channel names or IDs where 
the file will be shared.
+         If omitting this parameter, then file will send to workspace.
+    :param slack_initial_comment: The message text introducing the file in 
specified ``slack_channels``.
+    :param slack_title: Title of file.
+    :param df_kwargs: Keyword arguments forwarded to 
``pandas.DataFrame.to_{format}()`` method.
+
+    Example:
+     .. code-block:: python
+
+        SqlToSlackApiFileOperator(
+            task_id="sql_to_slack",
+            sql="SELECT 1 a, 2 b, 3 c",
+            sql_conn_id="sql-connection",
+            slack_conn_id="slack-api-connection",
+            slack_filename="awesome.json.gz",
+            slack_channels="#random,#general",
+            slack_initial_comment="Awesome load to compressed multiline JSON.",
+            df_kwargs={
+                "orient": "records",
+                "lines": True,
+            },
+        )
+    """
+
+    template_fields: Sequence[str] = (
+        "sql",
+        "slack_channels",
+        "slack_filename",
+        "slack_initial_comment",
+        "slack_title",
+    )
+    template_ext: Sequence[str] = (".sql", ".jinja", ".j2")
+    template_fields_renderers = {"sql": "sql", "slack_message": "jinja"}
+
+    SUPPORTED_FILE_FORMATS: Sequence[str] = ("csv", "json", "html")
+
+    def __init__(
+        self,
+        *,
+        sql: str,
+        sql_conn_id: str,
+        sql_hook_params: dict | None = None,
+        parameters: Iterable | Mapping | None = None,
+        slack_conn_id: str,
+        slack_filename: str,
+        slack_channels: str | Sequence[str] | None = None,
+        slack_initial_comment: str | None = None,
+        slack_title: str | None = None,
+        df_kwargs: dict | None = None,
+        **kwargs,
+    ):
+        super().__init__(
+            sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, 
parameters=parameters, **kwargs
+        )
+        self.slack_conn_id = slack_conn_id
+        self.slack_filename = slack_filename
+        self.slack_channels = slack_channels
+        self.slack_initial_comment = slack_initial_comment
+        self.slack_title = slack_title
+        self.df_kwargs = df_kwargs or {}
+
+    def execute(self, context: Context) -> None:
+        # Parse file format from filename
+        output_file_format, _ = parse_filename(
+            filename=self.slack_filename,
+            supported_file_formats=self.SUPPORTED_FILE_FORMATS,
+        )
+
+        slack_hook = SlackHook(slack_conn_id=self.slack_conn_id)
+        with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}") 
as fp:
+            # tempfile.NamedTemporaryFile used only for create and remove 
temporary file,
+            # pandas will open file in correct mode itself depend on file type.
+            # So we close file descriptor here for avoid incidentally write 
anything.
+            fp.close()
+
+            output_file_name = fp.name
+            output_file_format = output_file_format.upper()
+            df_result = self._get_query_results()
+            if output_file_format == "CSV":
+                df_result.to_csv(output_file_name, **self.df_kwargs)
+            elif output_file_format == "JSON":
+                df_result.to_json(output_file_name, **self.df_kwargs)
+            elif output_file_format == "HTML":
+                df_result.to_html(output_file_name, **self.df_kwargs)
+            else:
+                # Not expected that this error happen. This only possible
+                # if SUPPORTED_FILE_FORMATS extended and no actual 
implementation for specific format.
+                raise AirflowException(f"Unexpected output file format: 
{output_file_format}")
+
+            slack_hook.send_file(
+                channels=self.slack_channels,
+                file=output_file_name,
+                filename=self.slack_filename,
+                initial_comment=self.slack_initial_comment,
+                title=self.slack_title,
+            )
diff --git a/airflow/providers/slack/utils/__init__.py 
b/airflow/providers/slack/utils/__init__.py
index dda12656d4..1071de6299 100644
--- a/airflow/providers/slack/utils/__init__.py
+++ b/airflow/providers/slack/utils/__init__.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import warnings
-from typing import Any
+from typing import Any, Sequence
 
 from airflow.utils.types import NOTSET
 
@@ -77,3 +77,41 @@ class ConnectionExtraConfig:
         if value != default:
             value = int(value)
         return value
+
+
+def parse_filename(
+    filename: str, supported_file_formats: Sequence[str], fallback: str | None 
= None
+) -> tuple[str, str | None]:
+    """
+    Parse filetype and compression from given filename.
+    :param filename: filename to parse.
+    :param supported_file_formats: list of supported file extensions.
+    :param fallback: fallback to given file format.
+    :returns: filetype and compression (if specified)
+    """
+    if not filename:
+        raise ValueError("Expected 'filename' parameter is missing.")
+    if fallback and fallback not in supported_file_formats:
+        raise ValueError(f"Invalid fallback value {fallback!r}, expected one 
of {supported_file_formats}.")
+
+    parts = filename.rsplit(".", 2)
+    try:
+        if len(parts) == 1:
+            raise ValueError(f"No file extension specified in filename 
{filename!r}.")
+        if parts[-1] in supported_file_formats:
+            return parts[-1], None
+        elif len(parts) == 2:
+            raise ValueError(
+                f"Unsupported file format {parts[-1]!r}, expected one of 
{supported_file_formats}."
+            )
+        else:
+            if parts[-2] not in supported_file_formats:
+                raise ValueError(
+                    f"Unsupported file format '{parts[-2]}.{parts[-1]}', "
+                    f"expected one of {supported_file_formats} with 
compression extension."
+                )
+            return parts[-2], parts[-1]
+    except ValueError as ex:
+        if fallback:
+            return fallback, None
+        raise ex from None
diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py 
b/tests/providers/slack/transfers/test_sql_to_slack.py
index 307469460b..23efa895a2 100644
--- a/tests/providers/slack/transfers/test_sql_to_slack.py
+++ b/tests/providers/slack/transfers/test_sql_to_slack.py
@@ -23,7 +23,11 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.models import DAG, Connection
-from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator
+from airflow.providers.slack.transfers.sql_to_slack import (
+    BaseSqlToSlackOperator,
+    SqlToSlackApiFileOperator,
+    SqlToSlackOperator,
+)
 from airflow.utils import timezone
 
 TEST_DAG_ID = "sql_to_slack_unit_test"
@@ -31,6 +35,77 @@ TEST_TASK_ID = "sql_to_slack_unit_test_task"
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 
 
+class TestBaseSqlToSlackOperator:
+    def setup_method(self):
+        self.default_op_kwargs = {
+            "sql": "SELECT 1",
+            "sql_conn_id": "test-sql-conn-id",
+            "sql_hook_params": None,
+            "parameters": None,
+        }
+
+    def test_execute_not_implemented(self):
+        """Test that no base implementation for 
``BaseSqlToSlackOperator.execute()``."""
+        op = BaseSqlToSlackOperator(task_id="test_base_not_implements", 
**self.default_op_kwargs)
+        with pytest.raises(NotImplementedError):
+            op.execute(mock.MagicMock())
+
+    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
+    @mock.patch("airflow.models.connection.Connection.get_hook")
+    @pytest.mark.parametrize("conn_type", ["postgres", "snowflake"])
+    @pytest.mark.parametrize("sql_hook_params", [None, {"foo": "bar"}])
+    def test_get_hook(self, mock_get_hook, mock_get_conn, conn_type, 
sql_hook_params):
+        class SomeDummyHook:
+            """Hook which implements ``get_pandas_df`` method"""
+
+            def get_pandas_df(self):
+                pass
+
+        expected_hook = SomeDummyHook()
+        mock_get_conn.return_value = 
Connection(conn_id=f"test_connection_{conn_type}", conn_type=conn_type)
+        mock_get_hook.return_value = expected_hook
+        op_kwargs = {
+            **self.default_op_kwargs,
+            "sql_hook_params": sql_hook_params,
+        }
+        op = BaseSqlToSlackOperator(task_id="test_get_hook", **op_kwargs)
+        hook = op._get_hook()
+        mock_get_hook.assert_called_once_with(hook_params=sql_hook_params)
+        assert hook == expected_hook
+
+    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
+    @mock.patch("airflow.models.connection.Connection.get_hook")
+    def test_get_not_supported_hook(self, mock_get_hook, mock_get_conn):
+        class SomeDummyHook:
+            """Hook which not implemented ``get_pandas_df`` method"""
+
+        mock_get_conn.return_value = Connection(conn_id="test_connection", 
conn_type="test_connection")
+        mock_get_hook.return_value = SomeDummyHook()
+        op = BaseSqlToSlackOperator(task_id="test_get_not_supported_hook", 
**self.default_op_kwargs)
+        error_message = r"This hook is not supported. The hook class must have 
get_pandas_df method\."
+        with pytest.raises(AirflowException, match=error_message):
+            op._get_hook()
+
+    
@mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_hook")
+    @pytest.mark.parametrize("sql", ["SELECT 42", "SELECT 1 FROM DUMMY WHERE 
col = ?"])
+    @pytest.mark.parametrize("parameters", [None, {"col": "spam-egg"}])
+    def test_get_query_results(self, mock_op_get_hook, sql, parameters):
+        test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1])
+        mock_get_pandas_df = mock.MagicMock(return_value=test_df)
+        mock_hook = mock.MagicMock()
+        mock_hook.get_pandas_df = mock_get_pandas_df
+        mock_op_get_hook.return_value = mock_hook
+        op_kwargs = {
+            **self.default_op_kwargs,
+            "sql": sql,
+            "parameters": parameters,
+        }
+        op = BaseSqlToSlackOperator(task_id="test_get_query_results", 
**op_kwargs)
+        df = op._get_query_results()
+        mock_get_pandas_df.assert_called_once_with(sql, parameters=parameters)
+        assert df is test_df
+
+
 class TestSqlToSlackOperator:
     def setup_method(self):
         self.example_dag = DAG(TEST_DAG_ID, start_date=DEFAULT_DATE)
@@ -215,3 +290,90 @@ class TestSqlToSlackOperator:
         assert hook.database == "database"
         assert hook.role == "role"
         assert hook.schema == "schema"
+
+
+class TestSqlToSlackApiFileOperator:
+    def setup_method(self):
+        self.default_op_kwargs = {
+            "sql": "SELECT 1",
+            "sql_conn_id": "test-sql-conn-id",
+            "slack_conn_id": "test-slack-conn-id",
+            "sql_hook_params": None,
+            "parameters": None,
+        }
+
+    
@mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_query_results")
+    @mock.patch("airflow.providers.slack.transfers.sql_to_slack.SlackHook")
+    @pytest.mark.parametrize(
+        "filename,df_method",
+        [
+            ("awesome.json", "to_json"),
+            ("awesome.json.zip", "to_json"),
+            ("awesome.csv", "to_csv"),
+            ("awesome.csv.xz", "to_csv"),
+            ("awesome.html", "to_html"),
+        ],
+    )
+    @pytest.mark.parametrize("df_kwargs", [None, {}, {"foo": "bar"}])
+    @pytest.mark.parametrize("channels", ["#random", "#random,#general", None])
+    @pytest.mark.parametrize("initial_comment", [None, "Test Comment"])
+    @pytest.mark.parametrize("title", [None, "Test File Title"])
+    def test_send_file(
+        self,
+        mock_slack_hook_cls,
+        mock_get_query_results,
+        filename,
+        df_method,
+        df_kwargs,
+        channels,
+        initial_comment,
+        title,
+    ):
+        # Mock Hook
+        mock_send_file = mock.MagicMock()
+        mock_slack_hook_cls.return_value.send_file = mock_send_file
+
+        # Mock returns pandas.DataFrame and expected method
+        mock_df = mock.MagicMock()
+        mock_df_output_method = mock.MagicMock()
+        setattr(mock_df, df_method, mock_df_output_method)
+        mock_get_query_results.return_value = mock_df
+
+        op_kwargs = {
+            **self.default_op_kwargs,
+            "slack_conn_id": "expected-test-slack-conn-id",
+            "slack_filename": filename,
+            "slack_channels": channels,
+            "slack_initial_comment": initial_comment,
+            "slack_title": title,
+            "df_kwargs": df_kwargs,
+        }
+        op = SqlToSlackApiFileOperator(task_id="test_send_file", **op_kwargs)
+        op.execute(mock.MagicMock())
+
+        
mock_slack_hook_cls.assert_called_once_with(slack_conn_id="expected-test-slack-conn-id")
+        mock_get_query_results.assert_called_once_with()
+        mock_df_output_method.assert_called_once_with(mock.ANY, **(df_kwargs 
or {}))
+        mock_send_file.assert_called_once_with(
+            channels=channels,
+            filename=filename,
+            initial_comment=initial_comment,
+            title=title,
+            file=mock.ANY,
+        )
+
+    @pytest.mark.parametrize(
+        "filename",
+        [
+            "foo.parquet",
+            "bat.parquet.snappy",
+            "spam.xml",
+            "egg.xlsx",
+        ],
+    )
+    def test_unsupported_format(self, filename):
+        op = SqlToSlackApiFileOperator(
+            task_id="test_send_file", slack_filename=filename, 
**self.default_op_kwargs
+        )
+        with pytest.raises(ValueError):
+            op.execute(mock.MagicMock())
diff --git a/tests/providers/slack/utils/test_utils.py 
b/tests/providers/slack/utils/test_utils.py
index d794c80f60..bff3dbc658 100644
--- a/tests/providers/slack/utils/test_utils.py
+++ b/tests/providers/slack/utils/test_utils.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from airflow.providers.slack.utils import ConnectionExtraConfig
+from airflow.providers.slack.utils import ConnectionExtraConfig, parse_filename
 
 
 class TestConnectionExtra:
@@ -92,3 +92,55 @@ class TestConnectionExtra:
         )
         assert extra_config.getint("int_arg_1") == 42
         assert extra_config.getint("int_arg_2") == 9000
+
+
+class TestParseFilename:
+    SUPPORTED_FORMAT = ("so", "dll", "exe", "sh")
+
+    def test_error_parse_without_extension(self):
+        with pytest.raises(ValueError, match="No file extension specified in 
filename"):
+            assert parse_filename("Untitled File", self.SUPPORTED_FORMAT)
+
+    @pytest.mark.parametrize(
+        "filename,expected_format",
+        [
+            ("libc.so", "so"),
+            ("kernel32.dll", "dll"),
+            ("xxx.mp4.exe", "exe"),
+            ("init.sh", "sh"),
+        ],
+    )
+    def test_parse_first_level(self, filename, expected_format):
+        assert parse_filename(filename, self.SUPPORTED_FORMAT) == 
(expected_format, None)
+
+    @pytest.mark.parametrize("filename", ["New File.txt", "cats-memes.mp4"])
+    def test_error_parse_first_level(self, filename):
+        with pytest.raises(ValueError, match="Unsupported file format"):
+            assert parse_filename(filename, self.SUPPORTED_FORMAT)
+
+    @pytest.mark.parametrize(
+        "filename,expected",
+        [
+            ("libc.so.6", ("so", "6")),
+            ("kernel32.dll.zip", ("dll", "zip")),
+            ("explorer.exe.7z", ("exe", "7z")),
+            ("init.sh.gz", ("sh", "gz")),
+        ],
+    )
+    def test_parse_second_level(self, filename, expected):
+        assert parse_filename(filename, self.SUPPORTED_FORMAT) == expected
+
+    @pytest.mark.parametrize("filename", ["example.so.tar.gz", "w.i.e.r.d"])
+    def test_error_parse_second_level(self, filename):
+        with pytest.raises(ValueError, match="Unsupported file format.*with 
compression extension."):
+            assert parse_filename(filename, self.SUPPORTED_FORMAT)
+
+    @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt", 
"example.so.tar.gz"])
+    @pytest.mark.parametrize("fallback", SUPPORTED_FORMAT)
+    def test_fallback(self, filename, fallback):
+        assert parse_filename(filename, self.SUPPORTED_FORMAT, fallback) == 
(fallback, None)
+
+    @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt", 
"example.so.tar.gz"])
+    def test_wrong_fallback(self, filename):
+        with pytest.raises(ValueError, match="Invalid fallback value"):
+            assert parse_filename(filename, self.SUPPORTED_FORMAT, "mp4")

Reply via email to