Author: cito
Date: Wed Jan 13 16:49:35 2016
New Revision: 735

Log:
Implement "upsert" method for PostgreSQL 9.5

A new method upsert() has been added to the DB wrapper class that
nicely complements the existing get/insert/update/delete() methods.

Modified:
   trunk/docs/contents/pg/connection.rst
   trunk/docs/contents/pg/db_wrapper.rst
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/docs/contents/pg/connection.rst
==============================================================================
--- trunk/docs/contents/pg/connection.rst       Wed Jan 13 10:50:25 2016        
(r734)
+++ trunk/docs/contents/pg/connection.rst       Wed Jan 13 16:49:35 2016        
(r735)
@@ -168,9 +168,8 @@
 of tuples/lists that define the values for each inserted row. The rows
 values may contain string, integer, long or double (real) values.
 
-.. note::
+.. warning::
 
-    **Be very careful**:
     This method doesn't type check the fields according to the table 
definition;
     it just look whether or not it knows how to handle such types.
 

Modified: trunk/docs/contents/pg/db_wrapper.rst
==============================================================================
--- trunk/docs/contents/pg/db_wrapper.rst       Wed Jan 13 10:50:25 2016        
(r734)
+++ trunk/docs/contents/pg/db_wrapper.rst       Wed Jan 13 16:49:35 2016        
(r735)
@@ -217,6 +217,7 @@
     :param str keyname: name of field to use as key (optional)
     :returns: A dictionary - the keys are the attribute names,
       the values are the row values.
+    :raises ProgrammingError: no primary key or missing privilege
 
 This method is the basic mechanism to get a single row. It assumes
 that the key specifies a unique row. If *keyname* is not specified,
@@ -231,14 +232,15 @@
 insert -- insert a row into a database table
 --------------------------------------------
 
-.. method:: DB.insert(table, [d,] [key = val, ...])
+.. method:: DB.insert(table, [d], [col=val, ...])
 
     Insert a row into a database table
 
     :param str table: name of table
     :param dict d: optional dictionary of values
-    :returns: the inserted values
+    :returns: the inserted values in the database
     :rtype: dict
+    :raises ProgrammingError: missing privilege or conflict
 
 This method inserts a row into a table.  If the optional dictionary is
 not supplied then the required values must be included as keyword/value
@@ -254,14 +256,15 @@
 update -- update a row in a database table
 ------------------------------------------
 
-.. method:: DB.update(table, [d,] [key = val, ...])
+.. method:: DB.update(table, [d], [col=val, ...])
 
     Update a row in a database table
 
     :param str table: name of table
     :param dict d: optional dictionary of values
-    :returns: the new row
+    :returns: the new row in the database
     :rtype: dict
+    :raises ProgrammingError: no primary key or missing privilege
 
 Similar to insert but updates an existing row.  The update is based on the
 OID value as munged by get or passed as keyword, or on the primary key of
@@ -273,6 +276,61 @@
 either in the dictionary where the OID must be munged, or in the keywords
 where it can be simply the string 'oid'.
 
+upsert -- insert a row with conflict resolution
+-----------------------------------------------
+
+.. method:: DB.upsert(table, [d], [col=val, ...])
+
+    Insert a row into a database table with conflict resolution
+
+    :param str table: name of table
+    :param dict d: optional dictionary of values
+    :returns: the new row in the database
+    :rtype: dict
+    :raises ProgrammingError: no primary key or missing privilege
+
+This method inserts a row into a table, but instead of raising a
+ProgrammingError exception in case a row with the same primary key already
+exists, an update will be executed instead.  This will be performed as a
+single atomic operation on the database, so race conditions can be avoided.
+
+Like the insert method, the first parameter is the name of the table and the
+second parameter can be used to pass the values to be inserted as a dictionary.
+
+Unlike the insert und update statement, keyword parameters are not used to
+modify the dictionary, but to specify which columns shall be updated in case
+of a conflict, and in which way:
+
+A value of `False` or `None` means the column shall not be updated,
+a value of `True` means the column shall be updated with the value that
+has been proposed for insertion, i.e. has been passed as value in the
+dictionary.  Columns that are not specified by keywords but appear as keys
+in the dictionary are also updated like in the case keywords had been passed
+with the value `True`.
+
+So if in the case of a conflict you want to update every column that has been
+passed in the dictionary `d` , you would call ``upsert(cl, d)``. If you don't
+want to do anything in case of a conflict, i.e. leave the existing row as it
+is, call ``upsert(cl, d, **dict.fromkeys(d))``.
+
+If you need more fine-grained control of what gets updated, you can also pass
+strings in the keyword parameters.  These strings will be used as SQL
+expressions for the update columns.  In these expressions you can refer
+to the value that already exists in the table by writing the table prefix
+``included.`` before the column name, and you can refer to the value that
+has been proposed for insertion by writing ``excluded.`` as table prefix.
+
+The dictionary is modified in any case to reflect the values in the database
+after the operation has completed.
+
+.. note::
+
+    The method uses the PostgreSQL "upsert" feature which is only available
+    since PostgreSQL 9.5. With older PostgreSQL versions, you will get a
+    ProgrammingError if you use this method.
+
+.. versionadded:: 5.0
+
 query -- execute a SQL command string
 -------------------------------------
 

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Wed Jan 13 10:50:25 2016        (r734)
+++ trunk/pg.py Wed Jan 13 16:49:35 2016        (r735)
@@ -513,7 +513,7 @@
             if not pkey:
                 raise KeyError('Class %s has no primary key' % cl)
             if len(pkey) > 1:
