This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new df2ee5c  Adding app context wrapper to Celery tasks (#8653)
df2ee5c is described below

commit df2ee5cbcb8bc6e5cd310b9b9015509db744a256
Author: Craig Rueda <[email protected]>
AuthorDate: Wed Nov 27 07:06:06 2019 -0800

    Adding app context wrapper to Celery tasks (#8653)
    
    * Adding app context wrapper to Celery tasks
---
 superset/app.py       | 16 ++++++++++++++++
 superset/sql_lab.py   | 37 ++++++++++++++++++-------------------
 tests/celery_tests.py | 23 ++++++++++++++++++++++-
 3 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/superset/app.py b/superset/app.py
index efa0622..5eb246e 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -96,6 +96,22 @@ class SupersetAppInitializer:
     def configure_celery(self) -> None:
         celery_app.config_from_object(self.config["CELERY_CONFIG"])
         celery_app.set_default()
+        flask_app = self.flask_app
+
+        # Here, we want to ensure that every call into Celery task has an app 
context
+        # setup properly
+        task_base = celery_app.Task
+
+        class AppContextTask(task_base):  # type: ignore
+            # pylint: disable=too-few-public-methods
+            abstract = True
+
+            # Grab each call into the task and set up an app context
+            def __call__(self, *args, **kwargs):
+                with flask_app.app_context():
+                    return task_base.__call__(self, *args, **kwargs)
+
+        celery_app.Task = AppContextTask
 
     @staticmethod
     def init_views() -> None:
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 842f292..d7fe551 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -54,6 +54,7 @@ stats_logger = config["STATS_LOGGER"]
 SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
 SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
 log_query = config["QUERY_LOGGER"]
+logger = logging.getLogger(__name__)
 
 
 class SqlLabException(Exception):
@@ -84,9 +85,9 @@ def handle_query_error(msg, query, session, payload=None):
 
 def get_query_backoff_handler(details):
     query_id = details["kwargs"]["query_id"]
-    logging.error(f"Query with id `{query_id}` could not be retrieved")
+    logger.error(f"Query with id `{query_id}` could not be retrieved")
     stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] 
- 1))
-    logging.error(f"Query {query_id}: Sleeping for a sec before retrying...")
+    logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
 
 
 def get_query_giveup_handler(details):
@@ -128,7 +129,7 @@ def session_scope(nullpool):
         session.commit()
     except Exception as e:
         session.rollback()
-        logging.exception(e)
+        logger.exception(e)
         raise
     finally:
         session.close()
@@ -166,7 +167,7 @@ def get_sql_results(
                 expand_data=expand_data,
             )
         except Exception as e:
-            logging.exception(f"Query {query_id}: {e}")
+            logger.exception(f"Query {query_id}: {e}")
             stats_logger.incr("error_sqllab_unhandled")
             query = get_query(query_id, session)
             return handle_query_error(str(e), query, session)
@@ -224,13 +225,13 @@ def execute_sql_statement(sql_statement, query, 
user_name, session, cursor):
         query.executed_sql = sql
         session.commit()
         with stats_timing("sqllab.query.time_executing_query", stats_logger):
-            logging.info(f"Query {query_id}: Running query: \n{sql}")
+            logger.info(f"Query {query_id}: Running query: \n{sql}")
             db_engine_spec.execute(cursor, sql, async_=True)
-            logging.info(f"Query {query_id}: Handling cursor")
+            logger.info(f"Query {query_id}: Handling cursor")
             db_engine_spec.handle_cursor(cursor, query, session)
 
         with stats_timing("sqllab.query.time_fetching_results", stats_logger):
