This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch rename-get_sqla_engine_with_context
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to
refs/heads/rename-get_sqla_engine_with_context by this push:
new 2186d97555 chore: rename get_sqla_engine_with_context
2186d97555 is described below
commit 2186d97555e78d9c5fe73da7aec54bb3b432dc42
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Apr 12 09:37:18 2024 -0400
chore: rename get_sqla_engine_with_context
---
superset/commands/database/test_connection.py | 4 +---
superset/commands/database/validate.py | 2 +-
superset/commands/dataset/importers/v1/utils.py | 2 +-
superset/db_engine_specs/base.py | 2 +-
superset/db_engine_specs/bigquery.py | 2 +-
superset/examples/bart_lines.py | 2 +-
superset/examples/birth_names.py | 4 ++--
superset/examples/country_map.py | 2 +-
superset/examples/energy.py | 2 +-
superset/examples/flights.py | 2 +-
superset/examples/long_lat.py | 2 +-
superset/examples/multiformat_time_series.py | 2 +-
superset/examples/paris.py | 2 +-
superset/examples/random_time_series.py | 2 +-
superset/examples/sf_population_polygons.py | 2 +-
superset/examples/supported_charts_dashboard.py | 2 +-
superset/examples/world_bank.py | 2 +-
superset/extensions/metadb.py | 2 +-
superset/models/core.py | 27 ++++++++++++----------
superset/models/dashboard.py | 2 +-
superset/models/helpers.py | 4 ++--
superset/sql_lab.py | 2 +-
superset/sql_validators/presto_db.py | 4 +---
superset/utils/core.py | 2 +-
superset/utils/mock_data.py | 2 +-
tests/conftest.py | 2 +-
tests/integration_tests/celery_tests.py | 2 +-
tests/integration_tests/conftest.py | 4 ++--
tests/integration_tests/csv_upload_tests.py | 20 ++++++++--------
tests/integration_tests/databases/api_tests.py | 2 +-
tests/integration_tests/datasets/api_tests.py | 10 ++++----
tests/integration_tests/datasets/commands_tests.py | 4 ++--
tests/integration_tests/datasource_tests.py | 4 ++--
.../db_engine_specs/hive_tests.py | 4 ++--
.../integration_tests/fixtures/energy_dashboard.py | 4 ++--
.../fixtures/unicode_dashboard.py | 4 ++--
.../fixtures/world_bank_dashboard.py | 4 ++--
tests/integration_tests/model_tests.py | 26 ++++++++++-----------
tests/integration_tests/reports/commands_tests.py | 4 ++--
tests/integration_tests/sql_validator_tests.py | 2 +-
tests/integration_tests/sqla_models_tests.py | 2 +-
tests/integration_tests/sqllab_tests.py | 6 ++---
tests/unit_tests/extensions/test_sqlalchemy.py | 4 ++--
tests/unit_tests/models/core_test.py | 2 +-
tests/unit_tests/models/helpers_test.py | 6 ++---
45 files changed, 99 insertions(+), 100 deletions(-)
diff --git a/superset/commands/database/test_connection.py
b/superset/commands/database/test_connection.py
index 431918c6bc..6bf69bbb87 100644
--- a/superset/commands/database/test_connection.py
+++ b/superset/commands/database/test_connection.py
@@ -138,9 +138,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
- with database.get_sqla_engine_with_context(
- override_ssh_tunnel=ssh_tunnel
- ) as engine:
+ with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as
engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
diff --git a/superset/commands/database/validate.py
b/superset/commands/database/validate.py
index 83bbc4e90a..e550f51d70 100644
--- a/superset/commands/database/validate.py
+++ b/superset/commands/database/validate.py
@@ -102,7 +102,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
database.db_engine_spec.mutate_db_for_connection_test(database)
alive = False
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
diff --git a/superset/commands/dataset/importers/v1/utils.py
b/superset/commands/dataset/importers/v1/utils.py
index 04fc81e241..50bb916b07 100644
--- a/superset/commands/dataset/importers/v1/utils.py
+++ b/superset/commands/dataset/importers/v1/utils.py
@@ -217,7 +217,7 @@ def load_data(data_uri: str, dataset: SqlaTable, database:
Database) -> None:
)
else:
logger.warning("Loading data outside the import transaction")
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
df.to_sql(
dataset.table_name,
con=engine,
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index bcb4035c9c..ec1cc741d5 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -771,7 +771,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
... connection.execute(sql)
"""
- return database.get_sqla_engine_with_context(schema=schema,
source=source)
+ return database.get_sqla_engine(schema=schema, source=source)
@classmethod
def get_timestamp_expr(
diff --git a/superset/db_engine_specs/bigquery.py
b/superset/db_engine_specs/bigquery.py
index a8d834276e..63860e8aa8 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -456,7 +456,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint:
disable=too-many-public-met
In BigQuery, a catalog is called a "project".
"""
engine: Engine
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
projects = client.list_projects()
diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py
index ad96aecac4..9ce27d4952 100644
--- a/superset/examples/bart_lines.py
+++ b/superset/examples/bart_lines.py
@@ -29,7 +29,7 @@ from .helpers import get_example_url,
get_table_connector_registry
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index c9e38f1686..2e711bef29 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -63,7 +63,7 @@ def load_data(tbl_name: str, database: Database, sample: bool
= False) -> None:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
pdf.to_sql(
@@ -91,7 +91,7 @@ def load_birth_names(
) -> None:
"""Loading birth name dataset from a zip file in the repo"""
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py
index 3caf637584..59c257bc80 100644
--- a/superset/examples/country_map.py
+++ b/superset/examples/country_map.py
@@ -40,7 +40,7 @@ def load_country_map_data(only_metadata: bool = False, force:
bool = False) -> N
tbl_name = "birth_france_by_region"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/energy.py b/superset/examples/energy.py
index 998ee97a30..1f11c0f3f5 100644
--- a/superset/examples/energy.py
+++ b/superset/examples/energy.py
@@ -42,7 +42,7 @@ def load_energy(
tbl_name = "energy_usage"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/flights.py b/superset/examples/flights.py
index c7890cfa18..a42df2023c 100644
--- a/superset/examples/flights.py
+++ b/superset/examples/flights.py
@@ -27,7 +27,7 @@ def load_flights(only_metadata: bool = False, force: bool =
False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py
index 6f7cc64020..95cccadc24 100644
--- a/superset/examples/long_lat.py
+++ b/superset/examples/long_lat.py
@@ -39,7 +39,7 @@ def load_long_lat_data(only_metadata: bool = False, force:
bool = False) -> None
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/multiformat_time_series.py
b/superset/examples/multiformat_time_series.py
index 4c1e796316..91799b2c2c 100644
--- a/superset/examples/multiformat_time_series.py
+++ b/superset/examples/multiformat_time_series.py
@@ -39,7 +39,7 @@ def load_multiformat_time_series( # pylint:
disable=too-many-locals
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/paris.py b/superset/examples/paris.py
index fa5c77b84d..cea784be77 100644
--- a/superset/examples/paris.py
+++ b/superset/examples/paris.py
@@ -28,7 +28,7 @@ from .helpers import get_example_url,
get_table_connector_registry
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False)
-> None:
tbl_name = "paris_iris_mapping"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/random_time_series.py
b/superset/examples/random_time_series.py
index 4a2d10aee9..9b5306781d 100644
--- a/superset/examples/random_time_series.py
+++ b/superset/examples/random_time_series.py
@@ -37,7 +37,7 @@ def load_random_time_series_data(
"""Loading random time series data from a zip file in the repo"""
tbl_name = "random_time_series"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/sf_population_polygons.py
b/superset/examples/sf_population_polygons.py
index ba5905f58a..d97ffd3ae5 100644
--- a/superset/examples/sf_population_polygons.py
+++ b/superset/examples/sf_population_polygons.py
@@ -30,7 +30,7 @@ def load_sf_population_polygons(
) -> None:
tbl_name = "sf_population_polygons"
database = database_utils.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/examples/supported_charts_dashboard.py
b/superset/examples/supported_charts_dashboard.py
index 6ca33a87c9..371f03d18b 100644
--- a/superset/examples/supported_charts_dashboard.py
+++ b/superset/examples/supported_charts_dashboard.py
@@ -439,7 +439,7 @@ def load_supported_charts_dashboard() -> None:
"""Loading a dashboard featuring supported charts"""
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index 5e895fd78a..74ea2c43ad 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -48,7 +48,7 @@ def load_world_bank_health_n_pop( # pylint:
disable=too-many-locals, too-many-s
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
database = superset.utils.database.get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)
diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py
index bdfe1ae1e7..ea6ce118c9 100644
--- a/superset/extensions/metadb.py
+++ b/superset/extensions/metadb.py
@@ -315,7 +315,7 @@ class SupersetShillelaghAdapter(Adapter):
# store this callable for later whenever we need an engine
self.engine_context = partial(
- database.get_sqla_engine_with_context,
+ database.get_sqla_engine,
self.schema,
)
diff --git a/superset/models/core.py b/superset/models/core.py
index 92f6946f1e..bfd4c39593 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -382,7 +382,7 @@ class Database(
)
@contextmanager
- def get_sqla_engine_with_context(
+ def get_sqla_engine(
self,
schema: str | None = None,
nullpool: bool = True,
@@ -424,6 +424,11 @@ class Database(
sqlalchemy_uri=sqlalchemy_uri,
)
+ # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but
we kept a
+ # reference to the old method to prevent breaking third-party applications.
+ # TODO (betodealmeida): Remove in 5.0
+ get_sqla_engine_with_context = get_sqla_engine
+
def _get_sqla_engine(
self,
schema: str | None = None,
@@ -531,7 +536,7 @@ class Database(
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
- with self.get_sqla_engine_with_context(
+ with self.get_sqla_engine(
schema=schema, nullpool=nullpool, source=source
) as engine:
with closing(engine.raw_connection()) as conn:
@@ -574,7 +579,7 @@ class Database(
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
- with self.get_sqla_engine_with_context(schema) as engine:
+ with self.get_sqla_engine(schema) as engine:
engine_url = engine.url
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
@@ -636,7 +641,7 @@ class Database(
return df
def compile_sqla_query(self, qry: Select, schema: str | None = None) ->
str:
- with self.get_sqla_engine_with_context(schema) as engine:
+ with self.get_sqla_engine(schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds":
True}))
# pylint: disable=protected-access
@@ -656,7 +661,7 @@ class Database(
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
- with self.get_sqla_engine_with_context(schema) as engine:
+ with self.get_sqla_engine(schema) as engine:
return self.db_engine_spec.select_star(
self,
table_name,
@@ -753,9 +758,7 @@ class Database(
def get_inspector_with_context(
self, ssh_tunnel: SSHTunnel | None = None
) -> Inspector:
- with self.get_sqla_engine_with_context(
- override_ssh_tunnel=ssh_tunnel
- ) as engine:
+ with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
yield sqla.inspect(engine)
@cache_util.memoized_func(
@@ -835,7 +838,7 @@ class Database(
def get_table(self, table_name: str, schema: str | None = None) -> Table:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
- with self.get_sqla_engine_with_context() as engine:
+ with self.get_sqla_engine() as engine:
return Table(
table_name,
meta,
@@ -939,11 +942,11 @@ class Database(
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
- with self.get_sqla_engine_with_context() as engine:
+ with self.get_sqla_engine() as engine:
return engine.has_table(table.table_name, table.schema or None)
def has_table_by_name(self, table_name: str, schema: str | None = None) ->
bool:
- with self.get_sqla_engine_with_context() as engine:
+ with self.get_sqla_engine() as engine:
return engine.has_table(table_name, schema)
@classmethod
@@ -962,7 +965,7 @@ class Database(
return view_name in view_names
def has_view(self, view_name: str, schema: str | None = None) -> bool:
- with self.get_sqla_engine_with_context(schema) as engine:
+ with self.get_sqla_engine(schema) as engine:
return engine.run_callable(
self._has_view, engine.dialect, view_name, schema
)
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index 0a0d789c7a..aa961a2ffa 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -217,7 +217,7 @@ class Dashboard(AuditMixinNullable, ImportExportMixin,
Model):
@property
def sqla_metadata(self) -> None:
# pylint: disable=no-member
- with self.get_sqla_engine_with_context() as engine:
+ with self.get_sqla_engine() as engine:
meta = MetaData(bind=engine)
meta.reflect()
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 4b22873903..ad90e664ba 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1390,7 +1390,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
if self.fetch_values_predicate:
qry =
qry.where(self.get_fetch_values_predicate(template_processor=tp))
- with self.database.get_sqla_engine_with_context() as engine:
+ with self.database.get_sqla_engine() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
@@ -1992,7 +1992,7 @@ class ExploreMixin: # pylint:
disable=too-many-public-methods
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in
select_exprs]
):
- with self.database.get_sqla_engine_with_context() as engine:
+ with self.database.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote
col = literal_column(quote(col.name))
direction = sa.asc if ascending else sa.desc
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index e34f7e2fde..e87ae9c5b7 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -644,7 +644,7 @@ def cancel_query(query: Query) -> bool:
if cancel_query_id is None:
return False
- with query.database.get_sqla_engine_with_context(
+ with query.database.get_sqla_engine(
query.schema, source=QuerySource.SQL_LAB
) as engine:
with closing(engine.raw_connection()) as conn:
diff --git a/superset/sql_validators/presto_db.py
b/superset/sql_validators/presto_db.py
index 4852f70ee4..8e7d8c7209 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -160,9 +160,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
logger.info("Validating %i statement(s)", len(statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
- with database.get_sqla_engine_with_context(
- schema, source=QuerySource.SQL_LAB
- ) as engine:
+ with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as
engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: list[SQLValidationAnnotation] = []
diff --git a/superset/utils/core.py b/superset/utils/core.py
index de1034ddb0..988baed0af 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1170,7 +1170,7 @@ def get_example_default_schema() -> str | None:
Return the default schema of the examples database, if any.
"""
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
return inspect(engine).default_schema_name
diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py
index 67bd9ad73e..fc082ecb45 100644
--- a/superset/utils/mock_data.py
+++ b/superset/utils/mock_data.py
@@ -184,7 +184,7 @@ def add_data(
database = get_example_database()
table_exists = database.has_table_by_name(table_name)
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
if columns is None:
if not table_exists:
raise Exception( # pylint: disable=broad-exception-raised
diff --git a/tests/conftest.py b/tests/conftest.py
index 9d13e58170..c659a85243 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -72,7 +72,7 @@ def example_db_provider() -> Callable[[], Database]:
@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
with app.app_context():
- with example_db_provider().get_sqla_engine_with_context() as engine:
+ with example_db_provider().get_sqla_engine() as engine:
return engine
diff --git a/tests/integration_tests/celery_tests.py
b/tests/integration_tests/celery_tests.py
index 5774d8920a..384e6674af 100644
--- a/tests/integration_tests/celery_tests.py
+++ b/tests/integration_tests/celery_tests.py
@@ -113,7 +113,7 @@ def drop_table_if_exists(table_name: str, table_type:
CtasMethod) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute(sql)
diff --git a/tests/integration_tests/conftest.py
b/tests/integration_tests/conftest.py
index b90416587c..cc11c4df47 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -212,7 +212,7 @@ def setup_presto_if_needed():
if backend in {"presto", "hive"}:
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")
@@ -343,7 +343,7 @@ def physical_dataset():
example_database = get_example_database()
- with example_database.get_sqla_engine_with_context() as engine:
+ with example_database.get_sqla_engine() as engine:
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
diff --git a/tests/integration_tests/csv_upload_tests.py
b/tests/integration_tests/csv_upload_tests.py
index 741f4c1bc9..85be02cff3 100644
--- a/tests/integration_tests/csv_upload_tests.py
+++ b/tests/integration_tests/csv_upload_tests.py
@@ -71,7 +71,7 @@ def _setup_csv_upload():
yield
upload_db = get_upload_db()
- with upload_db.get_sqla_engine_with_context() as engine:
+ with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
@@ -268,7 +268,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
table=CSV_UPLOAD_TABLE_W_SCHEMA,
)
- with get_upload_db().get_sqla_engine_with_context() as engine:
+ with get_upload_db().get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {ADMIN_SCHEMA_NAME}.{CSV_UPLOAD_TABLE_W_SCHEMA}
ORDER BY b"
).fetchall()
@@ -294,7 +294,7 @@ def test_import_csv_enforced_schema(mock_event_logger):
assert success_msg in resp
# Clean up
- with get_upload_db().get_sqla_engine_with_context() as engine:
+ with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
@@ -380,7 +380,7 @@ def test_import_csv(mock_event_logger):
extra={"null_values": '["", "john"]', "if_exists": "replace"},
)
# make sure that john and empty string are replaced with None
- with test_db.get_sqla_engine_with_context() as engine:
+ with test_db.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY
c").fetchall()
assert data == [(None, 1, "x"), ("paul", 2, None)]
# default null values
@@ -390,7 +390,7 @@ def test_import_csv(mock_event_logger):
assert data == [("john", 1, "x"), ("paul", 2, None)]
# cleanup
- with get_upload_db().get_sqla_engine_with_context() as engine:
+ with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype
@@ -403,12 +403,12 @@ def test_import_csv(mock_event_logger):
# you can change the type to something compatible, like an object to string
# or an int to a float
# file upload should work as normal
- with test_db.get_sqla_engine_with_context() as engine:
+ with test_db.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE} ORDER BY
b").fetchall()
assert data == [("john", 1), ("paul", 2)]
# cleanup
- with get_upload_db().get_sqla_engine_with_context() as engine:
+ with get_upload_db().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
# with dtype - wrong type
@@ -475,7 +475,7 @@ def test_import_excel(mock_event_logger):
table=EXCEL_UPLOAD_TABLE,
)
- with test_db.get_sqla_engine_with_context() as engine:
+ with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {EXCEL_UPLOAD_TABLE} ORDER BY b"
).fetchall()
@@ -541,7 +541,7 @@ def test_import_parquet(mock_event_logger):
)
assert success_msg_f1 in resp
- with test_db.get_sqla_engine_with_context() as engine:
+ with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
).fetchall()
@@ -554,7 +554,7 @@ def test_import_parquet(mock_event_logger):
success_msg_f2 = f"Columnar file {escaped_parquet(ZIP_FILENAME)} uploaded
to table {escaped_double_quotes(full_table_name)}"
assert success_msg_f2 in resp
- with test_db.get_sqla_engine_with_context() as engine:
+ with test_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * from {PARQUET_UPLOAD_TABLE} ORDER BY b"
).fetchall()
diff --git a/tests/integration_tests/databases/api_tests.py
b/tests/integration_tests/databases/api_tests.py
index b944242259..e25a74e40c 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -895,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
if database.backend == "mysql":
query = query.replace('"', "`")
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute(query)
self.login(ADMIN_USERNAME)
diff --git a/tests/integration_tests/datasets/api_tests.py
b/tests/integration_tests/datasets/api_tests.py
index 939c03a4e4..3597bcdb08 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -718,7 +718,7 @@ class TestDatasetApi(SupersetTestCase):
return
example_db = get_example_database()
- with example_db.get_sqla_engine_with_context() as engine:
+ with example_db.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as
two"
)
@@ -739,7 +739,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200
- with example_db.get_sqla_engine_with_context() as engine:
+ with example_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")
def test_create_dataset_validate_database(self):
@@ -800,7 +800,7 @@ class TestDatasetApi(SupersetTestCase):
mock_get_table.return_value = None
example_db = get_example_database()
- with example_db.get_sqla_engine_with_context() as engine:
+ with example_db.get_sqla_engine() as engine:
engine = engine
dialect = engine.dialect
@@ -2389,7 +2389,7 @@ class TestDatasetApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
examples_db = get_example_database()
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api")
engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT
2 as col")
@@ -2415,7 +2415,7 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(table)
db.session.commit()
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_create_sqla_table_api")
@pytest.mark.usefixtures(
diff --git a/tests/integration_tests/datasets/commands_tests.py
b/tests/integration_tests/datasets/commands_tests.py
index cdf3cb6d97..8063466933 100644
--- a/tests/integration_tests/datasets/commands_tests.py
+++ b/tests/integration_tests/datasets/commands_tests.py
@@ -563,7 +563,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
def test_create_dataset_command(self):
examples_db = get_example_database()
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS test_create_dataset_command")
engine.execute(
"CREATE TABLE test_create_dataset_command AS SELECT 2 as col"
@@ -585,7 +585,7 @@ class TestCreateDatasetCommand(SupersetTestCase):
self.assertEqual([owner.username for owner in table.owners],
["admin"])
db.session.delete(table)
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_create_dataset_command")
db.session.commit()
diff --git a/tests/integration_tests/datasource_tests.py
b/tests/integration_tests/datasource_tests.py
index 4b02bb59a9..34da3df35a 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -47,7 +47,7 @@ def create_test_table_context(database: Database):
schema = get_example_default_schema()
full_table_name = f"{schema}.test_table" if schema else "test_table"
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as
first, 2 as second"
)
@@ -56,7 +56,7 @@ def create_test_table_context(database: Database):
yield db.session
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE {full_table_name}")
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py
b/tests/integration_tests/db_engine_specs/hive_tests.py
index 374d99c02e..d4b2e14d58 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -193,7 +193,7 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3,
mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
-
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute
= (
+ mock_database.get_sqla_engine.return_value.__enter__.return_value.execute
= (
mock_execute
)
table_name = "foobar"
@@ -220,7 +220,7 @@ def
test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
-
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute
= (
+ mock_database.get_sqla_engine.return_value.__enter__.return_value.execute
= (
mock_execute
)
table_name = "foobar"
diff --git a/tests/integration_tests/fixtures/energy_dashboard.py
b/tests/integration_tests/fixtures/energy_dashboard.py
index 9687fb4aff..5d938e0541 100644
--- a/tests/integration_tests/fixtures/energy_dashboard.py
+++ b/tests/integration_tests/fixtures/energy_dashboard.py
@@ -38,7 +38,7 @@ ENERGY_USAGE_TBL_NAME = "energy_usage"
def load_energy_table_data():
with app.app_context():
database = get_example_database()
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
@@ -52,7 +52,7 @@ def load_energy_table_data():
)
yield
with app.app_context():
- with get_example_database().get_sqla_engine_with_context() as engine:
+ with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS energy_usage")
diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py
b/tests/integration_tests/fixtures/unicode_dashboard.py
index 78178bcde7..e68e8f0799 100644
--- a/tests/integration_tests/fixtures/unicode_dashboard.py
+++ b/tests/integration_tests/fixtures/unicode_dashboard.py
@@ -37,7 +37,7 @@ UNICODE_TBL_NAME = "unicode_test"
@pytest.fixture(scope="session")
def load_unicode_data():
with app.app_context():
- with get_example_database().get_sqla_engine_with_context() as engine:
+ with get_example_database().get_sqla_engine() as engine:
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
engine,
@@ -51,7 +51,7 @@ def load_unicode_data():
yield
with app.app_context():
- with get_example_database().get_sqla_engine_with_context() as engine:
+ with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS unicode_test")
diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py
b/tests/integration_tests/fixtures/world_bank_dashboard.py
index a53cd76aa9..6c3d29eb43 100644
--- a/tests/integration_tests/fixtures/world_bank_dashboard.py
+++ b/tests/integration_tests/fixtures/world_bank_dashboard.py
@@ -50,7 +50,7 @@ def load_world_bank_data():
"country_name": String(255),
"region": String(255),
}
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
engine,
@@ -64,7 +64,7 @@ def load_world_bank_data():
yield
with app.app_context():
- with get_example_database().get_sqla_engine_with_context() as engine:
+ with get_example_database().get_sqla_engine() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population")
diff --git a/tests/integration_tests/model_tests.py
b/tests/integration_tests/model_tests.py
index 2a4c33a281..b9cbd9332e 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -56,22 +56,22 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
model = Database(database_name="test_database",
sqlalchemy_uri=sqlalchemy_uri)
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive/default", db)
- with model.get_sqla_engine_with_context(schema="core_db") as engine:
+ with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
model = Database(database_name="test_database",
sqlalchemy_uri=sqlalchemy_uri)
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("hive", db)
- with model.get_sqla_engine_with_context(schema="core_db") as engine:
+ with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("hive/core_db", db)
@@ -79,11 +79,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
model = Database(database_name="test_database",
sqlalchemy_uri=sqlalchemy_uri)
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
- with model.get_sqla_engine_with_context(schema="foo") as engine:
+ with model.get_sqla_engine(schema="foo") as engine:
db = make_url(engine.url).database
self.assertEqual("prod", db)
@@ -97,11 +97,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "hive://[email protected]:10000/default?auth=NOSASL"
model = Database(database_name="test_database",
sqlalchemy_uri=sqlalchemy_uri)
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("default", db)
- with model.get_sqla_engine_with_context(schema="core_db") as engine:
+ with model.get_sqla_engine(schema="core_db") as engine:
db = make_url(engine.url).database
self.assertEqual("core_db", db)
@@ -112,11 +112,11 @@ class TestDatabaseModel(SupersetTestCase):
sqlalchemy_uri = "mysql://root@localhost/superset"
model = Database(database_name="test_database",
sqlalchemy_uri=sqlalchemy_uri)
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
db = make_url(engine.url).database
self.assertEqual("superset", db)
- with model.get_sqla_engine_with_context(schema="staging") as engine:
+ with model.get_sqla_engine(schema="staging") as engine:
db = make_url(engine.url).database
self.assertEqual("staging", db)
@@ -130,12 +130,12 @@ class TestDatabaseModel(SupersetTestCase):
with override_user(example_user):
model.impersonate_user = True
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertEqual(example_user.username, username)
model.impersonate_user = False
- with model.get_sqla_engine_with_context() as engine:
+ with model.get_sqla_engine() as engine:
username = make_url(engine.url).username
self.assertNotEqual(example_user.username, username)
@@ -295,7 +295,7 @@ class TestDatabaseModel(SupersetTestCase):
db = get_example_database()
table_name = "energy_usage"
sql = db.select_star(table_name, show_cols=False,
latest_partition=False)
- with db.get_sqla_engine_with_context() as engine:
+ with db.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
source = quote(table_name) if db.backend in {"presto", "hive"} else
table_name
diff --git a/tests/integration_tests/reports/commands_tests.py
b/tests/integration_tests/reports/commands_tests.py
index 0c353d1fab..9e92841a62 100644
--- a/tests/integration_tests/reports/commands_tests.py
+++ b/tests/integration_tests/reports/commands_tests.py
@@ -150,13 +150,13 @@ def assert_log(state: str, error_message: Optional[str] =
None):
@contextmanager
def create_test_table_context(database: Database):
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as
second")
engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)")
engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)")
yield db.session
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_table")
diff --git a/tests/integration_tests/sql_validator_tests.py
b/tests/integration_tests/sql_validator_tests.py
index ae8b160ae1..850cc9ada3 100644
--- a/tests/integration_tests/sql_validator_tests.py
+++ b/tests/integration_tests/sql_validator_tests.py
@@ -38,7 +38,7 @@ class TestPrestoValidator(SupersetTestCase):
self.validator = PrestoDBSQLValidator
self.database = MagicMock()
self.database_engine = (
-
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
+ self.database.get_sqla_engine.return_value.__enter__.return_value
)
self.database_conn = self.database_engine.raw_connection.return_value
self.database_cursor = self.database_conn.cursor.return_value
diff --git a/tests/integration_tests/sqla_models_tests.py
b/tests/integration_tests/sqla_models_tests.py
index 6cae6f6a14..0359317e3a 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -313,7 +313,7 @@ class TestDatabaseModel(SupersetTestCase):
query = table.database.compile_sqla_query(sqla_query.sqla_query)
database = table.database
- with database.get_sqla_engine_with_context() as engine:
+ with database.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
for metric_label in {"metric using jinja macro", "same but different"}:
diff --git a/tests/integration_tests/sqllab_tests.py
b/tests/integration_tests/sqllab_tests.py
index 8f4c42ee28..ccc76a039a 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -212,7 +212,7 @@ class TestSqlLab(SupersetTestCase):
# assertions
db.session.commit()
examples_db = get_example_database()
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
data = engine.execute(
f"SELECT * FROM admin_database.{tmp_table_name}"
).fetchall()
@@ -296,7 +296,7 @@ class TestSqlLab(SupersetTestCase):
"SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
)
- with examples_db.get_sqla_engine_with_context() as engine:
+ with examples_db.get_sqla_engine() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS
SELECT 1 as c1, 2 as c2"
)
@@ -325,7 +325,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(1, len(data["data"]))
db.session.query(Query).delete()
- with get_example_database().get_sqla_engine_with_context() as engine:
+ with get_example_database().get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS
{CTAS_SCHEMA_NAME}.test_table")
db.session.commit()
diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py
b/tests/unit_tests/extensions/test_sqlalchemy.py
index df36dc44ef..24c849f55a 100644
--- a/tests/unit_tests/extensions/test_sqlalchemy.py
+++ b/tests/unit_tests/extensions/test_sqlalchemy.py
@@ -59,7 +59,7 @@ def database1(session: Session) -> Iterator["Database"]:
@pytest.fixture
def table1(session: Session, database1: "Database") -> Iterator[None]:
- with database1.get_sqla_engine_with_context() as engine:
+ with database1.get_sqla_engine() as engine:
conn = engine.connect()
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b
INTEGER)")
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
@@ -92,7 +92,7 @@ def database2(session: Session) -> Iterator["Database"]:
@pytest.fixture
def table2(session: Session, database2: "Database") -> Iterator[None]:
- with database2.get_sqla_engine_with_context() as engine:
+ with database2.get_sqla_engine() as engine:
conn = engine.connect()
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b
TEXT)")
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2,
'twenty')")
diff --git a/tests/unit_tests/models/core_test.py
b/tests/unit_tests/models/core_test.py
index 5d6c1fcbcc..beefd3ea3c 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -220,7 +220,7 @@ def test_get_prequeries(mocker: MockFixture) -> None:
"""
mocker.patch.object(
Database,
- "get_sqla_engine_with_context",
+ "get_sqla_engine",
)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
diff --git a/tests/unit_tests/models/helpers_test.py
b/tests/unit_tests/models/helpers_test.py
index 6d9597c0d6..e3c59cbcb0 100644
--- a/tests/unit_tests/models/helpers_test.py
+++ b/tests/unit_tests/models/helpers_test.py
@@ -54,13 +54,13 @@ def test_values_for_column(mocker: MockerFixture, session:
Session) -> None:
# since we're using an in-memory SQLite database, make sure we always
# return the same engine where the table was created
@contextmanager
- def mock_get_sqla_engine_with_context():
+ def mock_get_sqla_engine():
yield engine
mocker.patch.object(
database,
- "get_sqla_engine_with_context",
- new=mock_get_sqla_engine_with_context,
+ "get_sqla_engine",
+ new=mock_get_sqla_engine,
)
table = SqlaTable(