-                pkey = frozenset([k[0] for k in pkey])
+                pkey = frozenset(k[0] for k in pkey)
             else:
                 pkey = pkey[0][0]
             pkeys[cl] = pkey  # cache it
@@ -624,10 +624,6 @@
         """
         if cl.endswith('*'):  # scan descendant tables?
             cl = cl[:-1].rstrip()  # need parent table name
-        # build qualified class name
-        # To allow users to work with multiple tables,
-        # we munge the name of the "oid" key
-        qoid = _oid_key(cl)
         if not keyname:
             # use the primary key by default
             try:
@@ -638,7 +634,10 @@
         params = []
         param = partial(self._prepare_param, params=params)
         col = self.escape_identifier
-        # We want the oid for later updates if that isn't the key
+        # We want the oid for later updates if that isn't the key.
+        # To allow users to work with multiple tables, we munge
+        # the name of the "oid" key by adding the name of the class.
+        qoid = _oid_key(cl)
         if keyname == 'oid':
             if isinstance(arg, dict):
                 if qoid not in arg:
@@ -648,19 +647,20 @@
             what = '*'
             where = 'oid = %s' % param(arg[qoid], 'int')
         else:
-            if isinstance(keyname, basestring):
-                keyname = (keyname,)
+            keyname = [keyname] if isinstance(
+                keyname, basestring) else sorted(keyname)
             if not isinstance(arg, dict):
                 if len(keyname) > 1:
                     raise _prg_error('Composite key needs dict as arg')
                 arg = dict((k, arg) for k in keyname)
             what = ', '.join(col(k) for k in attnames)
-            where = ' AND '.join(['%s = %s'
-                % (col(k), param(arg[k], attnames[k])) for k in keyname])
+            where = ' AND '.join('%s = %s' % (
+                col(k), param(arg[k], attnames[k])) for k in keyname)
         q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
             what, self._escape_qualified_name(cl), where)
         self._do_debug(q, params)
-        res = self.db.query(q, params).dictresult()
+        q = self.db.query(q, params)
+        res = q.dictresult()
         if not res:
             raise _db_error('No such record in %s where %s' % (cl, where))
         for n, value in res[0].items():
@@ -688,7 +688,8 @@
         although PostgreSQL does.
 
         """
-        qoid = _oid_key(cl)
+        if 'oid' in kw:
+            del kw['oid']
         if d is None:
             d = {}
         d.update(kw)
@@ -698,7 +699,7 @@
         col = self.escape_identifier
         names, values = [], []
         for n in attnames:
-            if n != 'oid' and n in d:
+            if n in d:
                 names.append(col(n))
                 values.append(param(d[n], attnames[n]))
         names, values = ', '.join(names), ', '.join(values)
@@ -706,11 +707,13 @@
         q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
             self._escape_qualified_name(cl), names, values, ret)
         self._do_debug(q, params)
-        res = self.db.query(q, params)
-        res = res.dictresult()[0]
-        for n, value in res.items():
+        q = self.db.query(q, params)
+        res = q.dictresult()
+        if not res:
+            raise _int_error('insert did not return new values')
+        for n, value in res[0].items():
             if n == 'oid':
-                n = qoid
+                n = _oid_key(cl)
             elif attnames.get(n) == 'bytea' and value is not None:
                 value = self.unescape_bytea(value)
             d[n] = value
