Changeset: 160a16d35a5e for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=160a16d35a5e
Added Files:
        testing/tfducktest.py
Modified Files:
        testing/Mz.py.in
        testing/sqllogictest.py
Branch: default
Log Message:

multiple connections in sqllogic tests poc


diffs (253 lines):

diff --git a/testing/Mz.py.in b/testing/Mz.py.in
--- a/testing/Mz.py.in
+++ b/testing/Mz.py.in
@@ -2841,7 +2841,7 @@ def main(argv) :
                 prgreen('OK')
             print()
             print('failed={}, skipped={}'.format(failed, skipped))
-            print('Ran {} test in {:7.3f}'.format(test_count - skipped, t_))
+            print('Ran {} test in {:7.3f}s'.format(test_count - skipped, t_))
             if verbose:
                 for TSTDIR, TST in FAILED_TESTS:
                     prred('ERROR\t')
diff --git a/testing/sqllogictest.py b/testing/sqllogictest.py
--- a/testing/sqllogictest.py
+++ b/testing/sqllogictest.py
@@ -41,6 +41,21 @@ skipidx = re.compile(r'create index .* \
 class SQLLogicSyntaxError(Exception):
     pass
 
+class SQLLogicConnection(object):
+    def __init__(self, conn_id, dbh, crs=None, language='sql'):
+        self.conn_id = conn_id
+        self.dbh = dbh
+        self.crs = crs
+        self.language = language
+
+    def cursor(self):
+        if self.crs:
+            return self.crs
+        if self.language == 'sql':
+            return self.dbh.cursor()
+        return MapiCursor(self.dbh)
+
+
 def is_copyfrom_stmt(stmt:[str]=[]):
     try:
         index = stmt.index('<COPY_INTO_DATA>')
@@ -66,6 +81,32 @@ def prepare_copyfrom_stmt(stmt:[str]=[])
     except ValueError:
         return stmt
 
+def parse_connection_string(s: str) -> dict:
+    '''parse strings like @connection(id=con1, ...)
+    '''
+    res = dict()
+    if not (s.startswith('@connection(') and s.endswith(')')):
+        raise SQLLogicSyntaxError('invalid connection string!')
+    params = s[12:-1].split(',')
+    for p in params:
+        p = p.strip()
+        try:
+            k, v = p.split('=')
+            if k == 'id':
+                k = 'conn_id'
+            assert k in ['conn_id', 'username', 'password']
+            assert res.get(k) is None
+            res[k] = v
+        except (ValueError, AssertionError) as e:
+            raise SQLLogicSyntaxError('invalid connection paramaters 
definition!')
+    if len(res.keys()) > 1:
+        try:
+            assert res.get('username')
+            assert res.get('password')
+        except AssertionError as e:
+            raise SQLLogicSyntaxError('invalid connection paramaters 
definition, username or password missing!')
+    return res
+
 class SQLLogic:
     def __init__(self, report=None, out=sys.stdout):
         self.dbh = None
@@ -74,6 +115,10 @@ class SQLLogic:
         self.res = None
         self.rpt = report
         self.language = 'sql'
+        self.conn_map = dict()
+        self.database = None
+        self.hostname = None
+        self.port = None
 
     def __enter__(self):
         return self
@@ -84,6 +129,9 @@ class SQLLogic:
     def connect(self, username='monetdb', password='monetdb',
                 hostname='localhost', port=None, database='demo', 
language='sql'):
         self.language = language
+        self.hostname = hostname
+        self.port = port
+        self.database = database
         if language == 'sql':
             self.dbh = pymonetdb.connect(username=username,
                                      password=password,
@@ -103,7 +151,41 @@ class SQLLogic:
                                      port=port)
             self.crs = MapiCursor(dbh)
 
+    def add_connection(self, conn_id, username='monetdb', password='monetdb'):
+        if self.conn_map.get(conn_id, None) is None:
+            hostname = self.hostname
+            port = self.port
+            database = self.database
+            language = self.language
+            if language == 'sql':
+                dbh  = pymonetdb.connect(username=username,
+                                     password=password,
+                                     hostname=hostname,
+                                     port=port,
+                                     database=database,
+                                     autocommit=True)
+                crs = dbh.cursor()
+            else:
+                dbh = malmapi.Connection()
+                dbh.connect(database=database,
+                         username=username,
+                         password=password,
+                         language=language,
+                         hostname=hostname,
+                         port=port)
+                crs = MapiCursor(dbh)
+            conn = SQLLogicConnection(conn_id, dbh=dbh, crs=crs, 
language=language)
+            self.conn_map[conn_id] = conn
+            return conn
+
+    def get_connection(self, conn_id):
+        return self.conn_map.get(conn_id)
+
     def close(self):
+        for k in self.conn_map:
+            conn = self.conn_map[k]
+            conn.dbh.close()
+        self.conn_map.clear()
         if self.crs:
             self.crs.close()
             self.crs = None
@@ -111,6 +193,7 @@ class SQLLogic:
             self.dbh.close()
             self.dbh = None
 
+
     def drop(self):
         if self.language != 'sql':
             return
@@ -154,12 +237,13 @@ class SQLLogic:
             except pymonetdb.Error:
                 pass
 
-    def exec_statement(self, statement, expectok, err_stmt=None, 
expected_err_code=None, expected_err_msg=None, expected_rowcount=None):
+    def exec_statement(self, statement, expectok, err_stmt=None, 
expected_err_code=None, expected_err_msg=None, expected_rowcount=None, 
conn=None):
+        crs = conn.cursor() if conn else self.crs
         if skipidx.search(statement) is not None:
             # skip creation of ascending or descending index
             return
         try:
-            affected_rowcount = self.crs.execute(statement)
+            affected_rowcount = crs.execute(statement)
         except (pymonetdb.Error, ValueError) as e:
             msg = e.args[0]
             if not expectok:
@@ -259,20 +343,21 @@ class SQLLogic:
                     sep = '|'
                 print('', file=self.out)
 
-    def exec_query(self, query, columns, sorting, pyscript, hashlabel, 
nresult, hash, expected) -> bool:
+    def exec_query(self, query, columns, sorting, pyscript, hashlabel, 
nresult, hash, expected, conn=None) -> bool:
         err = False
+        crs = conn.cursor() if conn else self.crs
         try:
-            self.crs.execute(query)
+            crs.execute(query)
         except (pymonetdb.Error, ValueError) as e:
             self.query_error(query, 'query failed', e.args[0])
             return False
-        data = self.crs.fetchall()
-        if self.crs.description:
-            if len(self.crs.description) != len(columns):
-                self.query_error(query, 'received {} columns, expected {} 
columns'.format(len(self.crs.description), len(columns)), data=data)
+        data = crs.fetchall()
+        if crs.description:
+            if len(crs.description) != len(columns):
+                self.query_error(query, 'received {} columns, expected {} 
columns'.format(len(crs.description), len(columns)), data=data)
                 return False
-        if sorting != 'python' and self.crs.rowcount * len(columns) != nresult:
-            self.query_error(query, 'received {} rows, expected {} 
rows'.format(self.crs.rowcount, nresult // len(columns)), data=data)
+        if sorting != 'python' and crs.rowcount * len(columns) != nresult:
+            self.query_error(query, 'received {} rows, expected {} 
rows'.format(crs.rowcount, nresult // len(columns)), data=data)
             return False
         if self.res is not None:
             for row in data:
@@ -400,6 +485,12 @@ class SQLLogic:
                 break
             if line[0] == '#': # skip mal comments
                 break
+            conn = None
+            # look for connection string
+            if line.startswith('@connection'):
+                conn_params = parse_connection_string(line)
+                conn = self.get_connection(conn_params.get('conn_id')) or 
self.add_connection(**conn_params)
+                line = self.readline()
             line = line.split()
             if not line:
                 continue
@@ -434,9 +525,9 @@ class SQLLogic:
                 if not skipping:
                     if is_copyfrom_stmt(statement):
                         stmt, stmt_less_data = prepare_copyfrom_stmt(statement)
-                        self.exec_statement(stmt, expectok, 
err_stmt=stmt_less_data, expected_err_code=expected_err_code, 
expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount)
+                        self.exec_statement(stmt, expectok, 
err_stmt=stmt_less_data, expected_err_code=expected_err_code, 
expected_err_msg=expected_err_msg, expected_rowcount=expected_rowcount, 
conn=conn)
                     else:
-                        self.exec_statement('\n'.join(statement), expectok, 
expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, 
expected_rowcount=expected_rowcount)
+                        self.exec_statement('\n'.join(statement), expectok, 
expected_err_code=expected_err_code, expected_err_msg=expected_err_msg, 
expected_rowcount=expected_rowcount, conn=conn)
             elif line[0] == 'query':
                 columns = line[1]
                 pyscript = None
@@ -475,7 +566,7 @@ class SQLLogic:
                         line = self.readline()
                     nresult = len(expected)
                 if not skipping:
-                    self.exec_query('\n'.join(query), columns, sorting, 
pyscript, hashlabel, nresult, hash, expected)
+                    self.exec_query('\n'.join(query), columns, sorting, 
pyscript, hashlabel, nresult, hash, expected, conn=conn)
 
 if __name__ == '__main__':
     import argparse
diff --git a/testing/tfducktest.py b/testing/tfducktest.py
new file mode 100644
--- /dev/null
+++ b/testing/tfducktest.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+
+import sys
+
+def main():
+    buff = []
+    while True:
+        line = sys.stdin.readline()
+        if not line:
+            break
+        sline = line.strip()
+        if sline.startswith('statement') or sline.startswith('query'):
+            words = sline.split()
+            if len(words) == 3:
+               third = words[2]
+               if third.lower() not in ['rowsort', 'valuesort', 'nosort']:
+                   # must be connection str
+                   buff.append(f'@connection(id={third})\n')
+                   #strip last word
+                   buff.append(' '.join(words[:2]) + '\n')
+                   continue
+        buff.append(line)
+    print(''.join(buff))
+
+if __name__ == '__main__':
+    main()
+
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to