Changeset: 160a16d35a5e for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=160a16d35a5e
Added Files:
testing/tfducktest.py
Modified Files:
testing/Mz.py.in
testing/sqllogictest.py
Branch: default
Log Message:
multiple connections in sqllogic tests poc
diffs (253 lines):
diff --git a/testing/Mz.py.in b/testing/Mz.py.in
--- a/testing/Mz.py.in
+++ b/testing/Mz.py.in
@@ -2841,7 +2841,7 @@ def main(argv) :
prgreen('OK')
print()
print('failed={}, skipped={}'.format(failed, skipped))
- print('Ran {} test in {:7.3f}'.format(test_count - skipped, t_))
+ print('Ran {} test in {:7.3f}s'.format(test_count - skipped, t_))
if verbose:
for TSTDIR, TST in FAILED_TESTS:
prred('ERROR\t')
diff --git a/testing/sqllogictest.py b/testing/sqllogictest.py
--- a/testing/sqllogictest.py
+++ b/testing/sqllogictest.py
@@ -41,6 +41,21 @@ skipidx = re.compile(r'create index .* \
class SQLLogicSyntaxError(Exception):
pass
+class SQLLogicConnection(object):
+ def __init__(self, conn_id, dbh, crs=None, language='sql'):
+ self.conn_id = conn_id
+ self.dbh = dbh
+ self.crs = crs
+ self.language = language
+
+ def cursor(self):
+ if self.crs:
+ return self.crs
+ if self.language == 'sql':
+ return self.dbh.cursor()
+ return MapiCursor(self.dbh)
+
+
def is_copyfrom_stmt(stmt:[str]=[]):
try:
index = stmt.index('<COPY_INTO_DATA>')
@@ -66,6 +81,32 @@ def prepare_copyfrom_stmt(stmt:[str]=[])
except ValueError:
return stmt
+def parse_connection_string(s: str) -> dict:
+ '''parse strings like @connection(id=con1, ...)
+ '''
+ res = dict()
+ if not (s.startswith('@connection(') and s.endswith(')')):
+ raise SQLLogicSyntaxError('invalid connection string!')
+ params = s[12:-1].split(',')
+ for p in params:
+ p = p.strip()
+ try:
+ k, v = p.split('=')
+ if k == 'id':
+ k = 'conn_id'
+ assert k in ['conn_id', 'username', 'password']
+ assert res.get(k) is None
+ res[k] = v
+ except (ValueError, AssertionError) as e:
+ raise SQLLogicSyntaxError('invalid connection paramaters
definition!')
+ if len(res.keys()) > 1:
+ try:
+ assert res.get('username')
+ assert res.get('password')
+ except AssertionError as e:
+ raise SQLLogicSyntaxError('invalid connection paramaters
definition, username or password missing!')
+ return res
+
class SQLLogic:
def __init__(self, report=None, out=sys.stdout):
self.dbh = None
@@ -74,6 +115,10 @@ class SQLLogic:
self.res = None
self.rpt = report
self.language = 'sql'
+ self.conn_map = dict()
+ self.database = None
+ self.hostname = None
+ self.port = None
def __enter__(self):
return self
@@ -84,6 +129,9 @@ class SQLLogic:
def connect(self, username='monetdb', password='monetdb',
hostname='localhost', port=None, database='demo',
language='sql'):
self.language = language
+ self.hostname = hostname
+ self.port = port
+ self.database = database
if language == 'sql':
self.dbh = pymonetdb.connect(username=username,
password=password,
@@ -103,7 +151,41 @@ class SQLLogic:
port=port)
self.crs = MapiCursor(dbh)
+ def add_connection(self, conn_id, username='monetdb', password='monetdb'):
+ if self.conn_map.get(conn_id, None) is None:
+ hostname = self.hostname
+ port = self.port
+ database = self.database
+ language = self.language
+ if language == 'sql':
+ dbh = pymonetdb.connect(username=username,
+ password=password,
+ hostname=hostname,
+ port=port,
+ database=database,
+ autocommit=True)
+ crs = dbh.cursor()
+ else:
+ dbh = malmapi.Connection()
+ dbh.connect(database=database,
+ username=username,
+ password=password,
+ language=language,
+ hostname=hostname,
+ port=port)
+ crs = MapiCursor(dbh)
+ conn = SQLLogicConnection(conn_id, dbh=dbh, crs=crs,
language=language)
+ self.conn_map[conn_id] = conn
+ return conn
+
+ def get_connection(self, conn_id):
+ return self.conn_map.get(conn_id)
+
def close(self):
+ for k in self.conn_map:
+ conn = self.conn_map[k]
+ conn.dbh.close()
+ self.conn_map.clear()
if self.crs:
self.crs.close()
self.crs = None
@@ -111,6 +193,7 @@ class SQLLogic:
self.dbh.close()
self.dbh = None
+
def drop(self):
if self.language != 'sql':
return
@@ -154,12 +237,13 @@ class SQLLogic:
except pymonetdb.Error:
pass
- def exec_statement(self, statement, expectok, err_stmt=None,
expected_err_code=None, expected_err_msg=None, expected_rowcount=None):
+ def exec_statement(self, statement, expectok, err_stmt=None,
expected_err_code=None, expected_err_msg=None, expected_rowcount=None,
conn=None):
+ crs = conn.cursor() if conn else self.crs
if skipidx.search(statement) is not None:
# skip creation of ascending or descending index
return
try:
- affected_rowcount = self.crs.execute(statement)
+ affected_rowcount = crs.execute(statement)
except (pymonetdb.Error, ValueError) as e:
msg = e.args[0]
if not expectok:
@@ -259,20 +343,21 @@ class SQLLogic:
sep = '|'
print('', file=self.out)
- def exec_query(self, query, columns, sorting, pyscript, hashlabel,
nresult, hash, expected) -> bool:
+ def exec_query(self, query, columns, sorting, pyscript, hashlabel,
nresult, hash, expected, conn=None) -> bool:
err = False
+ crs = conn.cursor() if conn else self.crs
try:
- self.crs.execute(query)
+ crs.execute(query)
except (pymonetdb.Error, ValueError) as e:
self.query_error(query, 'query failed', e.args[0])
return False
- data = self.crs.fetchall()
- if self.crs.description:
- if len(self.crs.description) != len(columns):
- self.query_error(query, 'received {} columns, expected {}
columns'.format(len(self.crs.description), len(columns)), data=data)
+ data = crs.fetchall()
+ if crs.description:
+ if len(crs.description) != len(columns):
+ self.query_error(query, 'received {} columns, expected {}
columns'.format(len(crs.description), len(columns)), data=data)
return False
- if sorting != 'python' and self.crs.rowcount * len(columns) != nresult:
- self.query_error(query, 'received {} rows, expected {}
rows'.format(self.crs.rowcount, nresult // len(columns)), data=data)
+ if sorting != 'python' and crs.rowcount * len(columns) != nresult:
+ self.query_error(query, 'received {} rows, expected {}
rows'.format(crs.rowcount, nresult // len(columns)), data=data)
return False
if self.res is not None:
for row in data:
@@ -400,6 +485,12 @@ class SQLLogic:
break
if line[0] == '#': # skip mal comments
break
+ conn = None
+ # look for connection string
+ if line.startswith('@connection'):
+ conn_params = parse_connection_string(line)
+ conn = self.get_connection(conn_params.get('conn_id')) or
self.add_connection(**conn_params)
+ line = self.readline()
line = line.split()
if not line:
continue
@@ -434,9 +525,9 @@ class SQLLogic:
if not skipping:
if is_copyfrom_stmt(statement):
stmt, stmt_less_data = prepare_copyfrom_stmt(statement)
- self.exec_statement(stmt, expectok,
err_stmt=stmt_less_data, expected_err_code=expected_err_code,
expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount)
+ self.exec_statement(stmt, expectok,
err_stmt=stmt_less_data, expected_err_code=expected_err_code,
expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount,
conn=conn)
else:
- self.exec_statement('\n'.join(statement), expectok,
expected_err_code=expected_err_code, expected_err_msg=expected_err_msg,
expected_rowcount=expected_rowcount)
+ self.exec_statement('\n'.join(statement), expectok,
expected_err_code=expected_err_code, expected_err_msg=expected_err_msg,
expected_rowcount=expected_rowcount, conn=conn)
elif line[0] == 'query':
columns = line[1]
pyscript = None
@@ -475,7 +566,7 @@ class SQLLogic:
line = self.readline()
nresult = len(expected)
if not skipping:
- self.exec_query('\n'.join(query), columns, sorting,
pyscript, hashlabel, nresult, hash, expected)
+ self.exec_query('\n'.join(query), columns, sorting,
pyscript, hashlabel, nresult, hash, expected, conn=conn)
if __name__ == '__main__':
import argparse
diff --git a/testing/tfducktest.py b/testing/tfducktest.py
new file mode 100644
--- /dev/null
+++ b/testing/tfducktest.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+
+import sys
+
+def main():
+ buff = []
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+ sline = line.strip()
+ if sline.startswith('statement') or sline.startswith('query'):
+ words = sline.split()
+ if len(words) == 3:
+ third = words[2]
+ if third.lower() not in ['rowsort', 'valuesort', 'nosort']:
+ # must be connection str
+ buff.append(f'@connection(id={third})\n')
+ #strip last word
+ buff.append(' '.join(words[:2]) + '\n')
+ continue
+ buff.append(line)
+ print(''.join(buff))
+
+if __name__ == '__main__':
+ main()
+
_______________________________________________
checkin-list mailing list
[email protected]
https://www.monetdb.org/mailman/listinfo/checkin-list