This is an automated email from the ASF dual-hosted git repository.

willbarrett pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e1db016  test: presto engine spec tests (#12594)
e1db016 is described below

commit e1db016a6c324c64eec0b3ad2c1c0c8d9b85d813
Author: Karol Kostrzewa <[email protected]>
AuthorDate: Thu Jan 21 21:53:54 2021 +0100

    test: presto engine spec tests (#12594)
    
    * test get_table_names
    
    * test _get_full_name
    
    * add test_split_data_type
    
    * test _show_columns
    
    * add test_is_column_name_quoted
    
    * test select_star
    
    * test get_view_names
    
    * test estimate_statement_cost
    
    * test get_all_datasource_names
    
    * test get_create_view
    
    * test _extract_error_message
    
    * fix typo
---
 tests/db_engine_specs/presto_tests.py | 300 ++++++++++++++++++++++++++++++++++
 1 file changed, 300 insertions(+)

diff --git a/tests/db_engine_specs/presto_tests.py 
b/tests/db_engine_specs/presto_tests.py
index 9a493d3..9b2024d 100644
--- a/tests/db_engine_specs/presto_tests.py
+++ b/tests/db_engine_specs/presto_tests.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from collections import namedtuple
 from unittest import mock, skipUnless
 
 import pandas as pd
@@ -23,6 +24,7 @@ from sqlalchemy.sql import select
 
 from superset.db_engine_specs.presto import PrestoEngineSpec
 from superset.sql_parse import ParsedQuery
+from superset.utils.core import DatasourceName
 from tests.db_engine_specs.base_tests import TestDbEngineSpec
 
 
@@ -38,6 +40,45 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
             [], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
         )
 
+    @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+    def test_get_view_names(self, mock_is_feature_enabled):
+        mock_is_feature_enabled.return_value = True
+        mock_execute = mock.MagicMock()
+        mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", 
"e"]])
+        database = mock.MagicMock()
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
 = (
+            mock_execute
+        )
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
 = (
+            mock_fetchall
+        )
+        result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
+        mock_execute.assert_called_once_with(
+            "SELECT table_name FROM information_schema.views", {}
+        )
+        assert result == ["a", "d"]
+
+    @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+    def test_get_view_names_with_schema(self, mock_is_feature_enabled):
+        mock_is_feature_enabled.return_value = True
+        mock_execute = mock.MagicMock()
+        mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", 
"e"]])
+        database = mock.MagicMock()
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
 = (
+            mock_execute
+        )
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
 = (
+            mock_fetchall
+        )
+        schema = "schema"
+        result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
+        mock_execute.assert_called_once_with(
+            "SELECT table_name FROM information_schema.views "
+            "WHERE table_schema=%(schema)s",
+            {"schema": schema},
+        )
+        assert result == ["a", "d"]
+
     def verify_presto_column(self, column, expected_results):
         inspector = mock.Mock()
         inspector.engine.dialect.identifier_preparer.quote_identifier = 
mock.Mock()
@@ -516,6 +557,265 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
         sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
         assert sqla_type is None
 
