This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b35a07722 Add --retry and --retry-delay to "airflow db check" (#31836)
1b35a07722 is described below

commit 1b35a077221481e9bf4aeea07d1264973e7f3bf6
Author: Bruce <[email protected]>
AuthorDate: Thu Jun 15 01:54:09 2023 -0700

    Add --retry and --retry-delay to "airflow db check" (#31836)
---
 airflow/cli/cli_config.py             | 14 +++++++++++++-
 airflow/cli/commands/db_command.py    | 22 ++++++++++++++++++++--
 tests/cli/commands/test_db_command.py | 24 +++++++++++++++++++++++-
 3 files changed, 56 insertions(+), 4 deletions(-)

diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index b59fea7893..0c69571fea 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -500,6 +500,18 @@ ARG_DB_DROP_ARCHIVES = Arg(
     help="Drop the archive tables after exporting. Use with caution.",
     action="store_true",
 )
+ARG_DB_RETRY = Arg(
+    ("--retry",),
+    default=0,
+    type=positive_int(allow_zero=True),
+    help="Retry database check upon failure",
+)
+ARG_DB_RETRY_DELAY = Arg(
+    ("--retry-delay",),
+    default=1,
+    type=positive_int(allow_zero=False),
+    help="Wait time between retries in seconds",
+)
 
 # pool
 ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")
@@ -1620,7 +1632,7 @@ DB_COMMANDS = (
         name="check",
         help="Check if the database can be reached",
         func=lazy_load_command("airflow.cli.commands.db_command.check"),
-        args=(ARG_VERBOSE,),
+        args=(ARG_VERBOSE, ARG_DB_RETRY, ARG_DB_RETRY_DELAY),
     ),
     ActionCommand(
         name="clean",
diff --git a/airflow/cli/commands/db_command.py 
b/airflow/cli/commands/db_command.py
index c1fa0564a0..64d54cc22e 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -17,11 +17,13 @@
 """Database sub-commands."""
 from __future__ import annotations
 
+import logging
 import os
 import textwrap
 from tempfile import NamedTemporaryFile
 
 from packaging.version import parse as parse_version
+from tenacity import RetryCallState, Retrying, stop_after_attempt, wait_fixed
 
 from airflow import settings
 from airflow.exceptions import AirflowException
@@ -30,6 +32,8 @@ from airflow.utils.db import REVISION_HEADS_MAP
 from airflow.utils.db_cleanup import config_dict, drop_archived_tables, 
export_archived_records, run_cleanup
 from airflow.utils.process_utils import execute_interactive
 
+log = logging.getLogger(__name__)
+
 
 def initdb(args):
     """Initializes the metadata database."""
@@ -187,9 +191,23 @@ def shell(args):
 
 
 @cli_utils.action_cli(check_db=False)
-def check(_):
+def check(args):
     """Runs a check command that checks if db is available."""
-    db.check()
+    retries: int = args.retry
+    retry_delay: int = args.retry_delay
+
+    def _warn_remaining_retries(retrystate: RetryCallState):
+        remain = retries - retrystate.attempt_number
+        log.warning("%d retries remain. Will retry in %d seconds", remain, 
retry_delay)
+
+    for attempt in Retrying(
+        stop=stop_after_attempt(1 + retries),
+        wait=wait_fixed(retry_delay),
+        reraise=True,
+        before_sleep=_warn_remaining_retries,
+    ):
+        with attempt:
+            db.check()
 
 
 # lazily imported by CLI parser for `help` command
diff --git a/tests/cli/commands/test_db_command.py 
b/tests/cli/commands/test_db_command.py
index eefa88a1cc..edc691e7cd 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -17,12 +17,13 @@
 from __future__ import annotations
 
 from unittest import mock
-from unittest.mock import patch
+from unittest.mock import MagicMock, Mock, call, patch
 
 import pendulum
 import pytest
 from pytest import param
 from sqlalchemy.engine.url import make_url
+from sqlalchemy.exc import OperationalError
 
 from airflow.cli import cli_parser
 from airflow.cli.commands import db_command
@@ -271,6 +272,27 @@ class TestCliDb:
             db_command.downgrade(self.parser.parse_args(["db", "downgrade", 
"--to-revision", "abc"]))
             mock_dg.assert_called_with(to_revision="abc", from_revision=None, 
show_sql_only=False)
 
+    def test_check(self):
+        retry, retry_delay = 6, 9  # arbitrary but distinct number
+        args = self.parser.parse_args(
+            ["db", "check", "--retry", str(retry), "--retry-delay", 
str(retry_delay)]
+        )
+        sleep = MagicMock()
+        always_pass = Mock()
+        always_fail = Mock(side_effect=OperationalError("", None, None))
+
+        with patch("time.sleep", new=sleep), patch("airflow.utils.db.check", 
new=always_pass):
+            db_command.check(args)
+            always_pass.assert_called_once()
+            sleep.assert_not_called()
+
+        with patch("time.sleep", new=sleep), patch("airflow.utils.db.check", 
new=always_fail):
+            with pytest.raises(OperationalError):
+                db_command.check(args)
+            # With N retries there are N+1 total checks, hence N sleeps
+            always_fail.assert_has_calls([call()] * (retry + 1))
+            sleep.assert_has_calls([call(retry_delay)] * retry)
+
 
 class TestCLIDBClean:
     @classmethod

Reply via email to