This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new d8bb2d3 refactor(db_engine_specs): Removing top-level import of app
(#14366)
d8bb2d3 is described below
commit d8bb2d3e62724febbdfde73771eae4cdc0666079
Author: John Bodley <[email protected]>
AuthorDate: Wed Apr 28 15:47:32 2021 +1200
refactor(db_engine_specs): Removing top-level import of app (#14366)
Co-authored-by: John Bodley <[email protected]>
---
superset/db_engine_specs/base.py | 13 +--
superset/db_engine_specs/hive.py | 21 ++--
superset/db_engine_specs/presto.py | 6 +-
tests/conftest.py | 4 +-
tests/db_engine_specs/athena_tests.py | 2 -
tests/db_engine_specs/base_engine_spec_tests.py | 149 +++++++++++++-----------
tests/db_engine_specs/hive_tests.py | 90 +++++++-------
7 files changed, 143 insertions(+), 142 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 7ea73e3..3f91d58 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -40,7 +40,7 @@ import pandas as pd
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
-from flask import g
+from flask import current_app, g
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from sqlalchemy import column, DateTime, select, types
@@ -55,7 +55,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select,
TextAsFrom
from sqlalchemy.types import String, TypeEngine, UnicodeText
from typing_extensions import TypedDict
-from superset import app, security_manager, sql_parse
+from superset import security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.models.sql_types.base import literal_dttm_type_factory
@@ -80,7 +80,6 @@ class TimeGrain(NamedTuple): # pylint:
disable=too-few-public-methods
QueryStatus = utils.QueryStatus
-config = app.config
builtin_time_grains: Dict[Optional[str], str] = {
None: __("Original value"),
@@ -369,7 +368,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
ret_list = []
time_grains = builtin_time_grains.copy()
- time_grains.update(config["TIME_GRAIN_ADDONS"])
+ time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
for duration, func in cls.get_time_grain_expressions().items():
if duration in time_grains:
name = time_grains[duration]
@@ -448,9 +447,9 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
# TODO: use @memoize decorator or similar to avoid recomputation on
every call
time_grain_expressions = cls._time_grain_expressions.copy()
- grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"]
+ grain_addon_expressions =
current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine,
{}))
- denylist: List[str] = config["TIME_GRAIN_DENYLIST"]
+ denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key)
@@ -977,7 +976,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()
- sql_query_mutator = config["SQL_QUERY_MUTATOR"]
+ sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 6996beb..66c68b3 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -27,7 +27,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
-from flask import g
+from flask import current_app, g
from sqlalchemy import Column, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
@@ -35,7 +35,6 @@ from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
-from superset import app, conf
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.exceptions import SupersetException
@@ -50,12 +49,8 @@ if TYPE_CHECKING:
QueryStatus = utils.QueryStatus
-config = app.config
logger = logging.getLogger(__name__)
-tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
-hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
-
def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
"""
@@ -70,7 +65,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table:
Table) -> str:
# Optional dependency
import boto3 # pylint: disable=import-error
- bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
+ bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
if not bucket_path:
logger.info("No upload bucket specified")
@@ -229,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
)
with tempfile.NamedTemporaryFile(
- dir=config["UPLOAD_FOLDER"], suffix=".parquet"
+ dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
) as file:
pq.write_table(pa.Table.from_pandas(df), where=file.name)
@@ -243,9 +238,9 @@ class HiveEngineSpec(PrestoEngineSpec):
),
location=upload_to_s3(
filename=file.name,
- upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
- database, g.user, table.schema
- ),
+ upload_prefix=current_app.config[
+ "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
+ ](database, g.user, table.schema),
table=table,
),
)
@@ -356,7 +351,7 @@ class HiveEngineSpec(PrestoEngineSpec):
str(query_id),
tracking_url,
)
- tracking_url = tracking_url_trans(tracking_url)
+ tracking_url =
current_app.config["TRACKING_URL_TRANSFORMER"]
logger.info(
"Query %s: Transformation applied: %s",
str(query_id),
@@ -374,7 +369,7 @@ class HiveEngineSpec(PrestoEngineSpec):
last_log_line = len(log_lines)
if needs_commit:
session.commit()
- time.sleep(hive_poll_interval)
+ time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
polled = cursor.poll()
@classmethod
diff --git a/superset/db_engine_specs/presto.py
b/superset/db_engine_specs/presto.py
index c6cec6a..32741b1 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -39,6 +39,7 @@ from urllib import parse
import pandas as pd
import simplejson as json
+from flask import current_app
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
@@ -49,7 +50,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.types import TypeEngine
-from superset import app, cache_manager, is_feature_enabled
+from superset import cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException
@@ -94,7 +95,6 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
QueryStatus = utils.QueryStatus
-config = app.config
logger = logging.getLogger(__name__)
@@ -940,7 +940,7 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint:
disable=too-many-public-metho
"""Updates progress information"""
query_id = query.id
poll_interval = query.database.connect_args.get(
- "poll_interval", config["PRESTO_POLL_INTERVAL"]
+ "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()
diff --git a/tests/conftest.py b/tests/conftest.py
index b04a76c..456c8fb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -21,7 +21,6 @@ import pytest
from sqlalchemy.engine import Engine
from tests.test_app import app
-
from superset import db
from superset.utils.core import get_example_database, json_dumps_w_dates
@@ -38,6 +37,9 @@ def app_context():
@pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any:
+ # TODO(john-bodley): Determine a cleaner way of setting up the sample data
without
+ # relying on `tests.test_app.app` leveraging an `app` fixture which is
purposely
+ # scoped to the function level to ensure tests remain idempotent.
with app.app_context():
setup_presto_if_needed()
diff --git a/tests/db_engine_specs/athena_tests.py
b/tests/db_engine_specs/athena_tests.py
index 92160db..d928a98 100644
--- a/tests/db_engine_specs/athena_tests.py
+++ b/tests/db_engine_specs/athena_tests.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from tests.test_app import app # isort:skip
-
from superset.db_engine_specs.athena import AthenaEngineSpec
from tests.db_engine_specs.base_tests import TestDbEngineSpec
diff --git a/tests/db_engine_specs/base_engine_spec_tests.py
b/tests/db_engine_specs/base_engine_spec_tests.py
index b8b82ab..097e0ba 100644
--- a/tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/db_engine_specs/base_engine_spec_tests.py
@@ -167,25 +167,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
)
- def test_time_grain_denylist(self):
- with app.app_context():
- app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
- time_grain_functions =
SqliteEngineSpec.get_time_grain_expressions()
- self.assertNotIn("PT1M", time_grain_functions)
-
- def test_time_grain_addons(self):
- with app.app_context():
- app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
- "sqlite": {"PTXM": "ABC({col})"}
- }
- time_grains = SqliteEngineSpec.get_time_grains()
- time_grain_addon = time_grains[-1]
- self.assertEqual("PTXM", time_grain_addon.duration)
- self.assertEqual("x seconds", time_grain_addon.label)
- app.config["TIME_GRAIN_ADDONS"] = {}
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
@@ -198,43 +179,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
- def test_get_time_grain_with_config(self):
- """ Should concatenate from configs and then sort in the proper order
"""
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
- "mysql": {
- "PT2H": "foo",
- "PT4H": "foo",
- "PT6H": "foo",
- "PT8H": "foo",
- "PT10H": "foo",
- "PT12H": "foo",
- "PT1S": "foo",
- }
- }
- time_grains = MySQLEngineSpec.get_time_grain_expressions()
- self.assertEqual(
- list(time_grains.keys()),
- [
- None,
- "PT1S",
- "PT1M",
- "PT1H",
- "PT2H",
- "PT4H",
- "PT6H",
- "PT8H",
- "PT10H",
- "PT12H",
- "P1D",
- "P1W",
- "P1M",
- "P0.25Y",
- "P1Y",
- "1969-12-29T00:00:00Z/P1W",
- ],
- )
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
def test_get_time_grain_expressions(self):
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
@@ -253,18 +197,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
],
)
- def test_get_time_grain_with_unkown_values(self):
- """Should concatenate from configs and then sort in the proper order
- putting unknown patterns at the end"""
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
- "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
- }
- time_grains = MySQLEngineSpec.get_time_grain_expressions()
- self.assertEqual(
- list(time_grains)[-1], "weird",
- )
- app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
-
def test_get_table_names(self):
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=["schema.table",
"table_2"])
@@ -339,3 +271,84 @@ def test_is_readonly():
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")
+
+
+def test_time_grain_denylist():
+ config = app.config.copy()
+ app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
+
+ with app.app_context():
+ time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
+ assert not "PT1M" in time_grain_functions
+
+ app.config = config
+
+
+def test_time_grain_addons():
+ config = app.config.copy()
+ app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
+ app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM":
"ABC({col})"}}
+
+ with app.app_context():
+ time_grains = SqliteEngineSpec.get_time_grains()
+ time_grain_addon = time_grains[-1]
+ assert "PTXM" == time_grain_addon.duration
+ assert "x seconds" == time_grain_addon.label
+
+ app.config = config
+
+
+def test_get_time_grain_with_config():
+ """ Should concatenate from configs and then sort in the proper order """
+ config = app.config.copy()
+
+ app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
+ "mysql": {
+ "PT2H": "foo",
+ "PT4H": "foo",
+ "PT6H": "foo",
+ "PT8H": "foo",
+ "PT10H": "foo",
+ "PT12H": "foo",
+ "PT1S": "foo",
+ }
+ }
+
+ with app.app_context():
+ time_grains = MySQLEngineSpec.get_time_grain_expressions()
+ assert set(time_grains.keys()) == {
+ None,
+ "PT1S",
+ "PT1M",
+ "PT1H",
+ "PT2H",
+ "PT4H",
+ "PT6H",
+ "PT8H",
+ "PT10H",
+ "PT12H",
+ "P1D",
+ "P1W",
+ "P1M",
+ "P0.25Y",
+ "P1Y",
+ "1969-12-29T00:00:00Z/P1W",
+ }
+
+ app.config = config
+
+
+def test_get_time_grain_with_unkown_values():
+ """Should concatenate from configs and then sort in the proper order
+ putting unknown patterns at the end"""
+ config = app.config.copy()
+
+ app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
+ "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
+ }
+
+ with app.app_context():
+ time_grains = MySQLEngineSpec.get_time_grain_expressions()
+ assert list(time_grains)[-1] == "weird"
+
+ app.config = config
diff --git a/tests/db_engine_specs/hive_tests.py
b/tests/db_engine_specs/hive_tests.py
index 4d50518..1e978b7 100644
--- a/tests/db_engine_specs/hive_tests.py
+++ b/tests/db_engine_specs/hive_tests.py
@@ -21,12 +21,11 @@ from unittest import mock
import pytest
import pandas as pd
from sqlalchemy.sql import select
-from tests.test_app import app
-with app.app_context():
- from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
+from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery
+from tests.test_app import app
def test_0_progress():
@@ -170,10 +169,6 @@ def test_df_to_csv() -> None:
)
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail(mock_g):
mock_g.user = True
@@ -185,10 +180,6 @@ def test_df_to_sql_if_exists_fail(mock_g):
)
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
mock_g.user = True
@@ -203,13 +194,11 @@ def test_df_to_sql_if_exists_fail_with_schema(mock_g):
)
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
+ config = app.config.copy()
+ app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
@@ -218,23 +207,23 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3,
mock_g):
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
- HiveEngineSpec.df_to_sql(
- mock_database,
- Table(table=table_name),
- pd.DataFrame(),
- {"if_exists": "replace", "header": 1, "na_values": "mock", "sep":
"mock"},
- )
+ with app.app_context():
+ HiveEngineSpec.df_to_sql(
+ mock_database,
+ Table(table=table_name),
+ pd.DataFrame(),
+ {"if_exists": "replace", "header": 1, "na_values": "mock", "sep":
"mock"},
+ )
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
+ app.config = config
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: ""},
-)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
+ config = app.config.copy()
+ app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
@@ -244,14 +233,16 @@ def
test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
table_name = "foobar"
schema = "schema"
- HiveEngineSpec.df_to_sql(
- mock_database,
- Table(table=table_name, schema=schema),
- pd.DataFrame(),
- {"if_exists": "replace", "header": 1, "na_values": "mock", "sep":
"mock"},
- )
+ with app.app_context():
+ HiveEngineSpec.df_to_sql(
+ mock_database,
+ Table(table=table_name, schema=schema),
+ pd.DataFrame(),
+ {"if_exists": "replace", "header": 1, "na_values": "mock", "sep":
"mock"},
+ )
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
+ app.config = config
def test_is_readonly():
@@ -284,39 +275,42 @@ def test_s3_upload_prefix(schema: str, upload_prefix:
str) -> None:
def test_upload_to_s3_no_bucket_path():
- with pytest.raises(
- Exception,
- match="No upload bucket specified. You can specify one in the config
file.",
- ):
- upload_to_s3("filename", "prefix", Table("table"))
+ with app.app_context():
+ with pytest.raises(
+ Exception,
+ match="No upload bucket specified. You can specify one in the
config file.",
+ ):
+ upload_to_s3("filename", "prefix", Table("table"))
@mock.patch("boto3.client")
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
-)
def test_upload_to_s3_client_error(client):
+ config = app.config.copy()
+ app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
from botocore.exceptions import ClientError
client.return_value.upload_file.side_effect = ClientError(
{"Error": {}}, "operation_name"
)
- with pytest.raises(ClientError):
- upload_to_s3("filename", "prefix", Table("table"))
+ with app.app_context():
+ with pytest.raises(ClientError):
+ upload_to_s3("filename", "prefix", Table("table"))
+
+ app.config = config
@mock.patch("boto3.client")
[email protected](
- "superset.db_engine_specs.hive.config",
- {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"},
-)
def test_upload_to_s3_success(client):
+ config = app.config.copy()
+ app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
client.return_value.upload_file.return_value = True
- location = upload_to_s3("filename", "prefix", Table("table"))
- assert f"s3a://bucket/prefix/table" == location
+ with app.app_context():
+ location = upload_to_s3("filename", "prefix", Table("table"))
+ assert f"s3a://bucket/prefix/table" == location
+
+ app.config = config
def test_fetch_data_query_error():