jason810496 commented on code in PR #51586: URL: https://github.com/apache/airflow/pull/51586#discussion_r2140285753
########## airflow-core/src/airflow/utils/db_discovery.py: ########## @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import socket +import time + +from sqlalchemy.engine.url import make_url + +from airflow.configuration import conf + +logger = logging.getLogger(__name__) + + +class DbDiscoveryStatus: + """Enum with the return value for `check_db_discovery_if_needed()`.""" + + # The hostname resolves. + OK = "ok" + # There has been some temporary DNS lookup blip and the connection will probably recover. + # Causes: a dns timeout or a temporary network issue. + TEMPORARY_ERROR = "dns_temporary_failure" + # Unknown hostname or service, this is permanent and the connection can't be recovered. + # Causes: a cmd or config typo, a hostname that doesn't exist. + UNKNOWN_HOSTNAME = "unknown_hostname" + # Unknown hostname or service, this is permanent and the connection can't be recovered. + # Causes: Failed DNS server or config typo. + PERMANENT_ERROR = "dns_permanent_failure" + # Some other error. + UNKNOWN_ERROR = "unknown_error" + + +# db status - how long ago it was retrieved +db_health_status: tuple[str, float] = (DbDiscoveryStatus.OK, 0.0) + +# TODO: For now, this is used for testing +# but it can also be used to add stats. +db_retry_count: int = 0 + + +def get_sleep_time(retry_attempt: int, initial_wait: float, max_wait: float) -> float: + return min(initial_wait * (2**retry_attempt), max_wait) + + +def _retry_exponential_backoff(retry_attempt: int, initial_wait: float, max_wait: float) -> None: + sleep_time = get_sleep_time(retry_attempt, initial_wait, max_wait) + unit_str = "second" if sleep_time == float(1) else "seconds" + logger.info("Sleeping for %.2f %s.", sleep_time, unit_str) + time.sleep(sleep_time) Review Comment: IMO, we can leverage `tenacity` for retrying, as `tenacity` is already used in airflow core for a while. e.g.`run_with_db_retries` https://github.com/apache/airflow/blob/c4d7edb250dd988037c3e0eb024dbb77ab14c165/airflow-core/src/airflow/utils/retries.py#L33-L48 ########## airflow-core/src/airflow/config_templates/config.yml: ########## @@ -679,6 +679,39 @@ database: type: integer example: ~ default: "10000" + check_db_discovery: + description: | + Whether to check the db discovery before creating a new session. + If enabled, the dns lookup to the db hostname will be checked with retries. + Accepts ``True`` or ``False``. + version_added: 3.1.0 + type: string Review Comment: ```suggestion type: boolean ``` ########## airflow-core/tests/unit/utils/test_db_discovery_status.py: ########## @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import socket + +import pytest + +from airflow.utils import db_discovery +from airflow.utils.db_discovery import DbDiscoveryStatus + + +class TestDbDiscoveryStatus: + @pytest.mark.parametrize( + "retry, expected_sleep_time", + [ + pytest.param(0, 0.5, id="attempt-1"), + pytest.param(1, 1, id="attempt-2"), + pytest.param(2, 2, id="attempt-3"), + pytest.param(3, 4, id="attempt-4"), + pytest.param(4, 8, id="attempt-5"), + pytest.param(5, 15, id="attempt-6"), + pytest.param(6, 15, id="attempt-7"), + ], + ) + def test_get_sleep_time(self, retry: int, expected_sleep_time: float): + sleep = db_discovery.get_sleep_time(retry, 0.5, 15) + assert sleep == expected_sleep_time + + @pytest.mark.parametrize( + "error_code, expected_status", + [ + (socket.EAI_FAIL, DbDiscoveryStatus.PERMANENT_ERROR), + (socket.EAI_AGAIN, DbDiscoveryStatus.TEMPORARY_ERROR), + (socket.EAI_NONAME, DbDiscoveryStatus.UNKNOWN_HOSTNAME), + (socket.EAI_SYSTEM, DbDiscoveryStatus.UNKNOWN_ERROR), + ], + ) + def test_check_dns_resolution_with_retries(self, monkeypatch, error_code, expected_status): + def raise_exc(*args, **kwargs): + # The error message isn't important because the validation is based on the error code. + raise socket.gaierror(error_code, "patched failure") + + monkeypatch.setattr(socket, "getaddrinfo", raise_exc) Review Comment: nit: recommend to use `@patch` decorator to prevent side effect. ########## airflow-core/tests/unit/core/test_db_discovery.py: ########## @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import logging +import os +import shutil +import socket +import time + +import pytest +from sqlalchemy import text + +from airflow import settings +from airflow.utils import db_discovery +from airflow.utils.db_discovery import DbDiscoveryStatus + +log = logging.getLogger(__name__) + + +def dispose_connection_pool(): + """Dispose any cached sockets so that the next query will force a new connect.""" + settings.engine.dispose() + # Wait for SqlAlchemy. + time.sleep(0.5) + + +def make_db_test_call(): + """ + Create a session and execute a query. + + It will establish a new connection if there isn't one available. + New connections use DNS lookup. + """ + from airflow.utils.session import create_session + + with create_session() as session: + session.execute(text("SELECT 1")) + + +def assert_query_raises_exc(expected_error_msg: str, expected_status: str, expected_retry_num: int): + with pytest.raises(socket.gaierror, match=expected_error_msg): + make_db_test_call() + + assert len(db_discovery.db_health_status) == 2 + + assert db_discovery.db_health_status[0] == expected_status + assert db_discovery.db_retry_count == expected_retry_num + + [email protected]("postgres") +class TestDbDiscoveryIntegration: + @pytest.fixture + def patch_getaddrinfo_for_eai_fail(self, monkeypatch): + import socket + + def always_fail(*args, **kwargs): + # The error message isn't important, as long as the error code is EAI_FAIL. + raise socket.gaierror(socket.EAI_FAIL, "permanent failure") + + monkeypatch.setattr(socket, "getaddrinfo", always_fail) + + def test_dns_resolution_blip(self): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + + resolv_file = "/etc/resolv.conf" + resolv_backup = "/tmp/resolv.conf.bak" + + # Back up the original file so that it can later be restored. + shutil.copy(resolv_file, resolv_backup) + + try: + # Replace the IP with a bad resolver. + with open(resolv_file, "w", encoding="utf-8") as fh: + fh.write("nameserver 10.255.255.1\noptions timeout:1 attempts:1 ndots:0\n") + + # New connection + DNS lookup. + dispose_connection_pool() + assert_query_raises_exc( + expected_error_msg="Temporary failure in name resolution", + expected_status=DbDiscoveryStatus.TEMPORARY_ERROR, + expected_retry_num=3, + ) + + finally: + # Reset the values for the next tests. + db_discovery.db_health_status = (DbDiscoveryStatus.OK, 0.0) + db_discovery.db_retry_count = 0 + + # Restore the original file. + with contextlib.suppress(Exception): + shutil.copy(resolv_backup, resolv_file) + + def test_permanent_dns_failure(self, patch_getaddrinfo_for_eai_fail): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + + try: + # New connection + DNS lookup. + dispose_connection_pool() + assert_query_raises_exc( + expected_error_msg="permanent failure", + expected_status=DbDiscoveryStatus.PERMANENT_ERROR, + expected_retry_num=0, + ) + + finally: + # Reset the values for the next tests. + db_discovery.db_health_status = (DbDiscoveryStatus.OK, 0.0) + db_discovery.db_retry_count = 0 + + def test_invalid_hostname_in_config(self): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = ( + "postgresql+psycopg2://postgres:airflow@invalid/airflow" + ) + + try: + # New connection + DNS lookup. + dispose_connection_pool() + assert_query_raises_exc( + expected_error_msg="Name or service not known", + expected_status=DbDiscoveryStatus.UNKNOWN_HOSTNAME, + expected_retry_num=0, + ) + finally: + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = ( + "postgresql+psycopg2://postgres:airflow@postgres/airflow" + ) + + # Reset the values for the next tests. + db_discovery.db_health_status = (DbDiscoveryStatus.OK, 0.0) + db_discovery.db_retry_count = 0 + + @pytest.mark.parametrize( + "check_enabled", + [ + pytest.param(True, id="check-enabled"), + pytest.param(False, id="check-disabled"), + ], + ) + def test_no_errors(self, check_enabled: bool): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = str(check_enabled) + + dispose_connection_pool() + make_db_test_call() + + # No status checks and no retries. + assert db_discovery.db_health_status[0] == DbDiscoveryStatus.OK + assert db_discovery.db_retry_count == 0 Review Comment: nit: Should we mock `check_db_discovery_with_retries` here and check if called. ########## airflow-core/tests/unit/core/test_db_discovery.py: ########## @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import logging +import os +import shutil +import socket +import time + +import pytest +from sqlalchemy import text + +from airflow import settings +from airflow.utils import db_discovery +from airflow.utils.db_discovery import DbDiscoveryStatus + +log = logging.getLogger(__name__) + + +def dispose_connection_pool(): + """Dispose any cached sockets so that the next query will force a new connect.""" + settings.engine.dispose() + # Wait for SqlAlchemy. + time.sleep(0.5) + + +def make_db_test_call(): + """ + Create a session and execute a query. + + It will establish a new connection if there isn't one available. + New connections use DNS lookup. + """ + from airflow.utils.session import create_session + + with create_session() as session: + session.execute(text("SELECT 1")) + + +def assert_query_raises_exc(expected_error_msg: str, expected_status: str, expected_retry_num: int): + with pytest.raises(socket.gaierror, match=expected_error_msg): + make_db_test_call() + + assert len(db_discovery.db_health_status) == 2 + + assert db_discovery.db_health_status[0] == expected_status + assert db_discovery.db_retry_count == expected_retry_num + + [email protected]("postgres") +class TestDbDiscoveryIntegration: + @pytest.fixture + def patch_getaddrinfo_for_eai_fail(self, monkeypatch): + import socket + + def always_fail(*args, **kwargs): + # The error message isn't important, as long as the error code is EAI_FAIL. + raise socket.gaierror(socket.EAI_FAIL, "permanent failure") + + monkeypatch.setattr(socket, "getaddrinfo", always_fail) + + def test_dns_resolution_blip(self): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + + resolv_file = "/etc/resolv.conf" + resolv_backup = "/tmp/resolv.conf.bak" + + # Back up the original file so that it can later be restored. + shutil.copy(resolv_file, resolv_backup) + + try: + # Replace the IP with a bad resolver. + with open(resolv_file, "w", encoding="utf-8") as fh: + fh.write("nameserver 10.255.255.1\noptions timeout:1 attempts:1 ndots:0\n") + + # New connection + DNS lookup. + dispose_connection_pool() + assert_query_raises_exc( + expected_error_msg="Temporary failure in name resolution", + expected_status=DbDiscoveryStatus.TEMPORARY_ERROR, + expected_retry_num=3, + ) + + finally: + # Reset the values for the next tests. + db_discovery.db_health_status = (DbDiscoveryStatus.OK, 0.0) + db_discovery.db_retry_count = 0 + + # Restore the original file. + with contextlib.suppress(Exception): + shutil.copy(resolv_backup, resolv_file) + + def test_permanent_dns_failure(self, patch_getaddrinfo_for_eai_fail): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + + try: + # New connection + DNS lookup. + dispose_connection_pool() + assert_query_raises_exc( + expected_error_msg="permanent failure", + expected_status=DbDiscoveryStatus.PERMANENT_ERROR, + expected_retry_num=0, + ) + + finally: + # Reset the values for the next tests. + db_discovery.db_health_status = (DbDiscoveryStatus.OK, 0.0) + db_discovery.db_retry_count = 0 + + def test_invalid_hostname_in_config(self): + os.environ["AIRFLOW__DATABASE__CHECK_DB_DISCOVERY"] = "True" + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = ( + "postgresql+psycopg2://postgres:airflow@invalid/airflow" + ) Review Comment: nit: we can replace the setup/ teardown for airflow config with `@conf_vars` ctx manager decorator. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
