Copilot commented on code in PR #16480: URL: https://github.com/apache/pinot/pull/16480#discussion_r2246670472
########## pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java: ########## @@ -0,0 +1,432 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; + +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWith; +import org.apache.calcite.sql.SqlWithItem; +import org.apache.pinot.sql.parsers.CalciteSqlParser; +import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Helper class to extract table names from Calcite SqlNode tree. + */ +public class TableNameExtractor { + private static final Logger LOGGER = LoggerFactory.getLogger(TableNameExtractor.class); + // Static map of reserved SQL keywords loaded from config file + private static final Map<String, Boolean> RESERVED_KEYWORDS = loadReservedKeywords(); + + /** + * Returns the name of all the tables used in a sql query. + * + * @param query The SQL query string to analyze + * @return name of all the tables used in a sql query, or null if parsing fails + */ + @Nullable + public static String[] resolveTableName(String query) { + SqlNodeAndOptions sqlNodeAndOptions; + try { + sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query); + } catch (Exception e) { + LOGGER.error("Cannot parse table name from query: {}. Fallback to broker selector default.", query, e); + return null; + } + try { + Set<String> tableNames = extractTableNamesFromMultiStageQuery(sqlNodeAndOptions.getSqlNode()); + if (tableNames != null) { + return tableNames.toArray(new String[0]); + } + } catch (Exception e) { + LOGGER.error("Cannot extract table name from query: {}. Fallback to broker selector default.", query, e); + } + return null; + } + + /** + * Extracts table names from a multi-stage query using Calcite SQL AST traversal. + * + * @param sqlNode The root SqlNode of the parsed query + * @return Set of table names found in the query + */ + private static Set<String> extractTableNamesFromMultiStageQuery(SqlNode sqlNode) { + TableNameExtractor extractor = new TableNameExtractor(); + try { + extractor.extractTableNames(sqlNode); + return extractor.getTableNames(); + } catch (Exception e) { + LOGGER.debug("Failed to extract table names from multi-stage query", e); + return Collections.emptySet(); + } + } + + private final Set<String> _tableNames = new HashSet<>(); + private final Set<String> _cteNames = new HashSet<>(); + private boolean _inFromClause = false; + + public Set<String> getTableNames() { + return _tableNames; + } + + public void extractTableNames(SqlNode node) { + if (node == null) { + return; + } + if (node instanceof SqlWith) { + visitWith((SqlWith) node); + } else if (node instanceof SqlOrderBy) { + visitOrderBy((SqlOrderBy) node); + } else if (node instanceof SqlWithItem) { + visitWithItem((SqlWithItem) node); + } else if (node instanceof SqlSelect) { + visitSelect((SqlSelect) node); + } else if (node instanceof SqlJoin) { + visitJoin((SqlJoin) node); + } else if (node instanceof SqlBasicCall) { + visitBasicCall((SqlBasicCall) node); + } else if (node instanceof SqlIdentifier) { + visitIdentifier((SqlIdentifier) node); + } else if (node instanceof SqlNodeList) { + visitNodeList((SqlNodeList) node); + } else { + // Handle unknown node types by trying to access operands + visitUnknownNode(node); + } + } + + private void visitWith(SqlWith with) { + // Visit the WITH list (CTE definitions) + if (with.withList != null) { + visitNodeList(with.withList); + } + // Visit the main query body + if (with.body != null) { + extractTableNames(with.body); + } + } + + private void visitOrderBy(SqlOrderBy orderBy) { + // Visit the main query - this is the most important part + if (orderBy.query != null) { + extractTableNames(orderBy.query); + } + // Visit ORDER BY expressions for potential subqueries + if (orderBy.orderList != null) { + // Don't set inFromClause=true for ORDER BY expressions + // as they typically contain column references, not table names + visitNodeList(orderBy.orderList); + } + // Visit OFFSET clause if it contains subqueries (rare but possible) + if (orderBy.offset != null) { + extractTableNames(orderBy.offset); + } + // Visit FETCH/LIMIT clause if it contains subqueries (rare but possible) + if (orderBy.fetch != null) { + extractTableNames(orderBy.fetch); + } + } + + private void visitWithItem(SqlWithItem withItem) { + // Track the CTE name so we don't treat it as a table later + if (withItem.name != null) { + String cteName = withItem.name.getSimple(); + _cteNames.add(cteName); + } + // Extract table names from the CTE query definition, not the CTE alias + if (withItem.query != null) { + extractTableNames(withItem.query); + } + } + + private void visitSelect(SqlSelect select) { + // Visit FROM clause - this is where we expect to find table names + if (select.getFrom() != null) { + _inFromClause = true; + extractTableNames(select.getFrom()); + _inFromClause = false; + } + // Visit other clauses for subqueries + if (select.getWhere() != null) { + extractTableNames(select.getWhere()); + } + if (select.getGroup() != null) { + visitNodeList(select.getGroup()); + } + if (select.getHaving() != null) { + extractTableNames(select.getHaving()); + } + if (select.getOrderList() != null) { + visitNodeList(select.getOrderList()); + } + if (select.getSelectList() != null) { + visitNodeList(select.getSelectList()); + } + } + + private void visitJoin(SqlJoin join) { + // Visit both sides of the join - ensure they're processed as FROM clause items + boolean wasInFromClause = _inFromClause; + if (join.getLeft() != null) { + _inFromClause = true; + extractTableNames(join.getLeft()); + } + if (join.getRight() != null) { + _inFromClause = true; + extractTableNames(join.getRight()); + } + // Visit join condition but not as part of FROM clause context + // This handles potential subqueries in join conditions while avoiding + // incorrectly extracting column references as table names + if (join.getCondition() != null) { + _inFromClause = false; + extractTableNames(join.getCondition()); + } + // Restore original context + _inFromClause = wasInFromClause; + } + + private void visitBasicCall(SqlBasicCall call) { + String operatorName = call.getOperator().getName().toUpperCase(); + if (operatorName.equals("AS")) { + // Handle table aliases like "tableA AS a" + // For AS operations, the first operand is the actual table name + if (call.getOperandList().size() > 0 && call.getOperandList().get(0) != null) { + extractTableNames(call.getOperandList().get(0)); + } + } else if (operatorName.equals("WITH")) { + // Handle CTE (Common Table Expression) + visitWithClause(call); + } else if (operatorName.equals("VALUES")) { + // Handle VALUES clause - usually doesn't contain table references + // Skip this to avoid false positives + } else { + // For other basic calls, visit all operands + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } + + private void visitIdentifier(SqlIdentifier identifier) { + // Only extract table names when we're in a FROM clause + if (_inFromClause && identifier.names.size() >= 1) { + String tableName = identifier.names.get(identifier.names.size() - 1); + // Filter out SQL keywords, system identifiers, and CTE names + if (!isReservedKeyword(tableName) && !tableName.startsWith("$") && !_cteNames.contains(tableName)) { + _tableNames.add(tableName); + } + } + } + + /** + * Visit a SqlNodeList by visiting each node in the list. + */ + private void visitNodeList(SqlNodeList nodeList) { + if (nodeList != null) { + for (SqlNode node : nodeList) { + if (node != null) { + extractTableNames(node); + } + } + } + } + + /** + * Handle unknown node types by attempting to visit their operands. + */ + private void visitUnknownNode(SqlNode node) { + try { + // Try to get operands list using reflection or common methods + if (node.getKind() != null) { + switch (node.getKind().name()) { + case "WITH": + visitWithClause(node); + break; + case "ORDER_BY": + visitOrderByCall(node); + break; + default: + // For other unknown nodes, try to visit operands if they exist + visitNodeOperands(node); + break; + } + } else { + visitNodeOperands(node); + } + } catch (Exception e) { + // Ignore reflection errors and continue + } + } + + /** + * Handle WITH clause (CTE - Common Table Expression). + */ + private void visitWithClause(SqlNode node) { + try { + // WITH clause typically has operands: [with_list, query] + if (node instanceof SqlBasicCall) { + SqlBasicCall withCall = (SqlBasicCall) node; + for (SqlNode operand : withCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + // Fallback to generic operand handling + visitNodeOperands(node); + } + } + + /** + * Handle ORDER BY clause - this method is now replaced by visitOrderBy(SqlOrderBy). + * Keeping for backward compatibility with visitUnknownNode. + */ + private void visitOrderByCall(SqlNode node) { + try { + if (node instanceof SqlBasicCall) { + SqlBasicCall orderByCall = (SqlBasicCall) node; + // ORDER BY typically has [query, order_list] + for (SqlNode operand : orderByCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + visitNodeOperands(node); + } + } + + /** + * Generic method to visit node operands when specific handling is not available. + */ + private void visitNodeOperands(SqlNode node) { + try { + // Try to access operands through common interface + if (node instanceof SqlBasicCall) { + SqlBasicCall call = (SqlBasicCall) node; + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + // Nothing more we can do + } + } + + /** + * Check if the given name is a reserved SQL keyword that shouldn't be treated as a table name. + */ + private boolean isReservedKeyword(String name) { + if (name == null) { + return true; + } + String upperName = name.toUpperCase(); + return RESERVED_KEYWORDS.containsKey(upperName); + } + + /** + * Load reserved SQL keywords from the SqlParserImplConstants. + * This method uses the generated constants from the parser to get all reserved keywords. + */ + private static Map<String, Boolean> loadReservedKeywords() { + Map<String, Boolean> reservedKeywords = new HashMap<>(); + try { + // Use reflection to access SqlParserImplConstants.tokenImage + Class<?> constantsClass = Class.forName("org.apache.pinot.sql.parsers.parser.SqlParserImplConstants"); Review Comment: The hard-coded class name string makes the code fragile to refactoring. Consider using a constant or extracting this to a configuration property to improve maintainability. ```suggestion Class<?> constantsClass = Class.forName(SQL_PARSER_IMPL_CONSTANTS_CLASS); ``` ########## pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java: ########## @@ -0,0 +1,612 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + + +/** + * Tests for the TableNameExtractor class. + */ +public class TableNameExtractorTest { + + @Test + public void testResolveTableNameWithSingleQuery() { + // Test that single queries work correctly + String singleQuery = "SELECT * FROM myTable WHERE id > 100"; + + String[] tableNames = TableNameExtractor.resolveTableName(singleQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "myTable", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithSingleStatementAlias() { + String singleStatementQuery = "SELECT stats.* FROM airlineStats stats LIMIT 10"; + String[] tableNames = TableNameExtractor.resolveTableName(singleStatementQuery); + + assertNotNull(tableNames); + assertEquals(tableNames.length, 1); + assertEquals(tableNames[0], "airlineStats"); + } + + @Test + public void testResolveTableNameWithMultiStatementQuery() { + // Test the fix for issue #11823: CalciteSQLParser error with multi-statement queries + String multiStatementQuery = "SET useMultistageEngine=true;\nSELECT stats.* FROM airlineStats stats LIMIT 10"; + + // This should not throw a ClassCastException anymore + String[] tableNames = TableNameExtractor.resolveTableName(multiStatementQuery); + + // Should successfully resolve the table name + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "airlineStats", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithMultipleSetStatements() { + // Test with multiple SET statements + String multiSetQuery = "SET useMultistageEngine=true;\nSET timeoutMs=10000;\nSELECT * FROM testTable"; + + String[] tableNames = TableNameExtractor.resolveTableName(multiSetQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "testTable", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithMultipleSetStatementsAndJoin() { + String multiStatementQuery = "SET useMultistageEngine=true;\nSET maxRowsInJoin=1000;\n" + + "SELECT stats.* FROM airlineStats stats LIMIT 10"; + String[] tableNames = TableNameExtractor.resolveTableName(multiStatementQuery); + + assertNotNull(tableNames, "Table names should be resolved for queries with multiple SET statements"); + assertEquals(tableNames.length, 1); + assertEquals(tableNames[0], "airlineStats"); + } + + @Test + public void testResolveTableNameWithJoin() { + // Test with JOIN queries + String joinQuery = "SELECT * FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id"; + + String[] tableNames = TableNameExtractor.resolveTableName(joinQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("table1"), "Should contain table1"); + assertTrue(Arrays.asList(tableNames).contains("table2"), "Should contain table2"); + } + + @Test + public void testResolveTableNameWithJoinQueryAndSetStatements() { + String joinQuery = "SET useMultistageEngine=true;\n" + + "SELECT a.col1, b.col2 FROM tableA a JOIN tableB b ON a.id = b.id"; + String[] tableNames = TableNameExtractor.resolveTableName(joinQuery); + + assertNotNull(tableNames, "Table names should be resolved for join queries with SET statements"); + assertEquals(tableNames.length, 2); + + Set<String> expectedTableNames = new HashSet<>(Arrays.asList("tableA", "tableB")); + Set<String> actualTableNames = new HashSet<>(Arrays.asList(tableNames)); + assertEquals(actualTableNames, expectedTableNames); + } + + @Test + public void testResolveTableNameWithExplicitAlias() { + // Test with explicit AS alias + String aliasQuery = "SELECT u.name FROM users AS u WHERE u.active = true"; + + String[] tableNames = TableNameExtractor.resolveTableName(aliasQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "users", "Should resolve the actual table name, not the alias"); + } + + @Test + public void testResolveTableNameWithImplicitAlias() { + // Test with implicit alias (no AS keyword) + String implicitAliasQuery = "SELECT o.id, u.name FROM orders o JOIN users u ON o.user_id = u.id"; + + String[] tableNames = TableNameExtractor.resolveTableName(implicitAliasQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + } + + @Test + public void testResolveTableNameWithCTE() { + // Test with Common Table Expression (CTE) + String cteQuery = "WITH active_users AS (SELECT * FROM users WHERE active = true) " + + "SELECT au.name FROM active_users au JOIN orders o ON au.id = o.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(cteQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table from CTE"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithNestedCTE() { + // Test with nested CTEs + String nestedCteQuery = "WITH user_orders AS (" + + " SELECT u.id, u.name, o.order_date " + + " FROM users u JOIN orders o ON u.id = o.user_id" + + "), recent_orders AS (" + + " SELECT * FROM user_orders WHERE order_date > '2023-01-01'" + + ") " + + "SELECT ro.name FROM recent_orders ro JOIN products p ON ro.id = p.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(nestedCteQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 3, "Should resolve three tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain products table"); + } + + @Test + public void testResolveTableNameWithSubqueryAlias() { + // Test with subquery alias + String subqueryQuery = "SELECT t.name FROM (SELECT * FROM users WHERE active = true) AS t " + + "JOIN orders o ON t.id = o.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(subqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table from subquery"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithComplexJoinAndAliases() { + // Test with multiple JOINs and various alias styles + String complexQuery = "SELECT u.name, o.total, p.title " + + "FROM users AS u " + + "INNER JOIN orders o ON u.id = o.user_id " + + "LEFT JOIN order_items oi ON o.id = oi.order_id " + + "RIGHT JOIN products AS p ON oi.product_id = p.id " + + "WHERE u.active = true"; + + String[] tableNames = TableNameExtractor.resolveTableName(complexQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 4, "Should resolve four tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("order_items"), "Should contain order_items table"); + assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain products table"); + } + + @Test + public void testResolveTableNameWithJoinConditionSubquery() { + // Test with subquery in join condition + String joinSubqueryQuery = "SELECT u.name, o.total " + + "FROM users u " + + "JOIN orders o ON u.id = o.user_id " + + "AND o.id IN (SELECT order_id FROM order_items WHERE quantity > 5)"; + + String[] tableNames = TableNameExtractor.resolveTableName(joinSubqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 3, "Should resolve three tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("order_items"), + "Should contain order_items table from subquery"); + } + + @Test + public void testResolveTableNameWithOrderBy() { + // Test with ORDER BY clause + String orderByQuery = "SELECT * FROM users ORDER BY name"; + String[] tableNames = TableNameExtractor.resolveTableName(orderByQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "users", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithOrderBySubquery() { + // Test with subquery in ORDER BY clause (rare but possible) + String orderBySubqueryQuery = "SELECT * FROM users u ORDER BY " + + "(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id)"; + String[] tableNames = TableNameExtractor.resolveTableName(orderBySubqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithInvalidQuery() { + String invalidQuery = "INVALID SQL QUERY"; + String[] tableNames = TableNameExtractor.resolveTableName(invalidQuery); + + // Should return null when query cannot be parsed (fallback to default broker selector) + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithOnlySetStatements() { + String onlySetQuery = "SET useMultistageEngine=true;"; + String[] tableNames = TableNameExtractor.resolveTableName(onlySetQuery); + + // Should return null when there's no actual query statement + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithNullQuery() { + String[] tableNames = TableNameExtractor.resolveTableName(null); + + // Should return null when query is null + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithEmptyQuery() { + String[] tableNames = TableNameExtractor.resolveTableName(""); + + // Should return null when query is empty + assertNull(tableNames); + } + + /** + * Data provider for SQL queries and their expected table names. + * This makes it easy to add new test cases by simply adding entries to this array. + * + * @return Object array containing: [testName, sqlQuery, expectedTableNames] + */ + @org.testng.annotations.DataProvider(name = "sqlQueries") Review Comment: [nitpick] Using the fully qualified annotation name is unnecessary when TestNG annotations are already imported. Use `@DataProvider(name = "sqlQueries")` instead for consistency and readability. ```suggestion @DataProvider(name = "sqlQueries") ``` ########## pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java: ########## @@ -0,0 +1,432 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; + +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWith; +import org.apache.calcite.sql.SqlWithItem; +import org.apache.pinot.sql.parsers.CalciteSqlParser; +import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Helper class to extract table names from Calcite SqlNode tree. + */ +public class TableNameExtractor { + private static final Logger LOGGER = LoggerFactory.getLogger(TableNameExtractor.class); + // Static map of reserved SQL keywords loaded from config file + private static final Map<String, Boolean> RESERVED_KEYWORDS = loadReservedKeywords(); + + /** + * Returns the name of all the tables used in a sql query. + * + * @param query The SQL query string to analyze + * @return name of all the tables used in a sql query, or null if parsing fails + */ + @Nullable + public static String[] resolveTableName(String query) { + SqlNodeAndOptions sqlNodeAndOptions; + try { + sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query); + } catch (Exception e) { + LOGGER.error("Cannot parse table name from query: {}. Fallback to broker selector default.", query, e); + return null; + } + try { + Set<String> tableNames = extractTableNamesFromMultiStageQuery(sqlNodeAndOptions.getSqlNode()); + if (tableNames != null) { + return tableNames.toArray(new String[0]); + } + } catch (Exception e) { + LOGGER.error("Cannot extract table name from query: {}. Fallback to broker selector default.", query, e); + } + return null; + } + + /** + * Extracts table names from a multi-stage query using Calcite SQL AST traversal. + * + * @param sqlNode The root SqlNode of the parsed query + * @return Set of table names found in the query + */ + private static Set<String> extractTableNamesFromMultiStageQuery(SqlNode sqlNode) { + TableNameExtractor extractor = new TableNameExtractor(); + try { + extractor.extractTableNames(sqlNode); + return extractor.getTableNames(); + } catch (Exception e) { + LOGGER.debug("Failed to extract table names from multi-stage query", e); + return Collections.emptySet(); + } + } + + private final Set<String> _tableNames = new HashSet<>(); + private final Set<String> _cteNames = new HashSet<>(); + private boolean _inFromClause = false; + + public Set<String> getTableNames() { + return _tableNames; + } + + public void extractTableNames(SqlNode node) { + if (node == null) { + return; + } + if (node instanceof SqlWith) { + visitWith((SqlWith) node); + } else if (node instanceof SqlOrderBy) { + visitOrderBy((SqlOrderBy) node); + } else if (node instanceof SqlWithItem) { + visitWithItem((SqlWithItem) node); + } else if (node instanceof SqlSelect) { + visitSelect((SqlSelect) node); + } else if (node instanceof SqlJoin) { + visitJoin((SqlJoin) node); + } else if (node instanceof SqlBasicCall) { + visitBasicCall((SqlBasicCall) node); + } else if (node instanceof SqlIdentifier) { + visitIdentifier((SqlIdentifier) node); + } else if (node instanceof SqlNodeList) { + visitNodeList((SqlNodeList) node); + } else { + // Handle unknown node types by trying to access operands + visitUnknownNode(node); + } + } + + private void visitWith(SqlWith with) { + // Visit the WITH list (CTE definitions) + if (with.withList != null) { + visitNodeList(with.withList); + } + // Visit the main query body + if (with.body != null) { + extractTableNames(with.body); + } + } + + private void visitOrderBy(SqlOrderBy orderBy) { + // Visit the main query - this is the most important part + if (orderBy.query != null) { + extractTableNames(orderBy.query); + } + // Visit ORDER BY expressions for potential subqueries + if (orderBy.orderList != null) { + // Don't set inFromClause=true for ORDER BY expressions + // as they typically contain column references, not table names + visitNodeList(orderBy.orderList); + } + // Visit OFFSET clause if it contains subqueries (rare but possible) + if (orderBy.offset != null) { + extractTableNames(orderBy.offset); + } + // Visit FETCH/LIMIT clause if it contains subqueries (rare but possible) + if (orderBy.fetch != null) { + extractTableNames(orderBy.fetch); + } + } + + private void visitWithItem(SqlWithItem withItem) { + // Track the CTE name so we don't treat it as a table later + if (withItem.name != null) { + String cteName = withItem.name.getSimple(); + _cteNames.add(cteName); + } + // Extract table names from the CTE query definition, not the CTE alias + if (withItem.query != null) { + extractTableNames(withItem.query); + } + } + + private void visitSelect(SqlSelect select) { + // Visit FROM clause - this is where we expect to find table names + if (select.getFrom() != null) { + _inFromClause = true; + extractTableNames(select.getFrom()); + _inFromClause = false; + } + // Visit other clauses for subqueries + if (select.getWhere() != null) { + extractTableNames(select.getWhere()); + } + if (select.getGroup() != null) { + visitNodeList(select.getGroup()); + } + if (select.getHaving() != null) { + extractTableNames(select.getHaving()); + } + if (select.getOrderList() != null) { + visitNodeList(select.getOrderList()); + } + if (select.getSelectList() != null) { + visitNodeList(select.getSelectList()); + } + } + + private void visitJoin(SqlJoin join) { + // Visit both sides of the join - ensure they're processed as FROM clause items + boolean wasInFromClause = _inFromClause; + if (join.getLeft() != null) { + _inFromClause = true; + extractTableNames(join.getLeft()); + } + if (join.getRight() != null) { + _inFromClause = true; + extractTableNames(join.getRight()); + } + // Visit join condition but not as part of FROM clause context + // This handles potential subqueries in join conditions while avoiding + // incorrectly extracting column references as table names + if (join.getCondition() != null) { + _inFromClause = false; + extractTableNames(join.getCondition()); + } + // Restore original context + _inFromClause = wasInFromClause; + } + + private void visitBasicCall(SqlBasicCall call) { + String operatorName = call.getOperator().getName().toUpperCase(); + if (operatorName.equals("AS")) { + // Handle table aliases like "tableA AS a" + // For AS operations, the first operand is the actual table name + if (call.getOperandList().size() > 0 && call.getOperandList().get(0) != null) { + extractTableNames(call.getOperandList().get(0)); + } + } else if (operatorName.equals("WITH")) { + // Handle CTE (Common Table Expression) + visitWithClause(call); + } else if (operatorName.equals("VALUES")) { + // Handle VALUES clause - usually doesn't contain table references + // Skip this to avoid false positives + } else { + // For other basic calls, visit all operands + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } + + private void visitIdentifier(SqlIdentifier identifier) { + // Only extract table names when we're in a FROM clause + if (_inFromClause && identifier.names.size() >= 1) { + String tableName = identifier.names.get(identifier.names.size() - 1); + // Filter out SQL keywords, system identifiers, and CTE names + if (!isReservedKeyword(tableName) && !tableName.startsWith("$") && !_cteNames.contains(tableName)) { + _tableNames.add(tableName); + } + } + } + + /** + * Visit a SqlNodeList by visiting each node in the list. + */ + private void visitNodeList(SqlNodeList nodeList) { + if (nodeList != null) { + for (SqlNode node : nodeList) { + if (node != null) { + extractTableNames(node); + } + } + } + } + + /** + * Handle unknown node types by attempting to visit their operands. + */ + private void visitUnknownNode(SqlNode node) { + try { + // Try to get operands list using reflection or common methods + if (node.getKind() != null) { + switch (node.getKind().name()) { + case "WITH": + visitWithClause(node); + break; + case "ORDER_BY": + visitOrderByCall(node); + break; + default: + // For other unknown nodes, try to visit operands if they exist + visitNodeOperands(node); + break; + } + } else { + visitNodeOperands(node); + } + } catch (Exception e) { + // Ignore reflection errors and continue Review Comment: Catching and ignoring generic `Exception` in `visitUnknownNode` could mask important parsing errors. Consider logging these exceptions at debug level or being more specific about which exceptions to ignore. ```suggestion // Log reflection errors at debug level and continue _logger.debug("Exception encountered while visiting unknown node: {}", node, e); ``` ########## pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java: ########## @@ -0,0 +1,432 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; + +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWith; +import org.apache.calcite.sql.SqlWithItem; +import org.apache.pinot.sql.parsers.CalciteSqlParser; +import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Helper class to extract table names from Calcite SqlNode tree. + */ +public class TableNameExtractor { + private static final Logger LOGGER = LoggerFactory.getLogger(TableNameExtractor.class); + // Static map of reserved SQL keywords loaded from config file + private static final Map<String, Boolean> RESERVED_KEYWORDS = loadReservedKeywords(); + + /** + * Returns the name of all the tables used in a sql query. + * + * @param query The SQL query string to analyze + * @return name of all the tables used in a sql query, or null if parsing fails + */ + @Nullable + public static String[] resolveTableName(String query) { + SqlNodeAndOptions sqlNodeAndOptions; + try { + sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query); + } catch (Exception e) { + LOGGER.error("Cannot parse table name from query: {}. Fallback to broker selector default.", query, e); + return null; + } + try { + Set<String> tableNames = extractTableNamesFromMultiStageQuery(sqlNodeAndOptions.getSqlNode()); + if (tableNames != null) { + return tableNames.toArray(new String[0]); + } + } catch (Exception e) { + LOGGER.error("Cannot extract table name from query: {}. Fallback to broker selector default.", query, e); + } + return null; + } + + /** + * Extracts table names from a multi-stage query using Calcite SQL AST traversal. + * + * @param sqlNode The root SqlNode of the parsed query + * @return Set of table names found in the query + */ + private static Set<String> extractTableNamesFromMultiStageQuery(SqlNode sqlNode) { + TableNameExtractor extractor = new TableNameExtractor(); + try { + extractor.extractTableNames(sqlNode); + return extractor.getTableNames(); + } catch (Exception e) { + LOGGER.debug("Failed to extract table names from multi-stage query", e); + return Collections.emptySet(); + } + } + + private final Set<String> _tableNames = new HashSet<>(); + private final Set<String> _cteNames = new HashSet<>(); + private boolean _inFromClause = false; + + public Set<String> getTableNames() { + return _tableNames; + } + + public void extractTableNames(SqlNode node) { + if (node == null) { + return; + } + if (node instanceof SqlWith) { + visitWith((SqlWith) node); + } else if (node instanceof SqlOrderBy) { + visitOrderBy((SqlOrderBy) node); + } else if (node instanceof SqlWithItem) { + visitWithItem((SqlWithItem) node); + } else if (node instanceof SqlSelect) { + visitSelect((SqlSelect) node); + } else if (node instanceof SqlJoin) { + visitJoin((SqlJoin) node); + } else if (node instanceof SqlBasicCall) { + visitBasicCall((SqlBasicCall) node); + } else if (node instanceof SqlIdentifier) { + visitIdentifier((SqlIdentifier) node); + } else if (node instanceof SqlNodeList) { + visitNodeList((SqlNodeList) node); + } else { + // Handle unknown node types by trying to access operands + visitUnknownNode(node); + } + } + + private void visitWith(SqlWith with) { + // Visit the WITH list (CTE definitions) + if (with.withList != null) { + visitNodeList(with.withList); + } + // Visit the main query body + if (with.body != null) { + extractTableNames(with.body); + } + } + + private void visitOrderBy(SqlOrderBy orderBy) { + // Visit the main query - this is the most important part + if (orderBy.query != null) { + extractTableNames(orderBy.query); + } + // Visit ORDER BY expressions for potential subqueries + if (orderBy.orderList != null) { + // Don't set inFromClause=true for ORDER BY expressions + // as they typically contain column references, not table names + visitNodeList(orderBy.orderList); + } + // Visit OFFSET clause if it contains subqueries (rare but possible) + if (orderBy.offset != null) { + extractTableNames(orderBy.offset); + } + // Visit FETCH/LIMIT clause if it contains subqueries (rare but possible) + if (orderBy.fetch != null) { + extractTableNames(orderBy.fetch); + } + } + + private void visitWithItem(SqlWithItem withItem) { + // Track the CTE name so we don't treat it as a table later + if (withItem.name != null) { + String cteName = withItem.name.getSimple(); + _cteNames.add(cteName); + } + // Extract table names from the CTE query definition, not the CTE alias + if (withItem.query != null) { + extractTableNames(withItem.query); + } + } + + private void visitSelect(SqlSelect select) { + // Visit FROM clause - this is where we expect to find table names + if (select.getFrom() != null) { + _inFromClause = true; + extractTableNames(select.getFrom()); + _inFromClause = false; + } + // Visit other clauses for subqueries + if (select.getWhere() != null) { + extractTableNames(select.getWhere()); + } + if (select.getGroup() != null) { + visitNodeList(select.getGroup()); + } + if (select.getHaving() != null) { + extractTableNames(select.getHaving()); + } + if (select.getOrderList() != null) { + visitNodeList(select.getOrderList()); + } + if (select.getSelectList() != null) { + visitNodeList(select.getSelectList()); + } + } + + private void visitJoin(SqlJoin join) { + // Visit both sides of the join - ensure they're processed as FROM clause items + boolean wasInFromClause = _inFromClause; + if (join.getLeft() != null) { + _inFromClause = true; + extractTableNames(join.getLeft()); + } + if (join.getRight() != null) { + _inFromClause = true; + extractTableNames(join.getRight()); + } + // Visit join condition but not as part of FROM clause context + // This handles potential subqueries in join conditions while avoiding + // incorrectly extracting column references as table names + if (join.getCondition() != null) { + _inFromClause = false; + extractTableNames(join.getCondition()); + } + // Restore original context + _inFromClause = wasInFromClause; + } + + private void visitBasicCall(SqlBasicCall call) { + String operatorName = call.getOperator().getName().toUpperCase(); + if (operatorName.equals("AS")) { + // Handle table aliases like "tableA AS a" + // For AS operations, the first operand is the actual table name + if (call.getOperandList().size() > 0 && call.getOperandList().get(0) != null) { + extractTableNames(call.getOperandList().get(0)); + } + } else if (operatorName.equals("WITH")) { + // Handle CTE (Common Table Expression) + visitWithClause(call); + } else if (operatorName.equals("VALUES")) { + // Handle VALUES clause - usually doesn't contain table references + // Skip this to avoid false positives + } else { + // For other basic calls, visit all operands + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } + + private void visitIdentifier(SqlIdentifier identifier) { + // Only extract table names when we're in a FROM clause + if (_inFromClause && identifier.names.size() >= 1) { + String tableName = identifier.names.get(identifier.names.size() - 1); + // Filter out SQL keywords, system identifiers, and CTE names + if (!isReservedKeyword(tableName) && !tableName.startsWith("$") && !_cteNames.contains(tableName)) { + _tableNames.add(tableName); + } + } + } + + /** + * Visit a SqlNodeList by visiting each node in the list. + */ + private void visitNodeList(SqlNodeList nodeList) { + if (nodeList != null) { + for (SqlNode node : nodeList) { + if (node != null) { + extractTableNames(node); + } + } + } + } + + /** + * Handle unknown node types by attempting to visit their operands. + */ + private void visitUnknownNode(SqlNode node) { + try { + // Try to get operands list using reflection or common methods + if (node.getKind() != null) { + switch (node.getKind().name()) { + case "WITH": + visitWithClause(node); + break; + case "ORDER_BY": + visitOrderByCall(node); + break; + default: + // For other unknown nodes, try to visit operands if they exist + visitNodeOperands(node); + break; + } + } else { + visitNodeOperands(node); + } + } catch (Exception e) { + // Ignore reflection errors and continue + } + } + + /** + * Handle WITH clause (CTE - Common Table Expression). + */ + private void visitWithClause(SqlNode node) { + try { + // WITH clause typically has operands: [with_list, query] + if (node instanceof SqlBasicCall) { + SqlBasicCall withCall = (SqlBasicCall) node; + for (SqlNode operand : withCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + // Fallback to generic operand handling + visitNodeOperands(node); + } + } + + /** + * Handle ORDER BY clause - this method is now replaced by visitOrderBy(SqlOrderBy). + * Keeping for backward compatibility with visitUnknownNode. + */ + private void visitOrderByCall(SqlNode node) { + try { + if (node instanceof SqlBasicCall) { + SqlBasicCall orderByCall = (SqlBasicCall) node; + // ORDER BY typically has [query, order_list] + for (SqlNode operand : orderByCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + visitNodeOperands(node); + } + } + + /** + * Generic method to visit node operands when specific handling is not available. + */ + private void visitNodeOperands(SqlNode node) { Review Comment: This method catches and ignores all exceptions without logging or handling specific cases. Consider adding debug logging to help with troubleshooting when AST traversal fails. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
