This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 68c4c3a  Prevent 'main' database connection creation (#8038)
68c4c3a is described below

commit 68c4c3a0b9af083304713132534d607c9830e858
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Sun Sep 8 10:18:09 2019 -0700

    Prevent 'main' database connection creation (#8038)
    
    * prevent 'main' database connection creation
    
    * fix tests
    
    * removing get_main_database
    
    * Kill get_main_database
    
    * Point to examples tables
---
 superset/cli.py                   | 102 ++++++++++++-----------------------
 superset/security.py              |   7 ++-
 superset/utils/core.py            |  10 ----
 superset/views/core.py            |   6 +--
 tests/base_tests.py               |  41 ++++++++++++--
 tests/celery_tests.py             |  39 +++++++-------
 tests/core_tests.py               |  68 ++++++++----------------
 tests/dict_import_export_tests.py |   4 +-
 tests/model_tests.py              |   6 +--
 tests/sql_parse_tests.py          |  22 ++++----
 tests/sql_validator_tests.py      |   6 +--
 tests/sqla_models_tests.py        |   6 +--
 tests/sqllab_tests.py             | 109 +++++++++++++++++---------------------
 13 files changed, 186 insertions(+), 240 deletions(-)

diff --git a/superset/cli.py b/superset/cli.py
index 5721220..b30b654 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -48,7 +48,6 @@ def make_shell_context():
 @app.cli.command()
 def init():
     """Inits the Superset application"""
-    utils.get_or_create_main_db()
     utils.get_example_database()
     appbuilder.add_permissions(update_perms=True)
     security_manager.sync_role_definitions()
@@ -430,75 +429,40 @@ def load_test_users_run():
     Syncs permissions for those users/roles
     """
     if config.get("TESTING"):
-        security_manager.sync_role_definitions()
-        gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
-        for perm in security_manager.find_role("Gamma").permissions:
-            security_manager.add_permission_role(gamma_sqllab_role, perm)
-        utils.get_or_create_main_db()
-        db_perm = utils.get_main_database().perm
-        security_manager.add_permission_view_menu("database_access", db_perm)
-        db_pvm = security_manager.find_permission_view_menu(
-            view_menu_name=db_perm, permission_name="database_access"
+
+        sm = security_manager
+
+        examples_db = utils.get_example_database()
+
+        examples_pv = sm.add_permission_view_menu("database_access", 
examples_db.perm)
+
+        sm.sync_role_definitions()
+        gamma_sqllab_role = sm.add_role("gamma_sqllab")
+        sm.add_permission_role(gamma_sqllab_role, examples_pv)
+
+        for role in ["Gamma", "sql_lab"]:
+            for perm in sm.find_role(role).permissions:
+                sm.add_permission_role(gamma_sqllab_role, perm)
+
+        users = (
+            ("admin", "Admin"),
+            ("gamma", "Gamma"),
+            ("gamma2", "Gamma"),
+            ("gamma_sqllab", "gamma_sqllab"),
+            ("alpha", "Alpha"),
         )
-        gamma_sqllab_role.permissions.append(db_pvm)
-        for perm in security_manager.find_role("sql_lab").permissions:
-            security_manager.add_permission_role(gamma_sqllab_role, perm)
-
-        admin = security_manager.find_user("admin")
-        if not admin:
-            security_manager.add_user(
-                "admin",
-                "admin",
-                " user",
-                "[email protected]",
-                security_manager.find_role("Admin"),
-                password="general",
-            )
-
-        gamma = security_manager.find_user("gamma")
-        if not gamma:
-            security_manager.add_user(
-                "gamma",
-                "gamma",
-                "user",
-                "[email protected]",
-                security_manager.find_role("Gamma"),
-                password="general",
-            )
-
-        gamma2 = security_manager.find_user("gamma2")
-        if not gamma2:
-            security_manager.add_user(
-                "gamma2",
-                "gamma2",
-                "user",
-                "[email protected]",
-                security_manager.find_role("Gamma"),
-                password="general",
-            )
-
-        gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
-        if not gamma_sqllab_user:
-            security_manager.add_user(
-                "gamma_sqllab",
-                "gamma_sqllab",
-                "user",
-                "[email protected]",
-                gamma_sqllab_role,
-                password="general",
-            )
-
-        alpha = security_manager.find_user("alpha")
-        if not alpha:
-            security_manager.add_user(
-                "alpha",
-                "alpha",
-                "user",
-                "[email protected]",
-                security_manager.find_role("Alpha"),
-                password="general",
-            )
-        security_manager.get_session.commit()
+        for username, role in users:
+            user = sm.find_user(username)
+            if not user:
+                sm.add_user(
+                    username,
+                    username,
+                    "user",
+                    username + "@fab.org",
+                    sm.find_role(role),
+                    password="general",
+                )
+        sm.get_session.commit()
 
 
 @app.cli.command()
diff --git a/superset/security.py b/superset/security.py
index bea1571..918d7d1 100644
--- a/superset/security.py
+++ b/superset/security.py
@@ -200,7 +200,6 @@ class SupersetSecurityManager(SecurityManager):
         :param database: The Superset database
         :returns: Whether the user can access the Superset database
         """
-
         return (
             self.all_datasource_access()
             or self.all_database_access()
@@ -269,9 +268,9 @@ class SupersetSecurityManager(SecurityManager):
         :param tables: The list of denied SQL table names
         :returns: The error message
         """
-
-        return f"""You need access to the following tables: {", 
".join(tables)}, all
-            database access or `all_datasource_access` permission"""
+        quoted_tables = [f"`{t}`" for t in tables]
+        return f"""You need access to the following tables: {", 
".join(quoted_tables)},
+            `all_database_access` or `all_datasource_access` permission"""
 
     def get_table_access_link(self, tables: List[str]) -> Optional[str]:
         """
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 8e70455..118d65b 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -936,10 +936,6 @@ def user_label(user: User) -> Optional[str]:
     return None
 
 
-def get_or_create_main_db():
-    get_main_database()
-
-
 def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
     from superset import db
     from superset.models import core as models
@@ -957,12 +953,6 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, 
**kwargs):
     return database
 
 
-def get_main_database():
-    from superset import conf
-
-    return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
-
-
 def get_example_database():
     from superset import conf
 
diff --git a/superset/views/core.py b/superset/views/core.py
index af4fe7e..87e9400 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2705,11 +2705,7 @@ class Superset(BaseSupersetView):
             query.sql, query.database, query.schema
         )
         if rejected_tables:
-            flash(
-                security_manager.get_table_access_error_msg(
-                    "{}".format(rejected_tables)
-                )
-            )
+            flash(security_manager.get_table_access_error_msg(rejected_tables))
             return redirect("/")
         blob = None
         if results_backend and query.results_key:
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 8ac5bcd..8fc96cc 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -28,7 +28,7 @@ from superset.connectors.druid.models import DruidCluster, 
DruidDatasource
 from superset.connectors.sqla.models import SqlaTable
 from superset.models import core as models
 from superset.models.core import Database
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
 
 BASE_DIR = app.config.get("BASE_DIR")
 
@@ -168,6 +168,12 @@ class SupersetTestCase(unittest.TestCase):
             ):
                 security_manager.del_permission_role(public_role, perm)
 
+    def _get_database_by_name(self, database_name="main"):
+        if database_name == "examples":
+            return get_example_database()
+        else:
+            raise ValueError("Database doesn't exist")
+
     def run_sql(
         self,
         sql,
@@ -175,11 +181,12 @@ class SupersetTestCase(unittest.TestCase):
         user_name=None,
         raise_on_error=False,
         query_limit=None,
+        database_name="examples",
     ):
         if user_name:
             self.logout()
-            self.login(username=(user_name if user_name else "admin"))
-        dbid = get_main_database().id
+            self.login(username=(user_name or "admin"))
+        dbid = self._get_database_by_name(database_name).id
         resp = self.get_json_resp(
             "/superset/sql_json/",
             raise_on_error=False,
@@ -195,11 +202,35 @@ class SupersetTestCase(unittest.TestCase):
             raise Exception("run_sql failed")
         return resp
 
-    def validate_sql(self, sql, client_id=None, user_name=None, 
raise_on_error=False):
+    def create_fake_db(self):
+        self.login(username="admin")
+        database_name = "fake_db_100"
+        db_id = 100
+        extra = """{
+            "schemas_allowed_for_csv_upload":
+            ["this_schema_is_allowed", "this_schema_is_allowed_too"]
+        }"""
+
+        return self.get_or_create(
+            cls=models.Database,
+            criteria={"database_name": database_name},
+            session=db.session,
+            id=db_id,
+            extra=extra,
+        )
+
+    def validate_sql(
+        self,
+        sql,
+        client_id=None,
+        user_name=None,
+        raise_on_error=False,
+        database_name="examples",
+    ):
         if user_name:
             self.logout()
             self.login(username=(user_name if user_name else "admin"))
-        dbid = get_main_database().id
+        dbid = self._get_database_by_name(database_name).id
         resp = self.get_json_resp(
             "/superset/validate_sql_json/",
             raise_on_error=False,
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 7935448..d77529a 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -28,7 +28,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
 from superset.models.helpers import QueryStatus
 from superset.models.sql_lab import Query
 from superset.sql_parse import ParsedQuery
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
 from .base_tests import SupersetTestCase
 
 
@@ -132,20 +132,20 @@ class CeleryTestCase(SupersetTestCase):
         return json.loads(resp.data)
 
     def test_run_sync_query_dont_exist(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
         db_id = main_db.id
         sql_dont_exist = "SELECT name FROM table_dont_exist"
         result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
         self.assertTrue("error" in result1)
 
     def test_run_sync_query_cta(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
         backend = main_db.backend
         db_id = main_db.id
         tmp_table_name = "tmp_async_22"
         self.drop_table_if_exists(tmp_table_name, main_db)
-        perm_name = "can_sql_json"
-        sql_where = "SELECT name FROM ab_permission WHERE 
name='{}'".format(perm_name)
+        name = "James"
+        sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
         result = self.run_sql(
             db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
         )
@@ -162,9 +162,9 @@ class CeleryTestCase(SupersetTestCase):
             self.assertGreater(len(results["data"]), 0)
 
     def test_run_sync_query_cta_no_data(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
         db_id = main_db.id
-        sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
+        sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
         result3 = self.run_sql(db_id, sql_empty_result, "3")
         self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
         self.assertEqual([], result3["data"])
@@ -183,12 +183,12 @@ class CeleryTestCase(SupersetTestCase):
         return self.run_sql(db_id, sql)
 
     def test_run_async_query(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
         db_id = main_db.id
 
         self.drop_table_if_exists("tmp_async_1", main_db)
 
-        sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
+        sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
         result = self.run_sql(
             db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", 
cta="true"
         )
@@ -202,12 +202,13 @@ class CeleryTestCase(SupersetTestCase):
 
         query = self.get_query_by_id(result["query"]["serverId"])
         self.assertEqual(QueryStatus.SUCCESS, query.status)
+
         self.assertTrue("FROM tmp_async_1" in query.select_sql)
         self.assertEqual(
             "CREATE TABLE tmp_async_1 AS \n"
-            "SELECT name FROM ab_role "
-            "WHERE name='Admin'\n"
-            "LIMIT 666",
+            "SELECT name FROM birth_names "
+            "WHERE name='James' "
+            "LIMIT 10",
             query.executed_sql,
         )
         self.assertEqual(sql_where, query.sql)
@@ -216,13 +217,14 @@ class CeleryTestCase(SupersetTestCase):
         self.assertEqual(True, query.select_as_cta_used)
 
     def test_run_async_query_with_lower_limit(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
         db_id = main_db.id
-        self.drop_table_if_exists("tmp_async_2", main_db)
+        tmp_table = "tmp_async_2"
+        self.drop_table_if_exists(tmp_table, main_db)
 
-        sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
+        sql_where = "SELECT name FROM birth_names LIMIT 1"
         result = self.run_sql(
-            db_id, sql_where, "5", async_="true", tmp_table="tmp_async_2", 
cta="true"
+            db_id, sql_where, "5", async_="true", tmp_table=tmp_table, 
cta="true"
         )
         assert result["query"]["state"] in (
             QueryStatus.PENDING,
@@ -234,10 +236,9 @@ class CeleryTestCase(SupersetTestCase):
 
         query = self.get_query_by_id(result["query"]["serverId"])
         self.assertEqual(QueryStatus.SUCCESS, query.status)
-        self.assertTrue("FROM tmp_async_2" in query.select_sql)
+        self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
         self.assertEqual(
-            "CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role "
-            "WHERE name='Alpha' LIMIT 1",
+            f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names 
LIMIT 1",
             query.executed_sql,
         )
         self.assertEqual(sql_where, query.sql)
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 5db0b77..3c2e8ea 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -346,13 +346,13 @@ class CoreTests(SupersetTestCase):
 
     def test_testconn(self, username="admin"):
         self.login(username=username)
-        database = utils.get_main_database()
+        database = utils.get_example_database()
 
         # validate that the endpoint works with the password-masked sqlalchemy 
uri
         data = json.dumps(
             {
                 "uri": database.safe_sqlalchemy_uri(),
-                "name": "main",
+                "name": "examples",
                 "impersonate_user": False,
             }
         )
@@ -366,7 +366,7 @@ class CoreTests(SupersetTestCase):
         data = json.dumps(
             {
                 "uri": database.sqlalchemy_uri_decrypted,
-                "name": "main",
+                "name": "examples",
                 "impersonate_user": False,
             }
         )
@@ -377,7 +377,7 @@ class CoreTests(SupersetTestCase):
         assert response.headers["Content-Type"] == "application/json"
 
     def test_custom_password_store(self):
-        database = utils.get_main_database()
+        database = utils.get_example_database()
         conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
 
         def custom_password_store(uri):
@@ -395,13 +395,13 @@ class CoreTests(SupersetTestCase):
         # validate that sending a password-masked uri does not over-write the 
decrypted
         # uri
         self.login(username=username)
-        database = utils.get_main_database()
+        database = utils.get_example_database()
         sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
         url = "databaseview/edit/{}".format(database.id)
         data = {k: database.__getattribute__(k) for k in 
DatabaseView.add_columns}
         data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
         self.client.post(url, data=data)
-        database = utils.get_main_database()
+        database = utils.get_example_database()
         self.assertEqual(sqlalchemy_uri_decrypted, 
database.sqlalchemy_uri_decrypted)
 
     def test_warm_up_cache(self):
@@ -460,51 +460,51 @@ class CoreTests(SupersetTestCase):
     def test_csv_endpoint(self):
         self.login("admin")
         sql = """
-            SELECT first_name, last_name
-            FROM ab_user
-            WHERE first_name='admin'
+            SELECT name
+            FROM birth_names
+            WHERE name = 'James'
+            LIMIT 1
         """
         client_id = "{}".format(random.getrandbits(64))[:10]
         self.run_sql(sql, client_id, raise_on_error=True)
 
         resp = self.get_resp("/superset/csv/{}".format(client_id))
         data = csv.reader(io.StringIO(resp))
-        expected_data = csv.reader(io.StringIO("first_name,last_name\nadmin, 
user\n"))
+        expected_data = csv.reader(io.StringIO("name\nJames\n"))
 
-        sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
         client_id = "{}".format(random.getrandbits(64))[:10]
         self.run_sql(sql, client_id, raise_on_error=True)
 
         resp = self.get_resp("/superset/csv/{}".format(client_id))
         data = csv.reader(io.StringIO(resp))
-        expected_data = csv.reader(io.StringIO("first_name\nadmin\n"))
+        expected_data = csv.reader(io.StringIO("name\nJames\n"))
 
         self.assertEqual(list(expected_data), list(data))
         self.logout()
 
     def test_extra_table_metadata(self):
         self.login("admin")
-        dbid = utils.get_main_database().id
+        dbid = utils.get_example_database().id
         self.get_json_resp(
-            f"/superset/extra_table_metadata/{dbid}/" 
"ab_permission_view/panoramix/"
+            f"/superset/extra_table_metadata/{dbid}/birth_names/superset/"
         )
 
     def test_process_template(self):
-        maindb = utils.get_main_database()
+        maindb = utils.get_example_database()
         sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
         tp = jinja_context.get_template_processor(database=maindb)
         rendered = tp.process_template(sql)
         self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
 
     def test_get_template_kwarg(self):
-        maindb = utils.get_main_database()
+        maindb = utils.get_example_database()
         s = "{{ foo }}"
         tp = jinja_context.get_template_processor(database=maindb, foo="bar")
         rendered = tp.process_template(s)
         self.assertEqual("bar", rendered)
 
     def test_template_kwarg(self):
-        maindb = utils.get_main_database()
+        maindb = utils.get_example_database()
         s = "{{ foo }}"
         tp = jinja_context.get_template_processor(database=maindb)
         rendered = tp.process_template(s, foo="bar")
@@ -517,23 +517,12 @@ class CoreTests(SupersetTestCase):
         self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
 
     def test_table_metadata(self):
-        maindb = utils.get_main_database()
-        backend = maindb.backend
-        data = 
self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id))
-        self.assertEqual(data["name"], "ab_user")
+        maindb = utils.get_example_database()
+        data = 
self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/")
+        self.assertEqual(data["name"], "birth_names")
         assert len(data["columns"]) > 5
         assert data.get("selectStar").startswith("SELECT")
 
-        # Engine specific tests
-        if backend in ("mysql", "postgresql"):
-            self.assertEqual(data.get("primaryKey").get("type"), "pk")
-            self.assertEqual(data.get("primaryKey").get("column_names")[0], 
"id")
-            self.assertEqual(len(data.get("foreignKeys")), 2)
-            if backend == "mysql":
-                self.assertEqual(len(data.get("indexes")), 7)
-            elif backend == "postgresql":
-                self.assertEqual(len(data.get("indexes")), 5)
-
     def test_fetch_datasource_metadata(self):
         self.login(username="admin")
         url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
@@ -746,24 +735,11 @@ class CoreTests(SupersetTestCase):
     def test_schemas_access_for_csv_upload_endpoint(
         self, mock_all_datasource_access, mock_database_access, 
mock_schemas_accessible
     ):
+        self.login(username="admin")
+        dbobj = self.create_fake_db()
         mock_all_datasource_access.return_value = False
         mock_database_access.return_value = False
         mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"]
-        database_name = "fake_db_100"
-        db_id = 100
-        extra = """{
-            "schemas_allowed_for_csv_upload":
-            ["this_schema_is_allowed", "this_schema_is_allowed_too"]
-        }"""
-
-        self.login(username="admin")
-        dbobj = self.get_or_create(
-            cls=models.Database,
-            criteria={"database_name": database_name},
-            session=db.session,
-            id=db_id,
-            extra=extra,
-        )
         data = self.get_json_resp(
             url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format(
                 db_id=dbobj.id
diff --git a/tests/dict_import_export_tests.py 
b/tests/dict_import_export_tests.py
index fafa024..26080def 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -23,7 +23,7 @@ import yaml
 from superset import db
 from superset.connectors.druid.models import DruidColumn, DruidDatasource, 
DruidMetric
 from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
 from .base_tests import SupersetTestCase
 
 DBREF = "dict_import__export_test"
@@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
         params = {DBREF: id, "database_name": database_name}
 
         dict_rep = {
-            "database_id": get_main_database().id,
+            "database_id": get_example_database().id,
             "table_name": name,
             "schema": schema,
             "id": id,
diff --git a/tests/model_tests.py b/tests/model_tests.py
index f65db84..5033981 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -22,7 +22,7 @@ from sqlalchemy.engine.url import make_url
 
 from superset import app
 from superset.models.core import Database
-from superset.utils.core import get_example_database, get_main_database, 
QueryStatus
+from superset.utils.core import get_example_database, QueryStatus
 from .base_tests import SupersetTestCase
 
 
@@ -149,7 +149,7 @@ class DatabaseModelTestCase(SupersetTestCase):
             assert sql.startswith(expected)
 
     def test_single_statement(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
 
         if main_db.backend == "mysql":
             df = main_db.get_df("SELECT 1", None)
@@ -159,7 +159,7 @@ class DatabaseModelTestCase(SupersetTestCase):
             self.assertEquals(df.iat[0, 0], 1)
 
     def test_multi_statement(self):
-        main_db = get_main_database()
+        main_db = get_example_database()
 
         if main_db.backend == "mysql":
             df = main_db.get_df("USE superset; SELECT 1", None)
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index a7f925c..7c6d420 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -449,41 +449,41 @@ class SupersetTestCase(unittest.TestCase):
         self.assertEquals({"SalesOrderHeader"}, self.extract_tables(query))
 
     def test_get_query_with_new_limit_comment(self):
-        sql = "SELECT * FROM ab_user -- SOME COMMENT"
+        sql = "SELECT * FROM birth_names -- SOME COMMENT"
         parsed = sql_parse.ParsedQuery(sql)
         newsql = parsed.get_query_with_new_limit(1000)
         self.assertEquals(newsql, sql + "\nLIMIT 1000")
 
     def test_get_query_with_new_limit_comment_with_limit(self):
-        sql = "SELECT * FROM ab_user -- SOME COMMENT WITH LIMIT 555"
+        sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
         parsed = sql_parse.ParsedQuery(sql)
         newsql = parsed.get_query_with_new_limit(1000)
         self.assertEquals(newsql, sql + "\nLIMIT 1000")
 
     def test_get_query_with_new_limit(self):
-        sql = "SELECT * FROM ab_user LIMIT 555"
+        sql = "SELECT * FROM birth_names LIMIT 555"
         parsed = sql_parse.ParsedQuery(sql)
         newsql = parsed.get_query_with_new_limit(1000)
-        expected = "SELECT * FROM ab_user LIMIT 1000"
+        expected = "SELECT * FROM birth_names LIMIT 1000"
         self.assertEquals(newsql, expected)
 
     def test_basic_breakdown_statements(self):
         multi_sql = """
-        SELECT * FROM ab_user;
-        SELECT * FROM ab_user LIMIT 1;
+        SELECT * FROM birth_names;
+        SELECT * FROM birth_names LIMIT 1;
         """
         parsed = sql_parse.ParsedQuery(multi_sql)
         statements = parsed.get_statements()
         self.assertEquals(len(statements), 2)
-        expected = ["SELECT * FROM ab_user", "SELECT * FROM ab_user LIMIT 1"]
+        expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names 
LIMIT 1"]
         self.assertEquals(statements, expected)
 
     def test_messy_breakdown_statements(self):
         multi_sql = """
         SELECT 1;\t\n\n\n  \t
         \t\nSELECT 2;
-        SELECT * FROM ab_user;;;
-        SELECT * FROM ab_user LIMIT 1
+        SELECT * FROM birth_names;;;
+        SELECT * FROM birth_names LIMIT 1
         """
         parsed = sql_parse.ParsedQuery(multi_sql)
         statements = parsed.get_statements()
@@ -491,8 +491,8 @@ class SupersetTestCase(unittest.TestCase):
         expected = [
             "SELECT 1",
             "SELECT 2",
-            "SELECT * FROM ab_user",
-            "SELECT * FROM ab_user LIMIT 1",
+            "SELECT * FROM birth_names",
+            "SELECT * FROM birth_names LIMIT 1",
         ]
         self.assertEquals(statements, expected)
 
diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py
index 3db732a..3313f7e 100644
--- a/tests/sql_validator_tests.py
+++ b/tests/sql_validator_tests.py
@@ -53,7 +53,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
         app.config["SQL_VALIDATORS_BY_ENGINE"] = {}
 
         resp = self.validate_sql(
-            "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+            "SELECT * FROM birth_names", client_id="1", raise_on_error=False
         )
         self.assertIn("error", resp)
         self.assertIn("no SQL validator is configured", resp["error"])
@@ -97,7 +97,7 @@ class SqlValidatorEndpointTests(SupersetTestCase):
         validator.validate.side_effect = Exception("Kaboom!")
 
         resp = self.validate_sql(
-            "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+            "SELECT * FROM birth_names", client_id="1", raise_on_error=False
         )
         self.assertIn("error", resp)
         self.assertIn("Kaboom!", resp["error"])
@@ -186,7 +186,7 @@ class PrestoValidatorTests(SupersetTestCase):
         #    validator for sqlite, this test will fail because the validator
         #    will no longer error out.
         resp = self.validate_sql(
-            "SELECT * FROM ab_user", client_id="1", raise_on_error=False
+            "SELECT * FROM birth_names", client_id="1", raise_on_error=False
         )
         self.assertIn("error", resp)
         self.assertIn("no SQL validator is configured", resp["error"])
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index a46f388..326311a 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -16,7 +16,7 @@
 # under the License.
 from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.db_engine_specs.druid import DruidEngineSpec
-from superset.utils.core import get_main_database
+from superset.utils.core import get_example_database
 from .base_tests import SupersetTestCase
 
 
@@ -43,7 +43,7 @@ class DatabaseModelTestCase(SupersetTestCase):
 
     def test_has_extra_cache_keys(self):
         query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
-        table = SqlaTable(sql=query, database=get_main_database())
+        table = SqlaTable(sql=query, database=get_example_database())
         query_obj = {
             "granularity": None,
             "from_dttm": None,
@@ -60,7 +60,7 @@ class DatabaseModelTestCase(SupersetTestCase):
 
     def test_has_no_extra_cache_keys(self):
         query = "SELECT 'abc' as user"
-        table = SqlaTable(sql=query, database=get_main_database())
+        table = SqlaTable(sql=query, database=get_example_database())
         query_obj = {
             "granularity": None,
             "from_dttm": None,
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 1774c26..cad88f3 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -17,18 +17,20 @@
 """Unit tests for Sql Lab"""
 from datetime import datetime, timedelta
 import json
-import unittest
 
-from flask_appbuilder.security.sqla import models as ab_models
 import prison
 
 from superset import db, security_manager
 from superset.dataframe import SupersetDataFrame
 from superset.db_engine_specs import BaseEngineSpec
 from superset.models.sql_lab import Query
-from superset.utils.core import datetime_to_epoch, get_main_database
+from superset.utils.core import datetime_to_epoch, get_example_database
 from .base_tests import SupersetTestCase
 
+QUERY_1 = "SELECT * FROM birth_names LIMIT 1"
+QUERY_2 = "SELECT * FROM NO_TABLE"
+QUERY_3 = "SELECT * FROM birth_names LIMIT 10"
+
 
 class SqlLabTests(SupersetTestCase):
     """Testings for Sql Lab"""
@@ -39,17 +41,9 @@ class SqlLabTests(SupersetTestCase):
     def run_some_queries(self):
         db.session.query(Query).delete()
         db.session.commit()
-        self.run_sql(
-            "SELECT * FROM ab_user", client_id="client_id_1", user_name="admin"
-        )
-        self.run_sql(
-            "SELECT * FROM NO_TABLE", client_id="client_id_3", 
user_name="admin"
-        )
-        self.run_sql(
-            "SELECT * FROM ab_permission",
-            client_id="client_id_2",
-            user_name="gamma_sqllab",
-        )
+        self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
+        self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
+        self.run_sql(QUERY_3, client_id="client_id_2", 
user_name="gamma_sqllab")
         self.logout()
 
     def tearDown(self):
@@ -61,7 +55,7 @@ class SqlLabTests(SupersetTestCase):
     def test_sql_json(self):
         self.login("admin")
 
-        data = self.run_sql("SELECT * FROM ab_user", "1")
+        data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1")
         self.assertLess(0, len(data["data"]))
 
         data = self.run_sql("SELECT * FROM unexistant_table", "2")
@@ -71,8 +65,8 @@ class SqlLabTests(SupersetTestCase):
         self.login("admin")
 
         multi_sql = """
-        SELECT first_name FROM ab_user;
-        SELECT first_name FROM ab_user;
+        SELECT * FROM birth_names LIMIT 1;
+        SELECT * FROM birth_names LIMIT 2;
         """
         data = self.run_sql(multi_sql, "2234")
         self.assertLess(0, len(data["data"]))
@@ -80,24 +74,18 @@ class SqlLabTests(SupersetTestCase):
     def test_explain(self):
         self.login("admin")
 
-        data = self.run_sql("EXPLAIN SELECT * FROM ab_user", "1")
+        data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1")
         self.assertLess(0, len(data["data"]))
 
     def test_sql_json_has_access(self):
-        main_db = get_main_database()
-        security_manager.add_permission_view_menu("database_access", 
main_db.perm)
-        db.session.commit()
-        main_db_permission_view = (
-            db.session.query(ab_models.PermissionView)
-            .join(ab_models.ViewMenu)
-            .join(ab_models.Permission)
-            .filter(ab_models.ViewMenu.name == "[main].(id:1)")
-            .filter(ab_models.Permission.name == "database_access")
-            .first()
+        examples_db = get_example_database()
+        examples_db_permission_view = 
security_manager.add_permission_view_menu(
+            "database_access", examples_db.perm
         )
+
         astronaut = security_manager.add_role("Astronaut")
-        security_manager.add_permission_role(astronaut, 
main_db_permission_view)
-        # Astronaut role is Gamma + sqllab +  main db permissions
+        security_manager.add_permission_role(astronaut, 
examples_db_permission_view)
+        # Astronaut role is Gamma + sqllab + db permissions
         for perm in security_manager.find_role("Gamma").permissions:
             security_manager.add_permission_role(astronaut, perm)
         for perm in security_manager.find_role("sql_lab").permissions:
@@ -113,7 +101,7 @@ class SqlLabTests(SupersetTestCase):
                 astronaut,
                 password="general",
             )
-        data = self.run_sql("SELECT * FROM ab_user", "3", user_name="gagarin")
+        data = self.run_sql(QUERY_1, "3", user_name="gagarin")
         db.session.query(Query).delete()
         db.session.commit()
         self.assertLess(0, len(data["data"]))
@@ -132,8 +120,8 @@ class SqlLabTests(SupersetTestCase):
         self.assertEquals(2, len(data))
 
         # Run 2 more queries
-        self.run_sql("SELECT * FROM ab_user LIMIT 1", client_id="client_id_4")
-        self.run_sql("SELECT * FROM ab_user LIMIT 2", client_id="client_id_5")
+        self.run_sql("SELECT * FROM birth_names LIMIT 1", 
client_id="client_id_4")
+        self.run_sql("SELECT * FROM birth_names LIMIT 2", 
client_id="client_id_5")
         self.login("admin")
         data = self.get_json_resp("/superset/queries/0")
         self.assertEquals(4, len(data))
@@ -141,7 +129,7 @@ class SqlLabTests(SupersetTestCase):
         now = datetime.now() + timedelta(days=1)
         query = (
             db.session.query(Query)
-            .filter_by(sql="SELECT * FROM ab_user LIMIT 1")
+            .filter_by(sql="SELECT * FROM birth_names LIMIT 1")
             .first()
         )
         query.changed_on = now
@@ -160,11 +148,15 @@ class SqlLabTests(SupersetTestCase):
     def test_search_query_on_db_id(self):
         self.run_some_queries()
         self.login("admin")
+        examples_dbid = get_example_database().id
+
         # Test search queries on database Id
-        data = self.get_json_resp("/superset/search_queries?database_id=1")
+        data = self.get_json_resp(
+            f"/superset/search_queries?database_id={examples_dbid}"
+        )
         self.assertEquals(3, len(data))
         db_ids = [k["dbId"] for k in data]
-        self.assertEquals([1, 1, 1], db_ids)
+        self.assertEquals([examples_dbid for i in range(3)], db_ids)
 
         resp = self.get_resp("/superset/search_queries?database_id=-1")
         data = json.loads(resp)
@@ -205,19 +197,19 @@ class SqlLabTests(SupersetTestCase):
     def test_search_query_on_text(self):
         self.run_some_queries()
         self.login("admin")
-        url = "/superset/search_queries?search_text=permission"
+        url = "/superset/search_queries?search_text=birth"
         data = self.get_json_resp(url)
-        self.assertEquals(1, len(data))
-        self.assertIn("permission", data[0]["sql"])
+        self.assertEquals(2, len(data))
+        self.assertIn("birth", data[0]["sql"])
 
     def test_search_query_on_time(self):
         self.run_some_queries()
         self.login("admin")
         first_query_time = (
-            db.session.query(Query).filter_by(sql="SELECT * FROM 
ab_user").one()
+            db.session.query(Query).filter_by(sql=QUERY_1).one()
         ).start_time
         second_query_time = (
-            db.session.query(Query).filter_by(sql="SELECT * FROM 
ab_permission").one()
+            db.session.query(Query).filter_by(sql=QUERY_3).one()
         ).start_time
         # Test search queries on time filter
         from_time = "from={}".format(int(first_query_time))
@@ -265,7 +257,7 @@ class SqlLabTests(SupersetTestCase):
 
     def test_alias_duplicate(self):
         self.run_sql(
-            "SELECT username as col, id as col, username FROM ab_user",
+            "SELECT name as col, gender as col FROM birth_names LIMIT 10",
             client_id="2e2df3",
             user_name="admin",
             raise_on_error=True,
@@ -281,7 +273,7 @@ class SqlLabTests(SupersetTestCase):
 
     def test_df_conversion_tuple(self):
         cols = ["string_col", "int_col", "list_col", "float_col"]
-        data = [(u"Text", 111, [123], 1.0)]
+        data = [("Text", 111, [123], 1.0)]
         cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
 
         self.assertEquals(len(data), cdf.size)
@@ -296,6 +288,7 @@ class SqlLabTests(SupersetTestCase):
         self.assertEquals(len(cols), len(cdf.columns))
 
     def test_sqllab_viz(self):
+        examples_dbid = get_example_database().id
         payload = {
             "chartType": "dist_bar",
             "datasourceName": "test_viz_flow_table",
@@ -316,11 +309,10 @@ class SqlLabTests(SupersetTestCase):
                 },
             ],
             "sql": """\
-                SELECT viz_type, count(1) as ccount
-                FROM slices
-                WHERE viz_type LIKE '%a%'
-                GROUP BY viz_type""",
-            "dbId": 1,
+                SELECT *
+                FROM birth_names
+                LIMIT 10""",
+            "dbId": examples_dbid,
         }
         data = {"data": json.dumps(payload)}
         resp = self.get_json_resp("/superset/sqllab_viz/", data=data)
@@ -329,20 +321,20 @@ class SqlLabTests(SupersetTestCase):
     def test_sql_limit(self):
         self.login("admin")
         test_limit = 1
-        data = self.run_sql("SELECT * FROM ab_user", client_id="sql_limit_1")
+        data = self.run_sql("SELECT * FROM birth_names", 
client_id="sql_limit_1")
         self.assertGreater(len(data["data"]), test_limit)
         data = self.run_sql(
-            "SELECT * FROM ab_user", client_id="sql_limit_2", 
query_limit=test_limit
+            "SELECT * FROM birth_names", client_id="sql_limit_2", 
query_limit=test_limit
         )
         self.assertEquals(len(data["data"]), test_limit)
         data = self.run_sql(
-            "SELECT * FROM ab_user LIMIT {}".format(test_limit),
+            "SELECT * FROM birth_names LIMIT {}".format(test_limit),
             client_id="sql_limit_3",
             query_limit=test_limit + 1,
         )
         self.assertEquals(len(data["data"]), test_limit)
         data = self.run_sql(
-            "SELECT * FROM ab_user LIMIT {}".format(test_limit + 1),
+            "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1),
             client_id="sql_limit_4",
             query_limit=test_limit,
         )
@@ -406,6 +398,7 @@ class SqlLabTests(SupersetTestCase):
 
     def test_api_database(self):
         self.login("admin")
+        self.create_fake_db()
 
         arguments = {
             "keys": [],
@@ -415,12 +408,8 @@ class SqlLabTests(SupersetTestCase):
             "page": 0,
             "page_size": -1,
         }
-        expected_results = ["examples", "fake_db_100", "main"]
         url = "api/v1/database/?{}={}".format("q", prison.dumps(arguments))
-        data = self.get_json_resp(url)
-        for i, expected_result in enumerate(expected_results):
-            self.assertEquals(expected_result, 
data["result"][i]["database_name"])
-
-
-if __name__ == "__main__":
-    unittest.main()
+        self.assertEquals(
+            {"examples", "fake_db_100"},
+            {r.get("database_name") for r in 
self.get_json_resp(url)["result"]},
+        )

Reply via email to