This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 94b3d2f0f0 refactor: Deprecate ensure_user_is_set in favor of 
override_user (#20502)
94b3d2f0f0 is described below

commit 94b3d2f0f0e920e667bf2c2d9d3fbfd9ebcc3ffd
Author: John Bodley <[email protected]>
AuthorDate: Tue Jul 5 10:57:40 2022 -0700

    refactor: Deprecate ensure_user_is_set in favor of override_user (#20502)
    
    Co-authored-by: John Bodley <[email protected]>
---
 superset/tasks/async_queries.py                    | 170 +++++++++++----------
 superset/utils/core.py                             |  18 ++-
 tests/integration_tests/access_tests.py            |  24 +--
 .../integration_tests/tasks/async_queries_tests.py |  84 ++--------
 4 files changed, 122 insertions(+), 174 deletions(-)

diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
index 74adcd080c..1157c5fd37 100644
--- a/superset/tasks/async_queries.py
+++ b/superset/tasks/async_queries.py
@@ -33,6 +33,7 @@ from superset.extensions import (
     security_manager,
 )
 from superset.utils.cache import generate_cache_key, set_and_log_cache
+from superset.utils.core import override_user
 from superset.views.utils import get_datasource_info, get_viz
 
 if TYPE_CHECKING:
@@ -44,16 +45,6 @@ query_timeout = current_app.config[
 ]  # TODO: new config key
 
 
-def ensure_user_is_set(user_id: Optional[int]) -> None:
-    user_is_not_set = not (hasattr(g, "user") and g.user is not None)
-    if user_is_not_set and user_id is not None:
-        # pylint: disable=assigning-non-slot
-        g.user = security_manager.get_user_by_id(user_id)
-    elif user_is_not_set:
-        # pylint: disable=assigning-non-slot
-        g.user = security_manager.get_anonymous_user()
-
-
 def set_form_data(form_data: Dict[str, Any]) -> None:
     # pylint: disable=assigning-non-slot
     g.form_data = form_data
@@ -76,30 +67,35 @@ def load_chart_data_into_cache(
     # pylint: disable=import-outside-toplevel
     from superset.charts.data.commands.get_data_command import ChartDataCommand
 
-    try:
-        ensure_user_is_set(job_metadata.get("user_id"))
-        set_form_data(form_data)
-        query_context = _create_query_context_from_form(form_data)
-        command = ChartDataCommand(query_context)
-        result = command.run(cache=True)
-        cache_key = result["cache_key"]
-        result_url = f"/api/v1/chart/data/{cache_key}"
-        async_query_manager.update_job(
-            job_metadata,
-            async_query_manager.STATUS_DONE,
-            result_url=result_url,
-        )
-    except SoftTimeLimitExceeded as ex:
-        logger.warning("A timeout occurred while loading chart data, error: 
%s", ex)
-        raise ex
-    except Exception as ex:
-        # TODO: QueryContext should support SIP-40 style errors
-        error = ex.message if hasattr(ex, "message") else str(ex)  # type: 
ignore # pylint: disable=no-member
-        errors = [{"message": error}]
-        async_query_manager.update_job(
-            job_metadata, async_query_manager.STATUS_ERROR, errors=errors
-        )
-        raise ex
+    user = (
+        security_manager.get_user_by_id(job_metadata.get("user_id"))
+        or security_manager.get_anonymous_user()
+    )
+
+    with override_user(user, force=False):
+        try:
+            set_form_data(form_data)
+            query_context = _create_query_context_from_form(form_data)
+            command = ChartDataCommand(query_context)
+            result = command.run(cache=True)
+            cache_key = result["cache_key"]
+            result_url = f"/api/v1/chart/data/{cache_key}"
+            async_query_manager.update_job(
+                job_metadata,
+                async_query_manager.STATUS_DONE,
+                result_url=result_url,
+            )
+        except SoftTimeLimitExceeded as ex:
+            logger.warning("A timeout occurred while loading chart data, 
error: %s", ex)
+            raise ex
+        except Exception as ex:
+            # TODO: QueryContext should support SIP-40 style errors
+            error = ex.message if hasattr(ex, "message") else str(ex)  # type: 
ignore # pylint: disable=no-member
+            errors = [{"message": error}]
+            async_query_manager.update_job(
+                job_metadata, async_query_manager.STATUS_ERROR, errors=errors
+            )
+            raise ex
 
 
 @celery_app.task(name="load_explore_json_into_cache", 
soft_time_limit=query_timeout)
@@ -110,53 +106,61 @@ def load_explore_json_into_cache(  # pylint: 
disable=too-many-locals
     force: bool = False,
 ) -> None:
     cache_key_prefix = "ejr-"  # ejr: explore_json request
-    try:
-        ensure_user_is_set(job_metadata.get("user_id"))
-        set_form_data(form_data)
-        datasource_id, datasource_type = get_datasource_info(None, None, 
form_data)
-
-        # Perform a deep copy here so that below we can cache the original
-        # value of the form_data object. This is necessary since the viz
-        # objects modify the form_data object. If the modified version were
-        # to be cached here, it will lead to a cache miss when clients
-        # attempt to retrieve the value of the completed async query.
-        original_form_data = copy.deepcopy(form_data)
-
-        viz_obj = get_viz(
-            datasource_type=cast(str, datasource_type),
-            datasource_id=datasource_id,
-            form_data=form_data,
-            force=force,
-        )
-        # run query & cache results
-        payload = viz_obj.get_payload()
-        if viz_obj.has_error(payload):
-            raise SupersetVizException(errors=payload["errors"])
-
-        # Cache the original form_data value for async retrieval
-        cache_value = {
-            "form_data": original_form_data,
-            "response_type": response_type,
-        }
-        cache_key = generate_cache_key(cache_value, cache_key_prefix)
-        set_and_log_cache(cache_manager.cache, cache_key, cache_value)
-        result_url = f"/superset/explore_json/data/{cache_key}"
-        async_query_manager.update_job(
-            job_metadata,
-            async_query_manager.STATUS_DONE,
-            result_url=result_url,
-        )
-    except SoftTimeLimitExceeded as ex:
-        logger.warning("A timeout occurred while loading explore json, error: 
%s", ex)
-        raise ex
-    except Exception as ex:
-        if isinstance(ex, SupersetVizException):
-            errors = ex.errors  # pylint: disable=no-member
-        else:
-            error = ex.message if hasattr(ex, "message") else str(ex)  # type: 
ignore # pylint: disable=no-member
-            errors = [error]
 
-        async_query_manager.update_job(
-            job_metadata, async_query_manager.STATUS_ERROR, errors=errors
-        )
-        raise ex
+    user = (
+        security_manager.get_user_by_id(job_metadata.get("user_id"))
+        or security_manager.get_anonymous_user()
+    )
+
+    with override_user(user, force=False):
+        try:
+            set_form_data(form_data)
+            datasource_id, datasource_type = get_datasource_info(None, None, 
form_data)
+
+            # Perform a deep copy here so that below we can cache the original
+            # value of the form_data object. This is necessary since the viz
+            # objects modify the form_data object. If the modified version were
+            # to be cached here, it will lead to a cache miss when clients
+            # attempt to retrieve the value of the completed async query.
+            original_form_data = copy.deepcopy(form_data)
+
+            viz_obj = get_viz(
+                datasource_type=cast(str, datasource_type),
+                datasource_id=datasource_id,
+                form_data=form_data,
+                force=force,
+            )
+            # run query & cache results
+            payload = viz_obj.get_payload()
+            if viz_obj.has_error(payload):
+                raise SupersetVizException(errors=payload["errors"])
+
+            # Cache the original form_data value for async retrieval
+            cache_value = {
+                "form_data": original_form_data,
+                "response_type": response_type,
+            }
+            cache_key = generate_cache_key(cache_value, cache_key_prefix)
+            set_and_log_cache(cache_manager.cache, cache_key, cache_value)
+            result_url = f"/superset/explore_json/data/{cache_key}"
+            async_query_manager.update_job(
+                job_metadata,
+                async_query_manager.STATUS_DONE,
+                result_url=result_url,
+            )
+        except SoftTimeLimitExceeded as ex:
+            logger.warning(
+                "A timeout occurred while loading explore json, error: %s", ex
+            )
+            raise ex
+        except Exception as ex:
+            if isinstance(ex, SupersetVizException):
+                errors = ex.errors  # pylint: disable=no-member
+            else:
+                error = ex.message if hasattr(ex, "message") else str(ex)  # 
type: ignore # pylint: disable=no-member
+                errors = [error]
+
+            async_query_manager.update_job(
+                job_metadata, async_query_manager.STATUS_ERROR, errors=errors
+            )
+            raise ex
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 336ab4e208..aeb45051b6 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1453,23 +1453,27 @@ def get_user_id() -> Optional[int]:
 
 
 @contextmanager
-def override_user(user: Optional[User]) -> Iterator[Any]:
+def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]:
     """
-    Temporarily override the current user (if defined) per `flask.g`.
+    Temporarily override the current user per `flask.g` with the specified 
user.
 
     Sometimes, often in the context of async Celery tasks, it is useful to 
switch the
     current user (which may be undefined) to different one, execute some 
SQLAlchemy
-    tasks and then revert back to the original one.
+    tasks et al. and then revert back to the original one.
 
     :param user: The override user
+    :param force: Whether to override the current user if set
     """
 
     # pylint: disable=assigning-non-slot
     if hasattr(g, "user"):
-        current = g.user
-        g.user = user
-        yield
-        g.user = current
+        if force or g.user is None:
+            current = g.user
+            g.user = user
+            yield
+            g.user = current
+        else:
+            yield
     else:
         g.user = user
         yield
diff --git a/tests/integration_tests/access_tests.py 
b/tests/integration_tests/access_tests.py
index 2e1e897a4f..5ab03055d9 100644
--- a/tests/integration_tests/access_tests.py
+++ b/tests/integration_tests/access_tests.py
@@ -562,34 +562,34 @@ def test_get_username(
     assert get_username() == username
 
 
[email protected](
-    "username",
-    [
-        None,
-        "alpha",
-        "gamma",
-    ],
-)
[email protected]("username", [None, "alpha", "gamma"])
[email protected]("force", [False, True])
 def test_override_user(
     app_context: AppContext,
     mocker: MockFixture,
     username: str,
+    force: bool,
 ) -> None:
     mock_g = mocker.patch("superset.utils.core.g", spec={})
     admin = security_manager.find_user(username="admin")
     user = security_manager.find_user(username)
 
+    with override_user(user, force):
+        assert mock_g.user == user
+
     assert not hasattr(mock_g, "user")
 
-    with override_user(user):
+    mock_g.user = None
+
+    with override_user(user, force):
         assert mock_g.user == user
 
-    assert not hasattr(mock_g, "user")
+    assert mock_g.user is None
 
     mock_g.user = admin
 
-    with override_user(user):
-        assert mock_g.user == user
+    with override_user(user, force):
+        assert mock_g.user == user if force else admin
 
     assert mock_g.user == admin
 
diff --git a/tests/integration_tests/tasks/async_queries_tests.py 
b/tests/integration_tests/tasks/async_queries_tests.py
index 5a51c06601..20d0f39eea 100644
--- a/tests/integration_tests/tasks/async_queries_tests.py
+++ b/tests/integration_tests/tasks/async_queries_tests.py
@@ -28,7 +28,6 @@ from superset.exceptions import SupersetException
 from superset.extensions import async_query_manager, security_manager
 from superset.tasks import async_queries
 from superset.tasks.async_queries import (
-    ensure_user_is_set,
     load_chart_data_into_cache,
     load_explore_json_into_cache,
 )
@@ -58,12 +57,7 @@ class TestAsyncQueries(SupersetTestCase):
             "errors": [],
         }
 
-        with mock.patch.object(
-            async_queries, "ensure_user_is_set"
-        ) as ensure_user_is_set:
-            load_chart_data_into_cache(job_metadata, query_context)
-
-        ensure_user_is_set.assert_called_once_with(user.id)
+        load_chart_data_into_cache(job_metadata, query_context)
         mock_set_form_data.assert_called_once_with(query_context)
         mock_update_job.assert_called_once_with(
             job_metadata, "done", result_url=mock.ANY
@@ -85,11 +79,7 @@ class TestAsyncQueries(SupersetTestCase):
             "errors": [],
         }
         with pytest.raises(ChartDataQueryFailedError):
-            with mock.patch.object(
-                async_queries, "ensure_user_is_set"
-            ) as ensure_user_is_set:
-                load_chart_data_into_cache(job_metadata, query_context)
-            ensure_user_is_set.assert_called_once_with(user.id)
+            load_chart_data_into_cache(job_metadata, query_context)
 
         mock_run_command.assert_called_once_with(cache=True)
         errors = [{"message": "Error: foo"}]
@@ -115,11 +105,11 @@ class TestAsyncQueries(SupersetTestCase):
         with pytest.raises(SoftTimeLimitExceeded):
             with mock.patch.object(
                 async_queries,
-                "ensure_user_is_set",
-            ) as ensure_user_is_set:
-                ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
+                "set_form_data",
+            ) as set_form_data:
+                set_form_data.side_effect = SoftTimeLimitExceeded()
                 load_chart_data_into_cache(job_metadata, form_data)
-            ensure_user_is_set.assert_called_once_with(user.id, "error", 
errors=errors)
+            set_form_data.assert_called_once_with(form_data, "error", 
errors=errors)
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     @mock.patch.object(async_query_manager, "update_job")
@@ -145,12 +135,7 @@ class TestAsyncQueries(SupersetTestCase):
             "errors": [],
         }
 
-        with mock.patch.object(
-            async_queries, "ensure_user_is_set"
-        ) as ensure_user_is_set:
-            load_explore_json_into_cache(job_metadata, form_data)
-
-        ensure_user_is_set.assert_called_once_with(user.id)
+        load_explore_json_into_cache(job_metadata, form_data)
         mock_update_job.assert_called_once_with(
             job_metadata, "done", result_url=mock.ANY
         )
@@ -172,11 +157,7 @@ class TestAsyncQueries(SupersetTestCase):
         }
 
         with pytest.raises(SupersetException):
-            with mock.patch.object(
-                async_queries, "ensure_user_is_set"
-            ) as ensure_user_is_set:
-                load_explore_json_into_cache(job_metadata, form_data)
-            ensure_user_is_set.assert_called_once_with(user.id)
+            load_explore_json_into_cache(job_metadata, form_data)
 
         mock_set_form_data.assert_called_once_with(form_data)
         errors = ["The dataset associated with this chart no longer exists"]
@@ -202,49 +183,8 @@ class TestAsyncQueries(SupersetTestCase):
         with pytest.raises(SoftTimeLimitExceeded):
             with mock.patch.object(
                 async_queries,
-                "ensure_user_is_set",
-            ) as ensure_user_is_set:
-                ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
+                "set_form_data",
+            ) as set_form_data:
+                set_form_data.side_effect = SoftTimeLimitExceeded()
                 load_explore_json_into_cache(job_metadata, form_data)
-            ensure_user_is_set.assert_called_once_with(user.id, "error", 
errors=errors)
-
-    def test_ensure_user_is_set(self):
-        g_user_is_set = hasattr(g, "user")
-        original_g_user = g.user if g_user_is_set else None
-
-        if g_user_is_set:
-            del g.user
-
-        self.assertFalse(hasattr(g, "user"))
-        ensure_user_is_set(1)
-        self.assertTrue(hasattr(g, "user"))
-        self.assertFalse(g.user.is_anonymous)
-        self.assertEqual(1, get_user_id())
-
-        del g.user
-
-        self.assertFalse(hasattr(g, "user"))
-        ensure_user_is_set(None)
-        self.assertTrue(hasattr(g, "user"))
-        self.assertTrue(g.user.is_anonymous)
-        self.assertEqual(None, get_user_id())
-
-        del g.user
-
-        g.user = security_manager.get_user_by_id(2)
-        self.assertEqual(2, get_user_id())
-
-        ensure_user_is_set(1)
-        self.assertTrue(hasattr(g, "user"))
-        self.assertFalse(g.user.is_anonymous)
-        self.assertEqual(2, get_user_id())
-
-        ensure_user_is_set(None)
-        self.assertTrue(hasattr(g, "user"))
-        self.assertFalse(g.user.is_anonymous)
-        self.assertEqual(2, get_user_id())
-
-        if g_user_is_set:
-            g.user = original_g_user
-        else:
-            del g.user
+            set_form_data.assert_called_once_with(form_data, "error", 
errors=errors)

Reply via email to