This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch client-info-db-extra in repository https://gitbox.apache.org/repos/asf/superset.git
commit 1e049786a1a5134d64f241fa930a98e6281216f7 Author: Beto Dealmeida <[email protected]> AuthorDate: Thu Jun 6 16:44:58 2024 -0400 Add tests --- superset/db_engine_specs/base.py | 10 +++--- superset/models/core.py | 16 +++++----- tests/unit_tests/models/core_test.py | 2 +- tests/unit_tests/sql_lab_test.py | 62 ++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 15 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ab1f3ece06..bf570948fa 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -130,9 +130,7 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression( - ColumnClause -): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -390,9 +388,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods max_column_name_length: int | None = None try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = ( - {} - ) + custom_errors: dict[ + Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] + ] = {} # Whether the engine supports file uploads # if True, database will be listed as option in the upload file form diff --git a/superset/models/core.py b/superset/models/core.py index 6814e3f576..c8c875e435 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -116,9 +116,7 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database( - Model, AuditMixinNullable, ImportExportMixin -): # pylint: disable=too-many-public-methods +class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -392,7 +390,9 @@ class Database( return ( username if (username := get_username()) - else object_url.username if self.impersonate_user else None + else object_url.username + if self.impersonate_user + else None ) @contextmanager @@ -1074,8 +1074,8 @@ class Database( admins to create custom OAuth2 clients from the Superset UI, and assign them to specific databases. """ - config = json.loads(self.encrypted_extra or "{}") - oauth2_client_info = config.get("oauth2_client_info", {}) + encrypted_extra = json.loads(self.encrypted_extra or "{}") + oauth2_client_info = encrypted_extra.get("oauth2_client_info", {}) return bool(oauth2_client_info) or self.db_engine_spec.is_oauth2_enabled() def get_oauth2_config(self) -> OAuth2ClientConfig | None: @@ -1087,8 +1087,8 @@ class Database( admins to create custom OAuth2 clients from the Superset UI, and assign them to specific databases. """ - config = json.loads(self.encrypted_extra or "{}") - if oauth2_client_info := config.get("oauth2_client_info"): + encrypted_extra = json.loads(self.encrypted_extra or "{}") + if oauth2_client_info := encrypted_extra.get("oauth2_client_info"): schema = OAuth2ClientConfigSchema() client_config = schema.load(oauth2_client_info) return cast(OAuth2ClientConfig, client_config) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index f1e7138b55..c4d642baf5 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -451,7 +451,7 @@ def test_raw_connection_oauth(mocker: MockFixture) -> None: sqlalchemy_uri="sqlite://", encrypted_extra=json.dumps(oauth2_client_info), ) - database.db_engine_spec.oauth2_exception = OAuth2Error + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore get_sqla_engine = mocker.patch.object(database, "get_sqla_engine") get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error( "OAuth2 required" diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 3b2e7690e1..6ac055df13 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -16,12 +16,22 @@ # under the License. # pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals +import json +from uuid import UUID + import sqlparse +from freezegun import freeze_time from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session from superset import db +from superset.common.db_query_status import QueryStatus +from superset.errors import ErrorLevel, SupersetErrorType +from superset.exceptions import OAuth2Error +from superset.models.core import Database +from superset.sql_lab import get_sql_results from superset.utils.core import override_user +from tests.unit_tests.models.core_test import oauth2_client_info def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: @@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery( query.executed_sql == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6" ) + + +@freeze_time("2021-04-01T00:00:00Z") +def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: + """ + Test that `get_sql_results` works with OAuth2. + """ + app_context = app.test_request_context() + app_context.push() + + mocker.patch( + "superset.db_engine_specs.base.uuid4", + return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"), + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user = mocker.MagicMock() + g.user.id = 42 + + database = Database( + id=1, + database_name="my_db", + sqlalchemy_uri="sqlite://", + encrypted_extra=json.dumps(oauth2_client_info), + ) + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore + get_sqla_engine = mocker.patch.object(database, "get_sqla_engine") + get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error( + "OAuth2 required" + ) + + query = mocker.MagicMock() + query.database = database + mocker.patch("superset.sql_lab.get_query", return_value=query) + + payload = get_sql_results(query_id=1, rendered_query="SELECT 1") + assert payload == { + "status": QueryStatus.FAILED, + "error": "You don't have permission to access the data.", + "errors": [ + { + "message": "You don't have permission to access the data.", + "error_type": SupersetErrorType.OAUTH2_REDIRECT, + "level": ErrorLevel.WARNING, + "extra": { + "url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO192 [...] + "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + }, + } + ], + }
