This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new cfda30c81a fix(sqllab): Force trino client async execution (#24859)
cfda30c81a is described below
commit cfda30c81a8ee06924b37db889c1d1ba77e2bc41
Author: Rob Moore <[email protected]>
AuthorDate: Wed Sep 6 22:20:26 2023 +0100
fix(sqllab): Force trino client async execution (#24859)
---
superset/db_engine_specs/base.py | 18 +++++++
superset/db_engine_specs/trino.py | 66 +++++++++++++++++++++++---
superset/sql_lab.py | 7 +--
tests/unit_tests/db_engine_specs/test_trino.py | 31 +++++++++++-
tests/unit_tests/sql_lab_test.py | 10 ++--
5 files changed, 114 insertions(+), 18 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ebeea81ec6..4ca71340a9 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1040,6 +1040,24 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
query object"""
# TODO: Fix circular import error caused by importing sql_lab.Query
+ @classmethod
+ def execute_with_cursor(
+ cls, cursor: Any, sql: str, query: Query, session: Session
+ ) -> None:
+ """
+ Trigger execution of a query and handle the resulting cursor.
+
+ For most implementations this just makes calls to `execute` and
+ `handle_cursor` consecutively, but in some engines (e.g. Trino) we may
+ need to handle client limitations such as lack of async support and
+ perform a more complicated operation to get information from the cursor
+ in a timely manner and facilitate operations such as query stop
+ """
+ logger.debug("Query %d: Running query: %s", query.id, sql)
+ cls.execute(cursor, sql, async_=True)
+ logger.debug("Query %d: Handling cursor", query.id)
+ cls.handle_cursor(cursor, query, session)
+
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
return f"{cls.engine} error: {cls._extract_error_message(ex)}"
diff --git a/superset/db_engine_specs/trino.py
b/superset/db_engine_specs/trino.py
index da0a56e100..cc8e531dc2 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -18,6 +18,8 @@ from __future__ import annotations
import contextlib
import logging
+import threading
+import time
from typing import Any, TYPE_CHECKING
import simplejson as json
@@ -147,14 +149,21 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) ->
None:
- if tracking_url := cls.get_tracking_url(cursor):
- query.tracking_url = tracking_url
+ """
+ Handle a trino client cursor.
+
+ WARNING: if you execute a query, it will block until complete and you
+ will not be able to handle the cursor until complete. Use
+ `execute_with_cursor` instead, to handle this asynchronously.
+ """
# Adds the executed query id to the extra payload so the query can be
cancelled
- query.set_extra_json_key(
- key=QUERY_CANCEL_KEY,
- value=(cancel_query_id := cursor.stats["queryId"]),
- )
+ cancel_query_id = cursor.query_id
+ logger.debug("Query %d: queryId %s found in cursor", query.id,
cancel_query_id)
+ query.set_extra_json_key(key=QUERY_CANCEL_KEY, value=cancel_query_id)
+
+ if tracking_url := cls.get_tracking_url(cursor):
+ query.tracking_url = tracking_url
session.commit()
@@ -169,6 +178,51 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
super().handle_cursor(cursor=cursor, query=query, session=session)
+ @classmethod
+ def execute_with_cursor(
+ cls, cursor: Any, sql: str, query: Query, session: Session
+ ) -> None:
+ """
+ Trigger execution of a query and handle the resulting cursor.
+
+ Trino's client blocks until the query is complete, so we need to run it
+ in another thread and invoke `handle_cursor` to poll for the query ID
+ to appear on the cursor in parallel.
+ """
+ execute_result: dict[str, Any] = {}
+
+ def _execute(results: dict[str, Any]) -> None:
+ logger.debug("Query %d: Running query: %s", query.id, sql)
+
+ # Pass result / exception information back to the parent thread
+ try:
+ cls.execute(cursor, sql)
+ results["complete"] = True
+ except Exception as ex: # pylint: disable=broad-except
+ results["complete"] = True
+ results["error"] = ex
+
+ execute_thread = threading.Thread(target=_execute,
args=(execute_result,))
+ execute_thread.start()
+
+ # Wait for a query ID to be available before handling the cursor, as
+ # it's required by that method; it may never become available on error.
+ while not cursor.query_id and not execute_result.get("complete"):
+ time.sleep(0.1)
+
+ logger.debug("Query %d: Handling cursor", query.id)
+ cls.handle_cursor(cursor, query, session)
+
+ # Block until the query completes; same behaviour as the client itself
+ logger.debug("Query %d: Waiting for query to complete", query.id)
+ while not execute_result.get("complete"):
+ time.sleep(0.5)
+
+ # Unfortunately we'll mangle the stack trace due to the thread, but
+ # throwing the original exception allows mapping database errors as
normal
+ if err := execute_result.get("error"):
+ raise err
+
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
if QUERY_CANCEL_KEY not in query.extra:
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 196b48b1d2..4d71e23d88 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -191,7 +191,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query, session)
-def execute_sql_statement( # pylint:
disable=too-many-arguments,too-many-statements
+def execute_sql_statement( # pylint: disable=too-many-arguments
sql_statement: str,
query: Query,
session: Session,
@@ -271,10 +271,7 @@ def execute_sql_statement( # pylint:
disable=too-many-arguments,too-many-statem
)
session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger):
- logger.debug("Query %d: Running query: %s", query.id, sql)
- db_engine_spec.execute(cursor, sql, async_=True)
- logger.debug("Query %d: Handling cursor", query.id)
- db_engine_spec.handle_cursor(cursor, query, session)
+ db_engine_spec.execute_with_cursor(cursor, sql, query, session)
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logger.debug(
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py
b/tests/unit_tests/db_engine_specs/test_trino.py
index 963953d18b..1b50a683a0 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -352,7 +352,7 @@ def test_handle_cursor_early_cancel(
query_id = "myQueryId"
cursor_mock = engine_mock.return_value.__enter__.return_value
- cursor_mock.stats = {"queryId": query_id}
+ cursor_mock.query_id = query_id
session_mock = mocker.MagicMock()
query = Query()
@@ -366,3 +366,32 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None
+
+
+def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
+ """Test that `execute_with_cursor` fetches query ID from the cursor"""
+ from superset.db_engine_specs.trino import TrinoEngineSpec
+
+ query_id = "myQueryId"
+
+ mock_cursor = mocker.MagicMock()
+ mock_cursor.query_id = None
+
+ mock_query = mocker.MagicMock()
+ mock_session = mocker.MagicMock()
+
+ def _mock_execute(*args, **kwargs):
+ mock_cursor.query_id = query_id
+
+ mock_cursor.execute.side_effect = _mock_execute
+
+ TrinoEngineSpec.execute_with_cursor(
+ cursor=mock_cursor,
+ sql="SELECT 1 FROM foo",
+ query=mock_query,
+ session=mock_session,
+ )
+
+ mock_query.set_extra_json_key.assert_called_once_with(
+ key=QUERY_CANCEL_KEY, value=query_id
+ )
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 29f45eab68..edc1fd2ec4 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -55,8 +55,8 @@ def test_execute_sql_statement(mocker: MockerFixture, app:
None) -> None:
)
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2,
force=True)
- db_engine_spec.execute.assert_called_with(
- cursor, "SELECT 42 AS answer LIMIT 2", async_=True
+ db_engine_spec.execute_with_cursor.assert_called_with(
+ cursor, "SELECT 42 AS answer LIMIT 2", query, session
)
SupersetResultSet.assert_called_with([(42,)], cursor.description,
db_engine_spec)
@@ -106,10 +106,8 @@ def test_execute_sql_statement_with_rls(
101,
force=True,
)
- db_engine_spec.execute.assert_called_with(
- cursor,
- "SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
- async_=True,
+ db_engine_spec.execute_with_cursor.assert_called_with(
+ cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
query, session
)
SupersetResultSet.assert_called_with([(42,)], cursor.description,
db_engine_spec)