betodealmeida commented on code in PR #26767:
URL: https://github.com/apache/superset/pull/26767#discussion_r1464065388


##########
superset/sql_parse.py:
##########
@@ -252,6 +253,163 @@ def __eq__(self, __o: object) -> bool:
         return str(self) == str(__o)
 
 
+def extract_tables_from_statement(

Review Comment:
   This function was originally a method in `ParsedQuery`, I just converted it 
into a function without any big changes (see below).



##########
superset/sql_parse.py:
##########
@@ -252,6 +253,163 @@ def __eq__(self, __o: object) -> bool:
         return str(self) == str(__o)
 
 
+def extract_tables_from_statement(
+    statement: exp.Expression,
+    dialect: Optional[Dialects],
+) -> set[Table]:
+    """
+    Extract all table references in a single statement.
+
+    Please not that this is not trivial; consider the following queries:
+
+        DESCRIBE some_table;
+        SHOW PARTITIONS FROM some_table;
+        WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM 
masked_name;
+
+    See the unit tests for other tricky cases.
+    """
+    sources: Iterable[exp.Table]
+
+    if isinstance(statement, exp.Describe):
+        # A `DESCRIBE` query has no sources in sqlglot, so we need to 
explicitly
+        # query for all tables.
+        sources = statement.find_all(exp.Table)
+    elif isinstance(statement, exp.Command):
+        # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+        # `SELECT` statetement in order to extract tables.
+        literal = statement.find(exp.Literal)
+        if not literal:
+            return set()
+
+        pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+        sources = pseudo_query.find_all(exp.Table)
+    else:
+        sources = [
+            source
+            for scope in traverse_scope(statement)
+            for source in scope.sources.values()
+            if isinstance(source, exp.Table) and not is_cte(source, scope)
+        ]
+
+    return {
+        Table(
+            source.name,
+            source.db if source.db != "" else None,
+            source.catalog if source.catalog != "" else None,
+        )
+        for source in sources
+    }
+
+
+def is_cte(source: exp.Table, scope: Scope) -> bool:

Review Comment:
   This function was originally a method in `ParsedQuery`, I just converted it 
into a function without any big changes (see below).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org
For additional commands, e-mail: notifications-h...@superset.apache.org

Reply via email to