This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch ref-get-sqla-engine-2
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 158da8d2008fc26eb191a600d538d6796caffc3a
Author: hughhhh <[email protected]>
AuthorDate: Wed Oct 26 14:03:32 2022 -0400

    init
---
 superset/connectors/sqla/models.py               |  12 +--
 superset/connectors/sqla/utils.py                |  39 ++++---
 superset/databases/commands/test_connection.py   |  56 +++++-----
 superset/databases/commands/validate.py          |  46 ++++-----
 superset/datasets/commands/importers/v1/utils.py |   3 +-
 superset/db_engine_specs/base.py                 |   5 +-
 superset/sql_lab.py                              | 126 ++++++++++++-----------
 7 files changed, 154 insertions(+), 133 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 4855dd1af3..98aac4906f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -958,13 +958,13 @@ class SqlaTable(Model, BaseDatasource):  # pylint: 
disable=too-many-public-metho
         if self.fetch_values_predicate:
             qry = qry.where(self.get_fetch_values_predicate())
 
-        engine = self.database.get_sqla_engine()
-        sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
-        sql = self._apply_cte(sql, cte)
-        sql = self.mutate_query_from_config(sql)
+        with self.database.get_sqla_engine_with_context() as engine:
+            sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
+            sql = self._apply_cte(sql, cte)
+            sql = self.mutate_query_from_config(sql)
 