-            logging.debug(
+            logger.debug(
                 "Query {}: Fetching data for query object: {}".format(
                     query_id, query.to_dict()
                 )
@@ -238,16 +239,16 @@ def execute_sql_statement(sql_statement, query, 
user_name, session, cursor):
             data = db_engine_spec.fetch_data(cursor, query.limit)
 
     except SoftTimeLimitExceeded as e:
-        logging.exception(f"Query {query_id}: {e}")
+        logger.exception(f"Query {query_id}: {e}")
         raise SqlLabTimeoutException(
             "SQL Lab timeout. This environment's policy is to kill queries "
             "after {} seconds.".format(SQLLAB_TIMEOUT)
         )
     except Exception as e:
-        logging.exception(f"Query {query_id}: {e}")
+        logger.exception(f"Query {query_id}: {e}")
         raise SqlLabException(db_engine_spec.extract_error_message(e))
 
-    logging.debug(f"Query {query_id}: Fetching cursor description")
+    logger.debug(f"Query {query_id}: Fetching cursor description")
     cursor_description = cursor.description
     return SupersetDataFrame(data, cursor_description, db_engine_spec)
 
@@ -255,7 +256,7 @@ def execute_sql_statement(sql_statement, query, user_name, 
session, cursor):
 def _serialize_payload(
     payload: dict, use_msgpack: Optional[bool] = False
 ) -> Union[bytes, str]:
-    logging.debug(f"Serializing to msgpack: {use_msgpack}")
+    logger.debug(f"Serializing to msgpack: {use_msgpack}")
     if use_msgpack:
         return msgpack.dumps(payload, default=json_iso_dttm_ser, 
use_bin_type=True)
     else:
@@ -324,9 +325,9 @@ def execute_sql_statements(
     # Breaking down into multiple statements
     parsed_query = ParsedQuery(rendered_query)
     statements = parsed_query.get_statements()
-    logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
+    logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
 
-    logging.info(f"Query {query_id}: Set query to 'running'")
+    logger.info(f"Query {query_id}: Set query to 'running'")
     query.status = QueryStatus.RUNNING
     query.start_running_time = now_as_float()
     session.commit()
@@ -350,7 +351,7 @@ def execute_sql_statements(
 
                 # Run statement
                 msg = f"Running statement {i+1} out of {statement_count}"
-                logging.info(f"Query {query_id}: {msg}")
+                logger.info(f"Query {query_id}: {msg}")
                 query.set_extra_json_key("progress", msg)
                 session.commit()
                 try:
@@ -396,9 +397,7 @@ def execute_sql_statements(
 
     if store_results and results_backend:
         key = str(uuid.uuid4())
-        logging.info(
-            f"Query {query_id}: Storing results in results backend, key: {key}"
-        )
+        logger.info(f"Query {query_id}: Storing results in results backend, 
key: {key}")
         with stats_timing("sqllab.query.results_backend_write", stats_logger):
             with stats_timing(
                 "sqllab.query.results_backend_write_serialization", 
stats_logger
@@ -411,10 +410,10 @@ def execute_sql_statements(
                 cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]
 
             compressed = zlib_compress(serialized_payload)
-            logging.debug(
+            logger.debug(
                 f"*** serialized payload size: {getsizeof(serialized_payload)}"
             )
-            logging.debug(f"*** compressed payload size: 
{getsizeof(compressed)}")
+            logger.debug(f"*** compressed payload size: 
{getsizeof(compressed)}")
             results_backend.set(key, compressed, cache_timeout)
         query.results_key = key
 
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 954c84d..521be0b 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -23,10 +23,14 @@ import time
 import unittest
 import unittest.mock as mock
 
-from tests.test_app import app  # isort:skip
+import flask
+from flask import current_app
+
+from tests.test_app import app
 from superset import db, sql_lab
 from superset.dataframe import SupersetDataFrame
 from superset.db_engine_specs.base import BaseEngineSpec
+from superset.extensions import celery_app
 from superset.models.helpers import QueryStatus
 from superset.models.sql_lab import Query
 from superset.sql_parse import ParsedQuery
@@ -69,6 +73,23 @@ class UtilityFunctionTests(SupersetTestCase):
         )
 
 
+class AppContextTests(SupersetTestCase):
+    def test_in_app_context(self):
+        @celery_app.task()
+        def my_task():
+            self.assertTrue(current_app)
+
+        # Make sure we can call tasks with an app already setup
+        my_task()
+
+        # Make sure the app gets pushed onto the stack properly
+        try:
+            popped_app = flask._app_ctx_stack.pop()
+            my_task()
+        finally:
+            flask._app_ctx_stack.push(popped_app)
+
+
 class CeleryTestCase(SupersetTestCase):
     def get_query_by_name(self, sql):
         session = db.session

Reply via email to