This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch client-info-db-extra
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/client-info-db-extra by this
push:
new 4e31392305 Add tests
4e31392305 is described below
commit 4e31392305f719e4f04b24a56ecedba565066200
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Jun 6 16:44:58 2024 -0400
Add tests
---
superset/db_engine_specs/base.py | 10 +++---
superset/models/core.py | 16 +++++-----
tests/unit_tests/models/core_test.py | 2 +-
tests/unit_tests/sql_lab_test.py | 62 ++++++++++++++++++++++++++++++++++++
4 files changed, 75 insertions(+), 15 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ab1f3ece06..bf570948fa 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -130,9 +130,7 @@ builtin_time_grains: dict[str | None, str] = {
}
-class TimestampExpression(
- ColumnClause
-): # pylint: disable=abstract-method, too-many-ancestors
+class TimestampExpression(ColumnClause): # pylint: disable=abstract-method,
too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements
respecting
engine-specific quoting rules as part of a string-based expression.
@@ -390,9 +388,9 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
max_column_name_length: int | None = None
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
- custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str,
Any]]] = (
- {}
- )
+ custom_errors: dict[
+ Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
+ ] = {}
# Whether the engine supports file uploads
# if True, database will be listed as option in the upload file form
diff --git a/superset/models/core.py b/superset/models/core.py
index 6814e3f576..c8c875e435 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -116,9 +116,7 @@ class ConfigurationMethod(StrEnum):
DYNAMIC_FORM = "dynamic_form"
-class Database(
- Model, AuditMixinNullable, ImportExportMixin
-): # pylint: disable=too-many-public-methods
+class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint:
disable=too-many-public-methods
"""An ORM object that stores Database related information"""
__tablename__ = "dbs"
@@ -392,7 +390,9 @@ class Database(
return (
username
if (username := get_username())
- else object_url.username if self.impersonate_user else None
+ else object_url.username
+ if self.impersonate_user
+ else None
)
@contextmanager
@@ -1074,8 +1074,8 @@ class Database(
admins to create custom OAuth2 clients from the Superset UI, and
assign them to
specific databases.
"""
- config = json.loads(self.encrypted_extra or "{}")
- oauth2_client_info = config.get("oauth2_client_info", {})
+ encrypted_extra = json.loads(self.encrypted_extra or "{}")
+ oauth2_client_info = encrypted_extra.get("oauth2_client_info", {})
return bool(oauth2_client_info) or
self.db_engine_spec.is_oauth2_enabled()
def get_oauth2_config(self) -> OAuth2ClientConfig | None:
@@ -1087,8 +1087,8 @@ class Database(
admins to create custom OAuth2 clients from the Superset UI, and
assign them to
specific databases.
"""
- config = json.loads(self.encrypted_extra or "{}")
- if oauth2_client_info := config.get("oauth2_client_info"):
+ encrypted_extra = json.loads(self.encrypted_extra or "{}")
+ if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
schema = OAuth2ClientConfigSchema()
client_config = schema.load(oauth2_client_info)
return cast(OAuth2ClientConfig, client_config)
diff --git a/tests/unit_tests/models/core_test.py
b/tests/unit_tests/models/core_test.py
index f1e7138b55..c4d642baf5 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -451,7 +451,7 @@ def test_raw_connection_oauth(mocker: MockFixture) -> None:
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
- database.db_engine_spec.oauth2_exception = OAuth2Error
+ database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 3b2e7690e1..3b6f9088a0 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -16,12 +16,22 @@
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument,
too-many-locals
+import json
+from uuid import UUID
+
import sqlparse
+from freezegun import freeze_time
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset import db
+from superset.common.db_query_status import QueryStatus
+from superset.errors import ErrorLevel, SupersetErrorType
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database
+from superset.sql_lab import get_sql_results
from superset.utils.core import override_user
+from tests.unit_tests.models.core_test import oauth2_client_info
def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
@@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery(
query.executed_sql
== "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
)
+
+
+@freeze_time("2021-04-01T00:00:00Z")
+def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
+ """
+ Test that `get_sql_results` works with OAuth2.
+ """
+ app_context = app.test_request_context()
+ app_context.push()
+
+ mocker.patch(
+ "superset.db_engine_specs.base.uuid4",
+ return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
+ )
+
+ g = mocker.patch("superset.db_engine_specs.base.g")
+ g.user = mocker.MagicMock()
+ g.user.id = 42
+
+ database = Database(
+ id=1,
+ database_name="my_db",
+ sqlalchemy_uri="sqlite://",
+ encrypted_extra=json.dumps(oauth2_client_info),
+ )
+ database.db_engine_spec.oauth2_exception = OAuth2Error
+ get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+ get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
+ "OAuth2 required"
+ )
+
+ query = mocker.MagicMock()
+ query.database = database
+ mocker.patch("superset.sql_lab.get_query", return_value=query)
+
+ payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
+ assert payload == {
+ "status": QueryStatus.FAILED,
+ "error": "You don't have permission to access the data.",
+ "errors": [
+ {
+ "message": "You don't have permission to access the data.",
+ "error_type": SupersetErrorType.OAUTH2_REDIRECT,
+ "level": ErrorLevel.WARNING,
+ "extra": {
+ "url":
"https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO192
[...]
+ "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
+ "redirect_uri":
"http://example.com/api/v1/database/oauth2/",
+ },
+ }
+ ],
+ }