This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 72627b1 Adding YAML Import-Export for Datasources to CLI (#3978)
72627b1 is described below
commit 72627b1761718d95f1fb3f6da9e9b5d3ad1e4065
Author: fabianmenges <[email protected]>
AuthorDate: Tue Dec 5 14:14:52 2017 -0500
Adding YAML Import-Export for Datasources to CLI (#3978)
* Adding import and export for databases
* Linting
---
docs/import_export_datasources.rst | 103 +++++++++++
setup.py | 2 +
superset/cli.py | 81 ++++++++-
superset/connectors/druid/models.py | 30 +++-
superset/connectors/druid/views.py | 6 +-
superset/connectors/sqla/models.py | 15 +-
superset/connectors/sqla/views.py | 4 +-
superset/dict_import_export_util.py | 65 +++++++
superset/models/core.py | 8 +-
superset/models/helpers.py | 172 ++++++++++++++++++
superset/views/base.py | 24 +++
superset/views/core.py | 14 +-
tests/dict_import_export_tests.py | 350 ++++++++++++++++++++++++++++++++++++
tests/import_export_tests.py | 4 +-
14 files changed, 848 insertions(+), 30 deletions(-)
diff --git a/docs/import_export_datasources.rst
b/docs/import_export_datasources.rst
new file mode 100644
index 0000000..38d9101
--- /dev/null
+++ b/docs/import_export_datasources.rst
@@ -0,0 +1,103 @@
+Importing and Exporting Datasources
+===================================
+
+The superset cli allows you to import and export datasources from and to YAML.
+Datasources include both databases and druid clusters. The data is expected to
be organized in the following hierarchy: ::
+
+ .
+ ├──databases
+ | ├──database_1
+ | | ├──table_1
+ | | | ├──columns
+ | | | | ├──column_1
+ | | | | ├──column_2
+ | | | | └──... (more columns)
+ | | | └──metrics
+ | | | ├──metric_1
+ | | | ├──metric_2
+ | | | └──... (more metrics)
+ | | └── ... (more tables)
+ | └── ... (more databases)
+ └──druid_clusters
+ ├──cluster_1
+ | ├──datasource_1
+ | | ├──columns
+ | | | ├──column_1
+ | | | ├──column_2
+ | | | └──... (more columns)
+ | | └──metrics
+ | | ├──metric_1
+ | | ├──metric_2
+ | | └──... (more metrics)
+ | └── ... (more datasources)
+ └── ... (more clusters)
+
+
+Exporting Datasources to YAML
+-----------------------------
+You can print your current datasources to stdout by running: ::
+
+ superset export_datasources
+
+
+To save your datasources to a file run: ::
+
+ superset export_datasources -f <filename>
+
+
+By default, default (null) values will be omitted. Use the ``-d`` flag to
include them.
+If you want back references to be included (e.g. a column to include the table
id
+it belongs to) use the ``-b`` flag.
+
+Alternatively you can export datasources using the UI: ::
+
+1. Open **Sources** -> **Databases** to export all tables associated to a
single or multiple databases. (**Tables** for one or more tables, **Druid
Clusters** for clusters, **Druid Datasources** for datasources)
+2. Select the items you would like to export
+3. Click **Actions** -> **Export to YAML**
+4. If you want to import an item that you exported through the UI, you will
need to nest it inside its parent element, e.g. a `database` needs to be nested
under `databases` a `table` needs to be nested inside a `database` element.
+
+Exporting the complete supported YAML schema
+--------------------------------------------
+In order to obtain an exhaustive list of all fields you can import using the
YAML import run: ::
+
+ superset export_datasource_schema
+
+Again, you can use the ``-b`` flag to include back references.
+
+
+Importing Datasources from YAML
+-------------------------------
+In order to import datasources from a YAML file(s), run: ::
+
+ superset import_datasources -p <path or filename>
+
+If you supply a path all files ending with ``*.yaml`` or ``*.yml`` will be
parsed.
+You can apply additional falgs e.g.: ::
+
+ superset import_datasources -p <path> -r
+
+Will search the supplied path recursively.
+
+The sync flag ``-s`` takes parameters in order to sync the supplied elements
with
+your file. Be careful this can delete the contents of your meta database.
Example:
+
+ superset import_datasources -p <path / filename> -s columns,metrics
+
+This will sync all ``metrics`` and ``columns`` for all datasources found in the
+``<path / filename>`` in the Superset meta database. This means columns and
metrics
+not specified in YAML will be deleted. If you would add ``tables`` to
``columns,metrics``
+those would be synchronised as well.
+
+
+If you don't supply the sync flag (``-s``) importing will only add and update
(override) fields.
+E.g. you can add a ``verbose_name`` to the the column ``ds`` in the table
``random_time_series`` from the example datasets
+by saving the following YAML to file and then running the
``import_datasources`` command. ::
+
+ databases:
+ - database_name: main
+ tables:
+ - table_name: random_time_series
+ columns:
+ - column_name: ds
+ verbose_name: datetime
+
diff --git a/setup.py b/setup.py
index ec12dea..1feebf0 100644
--- a/setup.py
+++ b/setup.py
@@ -64,9 +64,11 @@ setup(
'markdown==2.6.8',
'pandas==0.20.3',
'parsedatetime==2.0.0',
+ 'pathlib2==2.3.0',
'pydruid==0.3.1',
'PyHive>=0.4.0',
'python-dateutil==2.6.0',
+ 'pyyaml>=3.11',
'requests==2.17.3',
'simplejson==3.10.0',
'six==1.10.0',
diff --git a/superset/cli.py b/superset/cli.py
index ddbdf4b..16500ac 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -7,12 +7,15 @@ from __future__ import unicode_literals
from datetime import datetime
import logging
from subprocess import Popen
+from sys import stdout
from colorama import Fore, Style
from flask_migrate import MigrateCommand
from flask_script import Manager
+from pathlib2 import Path
+import yaml
-from superset import app, db, security, utils
+from superset import app, db, dict_import_export_util, security, utils
config = app.config
celery_app = utils.get_celery_app(config)
@@ -178,6 +181,82 @@ def refresh_druid(datasource, merge):
session.commit()
[email protected](
+ '-p', '--path', dest='path',
+ help='Path to a single YAML file or path containing multiple YAML '
+ 'files to import (*.yaml or *.yml)')
[email protected](
+ '-s', '--sync', dest='sync', default='',
+ help='comma seperated list of element types to synchronize '
+ 'e.g. "metrics,columns" deletes metrics and columns in the DB '
+ 'that are not specified in the YAML file')
[email protected](
+ '-r', '--recursive', dest='recursive', action='store_true',
+ help='recursively search the path for yaml files')
+def import_datasources(path, sync, recursive=False):
+ """Import datasources from YAML"""
+ sync_array = sync.split(',')
+ p = Path(path)
+ files = []
+ if p.is_file():
+ files.append(p)
+ elif p.exists() and not recursive:
+ files.extend(p.glob('*.yaml'))
+ files.extend(p.glob('*.yml'))
+ elif p.exists() and recursive:
+ files.extend(p.rglob('*.yaml'))
+ files.extend(p.rglob('*.yml'))
+ for f in files:
+ logging.info('Importing datasources from file %s', f)
+ try:
+ with f.open() as data_stream:
+ dict_import_export_util.import_from_dict(
+ db.session,
+ yaml.load(data_stream),
+ sync=sync_array)
+ except Exception as e:
+ logging.error('Error when importing datasources from file %s', f)
+ logging.error(e)
+
+
[email protected](
+ '-f', '--datasource-file', default=None, dest='datasource_file',
+ help='Specify the the file to export to')
[email protected](
+ '-p', '--print', action='store_true', dest='print_stdout',
+ help='Print YAML to stdout')
[email protected](
+ '-b', '--back-references', action='store_true', dest='back_references',
+ help='Include parent back references')
[email protected](
+ '-d', '--include-defaults', action='store_true', dest='include_defaults',
+ help='Include fields containing defaults')
+def export_datasources(print_stdout, datasource_file,
+ back_references, include_defaults):
+ """Export datasources to YAML"""
+ data = dict_import_export_util.export_to_dict(
+ session=db.session,
+ recursive=True,
+ back_references=back_references,
+ include_defaults=include_defaults)
+ if print_stdout or not datasource_file:
+ yaml.safe_dump(data, stdout, default_flow_style=False)
+ if datasource_file:
+ logging.info('Exporting datasources to %s', datasource_file)
+ with open(datasource_file, 'w') as data_stream:
+ yaml.safe_dump(data, data_stream, default_flow_style=False)
+
+
[email protected](
+ '-b', '--back-references', action='store_false',
+ help='Include parent back references')
+def export_datasource_schema(back_references):
+ """Export datasource YAML schema to stdout"""
+ data = dict_import_export_util.export_schema_to_dict(
+ back_references=back_references)
+ yaml.safe_dump(data, stdout, default_flow_style=False)
+
+
@manager.command
def update_datasources_cache():
"""Refresh sqllab datasources cache"""
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 8444515..bf7e176 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -28,7 +28,9 @@ from sqlalchemy.orm import backref, relationship
from superset import conf, db, import_util, sm, utils
from superset.connectors.base.models import BaseColumn, BaseDatasource,
BaseMetric
-from superset.models.helpers import AuditMixinNullable, QueryResult, set_perm
+from superset.models.helpers import (
+ AuditMixinNullable, ImportMixin, QueryResult, set_perm,
+)
from superset.utils import (
DimSelector, DTTM_ALIAS, flasher, MetricPermException,
)
@@ -60,7 +62,7 @@ class CustomPostAggregator(Postaggregator):
self.post_aggregator = post_aggregator
-class DruidCluster(Model, AuditMixinNullable):
+class DruidCluster(Model, AuditMixinNullable, ImportMixin):
"""ORM object referencing the Druid clusters"""
@@ -81,6 +83,11 @@ class DruidCluster(Model, AuditMixinNullable):
metadata_last_refreshed = Column(DateTime)
cache_timeout = Column(Integer)
+ export_fields = ('cluster_name', 'coordinator_host', 'coordinator_port',
+ 'coordinator_endpoint', 'broker_host', 'broker_port',
+ 'broker_endpoint', 'cache_timeout')
+ export_children = ['datasources']
+
def __repr__(self):
return self.verbose_name if self.verbose_name else self.cluster_name
@@ -219,6 +226,7 @@ class DruidColumn(Model, BaseColumn):
"""ORM model for storing Druid datasource column metadata"""
__tablename__ = 'columns'
+ __table_args__ = (UniqueConstraint('column_name', 'datasource_id'),)
datasource_id = Column(
Integer,
@@ -233,8 +241,9 @@ class DruidColumn(Model, BaseColumn):
export_fields = (
'datasource_id', 'column_name', 'is_active', 'type', 'groupby',
'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable',
- 'description', 'dimension_spec_json',
+ 'description', 'dimension_spec_json', 'verbose_name',
)
+ export_parent = 'datasource'
def __repr__(self):
return self.column_name
@@ -360,6 +369,7 @@ class DruidMetric(Model, BaseMetric):
"""ORM object referencing Druid metrics for a datasource"""
__tablename__ = 'metrics'
+ __table_args__ = (UniqueConstraint('metric_name', 'datasource_id'),)
datasource_id = Column(
Integer,
ForeignKey('datasources.id'))
@@ -374,6 +384,7 @@ class DruidMetric(Model, BaseMetric):
'metric_name', 'verbose_name', 'metric_type', 'datasource_id',
'json', 'description', 'is_restricted', 'd3format',
)
+ export_parent = 'datasource'
@property
def expression(self):
@@ -409,6 +420,7 @@ class DruidDatasource(Model, BaseDatasource):
"""ORM object referencing Druid datasources (tables)"""
__tablename__ = 'datasources'
+ __table_args__ = (UniqueConstraint('datasource_name', 'cluster_name'),)
type = 'druid'
query_langtage = 'json'
@@ -438,6 +450,9 @@ class DruidDatasource(Model, BaseDatasource):
'cluster_name', 'offset', 'cache_timeout', 'params',
)
+ export_parent = 'cluster'
+ export_children = ['columns', 'metrics']
+
@property
def database(self):
return self.cluster
@@ -556,9 +571,12 @@ class DruidDatasource(Model, BaseDatasource):
v2nums = [int_or_0(n) for n in v2.split('.')]
v1nums = (v1nums + [0, 0, 0])[:3]
v2nums = (v2nums + [0, 0, 0])[:3]
- return v1nums[0] > v2nums[0] or \
- (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or \
- (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and v1nums[2] >
v2nums[2])
+ return (
+ v1nums[0] > v2nums[0] or
+ (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or
+ (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and
+ v1nums[2] > v2nums[2])
+ )
def latest_metadata(self):
"""Returns segment metadata from the latest segment"""
diff --git a/superset/connectors/druid/views.py
b/superset/connectors/druid/views.py
index 713a43c..ad3664b 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -14,7 +14,7 @@ from superset.utils import has_access
from superset.views.base import (
BaseSupersetView, DatasourceFilter, DeleteMixin,
get_datasource_exist_error_mgs, ListWidgetWithCheckboxes,
SupersetModelView,
- validate_json,
+ validate_json, YamlExportMixin,
)
from . import models
@@ -122,7 +122,7 @@ class DruidMetricInlineView(CompactCRUDMixin,
SupersetModelView): # noqa
appbuilder.add_view_no_menu(DruidMetricInlineView)
-class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa
+class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin):
# noqa
datamodel = SQLAInterface(models.DruidCluster)
list_title = _('List Druid Cluster')
@@ -168,7 +168,7 @@ appbuilder.add_view(
category_icon='fa-database',)
-class DruidDatasourceModelView(DatasourceModelView, DeleteMixin): # noqa
+class DruidDatasourceModelView(DatasourceModelView, DeleteMixin,
YamlExportMixin): # noqa
datamodel = SQLAInterface(models.DruidDatasource)
list_title = _('List Druid Datasource')
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 5990622..7e276a6 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -13,6 +13,7 @@ from sqlalchemy import (
select, String, Text,
)
from sqlalchemy.orm import backref, relationship
+from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, literal_column, table, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
@@ -31,6 +32,7 @@ class TableColumn(Model, BaseColumn):
"""ORM object for table columns, each table can have multiple columns"""
__tablename__ = 'table_columns'
+ __table_args__ = (UniqueConstraint('table_id', 'column_name'),)
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
@@ -47,6 +49,7 @@ class TableColumn(Model, BaseColumn):
'filterable', 'expression', 'description', 'python_date_format',
'database_expression',
)
+ export_parent = 'table'
@property
def sqla_col(self):
@@ -120,6 +123,7 @@ class SqlMetric(Model, BaseMetric):
"""ORM object for metrics, each table can have multiple metrics"""
__tablename__ = 'sql_metrics'
+ __table_args__ = (UniqueConstraint('table_id', 'metric_name'),)
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
@@ -130,6 +134,7 @@ class SqlMetric(Model, BaseMetric):
export_fields = (
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
'description', 'is_restricted', 'd3format')
+ export_parent = 'table'
@property
def sqla_col(self):
@@ -162,6 +167,8 @@ class SqlaTable(Model, BaseDatasource):
column_class = TableColumn
__tablename__ = 'tables'
+ __table_args__ = (UniqueConstraint('database_id', 'table_name'),)
+
table_name = Column(String(250))
main_dttm_col = Column(String(250))
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
@@ -179,15 +186,13 @@ class SqlaTable(Model, BaseDatasource):
sql = Column(Text)
baselink = 'tablemodelview'
+
export_fields = (
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
'database_id', 'offset', 'cache_timeout', 'schema',
'sql', 'params')
-
- __table_args__ = (
- sa.UniqueConstraint(
- 'database_id', 'schema', 'table_name',
- name='_customer_location_uc'),)
+ export_parent = 'database'
+ export_children = ['metrics', 'columns']
def __repr__(self):
return self.name
diff --git a/superset/connectors/sqla/views.py
b/superset/connectors/sqla/views.py
index 47358fa..2f394ce 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -12,7 +12,7 @@ from superset.connectors.base.views import DatasourceModelView
from superset.utils import has_access
from superset.views.base import (
DatasourceFilter, DeleteMixin, get_datasource_exist_error_mgs,
- ListWidgetWithCheckboxes, SupersetModelView,
+ ListWidgetWithCheckboxes, SupersetModelView, YamlExportMixin,
)
from . import models
@@ -148,7 +148,7 @@ class SqlMetricInlineView(CompactCRUDMixin,
SupersetModelView): # noqa
appbuilder.add_view_no_menu(SqlMetricInlineView)
-class TableModelView(DatasourceModelView, DeleteMixin): # noqa
+class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): #
noqa
datamodel = SQLAInterface(models.SqlaTable)
list_title = _('List Tables')
diff --git a/superset/dict_import_export_util.py
b/superset/dict_import_export_util.py
new file mode 100644
index 0000000..26cfc5d
--- /dev/null
+++ b/superset/dict_import_export_util.py
@@ -0,0 +1,65 @@
+import logging
+
+from superset.connectors.druid.models import DruidCluster
+from superset.models.core import Database
+
+
+DATABASES_KEY = 'databases'
+DRUID_CLUSTERS_KEY = 'druid_clusters'
+
+
+def export_schema_to_dict(back_references):
+ """Exports the supported import/export schema to a dictionary"""
+ databases = [Database.export_schema(recursive=True,
+ include_parent_ref=back_references)]
+ clusters = [DruidCluster.export_schema(recursive=True,
+ include_parent_ref=back_references)]
+ data = dict()
+ if databases:
+ data[DATABASES_KEY] = databases
+ if clusters:
+ data[DRUID_CLUSTERS_KEY] = clusters
+ return data
+
+
+def export_to_dict(session,
+ recursive,
+ back_references,
+ include_defaults):
+ """Exports databases and druid clusters to a dictionary"""
+ logging.info('Starting export')
+ dbs = session.query(Database)
+ databases = [database.export_to_dict(recursive=recursive,
+ include_parent_ref=back_references,
+ include_defaults=include_defaults) for database in dbs]
+ logging.info('Exported %d %s', len(databases), DATABASES_KEY)
+ cls = session.query(DruidCluster)
+ clusters = [cluster.export_to_dict(recursive=recursive,
+ include_parent_ref=back_references,
+ include_defaults=include_defaults) for cluster in cls]
+ logging.info('Exported %d %s', len(clusters), DRUID_CLUSTERS_KEY)
+ data = dict()
+ if databases:
+ data[DATABASES_KEY] = databases
+ if clusters:
+ data[DRUID_CLUSTERS_KEY] = clusters
+ return data
+
+
+def import_from_dict(session, data, sync=[]):
+ """Imports databases and druid clusters from dictionary"""
+ if isinstance(data, dict):
+ logging.info('Importing %d %s',
+ len(data.get(DATABASES_KEY, [])),
+ DATABASES_KEY)
+ for database in data.get(DATABASES_KEY, []):
+ Database.import_from_dict(session, database, sync=sync)
+
+ logging.info('Importing %d %s',
+ len(data.get(DRUID_CLUSTERS_KEY, [])),
+ DRUID_CLUSTERS_KEY)
+ for datasource in data.get(DRUID_CLUSTERS_KEY, []):
+ DruidCluster.import_from_dict(session, datasource, sync=sync)
+ session.commit()
+ else:
+ logging.info('Supplied object is not a dictionary.')
diff --git a/superset/models/core.py b/superset/models/core.py
index 68c305f..2c6e8b0 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -28,6 +28,7 @@ from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import relationship, subqueryload
from sqlalchemy.orm.session import make_transient
from sqlalchemy.pool import NullPool
+from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy_utils import EncryptedType
@@ -537,12 +538,13 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
})
-class Database(Model, AuditMixinNullable):
+class Database(Model, AuditMixinNullable, ImportMixin):
"""An ORM object that stores Database related information"""
__tablename__ = 'dbs'
type = 'table'
+ __table_args__ = (UniqueConstraint('database_name'),)
id = Column(Integer, primary_key=True)
verbose_name = Column(String(250), unique=True)
@@ -567,6 +569,10 @@ class Database(Model, AuditMixinNullable):
perm = Column(String(1000))
custom_password_store = config.get('SQLALCHEMY_CUSTOM_PASSWORD_STORE')
impersonate_user = Column(Boolean, default=False)
+ export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout',
+ 'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
+ 'allow_ctas', 'extra')
+ export_children = ['tables']
def __repr__(self):
return self.verbose_name if self.verbose_name else self.database_name
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index d4ae9f4..02b2cf2 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -6,6 +6,7 @@ from __future__ import unicode_literals
from datetime import datetime
import json
+import logging
import re
from flask import escape, Markup
@@ -13,13 +14,184 @@ from flask_appbuilder.models.decorators import renders
from flask_appbuilder.models.mixins import AuditMixin
import humanize
import sqlalchemy as sa
+from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm.exc import MultipleResultsFound
+import yaml
from superset import sm
from superset.utils import QueryStatus
class ImportMixin(object):
+ export_parent = None
+ # The name of the attribute
+ # with the SQL Alchemy back reference
+
+ export_children = []
+ # List of (str) names of attributes
+ # with the SQL Alchemy forward references
+
+ export_fields = []
+ # The names of the attributes
+ # that are available for import and export
+
+ @classmethod
+ def _parent_foreign_key_mappings(cls):
+ """Get a mapping of foreign name to the local name of foreign keys"""
+ parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
+ if parent_rel:
+ return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
+ return {}
+
+ @classmethod
+ def _unique_constrains(cls):
+ """Get all (single column and multi column) unique constraints"""
+ unique = [{c.name for c in u.columns} for u in cls.__table_args__
+ if isinstance(u, UniqueConstraint)]
+ unique.extend({c.name} for c in cls.__table__.columns if c.unique)
+ return unique
+
+ @classmethod
+ def export_schema(cls, recursive=True, include_parent_ref=False):
+ """Export schema as a dictionary"""
+ parent_excludes = {}
+ if not include_parent_ref:
+ parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+ if parent_ref:
+ parent_excludes = {c.name for c in parent_ref.local_columns}
+
+ def formatter(c): return ('{0} Default ({1})'.format(
+ str(c.type), c.default.arg) if c.default else str(c.type))
+
+ schema = {c.name: formatter(c) for c in cls.__table__.columns
+ if (c.name in cls.export_fields and
+ c.name not in parent_excludes)}
+ if recursive:
+ for c in cls.export_children:
+ child_class = cls.__mapper__.relationships[c].argument.class_
+ schema[c] = [child_class.export_schema(recursive=recursive,
+ include_parent_ref=include_parent_ref)]
+ return schema
+
+ @classmethod
+ def import_from_dict(cls, session, dict_rep, parent=None,
+ recursive=True, sync=[]):
+ """Import obj from a dictionary"""
+ parent_refs = cls._parent_foreign_key_mappings()
+ export_fields = set(cls.export_fields) | set(parent_refs.keys())
+ new_children = {c: dict_rep.get(c) for c in cls.export_children
+ if c in dict_rep}
+ unique_constrains = cls._unique_constrains()
+
+ filters = [] # Using these filters to check if obj already exists
+
+ # Remove fields that should not get imported
+ for k in list(dict_rep):
+ if k not in export_fields:
+ del dict_rep[k]
+
+ if not parent:
+ if cls.export_parent:
+ for p in parent_refs.keys():
+ if p not in dict_rep:
+ raise RuntimeError(
+ '{0}: Missing field {1}'.format(cls.__name__, p))
+ else:
+ # Set foreign keys to parent obj
+ for k, v in parent_refs.items():
+ dict_rep[k] = getattr(parent, v)
+
+ # Add filter for parent obj
+ filters.extend([getattr(cls, k) == dict_rep.get(k)
+ for k in parent_refs.keys()])
+
+ # Add filter for unique constraints
+ ucs = [and_(*[getattr(cls, k) == dict_rep.get(k)
+ for k in cs if dict_rep.get(k) is not None])
+ for cs in unique_constrains]
+ filters.append(or_(*ucs))
+
+ # Check if object already exists in DB, break if more than one is found
+ try:
+ obj_query = session.query(cls).filter(and_(*filters))
+ obj = obj_query.one_or_none()
+ except MultipleResultsFound as e:
+ logging.error('Error importing %s \n %s \n %s', cls.__name__,
+ str(obj_query),
+ yaml.safe_dump(dict_rep))
+ raise e
+
+ if not obj:
+ is_new_obj = True
+ # Create new DB object
+ obj = cls(**dict_rep)
+ logging.info('Importing new %s %s', obj.__tablename__, str(obj))
+ if cls.export_parent and parent:
+ setattr(obj, cls.export_parent, parent)
+ session.add(obj)
+ else:
+ is_new_obj = False
+ logging.info('Updating %s %s', obj.__tablename__, str(obj))
+ # Update columns
+ for k, v in dict_rep.items():
+ setattr(obj, k, v)
+
+ # Recursively create children
+ if recursive:
+ for c in cls.export_children:
+ child_class = cls.__mapper__.relationships[c].argument.class_
+ added = []
+ for c_obj in new_children.get(c, []):
+ added.append(child_class.import_from_dict(session=session,
+ dict_rep=c_obj,
+ parent=obj,
+ sync=sync))
+ # If children should get synced, delete the ones that did not
+ # get updated.
+ if c in sync and not is_new_obj:
+ back_refs = child_class._parent_foreign_key_mappings()
+ delete_filters = [getattr(child_class, k) ==
+ getattr(obj, back_refs.get(k))
+ for k in back_refs.keys()]
+ to_delete = set(session.query(child_class).filter(
+ and_(*delete_filters))).difference(set(added))
+ for o in to_delete:
+ logging.info('Deleting %s %s', c, str(obj))
+ session.delete(o)
+
+ return obj
+
+ def export_to_dict(self, recursive=True, include_parent_ref=False,
+ include_defaults=False):
+ """Export obj to dictionary"""
+ cls = self.__class__
+ parent_excludes = {}
+ if recursive and not include_parent_ref:
+ parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+ if parent_ref:
+ parent_excludes = {c.name for c in parent_ref.local_columns}
+ dict_rep = {c.name: getattr(self, c.name)
+ for c in cls.__table__.columns
+ if (c.name in self.export_fields and
+ c.name not in parent_excludes and
+ (include_defaults or (
+ getattr(self, c.name) is not None and
+ (not c.default or
+ getattr(self, c.name) != c.default.arg))))
+ }
+ if recursive:
+ for c in self.export_children:
+ # sorting to make lists of children stable
+ dict_rep[c] = sorted([child.export_to_dict(
+ recursive=recursive,
+ include_parent_ref=include_parent_ref,
+ include_defaults=include_defaults)
+ for child in getattr(self, c)],
+ key=lambda k: sorted(k.items()))
+
+ return dict_rep
+
def override(self, obj):
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
diff --git a/superset/views/base.py b/superset/views/base.py
index 7bc55d2..a909ed0 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -1,3 +1,4 @@
+from datetime import datetime
import functools
import json
import logging
@@ -11,6 +12,7 @@ from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale
from flask_babel import gettext as __
from flask_babel import lazy_gettext as _
+import yaml
from superset import appbuilder, conf, db, sm, sql_parse, utils
from superset.connectors.connector_registry import ConnectorRegistry
@@ -41,6 +43,15 @@ def json_error_response(msg=None, status=500,
stacktrace=None, payload=None):
status=status, mimetype='application/json')
+def generate_download_headers(extension, filename=None):
+ filename = filename if filename else
datetime.now().strftime('%Y%m%d_%H%M%S')
+ content_disp = 'attachment; filename={}.{}'.format(filename, extension)
+ headers = {
+ 'Content-Disposition': content_disp,
+ }
+ return headers
+
+
def api(f):
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
@@ -219,6 +230,19 @@ def validate_json(form, field): # noqa
raise Exception(_("json isn't valid"))
+class YamlExportMixin(object):
+ @action('yaml_export', __('Export to YAML'), __('Export to YAML?'),
'fa-download')
+ def yaml_export(self, items):
+ if not isinstance(items, list):
+ items = [items]
+
+ data = [t.export_to_dict() for t in items]
+ return Response(
+ yaml.safe_dump(data),
+ headers=generate_download_headers('yaml'),
+ mimetype='application/text')
+
+
class DeleteMixin(object):
def _delete(self, pk):
"""
diff --git a/superset/views/core.py b/superset/views/core.py
index 8677a17..e221cb8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -44,8 +44,9 @@ from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
from superset.utils import has_access, merge_extra_filters, QueryStatus
from .base import (
- api, BaseSupersetView, CsvResponse, DeleteMixin, get_error_msg,
- get_user_roles, json_error_response, SupersetFilter, SupersetModelView,
+ api, BaseSupersetView, CsvResponse, DeleteMixin,
+ generate_download_headers, get_error_msg, get_user_roles,
+ json_error_response, SupersetFilter, SupersetModelView, YamlExportMixin,
)
config = app.config
@@ -161,16 +162,9 @@ class DashboardFilter(SupersetFilter):
return query
-def generate_download_headers(extension):
- filename = datetime.now().strftime('%Y%m%d_%H%M%S')
- content_disp = 'attachment; filename={}.{}'.format(filename, extension)
- headers = {
- 'Content-Disposition': content_disp,
- }
- return headers
-class DatabaseView(SupersetModelView, DeleteMixin): # noqa
+class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.Database)
list_title = _('List Databases')
diff --git a/tests/dict_import_export_tests.py
b/tests/dict_import_export_tests.py
new file mode 100644
index 0000000..592930a
--- /dev/null
+++ b/tests/dict_import_export_tests.py
@@ -0,0 +1,350 @@
+"""Unit tests for Superset"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import json
+import unittest
+
+import yaml
+
+from superset import db
+from superset.connectors.druid.models import (
+ DruidColumn, DruidDatasource, DruidMetric)
+from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+from .base_tests import SupersetTestCase
+
+DBREF = 'dict_import__export_test'
+NAME_PREFIX = 'dict_'
+ID_PREFIX = 20000
+
+
+class DictImportExportTests(SupersetTestCase):
+ """Testing export import functionality for dashboards"""
+
+ def __init__(self, *args, **kwargs):
+ super(DictImportExportTests, self).__init__(*args, **kwargs)
+
+ @classmethod
+ def delete_imports(cls):
+ # Imported data clean up
+ session = db.session
+ for table in session.query(SqlaTable):
+ if DBREF in table.params_dict:
+ session.delete(table)
+ for datasource in session.query(DruidDatasource):
+ if DBREF in datasource.params_dict:
+ session.delete(datasource)
+ session.commit()
+
+ @classmethod
+ def setUpClass(cls):
+ cls.delete_imports()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.delete_imports()
+
+ def create_table(
+ self, name, schema='', id=0, cols_names=[], metric_names=[]):
+ database_name = 'main'
+ name = '{0}{1}'.format(NAME_PREFIX, name)
+ params = {DBREF: id, 'database_name': database_name}
+
+ dict_rep = {
+ 'database_id': self.get_main_database(db.session).id,
+ 'table_name': name,
+ 'schema': schema,
+ 'id': id,
+ 'params': json.dumps(params),
+ 'columns': [{'column_name': c}
+ for c in cols_names],
+ 'metrics': [{'metric_name': c} for c in metric_names],
+ }
+
+ table = SqlaTable(
+ id=id,
+ schema=schema,
+ table_name=name,
+ params=json.dumps(params),
+ )
+ for col_name in cols_names:
+ table.columns.append(TableColumn(column_name=col_name))
+ for metric_name in metric_names:
+ table.metrics.append(SqlMetric(metric_name=metric_name))
+ return table, dict_rep
+
+ def create_druid_datasource(
+ self, name, id=0, cols_names=[], metric_names=[]):
+ name = '{0}{1}'.format(NAME_PREFIX, name)
+ cluster_name = 'druid_test'
+ params = {DBREF: id, 'database_name': cluster_name}
+ dict_rep = {
+ 'cluster_name': cluster_name,
+ 'datasource_name': name,
+ 'id': id,
+ 'params': json.dumps(params),
+ 'columns': [{'column_name': c} for c in cols_names],
+ 'metrics': [{'metric_name': c} for c in metric_names],
+ }
+
+ datasource = DruidDatasource(
+ id=id,
+ datasource_name=name,
+ cluster_name=cluster_name,
+ params=json.dumps(params),
+ )
+ for col_name in cols_names:
+ datasource.columns.append(DruidColumn(column_name=col_name))
+ for metric_name in metric_names:
+ datasource.metrics.append(DruidMetric(metric_name=metric_name))
+ return datasource, dict_rep
+
+ def get_datasource(self, datasource_id):
+ return db.session.query(DruidDatasource).filter_by(
+ id=datasource_id).first()
+
+ def get_table_by_name(self, name):
+ return db.session.query(SqlaTable).filter_by(
+ table_name=name).first()
+
+ def yaml_compare(self, obj_1, obj_2):
+ obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
+ obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)
+ self.assertEquals(obj_1_str, obj_2_str)
+
+ def assert_table_equals(self, expected_ds, actual_ds):
+ self.assertEquals(expected_ds.table_name, actual_ds.table_name)
+ self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
+ self.assertEquals(expected_ds.schema, actual_ds.schema)
+ self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
+ self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
+ self.assertEquals(
+ set([c.column_name for c in expected_ds.columns]),
+ set([c.column_name for c in actual_ds.columns]))
+ self.assertEquals(
+ set([m.metric_name for m in expected_ds.metrics]),
+ set([m.metric_name for m in actual_ds.metrics]))
+
+ def assert_datasource_equals(self, expected_ds, actual_ds):
+ self.assertEquals(
+ expected_ds.datasource_name, actual_ds.datasource_name)
+ self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
+ self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
+ self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
+ self.assertEquals(
+ set([c.column_name for c in expected_ds.columns]),
+ set([c.column_name for c in actual_ds.columns]))
+ self.assertEquals(
+ set([m.metric_name for m in expected_ds.metrics]),
+ set([m.metric_name for m in actual_ds.metrics]))
+
+ def test_import_table_no_metadata(self):
+ table, dict_table = self.create_table('pure_table', id=ID_PREFIX + 1)
+ new_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ imported_id = new_table.id
+ imported = self.get_table(imported_id)
+ self.assert_table_equals(table, imported)
+ self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+ def test_import_table_1_col_1_met(self):
+ table, dict_table = self.create_table(
+ 'table_1_col_1_met', id=ID_PREFIX + 2,
+ cols_names=['col1'], metric_names=['metric1'])
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ imported = self.get_table(imported_table.id)
+ self.assert_table_equals(table, imported)
+ self.assertEquals(
+ {DBREF: ID_PREFIX + 2, 'database_name': 'main'},
+ json.loads(imported.params))
+ self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+ def test_import_table_2_col_2_met(self):
+ table, dict_table = self.create_table(
+ 'table_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'],
+ metric_names=['m1', 'm2'])
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ imported = self.get_table(imported_table.id)
+ self.assert_table_equals(table, imported)
+ self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+ def test_import_table_override_append(self):
+ table, dict_table = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ cols_names=['col1'],
+ metric_names=['m1'])
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ table_over, dict_table_over = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_over_table = SqlaTable.import_from_dict(
+ db.session,
+ dict_table_over)
+ db.session.commit()
+
+ imported_over = self.get_table(imported_over_table.id)
+ self.assertEquals(imported_table.id, imported_over.id)
+ expected_table, _ = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ metric_names=['new_metric1', 'm1'],
+ cols_names=['col1', 'new_col1', 'col2', 'col3'])
+ self.assert_table_equals(expected_table, imported_over)
+ self.yaml_compare(expected_table.export_to_dict(),
+ imported_over.export_to_dict())
+
+ def test_import_table_override_sync(self):
+ table, dict_table = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ cols_names=['col1'],
+ metric_names=['m1'])
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ table_over, dict_table_over = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_over_table = SqlaTable.import_from_dict(
+ session=db.session,
+ dict_rep=dict_table_over,
+ sync=['metrics', 'columns'])
+ db.session.commit()
+
+ imported_over = self.get_table(imported_over_table.id)
+ self.assertEquals(imported_table.id, imported_over.id)
+ expected_table, _ = self.create_table(
+ 'table_override', id=ID_PREFIX + 3,
+ metric_names=['new_metric1'],
+ cols_names=['new_col1', 'col2', 'col3'])
+ self.assert_table_equals(expected_table, imported_over)
+ self.yaml_compare(
+ expected_table.export_to_dict(),
+ imported_over.export_to_dict())
+
+ def test_import_table_override_identical(self):
+ table, dict_table = self.create_table(
+ 'copy_cat', id=ID_PREFIX + 4,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ db.session.commit()
+ copy_table, dict_copy_table = self.create_table(
+ 'copy_cat', id=ID_PREFIX + 4,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_copy_table = SqlaTable.import_from_dict(db.session,
+ dict_copy_table)
+ db.session.commit()
+ self.assertEquals(imported_table.id, imported_copy_table.id)
+ self.assert_table_equals(copy_table, self.get_table(imported_table.id))
+ self.yaml_compare(imported_copy_table.export_to_dict(),
+ imported_table.export_to_dict())
+
+ def test_import_druid_no_metadata(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'pure_druid', id=ID_PREFIX + 1)
+ imported_cluster = DruidDatasource.import_from_dict(db.session,
+ dict_datasource)
+ db.session.commit()
+ imported = self.get_datasource(imported_cluster.id)
+ self.assert_datasource_equals(datasource, imported)
+
+ def test_import_druid_1_col_1_met(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'druid_1_col_1_met', id=ID_PREFIX + 2,
+ cols_names=['col1'], metric_names=['metric1'])
+ imported_cluster = DruidDatasource.import_from_dict(db.session,
+ dict_datasource)
+ db.session.commit()
+ imported = self.get_datasource(imported_cluster.id)
+ self.assert_datasource_equals(datasource, imported)
+ self.assertEquals(
+ {DBREF: ID_PREFIX + 2, 'database_name': 'druid_test'},
+ json.loads(imported.params))
+
+ def test_import_druid_2_col_2_met(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'druid_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'],
+ metric_names=['m1', 'm2'])
+ imported_cluster = DruidDatasource.import_from_dict(db.session,
+ dict_datasource)
+ db.session.commit()
+ imported = self.get_datasource(imported_cluster.id)
+ self.assert_datasource_equals(datasource, imported)
+
+ def test_import_druid_override_append(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3, cols_names=['col1'],
+ metric_names=['m1'])
+ imported_cluster = DruidDatasource.import_from_dict(db.session,
+ dict_datasource)
+ db.session.commit()
+ table_over, table_over_dict = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_over_cluster = DruidDatasource.import_from_dict(
+ db.session,
+ table_over_dict)
+ db.session.commit()
+ imported_over = self.get_datasource(imported_over_cluster.id)
+ self.assertEquals(imported_cluster.id, imported_over.id)
+ expected_datasource, _ = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3,
+ metric_names=['new_metric1', 'm1'],
+ cols_names=['col1', 'new_col1', 'col2', 'col3'])
+ self.assert_datasource_equals(expected_datasource, imported_over)
+
+ def test_import_druid_override_sync(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3, cols_names=['col1'],
+ metric_names=['m1'])
+ imported_cluster = DruidDatasource.import_from_dict(
+ db.session,
+ dict_datasource)
+ db.session.commit()
+ table_over, table_over_dict = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_over_cluster = DruidDatasource.import_from_dict(
+ session=db.session,
+ dict_rep=table_over_dict,
+ sync=['metrics', 'columns']) # syncing metrics and columns
+ db.session.commit()
+ imported_over = self.get_datasource(imported_over_cluster.id)
+ self.assertEquals(imported_cluster.id, imported_over.id)
+ expected_datasource, _ = self.create_druid_datasource(
+ 'druid_override', id=ID_PREFIX + 3,
+ metric_names=['new_metric1'],
+ cols_names=['new_col1', 'col2', 'col3'])
+ self.assert_datasource_equals(expected_datasource, imported_over)
+
+ def test_import_druid_override_identical(self):
+ datasource, dict_datasource = self.create_druid_datasource(
+ 'copy_cat', id=ID_PREFIX + 4,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported = DruidDatasource.import_from_dict(session=db.session,
+ dict_rep=dict_datasource)
+ db.session.commit()
+ copy_datasource, dict_cp_datasource = self.create_druid_datasource(
+ 'copy_cat', id=ID_PREFIX + 4,
+ cols_names=['new_col1', 'col2', 'col3'],
+ metric_names=['new_metric1'])
+ imported_copy = DruidDatasource.import_from_dict(db.session,
+ dict_cp_datasource)
+ db.session.commit()
+
+ self.assertEquals(imported.id, imported_copy.id)
+ self.assert_datasource_equals(
+ copy_datasource, self.get_datasource(imported.id))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index 2a9f069..d51b959 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -441,7 +441,7 @@ class ImportExportTests(SupersetTestCase):
cols_names=['col1', 'new_col1', 'col2', 'col3'])
self.assert_table_equals(expected_table, imported_over)
- def test_import_table_override_idential(self):
+ def test_import_table_override_identical(self):
table = self.create_table(
'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
@@ -505,7 +505,7 @@ class ImportExportTests(SupersetTestCase):
cols_names=['col1', 'new_col1', 'col2', 'col3'])
self.assert_datasource_equals(expected_datasource, imported_over)
- def test_import_druid_override_idential(self):
+ def test_import_druid_override_identical(self):
datasource = self.create_druid_datasource(
'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'],
metric_names=['new_metric1'])
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].