+    @mock.patch(
+        
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+    )
+    @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+    
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+    def test_get_table_names_no_split_views_from_tables(
+        self, mock_get_view_names, mock_get_table_names, 
mock_is_feature_enabled
+    ):
+        mock_get_view_names.return_value = ["view1", "view2"]
+        table_names = ["table1", "table2", "view1", "view2"]
+        mock_get_table_names.return_value = table_names
+        mock_is_feature_enabled.return_value = False
+        tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), 
None)
+        assert tables == table_names
+
+    @mock.patch(
+        
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+    )
+    @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+    
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+    def test_get_table_names_split_views_from_tables(
+        self, mock_get_view_names, mock_get_table_names, 
mock_is_feature_enabled
+    ):
+        mock_get_view_names.return_value = ["view1", "view2"]
+        table_names = ["table1", "table2", "view1", "view2"]
+        mock_get_table_names.return_value = table_names
+        mock_is_feature_enabled.return_value = True
+        tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), 
None)
+        assert sorted(tables) == sorted(table_names)
+
+    @mock.patch(
+        
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
+    )
+    @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
+    
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
+    def test_get_table_names_split_views_from_tables_no_tables(
+        self, mock_get_view_names, mock_get_table_names, 
mock_is_feature_enabled
+    ):
+        mock_get_view_names.return_value = []
+        table_names = []
+        mock_get_table_names.return_value = table_names
+        mock_is_feature_enabled.return_value = True
+        tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), 
None)
+        assert tables == []
+
+    def test_get_full_name(self):
+        names = [
+            ("part1", "part2"),
+            ("part11", "part22"),
+        ]
+        result = PrestoEngineSpec._get_full_name(names)
+        assert result == "part1.part11"
+
+    def test_get_full_name_empty_tuple(self):
+        names = [
+            ("part1", "part2"),
+            ("", "part3"),
+            ("part4", "part5"),
+            ("", "part6"),
+        ]
+        result = PrestoEngineSpec._get_full_name(names)
+        assert result == "part1.part4"
+
+    def test_split_data_type(self):
+        data_type = "value1 value2"
+        result = PrestoEngineSpec._split_data_type(data_type, " ")
+        assert result == ["value1", "value2"]
+
+        data_type = "value1,value2"
+        result = PrestoEngineSpec._split_data_type(data_type, ",")
+        assert result == ["value1", "value2"]
+
+        data_type = '"value,1",value2'
+        result = PrestoEngineSpec._split_data_type(data_type, ",")
+        assert result == ['"value,1"', "value2"]
+
+    def test_show_columns(self):
+        inspector = mock.MagicMock()
+        inspector.engine.dialect.identifier_preparer.quote_identifier = (
+            lambda x: f'"{x}"'
+        )
+        mock_execute = mock.MagicMock(return_value=["a", "b"])
+        inspector.bind.execute = mock_execute
+        table_name = "table_name"
+        result = PrestoEngineSpec._show_columns(inspector, table_name, None)
+        assert result == ["a", "b"]
+        mock_execute.assert_called_once_with(f'SHOW COLUMNS FROM 
"{table_name}"')
+
+    def test_show_columns_with_schema(self):
+        inspector = mock.MagicMock()
+        inspector.engine.dialect.identifier_preparer.quote_identifier = (
+            lambda x: f'"{x}"'
+        )
+        mock_execute = mock.MagicMock(return_value=["a", "b"])
+        inspector.bind.execute = mock_execute
+        table_name = "table_name"
+        schema = "schema"
+        result = PrestoEngineSpec._show_columns(inspector, table_name, schema)
+        assert result == ["a", "b"]
+        mock_execute.assert_called_once_with(
+            f'SHOW COLUMNS FROM "{schema}"."{table_name}"'
+        )
+
+    def test_is_column_name_quoted(self):
+        column_name = "mock"
+        assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+        column_name = '"mock'
+        assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+        column_name = '"moc"k'
+        assert PrestoEngineSpec._is_column_name_quoted(column_name) is False
+
+        column_name = '"moc"k"'
+        assert PrestoEngineSpec._is_column_name_quoted(column_name) is True
+
+    @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.select_star")
+    def test_select_star_no_presto_expand_data(self, mock_select_star):
+        database = mock.Mock()
+        table_name = "table_name"
+        engine = mock.Mock()
+        cols = [
+            {"col1": "val1"},
+            {"col2": "val2"},
+        ]
+        PrestoEngineSpec.select_star(database, table_name, engine, cols=cols)
+        mock_select_star.assert_called_once_with(
+            database, table_name, engine, None, 100, False, True, True, cols
+        )
+
+    @mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
+    @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.select_star")
+    def test_select_star_presto_expand_data(
+        self, mock_select_star, mock_is_feature_enabled
+    ):
+        mock_is_feature_enabled.return_value = True
+        database = mock.Mock()
+        table_name = "table_name"
+        engine = mock.Mock()
+        cols = [
+            {"name": "val1"},
+            {"name": "val2<?!@#$312,/'][p098"},
+            {"name": ".val2"},
+            {"name": "val2."},
+            {"name": "val.2"},
+            {"name": ".val2."},
+        ]
+        PrestoEngineSpec.select_star(
+            database, table_name, engine, show_cols=True, cols=cols
+        )
+        mock_select_star.assert_called_once_with(
+            database,
+            table_name,
+            engine,
+            None,
+            100,
+            True,
+            True,
+            True,
+            [{"name": "val1"}, {"name": "val2<?!@#$312,/'][p098"},],
+        )
+
+    def test_estimate_statement_cost(self):
+        mock_cursor = mock.MagicMock()
+        estimate_json = {"a": "b"}
+        mock_cursor.fetchone.return_value = [
+            '{"a": "b"}',
+        ]
+        result = PrestoEngineSpec.estimate_statement_cost(
+            "SELECT * FROM brth_names", mock_cursor
+        )
+        assert result == estimate_json
+
+    def test_estimate_statement_cost_invalid_syntax(self):
+        mock_cursor = mock.MagicMock()
+        mock_cursor.execute.side_effect = Exception()
+        with self.assertRaises(Exception):
+            PrestoEngineSpec.estimate_statement_cost(
+                "DROP TABLE brth_names", mock_cursor
+            )
+
+    def test_get_all_datasource_names(self):
+        df = pd.DataFrame.from_dict(
+            {"table_schema": ["schema1", "schema2"], "table_name": ["name1", 
"name2"]}
+        )
+        database = mock.MagicMock()
+        database.get_df.return_value = df
+        result = PrestoEngineSpec.get_all_datasource_names(database, "table")
+        expected_result = [
+            DatasourceName(schema="schema1", table="name1"),
+            DatasourceName(schema="schema2", table="name2"),
+        ]
+        assert result == expected_result
+
+    def test_get_create_view(self):
+        mock_execute = mock.MagicMock()
+        mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", 
"e"]])
+        database = mock.MagicMock()
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
 = (
+            mock_execute
+        )
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall
 = (
+            mock_fetchall
+        )
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value
 = (
+            False
+        )
+        schema = "schema"
+        table = "table"
+        result = PrestoEngineSpec.get_create_view(database, schema=schema, 
table=table)
+        assert result == "a"
+        mock_execute.assert_called_once_with(f"SHOW CREATE VIEW 
{schema}.{table}")
+
+    def test_get_create_view_exception(self):
+        mock_execute = mock.MagicMock(side_effect=Exception())
+        database = mock.MagicMock()
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
 = (
+            mock_execute
+        )
+        schema = "schema"
+        table = "table"
+        with self.assertRaises(Exception):
+            PrestoEngineSpec.get_create_view(database, schema=schema, 
table=table)
+
+    def test_get_create_view_database_error(self):
+        from pyhive.exc import DatabaseError
+
+        mock_execute = mock.MagicMock(side_effect=DatabaseError())
+        database = mock.MagicMock()
+        
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute
 = (
+            mock_execute
+        )
+        schema = "schema"
+        table = "table"
+        result = PrestoEngineSpec.get_create_view(database, schema=schema, 
table=table)
+        assert result is None
+
+    def test_extract_error_message_orig(self):
+        DatabaseError = namedtuple("DatabaseError", ["error_dict"])
+        db_err = DatabaseError(
+            {"errorName": "name", "errorLocation": "location", "message": 
"msg"}
+        )
+        exception = Exception()
+        exception.orig = db_err
+        result = PrestoEngineSpec._extract_error_message(exception)
+        assert result == "name at location: msg"
+
+    def test_extract_error_message_db_errr(self):
+        from pyhive.exc import DatabaseError
+
+        exception = DatabaseError({"message": "Err message"})
+        result = PrestoEngineSpec._extract_error_message(exception)
+        assert result == "Err message"
+
+    def test_extract_error_message_general_exception(self):
+        exception = Exception("Err message")
+        result = PrestoEngineSpec._extract_error_message(exception)
+        assert result == "Err message"
+
 
 def test_is_readonly():
     def is_readonly(sql: str) -> bool:

Reply via email to