Changeset: 160a16d35a5e for MonetDB URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=160a16d35a5e Added Files: testing/tfducktest.py Modified Files: testing/Mz.py.in testing/sqllogictest.py Branch: default Log Message:
multiple connections in sqllogic tests poc diffs (253 lines): diff --git a/testing/Mz.py.in b/testing/Mz.py.in --- a/testing/Mz.py.in +++ b/testing/Mz.py.in @@ -2841,7 +2841,7 @@ def main(argv) : prgreen('OK') print() print('failed={}, skipped={}'.format(failed, skipped)) - print('Ran {} test in {:7.3f}'.format(test_count - skipped, t_)) + print('Ran {} test in {:7.3f}s'.format(test_count - skipped, t_)) if verbose: for TSTDIR, TST in FAILED_TESTS: prred('ERROR\t') diff --git a/testing/sqllogictest.py b/testing/sqllogictest.py --- a/testing/sqllogictest.py +++ b/testing/sqllogictest.py @@ -41,6 +41,21 @@ skipidx = re.compile(r'create index .* \ class SQLLogicSyntaxError(Exception): pass +class SQLLogicConnection(object): + def __init__(self, conn_id, dbh, crs=None, language='sql'): + self.conn_id = conn_id + self.dbh = dbh + self.crs = crs + self.language = language + + def cursor(self): + if self.crs: + return self.crs + if self.language == 'sql': + return self.dbh.cursor() + return MapiCursor(self.dbh) + + def is_copyfrom_stmt(stmt:[str]=[]): try: index = stmt.index('<COPY_INTO_DATA>') @@ -66,6 +81,32 @@ def prepare_copyfrom_stmt(stmt:[str]=[]) except ValueError: return stmt +def parse_connection_string(s: str) -> dict: + '''parse strings like @connection(id=con1, ...) + ''' + res = dict() + if not (s.startswith('@connection(') and s.endswith(')')): + raise SQLLogicSyntaxError('invalid connection string!') + params = s[12:-1].split(',') + for p in params: + p = p.strip() + try: + k, v = p.split('=') + if k == 'id': + k = 'conn_id' + assert k in ['conn_id', 'username', 'password'] + assert res.get(k) is None + res[k] = v + except (ValueError, AssertionError) as e: + raise SQLLogicSyntaxError('invalid connection paramaters definition!') + if len(res.keys()) > 1: + try: + assert res.get('username') + assert res.get('password') + except AssertionError as e: + raise SQLLogicSyntaxError('invalid connection paramaters definition, username or password missing!') + return res + class SQLLogic: def __init__(self, report=None, out=sys.stdout): self.dbh = None @@ -74,6 +115,10 @@ class SQLLogic: self.res = None self.rpt = report self.language = 'sql' + self.conn_map = dict() + self.database = None + self.hostname = None + self.port = None def __enter__(self): return self @@ -84,6 +129,9 @@ class SQLLogic: def connect(self, username='monetdb', password='monetdb', hostname='localhost', port=None, database='demo', language='sql'): self.language = language + self.hostname = hostname + self.port = port + self.database = database if language == 'sql': self.dbh = pymonetdb.connect(username=username, password=password, @@ -103,7 +151,41 @@ class SQLLogic: port=port) self.crs = MapiCursor(dbh) + def add_connection(self, conn_id, username='monetdb', password='monetdb'): + if self.conn_map.get(conn_id, None) is None: + hostname = self.hostname + port = self.port + database = self.database + language = self.language + if language == 'sql': + dbh = pymonetdb.connect(username=username, + password=password, + hostname=hostname, + port=port, + database=database, + autocommit=True) + crs = dbh.cursor() + else: + dbh = malmapi.Connection() + dbh.connect(database=database, + username=username, + password=password, + language=language, + hostname=hostname, + port=port) + crs = MapiCursor(dbh) + conn = SQLLogicConnection(conn_id, dbh=dbh, crs=crs, language=language) + self.conn_map[conn_id] = conn + return conn + + def get_connection(self, conn_id): + return self.conn_map.get(conn_id) + def close(self): + for k in self.conn_map: + conn = self.conn_map[k] + conn.dbh.close() + self.conn_map.clear() if self.crs: self.crs.close() self.crs = None @@ -111,6 +193,7 @@ class SQLLogic: self.dbh.close() self.dbh = None + def drop(self): if self.language != 'sql': return @@ -154,12 +237,13 @@ class SQLLogic: except pymonetdb.Error: pass - def exec_statement(self, statement, expectok, err_stmt=None, expected_err_code=None, expected_err_msg=None, expected_rowcount=None): + def exec_statement(self, statement, expectok, err_stmt=None, expected_err_code=None, expected_err_msg=None, expected_rowcount=None, conn=None): + crs = conn.cursor() if conn else self.crs if skipidx.search(statement) is not None: # skip creation of ascending or descending index return try: - affected_rowcount = self.crs.execute(statement) + affected_rowcount = crs.execute(statement) except (pymonetdb.Error, ValueError) as e: msg = e.args[0] if not expectok: @@ -259,20 +343,21 @@ class SQLLogic: sep = '|' print('', file=self.out) - def exec_query(self, query, columns, sorting, pyscript, hashlabel, nresult, hash, expected) -> bool: + def exec_query(self, query, columns, sorting, pyscript, hashlabel, nresult, hash, expected, conn=None) -> bool: err = False + crs = conn.cursor() if conn else self.crs try: - self.crs.execute(query) + crs.execute(query) except (pymonetdb.Error, ValueError) as e: self.query_error(query, 'query failed', e.args[0]) return False - data = self.crs.fetchall() - if self.crs.description: - if len(self.crs.description) != len(columns): - self.query_error(query, 'received {} columns, expected {} columns'.format(len(self.crs.description), len(columns)), data=data) + data = crs.fetchall() + if crs.description: + if len(crs.description) != len(columns): + self.query_error(query, 'received {} columns, expected {} columns'.format(len(crs.description), len(columns)), data=data) return False - if sorting != 'python' and self.crs.rowcount * len(columns) != nresult: - self.query_error(query, 'received {} rows, expected {} rows'.format(self.crs.rowcount, nresult // len(columns)), data=data) + if sorting != 'python' and crs.rowcount * len(columns) != nresult: + self.query_error(query, 'received {} rows, expected {} rows'.format(crs.rowcount, nresult // len(columns)), data=data) return False if self.res is not None: for row in data: @@ -400,6 +485,12 @@ class SQLLogic: break if line[0] == '#': # skip mal comments break + conn = None + # look for connection string + if line.startswith('@connection'): + conn_params = parse_connection_string(line) + conn = self.get_connection(conn_params.get('conn_id')) or self.add_connection(**conn_params) + line = self.readline() line = line.split() if not line: continue @@ -434,9 +525,9 @@ class SQLLogic: if not skipping: if is_copyfrom_stmt(statement): stmt, stmt_less_data = prepare_copyfrom_stmt(statement) - self.exec_statement(stmt, expectok, err_stmt=stmt_less_data, expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount) + self.exec_statement(stmt, expectok, err_stmt=stmt_less_data, expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount, conn=conn) else: - self.exec_statement('\n'.join(statement), expectok, expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount) + self.exec_statement('\n'.join(statement), expectok, expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount, conn=conn) elif line[0] == 'query': columns = line[1] pyscript = None @@ -475,7 +566,7 @@ class SQLLogic: line = self.readline() nresult = len(expected) if not skipping: - self.exec_query('\n'.join(query), columns, sorting, pyscript, hashlabel, nresult, hash, expected) + self.exec_query('\n'.join(query), columns, sorting, pyscript, hashlabel, nresult, hash, expected, conn=conn) if __name__ == '__main__': import argparse diff --git a/testing/tfducktest.py b/testing/tfducktest.py new file mode 100644 --- /dev/null +++ b/testing/tfducktest.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import sys + +def main(): + buff = [] + while True: + line = sys.stdin.readline() + if not line: + break + sline = line.strip() + if sline.startswith('statement') or sline.startswith('query'): + words = sline.split() + if len(words) == 3: + third = words[2] + if third.lower() not in ['rowsort', 'valuesort', 'nosort']: + # must be connection str + buff.append(f'@connection(id={third})\n') + #strip last word + buff.append(' '.join(words[:2]) + '\n') + continue + buff.append(line) + print(''.join(buff)) + +if __name__ == '__main__': + main() + _______________________________________________ checkin-list mailing list checkin-list@monetdb.org https://www.monetdb.org/mailman/listinfo/checkin-list