Hi, list!

I've experimented a bit with web.db sources and managed to hack in a few
lines of code that cause web.db methods to double-quote table and column
names.

Here's a python interactive session so you can see what it's about:

>>> import web
>>> db = web.database(dbn='sqlite', db='test.sqlite3')
>>> db.query('create table "test" ("id" integer primary key \ 
autoincrement, "name" varchar(40));')
...

>>> db.insert('test', name='testing')
0.0 (2): INSERT INTO "test" ("name") VALUES ('testing')
...

>>> names = ['jimmy', 'benny', 'sammy']
>>> insert_vals = [{'name': name} for name in names]
>>> db.multiple_insert('test', values=insert_vals)
0.0 (4): INSERT INTO "test" ("name") VALUES ('jimmy')
...

>>> db.where('test', name='jimmy')
0.0 (10): SELECT * FROM "test" WHERE "name" = 'jimmy'
...

>>> db.select('test', what='test.name')
0.0 (11): SELECT "test"."name" FROM "test"
...


I've attached the diff. The diff is created against the last commit on
github:

http://github.com/webpy/webpy/commit/cef03ac4eb50480faf4dafb66163ed97bcd19b9b

The code for this is hosted on my github fork:

http://github.com/foxbunny/webpy/tree/master

Sorry if the code is a bit dodgy, I'm still learning. The idea was to
demonstrate the feature, rather than make it a solid implementation.


Best regards,

-- 
Branko

eml: [email protected]
alt: [email protected]
blg1: http://sudologic.blogspot.com/
blg2: http://brankovukelic.blogspot.com/
img: http://picasaweb.google.com/bg.branko
twt: http://www.twitter.com/foxbunny/

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"web.py" group.
To post to this group, send email to [email protected]
To unsubscribe from this group, send email to [email protected]
For more options, visit this group at http://groups.google.com/group/webpy?hl=en
-~----------~----~----~----~------~----~------~--~---

diff --git a/web/db.py b/web/db.py
index 5c207cf..a969810 100644
--- a/web/db.py
+++ b/web/db.py
@@ -315,10 +315,17 @@ def sqldqlist(lst):
         '"a", "b"'
         >>> sqldqlist(u'abc')
         u'"abc"'
+        >>> sqldqlist('table.col')
+        '"table"."col"'
 
     """
     if isinstance(lst, basestring):
-        return '"%s"' % lst
+        # if the string contains a ., split it
+        head, sep, tail = lst.partition('.')
+        if sep and tail:
+            return '"%s"."%s"' % (head, tail)
+        else:
+            return '"%s"' % lst
     else:
         return ', '.join(['"%s"' % item for item in lst])
 
@@ -365,7 +372,7 @@ def sqlwhere(dictionary, grouping=' AND '):
         >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
         'a = %s AND b = %s'
     """
-    return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
+    return SQLQuery.join([sqldqlist(k) + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
 
 def sqlquote(a): 
     """
@@ -576,7 +583,7 @@ class DB:
     
     def _where(self, where, vars): 
         if isinstance(where, (int, long)):
-            where = "id = " + sqlparam(where)
+            where = '"id" = ' + sqlparam(where)
         #@@@ for backward-compatibility
         elif isinstance(where, (list, tuple)) and len(where) == 2:
             where = SQLQuery(where[0], where[1])
@@ -661,15 +668,16 @@ class DB:
         """
         where = []
         for k, v in kwargs.iteritems():
-            where.append(k + ' = ' + sqlquote(v))
+            where.append(sqldqlist(k) + ' = ' + sqlquote(v))
         return self.select(table, what=what, order=order, 
                group=group, limit=limit, offset=offset, _test=_test, 
                where=SQLQuery.join(where, ' AND '))
     
     def sql_clauses(self, what, tables, where, group, order, limit, offset): 
+        if not what == '*': what = sqldqlist(what)
         return (
             ('SELECT', what),
-            ('FROM', sqllist(tables)),
+            ('FROM', sqldqlist(tables)),
             ('WHERE', where),
             ('GROUP BY', group),
             ('ORDER BY', order),
@@ -679,7 +687,7 @@ class DB:
     def gen_clause(self, sql, val, vars): 
         if isinstance(val, (int, long)):
             if sql == 'WHERE':
-                nout = 'id = ' + sqlquote(val)
+                nout = '"id" = ' + sqlquote(val)
             else:
                 nout = SQLQuery(val)
         #@@@
@@ -714,9 +722,10 @@ class DB:
         def q(x): return "(" + x + ")"
         
         if values:
-            _keys = SQLQuery.join(values.keys(), ', ')
+            quoted_keys = ['"%s"' % item for item in values.keys()]
+            _keys = SQLQuery.join(quoted_keys, ', ')
             _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
-            sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
+            sql_query = "INSERT INTO %s " % sqldqlist(tablename) + q(_keys) + ' VALUES ' + q(_values)
         else:
             sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename)
 
@@ -776,7 +785,9 @@ class DB:
             if v.keys() != keys:
                 raise ValueError, 'Bad data'
 
-        sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) 
+        quoted_keys = ['"%s"' % key for key in keys]
+        sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, 
+                                                              ', '.join(quoted_keys))) 
 
         data = []
         for row in values:
@@ -830,7 +841,7 @@ class DB:
         where = self._where(where, vars)
 
         query = (
-          "UPDATE " + sqllist(tables) + 
+          "UPDATE " + sqldqlist(tables) + 
           " SET " + sqlwhere(values, ', ') + 
           " WHERE " + where)
 
@@ -854,7 +865,7 @@ class DB:
         if vars is None: vars = {}
         where = self._where(where, vars)
 
-        q = 'DELETE FROM ' + table
+        q = 'DELETE FROM ' + sqldqlist(table)
         if where: q += ' WHERE ' + where
         if using: q += ' USING ' + sqllist(using)
 
@@ -988,7 +999,7 @@ class FirebirdDB(DB):
             ('FIRST', limit),
             ('SKIP', offset),
             ('', what),
-            ('FROM', sqllist(tables)),
+            ('FROM', sqldqlist(tables)),
             ('WHERE', where),
             ('GROUP BY', group),
             ('ORDER BY', order)
@@ -1007,7 +1018,7 @@ class MSSQLDB(DB):
         return (
             ('SELECT', what),
             ('TOP', limit),
-            ('FROM', sqllist(tables)),
+            ('FROM', sqlsqlist(tables)),
             ('WHERE', where),
             ('GROUP BY', group),
             ('ORDER BY', order),

Reply via email to