patch is attached

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sqlalchemy" group.
To post to this group, send email to [email protected]
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at 
http://groups.google.com/group/sqlalchemy?hl=en
-~----------~----~----~----~------~----~------~--~---

Index: mssql.py
===================================================================
--- mssql.py    (revision 4043)
+++ mssql.py    (working copy)
@@ -20,7 +20,7 @@
   Note that the start & increment values for sequences are optional
   and will default to 1,1.
 
-* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for 
+* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
   ``INSERT`` s)
 
 * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
@@ -34,7 +34,7 @@
 
 * pymssql has problems with binary and unicode data that this module
   does **not** work around
-  
+
 """
 
 import datetime, random, warnings, re, sys, operator
@@ -44,7 +44,7 @@
 from sqlalchemy.engine import default, base
 from sqlalchemy import types as sqltypes
 from sqlalchemy.util import Decimal as _python_Decimal
-    
+
 MSSQL_RESERVED_WORDS = util.Set(['function'])
 
 class MSNumeric(sqltypes.Numeric):
@@ -67,9 +67,9 @@
                 # Not sure that this exception is needed
                 return value
             else:
-                return str(value) 
+                return str(value)
         return process
-        
+
     def get_col_spec(self):
         if self.precision is None:
             return "NUMERIC"
@@ -87,7 +87,7 @@
                 return str(value)
             return None
         return process
-        
+
 class MSInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -116,14 +116,14 @@
         super(MSDate, self).__init__(False)
 
     def get_col_spec(self):
-        return "SMALLDATETIME"
+        return "DATETIME"
 
 class MSTime(sqltypes.Time):
     __zero_date = datetime.date(1900, 1, 1)
 
     def __init__(self, *a, **kw):
         super(MSTime, self).__init__(False)
-    
+
     def get_col_spec(self):
         return "DATETIME"
 
@@ -135,7 +135,7 @@
                 value = datetime.datetime.combine(self.__zero_date, value)
             return value
         return process
-    
+
     def result_processor(self, dialect):
         def process(value):
             if type(value) is datetime.datetime:
@@ -144,7 +144,7 @@
                 return datetime.time(0, 0, 0)
             return value
         return process
-        
+
 class MSDateTime_adodbapi(MSDateTime):
     def result_processor(self, dialect):
         def process(value):
@@ -154,7 +154,7 @@
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-        
+
 class MSDateTime_pyodbc(MSDateTime):
     def bind_processor(self, dialect):
         def process(value):
@@ -162,7 +162,7 @@
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-        
+
 class MSDate_pyodbc(MSDate):
     def bind_processor(self, dialect):
         def process(value):
@@ -170,7 +170,7 @@
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-    
+
     def result_processor(self, dialect):
         def process(value):
             # pyodbc returns SMALLDATETIME values as datetime.datetime(). 
truncate it back to datetime.date()
@@ -178,7 +178,7 @@
                 return value.date()
             return value
         return process
-        
+
 class MSDate_pymssql(MSDate):
     def result_processor(self, dialect):
         def process(value):
@@ -187,11 +187,11 @@
                 return value.date()
             return value
         return process
-        
+
 class MSText(sqltypes.Text):
     def get_col_spec(self):
         if self.dialect.text_as_varchar:
-            return "VARCHAR(max)"            
+            return "VARCHAR(max)"
         else:
             return "TEXT"
 
@@ -238,7 +238,7 @@
                 return None
             return value and True or False
         return process
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is True:
@@ -250,27 +250,27 @@
             else:
                 return value and True or False
         return process
-        
+
 class MSTimeStamp(sqltypes.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
-        
+
 class MSMoney(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "MONEY"
-        
+
 class MSSmallMoney(MSMoney):
     def get_col_spec(self):
         return "SMALLMONEY"
-        
+
 class MSUniqueIdentifier(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "UNIQUEIDENTIFIER"
-        
+
 class MSVariant(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "SQL_VARIANT"
-        
+
 def descriptor():
     return {'name':'mssql',
     'description':'MSSQL',
@@ -297,7 +297,7 @@
     def pre_exec(self):
         """MS-SQL has a special mode for inserting non-NULL values
         into IDENTITY columns.
