https://github.com/python/cpython/commit/605022aeb69ae19cae1c020a6993ab5c433ce907
commit: 605022aeb69ae19cae1c020a6993ab5c433ce907
branch: main
author: ggqlq <124190229+gg...@users.noreply.github.com>
committer: picnixz <10796600+picn...@users.noreply.github.com>
date: 2025-05-19T12:15:04Z
summary:

gh-131178: Add tests for `http.server` command-line interface (#132540)

files:
M Lib/http/server.py
M Lib/test/test_httpservers.py

diff --git a/Lib/http/server.py b/Lib/http/server.py
index abf9f87a1fc711..f6d1b998f4201e 100644
--- a/Lib/http/server.py
+++ b/Lib/http/server.py
@@ -1000,7 +1000,7 @@ def test(HandlerClass=BaseHTTPRequestHandler,
             sys.exit(0)
 
 
-if __name__ == '__main__':
+def _main(args=None):
     import argparse
     import contextlib
 
@@ -1024,7 +1024,7 @@ def test(HandlerClass=BaseHTTPRequestHandler,
     parser.add_argument('port', default=8000, type=int, nargs='?',
                         help='bind to this port '
                              '(default: %(default)s)')
-    args = parser.parse_args()
+    args = parser.parse_args(args)
 
     if not args.tls_cert and args.tls_key:
         parser.error("--tls-key requires --tls-cert to be set")
@@ -1064,3 +1064,7 @@ def finish_request(self, request, client_address):
         tls_key=args.tls_key,
         tls_password=tls_key_password,
     )
+
+
+if __name__ == '__main__':
+    _main()
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index df5d2a7bedc4b2..429d9005bd7605 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -8,6 +8,7 @@
      SimpleHTTPRequestHandler
 from http import server, HTTPStatus
 
+import contextlib
 import os
 import socket
 import sys
@@ -20,6 +21,7 @@
 import html
 import http, http.client
 import urllib.parse
+import urllib.request
 import tempfile
 import time
 import datetime
@@ -32,6 +34,8 @@
 from test.support import (
     is_apple, import_helper, os_helper, threading_helper
 )
+from test.support.script_helper import kill_python, spawn_python
+from test.support.socket_helper import find_unused_port
 
 try:
     import ssl
@@ -1281,6 +1285,256 @@ def test_server_test_ipv4(self, _):
             self.assertEqual(mock_server.address_family, socket.AF_INET)
 
 
+class CommandLineTestCase(unittest.TestCase):
+    default_port = 8000
+    default_bind = None
+    default_protocol = 'HTTP/1.0'
+    default_handler = SimpleHTTPRequestHandler
+    default_server = unittest.mock.ANY
+    tls_cert = certdata_file('ssl_cert.pem')
+    tls_key = certdata_file('ssl_key.pem')
+    tls_password = 'somepass'
+    tls_cert_options = ['--tls-cert']
+    tls_key_options = ['--tls-key']
+    tls_password_options = ['--tls-password-file']
+    args = {
+        'HandlerClass': default_handler,
+        'ServerClass': default_server,
+        'protocol': default_protocol,
+        'port': default_port,
+        'bind': default_bind,
+        'tls_cert': None,
+        'tls_key': None,
+        'tls_password': None,
+    }
+
+    def setUp(self):
+        super().setUp()
+        self.tls_password_file = tempfile.mktemp()
+        with open(self.tls_password_file, 'wb') as f:
+            f.write(self.tls_password.encode())
+        self.addCleanup(os_helper.unlink, self.tls_password_file)
+
+    def invoke_httpd(self, *args, stdout=None, stderr=None):
+        stdout = StringIO() if stdout is None else stdout
+        stderr = StringIO() if stderr is None else stderr
+        with contextlib.redirect_stdout(stdout), \
+            contextlib.redirect_stderr(stderr):
+            server._main(args)
+        return stdout.getvalue(), stderr.getvalue()
+
+    @mock.patch('http.server.test')
+    def test_port_flag(self, mock_func):
+        ports = [8000, 65535]
+        for port in ports:
+            with self.subTest(port=port):
+                self.invoke_httpd(str(port))
+                call_args = self.args | dict(port=port)
+                mock_func.assert_called_once_with(**call_args)
+                mock_func.reset_mock()
+
+    @mock.patch('http.server.test')
+    def test_directory_flag(self, mock_func):
+        options = ['-d', '--directory']
+        directories = ['.', '/foo', '\\bar', '/',
+                       'C:\\', 'C:\\foo', 'C:\\bar',
+                       '/home/user', './foo/foo2', 'D:\\foo\\bar']
+        for flag in options:
+            for directory in directories:
+                with self.subTest(flag=flag, directory=directory):
+                    self.invoke_httpd(flag, directory)
+                    mock_func.assert_called_once_with(**self.args)
+                    mock_func.reset_mock()
+
+    @mock.patch('http.server.test')
+    def test_bind_flag(self, mock_func):
+        options = ['-b', '--bind']
+        bind_addresses = ['localhost', '127.0.0.1', '::1',
+                          '0.0.0.0', '8.8.8.8']
+        for flag in options:
+            for bind_address in bind_addresses:
+                with self.subTest(flag=flag, bind_address=bind_address):
+                    self.invoke_httpd(flag, bind_address)
+                    call_args = self.args | dict(bind=bind_address)
+                    mock_func.assert_called_once_with(**call_args)
+                    mock_func.reset_mock()
+
+    @mock.patch('http.server.test')
+    def test_protocol_flag(self, mock_func):
+        options = ['-p', '--protocol']
+        protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0']
+        for flag in options:
+            for protocol in protocols:
+                with self.subTest(flag=flag, protocol=protocol):
+                    self.invoke_httpd(flag, protocol)
+                    call_args = self.args | dict(protocol=protocol)
+                    mock_func.assert_called_once_with(**call_args)
+                    mock_func.reset_mock()
+
+    @unittest.skipIf(ssl is None, "requires ssl")
+    @mock.patch('http.server.test')
+    def test_tls_cert_and_key_flags(self, mock_func):
+        for tls_cert_option in self.tls_cert_options:
+            for tls_key_option in self.tls_key_options:
+                self.invoke_httpd(tls_cert_option, self.tls_cert,
+                                  tls_key_option, self.tls_key)
+                call_args = self.args | {
+                    'tls_cert': self.tls_cert,
+                    'tls_key': self.tls_key,
+                }
+                mock_func.assert_called_once_with(**call_args)
+                mock_func.reset_mock()
+
+    @unittest.skipIf(ssl is None, "requires ssl")
+    @mock.patch('http.server.test')
+    def test_tls_cert_and_key_and_password_flags(self, mock_func):
+        for tls_cert_option in self.tls_cert_options:
+            for tls_key_option in self.tls_key_options:
+                for tls_password_option in self.tls_password_options:
+                    self.invoke_httpd(tls_cert_option,
+                                      self.tls_cert,
+                                      tls_key_option,
+                                      self.tls_key,
+                                      tls_password_option,
+                                      self.tls_password_file)
+                    call_args = self.args | {
+                        'tls_cert': self.tls_cert,
+                        'tls_key': self.tls_key,
+                        'tls_password': self.tls_password,
+                    }
+                    mock_func.assert_called_once_with(**call_args)
+                    mock_func.reset_mock()
+
+    @unittest.skipIf(ssl is None, "requires ssl")
+    @mock.patch('http.server.test')
+    def test_missing_tls_cert_flag(self, mock_func):
+        for tls_key_option in self.tls_key_options:
+            with self.assertRaises(SystemExit):
+                self.invoke_httpd(tls_key_option, self.tls_key)
+            mock_func.reset_mock()
+
+        for tls_password_option in self.tls_password_options:
+            with self.assertRaises(SystemExit):
+                self.invoke_httpd(tls_password_option, self.tls_password)
+            mock_func.reset_mock()
+
+    @unittest.skipIf(ssl is None, "requires ssl")
+    @mock.patch('http.server.test')
+    def test_invalid_password_file(self, mock_func):
+        non_existent_file = 'non_existent_file'
+        for tls_password_option in self.tls_password_options:
+            for tls_cert_option in self.tls_cert_options:
+                with self.assertRaises(SystemExit):
+                    self.invoke_httpd(tls_cert_option,
+                                      self.tls_cert,
+                                      tls_password_option,
+                                      non_existent_file)
+
+    @mock.patch('http.server.test')
+    def test_no_arguments(self, mock_func):
+        self.invoke_httpd()
+        mock_func.assert_called_once_with(**self.args)
+        mock_func.reset_mock()
+
+    @mock.patch('http.server.test')
+    def test_help_flag(self, _):
+        options = ['-h', '--help']
+        for option in options:
+            stdout, stderr = StringIO(), StringIO()
+            with self.assertRaises(SystemExit):
+                self.invoke_httpd(option, stdout=stdout, stderr=stderr)
+            self.assertIn('usage', stdout.getvalue())
+            self.assertEqual(stderr.getvalue(), '')
+
+    @mock.patch('http.server.test')
+    def test_unknown_flag(self, _):
+        stdout, stderr = StringIO(), StringIO()
+        with self.assertRaises(SystemExit):
+            self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr)
+        self.assertEqual(stdout.getvalue(), '')
+        self.assertIn('error', stderr.getvalue())
+
+
+class CommandLineRunTimeTestCase(unittest.TestCase):
+    served_data = os.urandom(32)
+    served_file_name = 'served_filename'
+    tls_cert = certdata_file('ssl_cert.pem')
+    tls_key = certdata_file('ssl_key.pem')
+    tls_password = 'somepass'
+
+    def setUp(self):
+        super().setUp()
+        with open(self.served_file_name, 'wb') as f:
+            f.write(self.served_data)
+        self.addCleanup(os_helper.unlink, self.served_file_name)
+        self.tls_password_file = tempfile.mktemp()
+        with open(self.tls_password_file, 'wb') as f:
+            f.write(self.tls_password.encode())
+        self.addCleanup(os_helper.unlink, self.tls_password_file)
+
+    def fetch_file(self, path):
+        context = ssl.create_default_context()
+        # allow self-signed certificates
+        context.check_hostname = False
+        context.verify_mode = ssl.CERT_NONE
+        req = urllib.request.Request(path, method='GET')
+        with urllib.request.urlopen(req, context=context) as res:
+            return res.read()
+
+    def parse_cli_output(self, output):
+        matches = re.search(r'\((https?)://([^/:]+):(\d+)/?\)', output)
+        if matches is None:
+            return None, None, None
+        return matches.group(1), matches.group(2), int(matches.group(3))
+
+    def wait_for_server(self, proc, protocol, port, bind, timeout=50):
+        """Check the server process output.
+
+        Return True if the server was successfully started
+        and is listening on the given port and bind address.
+        """
+        while timeout > 0:
+            line = proc.stdout.readline()
+            if not line:
+                time.sleep(0.1)
+                timeout -= 1
+                continue
+            protocol_, host_, port_ = self.parse_cli_output(line)
+            if not protocol_ or not host_ or not port_:
+                time.sleep(0.1)
+                timeout -= 1
+                continue
+            if protocol_ == protocol and host_ == bind and port_ == port:
+                return True
+            break
+        return False
+
+    def test_http_client(self):
+        port = find_unused_port()
+        bind = '127.0.0.1'
+        proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+                            bufsize=1, text=True)
+        self.addCleanup(kill_python, proc)
+        self.addCleanup(proc.terminate)
+        self.assertTrue(self.wait_for_server(proc, 'http', port, bind))
+        res = self.fetch_file(f'http://{bind}:{port}/{self.served_file_name}')
+        self.assertEqual(res, self.served_data)
+
+    def test_https_client(self):
+        port = find_unused_port()
+        bind = '127.0.0.1'
+        proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+                            '--tls-cert', self.tls_cert,
+                            '--tls-key', self.tls_key,
+                            '--tls-password-file', self.tls_password_file,
+                            bufsize=1, text=True)
+        self.addCleanup(kill_python, proc)
+        self.addCleanup(proc.terminate)
+        self.assertTrue(self.wait_for_server(proc, 'https', port, bind))
+        res = self.fetch_file(f'https://{bind}:{port}/{self.served_file_name}')
+        self.assertEqual(res, self.served_data)
+
+
 def setUpModule():
     unittest.addModuleCleanup(os.chdir, os.getcwd())
 

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to