This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new d8bb2d3  refactor(db_engine_specs): Removing top-level import of app 
(#14366)
d8bb2d3 is described below

commit d8bb2d3e62724febbdfde73771eae4cdc0666079
Author: John Bodley <[email protected]>
AuthorDate: Wed Apr 28 15:47:32 2021 +1200

    refactor(db_engine_specs): Removing top-level import of app (#14366)
    
    Co-authored-by: John Bodley <[email protected]>
---
 superset/db_engine_specs/base.py                |  13 +--
 superset/db_engine_specs/hive.py                |  21 ++--
 superset/db_engine_specs/presto.py              |   6 +-
 tests/conftest.py                               |   4 +-
 tests/db_engine_specs/athena_tests.py           |   2 -
 tests/db_engine_specs/base_engine_spec_tests.py | 149 +++++++++++++-----------
 tests/db_engine_specs/hive_tests.py             |  90 +++++++-------
 7 files changed, 143 insertions(+), 142 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 7ea73e3..3f91d58 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -40,7 +40,7 @@ import pandas as pd
 import sqlparse
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
-from flask import g
+from flask import current_app, g
 from flask_babel import gettext as __, lazy_gettext as _
 from marshmallow import fields, Schema
 from sqlalchemy import column, DateTime, select, types
@@ -55,7 +55,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select, 
TextAsFrom
 from sqlalchemy.types import String, TypeEngine, UnicodeText
 from typing_extensions import TypedDict
 
-from superset import app, security_manager, sql_parse
+from superset import security_manager, sql_parse
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
 from superset.models.sql_types.base import literal_dttm_type_factory
@@ -80,7 +80,6 @@ class TimeGrain(NamedTuple):  # pylint: 
disable=too-few-public-methods
 
 
 QueryStatus = utils.QueryStatus
-config = app.config
 
 builtin_time_grains: Dict[Optional[str], str] = {
     None: __("Original value"),
@@ -369,7 +368,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         ret_list = []
         time_grains = builtin_time_grains.copy()
-        time_grains.update(config["TIME_GRAIN_ADDONS"])
+        time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
         for duration, func in cls.get_time_grain_expressions().items():
             if duration in time_grains:
                 name = time_grains[duration]
@@ -448,9 +447,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         # TODO: use @memoize decorator or similar to avoid recomputation on 
every call
         time_grain_expressions = cls._time_grain_expressions.copy()
-        grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"]
+        grain_addon_expressions = 
current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
         time_grain_expressions.update(grain_addon_expressions.get(cls.engine, 
{}))
-        denylist: List[str] = config["TIME_GRAIN_DENYLIST"]
+        denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
         for key in denylist:
             time_grain_expressions.pop(key)
 
@@ -977,7 +976,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         parsed_query = ParsedQuery(statement)
         sql = parsed_query.stripped()
-        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
+        sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
         if sql_query_mutator:
             sql = sql_query_mutator(sql, user_name, security_manager, database)
 
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 6996beb..66c68b3 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -27,7 +27,7 @@ import numpy as np
 import pandas as pd
 import pyarrow as pa
 import pyarrow.parquet as pq
-from flask import g
+from flask import current_app, g
 from sqlalchemy import Column, text
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.engine.reflection import Inspector
@@ -35,7 +35,6 @@ from sqlalchemy.engine.url import make_url, URL
 from sqlalchemy.orm import Session
 from sqlalchemy.sql.expression import ColumnClause, Select
 
-from superset import app, conf
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.presto import PrestoEngineSpec
 from superset.exceptions import SupersetException
@@ -50,12 +49,8 @@ if TYPE_CHECKING:
 
 
 QueryStatus = utils.QueryStatus
-config = app.config
 logger = logging.getLogger(__name__)
 
-tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
-hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
-
 
 def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
     """
@@ -70,7 +65,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: 
Table) -> str:
     # Optional dependency
     import boto3  # pylint: disable=import-error
 
-    bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
+    bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
 
     if not bucket_path:
         logger.info("No upload bucket specified")
@@ -229,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
         )
 
         with tempfile.NamedTemporaryFile(
-            dir=config["UPLOAD_FOLDER"], suffix=".parquet"
+            dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
         ) as file:
             pq.write_table(pa.Table.from_pandas(df), where=file.name)
 
@@ -243,9 +238,9 @@ class HiveEngineSpec(PrestoEngineSpec):
                 ),
                 location=upload_to_s3(
                     filename=file.name,
-                    upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
-                        database, g.user, table.schema
-                    ),
+                    upload_prefix=current_app.config[
+                        "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
+                    ](database, g.user, table.schema),
                     table=table,
                 ),
             )
@@ -356,7 +351,7 @@ class HiveEngineSpec(PrestoEngineSpec):
                             str(query_id),
                             tracking_url,
                         )
-                        tracking_url = tracking_url_trans(tracking_url)
+                        tracking_url = 
current_app.config["TRACKING_URL_TRANSFORMER"]
                         logger.info(
                             "Query %s: Transformation applied: %s",
                             str(query_id),
@@ -374,7 +369,7 @@ class HiveEngineSpec(PrestoEngineSpec):
                     last_log_line = len(log_lines)
                 if needs_commit:
                     session.commit()
-            time.sleep(hive_poll_interval)
+            time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
             polled = cursor.poll()
 
     @classmethod
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index c6cec6a..32741b1 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -39,6 +39,7 @@ from urllib import parse
 
 import pandas as pd
 import simplejson as json
+from flask import current_app
 from flask_babel import gettext as __, lazy_gettext as _
 from sqlalchemy import Column, literal_column, types
 from sqlalchemy.engine.base import Engine
@@ -49,7 +50,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.sql.expression import ColumnClause, Select
 from sqlalchemy.types import TypeEngine
 
-from superset import app, cache_manager, is_feature_enabled
+from superset import cache_manager, is_feature_enabled
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.errors import SupersetErrorType
 from superset.exceptions import SupersetTemplateException
@@ -94,7 +95,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
 
 
 QueryStatus = utils.QueryStatus
-config = app.config
 logger = logging.getLogger(__name__)
 
 
@@ -940,7 +940,7 @@ class PrestoEngineSpec(BaseEngineSpec):  # pylint: 
disable=too-many-public-metho
         """Updates progress information"""
         query_id = query.id
         poll_interval = query.database.connect_args.get(
-            "poll_interval", config["PRESTO_POLL_INTERVAL"]
+            "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
         )
         logger.info("Query %i: Polling the cursor for progress", query_id)
         polled = cursor.poll()
diff --git a/tests/conftest.py b/tests/conftest.py
index b04a76c..456c8fb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -21,7 +21,6 @@ import pytest
 from sqlalchemy.engine import Engine
 
 from tests.test_app import app
-
 from superset import db
 from superset.utils.core import get_example_database, json_dumps_w_dates
 
@@ -38,6 +37,9 @@ def app_context():
 
 @pytest.fixture(autouse=True, scope="session")
 def setup_sample_data() -> Any:
+    # TODO(john-bodley): Determine a cleaner way of setting up the sample data 
without
+    # relying on `tests.test_app.app` leveraging an  `app` fixture which is 
purposely
+    # scoped to the function level to ensure tests remain idempotent.
     with app.app_context():
         setup_presto_if_needed()
 
diff --git a/tests/db_engine_specs/athena_tests.py 
b/tests/db_engine_specs/athena_tests.py
index 92160db..d928a98 100644
--- a/tests/db_engine_specs/athena_tests.py
+++ b/tests/db_engine_specs/athena_tests.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from tests.test_app import app  # isort:skip
-
 from superset.db_engine_specs.athena import AthenaEngineSpec
 from tests.db_engine_specs.base_tests import TestDbEngineSpec
 
diff --git a/tests/db_engine_specs/base_engine_spec_tests.py 
b/tests/db_engine_specs/base_engine_spec_tests.py
index b8b82ab..097e0ba 100644
--- a/tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/db_engine_specs/base_engine_spec_tests.py
@@ -167,25 +167,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
             "SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
         )
 
-    def test_time_grain_denylist(self):
-        with app.app_context():
-            app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
-            time_grain_functions = 
SqliteEngineSpec.get_time_grain_expressions()
-            self.assertNotIn("PT1M", time_grain_functions)
-
-    def test_time_grain_addons(self):
-        with app.app_context():
-            app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
-            app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
-                "sqlite": {"PTXM": "ABC({col})"}
-            }
-            time_grains = SqliteEngineSpec.get_time_grains()
-            time_grain_addon = time_grains[-1]
-            self.assertEqual("PTXM", time_grain_addon.duration)
-            self.assertEqual("x seconds", time_grain_addon.label)
-            app.config["TIME_GRAIN_ADDONS"] = {}
-            app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
     def test_engine_time_grain_validity(self):
         time_grains = set(builtin_time_grains.keys())
         # loop over all subclasses of BaseEngineSpec
@@ -198,43 +179,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
                 intersection = time_grains.intersection(defined_grains)
                 self.assertSetEqual(defined_grains, intersection, engine)
 
-    def test_get_time_grain_with_config(self):
-        """ Should concatenate from configs and then sort in the proper order 
"""
-        app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
-            "mysql": {
-                "PT2H": "foo",
-                "PT4H": "foo",
-                "PT6H": "foo",
-                "PT8H": "foo",
-                "PT10H": "foo",
-                "PT12H": "foo",
-                "PT1S": "foo",
-            }
-        }
-        time_grains = MySQLEngineSpec.get_time_grain_expressions()
-        self.assertEqual(
-            list(time_grains.keys()),
-            [
-                None,
-                "PT1S",
-                "PT1M",
-                "PT1H",
-                "PT2H",
-                "PT4H",
-                "PT6H",
-                "PT8H",
-                "PT10H",
-                "PT12H",
-                "P1D",
-                "P1W",
-                "P1M",
-                "P0.25Y",
-                "P1Y",
-                "1969-12-29T00:00:00Z/P1W",
-            ],
-        )
-        app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
     def test_get_time_grain_expressions(self):
         time_grains = MySQLEngineSpec.get_time_grain_expressions()
         self.assertEqual(
@@ -253,18 +197,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
             ],
         )
 
-    def test_get_time_grain_with_unkown_values(self):
-        """Should concatenate from configs and then sort in the proper order
-        putting unknown patterns at the end"""
-        app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
-            "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
-        }
-        time_grains = MySQLEngineSpec.get_time_grain_expressions()
-        self.assertEqual(
-            list(time_grains)[-1], "weird",
-        )
-        app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
     def test_get_table_names(self):
         inspector = mock.Mock()
         inspector.get_table_names = mock.Mock(return_value=["schema.table", 
"table_2"])
@@ -339,3 +271,84 @@ def test_is_readonly():
     assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
     assert is_readonly("SHOW CATALOGS")
     assert is_readonly("SHOW TABLES")
+
+
+def test_time_grain_denylist():
+    config = app.config.copy()
+    app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
+
+    with app.app_context():
+        time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
+        assert not "PT1M" in time_grain_functions
+
+    app.config = config
+
+
+def test_time_grain_addons():
+    config = app.config.copy()
+    app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
+    app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": 
"ABC({col})"}}
+
+    with app.app_context():
+        time_grains = SqliteEngineSpec.get_time_grains()
+        time_grain_addon = time_grains[-1]
+        assert "PTXM" == time_grain_addon.duration
+        assert "x seconds" == time_grain_addon.label
+
+    app.config = config
+
+
+def test_get_time_grain_with_config():
+    """ Should concatenate from configs and then sort in the proper order """
+    config = app.config.copy()
+
+    app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
+        "mysql": {
+            "PT2H": "foo",
+            "PT4H": "foo",
+            "PT6H": "foo",
+            "PT8H": "foo",
+            "PT10H": "foo",
+            "PT12H": "foo",
+            "PT1S": "foo",
+        }
+    }
+
+    with app.app_context():
+        time_grains = MySQLEngineSpec.get_time_grain_expressions()
+        assert set(time_grains.keys()) == {
+            None,
+            "PT1S",
+            "PT1M",
+            "PT1H",
+            "PT2H",
+            "PT4H",
+            "PT6H",
+            "PT8H",
+            "PT10H",
+            "PT12H",
+            "P1D",
+            "P1W",
+            "P1M",
+            "P0.25Y",
+            "P1Y",
+            "1969-12-29T00:00:00Z/P1W",
+        }
+
+    app.config = config
+
+
+def test_get_time_grain_with_unkown_values():
+    """Should concatenate from configs and then sort in the proper order
+    putting unknown patterns at the end"""
+    config = app.config.copy()
+
+    app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
+        "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
+    }
+
+    with app.app_context():
+        time_grains = MySQLEngineSpec.get_time_grain_expressions()
+        assert list(time_grains)[-1] == "weird"
+
+    app.config = config
diff --git a/tests/db_engine_specs/hive_tests.py 
b/tests/db_engine_specs/hive_tests.py
index 4d50518..1e978b7 100644
--- a/tests/db_engine_specs/hive_tests.py
+++ b/tests/db_engine_specs/hive_tests.py
@@ -21,12 +21,11 @@ from unittest import mock
 import pytest
 import pandas as pd
 from sqlalchemy.sql import select
-from tests.test_app import app
 
-with app.app_context():
-    from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
+from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
 from superset.exceptions import SupersetException
 from superset.sql_parse import Table, ParsedQuery
+from tests.test_app import app
 
 
 def test_0_progress():
@@ -170,10 +169,6 @@ def test_df_to_csv() -> None:
         )
 
 
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
 @mock.patch("superset.db_engine_specs.hive.g", spec={})
 def test_df_to_sql_if_exists_fail(mock_g):
     mock_g.user = True
@@ -185,10 +180,6 @@ def test_df_to_sql_if_exists_fail(mock_g):
         )
 
 
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
 @mock.patch("superset.db_engine_specs.hive.g", spec={})
 def test_df_to_sql_if_exists_fail_with_schema(mock_g):
     mock_g.user = True
@@ -203,13 +194,11 @@ def test_df_to_sql_if_exists_fail_with_schema(mock_g):
         )
 
 
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
 @mock.patch("superset.db_engine_specs.hive.g", spec={})
 @mock.patch("superset.db_engine_specs.hive.upload_to_s3")
 def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
+    config = app.config.copy()
+    app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
     mock_upload_to_s3.return_value = "mock-location"
     mock_g.user = True
     mock_database = mock.MagicMock()
@@ -218,23 +207,23 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, 
mock_g):
     mock_database.get_sqla_engine.return_value.execute = mock_execute
     table_name = "foobar"
 
-    HiveEngineSpec.df_to_sql(
-        mock_database,
-        Table(table=table_name),
-        pd.DataFrame(),
-        {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": 
"mock"},
-    )
+    with app.app_context():
+        HiveEngineSpec.df_to_sql(
+            mock_database,
+            Table(table=table_name),
+            pd.DataFrame(),
+            {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": 
"mock"},
+        )
 
     mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
+    app.config = config
 
 
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
 @mock.patch("superset.db_engine_specs.hive.g", spec={})
 @mock.patch("superset.db_engine_specs.hive.upload_to_s3")
 def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
+    config = app.config.copy()
+    app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
     mock_upload_to_s3.return_value = "mock-location"
     mock_g.user = True
     mock_database = mock.MagicMock()
@@ -244,14 +233,16 @@ def 
test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
     table_name = "foobar"
     schema = "schema"
 
-    HiveEngineSpec.df_to_sql(
-        mock_database,
-        Table(table=table_name, schema=schema),
-        pd.DataFrame(),
-        {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": 
"mock"},
-    )
+    with app.app_context():
+        HiveEngineSpec.df_to_sql(
+            mock_database,
+            Table(table=table_name, schema=schema),
+            pd.DataFrame(),
+            {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": 
"mock"},
+        )
 
     mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
+    app.config = config
 
 
 def test_is_readonly():
@@ -284,39 +275,42 @@ def test_s3_upload_prefix(schema: str, upload_prefix: 
str) -> None:
 
 
 def test_upload_to_s3_no_bucket_path():
-    with pytest.raises(
-        Exception,
-        match="No upload bucket specified. You can specify one in the config 
file.",
-    ):
-        upload_to_s3("filename", "prefix", Table("table"))
+    with app.app_context():
+        with pytest.raises(
+            Exception,
+            match="No upload bucket specified. You can specify one in the 
config file.",
+        ):
+            upload_to_s3("filename", "prefix", Table("table"))
 
 
 @mock.patch("boto3.client")
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
-)
 def test_upload_to_s3_client_error(client):
+    config = app.config.copy()
+    app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
     from botocore.exceptions import ClientError
 
     client.return_value.upload_file.side_effect = ClientError(
         {"Error": {}}, "operation_name"
     )
 
-    with pytest.raises(ClientError):
-        upload_to_s3("filename", "prefix", Table("table"))
+    with app.app_context():
+        with pytest.raises(ClientError):
+            upload_to_s3("filename", "prefix", Table("table"))
+
+    app.config = config
 
 
 @mock.patch("boto3.client")
[email protected](
-    "superset.db_engine_specs.hive.config",
-    {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
-)
 def test_upload_to_s3_success(client):
+    config = app.config.copy()
+    app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
     client.return_value.upload_file.return_value = True
 
-    location = upload_to_s3("filename", "prefix", Table("table"))
-    assert f"s3a://bucket/prefix/table" == location
+    with app.app_context():
+        location = upload_to_s3("filename", "prefix", Table("table"))
+        assert f"s3a://bucket/prefix/table" == location
+
+    app.config = config
 
 
 def test_fetch_data_query_error():

Reply via email to