This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch impersonate_email_prefix in repository https://gitbox.apache.org/repos/asf/superset.git
commit f1a2c266da910cecdeb00ee66dd811681d76dd13 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed May 29 17:26:37 2024 -0400 feat: impersonate with email prefix --- superset/db_engine_specs/trino.py | 6 +- superset/models/core.py | 13 ++-- tests/unit_tests/db_engine_specs/test_base.py | 1 + tests/unit_tests/models/core_test.py | 88 +++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 600f236b48..51b517f865 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -233,7 +233,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): execute_thread = threading.Thread( target=_execute, - args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access + args=( + execute_result, + execute_event, + current_app._get_current_object(), + ), # pylint: disable=protected-access ) execute_thread.start() diff --git a/superset/models/core.py b/superset/models/core.py index b933c1694f..9a7e05e94c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -116,7 +116,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -390,9 +392,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -477,6 +477,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ) effective_username = self.get_effective_user(sqlalchemy_url) + if effective_username and extra.get("username_from_email"): + user = security_manager.find_user(username=effective_username) + if user and user.email: + effective_username = user.email.split("@")[0] + oauth2_config = self.get_oauth2_config() access_token = ( get_oauth2_access_token( diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 0950dcb439..14ccdda851 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -19,6 +19,7 @@ from __future__ import annotations +import json from textwrap import dedent from typing import Any diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index e653eee716..41af0cb8b2 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -21,6 +21,7 @@ from datetime import datetime import pytest from pytest_mock import MockFixture from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import make_url from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database @@ -289,3 +290,90 @@ def test_get_all_catalog_names(mocker: MockFixture) -> None: assert database.get_all_catalog_names(force=True) == {"examples", "other"} get_inspector.assert_called_with(ssh_tunnel=None) + + +def test_get_sqla_engine(mocker: MockFixture) -> None: + """ + Test `_get_sqla_engine`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "[email protected]" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"source": "Apache Superset"}, + ) + + +def test_get_sqla_engine_user_impersonation(mocker: MockFixture) -> None: + """ + Test user impersonation in `_get_sqla_engine`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "[email protected]" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + impersonate_user=True, + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"user": "alice", "source": "Apache Superset"}, + ) + + +def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None: + """ + Test user impersonation in `_get_sqla_engine` with `username_from_email`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "[email protected]" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + extra=json.dumps({"username_from_email": True}), + impersonate_user=True, + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"user": "alice.doe", "source": "Apache Superset"}, + )
