This is an automated email from the ASF dual-hosted git repository.
bkyryliuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new ac2937a fix: use nullpool in the celery workers (#10819)
ac2937a is described below
commit ac2937a6c58b63c2723917a1cd111a8363a728a0
Author: Bogdan <[email protected]>
AuthorDate: Thu Sep 10 13:29:57 2020 -0700
fix: use nullpool in the celery workers (#10819)
* Use nullpool in the celery workers
* Address comments
Co-authored-by: bogdan kyryliuk <[email protected]>
---
superset/cli.py | 12 +-
superset/sql_lab.py | 47 +-------
superset/tasks/alerts/observer.py | 12 +-
superset/tasks/schedules.py | 247 ++++++++++++++++++++------------------
superset/utils/celery.py | 57 +++++++++
tests/alerts_tests.py | 48 ++++----
tests/schedules_test.py | 4 +
7 files changed, 234 insertions(+), 193 deletions(-)
diff --git a/superset/cli.py b/superset/cli.py
index ef17682..f0f7f1e 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -34,6 +34,7 @@ from superset import app, appbuilder, security_manager
from superset.app import create_app
from superset.extensions import celery_app, db
from superset.utils import core as utils
+from superset.utils.celery import session_scope
from superset.utils.urls import get_url_path
logger = logging.getLogger(__name__)
@@ -619,6 +620,11 @@ def alert() -> None:
from superset.tasks.schedules import schedule_window
click.secho("Processing one alert loop", fg="green")
- schedule_window(
- ScheduleType.alert, datetime.now() - timedelta(1000), datetime.now(),
6000
- )
+ with session_scope(nullpool=True) as session:
+ schedule_window(
+ report_type=ScheduleType.alert,
+ start_at=datetime.now() - timedelta(1000),
+ stop_at=datetime.now(),
+ resolution=6000,
+ session=session,
+ )
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 09be5e8..796ddba 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -19,33 +19,25 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
-from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
import backoff
import msgpack
import pyarrow as pa
import simplejson as json
-import sqlalchemy
from celery.exceptions import SoftTimeLimitExceeded
from celery.task.base import Task
-from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
-from sqlalchemy.orm import Session, sessionmaker
-from sqlalchemy.pool import NullPool
-
-from superset import (
- app,
- db,
- results_backend,
- results_backend_use_msgpack,
- security_manager,
-)
+from sqlalchemy.orm import Session
+
+from superset import app, results_backend, results_backend_use_msgpack,
security_manager
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
+from superset.utils.celery import session_scope
from superset.utils.core import (
json_iso_dttm_ser,
QuerySource,
@@ -121,35 +113,6 @@ def get_query(query_id: int, session: Session) -> Query:
raise SqlLabException("Failed at getting query")
-@contextmanager
-def session_scope(nullpool: bool) -> Iterator[Session]:
- """Provide a transactional scope around a series of operations."""
- database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
- if "sqlite" in database_uri:
- logger.warning(
- "SQLite Database support for metadata databases will be removed \
- in a future version of Superset."
- )
- if nullpool:
- engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
- session_class = sessionmaker()
- session_class.configure(bind=engine)
- session = session_class()
- else:
- session = db.session()
- session.commit() # HACK
-
- try:
- yield session
- session.commit()
- except Exception as ex:
- session.rollback()
- logger.exception(ex)
- raise
- finally:
- session.close()
-
-
@celery_app.task(
name="sql_lab.get_sql_results",
bind=True,
diff --git a/superset/tasks/alerts/observer.py
b/superset/tasks/alerts/observer.py
index f7c5373..34ff668 100644
--- a/superset/tasks/alerts/observer.py
+++ b/superset/tasks/alerts/observer.py
@@ -20,22 +20,24 @@ from datetime import datetime
from typing import Optional
import pandas as pd
+from sqlalchemy.orm import Session
-from superset import db
from superset.models.alerts import Alert, SQLObservation
from superset.sql_parse import ParsedQuery
logger = logging.getLogger("tasks.email_reports")
-def observe(alert_id: int) -> Optional[str]:
+# Session needs to be passed along in the celery workers and db.session cannot
be used.
+# For more info see: https://github.com/apache/incubator-superset/issues/10530
+def observe(alert_id: int, session: Session) -> Optional[str]:
"""
Runs the SQL query in an alert's SQLObserver and then
stores the result in a SQLObservation.
Returns an error message if the observer value was not valid
"""
- alert = db.session.query(Alert).filter_by(id=alert_id).one()
+ alert = session.query(Alert).filter_by(id=alert_id).one()
sql_observer = alert.sql_observer[0]
value = None
@@ -57,8 +59,8 @@ def observe(alert_id: int) -> Optional[str]:
error_msg=error_msg,
)
- db.session.add(observation)
- db.session.commit()
+ session.add(observation)
+ session.commit()
return error_msg
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 9643f09..7c5ad4a 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -47,8 +47,9 @@ from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from selenium.webdriver.remote.webdriver import WebDriver
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
+from sqlalchemy.orm import Session
-from superset import app, db, security_manager, thumbnail_cache
+from superset import app, security_manager, thumbnail_cache
from superset.extensions import celery_app, machine_auth_provider_factory
from superset.models.alerts import Alert, AlertLog
from superset.models.dashboard import Dashboard
@@ -62,6 +63,7 @@ from superset.models.slice import Slice
from superset.tasks.alerts.observer import observe
from superset.tasks.alerts.validator import get_validator_function
from superset.tasks.slack_util import deliver_slack_msg
+from superset.utils.celery import session_scope
from superset.utils.core import get_email_address_list, send_email_smtp
from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
from superset.utils.urls import get_url_path
@@ -225,7 +227,7 @@ def destroy_webdriver(
pass
-def deliver_dashboard(
+def deliver_dashboard( # pylint: disable=too-many-locals
dashboard_id: int,
recipients: Optional[str],
slack_channel: Optional[str],
@@ -236,69 +238,70 @@ def deliver_dashboard(
"""
Given a schedule, delivery the dashboard as an email report
"""
- dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
+ with session_scope(nullpool=True) as session:
+ dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one()
- dashboard_url = _get_url_path(
- "Superset.dashboard", dashboard_id_or_slug=dashboard.id
- )
- dashboard_url_user_friendly = _get_url_path(
- "Superset.dashboard", user_friendly=True,
dashboard_id_or_slug=dashboard.id
- )
-
- # Create a driver, fetch the page, wait for the page to render
- driver = create_webdriver()
- window = config["WEBDRIVER_WINDOW"]["dashboard"]
- driver.set_window_size(*window)
- driver.get(dashboard_url)
- time.sleep(EMAIL_PAGE_RENDER_WAIT)
-
- # Set up a function to retry once for the element.
- # This is buggy in certain selenium versions with firefox driver
- get_element = getattr(driver, "find_element_by_class_name")
- element = retry_call(
- get_element, fargs=["grid-container"], tries=2,
delay=EMAIL_PAGE_RENDER_WAIT
- )
-
- try:
- screenshot = element.screenshot_as_png
- except WebDriverException:
- # Some webdrivers do not support screenshots for elements.
- # In such cases, take a screenshot of the entire page.
- screenshot = driver.screenshot() # pylint: disable=no-member
- finally:
- destroy_webdriver(driver)
-
- # Generate the email body and attachments
- report_content = _generate_report_content(
- delivery_type,
- screenshot,
- dashboard.dashboard_title,
- dashboard_url_user_friendly,
- )
+ dashboard_url = _get_url_path(
+ "Superset.dashboard", dashboard_id_or_slug=dashboard.id
+ )
+ dashboard_url_user_friendly = _get_url_path(
+ "Superset.dashboard", user_friendly=True,
dashboard_id_or_slug=dashboard.id
+ )
- subject = __(
- "%(prefix)s %(title)s",
- prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
- title=dashboard.dashboard_title,
- )
+ # Create a driver, fetch the page, wait for the page to render
+ driver = create_webdriver()
+ window = config["WEBDRIVER_WINDOW"]["dashboard"]
+ driver.set_window_size(*window)
+ driver.get(dashboard_url)
+ time.sleep(EMAIL_PAGE_RENDER_WAIT)
+
+ # Set up a function to retry once for the element.
+ # This is buggy in certain selenium versions with firefox driver
+ get_element = getattr(driver, "find_element_by_class_name")
+ element = retry_call(
+ get_element, fargs=["grid-container"], tries=2,
delay=EMAIL_PAGE_RENDER_WAIT
+ )
- if recipients:
- _deliver_email(
- recipients,
- deliver_as_group,
- subject,
- report_content.body,
- report_content.data,
- report_content.images,
+ try:
+ screenshot = element.screenshot_as_png
+ except WebDriverException:
+ # Some webdrivers do not support screenshots for elements.
+ # In such cases, take a screenshot of the entire page.
+ screenshot = driver.screenshot() # pylint: disable=no-member
+ finally:
+ destroy_webdriver(driver)
+
+ # Generate the email body and attachments
+ report_content = _generate_report_content(
+ delivery_type,
+ screenshot,
+ dashboard.dashboard_title,
+ dashboard_url_user_friendly,
)
- if slack_channel:
- deliver_slack_msg(
- slack_channel,
- subject,
- report_content.slack_message,
- report_content.slack_attachment,
+
+ subject = __(
+ "%(prefix)s %(title)s",
+ prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
+ title=dashboard.dashboard_title,
)
+ if recipients:
+ _deliver_email(
+ recipients,
+ deliver_as_group,
+ subject,
+ report_content.body,
+ report_content.data,
+ report_content.images,
+ )
+ if slack_channel:
+ deliver_slack_msg(
+ slack_channel,
+ subject,
+ report_content.slack_message,
+ report_content.slack_attachment,
+ )
+
def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) ->
ReportContent:
slice_url = _get_url_path(
@@ -362,8 +365,8 @@ def _get_slice_data(slc: Slice, delivery_type:
EmailDeliveryType) -> ReportConte
return ReportContent(body, data, None, slack_message, content)
-def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
- slice_obj = db.session.query(Slice).get(slice_id)
+def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:
+ slice_obj = session.query(Slice).get(slice_id)
chart_url = get_url_path("Superset.slice", slice_id=slice_obj.id,
standalone="true")
screenshot = ChartScreenshot(chart_url, slice_obj.digest)
@@ -376,7 +379,7 @@ def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
user=user, cache=thumbnail_cache, force=True,
)
- db.session.commit()
+ session.commit()
return ScreenshotData(image_url, image_data)
@@ -427,11 +430,12 @@ def deliver_slice( # pylint: disable=too-many-arguments
delivery_type: EmailDeliveryType,
email_format: SliceEmailReportFormat,
deliver_as_group: bool,
+ session: Session,
) -> None:
"""
Given a schedule, delivery the slice as an email report
"""
- slc = db.session.query(Slice).filter_by(id=slice_id).one()
+ slc = session.query(Slice).filter_by(id=slice_id).one()
if email_format == SliceEmailReportFormat.data:
report_content = _get_slice_data(slc, delivery_type)
@@ -477,38 +481,42 @@ def schedule_email_report( # pylint:
disable=unused-argument
slack_channel: Optional[str] = None,
) -> None:
model_cls = get_scheduler_model(report_type)
- schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
+ with session_scope(nullpool=True) as session:
+ schedule = session.query(model_cls).get(schedule_id)
- # The user may have disabled the schedule. If so, ignore this
- if not schedule or not schedule.active:
- logger.info("Ignoring deactivated schedule")
- return
-
- recipients = recipients or schedule.recipients
- slack_channel = slack_channel or schedule.slack_channel
- logger.info(
- "Starting report for slack: %s and recipients: %s.", slack_channel,
recipients
- )
+ # The user may have disabled the schedule. If so, ignore this
+ if not schedule or not schedule.active:
+ logger.info("Ignoring deactivated schedule")
+ return
- if report_type == ScheduleType.dashboard:
- deliver_dashboard(
- schedule.dashboard_id,
- recipients,
+ recipients = recipients or schedule.recipients
+ slack_channel = slack_channel or schedule.slack_channel
+ logger.info(
+ "Starting report for slack: %s and recipients: %s.",
slack_channel,
- schedule.delivery_type,
- schedule.deliver_as_group,
- )
- elif report_type == ScheduleType.slice:
- deliver_slice(
- schedule.slice_id,
recipients,
- slack_channel,
- schedule.delivery_type,
- schedule.email_format,
- schedule.deliver_as_group,
)
- else:
- raise RuntimeError("Unknown report type")
+
+ if report_type == ScheduleType.dashboard:
+ deliver_dashboard(
+ schedule.dashboard_id,
+ recipients,
+ slack_channel,
+ schedule.delivery_type,
+ schedule.deliver_as_group,
+ )
+ elif report_type == ScheduleType.slice:
+ deliver_slice(
+ schedule.slice_id,
+ recipients,
+ slack_channel,
+ schedule.delivery_type,
+ schedule.email_format,
+ schedule.deliver_as_group,
+ session,
+ )
+ else:
+ raise RuntimeError("Unknown report type")
@celery_app.task(
@@ -529,9 +537,8 @@ def schedule_alert_query( # pylint: disable=unused-argument
slack_channel: Optional[str] = None,
) -> None:
model_cls = get_scheduler_model(report_type)
-
- try:
- schedule = db.session.query(model_cls).get(schedule_id)
+ with session_scope(nullpool=True) as session:
+ schedule = session.query(model_cls).get(schedule_id)
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
@@ -539,15 +546,11 @@ def schedule_alert_query( # pylint:
disable=unused-argument
return
if report_type == ScheduleType.alert:
- evaluate_alert(schedule.id, schedule.label, recipients,
slack_channel)
+ evaluate_alert(
+ schedule.id, schedule.label, session, recipients, slack_channel
+ )
else:
raise RuntimeError("Unknown report type")
- except NoSuchColumnError as column_error:
- stats_logger.incr("run_alert_task.error.nosuchcolumnerror")
- raise column_error
- except ResourceClosedError as resource_error:
- stats_logger.incr("run_alert_task.error.resourceclosederror")
- raise resource_error
class AlertState:
@@ -558,6 +561,7 @@ class AlertState:
def deliver_alert(
alert_id: int,
+ session: Session,
recipients: Optional[str] = None,
slack_channel: Optional[str] = None,
) -> None:
@@ -566,7 +570,7 @@ def deliver_alert(
to its respective email and slack recipients
"""
- alert = db.session.query(Alert).get(alert_id)
+ alert = session.query(Alert).get(alert_id)
logging.info("Triggering alert: %s", alert)
@@ -588,7 +592,7 @@ def deliver_alert(
str(alert.observations[-1].value),
validation_error_message,
_get_url_path("AlertModelView.show", user_friendly=True,
pk=alert_id),
- _get_slice_screenshot(alert.slice.id),
+ _get_slice_screenshot(alert.slice.id, session),
)
else:
# TODO: dashboard delivery!
@@ -668,6 +672,7 @@ def deliver_slack_alert(alert_content: AlertContent,
slack_channel: str) -> None
def evaluate_alert(
alert_id: int,
label: str,
+ session: Session,
recipients: Optional[str] = None,
slack_channel: Optional[str] = None,
) -> None:
@@ -680,7 +685,7 @@ def evaluate_alert(
try:
logger.info("Querying observers for alert <%s:%s>", alert_id, label)
- error_msg = observe(alert_id)
+ error_msg = observe(alert_id, session)
if error_msg:
state = AlertState.ERROR
logging.error(error_msg)
@@ -694,17 +699,17 @@ def evaluate_alert(
if state != AlertState.ERROR:
# Don't validate alert on test runs since it may not be triggered
if recipients or slack_channel:
- deliver_alert(alert_id, recipients, slack_channel)
+ deliver_alert(alert_id, session, recipients, slack_channel)
state = AlertState.TRIGGER
# Validate during regular workflow and deliver only if triggered
- elif validate_observations(alert_id, label):
- deliver_alert(alert_id, recipients, slack_channel)
+ elif validate_observations(alert_id, label, session):
+ deliver_alert(alert_id, session, recipients, slack_channel)
state = AlertState.TRIGGER
else:
state = AlertState.PASS
- db.session.commit()
- alert = db.session.query(Alert).get(alert_id)
+ session.commit()
+ alert = session.query(Alert).get(alert_id)
if state != AlertState.ERROR:
alert.last_eval_dttm = dttm_end
alert.last_state = state
@@ -716,10 +721,10 @@ def evaluate_alert(
state=state,
)
)
- db.session.commit()
+ session.commit()
-def validate_observations(alert_id: int, label: str) -> bool:
+def validate_observations(alert_id: int, label: str, session: Session) -> bool:
"""
Runs an alert's validators to check if it should be triggered or not
If so, return the name of the validator that returned true
@@ -727,7 +732,7 @@ def validate_observations(alert_id: int, label: str) ->
bool:
logger.info("Validating observations for alert <%s:%s>", alert_id, label)
- alert = db.session.query(Alert).get(alert_id)
+ alert = session.query(Alert).get(alert_id)
if alert.validators:
validator = alert.validators[0]
validate = get_validator_function(validator.validator_type)
@@ -760,7 +765,11 @@ def next_schedules(
def schedule_window(
- report_type: str, start_at: datetime, stop_at: datetime, resolution: int
+ report_type: str,
+ start_at: datetime,
+ stop_at: datetime,
+ resolution: int,
+ session: Session,
) -> None:
"""
Find all active schedules and schedule celery tasks for
@@ -772,8 +781,7 @@ def schedule_window(
if not model_cls:
return None
- dbsession = db.create_scoped_session()
- schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
+ schedules = session.query(model_cls).filter(model_cls.active.is_(True))
for schedule in schedules:
logging.info("Processing schedule %s", schedule)
@@ -810,7 +818,6 @@ def get_scheduler_action(report_type: str) ->
Optional[Callable[..., Any]]:
@celery_app.task(name="email_reports.schedule_hourly")
def schedule_hourly() -> None:
""" Celery beat job meant to be invoked hourly """
-
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
logger.info("Scheduled email reports not enabled in config")
return
@@ -820,8 +827,10 @@ def schedule_hourly() -> None:
# Get the top of the hour
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0,
minute=0)
stop_at = start_at + timedelta(seconds=3600)
- schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
- schedule_window(ScheduleType.slice, start_at, stop_at, resolution)
+
+ with session_scope(nullpool=True) as session:
+ schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution,
session)
+ schedule_window(ScheduleType.slice, start_at, stop_at, resolution,
session)
@celery_app.task(name="alerts.schedule_check")
@@ -833,5 +842,5 @@ def schedule_alerts() -> None:
seconds=3600
) # process any missed tasks in the past hour
stop_at = now + timedelta(seconds=1)
-
- schedule_window(ScheduleType.alert, start_at, stop_at, resolution)
+ with session_scope(nullpool=True) as session:
+ schedule_window(ScheduleType.alert, start_at, stop_at, resolution,
session)
diff --git a/superset/utils/celery.py b/superset/utils/celery.py
new file mode 100644
index 0000000..1692e55
--- /dev/null
+++ b/superset/utils/celery.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from typing import Iterator
+
+import sqlalchemy as sa
+from contextlib2 import contextmanager
+from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.pool import NullPool
+
+from superset import app, db
+
+logger = logging.getLogger(__name__)
+
+# Null pool is used for the celery workers due process forking side effects.
+# For more info see: https://github.com/apache/incubator-superset/issues/10530
+@contextmanager
+def session_scope(nullpool: bool) -> Iterator[Session]:
+ """Provide a transactional scope around a series of operations."""
+ database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
+ if "sqlite" in database_uri:
+ logger.warning(
+ "SQLite Database support for metadata databases will be removed \
+ in a future version of Superset."
+ )
+ if nullpool:
+ engine = sa.create_engine(database_uri, poolclass=NullPool)
+ session_class = sessionmaker()
+ session_class.configure(bind=engine)
+ session = session_class()
+ else:
+ session = db.session()
+ session.commit() # HACK
+
+ try:
+ yield session
+ session.commit()
+ except Exception as ex:
+ session.rollback()
+ logger.exception(ex)
+ raise
+ finally:
+ session.close()
diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py
index a222610..5f6464c 100644
--- a/tests/alerts_tests.py
+++ b/tests/alerts_tests.py
@@ -112,37 +112,37 @@ def test_alert_observer(setup_database):
# Test SQLObserver with int SQL return
alert1 = create_alert(dbsession, "SELECT 55")
- observe(alert1.id)
+ observe(alert1.id, dbsession)
assert alert1.sql_observer[0].observations[-1].value == 55.0
assert alert1.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with double SQL return
alert2 = create_alert(dbsession, "SELECT 30.0 as wage")
- observe(alert2.id)
+ observe(alert2.id, dbsession)
assert alert2.sql_observer[0].observations[-1].value == 30.0
assert alert2.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with NULL result
alert3 = create_alert(dbsession, "SELECT null as null_result")
- observe(alert3.id)
+ observe(alert3.id, dbsession)
assert alert3.sql_observer[0].observations[-1].value is None
assert alert3.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with empty SQL return
alert4 = create_alert(dbsession, "SELECT first FROM test_table WHERE first
= -1")
- observe(alert4.id)
+ observe(alert4.id, dbsession)
assert alert4.sql_observer[0].observations[-1].value is None
assert alert4.sql_observer[0].observations[-1].error_msg is not None
# Test SQLObserver with str result
alert5 = create_alert(dbsession, "SELECT 'test_string' as string_value")
- observe(alert5.id)
+ observe(alert5.id, dbsession)
assert alert5.sql_observer[0].observations[-1].value is None
assert alert5.sql_observer[0].observations[-1].error_msg is not None
# Test SQLObserver with two row result
alert6 = create_alert(dbsession, "SELECT first FROM test_table")
- observe(alert6.id)
+ observe(alert6.id, dbsession)
assert alert6.sql_observer[0].observations[-1].value is None
assert alert6.sql_observer[0].observations[-1].error_msg is not None
@@ -150,7 +150,7 @@ def test_alert_observer(setup_database):
alert7 = create_alert(
dbsession, "SELECT first, second FROM test_table WHERE first = 1"
)
- observe(alert7.id)
+ observe(alert7.id, dbsession)
assert alert7.sql_observer[0].observations[-1].value is None
assert alert7.sql_observer[0].observations[-1].error_msg is not None
@@ -161,22 +161,22 @@ def test_evaluate_alert(mock_deliver_alert,
setup_database):
# Test error with Observer SQL statement
alert1 = create_alert(dbsession, "$%^&")
- evaluate_alert(alert1.id, alert1.label)
+ evaluate_alert(alert1.id, alert1.label, dbsession)
assert alert1.logs[-1].state == AlertState.ERROR
# Test error with alert lacking observer
alert2 = dbsession.query(Alert).filter_by(label="No Observer").one()
- evaluate_alert(alert2.id, alert2.label)
+ evaluate_alert(alert2.id, alert2.label, dbsession)
assert alert2.logs[-1].state == AlertState.ERROR
# Test pass on alert lacking validator
alert3 = create_alert(dbsession, "SELECT 55")
- evaluate_alert(alert3.id, alert3.label)
+ evaluate_alert(alert3.id, alert3.label, dbsession)
assert alert3.logs[-1].state == AlertState.PASS
# Test triggering successful alert
alert4 = create_alert(dbsession, "SELECT 55", "not null", "{}")
- evaluate_alert(alert4.id, alert4.label)
+ evaluate_alert(alert4.id, alert4.label, dbsession)
assert mock_deliver_alert.call_count == 1
assert alert4.logs[-1].state == AlertState.TRIGGER
@@ -214,17 +214,17 @@ def test_not_null_validator(setup_database):
# Test passing SQLObserver with 'null' SQL result
alert1 = create_alert(dbsession, "SELECT 0")
- observe(alert1.id)
+ observe(alert1.id, dbsession)
assert not_null_validator(alert1.sql_observer[0], "{}") is False
# Test passing SQLObserver with empty SQL result
alert2 = create_alert(dbsession, "SELECT first FROM test_table WHERE first
= -1")
- observe(alert2.id)
+ observe(alert2.id, dbsession)
assert not_null_validator(alert2.sql_observer[0], "{}") is False
# Test triggering alert with non-null SQL result
alert3 = create_alert(dbsession, "SELECT 55")
- observe(alert3.id)
+ observe(alert3.id, dbsession)
assert not_null_validator(alert3.sql_observer[0], "{}") is True
@@ -233,7 +233,7 @@ def test_operator_validator(setup_database):
# Test passing SQLObserver with empty SQL result
alert1 = create_alert(dbsession, "SELECT first FROM test_table WHERE first
= -1")
- observe(alert1.id)
+ observe(alert1.id, dbsession)
assert (
operator_validator(alert1.sql_observer[0], '{"op": ">=", "threshold":
60}')
is False
@@ -241,7 +241,7 @@ def test_operator_validator(setup_database):
# Test passing SQLObserver with result that doesn't pass a greater than
threshold
alert2 = create_alert(dbsession, "SELECT 55")
- observe(alert2.id)
+ observe(alert2.id, dbsession)
assert (
operator_validator(alert2.sql_observer[0], '{"op": ">=", "threshold":
60}')
is False
@@ -283,23 +283,23 @@ def test_validate_observations(setup_database):
# Test False on alert with no validator
alert1 = create_alert(dbsession, "SELECT 55")
- assert validate_observations(alert1.id, alert1.label) is False
+ assert validate_observations(alert1.id, alert1.label, dbsession) is False
# Test False on alert with no observations
alert2 = create_alert(dbsession, "SELECT 55", "not null", "{}")
- assert validate_observations(alert2.id, alert2.label) is False
+ assert validate_observations(alert2.id, alert2.label, dbsession) is False
# Test False on alert that shouldnt be triggered
alert3 = create_alert(dbsession, "SELECT 0", "not null", "{}")
- observe(alert3.id)
- assert validate_observations(alert3.id, alert3.label) is False
+ observe(alert3.id, dbsession)
+ assert validate_observations(alert3.id, alert3.label, dbsession) is False
# Test True on alert that should be triggered
alert4 = create_alert(
dbsession, "SELECT 55", "operator", '{"op": "<=", "threshold": 60}'
)
- observe(alert4.id)
- assert validate_observations(alert4.id, alert4.label) is True
+ observe(alert4.id, dbsession)
+ assert validate_observations(alert4.id, alert4.label, dbsession) is True
@patch("superset.tasks.slack_util.WebClient.files_upload")
@@ -311,7 +311,7 @@ def test_deliver_alert_screenshot(
):
dbsession = setup_database
alert = create_alert(dbsession, "SELECT 55", "not null", "{}")
- observe(alert.id)
+ observe(alert.id, dbsession)
screenshot = read_fixture("sample.png")
screenshot_mock.return_value = screenshot
@@ -322,7 +322,7 @@ def test_deliver_alert_screenshot(
f"http://0.0.0.0:8080/superset/slice/{alert.slice_id}/",
]
- deliver_alert(alert_id=alert.id)
+ deliver_alert(alert.id, dbsession)
assert email_mock.call_args[1]["images"]["screenshot"] == screenshot
assert file_upload_mock.call_args[1] == {
"channels": alert.slack_channel,
diff --git a/tests/schedules_test.py b/tests/schedules_test.py
index 77f7070..88b6d1f 100644
--- a/tests/schedules_test.py
+++ b/tests/schedules_test.py
@@ -366,6 +366,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
+ db.session,
)
mtime.sleep.assert_called_once()
driver.screenshot.assert_not_called()
@@ -418,6 +419,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
+ db.session,
)
mtime.sleep.assert_called_once()
@@ -466,6 +468,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
+ db.session,
)
send_email_smtp.assert_called_once()
@@ -510,6 +513,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
+ db.session,
)
send_email_smtp.assert_called_once()