This is an automated email from the ASF dual-hosted git repository.
willbarrett pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new e1db016 test: presto engine spec tests (#12594)
e1db016 is described below
commit e1db016a6c324c64eec0b3ad2c1c0c8d9b85d813
Author: Karol Kostrzewa <[email protected]>
AuthorDate: Thu Jan 21 21:53:54 2021 +0100
test: presto engine spec tests (#12594)
* test get_table_names
* test _get_full_name
* add test_split_data_type
* test _show_columns
* add test_is_column_name_quoted
* test select_star
* test get_view_names
* test estimate_statement_cost
* test get_all_datasource_names
* test get_create_view
* test _extract_error_message
* fix typo
---
tests/db_engine_specs/presto_tests.py | 300 ++++++++++++++++++++++++++++++++++
1 file changed, 300 insertions(+)
diff --git a/tests/db_engine_specs/presto_tests.py
b/tests/db_engine_specs/presto_tests.py
index 9a493d3..9b2024d 100644
--- a/tests/db_engine_specs/presto_tests.py
+++ b/tests/db_engine_specs/presto_tests.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from collections import namedtuple
from unittest import mock, skipUnless
import pandas as pd
@@ -23,6 +24,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.sql_parse import ParsedQuery
+from superset.utils.core import DatasourceName
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -38,6 +40,45 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)
+ @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+ def test_get_view_names(self, mock_is_feature_enabled):
+ mock_is_feature_enabled.return_value = True
+ mock_execute = mock.MagicMock()
+ mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d",
"e"]])
+ database = mock.MagicMock()
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
= (
+ mock_execute
+ )
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
= (
+ mock_fetchall
+ )
+ result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
+ mock_execute.assert_called_once_with(
+ "SELECT table_name FROM information_schema.views", {}
+ )
+ assert result == ["a", "d"]
+
+ @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+ def test_get_view_names_with_schema(self, mock_is_feature_enabled):
+ mock_is_feature_enabled.return_value = True
+ mock_execute = mock.MagicMock()
+ mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d",
"e"]])
+ database = mock.MagicMock()
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
= (
+ mock_execute
+ )
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
= (
+ mock_fetchall
+ )
+ schema = "schema"
+ result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
+ mock_execute.assert_called_once_with(
+ "SELECT table_name FROM information_schema.views "
+ "WHERE table_schema=%(schema)s",
+ {"schema": schema},
+ )
+ assert result == ["a", "d"]
+
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
inspector.engine.dialect.identifier_preparer.quote_identifier =
mock.Mock()
@@ -516,6 +557,265 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
+ @mock.patch(
+
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+ )
+ @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+ def test_get_table_names_no_split_views_from_tables(
+ self, mock_get_view_names, mock_get_table_names,
mock_is_feature_enabled
+ ):
+ mock_get_view_names.return_value = ["view1", "view2"]
+ table_names = ["table1", "table2", "view1", "view2"]
+ mock_get_table_names.return_value = table_names
+ mock_is_feature_enabled.return_value = False
+ tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(),
None)
+ assert tables == table_names
+
+ @mock.patch(
+
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+ )
+ @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+ def test_get_table_names_split_views_from_tables(
+ self, mock_get_view_names, mock_get_table_names,
mock_is_feature_enabled
+ ):
+ mock_get_view_names.return_value = ["view1", "view2"]
+ table_names = ["table1", "table2", "view1", "view2"]
+ mock_get_table_names.return_value = table_names
+ mock_is_feature_enabled.return_value = True
+ tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(),
None)
+ assert sorted(tables) == sorted(table_names)
+
+ @mock.patch(
+
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+ )
+ @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+ def test_get_table_names_split_views_from_tables_no_tables(
+ self, mock_get_view_names, mock_get_table_names,
mock_is_feature_enabled
+ ):
+ mock_get_view_names.return_value = []
+ table_names = []
+ mock_get_table_names.return_value = table_names
+ mock_is_feature_enabled.return_value = True
+ tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(),
None)
+ assert tables == []
+
+ def test_get_full_name(self):
+ names = [
+ ("part1", "part2"),
+ ("part11", "part22"),
+ ]
+ result = PrestoEngineSpec._get_full_name(names)
+ assert result == "part1.part11"
+
+ def test_get_full_name_empty_tuple(self):
+ names = [
+ ("part1", "part2"),
+ ("", "part3"),
+ ("part4", "part5"),
+ ("", "part6"),
+ ]
+ result = PrestoEngineSpec._get_full_name(names)
+ assert result == "part1.part4"
+
+ def test_split_data_type(self):
+ data_type = "value1 value2"
+ result = PrestoEngineSpec._split_data_type(data_type, " ")
+ assert result == ["value1", "value2"]
+
+ data_type = "value1,value2"
+ result = PrestoEngineSpec._split_data_type(data_type, ",")
+ assert result == ["value1", "value2"]
+
+ data_type = '"value,1",value2'
+ result = PrestoEngineSpec._split_data_type(data_type, ",")
+ assert result == ['"value,1"', "value2"]
+
+ def test_show_columns(self):
+ inspector = mock.MagicMock()
+ inspector.engine.dialect.identifier_preparer.quote_identifier = (
+ lambda x: f'"{x}"'
+ )
+ mock_execute = mock.MagicMock(return_value=["a", "b"])
+ inspector.bind.execute = mock_execute
+ table_name = "table_name"
+ result = PrestoEngineSpec._show_columns(inspector, table_name, None)
+ assert result == ["a", "b"]
+ mock_execute.assert_called_once_with(f'SHOW COLUMNS FROM
"{table_name}"')
+
+ def test_show_columns_with_schema(self):
+ inspector = mock.MagicMock()
+ inspector.engine.dialect.identifier_preparer.quote_identifier = (
+ lambda x: f'"{x}"'
+ )
+ mock_execute = mock.MagicMock(return_value=["a", "b"])
+ inspector.bind.execute = mock_execute
+ table_name = "table_name"
+ schema = "schema"
+ result = PrestoEngineSpec._show_columns(inspector, table_name, schema)
+ assert result == ["a", "b"]
+ mock_execute.assert_called_once_with(
+ f'SHOW COLUMNS FROM "{schema}"."{table_name}"'
+ )
+
+ def test_is_column_name_quoted(self):
+ column_name = "mock"
+ assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+ column_name = '"mock'
+ assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+ column_name = '"moc"k'
+ assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+ column_name = '"moc"k"'
+ assert PrestoEngineSpec._is_column_name_quoted(column_name) is True
+
+ @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.select_star")
+ def test_select_star_no_presto_expand_data(self, mock_select_star):
+ database = mock.Mock()
+ table_name = "table_name"
+ engine = mock.Mock()
+ cols = [
+ {"col1": "val1"},
+ {"col2": "val2"},
+ ]
+ PrestoEngineSpec.select_star(database, table_name, engine, cols=cols)
+ mock_select_star.assert_called_once_with(
+ database, table_name, engine, None, 100, False, True, True, cols
+ )
+
+ @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+ @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.select_star")
+ def test_select_star_presto_expand_data(
+ self, mock_select_star, mock_is_feature_enabled
+ ):
+ mock_is_feature_enabled.return_value = True
+ database = mock.Mock()
+ table_name = "table_name"
+ engine = mock.Mock()
+ cols = [
+ {"name": "val1"},
+ {"name": "val2<?!@#$312,/'][p098"},
+ {"name": ".val2"},
+ {"name": "val2."},
+ {"name": "val.2"},
+ {"name": ".val2."},
+ ]
+ PrestoEngineSpec.select_star(
+ database, table_name, engine, show_cols=True, cols=cols
+ )
+ mock_select_star.assert_called_once_with(
+ database,
+ table_name,
+ engine,
+ None,
+ 100,
+ True,
+ True,
+ True,
+ [{"name": "val1"}, {"name": "val2<?!@#$312,/'][p098"},],
+ )
+
+ def test_estimate_statement_cost(self):
+ mock_cursor = mock.MagicMock()
+ estimate_json = {"a": "b"}
+ mock_cursor.fetchone.return_value = [
+ '{"a": "b"}',
+ ]
+ result = PrestoEngineSpec.estimate_statement_cost(
+ "SELECT * FROM brth_names", mock_cursor
+ )
+ assert result == estimate_json
+
+ def test_estimate_statement_cost_invalid_syntax(self):
+ mock_cursor = mock.MagicMock()
+ mock_cursor.execute.side_effect = Exception()
+ with self.assertRaises(Exception):
+ PrestoEngineSpec.estimate_statement_cost(
+ "DROP TABLE brth_names", mock_cursor
+ )
+
+ def test_get_all_datasource_names(self):
+ df = pd.DataFrame.from_dict(
+ {"table_schema": ["schema1", "schema2"], "table_name": ["name1",
"name2"]}
+ )
+ database = mock.MagicMock()
+ database.get_df.return_value = df
+ result = PrestoEngineSpec.get_all_datasource_names(database, "table")
+ expected_result = [
+ DatasourceName(schema="schema1", table="name1"),
+ DatasourceName(schema="schema2", table="name2"),
+ ]
+ assert result == expected_result
+
+ def test_get_create_view(self):
+ mock_execute = mock.MagicMock()
+ mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d",
"e"]])
+ database = mock.MagicMock()
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
= (
+ mock_execute
+ )
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
= (
+ mock_fetchall
+ )
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value
= (
+ False
+ )
+ schema = "schema"
+ table = "table"
+ result = PrestoEngineSpec.get_create_view(database, schema=schema,
table=table)
+ assert result == "a"
+ mock_execute.assert_called_once_with(f"SHOW CREATE VIEW
{schema}.{table}")
+
+ def test_get_create_view_exception(self):
+ mock_execute = mock.MagicMock(side_effect=Exception())
+ database = mock.MagicMock()
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
= (
+ mock_execute
+ )
+ schema = "schema"
+ table = "table"
+ with self.assertRaises(Exception):
+ PrestoEngineSpec.get_create_view(database, schema=schema,
table=table)
+
+ def test_get_create_view_database_error(self):
+ from pyhive.exc import DatabaseError
+
+ mock_execute = mock.MagicMock(side_effect=DatabaseError())
+ database = mock.MagicMock()
+
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
= (
+ mock_execute
+ )
+ schema = "schema"
+ table = "table"
+ result = PrestoEngineSpec.get_create_view(database, schema=schema,
table=table)
+ assert result is None
+
+ def test_extract_error_message_orig(self):
+ DatabaseError = namedtuple("DatabaseError", ["error_dict"])
+ db_err = DatabaseError(
+ {"errorName": "name", "errorLocation": "location", "message":
"msg"}
+ )
+ exception = Exception()
+ exception.orig = db_err
+ result = PrestoEngineSpec._extract_error_message(exception)
+ assert result == "name at location: msg"
+
+ def test_extract_error_message_db_errr(self):
+ from pyhive.exc import DatabaseError
+
+ exception = DatabaseError({"message": "Err message"})
+ result = PrestoEngineSpec._extract_error_message(exception)
+ assert result == "Err message"
+
+ def test_extract_error_message_general_exception(self):
+ exception = Exception("Err message")
+ result = PrestoEngineSpec._extract_error_message(exception)
+ assert result == "Err message"
+
def test_is_readonly():
def is_readonly(sql: str) -> bool: