delete mode 100644 testenv/FTPServer.py delete mode 100644 testenv/HTTPServer.py create mode 100644 testenv/misc/constants.py create mode 100644 testenv/server/__init__.py create mode 100644 testenv/server/ftp/__init__.py create mode 100644 testenv/server/ftp/ftp_server.py create mode 100644 testenv/server/http/__init__.py create mode 100644 testenv/server/http/http_server.py
diff --git a/testenv/ChangeLog b/testenv/ChangeLog index 73b92e7..390becc 100644 --- a/testenv/ChangeLog +++ b/testenv/ChangeLog @@ -1,4 +1,13 @@ 2014-03-13 Zihang Chen <[email protected]> + * server: (new package) package for the server classes + * server.http: (new package) package for HTTP server + * server.ftp: (new package) package for FTP server + * HTTPServer.py: Move to server/http/http_server.py. Also change the + CERTFILE to '../certs/wget-cert.pem'. + * FTPServer.py: Move to server/ftp/ftp_server.py. + * WgetTest.py: Optimize import respect to the server classes. + (HTTP, HTTPS): Theses two string constants are move to misc/constants.py. +2014-03-13 Zihang Chen <[email protected]> * conf: (new package) package for rule classes and hook methods * WgetTest.py: (CommonMethods.Authentication): Move to conf/authentication.py. diff --git a/testenv/FTPServer.py b/testenv/FTPServer.py deleted file mode 100644 index f7d7771..0000000 --- a/testenv/FTPServer.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import re -import threading -import socket -import pyftpdlib.__main__ -from pyftpdlib.ioloop import IOLoop -import pyftpdlib.handlers as Handle -from pyftpdlib.servers import FTPServer -from pyftpdlib.authorizers import DummyAuthorizer -from pyftpdlib._compat import PY3, u, b, getcwdu, callable - -class FTPDHandler (Handle.FTPHandler): - - def ftp_LIST (self, path): - try: - iterator = self.run_as_current_user(self.fs.get_list_dir, path) - except (OSError, FilesystemError): - err = sys.exc_info()[1] - why = _strerror (err) - self.respond ('550 %s. ' % why) - else: - if self.isRule ("Bad List") is True: - iter_list = list () - for flist in iterator: - line = re.compile (r'(\s+)').split (flist.decode ('utf-8')) - line[8] = '0' - iter_l = ''.join (line).encode ('utf-8') - iter_list.append (iter_l) - iterator = (n for n in iter_list) - producer = Handle.BufferedIteratorProducer (iterator) - self.push_dtp_data (producer, isproducer=True, cmd="LIST") - return path - - def ftp_PASV (self, line): - if self._epsvall: - self.respond ("501 PASV not allowed after EPSV ALL.") - return - self._make_epasv(extmode=False) - if self.isRule ("FailPASV") is True: - del self.server.global_rules["FailPASV"] - self.socket.close () - - def isRule (self, rule): - rule_obj = self.server.global_rules[rule] - return False if not rule_obj else rule_obj[0] - -class FTPDServer (FTPServer): - - def set_global_rules (self, rules): - self.global_rules = rules - -class FTPd(threading.Thread): - """A threaded FTP server used for running tests. - - This is basically a modified version of the FTPServer class which - wraps the polling loop into a thread. - - The instance returned can be used to start(), stop() and - eventually re-start() the server. - """ - handler = FTPDHandler - server_class = FTPDServer - - def __init__(self, addr=None): - os.mkdir ('server') - os.chdir ('server') - try: - HOST = socket.gethostbyname ('localhost') - except socket.error: - HOST = 'localhost' - USER = 'user' - PASSWD = '12345' - HOME = getcwdu () - - threading.Thread.__init__(self) - self.__serving = False - self.__stopped = False - self.__lock = threading.Lock() - self.__flag = threading.Event() - if addr is None: - addr = (HOST, 0) - - authorizer = DummyAuthorizer() - authorizer.add_user(USER, PASSWD, HOME, perm='elradfmwM') # full perms - authorizer.add_anonymous(HOME) - self.handler.authorizer = authorizer - # lowering buffer sizes = more cycles to transfer data - # = less false positive test failures - self.handler.dtp_handler.ac_in_buffer_size = 32768 - self.handler.dtp_handler.ac_out_buffer_size = 32768 - self.server = self.server_class(addr, self.handler) - self.host, self.port = self.server.socket.getsockname()[:2] - os.chdir ('..') - - def set_global_rules (self, rules): - self.server.set_global_rules (rules) - - def __repr__(self): - status = [self.__class__.__module__ + "." + self.__class__.__name__] - if self.__serving: - status.append('active') - else: - status.append('inactive') - status.append('%s:%s' % self.server.socket.getsockname()[:2]) - return '<%s at %#x>' % (' '.join(status), id(self)) - - @property - def running(self): - return self.__serving - - def start(self, timeout=0.001): - """Start serving until an explicit stop() request. - Polls for shutdown every 'timeout' seconds. - """ - if self.__serving: - raise RuntimeError("Server already started") - if self.__stopped: - # ensure the server can be started again - FTPd.__init__(self, self.server.socket.getsockname(), self.handler) - self.__timeout = timeout - threading.Thread.start(self) - self.__flag.wait() - - def run(self): - self.__serving = True - self.__flag.set() - while self.__serving: - self.__lock.acquire() - self.server.serve_forever(timeout=self.__timeout, blocking=False) - self.__lock.release() - self.server.close_all() - - def stop(self): - """Stop serving (also disconnecting all currently connected - clients) by telling the serve_forever() loop to stop and - waits until it does. - """ - if not self.__serving: - raise RuntimeError("Server not started yet") - self.__serving = False - self.__stopped = True - self.join() - - -def mk_file_sys (file_list): - os.chdir ('server') - for name, content in file_list.items (): - file_h = open (name, 'w') - file_h.write (content) - file_h.close () - os.chdir ('..') - -def filesys (): - fileSys = dict () - os.chdir ('server') - for parent, dirs, files in os.walk ('.'): - for filename in files: - file_handle = open (filename, 'r') - file_content = file_handle.read () - fileSys[filename] = file_content - os.chdir ('..') - return fileSys diff --git a/testenv/HTTPServer.py b/testenv/HTTPServer.py deleted file mode 100644 index e554a10..0000000 --- a/testenv/HTTPServer.py +++ /dev/null @@ -1,467 +0,0 @@ -from http.server import HTTPServer, BaseHTTPRequestHandler -from socketserver import BaseServer -from posixpath import basename, splitext -from base64 import b64encode -from random import random -from hashlib import md5 -import threading -import socket -import re -import ssl -import os - - -class InvalidRangeHeader (Exception): - - """ Create an Exception for handling of invalid Range Headers. """ - # TODO: Eliminate this exception and use only ServerError - - def __init__ (self, err_message): - self.err_message = err_message - -class ServerError (Exception): - def __init__ (self, err_message): - self.err_message = err_message - - -class StoppableHTTPServer (HTTPServer): - - request_headers = list () - - """ Define methods for configuring the Server. """ - - def server_conf (self, filelist, conf_dict): - """ Set Server Rules and File System for this instance. """ - self.server_configs = conf_dict - self.fileSys = filelist - - def server_sett (self, settings): - for settings_key in settings: - setattr (self.RequestHandlerClass, settings_key, settings[settings_key]) - - def get_req_headers (self): - return self.request_headers - -class HTTPSServer (StoppableHTTPServer): - - def __init__ (self, address, handler): - BaseServer.__init__ (self, address, handler) - print (os.getcwd()) - CERTFILE = os.path.abspath (os.path.join ('..', 'certs', 'wget-cert.pem')) - print (CERTFILE) - fop = open (CERTFILE) - print (fop.readline()) - self.socket = ssl.wrap_socket ( - sock = socket.socket (self.address_family, self.socket_type), - ssl_version = ssl.PROTOCOL_TLSv1, - certfile = CERTFILE, - server_side = True - ) - self.server_bind () - self.server_activate () - -class WgetHTTPRequestHandler (BaseHTTPRequestHandler): - - """ Define methods for handling Test Checks. """ - - def get_rule_list (self, name): - r_list = self.rules.get (name) if name in self.rules else None - return r_list - - -class _Handler (WgetHTTPRequestHandler): - - """ Define Handler Methods for different Requests. """ - - InvalidRangeHeader = InvalidRangeHeader - protocol_version = 'HTTP/1.1' - - """ Define functions for various HTTP Requests. """ - - def do_HEAD (self): - self.send_head ("HEAD") - - def do_GET (self): - content, start = self.send_head ("GET") - if content: - if start is None: - self.wfile.write (content.encode ('utf-8')) - else: - self.wfile.write (content.encode ('utf-8')[start:]) - - def do_POST (self): - path = self.path[1:] - self.rules = self.server.server_configs.get (path) - if not self.custom_response (): - return (None, None) - if path in self.server.fileSys: - body_data = self.get_body_data () - self.send_response (200) - self.send_header ("Content-type", "text/plain") - content = self.server.fileSys.pop (path) + "\n" + body_data - total_length = len (content) - self.server.fileSys[path] = content - self.send_header ("Content-Length", total_length) - self.finish_headers () - try: - self.wfile.write (content.encode ('utf-8')) - except Exception: - pass - else: - self.send_put (path) - - def do_PUT (self): - path = self.path[1:] - self.rules = self.server.server_configs.get (path) - if not self.custom_response (): - return (None, None) - self.server.fileSys.pop (path, None) - self.send_put (path) - - """ End of HTTP Request Method Handlers. """ - - """ Helper functions for the Handlers. """ - - def parse_range_header (self, header_line, length): - if header_line is None: - return None - if not header_line.startswith ("bytes="): - raise InvalidRangeHeader ("Cannot parse header Range: %s" % - (header_line)) - regex = re.match (r"^bytes=(\d*)\-$", header_line) - range_start = int (regex.group (1)) - if range_start >= length: - raise InvalidRangeHeader ("Range Overflow") - return range_start - - def get_body_data (self): - cLength_header = self.headers.get ("Content-Length") - cLength = int (cLength_header) if cLength_header is not None else 0 - body_data = self.rfile.read (cLength).decode ('utf-8') - return body_data - - def send_put (self, path): - body_data = self.get_body_data () - self.send_response (201) - self.server.fileSys[path] = body_data - self.send_header ("Content-type", "text/plain") - self.send_header ("Content-Length", len (body_data)) - self.finish_headers () - try: - self.wfile.write (body_data.encode ('utf-8')) - except Exception: - pass - - def SendHeader (self, header_obj): - pass -# headers_list = header_obj.headers -# for header_line in headers_list: -# print (header_line + " : " + headers_list[header_line]) -# self.send_header (header_line, headers_list[header_line]) - - def send_cust_headers (self): - header_obj = self.get_rule_list ('SendHeader') - if header_obj: - for header in header_obj.headers: - self.send_header (header, header_obj.headers[header]) - - def finish_headers (self): - self.send_cust_headers () - self.end_headers () - - def Response (self, resp_obj): - self.send_response (resp_obj.response_code) - self.finish_headers () - raise ServerError ("Custom Response code sent.") - - def custom_response (self): - codes = self.get_rule_list ('Response') - if codes: - self.send_response (codes.response_code) - self.finish_headers () - return False - else: - return True - - def base64 (self, data): - string = b64encode (data.encode ('utf-8')) - return string.decode ('utf-8') - - def send_challenge (self, auth_type): - if auth_type == "Both": - self.send_challenge ("Digest") - self.send_challenge ("Basic") - return - if auth_type == "Basic": - challenge_str = 'Basic realm="Wget-Test"' - elif auth_type == "Digest" or auth_type == "Both_inline": - self.nonce = md5 (str (random ()).encode ('utf-8')).hexdigest () - self.opaque = md5 (str (random ()).encode ('utf-8')).hexdigest () - challenge_str = 'Digest realm="Test", nonce="%s", opaque="%s"' %( - self.nonce, - self.opaque) - challenge_str += ', qop="auth"' - if auth_type == "Both_inline": - challenge_str = 'Basic realm="Wget-Test", ' + challenge_str - self.send_header ("WWW-Authenticate", challenge_str) - - def authorize_Basic (self, auth_header, auth_rule): - if auth_header is None or auth_header.split(' ')[0] != 'Basic': - return False - else: - self.user = auth_rule.auth_user - self.passw = auth_rule.auth_pass - auth_str = "Basic " + self.base64 (self.user + ":" + self.passw) - return True if auth_str == auth_header else False - - def parse_auth_header (self, auth_header): - n = len("Digest ") - auth_header = auth_header[n:].strip() - items = auth_header.split(", ") - key_values = [i.split("=", 1) for i in items] - key_values = [(k.strip(), v.strip().replace('"', '')) for k, v in key_values] - return dict(key_values) - - def KD (self, secret, data): - return self.H (secret + ":" + data) - - def H (self, data): - return md5 (data.encode ('utf-8')).hexdigest () - - def A1 (self): - return "%s:%s:%s" % (self.user, "Test", self.passw) - - def A2 (self, params): - return "%s:%s" % (self.command, params["uri"]) - - def check_response (self, params): - if "qop" in params: - data_str = params['nonce'] \ - + ":" + params['nc'] \ - + ":" + params['cnonce'] \ - + ":" + params['qop'] \ - + ":" + self.H (self.A2 (params)) - else: - data_str = params['nonce'] + ":" + self.H (self.A2 (params)) - resp = self.KD (self.H (self.A1 ()), data_str) - - return True if resp == params['response'] else False - - def authorize_Digest (self, auth_header, auth_rule): - if auth_header is None or auth_header.split(' ')[0] != 'Digest': - return False - else: - self.user = auth_rule.auth_user - self.passw = auth_rule.auth_pass - params = self.parse_auth_header (auth_header) - pass_auth = True - if self.user != params['username'] or \ - self.nonce != params['nonce'] or self.opaque != params['opaque']: - pass_auth = False - req_attribs = ['username', 'realm', 'nonce', 'uri', 'response'] - for attrib in req_attribs: - if not attrib in params: - pass_auth = False - if not self.check_response (params): - pass_auth = False - return pass_auth - - def authorize_Both (self, auth_header, auth_rule): - return False - - def authorize_Both_inline (self, auth_header, auth_rule): - return False - - def Authentication (self, auth_rule): - try: - self.handle_auth (auth_rule) - except ServerError as se: - self.send_response (401, "Authorization Required") - self.send_challenge (auth_rule.auth_type) - self.finish_headers () - raise ServerError (se.__str__()) - - def handle_auth (self, auth_rule): - is_auth = True - auth_header = self.headers.get ("Authorization") - required_auth = auth_rule.auth_type - if required_auth == "Both" or required_auth == "Both_inline": - auth_type = auth_header.split(' ')[0] if auth_header else required_auth - else: - auth_type = required_auth - try: - assert hasattr (self, "authorize_" + auth_type) - is_auth = getattr (self, "authorize_" + auth_type) (auth_header, auth_rule) - except AssertionError: - raise ServerError ("Authentication Mechanism " + auth_rule + " not supported") - except AttributeError as ae: - raise ServerError (ae.__str__()) - if is_auth is False: - raise ServerError ("Unable to Authenticate") - - def is_authorized (self): - is_auth = True - auth_rule = self.get_rule_list ('Authentication') - if auth_rule: - auth_header = self.headers.get ("Authorization") - req_auth = auth_rule.auth_type - if req_auth == "Both" or req_auth == "Both_inline": - auth_type = auth_header.split(' ')[0] if auth_header else req_auth - else: - auth_type = req_auth - assert hasattr (self, "authorize_" + auth_type) - is_auth = getattr (self, "authorize_" + auth_type) (auth_header, auth_rule) - if is_auth is False: - self.send_response (401) - self.send_challenge (auth_type) - self.finish_headers () - return is_auth - - def ExpectHeader (self, header_obj): - exp_headers = header_obj.headers - for header_line in exp_headers: - header_recd = self.headers.get (header_line) - if header_recd is None or header_recd != exp_headers[header_line]: - self.send_error (400, "Expected Header " + header_line + " not found") - self.finish_headers () - raise ServerError ("Header " + header_line + " not found") - - def expect_headers (self): - """ This is modified code to handle a few changes. Should be removed ASAP """ - exp_headers_obj = self.get_rule_list ('ExpectHeader') - if exp_headers_obj: - exp_headers = exp_headers_obj.headers - for header_line in exp_headers: - header_re = self.headers.get (header_line) - if header_re is None or header_re != exp_headers[header_line]: - self.send_error (400, 'Expected Header not Found') - self.end_headers () - return False - return True - - def RejectHeader (self, header_obj): - rej_headers = header_obj.headers - for header_line in rej_headers: - header_recd = self.headers.get (header_line) - if header_recd is not None and header_recd == rej_headers[header_line]: - self.send_error (400, 'Blackisted Header ' + header_line + ' received') - self.finish_headers () - raise ServerError ("Header " + header_line + ' received') - - def reject_headers (self): - rej_headers = self.get_rule_list ("RejectHeader") - if rej_headers: - rej_headers = rej_headers.headers - for header_line in rej_headers: - header_re = self.headers.get (header_line) - if header_re is not None and header_re == rej_headers[header_line]: - self.send_error (400, 'Blacklisted Header was Sent') - self.end_headers () - return False - return True - - def __log_request (self, method): - req = method + " " + self.path - self.server.request_headers.append (req) - - def send_head (self, method): - """ Common code for GET and HEAD Commands. - This method is overriden to use the fileSys dict. - - The method variable contains whether this was a HEAD or a GET Request. - According to RFC 2616, the server should not differentiate between - the two requests, however, we use it here for a specific test. - """ - - if self.path == "/": - path = "index.html" - else: - path = self.path[1:] - - self.__log_request (method) - - if path in self.server.fileSys: - self.rules = self.server.server_configs.get (path) - - for rule_name in self.rules: - try: - assert hasattr (self, rule_name) - getattr (self, rule_name) (self.rules [rule_name]) - except AssertionError as ae: - msg = "Method " + rule_name + " not defined" - self.send_error (500, msg) - return (None, None) - except ServerError as se: - print (se.__str__()) - return (None, None) - - content = self.server.fileSys.get (path) - content_length = len (content) - try: - self.range_begin = self.parse_range_header ( - self.headers.get ("Range"), content_length) - except InvalidRangeHeader as ae: - # self.log_error("%s", ae.err_message) - if ae.err_message == "Range Overflow": - self.send_response (416) - self.finish_headers () - return (None, None) - else: - self.range_begin = None - if self.range_begin is None: - self.send_response (200) - else: - self.send_response (206) - self.send_header ("Accept-Ranges", "bytes") - self.send_header ("Content-Range", - "bytes %d-%d/%d" % (self.range_begin, - content_length - 1, - content_length)) - content_length -= self.range_begin - cont_type = self.guess_type (path) - self.send_header ("Content-type", cont_type) - self.send_header ("Content-Length", content_length) - self.finish_headers () - return (content, self.range_begin) - else: - self.send_error (404, "Not Found") - return (None, None) - - def guess_type (self, path): - base_name = basename ("/" + path) - name, ext = splitext (base_name) - extension_map = { - ".txt" : "text/plain", - ".css" : "text/css", - ".html" : "text/html" - } - if ext in extension_map: - return extension_map[ext] - else: - return "text/plain" - - -class HTTPd (threading.Thread): - server_class = StoppableHTTPServer - handler = _Handler - def __init__ (self, addr=None): - threading.Thread.__init__ (self) - if addr is None: - addr = ('localhost', 0) - self.server_inst = self.server_class (addr, self.handler) - self.server_address = self.server_inst.socket.getsockname()[:2] - - def run (self): - self.server_inst.serve_forever () - - def server_conf (self, file_list, server_rules): - self.server_inst.server_conf (file_list, server_rules) - - def server_sett (self, settings): - self.server_inst.server_sett (settings) - -class HTTPSd (HTTPd): - - server_class = HTTPSServer - -# vim: set ts=4 sts=4 sw=4 tw=80 et : diff --git a/testenv/WgetTest.py b/testenv/WgetTest.py index 6076012..92e4138 100644 --- a/testenv/WgetTest.py +++ b/testenv/WgetTest.py @@ -8,15 +8,11 @@ import time from subprocess import call from difflib import unified_diff -import HTTPServer import conf from exc.test_failed import TestFailed from misc.colour_terminal import print_red, print_green, print_blue - - -HTTP = "HTTP" -HTTPS = "HTTPS" - +from misc.constants import HTTP, HTTPS +from server.http import http_server """ Class that defines methods common to both HTTP and FTP Tests. """ @@ -220,12 +216,12 @@ class HTTPTest (CommonMethods): self.hook_call(post_hook, 'Post Test Function') def init_HTTP_Server (self): - server = HTTPServer.HTTPd () + server = http_server.HTTPd () server.start () return server def init_HTTPS_Server (self): - server = HTTPServer.HTTPSd () + server = http_server.HTTPSd () server.start () return server diff --git a/testenv/misc/constants.py b/testenv/misc/constants.py new file mode 100644 index 0000000..5fad2f8 --- /dev/null +++ b/testenv/misc/constants.py @@ -0,0 +1,3 @@ + +HTTP = "HTTP" +HTTPS = "HTTPS" \ No newline at end of file diff --git a/testenv/server/__init__.py b/testenv/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/testenv/server/ftp/__init__.py b/testenv/server/ftp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/testenv/server/ftp/ftp_server.py b/testenv/server/ftp/ftp_server.py new file mode 100644 index 0000000..f7d7771 --- /dev/null +++ b/testenv/server/ftp/ftp_server.py @@ -0,0 +1,162 @@ +import os +import re +import threading +import socket +import pyftpdlib.__main__ +from pyftpdlib.ioloop import IOLoop +import pyftpdlib.handlers as Handle +from pyftpdlib.servers import FTPServer +from pyftpdlib.authorizers import DummyAuthorizer +from pyftpdlib._compat import PY3, u, b, getcwdu, callable + +class FTPDHandler (Handle.FTPHandler): + + def ftp_LIST (self, path): + try: + iterator = self.run_as_current_user(self.fs.get_list_dir, path) + except (OSError, FilesystemError): + err = sys.exc_info()[1] + why = _strerror (err) + self.respond ('550 %s. ' % why) + else: + if self.isRule ("Bad List") is True: + iter_list = list () + for flist in iterator: + line = re.compile (r'(\s+)').split (flist.decode ('utf-8')) + line[8] = '0' + iter_l = ''.join (line).encode ('utf-8') + iter_list.append (iter_l) + iterator = (n for n in iter_list) + producer = Handle.BufferedIteratorProducer (iterator) + self.push_dtp_data (producer, isproducer=True, cmd="LIST") + return path + + def ftp_PASV (self, line): + if self._epsvall: + self.respond ("501 PASV not allowed after EPSV ALL.") + return + self._make_epasv(extmode=False) + if self.isRule ("FailPASV") is True: + del self.server.global_rules["FailPASV"] + self.socket.close () + + def isRule (self, rule): + rule_obj = self.server.global_rules[rule] + return False if not rule_obj else rule_obj[0] + +class FTPDServer (FTPServer): + + def set_global_rules (self, rules): + self.global_rules = rules + +class FTPd(threading.Thread): + """A threaded FTP server used for running tests. + + This is basically a modified version of the FTPServer class which + wraps the polling loop into a thread. + + The instance returned can be used to start(), stop() and + eventually re-start() the server. + """ + handler = FTPDHandler + server_class = FTPDServer + + def __init__(self, addr=None): + os.mkdir ('server') + os.chdir ('server') + try: + HOST = socket.gethostbyname ('localhost') + except socket.error: + HOST = 'localhost' + USER = 'user' + PASSWD = '12345' + HOME = getcwdu () + + threading.Thread.__init__(self) + self.__serving = False + self.__stopped = False + self.__lock = threading.Lock() + self.__flag = threading.Event() + if addr is None: + addr = (HOST, 0) + + authorizer = DummyAuthorizer() + authorizer.add_user(USER, PASSWD, HOME, perm='elradfmwM') # full perms + authorizer.add_anonymous(HOME) + self.handler.authorizer = authorizer + # lowering buffer sizes = more cycles to transfer data + # = less false positive test failures + self.handler.dtp_handler.ac_in_buffer_size = 32768 + self.handler.dtp_handler.ac_out_buffer_size = 32768 + self.server = self.server_class(addr, self.handler) + self.host, self.port = self.server.socket.getsockname()[:2] + os.chdir ('..') + + def set_global_rules (self, rules): + self.server.set_global_rules (rules) + + def __repr__(self): + status = [self.__class__.__module__ + "." + self.__class__.__name__] + if self.__serving: + status.append('active') + else: + status.append('inactive') + status.append('%s:%s' % self.server.socket.getsockname()[:2]) + return '<%s at %#x>' % (' '.join(status), id(self)) + + @property + def running(self): + return self.__serving + + def start(self, timeout=0.001): + """Start serving until an explicit stop() request. + Polls for shutdown every 'timeout' seconds. + """ + if self.__serving: + raise RuntimeError("Server already started") + if self.__stopped: + # ensure the server can be started again + FTPd.__init__(self, self.server.socket.getsockname(), self.handler) + self.__timeout = timeout + threading.Thread.start(self) + self.__flag.wait() + + def run(self): + self.__serving = True + self.__flag.set() + while self.__serving: + self.__lock.acquire() + self.server.serve_forever(timeout=self.__timeout, blocking=False) + self.__lock.release() + self.server.close_all() + + def stop(self): + """Stop serving (also disconnecting all currently connected + clients) by telling the serve_forever() loop to stop and + waits until it does. + """ + if not self.__serving: + raise RuntimeError("Server not started yet") + self.__serving = False + self.__stopped = True + self.join() + + +def mk_file_sys (file_list): + os.chdir ('server') + for name, content in file_list.items (): + file_h = open (name, 'w') + file_h.write (content) + file_h.close () + os.chdir ('..') + +def filesys (): + fileSys = dict () + os.chdir ('server') + for parent, dirs, files in os.walk ('.'): + for filename in files: + file_handle = open (filename, 'r') + file_content = file_handle.read () + fileSys[filename] = file_content + os.chdir ('..') + return fileSys diff --git a/testenv/server/http/__init__.py b/testenv/server/http/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/testenv/server/http/http_server.py b/testenv/server/http/http_server.py new file mode 100644 index 0000000..946fb79 --- /dev/null +++ b/testenv/server/http/http_server.py @@ -0,0 +1,467 @@ +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import BaseServer +from posixpath import basename, splitext +from base64 import b64encode +from random import random +from hashlib import md5 +import threading +import socket +import re +import ssl +import os + + +class InvalidRangeHeader (Exception): + + """ Create an Exception for handling of invalid Range Headers. """ + # TODO: Eliminate this exception and use only ServerError + + def __init__ (self, err_message): + self.err_message = err_message + +class ServerError (Exception): + def __init__ (self, err_message): + self.err_message = err_message + + +class StoppableHTTPServer (HTTPServer): + + request_headers = list () + + """ Define methods for configuring the Server. """ + + def server_conf (self, filelist, conf_dict): + """ Set Server Rules and File System for this instance. """ + self.server_configs = conf_dict + self.fileSys = filelist + + def server_sett (self, settings): + for settings_key in settings: + setattr (self.RequestHandlerClass, settings_key, settings[settings_key]) + + def get_req_headers (self): + return self.request_headers + +class HTTPSServer (StoppableHTTPServer): + + def __init__ (self, address, handler): + BaseServer.__init__ (self, address, handler) + print (os.getcwd()) + CERTFILE = os.path.abspath (os.path.join ('../', 'certs', 'wget-cert.pem')) + print (CERTFILE) + fop = open (CERTFILE) + print (fop.readline()) + self.socket = ssl.wrap_socket ( + sock = socket.socket (self.address_family, self.socket_type), + ssl_version = ssl.PROTOCOL_TLSv1, + certfile = CERTFILE, + server_side = True + ) + self.server_bind () + self.server_activate () + +class WgetHTTPRequestHandler (BaseHTTPRequestHandler): + + """ Define methods for handling Test Checks. """ + + def get_rule_list (self, name): + r_list = self.rules.get (name) if name in self.rules else None + return r_list + + +class _Handler (WgetHTTPRequestHandler): + + """ Define Handler Methods for different Requests. """ + + InvalidRangeHeader = InvalidRangeHeader + protocol_version = 'HTTP/1.1' + + """ Define functions for various HTTP Requests. """ + + def do_HEAD (self): + self.send_head ("HEAD") + + def do_GET (self): + content, start = self.send_head ("GET") + if content: + if start is None: + self.wfile.write (content.encode ('utf-8')) + else: + self.wfile.write (content.encode ('utf-8')[start:]) + + def do_POST (self): + path = self.path[1:] + self.rules = self.server.server_configs.get (path) + if not self.custom_response (): + return (None, None) + if path in self.server.fileSys: + body_data = self.get_body_data () + self.send_response (200) + self.send_header ("Content-type", "text/plain") + content = self.server.fileSys.pop (path) + "\n" + body_data + total_length = len (content) + self.server.fileSys[path] = content + self.send_header ("Content-Length", total_length) + self.finish_headers () + try: + self.wfile.write (content.encode ('utf-8')) + except Exception: + pass + else: + self.send_put (path) + + def do_PUT (self): + path = self.path[1:] + self.rules = self.server.server_configs.get (path) + if not self.custom_response (): + return (None, None) + self.server.fileSys.pop (path, None) + self.send_put (path) + + """ End of HTTP Request Method Handlers. """ + + """ Helper functions for the Handlers. """ + + def parse_range_header (self, header_line, length): + if header_line is None: + return None + if not header_line.startswith ("bytes="): + raise InvalidRangeHeader ("Cannot parse header Range: %s" % + (header_line)) + regex = re.match (r"^bytes=(\d*)\-$", header_line) + range_start = int (regex.group (1)) + if range_start >= length: + raise InvalidRangeHeader ("Range Overflow") + return range_start + + def get_body_data (self): + cLength_header = self.headers.get ("Content-Length") + cLength = int (cLength_header) if cLength_header is not None else 0 + body_data = self.rfile.read (cLength).decode ('utf-8') + return body_data + + def send_put (self, path): + body_data = self.get_body_data () + self.send_response (201) + self.server.fileSys[path] = body_data + self.send_header ("Content-type", "text/plain") + self.send_header ("Content-Length", len (body_data)) + self.finish_headers () + try: + self.wfile.write (body_data.encode ('utf-8')) + except Exception: + pass + + def SendHeader (self, header_obj): + pass +# headers_list = header_obj.headers +# for header_line in headers_list: +# print (header_line + " : " + headers_list[header_line]) +# self.send_header (header_line, headers_list[header_line]) + + def send_cust_headers (self): + header_obj = self.get_rule_list ('SendHeader') + if header_obj: + for header in header_obj.headers: + self.send_header (header, header_obj.headers[header]) + + def finish_headers (self): + self.send_cust_headers () + self.end_headers () + + def Response (self, resp_obj): + self.send_response (resp_obj.response_code) + self.finish_headers () + raise ServerError ("Custom Response code sent.") + + def custom_response (self): + codes = self.get_rule_list ('Response') + if codes: + self.send_response (codes.response_code) + self.finish_headers () + return False + else: + return True + + def base64 (self, data): + string = b64encode (data.encode ('utf-8')) + return string.decode ('utf-8') + + def send_challenge (self, auth_type): + if auth_type == "Both": + self.send_challenge ("Digest") + self.send_challenge ("Basic") + return + if auth_type == "Basic": + challenge_str = 'Basic realm="Wget-Test"' + elif auth_type == "Digest" or auth_type == "Both_inline": + self.nonce = md5 (str (random ()).encode ('utf-8')).hexdigest () + self.opaque = md5 (str (random ()).encode ('utf-8')).hexdigest () + challenge_str = 'Digest realm="Test", nonce="%s", opaque="%s"' %( + self.nonce, + self.opaque) + challenge_str += ', qop="auth"' + if auth_type == "Both_inline": + challenge_str = 'Basic realm="Wget-Test", ' + challenge_str + self.send_header ("WWW-Authenticate", challenge_str) + + def authorize_Basic (self, auth_header, auth_rule): + if auth_header is None or auth_header.split(' ')[0] != 'Basic': + return False + else: + self.user = auth_rule.auth_user + self.passw = auth_rule.auth_pass + auth_str = "Basic " + self.base64 (self.user + ":" + self.passw) + return True if auth_str == auth_header else False + + def parse_auth_header (self, auth_header): + n = len("Digest ") + auth_header = auth_header[n:].strip() + items = auth_header.split(", ") + key_values = [i.split("=", 1) for i in items] + key_values = [(k.strip(), v.strip().replace('"', '')) for k, v in key_values] + return dict(key_values) + + def KD (self, secret, data): + return self.H (secret + ":" + data) + + def H (self, data): + return md5 (data.encode ('utf-8')).hexdigest () + + def A1 (self): + return "%s:%s:%s" % (self.user, "Test", self.passw) + + def A2 (self, params): + return "%s:%s" % (self.command, params["uri"]) + + def check_response (self, params): + if "qop" in params: + data_str = params['nonce'] \ + + ":" + params['nc'] \ + + ":" + params['cnonce'] \ + + ":" + params['qop'] \ + + ":" + self.H (self.A2 (params)) + else: + data_str = params['nonce'] + ":" + self.H (self.A2 (params)) + resp = self.KD (self.H (self.A1 ()), data_str) + + return True if resp == params['response'] else False + + def authorize_Digest (self, auth_header, auth_rule): + if auth_header is None or auth_header.split(' ')[0] != 'Digest': + return False + else: + self.user = auth_rule.auth_user + self.passw = auth_rule.auth_pass + params = self.parse_auth_header (auth_header) + pass_auth = True + if self.user != params['username'] or \ + self.nonce != params['nonce'] or self.opaque != params['opaque']: + pass_auth = False + req_attribs = ['username', 'realm', 'nonce', 'uri', 'response'] + for attrib in req_attribs: + if not attrib in params: + pass_auth = False + if not self.check_response (params): + pass_auth = False + return pass_auth + + def authorize_Both (self, auth_header, auth_rule): + return False + + def authorize_Both_inline (self, auth_header, auth_rule): + return False + + def Authentication (self, auth_rule): + try: + self.handle_auth (auth_rule) + except ServerError as se: + self.send_response (401, "Authorization Required") + self.send_challenge (auth_rule.auth_type) + self.finish_headers () + raise ServerError (se.__str__()) + + def handle_auth (self, auth_rule): + is_auth = True + auth_header = self.headers.get ("Authorization") + required_auth = auth_rule.auth_type + if required_auth == "Both" or required_auth == "Both_inline": + auth_type = auth_header.split(' ')[0] if auth_header else required_auth + else: + auth_type = required_auth + try: + assert hasattr (self, "authorize_" + auth_type) + is_auth = getattr (self, "authorize_" + auth_type) (auth_header, auth_rule) + except AssertionError: + raise ServerError ("Authentication Mechanism " + auth_rule + " not supported") + except AttributeError as ae: + raise ServerError (ae.__str__()) + if is_auth is False: + raise ServerError ("Unable to Authenticate") + + def is_authorized (self): + is_auth = True + auth_rule = self.get_rule_list ('Authentication') + if auth_rule: + auth_header = self.headers.get ("Authorization") + req_auth = auth_rule.auth_type + if req_auth == "Both" or req_auth == "Both_inline": + auth_type = auth_header.split(' ')[0] if auth_header else req_auth + else: + auth_type = req_auth + assert hasattr (self, "authorize_" + auth_type) + is_auth = getattr (self, "authorize_" + auth_type) (auth_header, auth_rule) + if is_auth is False: + self.send_response (401) + self.send_challenge (auth_type) + self.finish_headers () + return is_auth + + def ExpectHeader (self, header_obj): + exp_headers = header_obj.headers + for header_line in exp_headers: + header_recd = self.headers.get (header_line) + if header_recd is None or header_recd != exp_headers[header_line]: + self.send_error (400, "Expected Header " + header_line + " not found") + self.finish_headers () + raise ServerError ("Header " + header_line + " not found") + + def expect_headers (self): + """ This is modified code to handle a few changes. Should be removed ASAP """ + exp_headers_obj = self.get_rule_list ('ExpectHeader') + if exp_headers_obj: + exp_headers = exp_headers_obj.headers + for header_line in exp_headers: + header_re = self.headers.get (header_line) + if header_re is None or header_re != exp_headers[header_line]: + self.send_error (400, 'Expected Header not Found') + self.end_headers () + return False + return True + + def RejectHeader (self, header_obj): + rej_headers = header_obj.headers + for header_line in rej_headers: + header_recd = self.headers.get (header_line) + if header_recd is not None and header_recd == rej_headers[header_line]: + self.send_error (400, 'Blackisted Header ' + header_line + ' received') + self.finish_headers () + raise ServerError ("Header " + header_line + ' received') + + def reject_headers (self): + rej_headers = self.get_rule_list ("RejectHeader") + if rej_headers: + rej_headers = rej_headers.headers + for header_line in rej_headers: + header_re = self.headers.get (header_line) + if header_re is not None and header_re == rej_headers[header_line]: + self.send_error (400, 'Blacklisted Header was Sent') + self.end_headers () + return False + return True + + def __log_request (self, method): + req = method + " " + self.path + self.server.request_headers.append (req) + + def send_head (self, method): + """ Common code for GET and HEAD Commands. + This method is overriden to use the fileSys dict. + + The method variable contains whether this was a HEAD or a GET Request. + According to RFC 2616, the server should not differentiate between + the two requests, however, we use it here for a specific test. + """ + + if self.path == "/": + path = "index.html" + else: + path = self.path[1:] + + self.__log_request (method) + + if path in self.server.fileSys: + self.rules = self.server.server_configs.get (path) + + for rule_name in self.rules: + try: + assert hasattr (self, rule_name) + getattr (self, rule_name) (self.rules [rule_name]) + except AssertionError as ae: + msg = "Method " + rule_name + " not defined" + self.send_error (500, msg) + return (None, None) + except ServerError as se: + print (se.__str__()) + return (None, None) + + content = self.server.fileSys.get (path) + content_length = len (content) + try: + self.range_begin = self.parse_range_header ( + self.headers.get ("Range"), content_length) + except InvalidRangeHeader as ae: + # self.log_error("%s", ae.err_message) + if ae.err_message == "Range Overflow": + self.send_response (416) + self.finish_headers () + return (None, None) + else: + self.range_begin = None + if self.range_begin is None: + self.send_response (200) + else: + self.send_response (206) + self.send_header ("Accept-Ranges", "bytes") + self.send_header ("Content-Range", + "bytes %d-%d/%d" % (self.range_begin, + content_length - 1, + content_length)) + content_length -= self.range_begin + cont_type = self.guess_type (path) + self.send_header ("Content-type", cont_type) + self.send_header ("Content-Length", content_length) + self.finish_headers () + return (content, self.range_begin) + else: + self.send_error (404, "Not Found") + return (None, None) + + def guess_type (self, path): + base_name = basename ("/" + path) + name, ext = splitext (base_name) + extension_map = { + ".txt" : "text/plain", + ".css" : "text/css", + ".html" : "text/html" + } + if ext in extension_map: + return extension_map[ext] + else: + return "text/plain" + + +class HTTPd (threading.Thread): + server_class = StoppableHTTPServer + handler = _Handler + def __init__ (self, addr=None): + threading.Thread.__init__ (self) + if addr is None: + addr = ('localhost', 0) + self.server_inst = self.server_class (addr, self.handler) + self.server_address = self.server_inst.socket.getsockname()[:2] + + def run (self): + self.server_inst.serve_forever () + + def server_conf (self, file_list, server_rules): + self.server_inst.server_conf (file_list, server_rules) + + def server_sett (self, settings): + self.server_inst.server_sett (settings) + +class HTTPSd (HTTPd): + + server_class = HTTPSServer + +# vim: set ts=4 sts=4 sw=4 tw=80 et : -- 1.8.3.2
