This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 68c4c3a Prevent 'main' database connection creation (#8038)
68c4c3a is described below
commit 68c4c3a0b9af083304713132534d607c9830e858
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Sun Sep 8 10:18:09 2019 -0700
Prevent 'main' database connection creation (#8038)
* prevent 'main' database connection creation
* fix tests
* removing get_main_database
* Kill get_main_database
* Point to examples tables
---
superset/cli.py | 102 ++++++++++++-----------------------
superset/security.py | 7 ++-
superset/utils/core.py | 10 ----
superset/views/core.py | 6 +--
tests/base_tests.py | 41 ++++++++++++--
tests/celery_tests.py | 39 +++++++-------
tests/core_tests.py | 68 ++++++++----------------
tests/dict_import_export_tests.py | 4 +-
tests/model_tests.py | 6 +--
tests/sql_parse_tests.py | 22 ++++----
tests/sql_validator_tests.py | 6 +--
tests/sqla_models_tests.py | 6 +--
tests/sqllab_tests.py | 109 +++++++++++++++++---------------------
13 files changed, 186 insertions(+), 240 deletions(-)
diff --git a/superset/cli.py b/superset/cli.py
index 5721220..b30b654 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -48,7 +48,6 @@ def make_shell_context():
@app.cli.command()
def init():
"""Inits the Superset application"""
- utils.get_or_create_main_db()
utils.get_example_database()
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@@ -430,75 +429,40 @@ def load_test_users_run():
Syncs permissions for those users/roles
"""
if config.get("TESTING"):
- security_manager.sync_role_definitions()
- gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
- for perm in security_manager.find_role("Gamma").permissions:
- security_manager.add_permission_role(gamma_sqllab_role, perm)
- utils.get_or_create_main_db()
- db_perm = utils.get_main_database().perm
- security_manager.add_permission_view_menu("database_access", db_perm)
- db_pvm = security_manager.find_permission_view_menu(
- view_menu_name=db_perm, permission_name="database_access"
+
+ sm = security_manager
+
+ examples_db = utils.get_example_database()
+
+ examples_pv = sm.add_permission_view_menu("database_access",
examples_db.perm)
+
+ sm.sync_role_definitions()
+ gamma_sqllab_role = sm.add_role("gamma_sqllab")
+ sm.add_permission_role(gamma_sqllab_role, examples_pv)
+
+ for role in ["Gamma", "sql_lab"]:
+ for perm in sm.find_role(role).permissions:
+ sm.add_permission_role(gamma_sqllab_role, perm)
+
+ users = (
+ ("admin", "Admin"),
+ ("gamma", "Gamma"),
+ ("gamma2", "Gamma"),
+ ("gamma_sqllab", "gamma_sqllab"),
+ ("alpha", "Alpha"),
)
- gamma_sqllab_role.permissions.append(db_pvm)
- for perm in security_manager.find_role("sql_lab").permissions:
- security_manager.add_permission_role(gamma_sqllab_role, perm)
-
- admin = security_manager.find_user("admin")
- if not admin:
- security_manager.add_user(
- "admin",
- "admin",
- " user",
- "[email protected]",
- security_manager.find_role("Admin"),
- password="general",
- )
-
- gamma = security_manager.find_user("gamma")
- if not gamma:
- security_manager.add_user(
- "gamma",
- "gamma",
- "user",
- "[email protected]",
- security_manager.find_role("Gamma"),
- password="general",
- )
-
- gamma2 = security_manager.find_user("gamma2")
- if not gamma2:
- security_manager.add_user(
- "gamma2",
- "gamma2",
- "user",
- "[email protected]",
- security_manager.find_role("Gamma"),
- password="general",
- )
-
- gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
- if not gamma_sqllab_user:
- security_manager.add_user(
- "gamma_sqllab",
- "gamma_sqllab",
- "user",
- "[email protected]",
- gamma_sqllab_role,
- password="general",
- )
-
- alpha = security_manager.find_user("alpha")
- if not alpha:
- security_manager.add_user(
- "alpha",
- "alpha",
- "user",
- "[email protected]",
- security_manager.find_role("Alpha"),
- password="general",
- )
- security_manager.get_session.commit()
+ for username, role in users:
+ user = sm.find_user(username)
+ if not user:
+ sm.add_user(
+ username,
+ username,
+ "user",
+ username + "@fab.org",
+ sm.find_role(role),
+ password="general",
+ )
+ sm.get_session.commit()
@app.cli.command()
diff --git a/superset/security.py b/superset/security.py
index bea1571..918d7d1 100644
--- a/superset/security.py
+++ b/superset/security.py
@@ -200,7 +200,6 @@ class SupersetSecurityManager(SecurityManager):
:param database: The Superset database
:returns: Whether the user can access the Superset database
"""
-
return (
self.all_datasource_access()
or self.all_database_access()
@@ -269,9 +268,9 @@ class SupersetSecurityManager(SecurityManager):
:param tables: The list of denied SQL table names
:returns: The error message
"""
-
- return f"""You need access to the following tables: {",
".join(tables)}, all
- database access or `all_datasource_access` permission"""
+ quoted_tables = [f"`{t}`" for t in tables]
+ return f"""You need access to the following tables: {",
".join(quoted_tables)},
+ `all_database_access` or `all_datasource_access` permission"""
def get_table_access_link(self, tables: List[str]) -> Optional[str]:
"""
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 8e70455..118d65b 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -936,10 +936,6 @@ def user_label(user: User) -> Optional[str]:
return None
-def get_or_create_main_db():
- get_main_database()
-
-
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
from superset import db
from superset.models import core as models
@@ -957,12 +953,6 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args,
**kwargs):
return database
-def get_main_database():
- from superset import conf
-
- return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
-
-
def get_example_database():
from superset import conf
diff --git a/superset/views/core.py b/superset/views/core.py
index af4fe7e..87e9400 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2705,11 +2705,7 @@ class Superset(BaseSupersetView):
query.sql, query.database, query.schema
)
if rejected_tables:
- flash(
- security_manager.get_table_access_error_msg(
- "{}".format(rejected_tables)
- )
- )
+ flash(security_manager.get_table_access_error_msg(rejected_tables))
return redirect("/")
blob = None
if results_backend and query.results_key:
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 8ac5bcd..8fc96cc 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -28,7 +28,7 @@ from superset.connectors.druid.models import DruidCluster,
DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.core import Database
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
BASE_DIR = app.config.get("BASE_DIR")
@@ -168,6 +168,12 @@ class SupersetTestCase(unittest.TestCase):
):
security_manager.del_permission_role(public_role, perm)
+ def _get_database_by_name(self, database_name="main"):
+ if database_name == "examples":
+ return get_example_database()
+ else:
+ raise ValueError("Database doesn't exist")
+
def run_sql(
self,
sql,
@@ -175,11 +181,12 @@ class SupersetTestCase(unittest.TestCase):
user_name=None,
raise_on_error=False,
query_limit=None,
+ database_name="examples",
):
if user_name:
self.logout()
- self.login(username=(user_name if user_name else "admin"))
- dbid = get_main_database().id
+ self.login(username=(user_name or "admin"))
+ dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
@@ -195,11 +202,35 @@ class SupersetTestCase(unittest.TestCase):
raise Exception("run_sql failed")
return resp
- def validate_sql(self, sql, client_id=None, user_name=None,
raise_on_error=False):
+ def create_fake_db(self):
+ self.login(username="admin")
+ database_name = "fake_db_100"
+ db_id = 100
+ extra = """{
+ "schemas_allowed_for_csv_upload":
+ ["this_schema_is_allowed", "this_schema_is_allowed_too"]
+ }"""
+
+ return self.get_or_create(
+ cls=models.Database,
+ criteria={"database_name": database_name},
+ session=db.session,
+ id=db_id,
+ extra=extra,
+ )
+
+ def validate_sql(
+ self,
+ sql,
+ client_id=None,
+ user_name=None,
+ raise_on_error=False,
+ database_name="examples",
+ ):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
- dbid = get_main_database().id
+ dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 7935448..d77529a 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -28,7 +28,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@@ -132,20 +132,20 @@ class CeleryTestCase(SupersetTestCase):
return json.loads(resp.data)
def test_run_sync_query_dont_exist(self):
- main_db = get_main_database()
+ main_db = get_example_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1)
def test_run_sync_query_cta(self):
- main_db = get_main_database()
+ main_db = get_example_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
self.drop_table_if_exists(tmp_table_name, main_db)
- perm_name = "can_sql_json"
- sql_where = "SELECT name FROM ab_permission WHERE
name='{}'".format(perm_name)
+ name = "James"
+ sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(
db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
)
@@ -162,9 +162,9 @@ class CeleryTestCase(SupersetTestCase):
self.assertGreater(len(results["data"]), 0)
def test_run_sync_query_cta_no_data(self):
- main_db = get_main_database()
+ main_db = get_example_database()
db_id = main_db.id
- sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
+ sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
result3 = self.run_sql(db_id, sql_empty_result, "3")
self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
self.assertEqual([], result3["data"])
@@ -183,12 +183,12 @@ class CeleryTestCase(SupersetTestCase):
return self.run_sql(db_id, sql)
def test_run_async_query(self):
- main_db = get_main_database()
+ main_db = get_example_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db)
- sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
+ sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1",
cta="true"
)
@@ -202,12 +202,13 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
+
self.assertTrue("FROM tmp_async_1" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_1 AS \n"
- "SELECT name FROM ab_role "
- "WHERE name='Admin'\n"
- "LIMIT 666",
+ "SELECT name FROM birth_names "
+ "WHERE name='James' "
+ "LIMIT 10",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
@@ -216,13 +217,14 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
- main_db = get_main_database()
+ main_db = get_example_database()
db_id = main_db.id
- self.drop_table_if_exists("tmp_async_2", main_db)
+ tmp_table = "tmp_async_2"
+ self.drop_table_if_exists(tmp_table, main_db)
- sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
+ sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
- db_id, sql_where, "5", async_="true", tmp_table="tmp_async_2",
cta="true"
+ db_id, sql_where, "5", async_="true", tmp_table=tmp_table,
cta="true"
)
assert result["query"]["state"] in (
QueryStatus.PENDING,
@@ -234,10 +236,9 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
- self.assertTrue("FROM tmp_async_2" in query.select_sql)
+ self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
self.assertEqual(
- "CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role "
- "WHERE name='Alpha' LIMIT 1",
+ f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names
LIMIT 1",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 5db0b77..3c2e8ea 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -346,13 +346,13 @@ class CoreTests(SupersetTestCase):
def test_testconn(self, username="admin"):
self.login(username=username)
- database = utils.get_main_database()
+ database = utils.get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy
uri
data = json.dumps(
{
"uri": database.safe_sqlalchemy_uri(),
- "name": "main",
+ "name": "examples",
"impersonate_user": False,
}
)
@@ -366,7 +366,7 @@ class CoreTests(SupersetTestCase):
data = json.dumps(
{
"uri": database.sqlalchemy_uri_decrypted,
- "name": "main",
+ "name": "examples",
"impersonate_user": False,
}
)
@@ -377,7 +377,7 @@ class CoreTests(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
def test_custom_password_store(self):
- database = utils.get_main_database()
+ database = utils.get_example_database()
conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
def custom_password_store(uri):
@@ -395,13 +395,13 @@ class CoreTests(SupersetTestCase):
# validate that sending a password-masked uri does not over-write the
decrypted
# uri
self.login(username=username)
- database = utils.get_main_database()
+ database = utils.get_example_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
url = "databaseview/edit/{}".format(database.id)
data = {k: database.__getattribute__(k) for k in
DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
- database = utils.get_main_database()
+ database = utils.get_example_database()
self.assertEqual(sqlalchemy_uri_decrypted,
database.sqlalchemy_uri_decrypted)
def test_warm_up_cache(self):
@@ -460,51 +460,51 @@ class CoreTests(SupersetTestCase):
def test_csv_endpoint(self):
self.login("admin")
sql = """
- SELECT first_name, last_name
- FROM ab_user
- WHERE first_name='admin'
+ SELECT name
+ FROM birth_names
+ WHERE name = 'James'
+ LIMIT 1
"""
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp))
- expected_data = csv.reader(io.StringIO("first_name,last_name\nadmin,
user\n"))
+ expected_data = csv.reader(io.StringIO("name\nJames\n"))
- sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
client_id = "{}".format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp("/superset/csv/{}".format(client_id))
data = csv.reader(io.StringIO(resp))
- expected_data = csv.reader(io.StringIO("first_name\nadmin\n"))
+ expected_data = csv.reader(io.StringIO("name\nJames\n"))
self.assertEqual(list(expected_data), list(data))
self.logout()
def test_extra_table_metadata(self):
self.login("admin")
- dbid = utils.get_main_database().id
+ dbid = utils.get_example_database().id
self.get_json_resp(
- f"/superset/extra_table_metadata/{dbid}/"
"ab_permission_view/panoramix/"
+ f"/superset/extra_table_metadata/{dbid}/birth_names/superset/"
)
def test_process_template(self):
- maindb = utils.get_main_database()
+ maindb = utils.get_example_database()
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_get_template_kwarg(self):
- maindb = utils.get_main_database()
+ maindb = utils.get_example_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb, foo="bar")
rendered = tp.process_template(s)
self.assertEqual("bar", rendered)
def test_template_kwarg(self):
- maindb = utils.get_main_database()
+ maindb = utils.get_example_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(s, foo="bar")
@@ -517,23 +517,12 @@ class CoreTests(SupersetTestCase):
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
def test_table_metadata(self):
- maindb = utils.get_main_database()
- backend = maindb.backend
- data =
self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id))
- self.assertEqual(data["name"], "ab_user")
+ maindb = utils.get_example_database()
+ data =
self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/")
+ self.assertEqual(data["name"], "birth_names")
assert len(data["columns"]) > 5
assert data.get("selectStar").startswith("SELECT")
- # Engine specific tests
- if backend in ("mysql", "postgresql"):
- self.assertEqual(data.get("primaryKey").get("type"), "pk")
- self.assertEqual(data.get("primaryKey").get("column_names")[0],
"id")
- self.assertEqual(len(data.get("foreignKeys")), 2)
- if backend == "mysql":
- self.assertEqual(len(data.get("indexes")), 7)
- elif backend == "postgresql":
- self.assertEqual(len(data.get("indexes")), 5)
-
def test_fetch_datasource_metadata(self):
self.login(username="admin")
url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
@@ -746,24 +735,11 @@ class CoreTests(SupersetTestCase):
def test_schemas_access_for_csv_upload_endpoint(
self, mock_all_datasource_access, mock_database_access,
mock_schemas_accessible
):
+ self.login(username="admin")
+ dbobj = self.create_fake_db()
mock_all_datasource_access.return_value = False
mock_database_access.return_value = False
mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"]
- database_name = "fake_db_100"
- db_id = 100
- extra = """{
- "schemas_allowed_for_csv_upload":
- ["this_schema_is_allowed", "this_schema_is_allowed_too"]
- }"""
-
- self.login(username="admin")
- dbobj = self.get_or_create(
- cls=models.Database,
- criteria={"database_name": database_name},
- session=db.session,
- id=db_id,
- extra=extra,
- )
data = self.get_json_resp(
url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format(
db_id=dbobj.id
diff --git a/tests/dict_import_export_tests.py
b/tests/dict_import_export_tests.py
index fafa024..26080def 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -23,7 +23,7 @@ import yaml
from superset import db
from superset.connectors.druid.models import DruidColumn, DruidDatasource,
DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
DBREF = "dict_import__export_test"
@@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
params = {DBREF: id, "database_name": database_name}
dict_rep = {
- "database_id": get_main_database().id,
+ "database_id": get_example_database().id,
"table_name": name,
"schema": schema,
"id": id,
diff --git a/tests/model_tests.py b/tests/model_tests.py
index f65db84..5033981 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -22,7 +22,7 @@ from sqlalchemy.engine.url import make_url
from superset import app
from superset.models.core import Database
-from superset.utils.core import get_example_database, get_main_database,
QueryStatus
+from superset.utils.core import get_example_database, QueryStatus
from .base_tests import SupersetTestCase
@@ -149,7 +149,7 @@ class DatabaseModelTestCase(SupersetTestCase):
assert sql.startswith(expected)
def test_single_statement(self):
- main_db = get_main_database()
+ main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None)
@@ -159,7 +159,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(df.iat[0, 0], 1)
def test_multi_statement(self):
- main_db = get_main_database()
+ main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None)
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index a7f925c..7c6d420 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -449,41 +449,41 @@ class SupersetTestCase(unittest.TestCase):
self.assertEquals({"SalesOrderHeader"}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self):
- sql = "SELECT * FROM ab_user -- SOME COMMENT"
+ sql = "SELECT * FROM birth_names -- SOME COMMENT"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self):
- sql = "SELECT * FROM ab_user -- SOME COMMENT WITH LIMIT 555"
+ sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
self.assertEquals(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit(self):
- sql = "SELECT * FROM ab_user LIMIT 555"
+ sql = "SELECT * FROM birth_names LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
newsql = parsed.get_query_with_new_limit(1000)
- expected = "SELECT * FROM ab_user LIMIT 1000"
+ expected = "SELECT * FROM birth_names LIMIT 1000"
self.assertEquals(newsql, expected)
def test_basic_breakdown_statements(self):
multi_sql = """
- SELECT * FROM ab_user;
- SELECT * FROM ab_user LIMIT 1;
+ SELECT * FROM birth_names;
+ SELECT * FROM birth_names LIMIT 1;
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEquals(len(statements), 2)
- expected = ["SELECT * FROM ab_user", "SELECT * FROM ab_user LIMIT 1"]
+ expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names
LIMIT 1"]
self.assertEquals(statements, expected)
def test_messy_breakdown_statements(self):
multi_sql = """
SELECT 1;\t\n\n\n \t
\t\nSELECT 2;
- SELECT * FROM ab_user;;;
- SELECT * FROM ab_user LIMIT 1
+ SELECT * FROM birth_names;;;
+ SELECT * FROM birth_names LIMIT 1
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
@@ -491,8 +491,8 @@ class SupersetTestCase(unittest.TestCase):
expected = [
"SELECT 1",
"SELECT 2",
- "SELECT * FROM ab_user",
- "SELECT * FROM ab_user LIMIT 1",
+ "SELECT * FROM birth_names",
+ "SELECT * FROM birth_names LIMIT 1",
]
self.assertEquals(statements, expected)
diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py
index 3db732a..3313f7e 100644
--- a/tests/sql_validator_tests.py
+++ b/tests/sql_validator_tests.py
@@ -53,7 +53,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
app.config["SQL_VALIDATORS_BY_ENGINE"] = {}
resp = self.validate_sql(
- "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+ "SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"])
@@ -97,7 +97,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
validator.validate.side_effect = Exception("Kaboom!")
resp = self.validate_sql(
- "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+ "SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("Kaboom!", resp["error"])
@@ -186,7 +186,7 @@ class PrestoValidatorTests(SupersetTestCase):
# validator for sqlite, this test will fail because the validator
# will no longer error out.
resp = self.validate_sql(
- "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+ "SELECT * FROM birth_names", client_id="1", raise_on_error=False
)
self.assertIn("error", resp)
self.assertIn("no SQL validator is configured", resp["error"])
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index a46f388..326311a 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -16,7 +16,7 @@
# under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@@ -43,7 +43,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_extra_cache_keys(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
- table = SqlaTable(sql=query, database=get_main_database())
+ table = SqlaTable(sql=query, database=get_example_database())
query_obj = {
"granularity": None,
"from_dttm": None,
@@ -60,7 +60,7 @@ class DatabaseModelTestCase(SupersetTestCase):
def test_has_no_extra_cache_keys(self):
query = "SELECT 'abc' as user"
- table = SqlaTable(sql=query, database=get_main_database())
+ table = SqlaTable(sql=query, database=get_example_database())
query_obj = {
"granularity": None,
"from_dttm": None,
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 1774c26..cad88f3 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -17,18 +17,20 @@
"""Unit tests for Sql Lab"""
from datetime import datetime, timedelta
import json
-import unittest
-from flask_appbuilder.security.sqla import models as ab_models
import prison
from superset import db, security_manager
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec
from superset.models.sql_lab import Query
-from superset.utils.core import datetime_to_epoch, get_main_database
+from superset.utils.core import datetime_to_epoch, get_example_database
from .base_tests import SupersetTestCase
+QUERY_1 = "SELECT * FROM birth_names LIMIT 1"
+QUERY_2 = "SELECT * FROM NO_TABLE"
+QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
+
class SqlLabTests(SupersetTestCase):
"""Testings for Sql Lab"""
@@ -39,17 +41,9 @@ class SqlLabTests(SupersetTestCase):
def run_some_queries(self):
db.session.query(Query).delete()
db.session.commit()
- self.run_sql(
- "SELECT * FROM ab_user", client_id="client_id_1", user_name="admin"
- )
- self.run_sql(
- "SELECT * FROM NO_TABLE", client_id="client_id_3",
user_name="admin"
- )
- self.run_sql(
- "SELECT * FROM ab_permission",
- client_id="client_id_2",
- user_name="gamma_sqllab",
- )
+ self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
+ self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
+ self.run_sql(QUERY_3, client_id="client_id_2",
user_name="gamma_sqllab")
self.logout()
def tearDown(self):
@@ -61,7 +55,7 @@ class SqlLabTests(SupersetTestCase):
def test_sql_json(self):
self.login("admin")
- data = self.run_sql("SELECT * FROM ab_user", "1")
+ data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
self.assertLess(0, len(data["data"]))
data = self.run_sql("SELECT * FROM unexistant_table", "2")
@@ -71,8 +65,8 @@ class SqlLabTests(SupersetTestCase):
self.login("admin")
multi_sql = """
- SELECT first_name FROM ab_user;
- SELECT first_name FROM ab_user;
+ SELECT * FROM birth_names LIMIT 1;
+ SELECT * FROM birth_names LIMIT 2;
"""
data = self.run_sql(multi_sql, "2234")
self.assertLess(0, len(data["data"]))
@@ -80,24 +74,18 @@ class SqlLabTests(SupersetTestCase):
def test_explain(self):
self.login("admin")
- data = self.run_sql("EXPLAIN SELECT * FROM ab_user", "1")
+ data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
self.assertLess(0, len(data["data"]))
def test_sql_json_has_access(self):
- main_db = get_main_database()
- security_manager.add_permission_view_menu("database_access",
main_db.perm)
- db.session.commit()
- main_db_permission_view = (
- db.session.query(ab_models.PermissionView)
- .join(ab_models.ViewMenu)
- .join(ab_models.Permission)
- .filter(ab_models.ViewMenu.name == "[main].(id:1)")
- .filter(ab_models.Permission.name == "database_access")
- .first()
+ examples_db = get_example_database()
+ examples_db_permission_view =
security_manager.add_permission_view_menu(
+ "database_access", examples_db.perm
)
+
astronaut = security_manager.add_role("Astronaut")
- security_manager.add_permission_role(astronaut,
main_db_permission_view)
- # Astronaut role is Gamma + sqllab + main db permissions
+ security_manager.add_permission_role(astronaut,
examples_db_permission_view)
+ # Astronaut role is Gamma + sqllab + db permissions
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(astronaut, perm)
for perm in security_manager.find_role("sql_lab").permissions:
@@ -113,7 +101,7 @@ class SqlLabTests(SupersetTestCase):
astronaut,
password="general",
)
- data = self.run_sql("SELECT * FROM ab_user", "3", user_name="gagarin")
+ data = self.run_sql(QUERY_1, "3", user_name="gagarin")
db.session.query(Query).delete()
db.session.commit()
self.assertLess(0, len(data["data"]))
@@ -132,8 +120,8 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(2, len(data))
# Run 2 more queries
- self.run_sql("SELECT * FROM ab_user LIMIT 1", client_id="client_id_4")
- self.run_sql("SELECT * FROM ab_user LIMIT 2", client_id="client_id_5")
+ self.run_sql("SELECT * FROM birth_names LIMIT 1",
client_id="client_id_4")
+ self.run_sql("SELECT * FROM birth_names LIMIT 2",
client_id="client_id_5")
self.login("admin")
data = self.get_json_resp("/superset/queries/0")
self.assertEquals(4, len(data))
@@ -141,7 +129,7 @@ class SqlLabTests(SupersetTestCase):
now = datetime.now() + timedelta(days=1)
query = (
db.session.query(Query)
- .filter_by(sql="SELECT * FROM ab_user LIMIT 1")
+ .filter_by(sql="SELECT * FROM birth_names LIMIT 1")
.first()
)
query.changed_on = now
@@ -160,11 +148,15 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_db_id(self):
self.run_some_queries()
self.login("admin")
+ examples_dbid = get_example_database().id
+
# Test search queries on database Id
- data = self.get_json_resp("/superset/search_queries?database_id=1")
+ data = self.get_json_resp(
+ f"/superset/search_queries?database_id={examples_dbid}"
+ )
self.assertEquals(3, len(data))
db_ids = [k["dbId"] for k in data]
- self.assertEquals([1, 1, 1], db_ids)
+ self.assertEquals([examples_dbid for i in range(3)], db_ids)
resp = self.get_resp("/superset/search_queries?database_id=-1")
data = json.loads(resp)
@@ -205,19 +197,19 @@ class SqlLabTests(SupersetTestCase):
def test_search_query_on_text(self):
self.run_some_queries()
self.login("admin")
- url = "/superset/search_queries?search_text=permission"
+ url = "/superset/search_queries?search_text=birth"
data = self.get_json_resp(url)
- self.assertEquals(1, len(data))
- self.assertIn("permission", data[0]["sql"])
+ self.assertEquals(2, len(data))
+ self.assertIn("birth", data[0]["sql"])
def test_search_query_on_time(self):
self.run_some_queries()
self.login("admin")
first_query_time = (
- db.session.query(Query).filter_by(sql="SELECT * FROM
ab_user").one()
+ db.session.query(Query).filter_by(sql=QUERY_1).one()
).start_time
second_query_time = (
- db.session.query(Query).filter_by(sql="SELECT * FROM
ab_permission").one()
+ db.session.query(Query).filter_by(sql=QUERY_3).one()
).start_time
# Test search queries on time filter
from_time = "from={}".format(int(first_query_time))
@@ -265,7 +257,7 @@ class SqlLabTests(SupersetTestCase):
def test_alias_duplicate(self):
self.run_sql(
- "SELECT username as col, id as col, username FROM ab_user",
+ "SELECT name as col, gender as col FROM birth_names LIMIT 10",
client_id="2e2df3",
user_name="admin",
raise_on_error=True,
@@ -281,7 +273,7 @@ class SqlLabTests(SupersetTestCase):
def test_df_conversion_tuple(self):
cols = ["string_col", "int_col", "list_col", "float_col"]
- data = [(u"Text", 111, [123], 1.0)]
+ data = [("Text", 111, [123], 1.0)]
cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
self.assertEquals(len(data), cdf.size)
@@ -296,6 +288,7 @@ class SqlLabTests(SupersetTestCase):
self.assertEquals(len(cols), len(cdf.columns))
def test_sqllab_viz(self):
+ examples_dbid = get_example_database().id
payload = {
"chartType": "dist_bar",
"datasourceName": "test_viz_flow_table",
@@ -316,11 +309,10 @@ class SqlLabTests(SupersetTestCase):
},
],
"sql": """\
- SELECT viz_type, count(1) as ccount
- FROM slices
- WHERE viz_type LIKE '%a%'
- GROUP BY viz_type""",
- "dbId": 1,
+ SELECT *
+ FROM birth_names
+ LIMIT 10""",
+ "dbId": examples_dbid,
}
data = {"data": json.dumps(payload)}
resp = self.get_json_resp("/superset/sqllab_viz/", data=data)
@@ -329,20 +321,20 @@ class SqlLabTests(SupersetTestCase):
def test_sql_limit(self):
self.login("admin")
test_limit = 1
- data = self.run_sql("SELECT * FROM ab_user", client_id="sql_limit_1")
+ data = self.run_sql("SELECT * FROM birth_names",
client_id="sql_limit_1")
self.assertGreater(len(data["data"]), test_limit)
data = self.run_sql(
- "SELECT * FROM ab_user", client_id="sql_limit_2",
query_limit=test_limit
+ "SELECT * FROM birth_names", client_id="sql_limit_2",
query_limit=test_limit
)
self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql(
- "SELECT * FROM ab_user LIMIT {}".format(test_limit),
+ "SELECT * FROM birth_names LIMIT {}".format(test_limit),
client_id="sql_limit_3",
query_limit=test_limit + 1,
)
self.assertEquals(len(data["data"]), test_limit)
data = self.run_sql(
- "SELECT * FROM ab_user LIMIT {}".format(test_limit + 1),
+ "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1),
client_id="sql_limit_4",
query_limit=test_limit,
)
@@ -406,6 +398,7 @@ class SqlLabTests(SupersetTestCase):
def test_api_database(self):
self.login("admin")
+ self.create_fake_db()
arguments = {
"keys": [],
@@ -415,12 +408,8 @@ class SqlLabTests(SupersetTestCase):
"page": 0,
"page_size": -1,
}
- expected_results = ["examples", "fake_db_100", "main"]
url = "api/v1/database/?{}={}".format("q", prison.dumps(arguments))
- data = self.get_json_resp(url)
- for i, expected_result in enumerate(expected_results):
- self.assertEquals(expected_result,
data["result"][i]["database_name"])
-
-
-if __name__ == "__main__":
- unittest.main()
+ self.assertEquals(
+ {"examples", "fake_db_100"},
+ {r.get("database_name") for r in
self.get_json_resp(url)["result"]},
+ )