-        
+
         Activate it if the feature is turned on and needed.
         """
         if self.compiled.isinsert:
@@ -328,7 +328,7 @@
         and fetch recently inserted IDENTIFY values (works only for
         one column).
         """
-        
+
         if self.compiled.isinsert and self.HASIDENT and not self.IINSERT:
             if not len(self._last_inserted_ids) or self._last_inserted_ids[0] 
is None:
                 if self.dialect.use_scope_identity:
@@ -339,17 +339,17 @@
                 self._last_inserted_ids = [int(row[0])] + 
self._last_inserted_ids[1:]
                 # print "LAST ROW ID", self._last_inserted_ids
         super(MSSQLExecutionContext, self).post_exec()
-    
+
     _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
                                re.I | re.UNICODE)
-    
+
     def returns_rows_text(self, statement):
         return self._ms_is_select.match(statement) is not None
 
 
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):    
+class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
     def pre_exec(self):
-        """where appropriate, issue "select scope_identity()" in the same 
statement"""                
+        """where appropriate, issue "select scope_identity()" in the same 
statement"""
         super(MSSQLExecutionContext_pyodbc, self).pre_exec()
         if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \
                 and len(self.parameters) == 1 and 
self.dialect.use_scope_identity:
@@ -418,7 +418,7 @@
             return dialect(*args, **kwargs)
         else:
             return object.__new__(cls, *args, **kwargs)
-                
+
     def __init__(self, auto_identity_insert=True, **params):
         super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
