This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d0800f6d1e Fix OAuth2 XOAUTH2 auth and EHLO after STARTTLS in 
SmtpHook (#62879)
9d0800f6d1e is described below

commit 9d0800f6d1e0ac54c4a9f09765147707e6ca1851
Author: Yoann <[email protected]>
AuthorDate: Mon Mar 9 05:28:32 2026 -0700

    Fix OAuth2 XOAUTH2 auth and EHLO after STARTTLS in SmtpHook (#62879)
    
    * fix(providers/smtp): fix OAuth2 XOAUTH2 auth in SmtpHook
    
    Three bugs fixed:
    1. get_conn() used self._auth_type (constructor default 'basic') instead
       of self.auth_type property, ignoring auth_type set in connection extras.
    2. After STARTTLS, ehlo() was not called, causing the server to reject
       subsequent AUTH commands (session state reset by STARTTLS).
    3. aget_conn() (async path) had no OAuth2 support at all, only basic auth.
    
    Fixes apache/airflow#62775
    
    * fix: use auth_xoauth2() for aiosmtplib async client + ruff format
    
    * ci: retry after cache miss
    
    * fix: update test to use auth_xoauth2 matching actual async implementation
    
    * fix: add auth_xoauth2 mock to async SMTP test fixtures for lowest-deps 
compat
    
    * ci: retrigger CI (unrelated infra failures)
    
    * ci: retrigger CI (ghcr.io timeout)
    
    * retrigger CI
---
 .../smtp/src/airflow/providers/smtp/hooks/smtp.py  | 16 +++-
 providers/smtp/tests/unit/smtp/hooks/test_smtp.py  | 98 ++++++++++++++++++++++
 2 files changed, 112 insertions(+), 2 deletions(-)

diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py 
b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
index ec7594d1647..b44010f0c17 100644
--- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
+++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
@@ -127,9 +127,10 @@ class SmtpHook(BaseHook):
                 else:
                     if self.smtp_starttls:
                         self._smtp_client.starttls()
+                        self._smtp_client.ehlo()
 
                     # choose auth
-                    if self._auth_type == "oauth2":
+                    if self.auth_type == "oauth2":
                         if not self._access_token:
                             self._access_token = self._get_oauth2_token()
                         user_identity = self.smtp_user or self.from_email
@@ -172,8 +173,19 @@ class SmtpHook(BaseHook):
                 else:
                     if self.smtp_starttls:
                         await async_client.starttls()
+                        await async_client.ehlo()
 
-                    if self.smtp_user and self.smtp_password:
+                    # choose auth
+                    if self.auth_type == "oauth2":
+                        if not self._access_token:
+                            self._access_token = self._get_oauth2_token()
+                        user_identity = self.smtp_user or self.from_email
+                        if user_identity is None:
+                            raise AirflowException(
+                                "smtp_user or from_email must be set for 
OAuth2 authentication"
+                            )
+                        await async_client.auth_xoauth2(user_identity, 
self._access_token)
+                    elif self.smtp_user and self.smtp_password:
                         await async_client.auth_login(self.smtp_user, 
self.smtp_password)
                     break
 
diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py 
b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
index 0553c9e0206..fcf6d5b30f5 100644
--- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
+++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
@@ -486,6 +486,59 @@ class TestSmtpHook:
 
         assert not mock_conn.auth.called
 
+    @patch(smtplib_string)
+    def test_oauth2_uses_auth_type_property(self, mock_smtplib, 
create_connection_without_db):
+        """Test that get_conn reads auth_type from connection extras, not just 
__init__ arg."""
+        mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False)
+
+        create_connection_without_db(
+            Connection(
+                conn_id="smtp_oauth2_extra",
+                conn_type=CONN_TYPE,
+                host=SMTP_HOST,
+                login=SMTP_LOGIN,
+                password=SMTP_PASSWORD,
+                port=NONSSL_PORT,
+                extra=json.dumps(
+                    dict(
+                        disable_ssl=True,
+                        from_email=FROM_EMAIL,
+                        auth_type="oauth2",
+                        access_token=ACCESS_TOKEN,
+                    )
+                ),
+            )
+        )
+
+        # Note: auth_type NOT passed to constructor -- should be read from 
extras
+        with SmtpHook(smtp_conn_id="smtp_oauth2_extra") as smtp_hook:
+            smtp_hook.send_email_smtp(
+                to=TO_EMAIL,
+                subject=TEST_SUBJECT,
+                html_content=TEST_BODY,
+                from_email=FROM_EMAIL,
+            )
+
+        assert mock_conn.auth.called
+        args, _ = mock_conn.auth.call_args
+        assert args[0] == "XOAUTH2"
+
+    @patch(smtplib_string)
+    def test_ehlo_called_after_starttls(self, mock_smtplib):
+        """Test that ehlo() is called after starttls() to re-establish session 
state."""
+        mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False)
+        manager = Mock()
+        mock_conn.starttls = manager.starttls
+        mock_conn.ehlo = manager.ehlo
+        mock_conn.login = manager.login
+
+        with SmtpHook(smtp_conn_id=CONN_ID_NONSSL):
+            pass
+
+        # Verify ehlo is called after starttls and before login
+        expected_calls = [call.starttls(), call.ehlo(), call.login(SMTP_LOGIN, 
SMTP_PASSWORD)]
+        assert manager.mock_calls == expected_calls
+
 
 @pytest.mark.asyncio
 @pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to 