@@ -726,9 +729,9 @@
         values, etc.
 
         """
-        # Update always works on the oid which get returns if available,
+        # Update always works on the oid which get() returns if available,
         # otherwise use the primary key.  Fail if neither.
-        # Note that we only accept oid key from named args for safety
+        # Note that we only accept oid key from named args for safety.
         qoid = _oid_key(cl)
         if 'oid' in kw:
             kw[qoid] = kw['oid']
@@ -742,19 +745,21 @@
         col = self.escape_identifier
         if qoid in d:
             where = 'oid = %s' % param(d[qoid], 'int')
-            keyname = ()
+            keyname = []
         else:
             try:
                 keyname = self.pkey(cl)
             except KeyError:
                 raise _prg_error('Class %s has no primary key' % cl)
-            if isinstance(keyname, basestring):
-                keyname = (keyname,)
+            keyname = [keyname] if isinstance(
+                keyname, basestring) else sorted(keyname)
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (col(k), param(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join('%s = %s' % (
+                    col(k), param(d[k], attnames[k])) for k in keyname)
             except KeyError:
-                raise _prg_error('Update needs primary key or oid.')
+                raise _prg_error('update needs primary key or oid')
+        keyname = set(keyname)
+        keyname.add('oid')
         values = []
         for n in attnames:
             if n in d and n not in keyname:
@@ -766,14 +771,122 @@
         q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
             self._escape_qualified_name(cl), values, where, ret)
         self._do_debug(q, params)
-        res = self.db.query(q, params)
-        res = res.dictresult()[0]
-        for n, value in res.items():
-            if n == 'oid':
-                n = qoid
-            elif attnames.get(n) == 'bytea' and value is not None:
-                value = self.unescape_bytea(value)
-            d[n] = value
+        q = self.db.query(q, params)
+        res = q.dictresult()
+        if res:  # may be empty when row does not exist
+            for n, value in res[0].items():
+                if n == 'oid':
+                    n = qoid
+                elif attnames.get(n) == 'bytea' and value is not None:
+                    value = self.unescape_bytea(value)
+                d[n] = value
+        return d
+
+    def upsert(self, cl, d=None, **kw):
+        """Insert a row into a database table with conflict resolution.
+
+        This method inserts a row into a table, but instead of raising a
+        ProgrammingError exception in case a row with the same primary key
+        already exists, an update will be executed instead.  This will be
+        performed as a single atomic operation on the database, so race
+        conditions can be avoided.
+
+        Like the insert method, the first parameter is the name of the
+        table and the second parameter can be used to pass the values to
+        be inserted as a dictionary.
+
+        Unlike the insert und update statement, keyword parameters are not
+        used to modify the dictionary, but to specify which columns shall
+        be updated in case of a conflict, and in which way:
+
+        A value of False or None means the column shall not be updated,
+        a value of True means the column shall be updated with the value
+        that has been proposed for insertion, i.e. has been passed as value
+        in the dictionary.  Columns that are not specified by keywords but
+        appear as keys in the dictionary are also updated like in the case
+        keywords had been passed with the value True.
+
+        So if in the case of a conflict you want to update every column that
+        has been passed in the dictionary d , you would call upsert(cl, d).
+        If you don't want to do anything in case of a conflict, i.e. leave
+        the existing row as it is, call upsert(cl, d, **dict.fromkeys(d)).
+
+        If you need more fine-grained control of what gets updated, you can
+        also pass strings in the keyword parameters.  These strings will
+        be used as SQL expressions for the update columns.  In these
+        expressions you can refer to the value that already exists in
+        the table by prefixing the column name with "included.", and to
+        the value that has been proposed for insertion by prefixing the
+        column name with the "excluded."
+
+        The dictionary is modified in any case to reflect the values in
+        the database after the operation has completed.
+
+        Note: The method uses the PostgreSQL "upsert" feature which is
+        only available since PostgreSQL 9.5.
+
+        """
+        if 'oid' in kw:
+            del kw['oid']
+        if d is None:
+            d = {}
+        attnames = self.get_attnames(cl)
+        params = []
+        param = partial(self._prepare_param,params=params)
+        col = self.escape_identifier
+        names, values, updates = [], [], []
+        for n in attnames:
+            if n in d:
+                names.append(col(n))
+                values.append(param(d[n], attnames[n]))
+        names, values = ', '.join(names), ', '.join(values)
+        try:
+            keyname = self.pkey(cl)
+        except KeyError:
+            raise _prg_error('Class %s has no primary key' % cl)
+        keyname = [keyname] if isinstance(
+            keyname, basestring) else sorted(keyname)
+        try:
+            target = ', '.join(col(k) for k in keyname)
+        except KeyError:
+            raise _prg_error('upsert needs primary key or oid')
+        update = []
+        keyname = set(keyname)
+        keyname.add('oid')
+        for n in attnames:
+            if n not in keyname:
+                value = kw.get(n, True)
+                if value:
+                    if not isinstance(value, basestring):
+                        value = 'excluded.%s' % col(n)
+                    update.append('%s = %s' % (col(n), value))
+        if not values and not update:
+            return d
+        do = 'update set %s' % ', '.join(update) if update else 'nothing'
+        ret = 'oid, *' if 'oid' in attnames else '*'
+        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
+            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
+                self._escape_qualified_name(cl), names, values,
+                target, do, ret)
+        self._do_debug(q, params)
+        try:
+            q = self.db.query(q, params)
+        except ProgrammingError:
+            if self.server_version < 90500:
+                raise _prg_error('upsert not supported by PostgreSQL version')
+            raise  # re-raise original error
+        res = q.dictresult()
+        if res:  # may be empty with "do nothing"
+            for n, value in res[0].items():
+                if n == 'oid':
+                    n = _oid_key(cl)
+                elif attnames.get(n) == 'bytea':
+                    value = self.unescape_bytea(value)
+                d[n] = value
+        elif update:
+            raise _int_error('upsert did not return new values')
+        else:
+            self.get(cl, d)
         return d
 
     def clear(self, cl, a=None):
@@ -814,7 +927,7 @@
         # Like update, delete works on the oid.
         # One day we will be testing that the record to be deleted
         # isn't referenced somewhere (or else PostgreSQL will).
-        # Note that we only accept oid key from named args for safety
+        # Note that we only accept oid key from named args for safety.
         qoid = _oid_key(cl)
         if 'oid' in kw:
             kw[qoid] = kw['oid']
@@ -831,19 +944,20 @@
                 keyname = self.pkey(cl)
             except KeyError:
                 raise _prg_error('Class %s has no primary key' % cl)
-            if isinstance(keyname, basestring):
-                keyname = (keyname,)
+            keyname = [keyname] if isinstance(
+                keyname, basestring) else sorted(keyname)
             attnames = self.get_attnames(cl)
             col = self.escape_identifier
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (col(k), param(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join('%s = %s'
+                    % (col(k), param(d[k], attnames[k])) for k in keyname)
             except KeyError:
-                raise _prg_error('Delete needs primary key or oid.')
+                raise _prg_error('delete needs primary key or oid')
         q = 'DELETE FROM %s WHERE %s' % (
             self._escape_qualified_name(cl), where)
         self._do_debug(q, params)
-        return int(self.db.query(q, params))
+        res = self.db.query(q, params)
+        return int(res)
 
     def notification_handler(self, event, callback, arg_dict={}, timeout=None):
         """Get notification handler that will run the given callback."""

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Wed Jan 13 10:50:25 2016        
(r734)
+++ trunk/tests/test_classic_dbwrapper.py       Wed Jan 13 16:49:35 2016        
(r735)
@@ -132,6 +132,7 @@
             'transaction',
             'unescape_bytea',
             'update',
+            'upsert',
             'use_regtypes',
             'user',
         ]
@@ -867,8 +868,8 @@
         r['t'] = 'u'
         s = update(table, r)
         self.assertEqual(s, r)
-        r = query('select t from "%s" where n=2' % table
-                  ).getresult()[0][0]
+        q = 'select t from "%s" where n=2' % table
+        r = query(q).getresult()[0][0]
         self.assertEqual(r, 'u')
         query('drop table "%s"' % table)
 
@@ -884,10 +885,24 @@
                 "%d, '%s')" % (table, n + 1, t))
         self.assertRaises(pg.ProgrammingError, update,
                           table, dict(t='b'))
-        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
-        r = query('select t from "%s" where n=2' % table
-                  ).getresult()[0][0]
+        s = dict(n=2, t='d')
+        r = update(table, s)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'd')
+        q = 'select t from "%s" where n=2' % table
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 'd')
+        s.update(dict(n=4, t='e'))
+        r = update(table, s)
+        self.assertEqual(r['n'], 4)
+        self.assertEqual(r['t'], 'e')
+        q = 'select t from "%s" where n=2' % table
+        r = query(q).getresult()[0][0]
         self.assertEqual(r, 'd')
+        q = 'select t from "%s" where n=4' % table
+        r = query(q).getresult()
+        self.assertEqual(len(r), 0)
         query('drop table "%s"' % table)
         table = 'update_test_table_2'
         query('drop table if exists "%s"' % table)
@@ -902,8 +917,8 @@
                           table, dict(n=2, t='b'))
         self.assertEqual(update(table,
                                 dict(n=2, m=2, t='x'))['t'], 'x')
-        r = [r[0] for r in query('select t from "%s" where n=2'
-            ' order by m' % table).getresult()]
+        q = 'select t from "%s" where n=2 order by m' % table
+        r = [r[0] for r in query(q).getresult()]
         self.assertEqual(r, ['c', 'x'])
         query('drop table "%s"' % table)
 
@@ -931,6 +946,175 @@
         self.assertEqual(r['Questions?'], 'When?')
         query('drop table "%s"' % table)
 
+    def testUpsert(self):
+        upsert = self.db.upsert
+        query = self.db.query
+        table = 'upsert_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer primary key, t text) with oids" % table)
+        s = dict(n=1, t='x')
+        try:
+            r = upsert(table, s)
+        except pg.ProgrammingError as error:
+            if self.db.server_version < 90500:
+                self.skipTest('database does not support upsert')
+            self.fail(str(error))
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['t'], 'x')
+        s.update(n=2, t='y')
+        r = upsert(table, s, **dict.fromkeys(s))
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'y')
+        q = 'select n, t from "%s" order by n limit 3' % table
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'y')])
+        s.update(t='z')
+        r = upsert(table, s)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'z')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'z')])
+        s.update(t='n')
+        r = upsert(table, s, t=False)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'z')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'z')])
+        s.update(t='y')
+        r = upsert(table, s, t=True)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'y')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'y')])
+        s.update(t='n')
+        r = upsert(table, s, t="included.t || '2'")
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'y2')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
+        s.update(t='y')
+        r = upsert(table, s, t="excluded.t || '3'")
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['t'], 'y3')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
+        s.update(n=1, t='2')
+        r = upsert(table, s, t="included.t || excluded.t")
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['t'], 'x2')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
+        query('drop table "%s"' % table)
+
+    def testUpsertWithCompositeKey(self):
+        upsert = self.db.upsert
+        query = self.db.query
+        table = 'upsert_test_table_2'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer, m integer, t text, primary key (n, m))" % table)
+        s = dict(n=1, m=2, t='x')
+        try:
+            r = upsert(table, s)
+        except pg.ProgrammingError as error:
+            if self.db.server_version < 90500:
+                self.skipTest('database does not support upsert')
+            self.fail(str(error))
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 2)
+        self.assertEqual(r['t'], 'x')
+        s.update(m=3, t='y')
+        r = upsert(table, s, **dict.fromkeys(s))
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'y')
+        q = 'select n, m, t from "%s" order by n, m limit 3' % table
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
+        s.update(t='z')
+        r = upsert(table, s)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'z')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
+        s.update(t='n')
+        r = upsert(table, s, t=False)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'z')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
+        s.update(t='n')
+        r = upsert(table, s, t=True)
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'n')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
+        s.update(n=2, t='y')
+        r = upsert(table, s, t="'z'")
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 2)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'y')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
+        s.update(n=1, t='m')
+        r = upsert(table, s, t='included.t || excluded.t')
+        self.assertIs(r, s)
+        self.assertEqual(r['n'], 1)
+        self.assertEqual(r['m'], 3)
+        self.assertEqual(r['t'], 'nm')
+        r = query(q).getresult()
+        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
+        query('drop table "%s"' % table)
+
+    def testUpsertWithQuotedNames(self):
+        upsert = self.db.upsert
+        query = self.db.query
+        table = 'test table for upsert()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
+        try:
+            r = upsert(table, s)
+        except pg.ProgrammingError as error:
+            if self.db.server_version < 90500:
+                self.skipTest('database does not support upsert')
+            self.fail(str(error))
+        self.assertIs(r, s)
+        self.assertEqual(r['Prime!'], 31)
+        self.assertEqual(r['much space'], 9009)
+        self.assertEqual(r['Questions?'], 'Yes.')
+        q = 'select * from "%s" limit 2' % table
+        r = query(q).getresult()
+        self.assertEqual(r, [(31, 9009, 'Yes.')])
+        s.update({'Questions?': 'No.'})
+        r = upsert(table, s)
+        self.assertIs(r, s)
+        self.assertEqual(r['Prime!'], 31)
+        self.assertEqual(r['much space'], 9009)
+        self.assertEqual(r['Questions?'], 'No.')
+        r = query(q).getresult()
+        self.assertEqual(r, [(31, 9009, 'No.')])
+
     def testClear(self):
         clear = self.db.clear
         query = self.db.query
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to