@@ -442,7 +442,7 @@
             else:
                 raise ImportError('No DBAPI module detected for MSSQL - please 
install pyodbc, pymssql, or adodbapi')
     dbapi = classmethod(dbapi)
-    
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         opts.update(url.query)
@@ -477,20 +477,20 @@
 
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
-            
+
     def do_execute(self, cursor, statement, params, context=None, **kwargs):
         if params == {}:
             params = ()
         try:
             super(MSSQLDialect, self).do_execute(cursor, statement, params, 
context=context, **kwargs)
-        finally:        
+        finally:
             if context.IINSERT:
                 cursor.execute("SET IDENTITY_INSERT %s OFF" % 
self.identifier_preparer.format_table(context.compiled.statement.table))
-         
+
     def do_executemany(self, cursor, statement, params, context=None, 
**kwargs):
         try:
             super(MSSQLDialect, self).do_executemany(cursor, statement, 
params, context=context, **kwargs)
-        finally:        
+        finally:
             if context.IINSERT:
                 cursor.execute("SET IDENTITY_INSERT %s OFF" % 
self.identifier_preparer.format_table(context.compiled.statement.table))
 
@@ -511,7 +511,7 @@
     def raw_connection(self, connection):
         """Pull the raw pymmsql connection out--sensative to 
"pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
         try:
-            # TODO: probably want to move this to individual dialect 
subclasses to 
+            # TODO: probably want to move this to individual dialect 
subclasses to
             # save on the exception throw + simplify
             return connection.connection.__dict__['_pymssqlCnx__cnx']
         except:
@@ -536,14 +536,14 @@
                        and sql.and_(columns.c.table_name==tablename, 
columns.c.table_schema==current_schema)
                        or columns.c.table_name==tablename,
                    )
-        
+
         c = connection.execute(s)
         row  = c.fetchone()
         return row is not None
-        
+
     def reflecttable(self, connection, table, include_columns):
         import sqlalchemy.databases.information_schema as ischema
-        
+
         # Get base columns
         if table.schema is not None:
             current_schema = table.schema
@@ -556,7 +556,7 @@
                        and sql.and_(columns.c.table_name==table.name, 
columns.c.table_schema==current_schema)
                        or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
-        
+
         c = connection.execute(s)
         found_table = False
         while True:
@@ -565,9 +565,9 @@
                 break
             found_table = True
             (name, type, nullable, charlen, numericprec, numericscale, 
default) = (
-                row[columns.c.column_name], 
-                row[columns.c.data_type], 
-                row[columns.c.is_nullable] == 'YES', 
+                row[columns.c.column_name],
+                row[columns.c.data_type],
+                row[columns.c.is_nullable] == 'YES',
                 row[columns.c.character_maximum_length],
                 row[columns.c.numeric_precision],
                 row[columns.c.numeric_scale],
@@ -582,21 +582,21 @@
                     args.append(a)
             coltype = self.ischema_names.get(type, None)
             if coltype == MSString and charlen == -1:
-                coltype = MSText()                
+                coltype = MSText()
             else:
                 if coltype is None:
                     warnings.warn(RuntimeWarning("Did not recognize type '%s' 
of column '%s'" % (type, name)))
                     coltype = sqltypes.NULLTYPE
-                    
+
                 elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
                     args[0] = None
                 coltype = coltype(*args)
             colargs= []
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
-                
+
             table.append_column(schema.Column(name, coltype, 
nullable=nullable, autoincrement=False, *colargs))
-        
+
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
 
@@ -631,7 +631,7 @@
         # Add constraints
         RR = self.uppercase_table(ischema.ref_constraints)    
#information_schema.referential_constraints
         TC = self.uppercase_table(ischema.constraints)        
#information_schema.table_constraints
-        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') 
#information_schema.constraint_column_usage: the constrained column 
+        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') 
#information_schema.constraint_column_usage: the constrained column
         R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') 
#information_schema.constraint_column_usage: the referenced column
 
         # Primary key constraints
@@ -672,7 +672,7 @@
 class MSSQLDialect_pymssql(MSSQLDialect):
     supports_sane_rowcount = False
     max_identifier_length = 30
-    
+
     def import_dbapi(cls):
         import pymssql as module
         # pymmsql doesn't have a Binary method.  we use string
@@ -680,7 +680,7 @@
         module.Binary = lambda st: str(st)
         return module
     import_dbapi = classmethod(import_dbapi)
-    
+
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Date] = MSDate_pymssql
 
@@ -723,7 +723,7 @@
 ##    This code is leftover from the initial implementation, for reference
 ##    def do_begin(self, connection):
 ##        """implementations might want to put logic here for turning 
autocommit on/off, etc."""
-##        pass  
+##        pass
 
 ##    def do_rollback(self, connection):
 ##        """implementations might want to put logic here for turning 
autocommit on/off, etc."""
@@ -740,7 +740,7 @@
 
 ##    def do_commit(self, connection):
 ##        """implementations might want to put logic here for turning 
autocommit on/off, etc.
-##            do_commit is set for pymmsql connections--ADO seems to handle 
transactions without any issue 
+##            do_commit is set for pymmsql connections--ADO seems to handle 
transactions without any issue
 ##        """
 ##        # ADO Uses Implicit Transactions.
 ##        # This is very pymssql specific.  We use this instead of its commit, 
because it hangs on failed rollbacks.
@@ -757,7 +757,7 @@
     # PyODBC unicode is broken on UCS-4 builds
     supports_unicode = sys.maxunicode == 65535
     supports_unicode_statements = supports_unicode
-    
+
     def __init__(self, **params):
         super(MSSQLDialect_pyodbc, self).__init__(**params)
         # whether use_scope_identity will work depends on the version of pyodbc
@@ -766,12 +766,12 @@
             self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
         except:
             pass
-        
+
     def import_dbapi(cls):
         import pyodbc as module
         return module
     import_dbapi = classmethod(import_dbapi)
-    
+
     colspecs = MSSQLDialect.colspecs.copy()
     if supports_unicode:
         colspecs[sqltypes.Unicode] = AdoMSNVarchar
@@ -877,16 +877,17 @@
     def get_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
         s = select._distinct and "DISTINCT " or ""
-        if select._limit:
+
+        """
+        if select._limit and not select._offset:
             s += "TOP %s " % (select._limit,)
-        if select._offset:
-            raise exceptions.InvalidRequestError('MSSQL does not support LIMIT 
with an offset')
+        """
         return s
 
-    def limit_clause(self, select):    
+    def limit_clause(self, select):
         # Limit in mssql is after the select keyword
         return ""
-            
+
     def _schema_aliased_table(self, table):
         if getattr(table, 'schema', None) is not None:
             if table not in self.tablealiases:
@@ -894,7 +895,7 @@
             return self.tablealiases[table]
         else:
             return None
-            
+
     def visit_table(self, table, mssql_aliased=False, **kwargs):
         if mssql_aliased:
             return super(MSSQLCompiler, self).visit_table(table, **kwargs)
@@ -905,7 +906,7 @@
             return self.process(alias, mssql_aliased=True, **kwargs)
         else:
             return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- 
+
     def visit_alias(self, alias, **kwargs):
         # translate for schema-qualified table aliases
         self.tablealiases[alias.original] = alias
@@ -953,11 +954,100 @@
         else:
             return ""
 
+    import re
+    EMPTY_INSERT = re.compile(r"^INSERT INTO (.+) \(\) VALUES \(\)$")
 
+    def visit_insert(self, insert_statement):
+        result =super(MSSQLCompiler, self).visit_insert(insert_statement)
+        m = self.EMPTY_INSERT.match(result)
+        if m:
+            result = "INSERT INTO %s DEFAULT VALUES" % m.group(1)
+        return result
+
+    def visit_select(self, select, **kwargs):
+        """Look for ``LIMIT`` and OFFSET in a select statement, and if
+        so tries to wrap it in a subquery with ``row_number()`` criterion.
+        """
+
+        if not getattr(select, '_mssql_visit', None) and (select._offset is 
not None or select._limit):
+            select._mssql_visit = True
+
+            aliased_select = select.order_by(None).alias('_mso')
+
+            orderby = select._order_by_clause
+
+            if not orderby:
+                orderby = 
sql.expression.ClauseList(select.oid_column.proxies[0])
+
+            def adapt_column_text(text):
+                id = text.split(".")
+                name = id[0]
+                if len(id) > 1: #is multipart identifier
+                    for f in select.froms :
+                        if f.name.lower() == name.lower():
+                            break
+                    else:
+                        # may not be a table, misleading exception ?
+                        raise exceptions.NoSuchTableError(name)
+                    return str(adapt_orderby_column(f.columns[id[1]]))
+                else:
+                    return "_mso." + name
+
+            def adapt_orderby_expr_text(text):
+                clauses = text.split()
+                clauses[0] = adapt_column_text(clauses[0])
+                return " ".join(clauses)
+
+            def adapt_orderby_list_text(text):
+                return sql.expression._TextClause(", ".join
+                                        (map(adapt_orderby_expr_text, \
+                                        filter(None,
+                                               map(lambda s : s.strip(), \
+                                                   text.split(","))))))
+
+            def adapt_orderby_column(c):
+                r = aliased_select.corresponding_column(c, False)
+                if not r:
+                    locals()["aliased_select"] = 
aliased_select.original.column(c).alias("_mso")
+                r = aliased_select.corresponding_column(c, True)
+                return r
+
+            def adapt_orderby_unary(c):
+                c.element = adapt_orderby_clause(c.element)
+                return c
+
+            def adapt_orderby_clause(c):
+                if isinstance(c, sql.expression._TextClause):
+                    return adapt_orderby_list_text(c.text)
+                elif isinstance(c, sql.expression._UnaryExpression):
+                    return adapt_orderby_unary(c)
+                else:
+                    return adapt_orderby_column(c)
+
+            orderby = sql.expression.ClauseList(*map(adapt_orderby_clause, 
orderby))
+
+            aliased_select = 
aliased_select.select().column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY 
%s)" % orderby).label("_mssql_rn"))
+
+            limitselect = sql.select([c for c in aliased_select.alias().c if 
c.key!='_mssql_rn'])
+
+            if select._offset is not None:
+                # faithfully translate offset 0 to 0,
+                # used as a way to trigger emulated offset even for offset of 0
+                # as this usually result in more optimal query plans
+                # (implies loop join) for eagerloaded entities
+                limitselect.append_whereclause("_mssql_rn>%d" % select._offset)
+
+            if select._limit is not None:
+                limitselect.append_whereclause("_mssql_rn<=%d" % 
(select._limit + (select._offset or 0)))
+
+            return self.process(limitselect, **kwargs)
+        else:
+            return compiler.DefaultCompiler.visit_select(self, select, 
**kwargs)
+
 class MSSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + 
column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
-        
+
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and 
column.primary_key and \
                 column.autoincrement and isinstance(column.type, 
sqltypes.Integer) and not column.foreign_keys:
@@ -974,7 +1064,7 @@
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
-        
+
         return colspec
 
 class MSSQLSchemaDropper(compiler.SchemaDropper):

Reply via email to