BaseNotifier in 3.1.0")
@@ -523,6 +576,7 @@ class TestSmtpHookAsync:
         mock_client = AsyncMock(spec=aiosmtplib.SMTP)
         mock_client.starttls = AsyncMock()
         mock_client.auth_login = AsyncMock()
+        mock_client.auth_xoauth2 = AsyncMock()
         mock_client.sendmail = AsyncMock()
         mock_client.quit = AsyncMock()
         return mock_client
@@ -552,6 +606,7 @@ class TestSmtpHookAsync:
         mock_client = AsyncMock(spec=aiosmtplib.SMTP)
         mock_client.starttls = AsyncMock()
         mock_client.auth_login = AsyncMock()
+        mock_client.auth_xoauth2 = AsyncMock()
         mock_client.sendmail = AsyncMock()
         mock_client.quit = AsyncMock()
         mock_smtp.return_value = mock_client
@@ -650,3 +705,46 @@ class TestSmtpHookAsync:
             )
 
         mock_smtp_client.sendmail.assert_not_awaited()
+
+    async def test_async_ehlo_called_after_starttls(self, mock_smtp, 
mock_smtp_client, mock_get_connection):
+        """Test that ehlo() is called after starttls() in async path."""
+        async with SmtpHook(smtp_conn_id=CONN_ID_NONSSL):
+            pass
+
+        # For non-SSL, starttls is called followed by ehlo
+        assert mock_smtp_client.starttls.await_count == 1
+        assert mock_smtp_client.ehlo.await_count >= 2  # once in 
_abuild_client + once after starttls
+
+    async def test_async_oauth2_auth(
+        self, mock_smtp, mock_smtp_client, mock_get_connection, 
create_connection_without_db
+    ):
+        """Test that async path supports OAuth2 authentication."""
+        create_connection_without_db(
+            Connection(
+                conn_id=CONN_ID_OAUTH,
+                conn_type=CONN_TYPE,
+                host=SMTP_HOST,
+                login=SMTP_LOGIN,
+                password=SMTP_PASSWORD,
+                port=NONSSL_PORT,
+                extra=json.dumps(
+                    dict(
+                        disable_ssl=True,
+                        from_email=FROM_EMAIL,
+                        auth_type="oauth2",
+                        access_token=ACCESS_TOKEN,
+                    )
+                ),
+            )
+        )
+
+        async with SmtpHook(smtp_conn_id=CONN_ID_OAUTH) as hook:
+            await hook.asend_email_smtp(
+                to=TO_EMAIL,
+                subject=TEST_SUBJECT,
+                html_content=TEST_BODY,
+                from_email=FROM_EMAIL,
+            )
+
+        assert mock_smtp_client.auth_xoauth2.called
+        mock_smtp_client.auth_xoauth2.assert_awaited_once_with(SMTP_LOGIN, 
ACCESS_TOKEN)

Reply via email to