betodealmeida commented on code in PR #26767:
URL: https://github.com/apache/superset/pull/26767#discussion_r1482116701


##########
superset/sql_parse.py:
##########
@@ -252,6 +253,182 @@ def __eq__(self, __o: object) -> bool:
         return str(self) == str(__o)
 
 
+def extract_tables_from_statement(
+    statement: exp.Expression,
+    dialect: Optional[Dialects],
+) -> set[Table]:
+    """
+    Extract all table references in a single statement.
+
+    Please not that this is not trivial; consider the following queries:
+
+        DESCRIBE some_table;
+        SHOW PARTITIONS FROM some_table;
+        WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM 
masked_name;
+
+    See the unit tests for other tricky cases.
+    """
+    sources: Iterable[exp.Table]
+
+    if isinstance(statement, exp.Describe):
+        # A `DESCRIBE` query has no sources in sqlglot, so we need to 
explicitly
+        # query for all tables.
+        sources = statement.find_all(exp.Table)
+    elif isinstance(statement, exp.Command):
+        # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+        # `SELECT` statetement in order to extract tables.
+        literal = statement.find(exp.Literal)
+        if not literal:
+            return set()
+
+        pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+        sources = pseudo_query.find_all(exp.Table)
+    else:
+        sources = [
+            source
+            for scope in traverse_scope(statement)
+            for source in scope.sources.values()
+            if isinstance(source, exp.Table) and not is_cte(source, scope)
+        ]
+
+    return {
+        Table(
+            source.name,
+            source.db if source.db != "" else None,
+            source.catalog if source.catalog != "" else None,
+        )
+        for source in sources
+    }
+
+
+def is_cte(source: exp.Table, scope: Scope) -> bool:
+    """
+    Is the source a CTE?
+
+    CTEs in the parent scope look like tables (and are represented by
+    exp.Table objects), but should not be considered as such;
+    otherwise a user with access to table `foo` could access any table
+    with a query like this:
+
+        WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+    """
+    parent_sources = scope.parent.sources if scope.parent else {}
+    ctes_in_scope = {
+        name
+        for name, parent_scope in parent_sources.items()
+        if isinstance(parent_scope, Scope) and parent_scope.scope_type == 
ScopeType.CTE
+    }
+
+    return source.name in ctes_in_scope
+
+
+class SQLQuery:
+    """
+    A SQL query, with 0+ statements.
+    """
+
+    def __init__(
+        self,
+        query: str,
+        engine: Optional[str] = None,
+    ):
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        self.statements = [
+            SQLStatement(statement, engine=engine)
+            for statement in parse(query, dialect=dialect)
+            if statement
+        ]
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+
+    def get_settings(self) -> dict[str, str]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
+class SQLStatement:
+    """
+    A SQL statement.
+
+    This class provides helper methods to manipulate and introspect SQL.
+    """
+
+    def __init__(
+        self,
+        statement: Union[str, exp.Expression],
+        engine: Optional[str] = None,
+    ):
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        if isinstance(statement, str):
+            try:
+                self._parsed = self._parse_statement(statement, dialect)
+            except ParseError as ex:
+                raise SupersetParseError(statement, engine) from ex
+        else:
+            self._parsed = statement
+
+        self._dialect = dialect
+        self.tables = extract_tables_from_statement(self._parsed, dialect)
+
+    @staticmethod
+    def _parse_statement(
+        sql_statement: str,
+        dialect: Optional[Dialects],
+    ) -> exp.Expression:
+        """
+        Parse a single SQL statement.
+        """
+        statements = [
+            statement
+            for statement in sqlglot.parse(sql_statement, dialect=dialect)
+            if statement
+        ]
+        if len(statements) != 1:
+            raise ValueError("SQLStatement should have exactly one statement")
+
+        return statements[0]
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        write = Dialect.get_or_raise(self._dialect)
+        return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
+
+    def get_settings(self) -> dict[str, str]:

Review Comment:
   This method is meant for statements that set parameters using the `SET foo = 
'bar'` syntax, which seems to be more universal than `USE ...`. We have a few 
DB engine specs that need to know if `search_path` was set to something 
different for security reasons, which is why I added this method. And while 
today it's used just to check for the `search_path,` I used a more generic name 
because in the future it could be leveraged in different ways.
   
   I agree that there's functionality here that is DB-specific, but I'm hoping 
to address that by always requiring the dialect to be specified together with 
the SQL when instantiating the `SQLStatement`, so we can leave the heavy work 
to `sqlglot`. If we need custom functionality we can always subclass the 
`SQLStatement` for specific dialects, though hopefully we'd be able to avoid 
that by implementing the dialect upstream in `sqlglot` directly.
   
   The main reason for having DB-specific handling here and not in the DB 
engine specs is that I'd like for all the SQL parsing to be done only by these 
new classes, so we can have a cleaner interface. So I actually want to remove 
the custom DB engine spec methods that are related to SQL parsing that exist 
today, since the DB engine specs already do too much.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org
For additional commands, e-mail: notifications-h...@superset.apache.org

Reply via email to