This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch impersonate_email_prefix
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f1a2c266da910cecdeb00ee66dd811681d76dd13
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 29 17:26:37 2024 -0400

    feat: impersonate with email prefix
---
 superset/db_engine_specs/trino.py             |  6 +-
 superset/models/core.py                       | 13 ++--
 tests/unit_tests/db_engine_specs/test_base.py |  1 +
 tests/unit_tests/models/core_test.py          | 88 +++++++++++++++++++++++++++
 4 files changed, 103 insertions(+), 5 deletions(-)

diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 600f236b48..51b517f865 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -233,7 +233,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
 
         execute_thread = threading.Thread(
             target=_execute,
-            args=(execute_result, execute_event, 
current_app._get_current_object()),  # pylint: disable=protected-access
+            args=(
+                execute_result,
+                execute_event,
+                current_app._get_current_object(),
+            ),  # pylint: disable=protected-access
         )
         execute_thread.start()
 
diff --git a/superset/models/core.py b/superset/models/core.py
index b933c1694f..9a7e05e94c 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -116,7 +116,9 @@ class ConfigurationMethod(StrEnum):
     DYNAMIC_FORM = "dynamic_form"
 
 
-class Database(Model, AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class Database(
+    Model, AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """An ORM object that stores Database related information"""
 
     __tablename__ = "dbs"
@@ -390,9 +392,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return (
             username
             if (username := get_username())
-            else object_url.username
-            if self.impersonate_user
-            else None
+            else object_url.username if self.impersonate_user else None
         )
 
     @contextmanager
@@ -477,6 +477,11 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         )
 
         effective_username = self.get_effective_user(sqlalchemy_url)
+        if effective_username and extra.get("username_from_email"):
+            user = security_manager.find_user(username=effective_username)
+            if user and user.email:
+                effective_username = user.email.split("@")[0]
+
         oauth2_config = self.get_oauth2_config()
         access_token = (
             get_oauth2_access_token(
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index 0950dcb439..14ccdda851 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import json
 from textwrap import dedent
 from typing import Any
 
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index e653eee716..41af0cb8b2 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -21,6 +21,7 @@ from datetime import datetime
 import pytest
 from pytest_mock import MockFixture
 from sqlalchemy.engine.reflection import Inspector
+from sqlalchemy.engine.url import make_url
 
 from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.models.core import Database
@@ -289,3 +290,90 @@ def test_get_all_catalog_names(mocker: MockFixture) -> 
None:
 
     assert database.get_all_catalog_names(force=True) == {"examples", "other"}
     get_inspector.assert_called_with(ssh_tunnel=None)
+
+
+def test_get_sqla_engine(mocker: MockFixture) -> None:
+    """
+    Test `_get_sqla_engine`.
+    """
+    from superset.models.core import Database
+
+    user = mocker.MagicMock()
+    user.email = "[email protected]"
+    mocker.patch(
+        "superset.models.core.security_manager.find_user",
+        return_value=user,
+    )
+    mocker.patch("superset.models.core.get_username", return_value="alice")
+
+    create_engine = mocker.patch("superset.models.core.create_engine")
+
+    database = Database(
+        database_name="my_db",
+        sqlalchemy_uri="trino://",
+    )
+    database._get_sqla_engine(nullpool=False)
+
+    create_engine.assert_called_with(
+        make_url("trino:///"),
+        connect_args={"source": "Apache Superset"},
+    )
+
+
+def test_get_sqla_engine_user_impersonation(mocker: MockFixture) -> None:
+    """
+    Test user impersonation in `_get_sqla_engine`.
+    """
+    from superset.models.core import Database
+
+    user = mocker.MagicMock()
+    user.email = "[email protected]"
+    mocker.patch(
+        "superset.models.core.security_manager.find_user",
+        return_value=user,
+    )
+    mocker.patch("superset.models.core.get_username", return_value="alice")
+
+    create_engine = mocker.patch("superset.models.core.create_engine")
+
+    database = Database(
+        database_name="my_db",
+        sqlalchemy_uri="trino://",
+        impersonate_user=True,
+    )
+    database._get_sqla_engine(nullpool=False)
+
+    create_engine.assert_called_with(
+        make_url("trino:///"),
+        connect_args={"user": "alice", "source": "Apache Superset"},
+    )
+
+
+def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None:
+    """
+    Test user impersonation in `_get_sqla_engine` with `username_from_email`.
+    """
+    from superset.models.core import Database
+
+    user = mocker.MagicMock()
+    user.email = "[email protected]"
+    mocker.patch(
+        "superset.models.core.security_manager.find_user",
+        return_value=user,
+    )
+    mocker.patch("superset.models.core.get_username", return_value="alice")
+
+    create_engine = mocker.patch("superset.models.core.create_engine")
+
+    database = Database(
+        database_name="my_db",
+        sqlalchemy_uri="trino://",
+        extra=json.dumps({"username_from_email": True}),
+        impersonate_user=True,
+    )
+    database._get_sqla_engine(nullpool=False)
+
+    create_engine.assert_called_with(
+        make_url("trino:///"),
+        connect_args={"user": "alice.doe", "source": "Apache Superset"},
+    )

Reply via email to