This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7fcc2af  [sql] Correct SQL parameter formatting (#5178)
7fcc2af is described below

commit 7fcc2af68f79a8a78e1799feb80647ae90ac9370
Author: John Bodley <4567245+john-bod...@users.noreply.github.com>
AuthorDate: Sat Jul 21 12:01:26 2018 -0700

    [sql] Correct SQL parameter formatting (#5178)
---
 .pylintrc                                          |  2 +-
 superset/connectors/sqla/models.py                 |  9 +--
 superset/db_engine_specs.py                        | 12 ++-
 .../4451805bbaa1_remove_double_percents.py         | 86 ++++++++++++++++++++++
 superset/models/core.py                            | 43 +++++++----
 superset/sql_lab.py                                |  3 +-
 tests/core_tests.py                                |  9 +++
 tests/sqllab_tests.py                              |  2 +-
 tox.ini                                            |  2 +-
 9 files changed, 138 insertions(+), 30 deletions(-)

diff --git a/.pylintrc b/.pylintrc
index 820637d..016b04e 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -282,7 +282,7 @@ 
ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil
 # List of class names for which member attributes should not be checked (useful
 # for classes with dynamically set attributes). This supports the use of
 # qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
+ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
 
 # List of members which are set dynamically and missed by pylint inference
 # system, and so shouldn't trigger E1101 when accessed. Python regular
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 3c5b18e..c86d4ea 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -12,7 +12,6 @@ from flask import escape, Markup
 from flask_appbuilder import Model
 from flask_babel import lazy_gettext as _
 import pandas as pd
-import six
 import sqlalchemy as sa
 from sqlalchemy import (
     and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
@@ -427,14 +426,8 @@ class SqlaTable(Model, BaseDatasource):
             table=self, database=self.database, **kwargs)
 
     def get_query_str(self, query_obj):
-        engine = self.database.get_sqla_engine()
         qry = self.get_sqla_query(**query_obj)
-        sql = six.text_type(
-            qry.compile(
-                engine,
-                compile_kwargs={'literal_binds': True},
-            ),
-        )
+        sql = self.database.compile_sqla_query(qry)
         logging.info(sql)
         sql = sqlparse.format(sql, reindent=True)
         if query_obj['is_prequery']:
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 97b8095..2b74541 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -65,7 +65,6 @@ class BaseEngineSpec(object):
     """Abstract class for database engine specific configurations"""
 
     engine = 'base'  # str as defined in sqlalchemy.engine.engine
-    cursor_execute_kwargs = {}
     time_grains = tuple()
     time_groupby_inline = False
     limit_method = LimitMethod.FORCE_LIMIT
@@ -331,6 +330,10 @@ class BaseEngineSpec(object):
     def normalize_column_name(column_name):
         return column_name
 
+    @staticmethod
+    def execute(cursor, query, async=False):
+        cursor.execute(query)
+
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -558,7 +561,6 @@ class SqliteEngineSpec(BaseEngineSpec):
 
 class MySQLEngineSpec(BaseEngineSpec):
     engine = 'mysql'
-    cursor_execute_kwargs = {'args': {}}
     time_grains = (
         Grain('Time Column', _('Time Column'), '{col}', None),
         Grain('second', _('second'), 'DATE_ADD(DATE({col}), '
@@ -639,7 +641,6 @@ class MySQLEngineSpec(BaseEngineSpec):
 
 class PrestoEngineSpec(BaseEngineSpec):
     engine = 'presto'
-    cursor_execute_kwargs = {'parameters': None}
 
     time_grains = (
         Grain('Time Column', _('Time Column'), '{col}', None),
@@ -938,7 +939,6 @@ class HiveEngineSpec(PrestoEngineSpec):
     """Reuses PrestoEngineSpec functionality."""
 
     engine = 'hive'
-    cursor_execute_kwargs = {'async': True}
 
     # Scoping regex at class level to avoid recompiling
     # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
@@ -1230,6 +1230,10 @@ class HiveEngineSpec(PrestoEngineSpec):
             configuration['hive.server2.proxy.user'] = username
         return configuration
 
+    @staticmethod
+    def execute(cursor, query, async=False):
+        cursor.execute(query, async=async)
+
 
 class MssqlEngineSpec(BaseEngineSpec):
     engine = 'mssql'
diff --git 
a/superset/migrations/versions/4451805bbaa1_remove_double_percents.py 
b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py
new file mode 100644
index 0000000..2e57b39
--- /dev/null
+++ b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py
@@ -0,0 +1,86 @@
+"""remove double percents
+
+Revision ID: 4451805bbaa1
+Revises: afb7730f6a9c
+Create Date: 2018-06-13 10:20:35.846744
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '4451805bbaa1'
+down_revision = 'bddc498dd179'
+
+
+from alembic import op
+import json
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import Column, create_engine, ForeignKey, Integer, String, Text
+
+from superset import db
+
+Base = declarative_base()
+
+
+class Slice(Base):
+    __tablename__ = 'slices'
+
+    id = Column(Integer, primary_key=True)
+    datasource_id = Column(Integer, ForeignKey('tables.id'))
+    datasource_type = Column(String(200))
+    params = Column(Text)
+
+
+class Table(Base):
+    __tablename__ = 'tables'
+
+    id = Column(Integer, primary_key=True)
+    database_id = Column(Integer, ForeignKey('dbs.id'))
+
+
+class Database(Base):
+    __tablename__ = 'dbs'
+
+    id = Column(Integer, primary_key=True)
+    sqlalchemy_uri = Column(String(1024))
+
+
+def replace(source, target):
+    bind = op.get_bind()
+    session = db.Session(bind=bind)
+
+    query = (
+        session.query(Slice, Database)
+        .join(Table)
+        .join(Database)
+        .filter(Slice.datasource_type == 'table')
+        .all()
+    )
+
+    for slc, database in query:
+        try:
+            engine = create_engine(database.sqlalchemy_uri)
+
+            if engine.dialect.identifier_preparer._double_percents:
+                params = json.loads(slc.params)
+
+                if 'adhoc_filters' in params:
+                    for filt in params['adhoc_filters']:
+                        if 'sqlExpression' in filt:
+                            filt['sqlExpression'] = (
+                                filt['sqlExpression'].replace(source, target)
+                            )
+
+                    slc.params = json.dumps(params, sort_keys=True)
+        except Exception:
+            pass
+
+    session.commit()
+    session.close()
+
+
+def upgrade():
+    replace('%%', '%')
+
+
+def downgrade():
+    replace('%', '%%')
diff --git a/superset/models/core.py b/superset/models/core.py
index 13021e7..d50cf4f 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -6,6 +6,7 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
+from contextlib import closing
 from copy import copy, deepcopy
 from datetime import datetime
 import functools
@@ -19,6 +20,7 @@ from flask_appbuilder.models.decorators import renders
 from future.standard_library import install_aliases
 import numpy
 import pandas as pd
+import six
 import sqlalchemy as sqla
 from sqlalchemy import (
     Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
@@ -749,12 +751,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
 
     def get_df(self, sql, schema):
         sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
-        eng = self.get_sqla_engine(schema=schema)
-
-        for i in range(len(sqls) - 1):
-            eng.execute(sqls[i])
-
-        df = pd.read_sql_query(sqls[-1], eng)
+        engine = self.get_sqla_engine(schema=schema)
 
         def needs_conversion(df_series):
             if df_series.empty:
@@ -763,15 +760,35 @@ class Database(Model, AuditMixinNullable, ImportMixin):
                 return True
             return False
 
-        for k, v in df.dtypes.items():
-            if v.type == numpy.object_ and needs_conversion(df[k]):
-                df[k] = df[k].apply(utils.json_dumps_w_dates)
-        return df
+        with closing(engine.raw_connection()) as conn:
+            with closing(conn.cursor()) as cursor:
+                for sql in sqls:
+                    self.db_engine_spec.execute(cursor, sql)
+                df = pd.DataFrame.from_records(
+                    data=list(cursor.fetchall()),
+                    columns=[col_desc[0] for col_desc in cursor.description],
+                    coerce_float=True,
+                )
+
+                for k, v in df.dtypes.items():
+                    if v.type == numpy.object_ and needs_conversion(df[k]):
+                        df[k] = df[k].apply(utils.json_dumps_w_dates)
+                return df
 
     def compile_sqla_query(self, qry, schema=None):
-        eng = self.get_sqla_engine(schema=schema)
-        compiled = qry.compile(eng, compile_kwargs={'literal_binds': True})
-        return '{}'.format(compiled)
+        engine = self.get_sqla_engine(schema=schema)
+
+        sql = six.text_type(
+            qry.compile(
+                engine,
+                compile_kwargs={'literal_binds': True},
+            ),
+        )
+
+        if engine.dialect.identifier_preparer._double_percents:
+            sql = sql.replace('%%', '%')
+
+        return sql
 
     def select_star(
             self, table_name, schema=None, limit=100, show_cols=False,
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index b45cbbb..a626b68 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -172,8 +172,7 @@ def execute_sql(
         cursor = conn.cursor()
         logging.info('Running query: \n{}'.format(executed_sql))
         logging.info(query.executed_sql)
-        cursor.execute(query.executed_sql,
-                       **db_engine_spec.cursor_execute_kwargs)
+        db_engine_spec.execute(cursor, query.executed_sql, async=True)
         logging.info('Handling cursor')
         db_engine_spec.handle_cursor(cursor, query, session)
         logging.info('Fetching data: {}'.format(query.to_dict()))
diff --git a/tests/core_tests.py b/tests/core_tests.py
index f1a0179..eb95f1f 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -436,6 +436,15 @@ class CoreTests(SupersetTestCase):
         expected_data = csv.reader(
             io.StringIO('first_name,last_name\nadmin, user\n'))
 
+        sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
+        client_id = '{}'.format(random.getrandbits(64))[:10]
+        self.run_sql(sql, client_id, raise_on_error=True)
+
+        resp = self.get_resp('/superset/csv/{}'.format(client_id))
+        data = csv.reader(io.StringIO(resp))
+        expected_data = csv.reader(
+            io.StringIO('first_name\nadmin\n'))
+
         self.assertEqual(list(expected_data), list(data))
         self.logout()
 
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index a3bb564..51c336b 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -254,7 +254,7 @@ class SqlLabTests(SupersetTestCase):
             'sql': """\
                 SELECT viz_type, count(1) as ccount
                 FROM slices
-                WHERE viz_type LIKE '%%a%%'
+                WHERE viz_type LIKE '%a%'
                 GROUP BY viz_type""",
             'dbId': 1,
         }
diff --git a/tox.ini b/tox.ini
index 6f3c9fd..464ab1b 100644
--- a/tox.ini
+++ b/tox.ini
@@ -37,7 +37,7 @@ setenv =
     SUPERSET_CONFIG = tests.superset_test_config
     SUPERSET_HOME = {envtmpdir}
     py27-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = 
mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
-    py34-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = 
mysql://mysqluser:mysqluserpassword@localhost/superset
+    py{34,36}-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = 
mysql://mysqluser:mysqluserpassword@localhost/superset
     py{27,34,36}-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = 
postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
     py{27,34,36}-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = 
sqlite:////{envtmpdir}/superset.db
 whitelist_externals =

Reply via email to