-        df = pd.read_sql_query(sql=sql, con=engine)
-        return df[column_name].to_list()
+            df = pd.read_sql_query(sql=sql, con=engine)
+            return df[column_name].to_list()
 
     def mutate_query_from_config(self, sql: str) -> str:
         """Apply config's SQL_QUERY_MUTATOR
diff --git a/superset/connectors/sqla/utils.py 
b/superset/connectors/sqla/utils.py
index 8151bfd44b..05cf8cea13 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> 
List[ResultSetColumnType]:
         )
 
     db_engine_spec = dataset.database.db_engine_spec
-    engine = dataset.database.get_sqla_engine(schema=dataset.schema)
     sql = dataset.get_template_processor().process_template(
         dataset.sql, **dataset.template_params_dict
     )
@@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> 
List[ResultSetColumnType]:
     # TODO(villebro): refactor to use same code that's used by
     #  sql_lab.py:execute_sql_statements
     try:
-        with closing(engine.raw_connection()) as conn:
-            cursor = conn.cursor()
-            query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
-            db_engine_spec.execute(cursor, query)
-            result = db_engine_spec.fetch_data(cursor, limit=1)
-            result_set = SupersetResultSet(result, cursor.description, 
db_engine_spec)
-            cols = result_set.columns
+        with dataset.database.get_sqla_engine_with_context(
+            schema=dataset.schema
+        ) as engine:
+            with closing(engine.raw_connection()) as conn:
+                cursor = conn.cursor()
+                query = dataset.database.apply_limit_to_sql(statements[0], 
limit=1)
+                db_engine_spec.execute(cursor, query)
+                result = db_engine_spec.fetch_data(cursor, limit=1)
+                result_set = SupersetResultSet(
+                    result, cursor.description, db_engine_spec
+                )
+                cols = result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
     return cols
@@ -155,14 +159,17 @@ def get_columns_description(
 ) -> List[ResultSetColumnType]:
     db_engine_spec = database.db_engine_spec
     try:
-        with closing(database.get_sqla_engine().raw_connection()) as conn:
-            cursor = conn.cursor()
-            query = database.apply_limit_to_sql(query, limit=1)
-            cursor.execute(query)
-            db_engine_spec.execute(cursor, query)
-            result = db_engine_spec.fetch_data(cursor, limit=1)
-            result_set = SupersetResultSet(result, cursor.description, 
db_engine_spec)
-            return result_set.columns
+        with database.get_sqla_engine_with_context() as engine:
+            with closing(engine.raw_connection()) as conn:
+                cursor = conn.cursor()
+                query = database.apply_limit_to_sql(query, limit=1)
+                cursor.execute(query)
+                db_engine_spec.execute(cursor, query)
+                result = db_engine_spec.fetch_data(cursor, limit=1)
+                result_set = SupersetResultSet(
+                    result, cursor.description, db_engine_spec
+                )
+                return result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
 
diff --git a/superset/databases/commands/test_connection.py 
b/superset/databases/commands/test_connection.py
index d7f7d90e49..2865174ff8 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -86,7 +86,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
             database.set_sqlalchemy_uri(uri)
             database.db_engine_spec.mutate_db_for_connection_test(database)
 
-            engine = database.get_sqla_engine()
             event_logger.log_with_context(
                 action="test_connection_attempt",
                 engine=database.db_engine_spec.__name__,
@@ -96,31 +95,36 @@ class TestConnectionDatabaseCommand(BaseCommand):
                 with closing(engine.raw_connection()) as conn:
                     return engine.dialect.do_ping(conn)
 
-            try:
-                alive = func_timeout(
-                    
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
-                    ping,
-                    args=(engine,),
-                )
-            except (sqlite3.ProgrammingError, RuntimeError):
-                # SQLite can't run on a separate thread, so ``func_timeout`` 
fails
-                # RuntimeError catches the equivalent error from duckdb.
-                alive = engine.dialect.do_ping(engine)
-            except FunctionTimedOut as ex:
-                raise SupersetTimeoutException(
-                    error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
-                    message=(
-                        "Please check your connection details and database 
settings, "
-                        "and ensure that your database is accepting 
connections, "
-                        "then try connecting again."
-                    ),
-                    level=ErrorLevel.ERROR,
-                    extra={"sqlalchemy_uri": database.sqlalchemy_uri},
-                ) from ex
-            except Exception:  # pylint: disable=broad-except
-                alive = False
-            if not alive:
-                raise DBAPIError(None, None, None)
+            with database.get_sqla_engine_with_context() as engine:
+                try:
+                    alive = func_timeout(
+                        int(
+                            app.config[
+                                "TEST_DATABASE_CONNECTION_TIMEOUT"
+                            ].total_seconds()
+                        ),
+                        ping,
+                        args=(engine,),
+                    )
+                except (sqlite3.ProgrammingError, RuntimeError):
+                    # SQLite can't run on a separate thread, so 
``func_timeout`` fails
+                    # RuntimeError catches the equivalent error from duckdb.
+                    alive = engine.dialect.do_ping(engine)
+                except FunctionTimedOut as ex:
+                    raise SupersetTimeoutException(
+                        
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
+                        message=(
+                            "Please check your connection details and database 
settings, "
+                            "and ensure that your database is accepting 
connections, "
+                            "then try connecting again."
+                        ),
+                        level=ErrorLevel.ERROR,
+                        extra={"sqlalchemy_uri": database.sqlalchemy_uri},
+                    ) from ex
+                except Exception:  # pylint: disable=broad-except
+                    alive = False
+                if not alive:
+                    raise DBAPIError(None, None, None)
 
             # Log succesful connection test with engine
             event_logger.log_with_context(
diff --git a/superset/databases/commands/validate.py 
b/superset/databases/commands/validate.py
index a8956257fa..a92fb79f83 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -101,30 +101,30 @@ class ValidateDatabaseParametersCommand(BaseCommand):
         database.set_sqlalchemy_uri(sqlalchemy_uri)
         database.db_engine_spec.mutate_db_for_connection_test(database)
 
-        engine = database.get_sqla_engine()
-        try:
-            with closing(engine.raw_connection()) as conn:
-                alive = engine.dialect.do_ping(conn)
-        except Exception as ex:
-            url = make_url_safe(sqlalchemy_uri)
-            context = {
-                "hostname": url.host,
-                "password": url.password,
-                "port": url.port,
-                "username": url.username,
-                "database": url.database,
-            }
-            errors = database.db_engine_spec.extract_errors(ex, context)
-            raise DatabaseTestConnectionFailedError(errors) from ex
+        with database.get_sqla_engine_with_context() as engine:
+            try:
+                with closing(engine.raw_connection()) as conn:
+                    alive = engine.dialect.do_ping(conn)
+            except Exception as ex:
+                url = make_url_safe(sqlalchemy_uri)
+                context = {
+                    "hostname": url.host,
+                    "password": url.password,
+                    "port": url.port,
+                    "username": url.username,
+                    "database": url.database,
+                }
+                errors = database.db_engine_spec.extract_errors(ex, context)
+                raise DatabaseTestConnectionFailedError(errors) from ex
 
-        if not alive:
-            raise DatabaseOfflineError(
-                SupersetError(
-                    message=__("Database is offline."),
-                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
-                    level=ErrorLevel.ERROR,
-                ),
-            )
+            if not alive:
+                raise DatabaseOfflineError(
+                    SupersetError(
+                        message=__("Database is offline."),
+                        error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
+                        level=ErrorLevel.ERROR,
+                    ),
+                )
 
     def validate(self) -> None:
         database_id = self._properties.get("id")
diff --git a/superset/datasets/commands/importers/v1/utils.py 
b/superset/datasets/commands/importers/v1/utils.py
index ba2b7df261..7d3998b3bb 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -168,7 +168,8 @@ def load_data(
         connection = session.connection()
     else:
         logger.warning("Loading data outside the import transaction")
-        connection = database.get_sqla_engine()
+        with database.get_sqla_engine_with_context() as engine:
+            connection = engine
 
     df.to_sql(
         dataset.table_name,
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index dabed0c7ae..9dd7594dc7 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -472,7 +472,10 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         schema: Optional[str] = None,
         source: Optional[utils.QuerySource] = None,
     ) -> Engine:
-        return database.get_sqla_engine(schema=schema, source=source)
+        with database.get_sqla_engine_with_context(
+            schema=schema, source=source
+        ) as engine:
+            return engine
 
     @classmethod
     def get_timestamp_expr(
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 96afc7f51e..6d9903c8f0 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -463,61 +463,66 @@ def execute_sql_statements(  # pylint: 
disable=too-many-arguments, too-many-loca
             )
         )
 
-    engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
-    # Sharing a single connection and cursor across the
-    # execution of all statements (if many)
-    with closing(engine.raw_connection()) as conn:
-        # closing the connection closes the cursor as well
-        cursor = conn.cursor()
-        cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
-        if cancel_query_id is not None:
-            query.set_extra_json_key(cancel_query_key, cancel_query_id)
-            session.commit()
-        statement_count = len(statements)
-        for i, statement in enumerate(statements):
-            # Check if stopped
-            session.refresh(query)
-            if query.status == QueryStatus.STOPPED:
-                payload.update({"status": query.status})
-                return payload
-
-            # For CTAS we create the table only on the last statement
-            apply_ctas = query.select_as_cta and (
-                query.ctas_method == CtasMethod.VIEW
-                or (query.ctas_method == CtasMethod.TABLE and i == 
len(statements) - 1)
-            )
-
-            # Run statement
-            msg = f"Running statement {i+1} out of {statement_count}"
-            logger.info("Query %s: %s", str(query_id), msg)
-            query.set_extra_json_key("progress", msg)
-            session.commit()
-            try:
-                result_set = execute_sql_statement(
-                    statement,
-                    query,
-                    session,
-                    cursor,
-                    log_params,
-                    apply_ctas,
-                )
-            except SqlLabQueryStoppedException:
-                payload.update({"status": QueryStatus.STOPPED})
-                return payload
-            except Exception as ex:  # pylint: disable=broad-except
-                msg = str(ex)
-                prefix_message = (
-                    f"[Statement {i+1} out of {statement_count}]"
-                    if statement_count > 1
-                    else ""
+    with database.get_sqla_engine_with_context(
+        query.schema, source=QuerySource.SQL_LAB
+    ) as engine:
+        # Sharing a single connection and cursor across the
+        # execution of all statements (if many)
+        with closing(engine.raw_connection()) as conn:
+            # closing the connection closes the cursor as well
+            cursor = conn.cursor()
+            cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
+            if cancel_query_id is not None:
+                query.set_extra_json_key(cancel_query_key, cancel_query_id)
+                session.commit()
+            statement_count = len(statements)
+            for i, statement in enumerate(statements):
+                # Check if stopped
+                session.refresh(query)
+                if query.status == QueryStatus.STOPPED:
+                    payload.update({"status": query.status})
+                    return payload
+
+                # For CTAS we create the table only on the last statement
+                apply_ctas = query.select_as_cta and (
+                    query.ctas_method == CtasMethod.VIEW
+                    or (
+                        query.ctas_method == CtasMethod.TABLE
+                        and i == len(statements) - 1
+                    )
                 )
-                payload = handle_query_error(
-                    ex, query, session, payload, prefix_message
-                )
-                return payload
 
-        # Commit the connection so CTA queries will create the table.
-        conn.commit()
+                # Run statement
+                msg = f"Running statement {i+1} out of {statement_count}"
+                logger.info("Query %s: %s", str(query_id), msg)
+                query.set_extra_json_key("progress", msg)
+                session.commit()
+                try:
+                    result_set = execute_sql_statement(
+                        statement,
+                        query,
+                        session,
+                        cursor,
+                        log_params,
+                        apply_ctas,
+                    )
+                except SqlLabQueryStoppedException:
+                    payload.update({"status": QueryStatus.STOPPED})
+                    return payload
+                except Exception as ex:  # pylint: disable=broad-except
+                    msg = str(ex)
+                    prefix_message = (
+                        f"[Statement {i+1} out of {statement_count}]"
+                        if statement_count > 1
+                        else ""
+                    )
+                    payload = handle_query_error(
+                        ex, query, session, payload, prefix_message
+                    )
+                    return payload
+
+            # Commit the connection so CTA queries will create the table.
+            conn.commit()
 
     # Success, updating the query entry in database
     query.rows = result_set.size
@@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool:
     if cancel_query_id is None:
         return False
 
-    engine = query.database.get_sqla_engine(query.schema, 
source=QuerySource.SQL_LAB)
-
-    with closing(engine.raw_connection()) as conn:
-        with closing(conn.cursor()) as cursor:
-            return query.database.db_engine_spec.cancel_query(
-                cursor, query, cancel_query_id
-            )
+    with query.database.get_sqla_engine_with_context(
+        query.schema, source=QuerySource.SQL_LAB
+    ) as engine:
+        with closing(engine.raw_connection()) as conn:
+            with closing(conn.cursor()) as cursor:
+                return query.database.db_engine_spec.cancel_query(
+                    cursor, query, cancel_query_id
+                )

Reply via email to