This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bd05901c815 feat: add async discord notifier (#56911)
bd05901c815 is described below

commit bd05901c8158667cb015befffadfc3206dc5a4c4
Author: Sebastian Daum <[email protected]>
AuthorDate: Thu Nov 20 14:05:13 2025 +0100

    feat: add async discord notifier (#56911)
---
 providers/discord/pyproject.toml                   |   2 +-
 .../providers/discord/hooks/discord_webhook.py     | 196 ++++++++++++++++-----
 .../providers/discord/notifications/discord.py     |  37 +++-
 .../providers/discord/operators/discord_webhook.py |   7 +-
 .../airflow/providers/discord/version_compat.py    |  42 +++++
 .../unit/discord/hooks/test_discord_webhook.py     | 167 ++++++++++++++----
 .../unit/discord/notifications/test_discord.py     |  19 +-
 7 files changed, 387 insertions(+), 83 deletions(-)

diff --git a/providers/discord/pyproject.toml b/providers/discord/pyproject.toml
index af9465493d2..c05c6f25315 100644
--- a/providers/discord/pyproject.toml
+++ b/providers/discord/pyproject.toml
@@ -59,7 +59,7 @@ requires-python = ">=3.10"
 # After you modify the dependencies, and rebuild your Breeze CI image with 
``breeze ci-image build``
 dependencies = [
     "apache-airflow>=2.10.0",
-    "apache-airflow-providers-common-compat>=1.8.0",
+    "apache-airflow-providers-common-compat>=1.8.0", # use next version
     "apache-airflow-providers-http",
 ]
 
diff --git 
a/providers/discord/src/airflow/providers/discord/hooks/discord_webhook.py 
b/providers/discord/src/airflow/providers/discord/hooks/discord_webhook.py
index 01cc4c616c4..5d547375ed1 100644
--- a/providers/discord/src/airflow/providers/discord/hooks/discord_webhook.py
+++ b/providers/discord/src/airflow/providers/discord/hooks/discord_webhook.py
@@ -19,10 +19,70 @@ from __future__ import annotations
 
 import json
 import re
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
-from airflow.exceptions import AirflowException
-from airflow.providers.http.hooks.http import HttpHook
+import aiohttp
+
+from airflow.providers.common.compat.connection import get_async_connection
+from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
+
+if TYPE_CHECKING:
+    from airflow.providers.common.compat.sdk import Connection
+
+
+class DiscordCommonHandler:
+    """Contains the common functionality."""
+
+    def get_webhook_endpoint(self, conn: Connection | None, webhook_endpoint: 
str | None) -> str:
+        """
+        Return the default webhook endpoint or override if a webhook_endpoint 
is manually supplied.
+
+        :param conn: Airflow Discord connection
+        :param webhook_endpoint: The manually provided webhook endpoint
+        :return: Webhook endpoint (str) to use
+        """
+        if webhook_endpoint:
+            endpoint = webhook_endpoint
+        elif conn:
+            extra = conn.extra_dejson
+            endpoint = extra.get("webhook_endpoint", "")
+        else:
+            raise ValueError(
+                "Cannot get webhook endpoint: No valid Discord webhook 
endpoint or http_conn_id supplied."
+            )
+
+        # make sure endpoint matches the expected Discord webhook format
+        if not re.fullmatch("webhooks/[0-9]+/[a-zA-Z0-9_-]+", endpoint):
+            raise ValueError(
+                'Expected Discord webhook endpoint in the form of 
"webhooks/{webhook.id}/{webhook.token}".'
+            )
+
+        return endpoint
+
+    def build_discord_payload(
+        self, *, tts: bool, message: str, username: str | None, avatar_url: 
str | None
+    ) -> str:
+        """
+        Build a valid Discord JSON payload.
+
+        :param tts: Is a text-to-speech message
+        :param message: The message you want to send to your Discord channel
+                        (max 2000 characters)
+        :param username: Override the default username of the webhook
+        :param avatar_url: Override the default avatar of the webhook
+        :return: Discord payload (str) to send
+        """
+        if len(message) > 2000:
+            raise ValueError("Discord message length must be 2000 or fewer 
characters.")
+        payload: dict[str, Any] = {
+            "content": message,
+            "tts": tts,
+        }
+        if username:
+            payload["username"] = username
+        if avatar_url:
+            payload["avatar_url"] = avatar_url
+        return json.dumps(payload)
 
 
 class DiscordWebhookHook(HttpHook):
@@ -84,6 +144,7 @@ class DiscordWebhookHook(HttpHook):
         **kwargs: Any,
     ) -> None:
         super().__init__(*args, **kwargs)
+        self.handler = DiscordCommonHandler()
         self.http_conn_id: Any = http_conn_id
         self.webhook_endpoint = self._get_webhook_endpoint(http_conn_id, 
webhook_endpoint)
         self.message = message
@@ -100,46 +161,10 @@ class DiscordWebhookHook(HttpHook):
         :param webhook_endpoint: The manually provided webhook endpoint
         :return: Webhook endpoint (str) to use
         """
-        if webhook_endpoint:
-            endpoint = webhook_endpoint
-        elif http_conn_id:
+        conn = None
+        if not webhook_endpoint and http_conn_id:
             conn = self.get_connection(http_conn_id)
-            extra = conn.extra_dejson
-            endpoint = extra.get("webhook_endpoint", "")
-        else:
-            raise AirflowException(
-                "Cannot get webhook endpoint: No valid Discord webhook 
endpoint or http_conn_id supplied."
-            )
-
-        # make sure endpoint matches the expected Discord webhook format
-        if not re.fullmatch("webhooks/[0-9]+/[a-zA-Z0-9_-]+", endpoint):
-            raise AirflowException(
-                'Expected Discord webhook endpoint in the form of 
"webhooks/{webhook.id}/{webhook.token}".'
-            )
-
-        return endpoint
-
-    def _build_discord_payload(self) -> str:
-        """
-        Combine all relevant parameters into a valid Discord JSON payload.
-
-        :return: Discord payload (str) to send
-        """
-        payload: dict[str, Any] = {}
-
-        if self.username:
-            payload["username"] = self.username
-        if self.avatar_url:
-            payload["avatar_url"] = self.avatar_url
-
-        payload["tts"] = self.tts
-
-        if len(self.message) <= 2000:
-            payload["content"] = self.message
-        else:
-            raise AirflowException("Discord message length must be 2000 or 
fewer characters.")
-
-        return json.dumps(payload)
+        return self.handler.get_webhook_endpoint(conn, webhook_endpoint)
 
     def execute(self) -> None:
         """Execute the Discord webhook call."""
@@ -148,7 +173,9 @@ class DiscordWebhookHook(HttpHook):
             # we only need https proxy for Discord
             proxies = {"https": self.proxy}
 
-        discord_payload = self._build_discord_payload()
+        discord_payload = self.handler.build_discord_payload(
+            tts=self.tts, message=self.message, username=self.username, 
avatar_url=self.avatar_url
+        )
 
         self.run(
             endpoint=self.webhook_endpoint,
@@ -156,3 +183,86 @@ class DiscordWebhookHook(HttpHook):
             headers={"Content-type": "application/json"},
             extra_options={"proxies": proxies},
         )
+
+
+class DiscordWebhookAsyncHook(HttpAsyncHook):
+    """
+    This hook allows you to post messages to Discord using incoming webhooks 
using async HTTP.
+
+    Takes a Discord connection ID with a default relative webhook endpoint. The
+    default endpoint can be overridden using the webhook_endpoint parameter
+    (https://discordapp.com/developers/docs/resources/webhook).
+
+    Each Discord webhook can be pre-configured to use a specific username and
+    avatar_url. You can override these defaults in this hook.
+
+    :param http_conn_id: Http connection ID with host as 
"https://discord.com/api/"; and
+                         default webhook endpoint in the extra field in the 
form of
+                         {"webhook_endpoint": 
"webhooks/{webhook.id}/{webhook.token}"}
+    :param webhook_endpoint: Discord webhook endpoint in the form of
+                             "webhooks/{webhook.id}/{webhook.token}"
+    :param message: The message you want to send to your Discord channel
+                    (max 2000 characters)
+    :param username: Override the default username of the webhook
+    :param avatar_url: Override the default avatar of the webhook
+    :param tts: Is a text-to-speech message
+    :param proxy: Proxy to use to make the Discord webhook call
+    """
+
+    default_headers = {
+        "Content-Type": "application/json",
+    }
+    conn_name_attr = "http_conn_id"
+    default_conn_name = "discord_default"
+    conn_type = "discord"
+    hook_name = "Async Discord"
+
+    def __init__(
+        self,
+        *,
+        http_conn_id: str = "",
+        webhook_endpoint: str | None = None,
+        message: str = "",
+        username: str | None = None,
+        avatar_url: str | None = None,
+        tts: bool = False,
+        proxy: str | None = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.http_conn_id = http_conn_id
+        self.webhook_endpoint = webhook_endpoint
+        self.message = message
+        self.username = username
+        self.avatar_url = avatar_url
+        self.tts = tts
+        self.proxy = proxy
+        self.handler = DiscordCommonHandler()
+
+    async def _get_webhook_endpoint(self) -> str:
+        """
+        Return the default webhook endpoint or override if a webhook_endpoint 
is manually supplied.
+
+        :param http_conn_id: The provided connection ID
+        :param webhook_endpoint: The manually provided webhook endpoint
+        :return: Webhook endpoint (str) to use
+        """
+        conn = None
+        if not self.webhook_endpoint and self.http_conn_id:
+            conn = await get_async_connection(self.http_conn_id)
+        return self.handler.get_webhook_endpoint(conn, self.webhook_endpoint)
+
+    async def execute(self) -> None:
+        """Execute the Discord webhook call."""
+        webhook_endpoint = await self._get_webhook_endpoint()
+        discord_payload = self.handler.build_discord_payload(
+            tts=self.tts, message=self.message, username=self.username, 
avatar_url=self.avatar_url
+        )
+
+        async with aiohttp.ClientSession(proxy=self.proxy) as session:
+            await super().run(
+                session=session,
+                endpoint=webhook_endpoint,
+                data=discord_payload,
+                headers=self.default_headers,
+            )
diff --git 
a/providers/discord/src/airflow/providers/discord/notifications/discord.py 
b/providers/discord/src/airflow/providers/discord/notifications/discord.py
index 2c8cd6f8ffe..190c146e8ef 100644
--- a/providers/discord/src/airflow/providers/discord/notifications/discord.py
+++ b/providers/discord/src/airflow/providers/discord/notifications/discord.py
@@ -20,7 +20,8 @@ from __future__ import annotations
 from functools import cached_property
 
 from airflow.providers.common.compat.notifier import BaseNotifier
-from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
+from airflow.providers.discord.hooks.discord_webhook import 
DiscordWebhookAsyncHook, DiscordWebhookHook
+from airflow.providers.discord.version_compat import AIRFLOW_V_3_1_PLUS
 
 ICON_URL: str = (
     
"https://raw.githubusercontent.com/apache/airflow/main/airflow-core/src/airflow/ui/public/pin_100.png";
@@ -50,8 +51,13 @@ class DiscordNotifier(BaseNotifier):
         username: str = "Airflow",
         avatar_url: str = ICON_URL,
         tts: bool = False,
+        **kwargs,
     ):
-        super().__init__()
+        if AIRFLOW_V_3_1_PLUS:
+            #  Support for passing context was added in 3.1.0
+            super().__init__(**kwargs)
+        else:
+            super().__init__()
         self.discord_conn_id = discord_conn_id
         self.text = text
         self.username = username
@@ -66,11 +72,36 @@ class DiscordNotifier(BaseNotifier):
         """Discord Webhook Hook."""
         return DiscordWebhookHook(http_conn_id=self.discord_conn_id)
 
+    @cached_property
+    def hook_async(self) -> DiscordWebhookAsyncHook:
+        """Discord Webhook Async Hook."""
+        return DiscordWebhookAsyncHook(
+            http_conn_id=self.discord_conn_id,
+            message=self.text,
+            username=self.username,
+            avatar_url=self.avatar_url,
+            tts=self.tts,
+        )
+
     def notify(self, context):
-        """Send a message to a Discord channel."""
+        """
+        Send a message to a Discord channel.
+
+        :param context: the context object
+        :return: None
+        """
         self.hook.username = self.username
         self.hook.message = self.text
         self.hook.avatar_url = self.avatar_url
         self.hook.tts = self.tts
 
         self.hook.execute()
+
+    async def async_notify(self, context) -> None:
+        """
+        Send a message to a Discord channel using async HTTP.
+
+        :param context: the context object
+        :return: None
+        """
+        await self.hook_async.execute()
diff --git 
a/providers/discord/src/airflow/providers/discord/operators/discord_webhook.py 
b/providers/discord/src/airflow/providers/discord/operators/discord_webhook.py
index c2e29684a07..60d97e2caf2 100644
--- 
a/providers/discord/src/airflow/providers/discord/operators/discord_webhook.py
+++ 
b/providers/discord/src/airflow/providers/discord/operators/discord_webhook.py
@@ -93,5 +93,10 @@ class DiscordWebhookOperator(HttpOperator):
         return hook
 
     def execute(self, context: Context) -> None:
-        """Call the DiscordWebhookHook to post a message."""
+        """
+        Call the DiscordWebhookHook to post a message.
+
+        :param context: the context object
+        :return: None
+        """
         self.hook.execute()
diff --git a/providers/discord/src/airflow/providers/discord/version_compat.py 
b/providers/discord/src/airflow/providers/discord/version_compat.py
new file mode 100644
index 00000000000..d63a188207c
--- /dev/null
+++ b/providers/discord/src/airflow/providers/discord/version_compat.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID 
ADDING UNNECESSARY
+# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR 
PROVIDER THAT DEPENDS
+# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR 
PROVIDER AND IMPORT
+# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR 
TEST CODE
+#
+from __future__ import annotations
+
+
+def get_base_airflow_version_tuple() -> tuple[int, int, int]:
+    from packaging.version import Version
+
+    from airflow import __version__
+
+    airflow_version = Version(__version__)
+    return airflow_version.major, airflow_version.minor, airflow_version.micro
+
+
+AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
+
+
+__all__ = [
+    "AIRFLOW_V_3_0_PLUS",
+    "AIRFLOW_V_3_1_PLUS",
+]
diff --git a/providers/discord/tests/unit/discord/hooks/test_discord_webhook.py 
b/providers/discord/tests/unit/discord/hooks/test_discord_webhook.py
index fbbb2ab722e..05d496d3340 100644
--- a/providers/discord/tests/unit/discord/hooks/test_discord_webhook.py
+++ b/providers/discord/tests/unit/discord/hooks/test_discord_webhook.py
@@ -18,34 +18,87 @@
 from __future__ import annotations
 
 import json
+from unittest import mock
 
 import pytest
+from aioresponses import aioresponses
 
-from airflow.exceptions import AirflowException
 from airflow.models import Connection
-from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
+from airflow.providers.discord.hooks.discord_webhook import (
+    DiscordCommonHandler,
+    DiscordWebhookAsyncHook,
+    DiscordWebhookHook,
+)
 
 
-class TestDiscordWebhookHook:
[email protected]
+def aioresponse():
+    """
+    Creates mock async API response.
+    """
+    with aioresponses() as async_response:
+        yield async_response
+
+
+class TestDiscordCommonHandler:
     _config = {
-        "http_conn_id": "default-discord-webhook",
-        "webhook_endpoint": "webhooks/11111/some-discord-token_111",
         "message": "your message here",
         "username": "Airflow Webhook",
         "avatar_url": "https://static-cdn.avatars.com/my-avatar-path";,
         "tts": False,
-        "proxy": "https://proxy.proxy.com:8888";,
     }
 
     expected_payload_dict = {
+        "content": _config["message"],
+        "tts": _config["tts"],
         "username": _config["username"],
         "avatar_url": _config["avatar_url"],
-        "tts": _config["tts"],
-        "content": _config["message"],
     }
 
     expected_payload = json.dumps(expected_payload_dict)
 
+    def test_get_webhook_endpoint_manual_token(self):
+        provided_endpoint = "webhooks/11111/some-discord-token_111"
+        handler = DiscordCommonHandler()
+        webhook_endpoint = handler.get_webhook_endpoint(None, 
provided_endpoint)
+        assert webhook_endpoint == provided_endpoint
+
+    def test_get_webhook_endpoint_invalid_url(self):
+        provided_endpoint = "https://discordapp.com/some-invalid-webhook-url";
+        handler = DiscordCommonHandler()
+        expected_message = "Expected Discord webhook endpoint in the form of"
+        with pytest.raises(ValueError, match=expected_message):
+            handler.get_webhook_endpoint(None, provided_endpoint)
+
+    def test_get_webhook_endpoint_conn_id(self):
+        conn = Connection(
+            conn_id="default-discord-webhook",
+            conn_type="discord",
+            host="https://discordapp.com/api/";,
+            extra='{"webhook_endpoint": 
"webhooks/00000/some-discord-token_000"}',
+        )
+        expected_webhook_endpoint = "webhooks/00000/some-discord-token_000"
+        handler = DiscordCommonHandler()
+        webhook_endpoint = handler.get_webhook_endpoint(conn, None)
+        assert webhook_endpoint == expected_webhook_endpoint
+
+    def test_build_discord_payload(self):
+        handler = DiscordCommonHandler()
+        payload = handler.build_discord_payload(**self._config)
+        assert self.expected_payload == payload
+
+    def test_build_discord_payload_message_length(self):
+        # Given
+        config = self._config.copy()
+        # create message over the character limit
+        config["message"] = "c" * 2001
+        handler = DiscordCommonHandler()
+        expected_message = "Discord message length must be 2000 or fewer 
characters"
+        with pytest.raises(ValueError, match=expected_message):
+            handler.build_discord_payload(**config)
+
+
+class TestDiscordWebhookHook:
     @pytest.fixture(autouse=True)
     def setup_connections(self, create_connection_without_db):
         create_connection_without_db(
@@ -68,15 +121,6 @@ class TestDiscordWebhookHook:
         # Then
         assert webhook_endpoint == provided_endpoint
 
-    def test_get_webhook_endpoint_invalid_url(self):
-        # Given
-        provided_endpoint = "https://discordapp.com/some-invalid-webhook-url";
-
-        # When/Then
-        expected_message = "Expected Discord webhook endpoint in the form of"
-        with pytest.raises(AirflowException, match=expected_message):
-            DiscordWebhookHook(webhook_endpoint=provided_endpoint)
-
     def test_get_webhook_endpoint_conn_id(self):
         # Given
         conn_id = "default-discord-webhook"
@@ -89,24 +133,79 @@ class TestDiscordWebhookHook:
         # Then
         assert webhook_endpoint == expected_webhook_endpoint
 
-    def test_build_discord_payload(self):
-        # Given
-        hook = DiscordWebhookHook(**self._config)
 
-        # When
-        payload = hook._build_discord_payload()
+class TestDiscordWebhookAsyncHook:
+    @pytest.fixture(autouse=True)
+    def setup_connections(self, create_connection_without_db):
+        create_connection_without_db(
+            Connection(
+                conn_id="default-discord-webhook",
+                conn_type="discord",
+                host="https://discordapp.com/api/";,
+                extra='{"webhook_endpoint": 
"webhooks/00000/some-discord-token_000"}',
+            )
+        )
 
-        # Then
-        assert self.expected_payload == payload
+    @pytest.fixture(autouse=True)
+    def mock_get_connection(self):
+        """Mock the async connection retrieval."""
+        with mock.patch(
+            
"airflow.providers.discord.hooks.discord_webhook.get_async_connection",
+            new_callable=mock.AsyncMock,
+        ) as mock_conn:
+            mock_conn.return_value = Connection(
+                conn_id="default-discord-webhook",
+                conn_type="discord",
+                host="https://discordapp.com/api/";,
+                extra='{"webhook_endpoint": 
"webhooks/00000/some-discord-token_000"}',
+            )
+            yield mock_conn
 
-    def test_build_discord_payload_message_length(self):
-        # Given
-        config = self._config.copy()
-        # create message over the character limit
-        config["message"] = "c" * 2001
-        hook = DiscordWebhookHook(**config)
+    @pytest.mark.asyncio
+    async def test_manual_token_overrides_conn(self):
+        provided_endpoint = "webhooks/11111/some-discord-token_111"
+        hook = DiscordWebhookAsyncHook(webhook_endpoint=provided_endpoint)
+        webhook_endpoint = await hook._get_webhook_endpoint()
+        assert webhook_endpoint == provided_endpoint
 
-        # When/Then
-        expected_message = "Discord message length must be 2000 or fewer 
characters"
-        with pytest.raises(AirflowException, match=expected_message):
-            hook._build_discord_payload()
+    @pytest.mark.asyncio
+    async def test_get_webhook_endpoint_conn_id(self):
+        conn_id = "default-discord-webhook"
+        hook = DiscordWebhookAsyncHook(http_conn_id=conn_id)
+        expected_webhook_endpoint = "webhooks/00000/some-discord-token_000"
+        webhook_endpoint = await hook._get_webhook_endpoint()
+        assert webhook_endpoint == expected_webhook_endpoint
+
+    @pytest.mark.asyncio
+    async def test_execute_with_payload(self):
+        conn_id = "default-discord-webhook"
+        hook = DiscordWebhookAsyncHook(
+            http_conn_id=conn_id,
+            message="your message here",
+            username="Airflow Webhook",
+            avatar_url="https://static-cdn.avatars.com/my-avatar-path";,
+            tts=False,
+        )
+        expected_payload_dict = {
+            "content": "your message here",
+            "tts": False,
+            "username": "Airflow Webhook",
+            "avatar_url": "https://static-cdn.avatars.com/my-avatar-path";,
+        }
+
+        with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
+            await hook.execute()
+            assert mocked_function.call_args.kwargs.get("data") == 
json.dumps(expected_payload_dict)
+
+    @pytest.mark.asyncio
+    async def test_execute_with_success(self, aioresponse):
+        conn_id = "default-discord-webhook"
+        hook = DiscordWebhookAsyncHook(
+            http_conn_id=conn_id,
+            message="your message here",
+            username="Airflow Webhook",
+            avatar_url="https://static-cdn.avatars.com/my-avatar-path";,
+            tts=False,
+        )
+        
aioresponse.post("https://discordapp.com/api/webhooks/00000/some-discord-token_000";,
 status=200)
+        await hook.execute()
diff --git a/providers/discord/tests/unit/discord/notifications/test_discord.py 
b/providers/discord/tests/unit/discord/notifications/test_discord.py
index 84db7713fa5..1181450ac78 100644
--- a/providers/discord/tests/unit/discord/notifications/test_discord.py
+++ b/providers/discord/tests/unit/discord/notifications/test_discord.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, call, patch
 
 import pytest
 
@@ -55,3 +55,20 @@ def test_discord_notifier_notify(mock_execute):
     assert notifier.hook.message == "This is a test message"
     assert notifier.hook.avatar_url == "https://example.com/avatar.png";
     assert notifier.hook.tts is False
+
+
[email protected]
+@patch(
+    
"airflow.providers.discord.notifications.discord.DiscordWebhookAsyncHook.execute",
+    new_callable=AsyncMock,
+)
+async def test_async_notifier(mock_async_hook):
+    notifier = DiscordNotifier(
+        discord_conn_id="my_discord_conn_id",
+        text="This is a test message",
+        username="test_user",
+        avatar_url="https://example.com/avatar.png";,
+        tts=False,
+    )
+    await notifier.async_notify({})
+    assert mock_async_hook.mock_calls == [call()]

Reply via email to