This is an automated email from the ASF dual-hosted git repository. hugh pushed a commit to branch ref-get-sqla-engine-2 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 158da8d2008fc26eb191a600d538d6796caffc3a Author: hughhhh <[email protected]> AuthorDate: Wed Oct 26 14:03:32 2022 -0400 init --- superset/connectors/sqla/models.py | 12 +-- superset/connectors/sqla/utils.py | 39 ++++--- superset/databases/commands/test_connection.py | 56 +++++----- superset/databases/commands/validate.py | 46 ++++----- superset/datasets/commands/importers/v1/utils.py | 3 +- superset/db_engine_specs/base.py | 5 +- superset/sql_lab.py | 126 ++++++++++++----------- 7 files changed, 154 insertions(+), 133 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4855dd1af3..98aac4906f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -958,13 +958,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) - engine = self.database.get_sqla_engine() - sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + with self.database.get_sqla_engine_with_context() as engine: + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) + sql = self.mutate_query_from_config(sql) - df = pd.read_sql_query(sql=sql, con=engine) - return df[column_name].to_list() + df = pd.read_sql_query(sql=sql, con=engine) + return df[column_name].to_list() def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 8151bfd44b..05cf8cea13 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: ) db_engine_spec = dataset.database.db_engine_spec - engine = dataset.database.get_sqla_engine(schema=dataset.schema) sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) @@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: # TODO(villebro): refactor to use same code that's used by # sql_lab.py:execute_sql_statements try: - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - query = dataset.database.apply_limit_to_sql(statements[0], limit=1) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - cols = result_set.columns + with dataset.database.get_sqla_engine_with_context( + schema=dataset.schema + ) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = dataset.database.apply_limit_to_sql(statements[0], limit=1) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + cols = result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex return cols @@ -155,14 +159,17 @@ def get_columns_description( ) -> List[ResultSetColumnType]: db_engine_spec = database.db_engine_spec try: - with closing(database.get_sqla_engine().raw_connection()) as conn: - cursor = conn.cursor() - query = database.apply_limit_to_sql(query, limit=1) - cursor.execute(query) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - return result_set.columns + with database.get_sqla_engine_with_context() as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = database.apply_limit_to_sql(query, limit=1) + cursor.execute(query) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + return result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index d7f7d90e49..2865174ff8 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -86,7 +86,6 @@ class TestConnectionDatabaseCommand(BaseCommand): database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() event_logger.log_with_context( action="test_connection_attempt", engine=database.db_engine_spec.__name__, @@ -96,31 +95,36 @@ class TestConnectionDatabaseCommand(BaseCommand): with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - try: - alive = func_timeout( - int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()), - ping, - args=(engine,), - ) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``func_timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) - except FunctionTimedOut as ex: - raise SupersetTimeoutException( - error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, - message=( - "Please check your connection details and database settings, " - "and ensure that your database is accepting connections, " - "then try connecting again." - ), - level=ErrorLevel.ERROR, - extra={"sqlalchemy_uri": database.sqlalchemy_uri}, - ) from ex - except Exception: # pylint: disable=broad-except - alive = False - if not alive: - raise DBAPIError(None, None, None) + with database.get_sqla_engine_with_context() as engine: + try: + alive = func_timeout( + int( + app.config[ + "TEST_DATABASE_CONNECTION_TIMEOUT" + ].total_seconds() + ), + ping, + args=(engine,), + ) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``func_timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + alive = engine.dialect.do_ping(engine) + except FunctionTimedOut as ex: + raise SupersetTimeoutException( + error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + message=( + "Please check your connection details and database settings, " + "and ensure that your database is accepting connections, " + "then try connecting again." + ), + level=ErrorLevel.ERROR, + extra={"sqlalchemy_uri": database.sqlalchemy_uri}, + ) from ex + except Exception: # pylint: disable=broad-except + alive = False + if not alive: + raise DBAPIError(None, None, None) # Log succesful connection test with engine event_logger.log_with_context( diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index a8956257fa..a92fb79f83 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -101,30 +101,30 @@ class ValidateDatabaseParametersCommand(BaseCommand): database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() - try: - with closing(engine.raw_connection()) as conn: - alive = engine.dialect.do_ping(conn) - except Exception as ex: - url = make_url_safe(sqlalchemy_uri) - context = { - "hostname": url.host, - "password": url.password, - "port": url.port, - "username": url.username, - "database": url.database, - } - errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + with database.get_sqla_engine_with_context() as engine: + try: + with closing(engine.raw_connection()) as conn: + alive = engine.dialect.do_ping(conn) + except Exception as ex: + url = make_url_safe(sqlalchemy_uri) + context = { + "hostname": url.host, + "password": url.password, + "port": url.port, + "username": url.username, + "database": url.database, + } + errors = database.db_engine_spec.extract_errors(ex, context) + raise DatabaseTestConnectionFailedError(errors) from ex - if not alive: - raise DatabaseOfflineError( - SupersetError( - message=__("Database is offline."), - error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - level=ErrorLevel.ERROR, - ), - ) + if not alive: + raise DatabaseOfflineError( + SupersetError( + message=__("Database is offline."), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) def validate(self) -> None: database_id = self._properties.get("id") diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index ba2b7df261..7d3998b3bb 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -168,7 +168,8 @@ def load_data( connection = session.connection() else: logger.warning("Loading data outside the import transaction") - connection = database.get_sqla_engine() + with database.get_sqla_engine_with_context() as engine: + connection = engine df.to_sql( dataset.table_name, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dabed0c7ae..9dd7594dc7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -472,7 +472,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> Engine: - return database.get_sqla_engine(schema=schema, source=source) + with database.get_sqla_engine_with_context( + schema=schema, source=source + ) as engine: + return engine @classmethod def get_timestamp_expr( diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 96afc7f51e..6d9903c8f0 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -463,61 +463,66 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca ) ) - engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - # Sharing a single connection and cursor across the - # execution of all statements (if many) - with closing(engine.raw_connection()) as conn: - # closing the connection closes the cursor as well - cursor = conn.cursor() - cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) - if cancel_query_id is not None: - query.set_extra_json_key(cancel_query_key, cancel_query_id) - session.commit() - statement_count = len(statements) - for i, statement in enumerate(statements): - # Check if stopped - session.refresh(query) - if query.status == QueryStatus.STOPPED: - payload.update({"status": query.status}) - return payload - - # For CTAS we create the table only on the last statement - apply_ctas = query.select_as_cta and ( - query.ctas_method == CtasMethod.VIEW - or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) - ) - - # Run statement - msg = f"Running statement {i+1} out of {statement_count}" - logger.info("Query %s: %s", str(query_id), msg) - query.set_extra_json_key("progress", msg) - session.commit() - try: - result_set = execute_sql_statement( - statement, - query, - session, - cursor, - log_params, - apply_ctas, - ) - except SqlLabQueryStoppedException: - payload.update({"status": QueryStatus.STOPPED}) - return payload - except Exception as ex: # pylint: disable=broad-except - msg = str(ex) - prefix_message = ( - f"[Statement {i+1} out of {statement_count}]" - if statement_count > 1 - else "" + with database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + # Sharing a single connection and cursor across the + # execution of all statements (if many) + with closing(engine.raw_connection()) as conn: + # closing the connection closes the cursor as well + cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) + if cancel_query_id is not None: + query.set_extra_json_key(cancel_query_key, cancel_query_id) + session.commit() + statement_count = len(statements) + for i, statement in enumerate(statements): + # Check if stopped + session.refresh(query) + if query.status == QueryStatus.STOPPED: + payload.update({"status": query.status}) + return payload + + # For CTAS we create the table only on the last statement + apply_ctas = query.select_as_cta and ( + query.ctas_method == CtasMethod.VIEW + or ( + query.ctas_method == CtasMethod.TABLE + and i == len(statements) - 1 + ) ) - payload = handle_query_error( - ex, query, session, payload, prefix_message - ) - return payload - # Commit the connection so CTA queries will create the table. - conn.commit() + # Run statement + msg = f"Running statement {i+1} out of {statement_count}" + logger.info("Query %s: %s", str(query_id), msg) + query.set_extra_json_key("progress", msg) + session.commit() + try: + result_set = execute_sql_statement( + statement, + query, + session, + cursor, + log_params, + apply_ctas, + ) + except SqlLabQueryStoppedException: + payload.update({"status": QueryStatus.STOPPED}) + return payload + except Exception as ex: # pylint: disable=broad-except + msg = str(ex) + prefix_message = ( + f"[Statement {i+1} out of {statement_count}]" + if statement_count > 1 + else "" + ) + payload = handle_query_error( + ex, query, session, payload, prefix_message + ) + return payload + + # Commit the connection so CTA queries will create the table. + conn.commit() # Success, updating the query entry in database query.rows = result_set.size @@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool: if cancel_query_id is None: return False - engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - - with closing(engine.raw_connection()) as conn: - with closing(conn.cursor()) as cursor: - return query.database.db_engine_spec.cancel_query( - cursor, query, cancel_query_id - ) + with query.database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + return query.database.db_engine_spec.cancel_query( + cursor, query, cancel_query_id + )
