import sys
import urllib
from sqlalchemy import types as sqltypes
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine import reflection
from sqlalchemy.util import asbool


ischema_names = {
    'BINARY': sqltypes.BLOB,
    'VARBINARY': sqltypes.BLOB,
    'BYTEA': sqltypes.BLOB,
    'RAW': sqltypes.BLOB,

    'BOOLEAN': sqltypes.BOOLEAN,

    'CHAR': sqltypes.CHAR,
    'VARCHAR': sqltypes.VARCHAR,
    'VARCHAR2': sqltypes.VARCHAR,

    'DATE': sqltypes.DATE,
    'DATETIME': sqltypes.DATETIME,
    'SMALLDATETIME': sqltypes.DATETIME,
    'TIME': sqltypes.TIME,
    # Not supported yet
    # TIME WITH TIMEZONE
    # TIMESTAMP
    # TIMESTAMP WITH TIMEZONE
    # INTERVAL

    # All the same internal representation
    'FLOAT': sqltypes.FLOAT,
    'FLOAT8': sqltypes.FLOAT,
    'DOUBLE': sqltypes.FLOAT,
    'REAL': sqltypes.FLOAT,

    'INT': sqltypes.INTEGER,
    'INTEGER': sqltypes.INTEGER,
    'INT8': sqltypes.INTEGER,
    'BIGINT': sqltypes.INTEGER,
    'SMALLINT': sqltypes.INTEGER,
    'TINYINT': sqltypes.INTEGER,

    'NUMERIC': sqltypes.NUMERIC,
    'DECIMAL': sqltypes.NUMERIC,
    'NUMBER': sqltypes.NUMERIC,
    'MONEY': sqltypes.NUMERIC,
}


class VerticaConnection(PyODBCConnector):
    unicode_type = (sys.maxunicode < 2 ** 16) and 'UCS2' or 'UCS4'

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='user')
        opts.update(url.query)

        keys = opts
        query = url.query

        connect_args = {}
        for param in ('ansi', 'unicode_results', 'autocommit'):
            if param in keys:
                connect_args[param] = asbool(keys.pop(param))

        if 'odbc_connect' in keys:
            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
        else:
            dsn_connection = 'dsn' in keys or (
                    'host' in keys and 'database' not in keys)
            if dsn_connection:
                connectors= [
                    'dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))
                ]
            else:
                # Default Vertica port:
                port = '5433'
                if 'port' in keys and not 'port' in query:
                    port = '%d' % int(keys.pop('port'))

                connectors = [
                    "DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
                    'SERVERNAME=%s' % keys.pop('host', ''),
                    'DATABASE=%s' % keys.pop('database', '')
                ]

            user = keys.pop("user", None)
            if user:
                connectors.append("USERNAME=%s" % user)
                connectors.append("PASSWORD=%s" % keys.pop('password', ''))

            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])

        # Handle funny unicode behavior in pyodbc
        # The issue seems isolated to unixODBC though, and not iODBC.
        connectors.append('WIDECHARSIZEOUT=4')
        if self.unicode_type == 'UCS2':
            connectors.append('WIDECHARSIZEIN=2')
        elif self.unicode_type == 'UCS4':
            connectors.append('WIDECHARSIZEIN=4')

        return [[";".join (connectors)], connect_args]


class VerticaDialect(VerticaConnection, SQLiteDialect):
    name = 'vertica'
    ischema_names = ischema_names

    def __init__(self, **params):
        super(VerticaDialect, self).__init__(**params)

    @reflection.cache
    def get_table_names(self, connection, schema=None, **kw):
        s = ["SELECT table_name FROM v_catalog.tables"]
        if schema is not None:
            s.append("WHERE table_schema = '%s'" % (schema,))
        s.append("ORDER BY table_schema, table_name")

        rs = connection.execute(' '.join(s))
        return [row[0] for row in rs]


    @reflection.cache
    def get_columns(self, connection, table_name, schema=None, **kw):
        s = ("SELECT * FROM v_catalog.columns "
             "WHERE table_name = '%s' ") % (table_name,)

        spk = ("SELECT column_name FROM v_catalog.primary_keys "
               "WHERE table_name = '%s' "
               "AND constraint_type = 'p'") % (table_name)

        if schema is not None:
            _pred = lambda p: ("%s AND table_schema = '%s'" % (p, schema))
            s = _pred(s)
            spk = _pred(spk)

        pk_columns = [x[0] for x in connection.execute(spk)]
        columns = []
        for row in connection.execute(s):
            name = row.column_name
            coltype = self.ischema_names[row.data_type.upper()]
            primary_key = name in pk_columns
            default = row.column_default
            nullable = row.is_nullable

            columns.append({
                'name' : name,
                'type' : coltype,
                'nullable' : nullable,
                'default' : default,
                'primary_key': primary_key
            })
        return columns


dialect = VerticaDialect
