This is an automated email from the ASF dual-hosted git repository.

bkyryliuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new ac2937a  fix: use nullpool in the celery workers (#10819)
ac2937a is described below

commit ac2937a6c58b63c2723917a1cd111a8363a728a0
Author: Bogdan <[email protected]>
AuthorDate: Thu Sep 10 13:29:57 2020 -0700

    fix: use nullpool in the celery workers (#10819)
    
    * Use nullpool in the celery workers
    
    * Address comments
    
    Co-authored-by: bogdan kyryliuk <[email protected]>
---
 superset/cli.py                   |  12 +-
 superset/sql_lab.py               |  47 +-------
 superset/tasks/alerts/observer.py |  12 +-
 superset/tasks/schedules.py       | 247 ++++++++++++++++++++------------------
 superset/utils/celery.py          |  57 +++++++++
 tests/alerts_tests.py             |  48 ++++----
 tests/schedules_test.py           |   4 +
 7 files changed, 234 insertions(+), 193 deletions(-)

diff --git a/superset/cli.py b/superset/cli.py
index ef17682..f0f7f1e 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -34,6 +34,7 @@ from superset import app, appbuilder, security_manager
 from superset.app import create_app
 from superset.extensions import celery_app, db
 from superset.utils import core as utils
+from superset.utils.celery import session_scope
 from superset.utils.urls import get_url_path
 
 logger = logging.getLogger(__name__)
@@ -619,6 +620,11 @@ def alert() -> None:
     from superset.tasks.schedules import schedule_window
 
     click.secho("Processing one alert loop", fg="green")
-    schedule_window(
-        ScheduleType.alert, datetime.now() - timedelta(1000), datetime.now(), 
6000
-    )
+    with session_scope(nullpool=True) as session:
+        schedule_window(
+            report_type=ScheduleType.alert,
+            start_at=datetime.now() - timedelta(1000),
+            stop_at=datetime.now(),
+            resolution=6000,
+            session=session,
+        )
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 09be5e8..796ddba 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -19,33 +19,25 @@ import uuid
 from contextlib import closing
 from datetime import datetime
 from sys import getsizeof
-from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
 
 import backoff
 import msgpack
 import pyarrow as pa
 import simplejson as json
-import sqlalchemy
 from celery.exceptions import SoftTimeLimitExceeded
 from celery.task.base import Task
-from contextlib2 import contextmanager
 from flask_babel import lazy_gettext as _
-from sqlalchemy.orm import Session, sessionmaker
-from sqlalchemy.pool import NullPool
-
-from superset import (
-    app,
-    db,
-    results_backend,
-    results_backend_use_msgpack,
-    security_manager,
-)
+from sqlalchemy.orm import Session
+
+from superset import app, results_backend, results_backend_use_msgpack, 
security_manager
 from superset.dataframe import df_to_records
 from superset.db_engine_specs import BaseEngineSpec
 from superset.extensions import celery_app
 from superset.models.sql_lab import Query
 from superset.result_set import SupersetResultSet
 from superset.sql_parse import ParsedQuery
+from superset.utils.celery import session_scope
 from superset.utils.core import (
     json_iso_dttm_ser,
     QuerySource,
@@ -121,35 +113,6 @@ def get_query(query_id: int, session: Session) -> Query:
         raise SqlLabException("Failed at getting query")
 
 
-@contextmanager
-def session_scope(nullpool: bool) -> Iterator[Session]:
-    """Provide a transactional scope around a series of operations."""
-    database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
-    if "sqlite" in database_uri:
-        logger.warning(
-            "SQLite Database support for metadata databases will be removed \
-            in a future version of Superset."
-        )
-    if nullpool:
-        engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
-        session_class = sessionmaker()
-        session_class.configure(bind=engine)
-        session = session_class()
-    else:
-        session = db.session()
-        session.commit()  # HACK
-
-    try:
-        yield session
-        session.commit()
-    except Exception as ex:
-        session.rollback()
-        logger.exception(ex)
-        raise
-    finally:
-        session.close()
-
-
 @celery_app.task(
     name="sql_lab.get_sql_results",
     bind=True,
diff --git a/superset/tasks/alerts/observer.py 
b/superset/tasks/alerts/observer.py
index f7c5373..34ff668 100644
--- a/superset/tasks/alerts/observer.py
+++ b/superset/tasks/alerts/observer.py
@@ -20,22 +20,24 @@ from datetime import datetime
 from typing import Optional
 
 import pandas as pd
+from sqlalchemy.orm import Session
 
-from superset import db
 from superset.models.alerts import Alert, SQLObservation
 from superset.sql_parse import ParsedQuery
 
 logger = logging.getLogger("tasks.email_reports")
 
 
-def observe(alert_id: int) -> Optional[str]:
+# Session needs to be passed along in the celery workers and db.session cannot 
be used.
+# For more info see: https://github.com/apache/incubator-superset/issues/10530
+def observe(alert_id: int, session: Session) -> Optional[str]:
     """
     Runs the SQL query in an alert's SQLObserver and then
     stores the result in a SQLObservation.
     Returns an error message if the observer value was not valid
     """
 
-    alert = db.session.query(Alert).filter_by(id=alert_id).one()
+    alert = session.query(Alert).filter_by(id=alert_id).one()
     sql_observer = alert.sql_observer[0]
 
     value = None
@@ -57,8 +59,8 @@ def observe(alert_id: int) -> Optional[str]:
         error_msg=error_msg,
     )
 
-    db.session.add(observation)
-    db.session.commit()
+    session.add(observation)
+    session.commit()
 
     return error_msg
 
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 9643f09..7c5ad4a 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -47,8 +47,9 @@ from selenium.common.exceptions import WebDriverException
 from selenium.webdriver import chrome, firefox
 from selenium.webdriver.remote.webdriver import WebDriver
 from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
+from sqlalchemy.orm import Session
 
-from superset import app, db, security_manager, thumbnail_cache
+from superset import app, security_manager, thumbnail_cache
 from superset.extensions import celery_app, machine_auth_provider_factory
 from superset.models.alerts import Alert, AlertLog
 from superset.models.dashboard import Dashboard
@@ -62,6 +63,7 @@ from superset.models.slice import Slice
 from superset.tasks.alerts.observer import observe
 from superset.tasks.alerts.validator import get_validator_function
 from superset.tasks.slack_util import deliver_slack_msg
+from superset.utils.celery import session_scope
 from superset.utils.core import get_email_address_list, send_email_smtp
 from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
 from superset.utils.urls import get_url_path
@@ -225,7 +227,7 @@ def destroy_webdriver(
         pass
 
 
-def deliver_dashboard(
+def deliver_dashboard(  # pylint: disable=too-many-locals
     dashboard_id: int,
     recipients: Optional[str],
     slack_channel: Optional[str],
@@ -236,69 +238,70 @@ def deliver_dashboard(
     """
     Given a schedule, delivery the dashboard as an email report
     """
-    dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
+    with session_scope(nullpool=True) as session:
+        dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one()
 
-    dashboard_url = _get_url_path(
-        "Superset.dashboard", dashboard_id_or_slug=dashboard.id
-    )
-    dashboard_url_user_friendly = _get_url_path(
-        "Superset.dashboard", user_friendly=True, 
dashboard_id_or_slug=dashboard.id
-    )
-
-    # Create a driver, fetch the page, wait for the page to render
-    driver = create_webdriver()
-    window = config["WEBDRIVER_WINDOW"]["dashboard"]
-    driver.set_window_size(*window)
-    driver.get(dashboard_url)
-    time.sleep(EMAIL_PAGE_RENDER_WAIT)
-
-    # Set up a function to retry once for the element.
-    # This is buggy in certain selenium versions with firefox driver
-    get_element = getattr(driver, "find_element_by_class_name")
-    element = retry_call(
-        get_element, fargs=["grid-container"], tries=2, 
delay=EMAIL_PAGE_RENDER_WAIT
-    )
-
-    try:
-        screenshot = element.screenshot_as_png
-    except WebDriverException:
-        # Some webdrivers do not support screenshots for elements.
-        # In such cases, take a screenshot of the entire page.
-        screenshot = driver.screenshot()  # pylint: disable=no-member
-    finally:
-        destroy_webdriver(driver)
-
-    # Generate the email body and attachments
-    report_content = _generate_report_content(
-        delivery_type,
-        screenshot,
-        dashboard.dashboard_title,
-        dashboard_url_user_friendly,
-    )
+        dashboard_url = _get_url_path(
+            "Superset.dashboard", dashboard_id_or_slug=dashboard.id
+        )
+        dashboard_url_user_friendly = _get_url_path(
+            "Superset.dashboard", user_friendly=True, 
dashboard_id_or_slug=dashboard.id
+        )
 
-    subject = __(
-        "%(prefix)s %(title)s",
-        prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
-        title=dashboard.dashboard_title,
-    )
+        # Create a driver, fetch the page, wait for the page to render
+        driver = create_webdriver()
+        window = config["WEBDRIVER_WINDOW"]["dashboard"]
+        driver.set_window_size(*window)
+        driver.get(dashboard_url)
+        time.sleep(EMAIL_PAGE_RENDER_WAIT)
+
+        # Set up a function to retry once for the element.
+        # This is buggy in certain selenium versions with firefox driver
+        get_element = getattr(driver, "find_element_by_class_name")
+        element = retry_call(
+            get_element, fargs=["grid-container"], tries=2, 
delay=EMAIL_PAGE_RENDER_WAIT
+        )
 
-    if recipients:
-        _deliver_email(
-            recipients,
-            deliver_as_group,
-            subject,
-            report_content.body,
-            report_content.data,
-            report_content.images,
+        try:
+            screenshot = element.screenshot_as_png
+        except WebDriverException:
+            # Some webdrivers do not support screenshots for elements.
+            # In such cases, take a screenshot of the entire page.
+            screenshot = driver.screenshot()  # pylint: disable=no-member
+        finally:
+            destroy_webdriver(driver)
+
+        # Generate the email body and attachments
+        report_content = _generate_report_content(
+            delivery_type,
+            screenshot,
+            dashboard.dashboard_title,
+            dashboard_url_user_friendly,
         )
-    if slack_channel:
-        deliver_slack_msg(
-            slack_channel,
-            subject,
-            report_content.slack_message,
-            report_content.slack_attachment,
+
+        subject = __(
+            "%(prefix)s %(title)s",
+            prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
+            title=dashboard.dashboard_title,
         )
 
+        if recipients:
+            _deliver_email(
+                recipients,
+                deliver_as_group,
+                subject,
+                report_content.body,
+                report_content.data,
+                report_content.images,
+            )
+        if slack_channel:
+            deliver_slack_msg(
+                slack_channel,
+                subject,
+                report_content.slack_message,
+                report_content.slack_attachment,
+            )
+
 
 def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> 
ReportContent:
     slice_url = _get_url_path(
@@ -362,8 +365,8 @@ def _get_slice_data(slc: Slice, delivery_type: 
EmailDeliveryType) -> ReportConte
     return ReportContent(body, data, None, slack_message, content)
 
 
-def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
-    slice_obj = db.session.query(Slice).get(slice_id)
+def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:
+    slice_obj = session.query(Slice).get(slice_id)
 
     chart_url = get_url_path("Superset.slice", slice_id=slice_obj.id, 
standalone="true")
     screenshot = ChartScreenshot(chart_url, slice_obj.digest)
@@ -376,7 +379,7 @@ def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
         user=user, cache=thumbnail_cache, force=True,
     )
 
-    db.session.commit()
+    session.commit()
     return ScreenshotData(image_url, image_data)
 
 
@@ -427,11 +430,12 @@ def deliver_slice(  # pylint: disable=too-many-arguments
     delivery_type: EmailDeliveryType,
     email_format: SliceEmailReportFormat,
     deliver_as_group: bool,
+    session: Session,
 ) -> None:
     """
     Given a schedule, delivery the slice as an email report
     """
-    slc = db.session.query(Slice).filter_by(id=slice_id).one()
+    slc = session.query(Slice).filter_by(id=slice_id).one()
 
     if email_format == SliceEmailReportFormat.data:
         report_content = _get_slice_data(slc, delivery_type)
@@ -477,38 +481,42 @@ def schedule_email_report(  # pylint: 
disable=unused-argument
     slack_channel: Optional[str] = None,
 ) -> None:
     model_cls = get_scheduler_model(report_type)
-    schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
+    with session_scope(nullpool=True) as session:
+        schedule = session.query(model_cls).get(schedule_id)
 
-    # The user may have disabled the schedule. If so, ignore this
-    if not schedule or not schedule.active:
-        logger.info("Ignoring deactivated schedule")
-        return
-
-    recipients = recipients or schedule.recipients
-    slack_channel = slack_channel or schedule.slack_channel
-    logger.info(
-        "Starting report for slack: %s and recipients: %s.", slack_channel, 
recipients
-    )
+        # The user may have disabled the schedule. If so, ignore this
+        if not schedule or not schedule.active:
+            logger.info("Ignoring deactivated schedule")
+            return
 
-    if report_type == ScheduleType.dashboard:
-        deliver_dashboard(
-            schedule.dashboard_id,
-            recipients,
+        recipients = recipients or schedule.recipients
+        slack_channel = slack_channel or schedule.slack_channel
+        logger.info(
+            "Starting report for slack: %s and recipients: %s.",
             slack_channel,
-            schedule.delivery_type,
-            schedule.deliver_as_group,
-        )
-    elif report_type == ScheduleType.slice:
-        deliver_slice(
-            schedule.slice_id,
             recipients,
-            slack_channel,
-            schedule.delivery_type,
-            schedule.email_format,
-            schedule.deliver_as_group,
         )
-    else:
-        raise RuntimeError("Unknown report type")
+
+        if report_type == ScheduleType.dashboard:
+            deliver_dashboard(
+                schedule.dashboard_id,
+                recipients,
+                slack_channel,
+                schedule.delivery_type,
+                schedule.deliver_as_group,
+            )
+        elif report_type == ScheduleType.slice:
+            deliver_slice(
+                schedule.slice_id,
+                recipients,
+                slack_channel,
+                schedule.delivery_type,
+                schedule.email_format,
+                schedule.deliver_as_group,
+                session,
+            )
+        else:
+            raise RuntimeError("Unknown report type")
 
 
 @celery_app.task(
@@ -529,9 +537,8 @@ def schedule_alert_query(  # pylint: disable=unused-argument
     slack_channel: Optional[str] = None,
 ) -> None:
     model_cls = get_scheduler_model(report_type)
-
-    try:
-        schedule = db.session.query(model_cls).get(schedule_id)
+    with session_scope(nullpool=True) as session:
+        schedule = session.query(model_cls).get(schedule_id)
 
         # The user may have disabled the schedule. If so, ignore this
         if not schedule or not schedule.active:
@@ -539,15 +546,11 @@ def schedule_alert_query(  # pylint: 
disable=unused-argument
             return
 
         if report_type == ScheduleType.alert:
-            evaluate_alert(schedule.id, schedule.label, recipients, 
slack_channel)
+            evaluate_alert(
+                schedule.id, schedule.label, session, recipients, slack_channel
+            )
         else:
             raise RuntimeError("Unknown report type")
-    except NoSuchColumnError as column_error:
-        stats_logger.incr("run_alert_task.error.nosuchcolumnerror")
-        raise column_error
-    except ResourceClosedError as resource_error:
-        stats_logger.incr("run_alert_task.error.resourceclosederror")
-        raise resource_error
 
 
 class AlertState:
@@ -558,6 +561,7 @@ class AlertState:
 
 def deliver_alert(
     alert_id: int,
+    session: Session,
     recipients: Optional[str] = None,
     slack_channel: Optional[str] = None,
 ) -> None:
@@ -566,7 +570,7 @@ def deliver_alert(
     to its respective email and slack recipients
     """
 
-    alert = db.session.query(Alert).get(alert_id)
+    alert = session.query(Alert).get(alert_id)
 
     logging.info("Triggering alert: %s", alert)
 
@@ -588,7 +592,7 @@ def deliver_alert(
             str(alert.observations[-1].value),
             validation_error_message,
             _get_url_path("AlertModelView.show", user_friendly=True, 
pk=alert_id),
-            _get_slice_screenshot(alert.slice.id),
+            _get_slice_screenshot(alert.slice.id, session),
         )
     else:
         # TODO: dashboard delivery!
@@ -668,6 +672,7 @@ def deliver_slack_alert(alert_content: AlertContent, 
slack_channel: str) -> None
 def evaluate_alert(
     alert_id: int,
     label: str,
+    session: Session,
     recipients: Optional[str] = None,
     slack_channel: Optional[str] = None,
 ) -> None:
@@ -680,7 +685,7 @@ def evaluate_alert(
 
     try:
         logger.info("Querying observers for alert <%s:%s>", alert_id, label)
-        error_msg = observe(alert_id)
+        error_msg = observe(alert_id, session)
         if error_msg:
             state = AlertState.ERROR
             logging.error(error_msg)
@@ -694,17 +699,17 @@ def evaluate_alert(
     if state != AlertState.ERROR:
         # Don't validate alert on test runs since it may not be triggered
         if recipients or slack_channel:
-            deliver_alert(alert_id, recipients, slack_channel)
+            deliver_alert(alert_id, session, recipients, slack_channel)
             state = AlertState.TRIGGER
         # Validate during regular workflow and deliver only if triggered
-        elif validate_observations(alert_id, label):
-            deliver_alert(alert_id, recipients, slack_channel)
+        elif validate_observations(alert_id, label, session):
+            deliver_alert(alert_id, session, recipients, slack_channel)
             state = AlertState.TRIGGER
         else:
             state = AlertState.PASS
 
-    db.session.commit()
-    alert = db.session.query(Alert).get(alert_id)
+    session.commit()
+    alert = session.query(Alert).get(alert_id)
     if state != AlertState.ERROR:
         alert.last_eval_dttm = dttm_end
     alert.last_state = state
@@ -716,10 +721,10 @@ def evaluate_alert(
             state=state,
         )
     )
-    db.session.commit()
+    session.commit()
 
 
-def validate_observations(alert_id: int, label: str) -> bool:
+def validate_observations(alert_id: int, label: str, session: Session) -> bool:
     """
     Runs an alert's validators to check if it should be triggered or not
     If so, return the name of the validator that returned true
@@ -727,7 +732,7 @@ def validate_observations(alert_id: int, label: str) -> 
bool:
 
     logger.info("Validating observations for alert <%s:%s>", alert_id, label)
 
-    alert = db.session.query(Alert).get(alert_id)
+    alert = session.query(Alert).get(alert_id)
     if alert.validators:
         validator = alert.validators[0]
         validate = get_validator_function(validator.validator_type)
@@ -760,7 +765,11 @@ def next_schedules(
 
 
 def schedule_window(
-    report_type: str, start_at: datetime, stop_at: datetime, resolution: int
+    report_type: str,
+    start_at: datetime,
+    stop_at: datetime,
+    resolution: int,
+    session: Session,
 ) -> None:
     """
     Find all active schedules and schedule celery tasks for
@@ -772,8 +781,7 @@ def schedule_window(
     if not model_cls:
         return None
 
-    dbsession = db.create_scoped_session()
-    schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
+    schedules = session.query(model_cls).filter(model_cls.active.is_(True))
 
     for schedule in schedules:
         logging.info("Processing schedule %s", schedule)
@@ -810,7 +818,6 @@ def get_scheduler_action(report_type: str) -> 
Optional[Callable[..., Any]]:
 @celery_app.task(name="email_reports.schedule_hourly")
 def schedule_hourly() -> None:
     """ Celery beat job meant to be invoked hourly """
-
     if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
         logger.info("Scheduled email reports not enabled in config")
         return
@@ -820,8 +827,10 @@ def schedule_hourly() -> None:
     # Get the top of the hour
     start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, 
minute=0)
     stop_at = start_at + timedelta(seconds=3600)
-    schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
-    schedule_window(ScheduleType.slice, start_at, stop_at, resolution)
+
+    with session_scope(nullpool=True) as session:
+        schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution, 
session)
+        schedule_window(ScheduleType.slice, start_at, stop_at, resolution, 
session)
 
 
 @celery_app.task(name="alerts.schedule_check")
@@ -833,5 +842,5 @@ def schedule_alerts() -> None:
         seconds=3600
     )  # process any missed tasks in the past hour
     stop_at = now + timedelta(seconds=1)
-
-    schedule_window(ScheduleType.alert, start_at, stop_at, resolution)
+    with session_scope(nullpool=True) as session:
+        schedule_window(ScheduleType.alert, start_at, stop_at, resolution, 
session)
diff --git a/superset/utils/celery.py b/superset/utils/celery.py
new file mode 100644
index 0000000..1692e55
--- /dev/null
+++ b/superset/utils/celery.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from typing import Iterator
+
+import sqlalchemy as sa
+from contextlib2 import contextmanager
+from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.pool import NullPool
+
+from superset import app, db
+
+logger = logging.getLogger(__name__)
+
+# Null pool is used for the celery workers due process forking side effects.
+# For more info see: https://github.com/apache/incubator-superset/issues/10530
+@contextmanager
+def session_scope(nullpool: bool) -> Iterator[Session]:
+    """Provide a transactional scope around a series of operations."""
+    database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
+    if "sqlite" in database_uri:
+        logger.warning(
+            "SQLite Database support for metadata databases will be removed \
+            in a future version of Superset."
+        )
+    if nullpool:
+        engine = sa.create_engine(database_uri, poolclass=NullPool)
+        session_class = sessionmaker()
+        session_class.configure(bind=engine)
+        session = session_class()
+    else:
+        session = db.session()
+        session.commit()  # HACK
+
+    try:
+        yield session
+        session.commit()
+    except Exception as ex:
+        session.rollback()
+        logger.exception(ex)
+        raise
+    finally:
+        session.close()
diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py
index a222610..5f6464c 100644
--- a/tests/alerts_tests.py
+++ b/tests/alerts_tests.py
@@ -112,37 +112,37 @@ def test_alert_observer(setup_database):
 
     # Test SQLObserver with int SQL return
     alert1 = create_alert(dbsession, "SELECT 55")
-    observe(alert1.id)
+    observe(alert1.id, dbsession)
     assert alert1.sql_observer[0].observations[-1].value == 55.0
     assert alert1.sql_observer[0].observations[-1].error_msg is None
 
     # Test SQLObserver with double SQL return
     alert2 = create_alert(dbsession, "SELECT 30.0 as wage")
-    observe(alert2.id)
+    observe(alert2.id, dbsession)
     assert alert2.sql_observer[0].observations[-1].value == 30.0
     assert alert2.sql_observer[0].observations[-1].error_msg is None
 
     # Test SQLObserver with NULL result
     alert3 = create_alert(dbsession, "SELECT null as null_result")
-    observe(alert3.id)
+    observe(alert3.id, dbsession)
     assert alert3.sql_observer[0].observations[-1].value is None
     assert alert3.sql_observer[0].observations[-1].error_msg is None
 
     # Test SQLObserver with empty SQL return
     alert4 = create_alert(dbsession, "SELECT first FROM test_table WHERE first 
= -1")
-    observe(alert4.id)
+    observe(alert4.id, dbsession)
     assert alert4.sql_observer[0].observations[-1].value is None
     assert alert4.sql_observer[0].observations[-1].error_msg is not None
 
     # Test SQLObserver with str result
     alert5 = create_alert(dbsession, "SELECT 'test_string' as string_value")
-    observe(alert5.id)
+    observe(alert5.id, dbsession)
     assert alert5.sql_observer[0].observations[-1].value is None
     assert alert5.sql_observer[0].observations[-1].error_msg is not None
 
     # Test SQLObserver with two row result
     alert6 = create_alert(dbsession, "SELECT first FROM test_table")
-    observe(alert6.id)
+    observe(alert6.id, dbsession)
     assert alert6.sql_observer[0].observations[-1].value is None
     assert alert6.sql_observer[0].observations[-1].error_msg is not None
 
@@ -150,7 +150,7 @@ def test_alert_observer(setup_database):
     alert7 = create_alert(
         dbsession, "SELECT first, second FROM test_table WHERE first = 1"
     )
-    observe(alert7.id)
+    observe(alert7.id, dbsession)
     assert alert7.sql_observer[0].observations[-1].value is None
     assert alert7.sql_observer[0].observations[-1].error_msg is not None
 
@@ -161,22 +161,22 @@ def test_evaluate_alert(mock_deliver_alert, 
setup_database):
 
     # Test error with Observer SQL statement
     alert1 = create_alert(dbsession, "$%^&")
-    evaluate_alert(alert1.id, alert1.label)
+    evaluate_alert(alert1.id, alert1.label, dbsession)
     assert alert1.logs[-1].state == AlertState.ERROR
 
     # Test error with alert lacking observer
     alert2 = dbsession.query(Alert).filter_by(label="No Observer").one()
-    evaluate_alert(alert2.id, alert2.label)
+    evaluate_alert(alert2.id, alert2.label, dbsession)
     assert alert2.logs[-1].state == AlertState.ERROR
 
     # Test pass on alert lacking validator
     alert3 = create_alert(dbsession, "SELECT 55")
-    evaluate_alert(alert3.id, alert3.label)
+    evaluate_alert(alert3.id, alert3.label, dbsession)
     assert alert3.logs[-1].state == AlertState.PASS
 
     # Test triggering successful alert
     alert4 = create_alert(dbsession, "SELECT 55", "not null", "{}")
-    evaluate_alert(alert4.id, alert4.label)
+    evaluate_alert(alert4.id, alert4.label, dbsession)
     assert mock_deliver_alert.call_count == 1
     assert alert4.logs[-1].state == AlertState.TRIGGER
 
@@ -214,17 +214,17 @@ def test_not_null_validator(setup_database):
 
     # Test passing SQLObserver with 'null' SQL result
     alert1 = create_alert(dbsession, "SELECT 0")
-    observe(alert1.id)
+    observe(alert1.id, dbsession)
     assert not_null_validator(alert1.sql_observer[0], "{}") is False
 
     # Test passing SQLObserver with empty SQL result
     alert2 = create_alert(dbsession, "SELECT first FROM test_table WHERE first 
= -1")
-    observe(alert2.id)
+    observe(alert2.id, dbsession)
     assert not_null_validator(alert2.sql_observer[0], "{}") is False
 
     # Test triggering alert with non-null SQL result
     alert3 = create_alert(dbsession, "SELECT 55")
-    observe(alert3.id)
+    observe(alert3.id, dbsession)
     assert not_null_validator(alert3.sql_observer[0], "{}") is True
 
 
@@ -233,7 +233,7 @@ def test_operator_validator(setup_database):
 
     # Test passing SQLObserver with empty SQL result
     alert1 = create_alert(dbsession, "SELECT first FROM test_table WHERE first 
= -1")
-    observe(alert1.id)
+    observe(alert1.id, dbsession)
     assert (
         operator_validator(alert1.sql_observer[0], '{"op": ">=", "threshold": 
60}')
         is False
@@ -241,7 +241,7 @@ def test_operator_validator(setup_database):
 
     # Test passing SQLObserver with result that doesn't pass a greater than 
threshold
     alert2 = create_alert(dbsession, "SELECT 55")
-    observe(alert2.id)
+    observe(alert2.id, dbsession)
     assert (
         operator_validator(alert2.sql_observer[0], '{"op": ">=", "threshold": 
60}')
         is False
@@ -283,23 +283,23 @@ def test_validate_observations(setup_database):
 
     # Test False on alert with no validator
     alert1 = create_alert(dbsession, "SELECT 55")
-    assert validate_observations(alert1.id, alert1.label) is False
+    assert validate_observations(alert1.id, alert1.label, dbsession) is False
 
     # Test False on alert with no observations
     alert2 = create_alert(dbsession, "SELECT 55", "not null", "{}")
-    assert validate_observations(alert2.id, alert2.label) is False
+    assert validate_observations(alert2.id, alert2.label, dbsession) is False
 
     # Test False on alert that shouldnt be triggered
     alert3 = create_alert(dbsession, "SELECT 0", "not null", "{}")
-    observe(alert3.id)
-    assert validate_observations(alert3.id, alert3.label) is False
+    observe(alert3.id, dbsession)
+    assert validate_observations(alert3.id, alert3.label, dbsession) is False
 
     # Test True on alert that should be triggered
     alert4 = create_alert(
         dbsession, "SELECT 55", "operator", '{"op": "<=", "threshold": 60}'
     )
-    observe(alert4.id)
-    assert validate_observations(alert4.id, alert4.label) is True
+    observe(alert4.id, dbsession)
+    assert validate_observations(alert4.id, alert4.label, dbsession) is True
 
 
 @patch("superset.tasks.slack_util.WebClient.files_upload")
@@ -311,7 +311,7 @@ def test_deliver_alert_screenshot(
 ):
     dbsession = setup_database
     alert = create_alert(dbsession, "SELECT 55", "not null", "{}")
-    observe(alert.id)
+    observe(alert.id, dbsession)
 
     screenshot = read_fixture("sample.png")
     screenshot_mock.return_value = screenshot
@@ -322,7 +322,7 @@ def test_deliver_alert_screenshot(
         f"http://0.0.0.0:8080/superset/slice/{alert.slice_id}/";,
     ]
 
-    deliver_alert(alert_id=alert.id)
+    deliver_alert(alert.id, dbsession)
     assert email_mock.call_args[1]["images"]["screenshot"] == screenshot
     assert file_upload_mock.call_args[1] == {
         "channels": alert.slack_channel,
diff --git a/tests/schedules_test.py b/tests/schedules_test.py
index 77f7070..88b6d1f 100644
--- a/tests/schedules_test.py
+++ b/tests/schedules_test.py
@@ -366,6 +366,7 @@ class TestSchedules(SupersetTestCase):
             schedule.delivery_type,
             schedule.email_format,
             schedule.deliver_as_group,
+            db.session,
         )
         mtime.sleep.assert_called_once()
         driver.screenshot.assert_not_called()
@@ -418,6 +419,7 @@ class TestSchedules(SupersetTestCase):
             schedule.delivery_type,
             schedule.email_format,
             schedule.deliver_as_group,
+            db.session,
         )
 
         mtime.sleep.assert_called_once()
@@ -466,6 +468,7 @@ class TestSchedules(SupersetTestCase):
             schedule.delivery_type,
             schedule.email_format,
             schedule.deliver_as_group,
+            db.session,
         )
 
         send_email_smtp.assert_called_once()
@@ -510,6 +513,7 @@ class TestSchedules(SupersetTestCase):
             schedule.delivery_type,
             schedule.email_format,
             schedule.deliver_as_group,
+            db.session,
         )
 
         send_email_smtp.assert_called_once()

Reply via email to