This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new d65b039 Improve examples & related tests (#7773)
d65b039 is described below
commit d65b039219c9825ad50ec03ad73a1638710c73c9
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue Jul 16 21:36:56 2019 -0700
Improve examples & related tests (#7773)
* [WiP] improve load_examples
related to #7472, longer term we will generate the examples by exporting
them into tarball as in #7472. In the meantime, we need this subset of
the features:
* allowing specifying an alternate database connection for examples
* allowing a --only-metadata flag to `load_examples` to load only
dashboard and chart definitions, no actual data is loaded
* Improve logging
* Rename data->examples
* Load only if not exist
* By default do not load, add a force flag
* fix build
* set published to true
---
MANIFEST.in | 2 +-
superset/cli.py | 57 ++++++++------
superset/config.py | 4 +
superset/connectors/connector_registry.py | 15 +---
superset/connectors/druid/models.py | 10 +++
superset/connectors/sqla/models.py | 15 ++++
superset/{data => examples}/__init__.py | 0
superset/{data => examples}/bart_lines.py | 49 ++++++------
superset/{data => examples}/birth_names.py | 60 ++++++++-------
superset/{data => examples}/countries.md | 0
superset/{data => examples}/countries.py | 0
superset/{data => examples}/country_map.py | 74 +++++++++---------
superset/{data => examples}/css_templates.py | 0
superset/{data => examples}/deck.py | 1 +
superset/{data => examples}/energy.py | 37 +++++----
superset/{data => examples}/flights.py | 59 +++++++-------
superset/{data => examples}/helpers.py | 2 +-
superset/{data => examples}/long_lat.py | 89 ++++++++++++----------
superset/{data => examples}/misc_dashboard.py | 0
superset/{data => examples}/multi_line.py | 6 +-
.../{data => examples}/multiformat_time_series.py | 64 +++++++++-------
superset/{data => examples}/paris.py | 42 +++++-----
superset/{data => examples}/random_time_series.py | 41 +++++-----
.../{data => examples}/sf_population_polygons.py | 42 +++++-----
superset/{data => examples}/tabbed_dashboard.py | 21 +----
superset/{data => examples}/unicode_test_data.py | 61 ++++++++-------
superset/{data => examples}/world_bank.py | 47 +++++++-----
superset/models/core.py | 7 +-
superset/models/tags.py | 2 +-
superset/tasks/cache.py | 20 +++--
superset/utils/core.py | 38 +++++----
superset/viz.py | 2 +-
tests/access_tests.py | 4 +-
tests/base_tests.py | 11 +--
tests/celery_tests.py | 10 +--
tests/core_tests.py | 35 ++++-----
tests/db_engine_specs_test.py | 11 +--
tests/dict_import_export_tests.py | 2 +-
tests/import_export_tests.py | 43 ++++++-----
tests/load_examples_test.py | 13 ++--
tests/model_tests.py | 10 +--
tests/sqllab_tests.py | 2 +-
tests/strategy_tests.py | 63 +++++++--------
tests/viz_tests.py | 1 -
tox.ini | 2 +-
45 files changed, 583 insertions(+), 491 deletions(-)
diff --git a/MANIFEST.in b/MANIFEST.in
index f06d131..363faac 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -18,7 +18,7 @@ include NOTICE
include LICENSE.txt
graft licenses/
include README.md
-recursive-include superset/data *
+recursive-include superset/examples *
recursive-include superset/migrations *
recursive-include superset/static *
recursive-exclude superset/static/assets/docs *
diff --git a/superset/cli.py b/superset/cli.py
index cb363c2..c4f8309 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -26,7 +26,7 @@ from colorama import Fore, Style
from pathlib2 import Path
import yaml
-from superset import app, appbuilder, data, db, security_manager
+from superset import app, appbuilder, db, examples, security_manager
from superset.utils import core as utils, dashboard_import_export,
dict_import_export
config = app.config
@@ -46,6 +46,7 @@ def make_shell_context():
def init():
"""Inits the Superset application"""
utils.get_or_create_main_db()
+ utils.get_example_database()
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@@ -67,66 +68,76 @@ def version(verbose):
print(Style.RESET_ALL)
-def load_examples_run(load_test_data):
- print("Loading examples into {}".format(db))
+def load_examples_run(load_test_data, only_metadata=False, force=False):
+ if only_metadata:
+ print("Loading examples metadata")
+ else:
+ examples_db = utils.get_example_database()
+ print(f"Loading examples metadata and related data into {examples_db}")
- data.load_css_templates()
+ examples.load_css_templates()
print("Loading energy related dataset")
- data.load_energy()
+ examples.load_energy(only_metadata, force)
print("Loading [World Bank's Health Nutrition and Population Stats]")
- data.load_world_bank_health_n_pop()
+ examples.load_world_bank_health_n_pop(only_metadata, force)
print("Loading [Birth names]")
- data.load_birth_names()
+ examples.load_birth_names(only_metadata, force)
print("Loading [Unicode test data]")
- data.load_unicode_test_data()
+ examples.load_unicode_test_data(only_metadata, force)
if not load_test_data:
print("Loading [Random time series data]")
- data.load_random_time_series_data()
+ examples.load_random_time_series_data(only_metadata, force)
print("Loading [Random long/lat data]")
- data.load_long_lat_data()
+ examples.load_long_lat_data(only_metadata, force)
print("Loading [Country Map data]")
- data.load_country_map_data()
+ examples.load_country_map_data(only_metadata, force)
print("Loading [Multiformat time series]")
- data.load_multiformat_time_series()
+ examples.load_multiformat_time_series(only_metadata, force)
print("Loading [Paris GeoJson]")
- data.load_paris_iris_geojson()
+ examples.load_paris_iris_geojson(only_metadata, force)
print("Loading [San Francisco population polygons]")
- data.load_sf_population_polygons()
+ examples.load_sf_population_polygons(only_metadata, force)
print("Loading [Flights data]")
- data.load_flights()
+ examples.load_flights(only_metadata, force)
print("Loading [BART lines]")
- data.load_bart_lines()
+ examples.load_bart_lines(only_metadata, force)
print("Loading [Multi Line]")
- data.load_multi_line()
+ examples.load_multi_line(only_metadata)
print("Loading [Misc Charts] dashboard")
- data.load_misc_dashboard()
+ examples.load_misc_dashboard()
print("Loading DECK.gl demo")
- data.load_deck_dash()
+ examples.load_deck_dash()
print("Loading [Tabbed dashboard]")
- data.load_tabbed_dashboard()
+ examples.load_tabbed_dashboard(only_metadata)
@app.cli.command()
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional
test data")
-def load_examples(load_test_data):
[email protected](
+ "--only-metadata", "-m", is_flag=True, help="Only load metadata, skip
actual data"
+)
[email protected](
+ "--force", "-f", is_flag=True, help="Force load data even if table already
exists"
+)
+def load_examples(load_test_data, only_metadata=False, force=False):
"""Loads a set of Slices and Dashboards and a supporting dataset """
- load_examples_run(load_test_data)
+ load_examples_run(load_test_data, only_metadata, force)
@app.cli.command()
@@ -405,7 +416,7 @@ def load_test_users_run():
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
- db_perm = utils.get_main_database(security_manager.get_session).perm
+ db_perm = utils.get_main_database().perm
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name="database_access"
diff --git a/superset/config.py b/superset/config.py
index 7679d20..b676e51 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -617,6 +617,10 @@ TALISMAN_CONFIG = {
"force_https_permanent": False,
}
+# URI to database storing the example data, points to
+# SQLALCHEMY_DATABASE_URI by default if set to `None`
+SQLALCHEMY_EXAMPLES_URI = None
+
try:
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not in pythonpath; useful
diff --git a/superset/connectors/connector_registry.py
b/superset/connectors/connector_registry.py
index be31a37..d5e951a 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -55,18 +55,9 @@ class ConnectorRegistry(object):
cls, session, datasource_type, datasource_name, schema, database_name
):
datasource_class = ConnectorRegistry.sources[datasource_type]
- datasources = session.query(datasource_class).all()
-
- # Filter datasoures that don't have database.
- db_ds = [
- d
- for d in datasources
- if d.database
- and d.database.name == database_name
- and d.name == datasource_name
- and schema == schema
- ]
- return db_ds[0]
+ return datasource_class.get_datasource_by_name(
+ session, datasource_name, schema, database_name
+ )
@classmethod
def query_datasources_by_permissions(cls, session, database, permissions):
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 3f81ca7..6a0873c 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -732,6 +732,16 @@ class DruidDatasource(Model, BaseDatasource):
return 6 * 24 * 3600 * 1000 # 6 days
return 0
+ @classmethod
+ def get_datasource_by_name(cls, session, datasource_name, schema,
database_name):
+ query = (
+ session.query(cls)
+ .join(DruidCluster)
+ .filter(cls.datasource_name == datasource_name)
+ .filter(DruidCluster.cluster_name == database_name)
+ )
+ return query.first()
+
# uses https://en.wikipedia.org/wiki/ISO_8601
# http://druid.io/docs/0.8.0/querying/granularities.html
# TODO: pass origin from the UI
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 6805777..a3f5690 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -374,6 +374,21 @@ class SqlaTable(Model, BaseDatasource):
def database_name(self):
return self.database.name
+ @classmethod
+ def get_datasource_by_name(cls, session, datasource_name, schema,
database_name):
+ schema = schema or None
+ query = (
+ session.query(cls)
+ .join(Database)
+ .filter(cls.table_name == datasource_name)
+ .filter(Database.database_name == database_name)
+ )
+ # Handling schema being '' or None, which is easier to handle
+ # in python than in the SQLA query in a multi-dialect way
+ for tbl in query.all():
+ if schema == (tbl.schema or None):
+ return tbl
+
@property
def link(self):
name = escape(self.name)
diff --git a/superset/data/__init__.py b/superset/examples/__init__.py
similarity index 100%
rename from superset/data/__init__.py
rename to superset/examples/__init__.py
diff --git a/superset/data/bart_lines.py b/superset/examples/bart_lines.py
similarity index 56%
rename from superset/data/bart_lines.py
rename to superset/examples/bart_lines.py
index 8e615fc..203a60e 100644
--- a/superset/data/bart_lines.py
+++ b/superset/examples/bart_lines.py
@@ -21,37 +21,42 @@ import polyline
from sqlalchemy import String, Text
from superset import db
-from superset.utils.core import get_or_create_main_db
-from .helpers import TBL, get_example_data
+from superset.utils.core import get_example_database
+from .helpers import get_example_data, TBL
-def load_bart_lines():
+def load_bart_lines(only_metadata=False, force=False):
tbl_name = "bart_lines"
- content = get_example_data("bart-lines.json.gz")
- df = pd.read_json(content, encoding="latin-1")
- df["path_json"] = df.path.map(json.dumps)
- df["polyline"] = df.path.map(polyline.encode)
- del df["path"]
+ database = get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ content = get_example_data("bart-lines.json.gz")
+ df = pd.read_json(content, encoding="latin-1")
+ df["path_json"] = df.path.map(json.dumps)
+ df["polyline"] = df.path.map(polyline.encode)
+ del df["path"]
+
+ df.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "color": String(255),
+ "name": String(255),
+ "polyline": Text,
+ "path_json": Text,
+ },
+ index=False,
+ )
- df.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "color": String(255),
- "name": String(255),
- "polyline": Text,
- "path_json": Text,
- },
- index=False,
- )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "BART lines"
- tbl.database = get_or_create_main_db()
+ tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
diff --git a/superset/data/birth_names.py b/superset/examples/birth_names.py
similarity index 93%
rename from superset/data/birth_names.py
rename to superset/examples/birth_names.py
index 9040847..1fcdc55 100644
--- a/superset/data/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -23,7 +23,7 @@ from sqlalchemy.sql import column
from superset import db, security_manager
from superset.connectors.sqla.models import SqlMetric, TableColumn
-from superset.utils.core import get_or_create_main_db
+from superset.utils.core import get_example_database
from .helpers import (
config,
Dash,
@@ -36,33 +36,39 @@ from .helpers import (
)
-def load_birth_names():
+def load_birth_names(only_metadata=False, force=False):
"""Loading birth name dataset from a zip file in the repo"""
- data = get_example_data("birth_names.json.gz")
- pdf = pd.read_json(data)
- pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
- pdf.to_sql(
- "birth_names",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "ds": DateTime,
- "gender": String(16),
- "state": String(10),
- "name": String(255),
- },
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
+ # pylint: disable=too-many-locals
+ tbl_name = "birth_names"
+ database = get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ pdf = pd.read_json(get_example_data("birth_names.json.gz"))
+ pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "ds": DateTime,
+ "gender": String(16),
+ "state": String(10),
+ "name": String(255),
+ },
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
- print("Creating table [birth_names] reference")
- obj = db.session.query(TBL).filter_by(table_name="birth_names").first()
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="birth_names")
+ print(f"Creating table [{tbl_name}] reference")
+ obj = TBL(table_name=tbl_name)
+ db.session.add(obj)
obj.main_dttm_col = "ds"
- obj.database = get_or_create_main_db()
+ obj.database = database
obj.filter_select_enabled = True
if not any(col.column_name == "num_california" for col in obj.columns):
@@ -79,7 +85,6 @@ def load_birth_names():
col = str(column("num").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="sum__num",
expression=f"SUM({col})"))
- db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
@@ -384,10 +389,12 @@ def load_birth_names():
merge_slice(slc)
print("Creating a dashboard")
- dash = db.session.query(Dash).filter_by(dashboard_title="Births").first()
+ dash = db.session.query(Dash).filter_by(slug="births").first()
if not dash:
dash = Dash()
+ db.session.add(dash)
+ dash.published = True
js = textwrap.dedent(
# pylint: disable=line-too-long
"""\
@@ -649,5 +656,4 @@ def load_birth_names():
dash.dashboard_title = "Births"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = "births"
- db.session.merge(dash)
db.session.commit()
diff --git a/superset/data/countries.md b/superset/examples/countries.md
similarity index 100%
rename from superset/data/countries.md
rename to superset/examples/countries.md
diff --git a/superset/data/countries.py b/superset/examples/countries.py
similarity index 100%
rename from superset/data/countries.py
rename to superset/examples/countries.py
diff --git a/superset/data/country_map.py b/superset/examples/country_map.py
similarity index 61%
rename from superset/data/country_map.py
rename to superset/examples/country_map.py
index d2b12cf..d664e4b 100644
--- a/superset/data/country_map.py
+++ b/superset/examples/country_map.py
@@ -33,44 +33,50 @@ from .helpers import (
)
-def load_country_map_data():
+def load_country_map_data(only_metadata=False, force=False):
"""Loading data for map with country map"""
- csv_bytes = get_example_data(
- "birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True
- )
- data = pd.read_csv(csv_bytes, encoding="utf-8")
- data["dttm"] = datetime.datetime.now().date()
- data.to_sql( # pylint: disable=no-member
- "birth_france_by_region",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "DEPT_ID": String(10),
- "2003": BigInteger,
- "2004": BigInteger,
- "2005": BigInteger,
- "2006": BigInteger,
- "2007": BigInteger,
- "2008": BigInteger,
- "2009": BigInteger,
- "2010": BigInteger,
- "2011": BigInteger,
- "2012": BigInteger,
- "2013": BigInteger,
- "2014": BigInteger,
- "dttm": Date(),
- },
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
+ tbl_name = "birth_france_by_region"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ csv_bytes = get_example_data(
+ "birth_france_data_for_country_map.csv", is_gzip=False,
make_bytes=True
+ )
+ data = pd.read_csv(csv_bytes, encoding="utf-8")
+ data["dttm"] = datetime.datetime.now().date()
+ data.to_sql( # pylint: disable=no-member
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "DEPT_ID": String(10),
+ "2003": BigInteger,
+ "2004": BigInteger,
+ "2005": BigInteger,
+ "2006": BigInteger,
+ "2007": BigInteger,
+ "2008": BigInteger,
+ "2009": BigInteger,
+ "2010": BigInteger,
+ "2011": BigInteger,
+ "2012": BigInteger,
+ "2013": BigInteger,
+ "2014": BigInteger,
+ "dttm": Date(),
+ },
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
+
print("Creating table reference")
- obj =
db.session.query(TBL).filter_by(table_name="birth_france_by_region").first()
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="birth_france_by_region")
+ obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "dttm"
- obj.database = utils.get_or_create_main_db()
+ obj.database = database
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004",
expression=f"AVG({col})"))
diff --git a/superset/data/css_templates.py b/superset/examples/css_templates.py
similarity index 100%
rename from superset/data/css_templates.py
rename to superset/examples/css_templates.py
diff --git a/superset/data/deck.py b/superset/examples/deck.py
similarity index 99%
rename from superset/data/deck.py
rename to superset/examples/deck.py
index 974be90..6cd0441 100644
--- a/superset/data/deck.py
+++ b/superset/examples/deck.py
@@ -501,6 +501,7 @@ def load_deck_dash():
if not dash:
dash = Dash()
+ dash.published = True
js = POSITION_JSON
pos = json.loads(js)
update_slice_ids(pos, slices)
diff --git a/superset/data/energy.py b/superset/examples/energy.py
similarity index 85%
rename from superset/data/energy.py
rename to superset/examples/energy.py
index a3edb2b..00dae11 100644
--- a/superset/data/energy.py
+++ b/superset/examples/energy.py
@@ -25,36 +25,33 @@ from sqlalchemy.sql import column
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.utils import core as utils
-from .helpers import (
- DATA_FOLDER,
- get_example_data,
- merge_slice,
- misc_dash_slices,
- Slice,
- TBL,
-)
+from .helpers import get_example_data, merge_slice, misc_dash_slices, Slice,
TBL
-def load_energy():
+def load_energy(only_metadata=False, force=False):
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
- data = get_example_data("energy.json.gz")
- pdf = pd.read_json(data)
- pdf.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={"source": String(255), "target": String(255), "value": Float()},
- index=False,
- )
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("energy.json.gz")
+ pdf = pd.read_json(data)
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={"source": String(255), "target": String(255), "value":
Float()},
+ index=False,
+ )
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Energy consumption"
- tbl.database = utils.get_or_create_main_db()
+ tbl.database = database
if not any(col.metric_name == "sum__value" for col in tbl.metrics):
col = str(column("value").compile(db.engine))
diff --git a/superset/data/flights.py b/superset/examples/flights.py
similarity index 52%
rename from superset/data/flights.py
rename to superset/examples/flights.py
index 8876d54..d386ae2 100644
--- a/superset/data/flights.py
+++ b/superset/examples/flights.py
@@ -22,38 +22,45 @@ from superset.utils import core as utils
from .helpers import get_example_data, TBL
-def load_flights():
+def load_flights(only_metadata=False, force=False):
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
- data = get_example_data("flight_data.csv.gz", make_bytes=True)
- pdf = pd.read_csv(data, encoding="latin-1")
-
- # Loading airports info to join and get lat/long
- airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
- airports = pd.read_csv(airports_bytes, encoding="latin-1")
- airports = airports.set_index("IATA_CODE")
-
- pdf["ds"] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" +
pdf.DAY.map(str)
- pdf.ds = pd.to_datetime(pdf.ds)
- del pdf["YEAR"]
- del pdf["MONTH"]
- del pdf["DAY"]
-
- pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
- pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
- pdf.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={"ds": DateTime},
- index=False,
- )
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("flight_data.csv.gz", make_bytes=True)
+ pdf = pd.read_csv(data, encoding="latin-1")
+
+ # Loading airports info to join and get lat/long
+ airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
+ airports = pd.read_csv(airports_bytes, encoding="latin-1")
+ airports = airports.set_index("IATA_CODE")
+
+ pdf["ds"] = (
+ pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" +
pdf.DAY.map(str)
+ )
+ pdf.ds = pd.to_datetime(pdf.ds)
+ del pdf["YEAR"]
+ del pdf["MONTH"]
+ del pdf["DAY"]
+
+ pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
+ pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={"ds": DateTime},
+ index=False,
+ )
+
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Random set of flights in the US"
- tbl.database = utils.get_or_create_main_db()
+ tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
diff --git a/superset/data/helpers.py b/superset/examples/helpers.py
similarity index 97%
rename from superset/data/helpers.py
rename to superset/examples/helpers.py
index a6dd734..cff7da6 100644
--- a/superset/data/helpers.py
+++ b/superset/examples/helpers.py
@@ -38,7 +38,7 @@ TBL = ConnectorRegistry.sources["table"]
config = app.config
-DATA_FOLDER = os.path.join(config.get("BASE_DIR"), "data")
+EXAMPLES_FOLDER = os.path.join(config.get("BASE_DIR"), "examples")
misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard
diff --git a/superset/data/long_lat.py b/superset/examples/long_lat.py
similarity index 50%
rename from superset/data/long_lat.py
rename to superset/examples/long_lat.py
index 4ecb618..28b8dbb 100644
--- a/superset/data/long_lat.py
+++ b/superset/examples/long_lat.py
@@ -33,52 +33,59 @@ from .helpers import (
)
-def load_long_lat_data():
+def load_long_lat_data(only_metadata=False, force=False):
"""Loading lat/long data from a csv file in the repo"""
- data = get_example_data("san_francisco.csv.gz", make_bytes=True)
- pdf = pd.read_csv(data, encoding="utf-8")
- start = datetime.datetime.now().replace(hour=0, minute=0, second=0,
microsecond=0)
- pdf["datetime"] = [
- start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
- for i in range(len(pdf))
- ]
- pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
- pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
- pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x),
axis=1)
- pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str),
sep=",")
- pdf.to_sql( # pylint: disable=no-member
- "long_lat",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "longitude": Float(),
- "latitude": Float(),
- "number": Float(),
- "street": String(100),
- "unit": String(10),
- "city": String(50),
- "district": String(50),
- "region": String(50),
- "postcode": Float(),
- "id": String(100),
- "datetime": DateTime(),
- "occupancy": Float(),
- "radius_miles": Float(),
- "geohash": String(12),
- "delimited": String(60),
- },
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
+ tbl_name = "long_lat"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("san_francisco.csv.gz", make_bytes=True)
+ pdf = pd.read_csv(data, encoding="utf-8")
+ start = datetime.datetime.now().replace(
+ hour=0, minute=0, second=0, microsecond=0
+ )
+ pdf["datetime"] = [
+ start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
+ for i in range(len(pdf))
+ ]
+ pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
+ pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
+ pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x:
geohash.encode(*x), axis=1)
+ pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str),
sep=",")
+ pdf.to_sql( # pylint: disable=no-member
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "longitude": Float(),
+ "latitude": Float(),
+ "number": Float(),
+ "street": String(100),
+ "unit": String(10),
+ "city": String(50),
+ "district": String(50),
+ "region": String(50),
+ "postcode": Float(),
+ "id": String(100),
+ "datetime": DateTime(),
+ "occupancy": Float(),
+ "radius_miles": Float(),
+ "geohash": String(12),
+ "delimited": String(60),
+ },
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
print("Creating table reference")
- obj = db.session.query(TBL).filter_by(table_name="long_lat").first()
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="long_lat")
+ obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "datetime"
- obj.database = utils.get_or_create_main_db()
+ obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
diff --git a/superset/data/misc_dashboard.py
b/superset/examples/misc_dashboard.py
similarity index 100%
rename from superset/data/misc_dashboard.py
rename to superset/examples/misc_dashboard.py
diff --git a/superset/data/multi_line.py b/superset/examples/multi_line.py
similarity index 93%
rename from superset/data/multi_line.py
rename to superset/examples/multi_line.py
index d2b4d18..390e5f8 100644
--- a/superset/data/multi_line.py
+++ b/superset/examples/multi_line.py
@@ -22,9 +22,9 @@ from .helpers import merge_slice, misc_dash_slices, Slice
from .world_bank import load_world_bank_health_n_pop
-def load_multi_line():
- load_world_bank_health_n_pop()
- load_birth_names()
+def load_multi_line(only_metadata=False):
+ load_world_bank_health_n_pop(only_metadata)
+ load_birth_names(only_metadata)
ids = [
row.id
for row in db.session.query(Slice).filter(
diff --git a/superset/data/multiformat_time_series.py
b/superset/examples/multiformat_time_series.py
similarity index 66%
rename from superset/data/multiformat_time_series.py
rename to superset/examples/multiformat_time_series.py
index b33391d..2875d37 100644
--- a/superset/data/multiformat_time_series.py
+++ b/superset/examples/multiformat_time_series.py
@@ -19,7 +19,7 @@ import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
from superset import db
-from superset.utils import core as utils
+from superset.utils.core import get_example_database
from .helpers import (
config,
get_example_data,
@@ -31,38 +31,44 @@ from .helpers import (
)
-def load_multiformat_time_series():
+def load_multiformat_time_series(only_metadata=False, force=False):
"""Loading time series data from a zip file in the repo"""
- data = get_example_data("multiformat_time_series.json.gz")
- pdf = pd.read_json(data)
+ tbl_name = "multiformat_time_series"
+ database = get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
- pdf.ds = pd.to_datetime(pdf.ds, unit="s")
- pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
- pdf.to_sql(
- "multiformat_time_series",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "ds": Date,
- "ds2": DateTime,
- "epoch_s": BigInteger,
- "epoch_ms": BigInteger,
- "string0": String(100),
- "string1": String(100),
- "string2": String(100),
- "string3": String(100),
- },
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
- print("Creating table [multiformat_time_series] reference")
- obj =
db.session.query(TBL).filter_by(table_name="multiformat_time_series").first()
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("multiformat_time_series.json.gz")
+ pdf = pd.read_json(data)
+
+ pdf.ds = pd.to_datetime(pdf.ds, unit="s")
+ pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "ds": Date,
+ "ds2": DateTime,
+ "epoch_s": BigInteger,
+ "epoch_ms": BigInteger,
+ "string0": String(100),
+ "string1": String(100),
+ "string2": String(100),
+ "string3": String(100),
+ },
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
+
+ print(f"Creating table [{tbl_name}] reference")
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="multiformat_time_series")
+ obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "ds"
- obj.database = utils.get_or_create_main_db()
+ obj.database = database
dttm_and_expr_dict = {
"ds": [None, None],
"ds2": [None, None],
diff --git a/superset/data/paris.py b/superset/examples/paris.py
similarity index 61%
rename from superset/data/paris.py
rename to superset/examples/paris.py
index 6387272..a4f252b 100644
--- a/superset/data/paris.py
+++ b/superset/examples/paris.py
@@ -21,35 +21,39 @@ from sqlalchemy import String, Text
from superset import db
from superset.utils import core as utils
-from .helpers import TBL, get_example_data
+from .helpers import get_example_data, TBL
-def load_paris_iris_geojson():
+def load_paris_iris_geojson(only_metadata=False, force=False):
tbl_name = "paris_iris_mapping"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
- data = get_example_data("paris_iris.json.gz")
- df = pd.read_json(data)
- df["features"] = df.features.map(json.dumps)
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("paris_iris.json.gz")
+ df = pd.read_json(data)
+ df["features"] = df.features.map(json.dumps)
+
+ df.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "color": String(255),
+ "name": String(255),
+ "features": Text,
+ "type": Text,
+ },
+ index=False,
+ )
- df.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "color": String(255),
- "name": String(255),
- "features": Text,
- "type": Text,
- },
- index=False,
- )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Map of Paris"
- tbl.database = utils.get_or_create_main_db()
+ tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
diff --git a/superset/data/random_time_series.py
b/superset/examples/random_time_series.py
similarity index 67%
rename from superset/data/random_time_series.py
rename to superset/examples/random_time_series.py
index 477cb14..26dd003 100644
--- a/superset/data/random_time_series.py
+++ b/superset/examples/random_time_series.py
@@ -23,28 +23,33 @@ from superset.utils import core as utils
from .helpers import config, get_example_data, get_slice_json, merge_slice,
Slice, TBL
-def load_random_time_series_data():
+def load_random_time_series_data(only_metadata=False, force=False):
"""Loading random time series data from a zip file in the repo"""
- data = get_example_data("random_time_series.json.gz")
- pdf = pd.read_json(data)
- pdf.ds = pd.to_datetime(pdf.ds, unit="s")
- pdf.to_sql(
- "random_time_series",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={"ds": DateTime},
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
+ tbl_name = "random_time_series"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("random_time_series.json.gz")
+ pdf = pd.read_json(data)
+ pdf.ds = pd.to_datetime(pdf.ds, unit="s")
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={"ds": DateTime},
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
- print("Creating table [random_time_series] reference")
- obj =
db.session.query(TBL).filter_by(table_name="random_time_series").first()
+ print(f"Creating table [{tbl_name}] reference")
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="random_time_series")
+ obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "ds"
- obj.database = utils.get_or_create_main_db()
+ obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
diff --git a/superset/data/sf_population_polygons.py
b/superset/examples/sf_population_polygons.py
similarity index 61%
rename from superset/data/sf_population_polygons.py
rename to superset/examples/sf_population_polygons.py
index 49550b3..738ac41 100644
--- a/superset/data/sf_population_polygons.py
+++ b/superset/examples/sf_population_polygons.py
@@ -21,35 +21,39 @@ from sqlalchemy import BigInteger, Text
from superset import db
from superset.utils import core as utils
-from .helpers import TBL, get_example_data
+from .helpers import get_example_data, TBL
-def load_sf_population_polygons():
+def load_sf_population_polygons(only_metadata=False, force=False):
tbl_name = "sf_population_polygons"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
- data = get_example_data("sf_population.json.gz")
- df = pd.read_json(data)
- df["contour"] = df.contour.map(json.dumps)
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("sf_population.json.gz")
+ df = pd.read_json(data)
+ df["contour"] = df.contour.map(json.dumps)
+
+ df.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "zipcode": BigInteger,
+ "population": BigInteger,
+ "contour": Text,
+ "area": BigInteger,
+ },
+ index=False,
+ )
- df.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "zipcode": BigInteger,
- "population": BigInteger,
- "contour": Text,
- "area": BigInteger,
- },
- index=False,
- )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Population density of San Francisco"
- tbl.database = utils.get_or_create_main_db()
+ tbl.database = database
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
diff --git a/superset/data/tabbed_dashboard.py
b/superset/examples/tabbed_dashboard.py
similarity index 95%
rename from superset/data/tabbed_dashboard.py
rename to superset/examples/tabbed_dashboard.py
index 09f18ad..ab4552c 100644
--- a/superset/data/tabbed_dashboard.py
+++ b/superset/examples/tabbed_dashboard.py
@@ -17,30 +17,13 @@
"""Loads datasets, dashboards and slices in a new superset instance"""
# pylint: disable=C,R,W
import json
-import os
import textwrap
-import pandas as pd
-from sqlalchemy import DateTime, String
-
from superset import db
-from superset.connectors.sqla.models import SqlMetric
-from superset.utils import core as utils
-from .helpers import (
- config,
- Dash,
- DATA_FOLDER,
- get_example_data,
- get_slice_json,
- merge_slice,
- misc_dash_slices,
- Slice,
- TBL,
- update_slice_ids,
-)
+from .helpers import Dash, Slice, update_slice_ids
-def load_tabbed_dashboard():
+def load_tabbed_dashboard(only_metadata=False):
"""Creating a tabbed dashboard"""
print("Creating a dashboard with nested tabs")
diff --git a/superset/data/unicode_test_data.py
b/superset/examples/unicode_test_data.py
similarity index 72%
rename from superset/data/unicode_test_data.py
rename to superset/examples/unicode_test_data.py
index 3f3ed55..3f91f3f 100644
--- a/superset/data/unicode_test_data.py
+++ b/superset/examples/unicode_test_data.py
@@ -35,38 +35,43 @@ from .helpers import (
)
-def load_unicode_test_data():
+def load_unicode_test_data(only_metadata=False, force=False):
"""Loading unicode test dataset from a csv file in the repo"""
- data = get_example_data(
- "unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
- )
- df = pd.read_csv(data, encoding="utf-8")
- # generate date/numeric data
- df["dttm"] = datetime.datetime.now().date()
- df["value"] = [random.randint(1, 100) for _ in range(len(df))]
- df.to_sql( # pylint: disable=no-member
- "unicode_test",
- db.engine,
- if_exists="replace",
- chunksize=500,
- dtype={
- "phrase": String(500),
- "short_phrase": String(10),
- "with_missing": String(100),
- "dttm": Date(),
- "value": Float(),
- },
- index=False,
- )
- print("Done loading table!")
- print("-" * 80)
+ tbl_name = "unicode_test"
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data(
+ "unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
+ )
+ df = pd.read_csv(data, encoding="utf-8")
+ # generate date/numeric data
+ df["dttm"] = datetime.datetime.now().date()
+ df["value"] = [random.randint(1, 100) for _ in range(len(df))]
+ df.to_sql( # pylint: disable=no-member
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=500,
+ dtype={
+ "phrase": String(500),
+ "short_phrase": String(10),
+ "with_missing": String(100),
+ "dttm": Date(),
+ "value": Float(),
+ },
+ index=False,
+ )
+ print("Done loading table!")
+ print("-" * 80)
print("Creating table [unicode_test] reference")
- obj = db.session.query(TBL).filter_by(table_name="unicode_test").first()
+ obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not obj:
- obj = TBL(table_name="unicode_test")
+ obj = TBL(table_name=tbl_name)
obj.main_dttm_col = "dttm"
- obj.database = utils.get_or_create_main_db()
+ obj.database = database
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
@@ -104,7 +109,7 @@ def load_unicode_test_data():
merge_slice(slc)
print("Creating a dashboard")
- dash = db.session.query(Dash).filter_by(dashboard_title="Unicode
Test").first()
+ dash = db.session.query(Dash).filter_by(slug="unicode-test").first()
if not dash:
dash = Dash()
diff --git a/superset/data/world_bank.py b/superset/examples/world_bank.py
similarity index 93%
rename from superset/data/world_bank.py
rename to superset/examples/world_bank.py
index a64bd2b..fbe4512 100644
--- a/superset/data/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -30,7 +30,7 @@ from superset.utils import core as utils
from .helpers import (
config,
Dash,
- DATA_FOLDER,
+ EXAMPLES_FOLDER,
get_example_data,
get_slice_json,
merge_slice,
@@ -41,34 +41,38 @@ from .helpers import (
)
-def load_world_bank_health_n_pop():
+def load_world_bank_health_n_pop(only_metadata=False, force=False):
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = "wb_health_population"
- data = get_example_data("countries.json.gz")
- pdf = pd.read_json(data)
- pdf.columns = [col.replace(".", "_") for col in pdf.columns]
- pdf.year = pd.to_datetime(pdf.year)
- pdf.to_sql(
- tbl_name,
- db.engine,
- if_exists="replace",
- chunksize=50,
- dtype={
- "year": DateTime(),
- "country_code": String(3),
- "country_name": String(255),
- "region": String(255),
- },
- index=False,
- )
+ database = utils.get_example_database()
+ table_exists = database.has_table_by_name(tbl_name)
+
+ if not only_metadata and (not table_exists or force):
+ data = get_example_data("countries.json.gz")
+ pdf = pd.read_json(data)
+ pdf.columns = [col.replace(".", "_") for col in pdf.columns]
+ pdf.year = pd.to_datetime(pdf.year)
+ pdf.to_sql(
+ tbl_name,
+ database.get_sqla_engine(),
+ if_exists="replace",
+ chunksize=50,
+ dtype={
+ "year": DateTime(),
+ "country_code": String(3),
+ "country_name": String(255),
+ "region": String(255),
+ },
+ index=False,
+ )
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
- tbl.description = utils.readfile(os.path.join(DATA_FOLDER, "countries.md"))
+ tbl.description = utils.readfile(os.path.join(EXAMPLES_FOLDER,
"countries.md"))
tbl.main_dttm_col = "year"
- tbl.database = utils.get_or_create_main_db()
+ tbl.database = database
tbl.filter_select_enabled = True
metrics = [
@@ -328,6 +332,7 @@ def load_world_bank_health_n_pop():
if not dash:
dash = Dash()
+ dash.published = True
js = textwrap.dedent(
"""\
{
diff --git a/superset/models/core.py b/superset/models/core.py
index 0e5d1db..66f2d0e 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -666,12 +666,13 @@ class Dashboard(Model, AuditMixinNullable, ImportMixin):
)
make_transient(copied_dashboard)
for slc in copied_dashboard.slices:
+ make_transient(slc)
datasource_ids.add((slc.datasource_id, slc.datasource_type))
# add extra params for the import
slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.name,
- schema=slc.datasource.name,
+ schema=slc.datasource.schema,
database_name=slc.datasource.database.name,
)
copied_dashboard.alter_params(remote_id=dashboard_id)
@@ -1169,6 +1170,10 @@ class Database(Model, AuditMixinNullable, ImportMixin):
engine = self.get_sqla_engine()
return engine.has_table(table.table_name, table.schema or None)
+ def has_table_by_name(self, table_name, schema=None):
+ engine = self.get_sqla_engine()
+ return engine.has_table(table_name, schema)
+
@utils.memoized
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
diff --git a/superset/models/tags.py b/superset/models/tags.py
index d318115..0bd6c2b 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -81,7 +81,7 @@ class TaggedObject(Model, AuditMixinNullable):
object_id = Column(Integer)
object_type = Column(Enum(ObjectTypes))
- tag = relationship("Tag")
+ tag = relationship("Tag", backref="objects")
def get_tag(name, session, type_):
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index d99719c..73dc756 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -18,10 +18,8 @@
import json
import logging
-import urllib.parse
from celery.utils.log import get_task_logger
-from flask import url_for
import requests
from requests.exceptions import RequestException
from sqlalchemy import and_, func
@@ -75,13 +73,13 @@ def get_form_data(chart_id, dashboard=None):
return form_data
-def get_url(params):
+def get_url(chart):
"""Return external URL for warming up a given chart/table cache."""
- baseurl =
"http://{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}/".format(
- **app.config
- )
with app.test_request_context():
- return urllib.parse.urljoin(baseurl, url_for("Superset.explore_json",
**params))
+ baseurl =
"{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}".format(
+ **app.config
+ )
+ return f"{baseurl}{chart.url}"
class Strategy:
@@ -136,7 +134,7 @@ class DummyStrategy(Strategy):
session = db.create_scoped_session()
charts = session.query(Slice).all()
- return [get_url({"form_data": get_form_data(chart.id)}) for chart in
charts]
+ return [get_url(chart) for chart in charts]
class TopNDashboardsStrategy(Strategy):
@@ -180,7 +178,7 @@ class TopNDashboardsStrategy(Strategy):
dashboards =
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
- urls.append(get_url({"form_data": get_form_data(chart.id,
dashboard)}))
+ urls.append(get_url(chart))
return urls
@@ -229,7 +227,7 @@ class DashboardTagsStrategy(Strategy):
tagged_dashboards =
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
- urls.append(get_url({"form_data": get_form_data(chart.id,
dashboard)}))
+ urls.append(get_url(chart))
# add charts that are tagged
tagged_objects = (
@@ -245,7 +243,7 @@ class DashboardTagsStrategy(Strategy):
chart_ids = [tagged_object.object_id for tagged_object in
tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
- urls.append(get_url({"form_data": get_form_data(chart.id)}))
+ urls.append(get_url(chart))
return urls
diff --git a/superset/utils/core.py b/superset/utils/core.py
index ebbb082..b0ca2dc 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -942,25 +942,37 @@ def user_label(user: User) -> Optional[str]:
def get_or_create_main_db():
- from superset import conf, db
+ get_main_database()
+
+
+def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
+ from superset import db
from superset.models import core as models
- logging.info("Creating database reference")
- dbobj = get_main_database(db.session)
- if not dbobj:
- dbobj = models.Database(
- database_name="main", allow_csv_upload=True, expose_in_sqllab=True
- )
- dbobj.set_sqlalchemy_uri(conf.get("SQLALCHEMY_DATABASE_URI"))
- db.session.add(dbobj)
+ database = (
+
db.session.query(models.Database).filter_by(database_name=database_name).first()
+ )
+ if not database:
+ logging.info(f"Creating database reference for {database_name}")
+ database = models.Database(database_name=database_name, *args,
**kwargs)
+ db.session.add(database)
+
+ database.set_sqlalchemy_uri(sqlalchemy_uri)
db.session.commit()
- return dbobj
+ return database
-def get_main_database(session):
- from superset.models import core as models
+def get_main_database():
+ from superset import conf
+
+ return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))
+
+
+def get_example_database():
+ from superset import conf
- return
session.query(models.Database).filter_by(database_name="main").first()
+ db_uri = conf.get("SQLALCHEMY_EXAMPLES_URI") or
conf.get("SQLALCHEMY_DATABASE_URI")
+ return get_or_create_db("examples", db_uri)
def is_adhoc_metric(metric) -> bool:
diff --git a/superset/viz.py b/superset/viz.py
index 5498bec..d81e16e 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -1746,7 +1746,7 @@ class WorldMapViz(BaseViz):
return qry
def get_data(self, df):
- from superset.data import countries
+ from superset.examples import countries
fd = self.form_data
cols = [fd.get("entity")]
diff --git a/tests/access_tests.py b/tests/access_tests.py
index 835ac22..999db1f 100644
--- a/tests/access_tests.py
+++ b/tests/access_tests.py
@@ -31,7 +31,7 @@ ROLE_TABLES_PERM_DATA = {
"database": [
{
"datasource_type": "table",
- "name": "main",
+ "name": "examples",
"schema": [{"name": "", "datasources": ["birth_names"]}],
}
],
@@ -42,7 +42,7 @@ ROLE_ALL_PERM_DATA = {
"database": [
{
"datasource_type": "table",
- "name": "main",
+ "name": "examples",
"schema": [{"name": "", "datasources": ["birth_names"]}],
},
{
diff --git a/tests/base_tests.py b/tests/base_tests.py
index cca317e..8ac5bcd 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -168,9 +168,6 @@ class SupersetTestCase(unittest.TestCase):
):
security_manager.del_permission_role(public_role, perm)
- def get_main_database(self):
- return get_main_database(db.session)
-
def run_sql(
self,
sql,
@@ -182,7 +179,7 @@ class SupersetTestCase(unittest.TestCase):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
- dbid = self.get_main_database().id
+ dbid = get_main_database().id
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
@@ -202,7 +199,7 @@ class SupersetTestCase(unittest.TestCase):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
- dbid = self.get_main_database().id
+ dbid = get_main_database().id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
@@ -223,3 +220,7 @@ class SupersetTestCase(unittest.TestCase):
def test_feature_flags(self):
self.assertEquals(is_feature_enabled("foo"), "bar")
self.assertEquals(is_feature_enabled("super"), "set")
+
+ def get_dash_by_slug(self, dash_slug):
+ sesh = db.session()
+ return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index f204fa6..dba087a 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -128,14 +128,14 @@ class CeleryTestCase(SupersetTestCase):
return json.loads(resp.data)
def test_run_sync_query_dont_exist(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1)
def test_run_sync_query_cta(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
@@ -158,7 +158,7 @@ class CeleryTestCase(SupersetTestCase):
self.assertGreater(len(results["data"]), 0)
def test_run_sync_query_cta_no_data(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
db_id = main_db.id
sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
result3 = self.run_sql(db_id, sql_empty_result, "3")
@@ -179,7 +179,7 @@ class CeleryTestCase(SupersetTestCase):
return self.run_sql(db_id, sql)
def test_run_async_query(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db)
@@ -212,7 +212,7 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_2", main_db)
diff --git a/tests/core_tests.py b/tests/core_tests.py
index ea3b889..b401c8c 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -39,7 +39,6 @@ from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.models import core as models
from superset.models.sql_lab import Query
from superset.utils import core as utils
-from superset.utils.core import get_main_database
from superset.views.core import DatabaseView
from .base_tests import SupersetTestCase
from .fixtures.pyodbcRow import Row
@@ -345,7 +344,7 @@ class CoreTests(SupersetTestCase):
def test_testconn(self, username="admin"):
self.login(username=username)
- database = get_main_database(db.session)
+ database = utils.get_main_database()
# validate that the endpoint works with the password-masked sqlalchemy
uri
data = json.dumps(
@@ -376,7 +375,7 @@ class CoreTests(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
def test_custom_password_store(self):
- database = get_main_database(db.session)
+ database = utils.get_main_database()
conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
def custom_password_store(uri):
@@ -394,13 +393,13 @@ class CoreTests(SupersetTestCase):
# validate that sending a password-masked uri does not over-write the
decrypted
# uri
self.login(username=username)
- database = get_main_database(db.session)
+ database = utils.get_main_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
url = "databaseview/edit/{}".format(database.id)
data = {k: database.__getattribute__(k) for k in
DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
- database = get_main_database(db.session)
+ database = utils.get_main_database()
self.assertEqual(sqlalchemy_uri_decrypted,
database.sqlalchemy_uri_decrypted)
def test_warm_up_cache(self):
@@ -483,27 +482,27 @@ class CoreTests(SupersetTestCase):
def test_extra_table_metadata(self):
self.login("admin")
- dbid = get_main_database(db.session).id
+ dbid = utils.get_main_database().id
self.get_json_resp(
f"/superset/extra_table_metadata/{dbid}/"
"ab_permission_view/panoramix/"
)
def test_process_template(self):
- maindb = get_main_database(db.session)
+ maindb = utils.get_main_database()
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_get_template_kwarg(self):
- maindb = get_main_database(db.session)
+ maindb = utils.get_main_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb, foo="bar")
rendered = tp.process_template(s)
self.assertEqual("bar", rendered)
def test_template_kwarg(self):
- maindb = get_main_database(db.session)
+ maindb = utils.get_main_database()
s = "{{ foo }}"
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(s, foo="bar")
@@ -516,7 +515,7 @@ class CoreTests(SupersetTestCase):
self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
def test_table_metadata(self):
- maindb = get_main_database(db.session)
+ maindb = utils.get_main_database()
backend = maindb.backend
data =
self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id))
self.assertEqual(data["name"], "ab_user")
@@ -615,15 +614,16 @@ class CoreTests(SupersetTestCase):
test_file.write("john,1\n")
test_file.write("paul,2\n")
test_file.close()
- main_db_uri = (
-
db.session.query(models.Database).filter_by(database_name="main").one()
- )
+ example_db = utils.get_example_database()
+ example_db.allow_csv_upload = True
+ db_id = example_db.id
+ db.session.commit()
test_file = open(filename, "rb")
form_data = {
"csv_file": test_file,
"sep": ",",
"name": table_name,
- "con": main_db_uri.id,
+ "con": db_id,
"if_exists": "append",
"index_label": "test_label",
"mangle_dupe_cols": False,
@@ -638,8 +638,8 @@ class CoreTests(SupersetTestCase):
try:
# ensure uploaded successfully
- form_post = self.get_resp(url, data=form_data)
- assert 'CSV file "testCSV.csv" uploaded to table' in form_post
+ resp = self.get_resp(url, data=form_data)
+ assert 'CSV file "testCSV.csv" uploaded to table' in resp
finally:
os.remove(filename)
@@ -769,7 +769,8 @@ class CoreTests(SupersetTestCase):
def test_select_star(self):
self.login(username="admin")
- resp = self.get_resp("/superset/select_star/1/birth_names")
+ examples_db = utils.get_example_database()
+ resp =
self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names")
self.assertIn("gender", resp)
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index 7ab389e..733b9be 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -39,6 +39,7 @@ from superset.db_engine_specs.pinot import PinotEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.core import Database
+from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
@@ -925,14 +926,14 @@ class DbEngineSpecsTestCase(SupersetTestCase):
) # noqa
def test_column_datatype_to_string(self):
- main_db = self.get_main_database()
- sqla_table = main_db.get_table("energy_usage")
- dialect = main_db.get_dialect()
+ example_db = get_example_database()
+ sqla_table = example_db.get_table("energy_usage")
+ dialect = example_db.get_dialect()
col_names = [
- main_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
+ example_db.db_engine_spec.column_datatype_to_string(c.type,
dialect)
for c in sqla_table.columns
]
- if main_db.backend == "postgresql":
+ if example_db.backend == "postgresql":
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
diff --git a/tests/dict_import_export_tests.py
b/tests/dict_import_export_tests.py
index ba6766b..fafa024 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -63,7 +63,7 @@ class DictImportExportTests(SupersetTestCase):
params = {DBREF: id, "database_name": database_name}
dict_rep = {
- "database_id": get_main_database(db.session).id,
+ "database_id": get_main_database().id,
"table_name": name,
"schema": schema,
"id": id,
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index f48fa2d..d5c3609 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -63,7 +63,7 @@ class ImportExportTests(SupersetTestCase):
name,
ds_id=None,
id=None,
- db_name="main",
+ db_name="examples",
table_name="wb_health_population",
):
params = {
@@ -102,7 +102,7 @@ class ImportExportTests(SupersetTestCase):
)
def create_table(self, name, schema="", id=0, cols_names=[],
metric_names=[]):
- params = {"remote_id": id, "database_name": "main"}
+ params = {"remote_id": id, "database_name": "examples"}
table = SqlaTable(
id=id, schema=schema, table_name=name, params=json.dumps(params)
)
@@ -135,10 +135,6 @@ class ImportExportTests(SupersetTestCase):
def get_dash(self, dash_id):
return db.session.query(models.Dashboard).filter_by(id=dash_id).first()
- def get_dash_by_slug(self, dash_slug):
- sesh = db.session()
- return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()
-
def get_datasource(self, datasource_id):
return
db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
@@ -192,9 +188,21 @@ class ImportExportTests(SupersetTestCase):
self.assertEquals(expected_slc_name, actual_slc_name)
self.assertEquals(expected_slc.datasource_type,
actual_slc.datasource_type)
self.assertEquals(expected_slc.viz_type, actual_slc.viz_type)
- self.assertEquals(
- json.loads(expected_slc.params), json.loads(actual_slc.params)
- )
+ exp_params = json.loads(expected_slc.params)
+ actual_params = json.loads(actual_slc.params)
+ diff_params_keys = (
+ "schema",
+ "database_name",
+ "datasource_name",
+ "remote_id",
+ "import_time",
+ )
+ for k in diff_params_keys:
+ if k in actual_params:
+ actual_params.pop(k)
+ if k in exp_params:
+ exp_params.pop(k)
+ self.assertEquals(exp_params, actual_params)
def test_export_1_dashboard(self):
self.login("admin")
@@ -233,11 +241,11 @@ class ImportExportTests(SupersetTestCase):
birth_dash.id, world_health_dash.id
)
resp = self.client.get(export_dash_url)
+ resp_data = json.loads(
+ resp.data.decode("utf-8"), object_hook=utils.decode_dashboards
+ )
exported_dashboards = sorted(
- json.loads(resp.data.decode("utf-8"),
object_hook=utils.decode_dashboards)[
- "dashboards"
- ],
- key=lambda d: d.dashboard_title,
+ resp_data.get("dashboards"), key=lambda d: d.dashboard_title
)
self.assertEquals(2, len(exported_dashboards))
@@ -255,10 +263,7 @@ class ImportExportTests(SupersetTestCase):
)
exported_tables = sorted(
- json.loads(resp.data.decode("utf-8"),
object_hook=utils.decode_dashboards)[
- "datasources"
- ],
- key=lambda t: t.table_name,
+ resp_data.get("datasources"), key=lambda t: t.table_name
)
self.assertEquals(2, len(exported_tables))
self.assert_table_equals(
@@ -297,7 +302,7 @@ class ImportExportTests(SupersetTestCase):
self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm)
def test_import_slices_for_non_existent_table(self):
- with self.assertRaises(IndexError):
+ with self.assertRaises(AttributeError):
models.Slice.import_obj(
self.create_slice("Import Me 3", id=10004,
table_name="non_existent"),
None,
@@ -447,7 +452,7 @@ class ImportExportTests(SupersetTestCase):
imported = self.get_table(imported_id)
self.assert_table_equals(table, imported)
self.assertEquals(
- {"remote_id": 10002, "import_time": 1990, "database_name": "main"},
+ {"remote_id": 10002, "import_time": 1990, "database_name":
"examples"},
json.loads(imported.params),
)
diff --git a/tests/load_examples_test.py b/tests/load_examples_test.py
index 6ca3b2d..0d1db79 100644
--- a/tests/load_examples_test.py
+++ b/tests/load_examples_test.py
@@ -14,23 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from superset import data
+from superset import examples
from superset.cli import load_test_users_run
from .base_tests import SupersetTestCase
class SupersetDataFrameTestCase(SupersetTestCase):
def test_load_css_templates(self):
- data.load_css_templates()
+ examples.load_css_templates()
def test_load_energy(self):
- data.load_energy()
+ examples.load_energy()
def test_load_world_bank_health_n_pop(self):
- data.load_world_bank_health_n_pop()
+ examples.load_world_bank_health_n_pop()
def test_load_birth_names(self):
- data.load_birth_names()
+ examples.load_birth_names()
def test_load_test_users_run(self):
load_test_users_run()
+
+ def test_load_unicode_test_data(self):
+ examples.load_unicode_test_data()
diff --git a/tests/model_tests.py b/tests/model_tests.py
index 445ec52..fda941b 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -20,9 +20,9 @@ import unittest
import pandas
from sqlalchemy.engine.url import make_url
-from superset import app, db
+from superset import app
from superset.models.core import Database
-from superset.utils.core import get_main_database, QueryStatus
+from superset.utils.core import get_example_database, get_main_database,
QueryStatus
from .base_tests import SupersetTestCase
@@ -101,7 +101,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertNotEquals(example_user, user_name)
def test_select_star(self):
- main_db = get_main_database(db.session)
+ main_db = get_example_database()
table_name = "energy_usage"
sql = main_db.select_star(table_name, show_cols=False,
latest_partition=False)
expected = textwrap.dedent(
@@ -124,7 +124,7 @@ class DatabaseModelTestCase(SupersetTestCase):
assert sql.startswith(expected)
def test_single_statement(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None)
@@ -134,7 +134,7 @@ class DatabaseModelTestCase(SupersetTestCase):
self.assertEquals(df.iat[0, 0], 1)
def test_multi_statement(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None)
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 5dc56bf..b16b796 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -83,7 +83,7 @@ class SqlLabTests(SupersetTestCase):
self.assertLess(0, len(data["data"]))
def test_sql_json_has_access(self):
- main_db = get_main_database(db.session)
+ main_db = get_main_database()
security_manager.add_permission_view_menu("database_access",
main_db.perm)
db.session.commit()
main_db_permission_view = (
diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py
index 0f5a20e..0786ed3 100644
--- a/tests/strategy_tests.py
+++ b/tests/strategy_tests.py
@@ -23,14 +23,13 @@ from superset.models.core import Log
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import (
DashboardTagsStrategy,
- DummyStrategy,
get_form_data,
TopNDashboardsStrategy,
)
from .base_tests import SupersetTestCase
-TEST_URL = "http://0.0.0.0:8081/superset/explore_json"
+URL_PREFIX = "0.0.0.0:8081"
class CacheWarmUpTests(SupersetTestCase):
@@ -141,61 +140,61 @@ class CacheWarmUpTests(SupersetTestCase):
}
self.assertEqual(result, expected)
- def test_dummy_strategy(self):
- strategy = DummyStrategy()
- result = sorted(strategy.get_urls())
- expected = [
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+1%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+17%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+18%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+19%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+8%7D",
- ]
- self.assertEqual(result, expected)
-
def test_top_n_dashboards_strategy(self):
# create a top visited dashboard
db.session.query(Log).delete()
self.login(username="admin")
+ dash = self.get_dash_by_slug("births")
for _ in range(10):
- self.client.get("/superset/dashboard/3/")
+ self.client.get(f"/superset/dashboard/{dash.id}/")
strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls())
- expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"]
+ expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
self.assertEqual(result, expected)
+ def reset_tag(self, tag):
+ """Remove associated object from tag, used to reset tests"""
+ if tag.objects:
+ for o in tag.objects:
+ db.session.delete(o)
+ db.session.commit()
+
def test_dashboard_tags(self):
- strategy = DashboardTagsStrategy(["tag1"])
+ tag1 = get_tag("tag1", db.session, TagTypes.custom)
+ # delete first to make test idempotent
+ self.reset_tag(tag1)
+ strategy = DashboardTagsStrategy(["tag1"])
result = sorted(strategy.get_urls())
expected = []
self.assertEqual(result, expected)
- # tag dashboard 3 with `tag1`
+ # tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
- object_id = 3
+ dash = self.get_dash_by_slug("births")
+ tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tagged_object = TaggedObject(
- tag_id=tag1.id, object_id=object_id,
object_type=ObjectTypes.dashboard
+ tag_id=tag1.id, object_id=dash.id,
object_type=ObjectTypes.dashboard
)
db.session.add(tagged_object)
db.session.commit()
- result = sorted(strategy.get_urls())
- expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"]
- self.assertEqual(result, expected)
+ self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
strategy = DashboardTagsStrategy(["tag2"])
+ tag2 = get_tag("tag2", db.session, TagTypes.custom)
+ self.reset_tag(tag2)
result = sorted(strategy.get_urls())
expected = []
self.assertEqual(result, expected)
- # tag chart 30 with `tag2`
- tag2 = get_tag("tag2", db.session, TagTypes.custom)
- object_id = 30
+ # tag first slice
+ dash = self.get_dash_by_slug("unicode-test")
+ slc = dash.slices[0]
+ tag2_urls = [f"{URL_PREFIX}{slc.url}"]
+ object_id = slc.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
)
@@ -203,14 +202,10 @@ class CacheWarmUpTests(SupersetTestCase):
db.session.commit()
result = sorted(strategy.get_urls())
- expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D"]
- self.assertEqual(result, expected)
+ self.assertEqual(result, tag2_urls)
strategy = DashboardTagsStrategy(["tag1", "tag2"])
result = sorted(strategy.get_urls())
- expected = [
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D",
- f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D",
- ]
+ expected = sorted(tag1_urls + tag2_urls)
self.assertEqual(result, expected)
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 136fdf8..b5638fa 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -109,7 +109,6 @@ class BaseVizTestCase(SupersetTestCase):
datasource.get_col = Mock(return_value=mock_dttm_col)
mock_dttm_col.python_date_format = "epoch_ms"
result = test_viz.get_df(query_obj)
- print(result)
import logging
logging.info(result)
diff --git a/tox.ini b/tox.ini
index 189862e..3625ed6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -46,7 +46,7 @@ setenv =
PYTHONPATH = {toxinidir}
SUPERSET_CONFIG = tests.superset_test_config
SUPERSET_HOME = {envtmpdir}
- py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI =
mysql://mysqluser:mysqluserpassword@localhost/superset
+ py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI =
mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
py36-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI =
postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
py36-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI =
sqlite:////{envtmpdir}/superset.db
whitelist_externals =