This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new f7d3413 Add support for period character in table names (#7453)
f7d3413 is described below
commit f7d3413a501d8b643318fe7c0641eba608a079f5
Author: Ville Brofeldt <[email protected]>
AuthorDate: Sun May 26 06:13:16 2019 +0300
Add support for period character in table names (#7453)
* Move schema name handling in table names from frontend to backend
* Rename all_schema_names to get_all_schema_names
* Fix js errors
* Fix additional js linting errors
* Refactor datasource getters and fix linting errors
* Update js unit tests
* Add python unit test for get_table_names method
* Add python unit test for get_table_names method
* Fix js linting error
---
.../javascripts/components/TableSelector_spec.jsx | 11 +--
.../assets/spec/javascripts/sqllab/fixtures.js | 6 +-
.../src/SqlLab/components/SqlEditorLeftBar.jsx | 15 ++--
superset/assets/src/components/TableSelector.jsx | 9 +--
superset/cli.py | 4 +-
superset/db_engine_specs.py | 84 +++++++++++-----------
superset/models/core.py | 57 ++++++---------
superset/security.py | 6 +-
superset/utils/core.py | 7 +-
superset/views/core.py | 67 +++++++++--------
tests/db_engine_specs_test.py | 19 +++++
11 files changed, 148 insertions(+), 137 deletions(-)
diff --git a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
index 70e2cca..1366592 100644
--- a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
+++ b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
@@ -208,19 +208,20 @@ describe('TableSelector', () => {
it('test 1', () => {
wrapper.instance().changeTable({
- value: 'birth_names',
+ value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
});
expect(wrapper.state().tableName).toBe('birth_names');
});
- it('test 2', () => {
+ it('should call onTableChange with schema from table object', () => {
+ wrapper.setProps({ schema: null });
wrapper.instance().changeTable({
- value: 'main.my_table',
- label: 'my_table',
+ value: { schema: 'other_schema', table: 'my_table' },
+ label: 'other_schema.my_table',
});
expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table');
- expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main');
+
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema');
});
});
diff --git a/superset/assets/spec/javascripts/sqllab/fixtures.js
b/superset/assets/spec/javascripts/sqllab/fixtures.js
index 6471be1..f43f43f 100644
--- a/superset/assets/spec/javascripts/sqllab/fixtures.js
+++ b/superset/assets/spec/javascripts/sqllab/fixtures.js
@@ -329,15 +329,15 @@ export const databases = {
export const tables = {
options: [
{
- value: 'birth_names',
+ value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
},
{
- value: 'energy_usage',
+ value: { schema: 'main', table: 'energy_usage' },
label: 'energy_usage',
},
{
- value: 'wb_health_population',
+ value: { schema: 'main', table: 'wb_health_population' },
label: 'wb_health_population',
},
],
diff --git a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
index 9d0796c..43ea487 100644
--- a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
+++ b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
@@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends
React.PureComponent {
this.setState({ tableName: '' });
return;
}
- const namePieces = tableOpt.value.split('.');
- let tableName = namePieces[0];
- let schemaName = this.props.queryEditor.schema;
- if (namePieces.length === 1) {
- this.setState({ tableName });
- } else {
- schemaName = namePieces[0];
- tableName = namePieces[1];
- this.setState({ tableName });
- this.props.actions.queryEditorSetSchema(this.props.queryEditor,
schemaName);
- }
+ const schemaName = tableOpt.value.schema;
+ const tableName = tableOpt.value.table;
+ this.setState({ tableName });
+ this.props.actions.queryEditorSetSchema(this.props.queryEditor,
schemaName);
this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
}
diff --git a/superset/assets/src/components/TableSelector.jsx
b/superset/assets/src/components/TableSelector.jsx
index ba2cebb..940e1c2 100644
--- a/superset/assets/src/components/TableSelector.jsx
+++ b/superset/assets/src/components/TableSelector.jsx
@@ -170,13 +170,8 @@ export default class TableSelector extends
React.PureComponent {
this.setState({ tableName: '' });
return;
}
- const namePieces = tableOpt.value.split('.');
- let tableName = namePieces[0];
- let schemaName = this.props.schema;
- if (namePieces.length > 1) {
- schemaName = namePieces[0];
- tableName = namePieces[1];
- }
+ const schemaName = tableOpt.value.schema;
+ const tableName = tableOpt.value.table;
if (this.props.tableNameSticky) {
this.setState({ tableName }, this.onChange);
}
diff --git a/superset/cli.py b/superset/cli.py
index 7b441b4..edb0102 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -288,9 +288,9 @@ def update_datasources_cache():
if database.allow_multi_schema_metadata_fetch:
print('Fetching {} datasources ...'.format(database.name))
try:
- database.all_table_names_in_database(
+ database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
- database.all_view_names_in_database(
+ database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
except Exception as e:
print('{}'.format(str(e)))
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 35a591f..67aba12 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -122,6 +122,7 @@ class BaseEngineSpec(object):
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
+ try_remove_schema_from_table_name = True
@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
@@ -279,33 +280,32 @@ class BaseEngineSpec(object):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod
- def fetch_result_sets(cls, db, datasource_type):
- """Returns a list of tables [schema1.table1, schema2.table2, ...]
+ def get_all_datasource_names(cls, db, datasource_type: str) \
+ -> List[utils.DatasourceName]:
+ """Returns a list of all tables or views in database.
- Datasource_type can be 'table' or 'view'.
- Empty schema corresponds to the list of full names of the all
- tables or views: <schema>.<result_set_name>.
+ :param db: Database instance
+ :param datasource_type: Datasource_type can be 'table' or 'view'
+ :return: List of all datasources in database or schema
"""
- schemas = db.all_schema_names(cache=db.schema_cache_enabled,
- cache_timeout=db.schema_cache_timeout,
- force=True)
- all_result_sets = []
+ schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
+
cache_timeout=db.schema_cache_timeout,
+ force=True)
+ all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == 'table':
- all_datasource_names = db.all_table_names_in_schema(
+ all_datasources += db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
- all_datasource_names = db.all_view_names_in_schema(
+ all_datasources += db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type:
{datasource_type}')
- all_result_sets += [
- '{}.{}'.format(schema, t) for t in all_datasource_names]
- return all_result_sets
+ return all_datasources
@classmethod
def handle_cursor(cls, cursor, query, session):
@@ -352,11 +352,17 @@ class BaseEngineSpec(object):
@classmethod
def get_table_names(cls, inspector, schema):
- return sorted(inspector.get_table_names(schema))
+ tables = inspector.get_table_names(schema)
+ if schema and cls.try_remove_schema_from_table_name:
+ tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
+ return sorted(tables)
@classmethod
def get_view_names(cls, inspector, schema):
- return sorted(inspector.get_view_names(schema))
+ views = inspector.get_view_names(schema)
+ if schema and cls.try_remove_schema_from_table_name:
+ views = [re.sub(f'^{schema}\\.', '', view) for view in views]
+ return sorted(views)
@classmethod
def get_columns(cls, inspector: Inspector, table_name: str, schema: str)
-> list:
@@ -528,6 +534,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql'
max_column_name_length = 63
+ try_remove_schema_from_table_name = False
@classmethod
def get_table_names(cls, inspector, schema):
@@ -685,29 +692,25 @@ class SqliteEngineSpec(BaseEngineSpec):
return "datetime({col}, 'unixepoch')"
@classmethod
- def fetch_result_sets(cls, db, datasource_type):
- schemas = db.all_schema_names(cache=db.schema_cache_enabled,
- cache_timeout=db.schema_cache_timeout,
- force=True)
- all_result_sets = []
+ def get_all_datasource_names(cls, db, datasource_type: str) \
+ -> List[utils.DatasourceName]:
+ schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
+
cache_timeout=db.schema_cache_timeout,
+ force=True)
schema = schemas[0]
if datasource_type == 'table':
- all_datasource_names = db.all_table_names_in_schema(
+ return db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
- all_datasource_names = db.all_view_names_in_schema(
+ return db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
- all_result_sets += [
- '{}.{}'.format(schema, t) for t in all_datasource_names]
- return all_result_sets
-
@classmethod
def convert_dttm(cls, target_type, dttm):
iso = dttm.isoformat().replace('T', ' ')
@@ -1107,24 +1110,19 @@ class PrestoEngineSpec(BaseEngineSpec):
return 'from_unixtime({col})'
@classmethod
- def fetch_result_sets(cls, db, datasource_type):
- """Returns a list of tables [schema1.table1, schema2.table2, ...]
-
- Datasource_type can be 'table' or 'view'.
- Empty schema corresponds to the list of full names of the all
- tables or views: <schema>.<result_set_name>.
- """
- result_set_df = db.get_df(
+ def get_all_datasource_names(cls, db, datasource_type: str) \
+ -> List[utils.DatasourceName]:
+ datasource_df = db.get_df(
"""SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
ORDER BY concat(table_schema, '.', table_name)""".format(
datasource_type.upper(),
),
None)
- result_sets = []
- for unused, row in result_set_df.iterrows():
- result_sets.append('{}.{}'.format(
- row['table_schema'], row['table_name']))
- return result_sets
+ datasource_names: List[utils.DatasourceName] = []
+ for unused, row in datasource_df.iterrows():
+ datasource_names.append(utils.DatasourceName(
+ schema=row['table_schema'], table=row['table_name']))
+ return datasource_names
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
@@ -1385,9 +1383,9 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = patched_hive.fetch_logs
@classmethod
- def fetch_result_sets(cls, db, datasource_type):
- return BaseEngineSpec.fetch_result_sets(
- db, datasource_type)
+ def get_all_datasource_names(cls, db, datasource_type: str) \
+ -> List[utils.DatasourceName]:
+ return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit):
diff --git a/superset/models/core.py b/superset/models/core.py
index e16a234..047a3dd 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -23,6 +23,7 @@ import functools
import json
import logging
import textwrap
+from typing import List
from flask import escape, g, Markup, request
from flask_appbuilder import Model
@@ -65,6 +66,7 @@ metadata = Model.metadata # pylint: disable=no-member
PASSWORD_MASK = 'X' * 10
+
def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
id_ = target.datasource_id
@@ -184,7 +186,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
description=self.description,
cache_timeout=self.cache_timeout)
- @datasource.getter
+ @datasource.getter # type: ignore
@utils.memoized
def get_datasource(self):
return (
@@ -210,7 +212,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
datasource = self.datasource
return datasource.url if datasource else None
- @property
+ @property # type: ignore
@utils.memoized
def viz(self):
d = json.loads(self.params)
@@ -930,100 +932,87 @@ class Database(Model, AuditMixinNullable, ImportMixin):
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema:None:table_list',
attribute_in_key='id')
- def all_table_names_in_database(self, cache=False,
- cache_timeout=None, force=False):
+ def get_all_table_names_in_database(self, cache: bool = False,
+ cache_timeout: bool = None,
+ force=False) ->
List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
- return self.db_engine_spec.fetch_result_sets(self, 'table')
+ return self.db_engine_spec.get_all_datasource_names(self, 'table')
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema:None:view_list',
attribute_in_key='id')
- def all_view_names_in_database(self, cache=False,
- cache_timeout=None, force=False):
+ def get_all_view_names_in_database(self, cache: bool = False,
+ cache_timeout: bool = None,
+ force: bool = False) ->
List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
- return self.db_engine_spec.fetch_result_sets(self, 'view')
+ return self.db_engine_spec.get_all_datasource_names(self, 'view')
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:table_list'.format(
kwargs.get('schema')),
attribute_in_key='id')
- def all_table_names_in_schema(self, schema, cache=False,
- cache_timeout=None, force=False):
+ def get_all_table_names_in_schema(self, schema: str, cache: bool = False,
+ cache_timeout: int = None, force: bool =
False):
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param schema: schema name
- :type schema: str
:param cache: whether cache is enabled for the function
- :type cache: bool
:param cache_timeout: timeout in seconds for the cache
- :type cache_timeout: int
:param force: whether to force refresh the cache
- :type force: bool
- :return: table list
- :rtype: list
+ :return: list of tables
"""
- tables = []
try:
tables = self.db_engine_spec.get_table_names(
inspector=self.inspector, schema=schema)
+ return [utils.DatasourceName(table=table, schema=schema) for table
in tables]
except Exception as e:
logging.exception(e)
- return tables
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{{}}:schema:{}:view_list'.format(
kwargs.get('schema')),
attribute_in_key='id')
- def all_view_names_in_schema(self, schema, cache=False,
- cache_timeout=None, force=False):
+ def get_all_view_names_in_schema(self, schema: str, cache: bool = False,
+ cache_timeout: int = None, force: bool =
False):
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param schema: schema name
- :type schema: str
:param cache: whether cache is enabled for the function
- :type cache: bool
:param cache_timeout: timeout in seconds for the cache
- :type cache_timeout: int
:param force: whether to force refresh the cache
- :type force: bool
- :return: view list
- :rtype: list
+ :return: list of views
"""
- views = []
try:
views = self.db_engine_spec.get_view_names(
inspector=self.inspector, schema=schema)
+ return [utils.DatasourceName(table=view, schema=schema) for view
in views]
except Exception as e:
logging.exception(e)
- return views
@cache_util.memoized_func(
key=lambda *args, **kwargs: 'db:{}:schema_list',
attribute_in_key='id')
- def all_schema_names(self, cache=False, cache_timeout=None, force=False):
+ def get_all_schema_names(self, cache: bool = False, cache_timeout: int =
None,
+ force: bool = False) -> List[str]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param cache: whether cache is enabled for the function
- :type cache: bool
:param cache_timeout: timeout in seconds for the cache
- :type cache_timeout: int
:param force: whether to force refresh the cache
- :type force: bool
:return: schema list
- :rtype: list
"""
return self.db_engine_spec.get_schema_names(self.inspector)
@@ -1232,7 +1221,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
def datasource(self):
return self.get_datasource
- @datasource.getter
+ @datasource.getter # type: ignore
@utils.memoized
def get_datasource(self):
# pylint: disable=no-member
diff --git a/superset/security.py b/superset/security.py
index f8ae057..89eab5d 100644
--- a/superset/security.py
+++ b/superset/security.py
@@ -17,6 +17,7 @@
# pylint: disable=C,R,W
"""A set of constants and methods to manage permissions and security"""
import logging
+from typing import List
from flask import g
from flask_appbuilder.security.sqla import models as ab_models
@@ -26,6 +27,7 @@ from sqlalchemy import or_
from superset import sql_parse
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import SupersetSecurityException
+from superset.utils.core import DatasourceName
class SupersetSecurityManager(SecurityManager):
@@ -240,7 +242,9 @@ class SupersetSecurityManager(SecurityManager):
subset.add(t.schema)
return sorted(list(subset))
- def accessible_by_user(self, database, datasource_names, schema=None):
+ def get_datasources_accessible_by_user(
+ self, database, datasource_names: List[DatasourceName],
+ schema: str = None) -> List[DatasourceName]:
from superset import db
if self.database_access(database) or self.all_datasource_access():
return datasource_names
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 3b41457..2defa70 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -32,7 +32,7 @@ import signal
import smtplib
import sys
from time import struct_time
-from typing import List, Optional, Tuple
+from typing import List, NamedTuple, Optional, Tuple
from urllib.parse import unquote_plus
import uuid
import zlib
@@ -1100,3 +1100,8 @@ def MediumText() -> Variant:
def shortid() -> str:
return '{}'.format(uuid.uuid4())[-12:]
+
+
+class DatasourceName(NamedTuple):
+ table: str
+ schema: str
diff --git a/superset/views/core.py b/superset/views/core.py
index 883a2d9..0a6ddef 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -22,7 +22,7 @@ import logging
import os
import re
import traceback
-from typing import List # noqa: F401
+from typing import Dict, List # noqa: F401
from urllib import parse
from flask import (
@@ -311,7 +311,7 @@ class DatabaseView(SupersetModelView, DeleteMixin,
YamlExportMixin): # noqa
db.set_sqlalchemy_uri(db.sqlalchemy_uri)
security_manager.add_permission_view_menu('database_access', db.perm)
# adding a new database we always want to force refresh schema list
- for schema in db.all_schema_names():
+ for schema in db.get_all_schema_names():
security_manager.add_permission_view_menu(
'schema_access', security_manager.get_schema_perm(db, schema))
@@ -1545,7 +1545,7 @@ class Superset(BaseSupersetView):
.first()
)
if database:
- schemas = database.all_schema_names(
+ schemas = database.get_all_schema_names(
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=force_refresh)
@@ -1570,50 +1570,57 @@ class Superset(BaseSupersetView):
database = db.session.query(models.Database).filter_by(id=db_id).one()
if schema:
- table_names = database.all_table_names_in_schema(
+ tables = database.get_all_table_names_in_schema(
schema=schema, force=force_refresh,
cache=database.table_cache_enabled,
- cache_timeout=database.table_cache_timeout)
- view_names = database.all_view_names_in_schema(
+ cache_timeout=database.table_cache_timeout) or []
+ views = database.get_all_view_names_in_schema(
schema=schema, force=force_refresh,
cache=database.table_cache_enabled,
- cache_timeout=database.table_cache_timeout)
+ cache_timeout=database.table_cache_timeout) or []
else:
- table_names = database.all_table_names_in_database(
+ tables = database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60)
- view_names = database.all_view_names_in_database(
+ views = database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60)
- table_names = security_manager.accessible_by_user(database,
table_names, schema)
- view_names = security_manager.accessible_by_user(database, view_names,
schema)
+ tables = security_manager.get_datasources_accessible_by_user(
+ database, tables, schema)
+ views = security_manager.get_datasources_accessible_by_user(
+ database, views, schema)
+
+ def get_datasource_label(ds_name: utils.DatasourceName) -> str:
+ return ds_name.table if schema else
f'{ds_name.schema}.{ds_name.table}'
if substr:
- table_names = [tn for tn in table_names if substr in tn]
- view_names = [vn for vn in view_names if substr in vn]
+ tables = [tn for tn in tables if substr in
get_datasource_label(tn)]
+ views = [vn for vn in views if substr in get_datasource_label(vn)]
if not schema and database.default_schemas:
- def get_schema(tbl_or_view_name):
- return tbl_or_view_name.split('.')[0] if '.' in
tbl_or_view_name else None
-
user_schema = g.user.email.split('@')[0]
valid_schemas = set(database.default_schemas + [user_schema])
- table_names = [tn for tn in table_names if get_schema(tn) in
valid_schemas]
- view_names = [vn for vn in view_names if get_schema(vn) in
valid_schemas]
+ tables = [tn for tn in tables if tn.schema in valid_schemas]
+ views = [vn for vn in views if vn.schema in valid_schemas]
- max_items = config.get('MAX_TABLE_NAMES') or len(table_names)
- total_items = len(table_names) + len(view_names)
- max_tables = len(table_names)
- max_views = len(view_names)
+ max_items = config.get('MAX_TABLE_NAMES') or len(tables)
+ total_items = len(tables) + len(views)
+ max_tables = len(tables)
+ max_views = len(views)
if total_items and substr:
- max_tables = max_items * len(table_names) // total_items
- max_views = max_items * len(view_names) // total_items
-
- table_options = [{'value': tn, 'label': tn}
- for tn in table_names[:max_tables]]
- table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)}
- for vn in view_names[:max_views]])
+ max_tables = max_items * len(tables) // total_items
+ max_views = max_items * len(views) // total_items
+
+ def get_datasource_value(ds_name: utils.DatasourceName) -> Dict[str,
str]:
+ return {'schema': ds_name.schema, 'table': ds_name.table}
+
+ table_options = [{'value': get_datasource_value(tn),
+ 'label': get_datasource_label(tn)}
+ for tn in tables[:max_tables]]
+ table_options.extend([{'value': get_datasource_value(vn),
+ 'label': f'[view] {get_datasource_label(vn)}'}
+ for vn in views[:max_views]])
payload = {
- 'tableLength': len(table_names) + len(view_names),
+ 'tableLength': len(tables) + len(views),
'options': table_options,
}
return json_success(json.dumps(payload))
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index e0d914f..e190014 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -464,3 +464,22 @@ class DbEngineSpecsTestCase(SupersetTestCase):
query = str(sel.compile(dialect=dialect,
compile_kwargs={'literal_binds': True}))
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col =
'abc' AND unicode_col = N'abc'" # noqa
self.assertEqual(query, query_expected)
+
+ def test_get_table_names(self):
+ inspector = mock.Mock()
+ inspector.get_table_names = mock.Mock(return_value=['schema.table',
'table_2'])
+ inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])
+
+ """ Make sure base engine spec removes schema name from table name
+ ie. when try_remove_schema_from_table_name == True. """
+ base_result_expected = ['table', 'table_2']
+ base_result = db_engine_specs.BaseEngineSpec.get_table_names(
+ schema='schema', inspector=inspector)
+ self.assertListEqual(base_result_expected, base_result)
+
+ """ Make sure postgres doesn't try to remove schema name from table
name
+ ie. when try_remove_schema_from_table_name == False. """
+ pg_result_expected = ['schema.table', 'table_2', 'table_3']
+ pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
+ schema='schema', inspector=inspector)
+ self.assertListEqual(pg_result_expected, pg_result)