mistercrunch closed pull request #3575: [Feature] Task scheduling system +
Druid cluster refresh task
URL: https://github.com/apache/incubator-superset/pull/3575
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/setup.py b/setup.py
index ec12deafb9..580c3593ae 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ def get_git_sha():
'sqlparse==0.2.3',
'thrift>=0.9.3',
'thrift-sasl>=0.2.1',
+ 'crontab==0.21.3',
],
extras_require={
'cors': ['Flask-Cors>=2.0.0'],
diff --git a/superset/__init__.py b/superset/__init__.py
index 099edc133e..c58665f48d 100644
--- a/superset/__init__.py
+++ b/superset/__init__.py
@@ -161,3 +161,4 @@ def index(self):
ConnectorRegistry.register_sources(module_datasource_map)
from superset import views # noqa
+from superset.tasks import views # noqa
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 90b4dc063e..a103914d8b 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -106,29 +106,38 @@ def get_druid_version(self):
def refresh_datasources(
self,
- datasource_name=None,
+ datasource_names=None,
merge_flag=True,
refreshAll=True):
- """Refresh metadata of all datasources in the cluster
- If ``datasource_name`` is specified, only that datasource is updated
+ """
+ Refresh metadata of all datasources in the cluster
+ If ``datasource_names`` is specified, only those datasources
+ are updated
"""
self.druid_version = self.get_druid_version()
ds_list = self.get_datasources()
blacklist = conf.get('DRUID_DATA_SOURCE_BLACKLIST', [])
- ds_refresh = []
- if not datasource_name:
+ if not datasource_names:
ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
- elif datasource_name not in blacklist and datasource_name in ds_list:
- ds_refresh.append(datasource_name)
else:
- return
+ if not isinstance(datasource_names, list):
+ # can pass single name or list or names
+ datasource_names = [datasource_names]
+ ds_refresh = list(filter(
+ lambda ds: ds not in blacklist and ds in ds_list,
+ datasource_names,
+ ))
self.refresh(ds_refresh, merge_flag, refreshAll)
+ flasher("Refreshed metadata from cluster [{}]"
+ .format(self.cluster_name), "info")
def refresh(self, datasource_names, merge_flag, refreshAll):
"""
Fetches metadata for the specified datasources andm
merges to the Superset database
"""
+ if not datasource_names or not len(datasource_names):
+ return
session = db.session
ds_list = (
session.query(DruidDatasource)
diff --git a/superset/connectors/druid/views.py
b/superset/connectors/druid/views.py
index 713a43c36b..6a4fa06fe0 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -3,6 +3,7 @@
from flask import flash, Markup, redirect
from flask_appbuilder import CompactCRUDMixin, expose
+from flask_appbuilder.actions import action
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as __
from flask_babel import lazy_gettext as _
@@ -154,6 +155,29 @@ def pre_add(self, cluster):
def pre_update(self, cluster):
self.pre_add(cluster)
+ def refresh_items(self, items, refreshAll=True):
+ if not isinstance(items, list):
+ items = [items]
+ for it in items:
+ it.refresh_datasources(refreshAll=refreshAll)
+ return redirect('/druiddatasourcemodelview/list/')
+
+ @action(
+ "scan_cluster",
+ "Scan",
+ "Scan selected clusters for new datasources?",
+ "fa-search-plus")
+ def scan_clusters(self, items):
+ return self.refresh_items(items, False)
+
+ @action(
+ "refresh_cluster",
+ "Refresh",
+ "Refresh metadata for selected clusters?",
+ "fa-refresh")
+ def refresh_clusters(self, items):
+ return self.refresh_items(items)
+
def _delete(self, pk):
DeleteMixin._delete(self, pk)
@@ -257,6 +281,28 @@ def post_add(self, datasource):
def post_update(self, datasource):
self.post_add(datasource)
+ @action(
+ "refresh_datasource",
+ "Refresh",
+ "Refresh selected datasources?",
+ "fa-refresh")
+ def refresh_datasources(self, items):
+ to_refresh = {}
+ if not isinstance(items, list):
+ items = [items]
+ for item in items:
+ c_name = item.cluster.cluster_name
+ if c_name in to_refresh:
+ to_refresh[c_name]['ds'].append(item.datasource_name)
+ else:
+ to_refresh[c_name] = {
+ 'ds': [item.datasource_name],
+ 'cluster': item.cluster,
+ }
+ for c in to_refresh.values():
+ c['cluster'].refresh_datasources(c['ds'])
+ return redirect(self.get_default_url())
+
def _delete(self, pk):
DeleteMixin._delete(self, pk)
@@ -291,10 +337,6 @@ def refresh_datasources(self, refreshAll=True):
logging.exception(e)
return redirect('/druidclustermodelview/list/')
cluster.metadata_last_refreshed = datetime.now()
- flash(
- 'Refreshed metadata from cluster '
- '[' + cluster.cluster_name + ']',
- 'info')
session.commit()
return redirect('/druiddatasourcemodelview/list/')
@@ -317,7 +359,7 @@ def scan_new_datasources(self):
category='Sources',
category_label=__('Sources'),
category_icon='fa-database',
- icon='fa-refresh')
+ icon="fa-search-plus")
appbuilder.add_link(
'Refresh Druid Metadata',
label=__('Refresh Druid Metadata'),
@@ -325,7 +367,7 @@ def scan_new_datasources(self):
category='Sources',
category_label=__('Sources'),
category_icon='fa-database',
- icon='fa-cog')
+ icon="fa-refresh")
appbuilder.add_separator('Sources', )
diff --git a/superset/migrations/versions/2c7417d4af40_.py
b/superset/migrations/versions/2c7417d4af40_.py
new file mode 100644
index 0000000000..2a6a0fa077
--- /dev/null
+++ b/superset/migrations/versions/2c7417d4af40_.py
@@ -0,0 +1,22 @@
+"""empty message
+
+Revision ID: 2c7417d4af40
+Revises: ('9fb37232bbff', 'f959a6652acd')
+Create Date: 2017-10-02 18:52:18.023368
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '2c7417d4af40'
+down_revision = ('9fb37232bbff', 'f959a6652acd')
+
+from alembic import op
+import sqlalchemy as sa
+
+
+def upgrade():
+ pass
+
+
+def downgrade():
+ pass
diff --git a/superset/migrations/versions/9fb37232bbff_create_tasks_table.py
b/superset/migrations/versions/9fb37232bbff_create_tasks_table.py
new file mode 100644
index 0000000000..0c0049fb61
--- /dev/null
+++ b/superset/migrations/versions/9fb37232bbff_create_tasks_table.py
@@ -0,0 +1,34 @@
+"""create tasks table
+
+Revision ID: 9fb37232bbff
+Revises: 472d2f73dfd4
+Create Date: 2017-10-02 16:33:12.945687
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '9fb37232bbff'
+down_revision = '472d2f73dfd4'
+
+
+from alembic import op
+from datetime import datetime
+import sqlalchemy as sa
+
+
+def upgrade():
+ op.create_table(
+ 'refresh_tasks',
+ sa.Column('id', sa.Integer, primary_key=True),
+ sa.Column('crontab_str', sa.String(120), nullable=False),
+ sa.Column('config', sa.Text, nullable=False),
+ sa.Column('description', sa.String(250), nullable=True),
+ sa.Column('created_on', sa.DateTime, default=datetime.now,
nullable=True),
+ sa.Column('changed_on', sa.DateTime, default=datetime.now,
nullable=True, onupdate=datetime.now),
+ sa.Column('created_by_fk', sa.Integer(), sa.ForeignKey("ab_user.id"),
nullable=True),
+ sa.Column('changed_by_fk', sa.Integer(), sa.ForeignKey("ab_user.id"),
nullable=True),
+ )
+
+
+def downgrade():
+ op.drop_table('refresh_tasks')
diff --git a/superset/tasks/__init__.py b/superset/tasks/__init__.py
new file mode 100644
index 0000000000..54a32b6dda
--- /dev/null
+++ b/superset/tasks/__init__.py
@@ -0,0 +1,3 @@
+from . import models # noqa
+from . import views # noqa
+from . import utils # noqa
diff --git a/superset/tasks/manager.py b/superset/tasks/manager.py
new file mode 100644
index 0000000000..d01de6f117
--- /dev/null
+++ b/superset/tasks/manager.py
@@ -0,0 +1,179 @@
+import threading
+
+from time import time, sleep
+
+from superset import db
+
+from .models import CronTask
+from .processor import execute_task_config
+
+try:
+ import Queue as Q
+except ImportError:
+ import queue as Q
+
+
+class ManagedTask:
+ """
+ Wrapper class for `CronTask`. This object retains a reference to
+ `task` and is placed in the `TaskManager` priority queue.
+ """
+
+ def __init__(self, task):
+ self.task = task
+ self.valid = True
+ # managed tasks are compared by absolute execution time
+ self.execution_time = task.abs_execution_time()
+
+ def __repr__(self):
+ return "Task id={}".format(self.task.id)
+
+ def invalidate(self):
+ """
+ Tasks are invalidated so that they do not run,
+ instead of trying to remove them from the queue
+ """
+ self.valid = False
+
+ def run(self):
+ if not self.valid:
+ return False
+ """
+ Pass the task configuration JSON to run the task
+ """
+ return execute_task_config(self.task.config_json())
+
+ def is_repeating(self):
+ return self.task.is_repeating() and self.valid
+
+ def __cmp__(self, other):
+ if self.execution_time > other.execution_time:
+ return 1
+ elif self.execution_time < other.execution_time:
+ return -1
+ return 0
+
+ def __lt__(self, other):
+ return self.execution_time < other.execution_time
+
+ def __eq__(self, other):
+ return self.execution_time == other.execution_time
+
+
+class TaskThread(threading.Thread):
+ """
+ Thread subclass which supports passing a target
+ function and arguments.
+ """
+ def __init__(self, target, *args):
+ self.target = target
+ self.args = args
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.target(*self.args)
+
+
+# wrapper function for thread
+def _run_task(managed_task):
+ return managed_task.run()
+
+
+class TaskManager:
+
+ def __init__(self, existing_tasks=(), max_tasks=512, tick_delay=60):
+ self.task_queue = Q.PriorityQueue(max_tasks)
+ self.is_ticking = False
+ self.tick_delay = tick_delay
+ # keep a reference to existing tasks in the queue
+ self.managed_tasks = {}
+ # add existing tasks to the queue
+ for existing_task in existing_tasks:
+ self.enqueue_task(existing_task, False)
+
+ def enqueue_task(self, task, start_if_stopped=True):
+ if task.id in self.managed_tasks:
+ # task ID already maps to a task, cancel existing task first
+ self.cancel_task(task.id)
+ # create a new managed task and add to queue
+ managed_task = ManagedTask(task)
+ self.managed_tasks[task.id] = managed_task
+ self.task_queue.put_nowait(managed_task)
+ # start ticking if not already
+ if not self.is_ticking and start_if_stopped:
+ self.start_ticking()
+
+ def cancel_task(self, task_id):
+ if task_id not in self.managed_tasks:
+ # task ID does not exist, no task to cancel
+ return False
+ task = self.managed_tasks[task_id]
+ task.invalidate() # task will not run
+ del self.managed_tasks[task_id] # remove from task map
+
+ def _tick(self):
+ while self.is_ticking:
+ # process tasks if any
+ if len(self.task_queue.queue):
+ # remove incoming invalidated tasks
+ while (
+ len(self.task_queue.queue) and
+ not self.task_queue.queue[0].valid
+ ):
+ self.task_queue.get_nowait()
+ self.task_queue.task_done()
+ # if there are no more tasks, stop ticking
+ if not len(self.task_queue.queue):
+ self.is_ticking = False
+ break
+ # processing remaining valid tasks
+ time_now = time()
+ # run all tasks that need to be
+ while (
+ len(self.task_queue.queue) and
+ time_now >= self.task_queue.queue[0].execution_time
+ ):
+ # dequeue the next task
+ next_task = self.task_queue.get_nowait()
+ if next_task.valid:
+ # dispatch task into another thread
+ task_thread = TaskThread(_run_task, next_task)
+ task_thread.daemon = True
+ task_thread.start()
+ # remove from task map
+ del self.managed_tasks[next_task.task.id]
+ self.task_queue.task_done()
+ # if the task is repeating, push it to the queue
+ if next_task.is_repeating():
+ self.enqueue_task(next_task.task)
+ if not len(self.task_queue.queue):
+ self.is_ticking = False
+ break
+ else:
+ # stop ticking when no more tasks
+ self.is_ticking = False
+ break
+ sleep(self.tick_delay)
+
+ def start_ticking(self):
+ if len(self.task_queue.queue):
+ # only start ticking if not already doing so
+ if not self.is_ticking:
+ self.is_ticking = True
+ # self.thread is the thicking thread
+ self.thread = threading.Thread(target=self._tick)
+ self.thread.daemon = True
+ self.thread.start()
+ else:
+ self.is_ticking = False
+
+
+try:
+ _existing_tasks = db.session.query(CronTask).all()
+except:
+ _existing_tasks = []
+if 'task_manager' not in globals():
+ # tick every 20 seconds
+ task_manager = TaskManager(_existing_tasks, 1024, 20)
+ task_manager.start_ticking()
+_existing_tasks = None
diff --git a/superset/tasks/models.py b/superset/tasks/models.py
new file mode 100644
index 0000000000..d17261571a
--- /dev/null
+++ b/superset/tasks/models.py
@@ -0,0 +1,64 @@
+import json
+
+from flask_appbuilder import Model
+
+from sqlalchemy import (
+ Column, String, Text, Integer
+)
+
+from superset.utils import memoized
+from superset.models.helpers import AuditMixinNullable
+
+from datetime import datetime
+from time import time
+from crontab import CronTab
+
+from .utils import round_time
+
+
+class CronTask(Model, AuditMixinNullable):
+ """An scheduled and repeating task"""
+
+ __tablename__ = 'refresh_tasks'
+ id = Column(Integer, primary_key=True)
+ # crontab expression string
+ crontab_str = Column(String(120), nullable=False)
+ # JSON config for this task
+ config = Column(Text, nullable=False)
+ description = Column(String(250), nullable=True)
+
+ def __repr__(self):
+ return "{}: {}".format(str(self.id), self.crontab_str)
+
+ def is_repeating(self):
+ return True
+
+ @memoized
+ def get_perm(self):
+ return "[Task].(id:{})".format(self.id)
+
+ @memoized
+ def config_json(self):
+ return json.loads(self.config)
+
+ @memoized
+ def crontab_obj(self):
+ entry = CronTab(self.crontab_str)
+ return entry
+
+ def time_to_execution(self):
+ """Returns the time in seconds until this task executes"""
+ return self.crontab_obj().next()
+
+ def time_to_execution_nearest_sec(self):
+ return round(self.time_to_execution())
+
+ def abs_execution_time(self):
+ # execution time since epoch, rounded to nearest second
+ return round(self.time_to_execution() + time())
+
+ def next_execution_date(self):
+ """Returns the `datetime` of the next execution"""
+ timestamp = self.time_to_execution() + time()
+ date = datetime.fromtimestamp(timestamp)
+ return round_time(date)
diff --git a/superset/tasks/processor.py b/superset/tasks/processor.py
new file mode 100644
index 0000000000..def5cc8989
--- /dev/null
+++ b/superset/tasks/processor.py
@@ -0,0 +1,26 @@
+from .tasklist import task_list
+
+
+def validate_config(config):
+ """
+ Performs general configuration validation and then
+ passes the JSON to the specific task type for validation
+ """
+ if 'type' not in config:
+ raise ValueError("Task 'type' must be specified in config")
+ task_type = config['type']
+ if task_type not in task_list:
+ raise ValueError("Task type {} does not exist".format(task_type))
+ Task = task_list[task_type]
+ return Task.validate_task_config(config)
+
+
+def execute_task_config(config):
+ """
+ Create the appropriate task, passing it
+ the configuration and then run.
+ """
+ if config['type'] not in task_list:
+ return False # ignore tasks that do not exist
+ task_list[config['type']](config).execute()
+ return True
diff --git a/superset/tasks/tasklist.py b/superset/tasks/tasklist.py
new file mode 100644
index 0000000000..1217b84d15
--- /dev/null
+++ b/superset/tasks/tasklist.py
@@ -0,0 +1,60 @@
+from superset import db
+from superset.connectors.druid.models import DruidCluster
+
+
+class BaseTask:
+ """
+ Task base class, which other classes derive.
+ Subclasses must implement `execute` and `validate_config`.
+ """
+
+ def __init__(self, config):
+ self.config = config
+
+ def execute(self):
+ return False
+
+ @classmethod
+ def validate_task_config(cls, config):
+ return False
+
+
+class DruidClusterRefreshTask(BaseTask):
+ """
+ This task will trigger a refresh of Druid clusters
+ specified as an `array` in `clusters`. If `refresh_all` is
+ set to `false`, the task will only check for new datasources
+ instead of refreshing all metadata.
+ """
+
+ def __repr__(self):
+ return "druid_cluster_refresh"
+
+ @classmethod
+ def validate_task_config(cls, config):
+ if 'clusters' not in config:
+ raise ValueError("Key 'clusters' is not defined")
+ if not isinstance(config['clusters'], list):
+ raise TypeError("'clusters' should be a list")
+ if not len(config['clusters']):
+ raise ValueError("No clusters specified")
+ return True
+
+ def execute(self):
+ # grab the list of cluster names
+ cluster_names = self.config['clusters']
+ # whether to force refresh all, default `True`
+ refreshAll = self.config.get('refresh_all', True)
+ clusters = (
+ db.session.query(DruidCluster)
+ .filter(DruidCluster.cluster_name.in_(cluster_names))
+ )
+ # start refreshing
+ for cluster in clusters:
+ cluster.refresh_datasources(refreshAll=refreshAll)
+
+
+# list of task classes live here
+task_list = {
+ 'druid_cluster_refresh': DruidClusterRefreshTask,
+}
diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py
new file mode 100644
index 0000000000..a5c2b22b3e
--- /dev/null
+++ b/superset/tasks/utils.py
@@ -0,0 +1,32 @@
+import json
+
+from .processor import validate_config
+
+from crontab import CronTab
+from datetime import timedelta
+
+
+def is_valid_crontab_str(crontab_str):
+ """Validates a crontab expression"""
+ try:
+ CronTab(crontab_str)
+ except ValueError:
+ return False
+ return True
+
+
+def round_time(dt, roundTo=60):
+ """Rounds a datetime to, by default, the nearest minute"""
+ seconds = (dt.replace(tzinfo=None) - dt.min).seconds
+ rounding = (seconds + roundTo/2) // roundTo * roundTo
+ return dt + timedelta(0, rounding - seconds, -dt.microsecond)
+
+
+def is_valid_task_config(config):
+ """Performs basic JSON validation and then
+ passes the object for task validation"""
+ try:
+ config_json = json.loads(config)
+ except ValueError:
+ return False
+ return validate_config(config_json)
diff --git a/superset/tasks/views.py b/superset/tasks/views.py
new file mode 100644
index 0000000000..0f96f4cab5
--- /dev/null
+++ b/superset/tasks/views.py
@@ -0,0 +1,95 @@
+from flask_appbuilder.models.sqla.interface import SQLAInterface
+from flask_babel import gettext as __
+from flask_babel import lazy_gettext as _
+
+from superset import appbuilder
+from superset.views.base import SupersetModelView, DeleteMixin
+
+from .manager import task_manager
+from .models import CronTask
+from .utils import is_valid_crontab_str, is_valid_task_config
+
+
+class CronTaskModelView(SupersetModelView, DeleteMixin):
+ """Model view for all cron tasks"""
+ datamodel = SQLAInterface(CronTask)
+
+ list_title = _('List Cron Tasks')
+ show_title = _('Show Cron Task')
+ add_title = _('Create Cron Task')
+ edit_title = _('Edit Cron Task')
+
+ list_columns = [
+ 'next_execution_date',
+ 'created_on',
+ 'description',
+ ]
+ add_columns = [
+ 'crontab_str',
+ 'description',
+ 'config',
+ ]
+ edit_columns = add_columns
+ show_columns = [
+ 'description',
+ 'crontab_str',
+ 'next_execution_date',
+ 'time_to_execution_nearest_sec',
+ 'config',
+ 'created_by',
+ 'created_on',
+ 'changed_by',
+ 'changed_on',
+ ]
+ base_order = ('created_on', 'desc')
+ label_columns = {
+ 'description': _('Description'),
+ 'crontab_str': _('Crontab expression'),
+ 'next_execution_date': _('Next execution date'),
+ 'config': _('Task config'),
+ 'time_to_execution_nearest_sec': _('Time until next execution'),
+ }
+ description_columns = {
+ 'crontab_str': _(
+ 'The crontab expression describing when this task should run'),
+ 'config': _(
+ 'Configuration describing the `type` of task and its options'),
+ }
+
+ def pre_add(self, task):
+ # Validate crontab expression and config JSON
+ if not is_valid_crontab_str(task.crontab_str):
+ raise ValueError(
+ "Task has invalid crontab expression: {}"
+ .format(task.crontab_str)
+ )
+ if not is_valid_task_config(task.config):
+ raise ValueError(
+ "Task has invalid configuration json"
+ )
+
+ def post_add(self, task):
+ # After adding the task, immediately enqueue it
+ task_manager.enqueue_task(task)
+
+ def pre_delete(self, task):
+ # cancel the task before it is deleted
+ task_manager.cancel_task(task.id)
+
+ def pre_update(self, task):
+ self.pre_add(task)
+
+ def post_update(self, task):
+ self.post_add(task)
+
+ def _delete(self, pk):
+ DeleteMixin._delete(self, pk)
+
+
+appbuilder.add_view(
+ CronTaskModelView,
+ "Cron Tasks",
+ label=__("Cron Tasks"),
+ icon="fa-list-ul",
+ category="",
+ category_icon='')
diff --git a/superset/templates/appbuilder/general/model/list.html
b/superset/templates/appbuilder/general/model/list.html
index 73710fd2fe..cce10b66e7 100644
--- a/superset/templates/appbuilder/general/model/list.html
+++ b/superset/templates/appbuilder/general/model/list.html
@@ -15,4 +15,4 @@
</div>
{{ lib.panel_end() }}
-{% endblock %}
\ No newline at end of file
+{% endblock %}
diff --git a/superset/templates/appbuilder/general/widgets/base_list.html
b/superset/templates/appbuilder/general/widgets/base_list.html
index 69ebf14479..ddf5a6fcbc 100644
--- a/superset/templates/appbuilder/general/widgets/base_list.html
+++ b/superset/templates/appbuilder/general/widgets/base_list.html
@@ -28,16 +28,18 @@
{% block end_content scoped %}
{% endblock %}
-<div class="form-actions-container">
- {{ lib.render_actions(actions, modelview_name) }}
-</div>
+ <div class="form-actions-container">
+ {{ lib.render_actions(actions, modelview_name) }}
+ </div>
+
{{ lib.action_form(actions,modelview_name) }}
-<div class="pagination-container pull-right">
- <strong>{{ _('Record Count') }}:</strong> {{ count }}
- {{ lib.render_pagination(page, page_size, count, modelview_name) }}
- {{ lib.render_set_page_size(page, page_size, count, modelview_name) }}
-</div>
+ <div class="pagination-container pull-right">
+ <strong>{{ _('Record Count') }}:</strong> {{ count }}
+ {{ lib.render_pagination(page, page_size, count, modelview_name) }}
+ {{ lib.render_set_page_size(page, page_size, count, modelview_name) }}
+ </div>
+
<script language="javascript">
var modelActions = new AdminActions();
</script>
diff --git a/tests/task_tests.py b/tests/task_tests.py
new file mode 100644
index 0000000000..2986a242d1
--- /dev/null
+++ b/tests/task_tests.py
@@ -0,0 +1,361 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+
+from mock import Mock
+from datetime import datetime, timedelta
+from time import mktime, time
+
+from superset import db
+
+from .base_tests import SupersetTestCase
+
+from superset.tasks.manager import (
+ TaskManager, ManagedTask,
+ TaskThread, _run_task, task_manager
+)
+from superset.tasks.tasklist import (
+ BaseTask, DruidClusterRefreshTask
+)
+from superset.tasks.processor import (
+ validate_config, execute_task_config
+)
+from superset.tasks.utils import (
+ is_valid_crontab_str, round_time, is_valid_task_config
+)
+from superset.tasks.models import CronTask
+from superset.tasks.views import CronTaskModelView
+from superset.connectors.druid.models import DruidCluster
+
+
+class TasksTestCase(unittest.TestCase):
+ def test_execute_nonexisting_task(self):
+ fake_task_config = {'type': 'does_not_exist'}
+ result = execute_task_config(fake_task_config)
+ self.assertFalse(result)
+
+ def test_execute_task_and_validate_config(self):
+ invalid_config = {'badkey': 'badval'}
+ with self.assertRaises(ValueError):
+ validate_config(invalid_config)
+ invalid_config = {'type': 'badtype'}
+ with self.assertRaises(ValueError):
+ validate_config(invalid_config)
+
+ def test_druid_refresh_task_validate_config(self):
+ invalid_config = {'type': 'druid_cluster_refresh'}
+ with self.assertRaises(ValueError):
+ validate_config(invalid_config)
+ invalid_config['clusters'] = 'somestring'
+ with self.assertRaises(TypeError):
+ validate_config(invalid_config)
+ invalid_config['clusters'] = []
+ with self.assertRaises(ValueError):
+ validate_config(invalid_config)
+ valid_config = {
+ 'type': 'druid_cluster_refresh',
+ 'clusters': ['A', 'B'],
+ }
+ self.assertTrue(validate_config(valid_config))
+
+ def test_base_task(self):
+ base_task = BaseTask({'type': 'fake_task'})
+ self.assertFalse(base_task.execute())
+ self.assertFalse(BaseTask.validate_task_config(base_task.config))
+
+ def test_is_valid_config_json(self):
+ invalid_json_str = '{"invalid_key: invalid_va'
+ self.assertFalse(is_valid_task_config(invalid_json_str))
+ valid_json_str = (
+ '{"type": "druid_cluster_refresh", "clusters": ["a", "b"]}'
+ )
+ self.assertTrue(is_valid_task_config(valid_json_str))
+
+ def test_is_valid_crontab_str(self):
+ invalid_crontab_str = '5 * *'
+ self.assertFalse(is_valid_crontab_str(invalid_crontab_str))
+ valid_crontab_str = '5 10 * * *'
+ self.assertTrue(is_valid_crontab_str(valid_crontab_str))
+
+ def test_round_time(self):
+ test_datetime = datetime(2000, 1, 1, 5, 25, 20, 624555)
+ to_nearest_second = round_time(test_datetime, roundTo=1)
+ to_nearest_halfMinute = round_time(test_datetime, roundTo=30)
+ to_nearest_minute = round_time(test_datetime, roundTo=60)
+ self.assertEqual(to_nearest_second.microsecond, 0)
+ self.assertEqual(to_nearest_second.second, 20)
+ self.assertEqual(to_nearest_halfMinute.microsecond, 0)
+ self.assertEqual(to_nearest_halfMinute.second, 30)
+ self.assertEqual(to_nearest_halfMinute.minute, 25)
+ self.assertEqual(to_nearest_minute.microsecond, 0)
+ self.assertEqual(to_nearest_minute.second, 0)
+ self.assertEqual(to_nearest_minute.minute, 25)
+
+ def test_managed_task_class(self):
+ test_task = Mock()
+ test_task.abs_execution_time = Mock(return_value=10)
+ test_task.id = 99999
+ test_task.config_json = Mock(return_value={
+ 'type': 'druid_cluster_refresh',
+ 'clusters': ['__fake_cluster']
+ })
+ test_task.is_repeating = Mock(return_value=True)
+ managed_task = ManagedTask(test_task)
+ self.assertEqual('Task id=99999', managed_task.__repr__())
+ self.assertEqual(True, managed_task.is_repeating())
+ test_task2 = Mock()
+ test_task2.abs_execution_time = Mock(return_value=100)
+ managed_task2 = ManagedTask(test_task2)
+ self.assertEqual(-1, managed_task.__cmp__(managed_task2))
+ self.assertEqual(0, managed_task.__cmp__(managed_task))
+ self.assertEqual(1, managed_task2.__cmp__(managed_task))
+ self.assertTrue(managed_task.__lt__(managed_task2))
+ self.assertTrue(managed_task.__eq__(managed_task))
+ self.assertFalse(managed_task.__eq__(managed_task2))
+ self.assertTrue(managed_task.run())
+ managed_task.invalidate()
+ self.assertFalse(_run_task(managed_task))
+ self.assertFalse(managed_task.is_repeating())
+
+ def test_task_thread_class(self):
+ watcher = {'val': 0}
+
+ def test_target(arg1, arg2):
+ watcher['val'] += arg1 * arg2
+
+ test_thread = TaskThread(test_target, 5, 6)
+ test_thread.run()
+ self.assertEqual(30, watcher['val'])
+
+ def test_task_manager(self):
+ watcher = {
+ 'fake_id': 9990,
+ 'enqueue': 0
+ }
+
+ def get_fake_task(abs_time=0, repeat=False, id=None):
+ fake_task = Mock()
+ if not id:
+ fake_task.id = watcher['fake_id']
+ watcher['fake_id'] += 1
+ else:
+ fake_task.id = id
+ fake_task.abs_execution_time = Mock(
+ return_value=abs_time
+ )
+ fake_task.config_json = Mock(
+ return_value={'type': 'dummytask'}
+ )
+ fake_task.is_repeating = Mock(return_value=repeat)
+ return fake_task, fake_task.id
+
+ f_task1, fid1 = get_fake_task()
+ f_task2, fid2 = get_fake_task()
+ f_task3, fid3 = get_fake_task()
+ f_existing_tasks = [f_task1, f_task2, f_task3]
+ test_tm = TaskManager(f_existing_tasks, tick_delay=0)
+ self.assertEqual(3, len(test_tm.task_queue.queue))
+ self.assertEqual(3, len(test_tm.managed_tasks))
+ for fid in [fid1, fid2, fid3]:
+ self.assertIn(fid, test_tm.managed_tasks)
+ test_tm.enqueue_task(f_task1, False)
+ test_tm.enqueue_task(f_task2, False)
+ self.assertEqual(5, len(test_tm.task_queue.queue))
+ self.assertEqual(3, len(test_tm.managed_tasks))
+ self.assertFalse(test_tm.task_queue.queue[0].valid)
+ self.assertFalse(test_tm.task_queue.queue[1].valid)
+ self.assertTrue(test_tm.task_queue.queue[3].valid)
+ self.assertTrue(test_tm.task_queue.queue[4].valid)
+ self.assertFalse(test_tm.cancel_task(-5))
+ test_tm.cancel_task(fid1)
+ self.assertEqual(2, len(test_tm.managed_tasks))
+ self.assertEqual(5, len(test_tm.task_queue.queue))
+ self.assertFalse(test_tm.task_queue.queue[3].valid)
+ test_tm.is_ticking = True
+ test_tm._tick()
+ self.assertEqual(0, len(test_tm.managed_tasks))
+ self.assertEqual(0, len(test_tm.task_queue.queue))
+ self.assertFalse(test_tm.is_ticking)
+ test_tm.is_ticking = True
+ test_tm._tick()
+ self.assertFalse(test_tm.is_ticking)
+ test_tm.is_ticking = True
+ test_tm.start_ticking()
+ self.assertFalse(test_tm.is_ticking)
+ for f_task in [f_task1, f_task2, f_task3]:
+ test_tm.enqueue_task(f_task, False)
+ for fid in [fid1, fid2, fid3]:
+ test_tm.cancel_task(fid)
+ self.assertEqual(0, len(test_tm.managed_tasks))
+ test_tm.is_ticking = True
+ test_tm._tick()
+ self.assertEqual(0, len(test_tm.task_queue.queue))
+ f_task4, fid4 = get_fake_task(repeat=True)
+ test_tm.enqueue_task(f_task4, False)
+
+ def enqueue_side_effect(task):
+ watcher['enqueue'] += 1
+
+ test_tm.enqueue_task = Mock(
+ side_effect=enqueue_side_effect
+ )
+ test_tm.is_ticking = True
+ test_tm._tick()
+ self.assertEqual(watcher['enqueue'], 1)
+ self.assertEqual(0, len(test_tm.task_queue.queue))
+ test_tm = TaskManager(tick_delay=1)
+ f_task5, fid5 = get_fake_task(repeat=False)
+ datetime_now = datetime.now()
+ datetime_run = datetime_now + timedelta(0, 3)
+ f_runtime_5 = mktime(datetime_run.timetuple())
+ f_task5.abs_execution_time = Mock(return_value=f_runtime_5)
+ test_tm.enqueue_task(f_task5)
+ test_tm.thread.join()
+ self.assertEqual(0, len(test_tm.task_queue.queue))
+
+
+class DBTaskTestCase(SupersetTestCase):
+ def __init__(self, *args, **kwargs):
+ super(DBTaskTestCase, self).__init__(*args, **kwargs)
+
+ def test_cron_task_model(self):
+ self.login(username='admin')
+ crontask1 = (
+ db.session.query(CronTask)
+ .filter_by(id=99991)
+ .first()
+ )
+ crontask2 = (
+ db.session.query(CronTask)
+ .filter_by(id=99992)
+ .first()
+ )
+ if crontask1:
+ db.session.delete(crontask1)
+ if crontask2:
+ db.session.delete(crontask2)
+ db.session.commit()
+ crontask1 = CronTask(
+ id=99991,
+ crontab_str='30 * * * *',
+ config='{"type": "faketask"}',
+ description='fake test for testing',
+ )
+ crontask2 = CronTask(
+ id=99992,
+ crontab_str='45 * * * *',
+ config='{"type": "faketask"}',
+ description='another fake test for testing',
+ )
+ db.session.add(crontask1)
+ db.session.add(crontask2)
+ db.session.commit()
+ expected = '99991: 30 * * * *'
+ self.assertEqual(expected, crontask1.__repr__())
+ expected = '99992: 45 * * * *'
+ self.assertEqual(expected, crontask2.__repr__())
+ self.assertTrue(crontask1.is_repeating())
+ expected = '[Task].(id:99991)'
+ self.assertEqual(expected, crontask1.get_perm())
+ expected = '[Task].(id:99992)'
+ self.assertEqual(expected, crontask2.get_perm())
+ expected = {'type': 'faketask'}
+ self.assertEqual(expected, crontask1.config_json())
+ self.assertEqual(expected, crontask2.config_json())
+ cronobj1 = crontask1.crontab_obj()
+ cronobj2 = crontask2.crontab_obj()
+ half_hour = datetime(2000, 1, 1, 1, 30)
+ three_quarter_hour = datetime(2000, 1, 1, 1, 45)
+ self.assertTrue(cronobj1.test(half_hour))
+ self.assertTrue(cronobj2.test(three_quarter_hour))
+ quarter_hour = datetime(2000, 1, 1, 1, 15)
+ self.assertFalse(cronobj1.test(quarter_hour))
+ self.assertFalse(cronobj2.test(quarter_hour))
+ self.assertTrue(crontask1.time_to_execution() > 0)
+ time_to_exec_sec = crontask1.time_to_execution_nearest_sec()
+ self.assertEqual(time_to_exec_sec, round(time_to_exec_sec))
+ timestamp_now = time()
+ self.assertTrue(crontask1.abs_execution_time() >= round(timestamp_now))
+ abs_exec_time = crontask1.abs_execution_time()
+ self.assertEqual(round(abs_exec_time), abs_exec_time)
+ self.assertTrue(crontask1.next_execution_date() >= datetime.now())
+ self.logout()
+
+ def test_crontask_model_view(self):
+ f_modelview = CronTaskModelView()
+ fake_task = Mock()
+ fake_task.crontab_str = '5 * *'
+ with self.assertRaises(ValueError):
+ f_modelview.pre_update(fake_task)
+ fake_task.crontab_str = '* * * * *'
+ fake_task.config = '{"invalid"'
+ with self.assertRaises(ValueError):
+ f_modelview.pre_update(fake_task)
+ fake_task.config = '{"type": "faketask"}'
+ f_modelview.post_update(fake_task)
+ self.assertTrue(task_manager.is_ticking)
+ self.assertEqual(1, len(task_manager.task_queue.queue))
+ f_modelview.pre_delete(fake_task)
+ self.assertEqual(0, len(task_manager.managed_tasks))
+
+ def test_execute_druid_refresh_task(self):
+ self.login(username='admin')
+ cluster1 = (
+ db.session.query(DruidCluster)
+ .filter_by(cluster_name='test_cluster1')
+ .first()
+ )
+ if cluster1:
+ db.session.delete(cluster1)
+ db.session.commit()
+ cluster2 = (
+ db.session.query(DruidCluster)
+ .filter_by(cluster_name='test_cluster2')
+ .first()
+ )
+ if cluster2:
+ db.session.delete(cluster2)
+ db.session.commit()
+
+ cluster1 = DruidCluster(
+ cluster_name='test_cluster1',
+ coordinator_host='localhost',
+ coordinator_port=7979,
+ broker_host='localhost',
+ broker_port=7980,
+ metadata_last_refreshed=datetime.now()
+ )
+ cluster2 = DruidCluster(
+ cluster_name='test_cluster2',
+ coordinator_host='localhost',
+ coordinator_port=8080,
+ broker_host='localhost',
+ broker_port=8880,
+ metadata_last_refreshed=datetime.now()
+ )
+ db.session.add(cluster1)
+ db.session.add(cluster2)
+ refresh_count = {'val': 0}
+
+ def refresh_side_effect(refreshAll):
+ refresh_count['val'] += 1
+ cluster1.refresh_datasources = Mock(
+ return_value=True,
+ side_effect=refresh_side_effect)
+ cluster2.refresh_datasources = Mock(
+ return_value=True,
+ side_effect=refresh_side_effect)
+ db.session.commit()
+ task_config = {
+ 'type': 'druid_cluster_refresh',
+ 'clusters': ['test_cluster1', 'test_cluster2']
+ }
+ self.assertTrue(execute_task_config(task_config))
+ self.assertEqual(refresh_count['val'], 2)
+ self.assertEqual('druid_cluster_refresh', DruidClusterRefreshTask({
+ 'type': 'druid_cluster_refresh'
+ }).__repr__())
+ self.logout()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services