This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d751bd8796 Enabling TLS arguments for FastAPI (#42395)
d751bd8796 is described below

commit d751bd87960795a8977f7459f2a670b574deda6a
Author: JoshuaXOng <[email protected]>
AuthorDate: Sun Sep 22 22:43:21 2024 +1000

    Enabling TLS arguments for FastAPI (#42395)
    
    * Added TLS arguments for FastAPI
    
    * Uplifted FastAPI TLS tests
---
 airflow/cli/commands/fastapi_api_command.py    | 25 ++++++++++++++
 tests/cli/commands/test_fastapi_api_command.py | 45 ++++++++++++++++++++++++--
 2 files changed, 67 insertions(+), 3 deletions(-)

diff --git a/airflow/cli/commands/fastapi_api_command.py 
b/airflow/cli/commands/fastapi_api_command.py
index d50d454347..c11da959ce 100644
--- a/airflow/cli/commands/fastapi_api_command.py
+++ b/airflow/cli/commands/fastapi_api_command.py
@@ -36,6 +36,7 @@ from uvicorn.workers import UvicornWorker
 from airflow import settings
 from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.cli.commands.webserver_command import GunicornMonitor
+from airflow.exceptions import AirflowConfigException
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import setup_locations
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
@@ -124,6 +125,10 @@ def fastapi_api(args):
             "python:airflow.api_fastapi.gunicorn_config",
         ]
 
+        ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args)
+        if ssl_cert and ssl_key:
+            run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key]
+
         if args.access_logformat and args.access_logformat.strip():
             run_args += ["--access-logformat", str(args.access_logformat)]
 
@@ -199,3 +204,23 @@ def fastapi_api(args):
             should_setup_logging=True,
             pid_file=monitor_pid_file,
         )
+
+
+def _get_ssl_cert_and_key_filepaths(cli_arguments) -> tuple[str | None, str | 
None]:
+    error_template_1 = "Need both, have provided {} but not {}"
+    error_template_2 = "SSL related file does not exist {}"
+
+    ssl_cert, ssl_key = cli_arguments.ssl_cert, cli_arguments.ssl_key
+    if ssl_cert and ssl_key:
+        if not os.path.isfile(ssl_cert):
+            raise AirflowConfigException(error_template_2.format(ssl_cert))
+        if not os.path.isfile(ssl_key):
+            raise AirflowConfigException(error_template_2.format(ssl_key))
+
+        return (ssl_cert, ssl_key)
+    elif ssl_cert:
+        raise AirflowConfigException(error_template_1.format("SSL 
certificate", "SSL key"))
+    elif ssl_key:
+        raise AirflowConfigException(error_template_1.format("SSL key", "SSL 
certificate"))
+
+    return (None, None)
diff --git a/tests/cli/commands/test_fastapi_api_command.py 
b/tests/cli/commands/test_fastapi_api_command.py
index 529c67f5ed..1c1af3342f 100644
--- a/tests/cli/commands/test_fastapi_api_command.py
+++ b/tests/cli/commands/test_fastapi_api_command.py
@@ -27,13 +27,14 @@ import pytest
 from rich.console import Console
 
 from airflow.cli.commands import fastapi_api_command
+from airflow.exceptions import AirflowConfigException
 from tests.cli.commands._common_cli_classes import _CommonCLIGunicornTestClass
 
 console = Console(width=400, color_system="standard")
 
 
 @pytest.mark.db_test
-class TestCliInternalAPI(_CommonCLIGunicornTestClass):
+class TestCliFastAPI(_CommonCLIGunicornTestClass):
     main_process_regexp = r"airflow fastapi-api"
 
     @pytest.mark.execution_timeout(210)
@@ -46,7 +47,7 @@ class TestCliInternalAPI(_CommonCLIGunicornTestClass):
         stderr = parent_path / "airflow-fastapi-api.err"
         logfile = parent_path / "airflow-fastapi-api.log"
         try:
-            # Run internal-api as daemon in background. Note that the wait 
method is not called.
+            # Run fastapi-api as daemon in background. Note that the wait 
method is not called.
             console.print("[magenta]Starting airflow fastapi-api --daemon")
             env = os.environ.copy()
             proc = subprocess.Popen(
@@ -123,7 +124,9 @@ class TestCliInternalAPI(_CommonCLIGunicornTestClass):
                 close_fds=True,
             )
 
-    def test_cli_fastapi_api_args(self):
+    def test_cli_fastapi_api_args(self, ssl_cert_and_key):
+        cert_path, key_path = ssl_cert_and_key
+
         with mock.patch("subprocess.Popen") as Popen, mock.patch.object(
             fastapi_api_command, "GunicornMonitor"
         ):
@@ -134,6 +137,10 @@ class TestCliInternalAPI(_CommonCLIGunicornTestClass):
                     "custom_log_format",
                     "--pid",
                     "/tmp/x.pid",
+                    "--ssl-cert",
+                    str(cert_path),
+                    "--ssl-key",
+                    str(key_path),
                 ]
             )
             fastapi_api_command.fastapi_api(args)
@@ -161,6 +168,10 @@ class TestCliInternalAPI(_CommonCLIGunicornTestClass):
                     "-",
                     "--config",
                     "python:airflow.api_fastapi.gunicorn_config",
+                    "--certfile",
+                    str(cert_path),
+                    "--keyfile",
+                    str(key_path),
                     "--access-logformat",
                     "custom_log_format",
                     "airflow.api_fastapi.app:cached_app()",
@@ -168,3 +179,31 @@ class TestCliInternalAPI(_CommonCLIGunicornTestClass):
                 ],
                 close_fds=True,
             )
+
+    @pytest.mark.parametrize(
+        "ssl_arguments, error_pattern",
+        [
+            (["--ssl-cert", "_.crt", "--ssl-key", "_.key"], "does not exist 
_.crt"),
+            (["--ssl-cert", "_.crt"], "Need both.*certificate.*key"),
+            (["--ssl-key", "_.key"], "Need both.*key.*certificate"),
+        ],
+    )
+    def test_get_ssl_cert_and_key_filepaths_with_incorrect_usage(self, 
ssl_arguments, error_pattern):
+        args = self.parser.parse_args(["fastapi-api"] + ssl_arguments)
+        with pytest.raises(AirflowConfigException, match=error_pattern):
+            fastapi_api_command._get_ssl_cert_and_key_filepaths(args)
+
+    def test_get_ssl_cert_and_key_filepaths_with_correct_usage(self, 
ssl_cert_and_key):
+        cert_path, key_path = ssl_cert_and_key
+
+        args = self.parser.parse_args(
+            ["fastapi-api"] + ["--ssl-cert", str(cert_path), "--ssl-key", 
str(key_path)]
+        )
+        assert fastapi_api_command._get_ssl_cert_and_key_filepaths(args) == 
(str(cert_path), str(key_path))
+
+    @pytest.fixture
+    def ssl_cert_and_key(self, tmp_path):
+        cert_path, key_path = tmp_path / "_.crt", tmp_path / "_.key"
+        cert_path.touch()
+        key_path.touch()
+        return cert_path, key_path

Reply via email to