This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 5771fdda4f [KYUUBI #7190] Fix - Presto SQLAlchemy dialect did not 
implement get_view_names
5771fdda4f is described below

commit 5771fdda4fccca25e27549bad3a78afa49b47700
Author: Cheng Pan <[email protected]>
AuthorDate: Mon Sep 22 16:39:09 2025 +0800

    [KYUUBI #7190] Fix - Presto SQLAlchemy dialect did not implement 
get_view_names
    
    ### Why are the changes needed?
    Presto SQLAlchemy dialect did not implement the `get_view_names` method and 
resulted in an exception when trying to inspect the schema. This was discovered 
in Superset repo whilst trying to update the pandas package which now makes a 
call to `get_view_names`.
    
    ### How was this patch tested?
    Very basic Python tests have been added here, but all the SQLAlchemy 
dialects in this repo would benefit from running the full SQLAlchemy dialect 
test suite instead of these bespoke tests.
    
    Closes #7190 from rad-pat/fix-presto-dialect.
    
    Closes #7190
    
    c2d06f7e0 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py
    27396977e [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py
    1c7b62850 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py
    2e7040a14 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py
    89d3f55fe [Cheng Pan] Update python/pyhive/__init__.py
    b8deadcdb [Pat Buxton] Bump python version to 0.7.1
    ab829ee04 [Pat Buxton] Fix - Presto SQLAlchemy dialect did not implement 
get_view_names
    
    Lead-authored-by: Cheng Pan <[email protected]>
    Co-authored-by: Pat Buxton <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 python/pyhive/sqlalchemy_presto.py            | 35 ++++++++++++++++++++++++++-
 python/pyhive/tests/test_sqlalchemy_presto.py | 29 +++++++++++++++++++++-
 2 files changed, 62 insertions(+), 2 deletions(-)

diff --git a/python/pyhive/sqlalchemy_presto.py 
b/python/pyhive/sqlalchemy_presto.py
index 33a41bae3e..f5a256fb8d 100644
--- a/python/pyhive/sqlalchemy_presto.py
+++ b/python/pyhive/sqlalchemy_presto.py
@@ -23,7 +23,7 @@ except ImportError:
     from sqlalchemy.dialects import mysql
     mysql_tinyinteger = mysql.base.MSTinyInteger
 from sqlalchemy.engine import default
-from sqlalchemy.sql import compiler
+from sqlalchemy.sql import compiler, bindparam
 from sqlalchemy.sql.compiler import SQLCompiler
 
 from pyhive import presto
@@ -204,12 +204,45 @@ class PrestoDialect(default.DefaultDialect):
         else:
             return []
 
+    def _get_default_schema_name(self, connection):
+        #'SELECT CURRENT_SCHEMA()'
+        return super()._get_default_schema_name(connection)
+
     def get_table_names(self, connection, schema=None, **kw):
         query = 'SHOW TABLES'
+        # N.B. This is incorrect, if no schema is provided, the 
current/default schema should be used
+        #  with a call to an overridden 
self._get_default_schema_name(connection), but I could not
+        #  see how to implement that as there is no CURRENT_SCHEMA function
+        #  default_schema = self._get_default_schema_name(connection)
+
         if schema:
             query += ' FROM ' + 
self.identifier_preparer.quote_identifier(schema)
         return [row.Table for row in connection.execute(text(query))]
 
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema:
+            view_name_query = """
+                SELECT table_name
+                FROM information_schema.views
+                WHERE table_schema = :schema
+            """
+            query = text(view_name_query).bindparams(
+                bindparam("schema", type_=types.Unicode)
+            )
+        else:
+            # N.B. This is incorrect, if no schema is provided, the 
current/default schema should
+            #  be used with a call to 
self._get_default_schema_name(connection), but I could not
+            #  see how to implement that
+            #  default_schema = self._get_default_schema_name(connection)
+            view_name_query = """
+                SELECT table_name
+                FROM information_schema.views
+            """
+            query = text(view_name_query)
+
+        result = connection.execute(query, dict(schema=schema))
+        return [row[0] for row in result]
+
     def do_rollback(self, dbapi_connection):
         # No transactions for Presto
         pass
diff --git a/python/pyhive/tests/test_sqlalchemy_presto.py 
b/python/pyhive/tests/test_sqlalchemy_presto.py
index 336dd12e24..e8b04ea249 100644
--- a/python/pyhive/tests/test_sqlalchemy_presto.py
+++ b/python/pyhive/tests/test_sqlalchemy_presto.py
@@ -102,4 +102,31 @@ class TestSqlAlchemyPresto(unittest.TestCase, 
SqlAlchemyTestCase):
             self.assertFalse(insp.has_table("THIS_TABLE_DOSE_not_exist"))
         else:
             self.assertFalse(Table('THIS_TABLE_DOSE_NOT_EXIST', 
MetaData(bind=engine)).exists())
-            self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', 
MetaData(bind=engine)).exists())
\ No newline at end of file
+            self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', 
MetaData(bind=engine)).exists())
+
+    @with_engine_connection
+    def test_reflect_table_names(self, engine, connection):
+        sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", 
sqlalchemy.__version__).group(1))
+        if sqlalchemy_version >= 1.4:
+            insp = sqlalchemy.inspect(engine)
+            table_names = insp.get_table_names()
+            self.assertIn("one_row", table_names)
+            self.assertIn("one_row_complex", table_names)
+            self.assertIn("many_rows", table_names)
+            self.assertNotIn("THIS_TABLE_DOES_not_exist", table_names)
+        else:
+            self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
+            self.assertTrue(Table('one_row_complex', 
MetaData(bind=engine)).exists())
+            self.assertTrue(Table('many_rows', MetaData(bind=engine)).exists())
+            self.assertFalse(Table('THIS_TABLE_DOES_not_exist', 
MetaData(bind=engine)).exists())
+
+    @with_engine_connection
+    def test_reflect_view_names(self, engine, connection):
+        sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", 
sqlalchemy.__version__).group(1))
+        if sqlalchemy_version >= 1.4:
+            insp = sqlalchemy.inspect(engine)
+            view_names = insp.get_view_names()
+            self.assertNotIn("one_row", view_names)
+            self.assertNotIn("one_row_complex", view_names)
+            self.assertNotIn("many_rows", view_names)
+            self.assertNotIn("THIS_TABLE_DOES_not_exist", view_names)

Reply via email to