parthchandra commented on code in PR #21508: URL: https://github.com/apache/datafusion/pull/21508#discussion_r3061073535
########## datafusion/spark/scripts/validate_slt.py: ########## @@ -0,0 +1,1210 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Validate hardcoded expected values in .slt (sqllogictest) test files +by running the same queries against PySpark and comparing results. + +Usage: + python validate_slt.py # Run all .slt files + python validate_slt.py --path math/abs.slt # Single file + python validate_slt.py --path string/ # All files in subdirectory + python validate_slt.py --verbose # Show details + python validate_slt.py --show-skipped # Show skipped queries +""" + +import argparse +import math +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +# --------------------------------------------------------------------------- +# Arrow type -> Spark type mapping +# --------------------------------------------------------------------------- +ARROW_TO_SPARK_TYPE = { + "Int8": "TINYINT", + "Int16": "SMALLINT", + "Int32": "INT", + "Int64": "BIGINT", + "UInt8": "SMALLINT", + "UInt16": "INT", + "UInt32": "BIGINT", + "UInt64": "BIGINT", + "Float16": "FLOAT", + "Float32": "FLOAT", + "Float64": "DOUBLE", + "Utf8": "STRING", + "Boolean": "BOOLEAN", + "Binary": "BINARY", + "Date32": "DATE", + "Date64": "DATE", +} + +# DataFusion cast type -> Spark type mapping +DF_TO_SPARK_CAST_TYPE = { + "TINYINT": "TINYINT", + "SMALLINT": "SMALLINT", + "INT": "INT", + "INTEGER": "INT", + "BIGINT": "BIGINT", + "FLOAT": "FLOAT", + "REAL": "FLOAT", + "DOUBLE": "DOUBLE", + "STRING": "STRING", + "VARCHAR": "STRING", + "TEXT": "STRING", + "BOOLEAN": "BOOLEAN", + "BINARY": "BINARY", + "DATE": "DATE", + "TIMESTAMP": "TIMESTAMP", + # PostgreSQL-style aliases used in some .slt files + "FLOAT8": "DOUBLE", + "FLOAT4": "FLOAT", + "INT8": "BIGINT", + "INT4": "INT", + "INT2": "SMALLINT", + "BYTEA": "BINARY", +} + +# Unsupported Arrow types for Spark +UNSUPPORTED_ARROW_TYPES = { + "Utf8View", + "LargeUtf8", + "LargeBinary", + "BinaryView", +} + +# --------------------------------------------------------------------------- +# SLT record types +# --------------------------------------------------------------------------- + + +@dataclass +class QueryRecord: + """A 'query <TYPE_CODES> [rowsort]' block.""" + + type_codes: str + sql: str + expected: list[str] + rowsort: bool + line_number: int + in_ansi_block: bool = False + + +@dataclass +class ErrorRecord: + """A 'query error <pattern>' or 'statement error <pattern>' block.""" + + pattern: str + sql: str + line_number: int + kind: str = "query" # "query" or "statement" + in_ansi_block: bool = False + + +@dataclass +class StatementRecord: + """A 'statement ok' block (DDL/config).""" + + sql: str + line_number: int + in_ansi_block: bool = False + + +# --------------------------------------------------------------------------- +# 1. SLT Parser +# --------------------------------------------------------------------------- + + +def parse_slt(filepath: str) -> list: + """Parse an .slt file into a list of records.""" + with open(filepath) as f: + lines = f.readlines() + + records = [] + i = 0 + in_ansi_mode = False + + while i < len(lines): + line = lines[i].rstrip("\n") + + # Skip blank lines and comments + if not line.strip() or line.strip().startswith("#"): + i += 1 + continue + + # query error <pattern> + m = re.match(r"^query\s+error\s+(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="query", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement error <pattern> + m = re.match(r"^statement\s+error\s*(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="statement", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement ok + m = re.match(r"^statement\s+ok\s*$", line) + if m: + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + sql = "\n".join(sql_lines) + + # Track ANSI mode from statements + if re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*true", + sql, + re.IGNORECASE, + ): + in_ansi_mode = True + elif re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*false", + sql, + re.IGNORECASE, + ): + in_ansi_mode = False + + records.append( + StatementRecord( + sql=sql, line_number=line_num, in_ansi_block=in_ansi_mode + ) + ) + continue + + # query <TYPE_CODES> [rowsort] + m = re.match(r"^query\s+(\S+)(\s+rowsort)?\s*$", line) + if m: + type_codes = m.group(1) + rowsort = m.group(2) is not None + line_num = i + 1 + i += 1 + + # Collect SQL lines until ---- + sql_lines = [] + while i < len(lines) and lines[i].rstrip("\n") != "----": + sql_lines.append(lines[i].rstrip("\n")) + i += 1 + + # Skip the ---- separator + if i < len(lines) and lines[i].rstrip("\n") == "----": + i += 1 + + # Collect expected result lines until blank line or next record. + # Note: do NOT treat # as a comment here — result values can + # start with # (e.g., soundex('#') -> '#'). + expected = [] + while i < len(lines): + result_line = lines[i].rstrip("\n") + if result_line == "": + i += 1 + break + if re.match(r"^(query|statement)\s", result_line): + break + # A ## comment line in the results section signals end of results + if result_line.startswith("##"): + break + expected.append(result_line) + i += 1 + + records.append( + QueryRecord( + type_codes=type_codes, + sql="\n".join(sql_lines), + expected=expected, + rowsort=rowsort, + line_number=line_num, + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # Unknown line, skip + i += 1 + + return records + + +# --------------------------------------------------------------------------- +# 2. SQL Translator (DataFusion -> PySpark) +# --------------------------------------------------------------------------- + + +def _translate_cast_type(df_type: str) -> Optional[str]: + """Map a DataFusion type name to a Spark SQL type name.""" + upper = df_type.upper().strip() + # Direct match + if upper in DF_TO_SPARK_CAST_TYPE: + return DF_TO_SPARK_CAST_TYPE[upper] + # DECIMAL(p, s) - skip if precision > 38 (Spark max) + if upper.startswith("DECIMAL"): + m = re.match(r"DECIMAL\(\s*(\d+)", upper) + if m and int(m.group(1)) > 38: + raise _SkipQuery(f"Decimal precision {m.group(1)} exceeds Spark max of 38") + return df_type # pass through + return upper # pass through and hope for the best + + +class _SkipQuery(Exception): + """Signal that a query should be skipped.""" + + pass + + +def _replace_arrow_cast_nested(sql: str) -> str: + """Replace arrow_cast(...) handling nested parentheses properly.""" + result = [] + i = 0 + while i < len(sql): + if sql[i:].startswith("arrow_cast("): + start = i + i += len("arrow_cast(") + depth = 1 + inner_start = i + while i < len(sql) and depth > 0: + if sql[i] == "(": + depth += 1 + elif sql[i] == ")": + depth -= 1 + i += 1 + inner = sql[inner_start : i - 1] + + # Find the last top-level comma + depth = 0 + last_comma = -1 + for idx, ch in enumerate(inner): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + elif ch == "," and depth == 0: Review Comment: what if the cast string contains a comma? ########## datafusion/spark/scripts/validate_slt.py: ########## @@ -0,0 +1,1210 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Validate hardcoded expected values in .slt (sqllogictest) test files +by running the same queries against PySpark and comparing results. + +Usage: + python validate_slt.py # Run all .slt files + python validate_slt.py --path math/abs.slt # Single file + python validate_slt.py --path string/ # All files in subdirectory + python validate_slt.py --verbose # Show details + python validate_slt.py --show-skipped # Show skipped queries +""" + +import argparse +import math +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +# --------------------------------------------------------------------------- +# Arrow type -> Spark type mapping +# --------------------------------------------------------------------------- +ARROW_TO_SPARK_TYPE = { + "Int8": "TINYINT", + "Int16": "SMALLINT", + "Int32": "INT", + "Int64": "BIGINT", + "UInt8": "SMALLINT", + "UInt16": "INT", + "UInt32": "BIGINT", + "UInt64": "BIGINT", + "Float16": "FLOAT", + "Float32": "FLOAT", + "Float64": "DOUBLE", + "Utf8": "STRING", + "Boolean": "BOOLEAN", + "Binary": "BINARY", + "Date32": "DATE", + "Date64": "DATE", +} + +# DataFusion cast type -> Spark type mapping +DF_TO_SPARK_CAST_TYPE = { + "TINYINT": "TINYINT", + "SMALLINT": "SMALLINT", + "INT": "INT", + "INTEGER": "INT", + "BIGINT": "BIGINT", + "FLOAT": "FLOAT", + "REAL": "FLOAT", + "DOUBLE": "DOUBLE", + "STRING": "STRING", + "VARCHAR": "STRING", + "TEXT": "STRING", + "BOOLEAN": "BOOLEAN", + "BINARY": "BINARY", + "DATE": "DATE", + "TIMESTAMP": "TIMESTAMP", + # PostgreSQL-style aliases used in some .slt files + "FLOAT8": "DOUBLE", + "FLOAT4": "FLOAT", + "INT8": "BIGINT", + "INT4": "INT", + "INT2": "SMALLINT", + "BYTEA": "BINARY", +} + +# Unsupported Arrow types for Spark +UNSUPPORTED_ARROW_TYPES = { + "Utf8View", + "LargeUtf8", + "LargeBinary", + "BinaryView", +} + +# --------------------------------------------------------------------------- +# SLT record types +# --------------------------------------------------------------------------- + + +@dataclass +class QueryRecord: + """A 'query <TYPE_CODES> [rowsort]' block.""" + + type_codes: str + sql: str + expected: list[str] + rowsort: bool + line_number: int + in_ansi_block: bool = False + + +@dataclass +class ErrorRecord: + """A 'query error <pattern>' or 'statement error <pattern>' block.""" + + pattern: str + sql: str + line_number: int + kind: str = "query" # "query" or "statement" + in_ansi_block: bool = False + + +@dataclass +class StatementRecord: + """A 'statement ok' block (DDL/config).""" + + sql: str + line_number: int + in_ansi_block: bool = False + + +# --------------------------------------------------------------------------- +# 1. SLT Parser +# --------------------------------------------------------------------------- + + +def parse_slt(filepath: str) -> list: + """Parse an .slt file into a list of records.""" + with open(filepath) as f: + lines = f.readlines() + + records = [] + i = 0 + in_ansi_mode = False + + while i < len(lines): + line = lines[i].rstrip("\n") + + # Skip blank lines and comments + if not line.strip() or line.strip().startswith("#"): + i += 1 + continue + + # query error <pattern> + m = re.match(r"^query\s+error\s+(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="query", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement error <pattern> + m = re.match(r"^statement\s+error\s*(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="statement", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement ok + m = re.match(r"^statement\s+ok\s*$", line) + if m: + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + sql = "\n".join(sql_lines) + + # Track ANSI mode from statements + if re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*true", + sql, + re.IGNORECASE, + ): + in_ansi_mode = True + elif re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*false", + sql, + re.IGNORECASE, + ): + in_ansi_mode = False + + records.append( + StatementRecord( + sql=sql, line_number=line_num, in_ansi_block=in_ansi_mode + ) + ) + continue + + # query <TYPE_CODES> [rowsort] + m = re.match(r"^query\s+(\S+)(\s+rowsort)?\s*$", line) + if m: + type_codes = m.group(1) + rowsort = m.group(2) is not None + line_num = i + 1 + i += 1 + + # Collect SQL lines until ---- + sql_lines = [] + while i < len(lines) and lines[i].rstrip("\n") != "----": + sql_lines.append(lines[i].rstrip("\n")) + i += 1 + + # Skip the ---- separator + if i < len(lines) and lines[i].rstrip("\n") == "----": + i += 1 + + # Collect expected result lines until blank line or next record. + # Note: do NOT treat # as a comment here — result values can + # start with # (e.g., soundex('#') -> '#'). + expected = [] + while i < len(lines): + result_line = lines[i].rstrip("\n") + if result_line == "": + i += 1 + break + if re.match(r"^(query|statement)\s", result_line): + break + # A ## comment line in the results section signals end of results + if result_line.startswith("##"): + break + expected.append(result_line) + i += 1 + + records.append( + QueryRecord( + type_codes=type_codes, + sql="\n".join(sql_lines), + expected=expected, + rowsort=rowsort, + line_number=line_num, + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # Unknown line, skip + i += 1 + + return records + + +# --------------------------------------------------------------------------- +# 2. SQL Translator (DataFusion -> PySpark) +# --------------------------------------------------------------------------- + + +def _translate_cast_type(df_type: str) -> Optional[str]: + """Map a DataFusion type name to a Spark SQL type name.""" + upper = df_type.upper().strip() + # Direct match + if upper in DF_TO_SPARK_CAST_TYPE: + return DF_TO_SPARK_CAST_TYPE[upper] + # DECIMAL(p, s) - skip if precision > 38 (Spark max) + if upper.startswith("DECIMAL"): + m = re.match(r"DECIMAL\(\s*(\d+)", upper) + if m and int(m.group(1)) > 38: + raise _SkipQuery(f"Decimal precision {m.group(1)} exceeds Spark max of 38") + return df_type # pass through + return upper # pass through and hope for the best Review Comment: Yeah, this is a bit dicey. So an unsupported type would create a failed test (implying a bug instead of being unsupported)? ########## datafusion/spark/scripts/validate_slt.py: ########## @@ -0,0 +1,1210 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Validate hardcoded expected values in .slt (sqllogictest) test files +by running the same queries against PySpark and comparing results. + +Usage: + python validate_slt.py # Run all .slt files + python validate_slt.py --path math/abs.slt # Single file + python validate_slt.py --path string/ # All files in subdirectory + python validate_slt.py --verbose # Show details + python validate_slt.py --show-skipped # Show skipped queries +""" + +import argparse +import math +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +# --------------------------------------------------------------------------- +# Arrow type -> Spark type mapping +# --------------------------------------------------------------------------- +ARROW_TO_SPARK_TYPE = { + "Int8": "TINYINT", + "Int16": "SMALLINT", + "Int32": "INT", + "Int64": "BIGINT", + "UInt8": "SMALLINT", + "UInt16": "INT", + "UInt32": "BIGINT", + "UInt64": "BIGINT", + "Float16": "FLOAT", + "Float32": "FLOAT", + "Float64": "DOUBLE", + "Utf8": "STRING", + "Boolean": "BOOLEAN", + "Binary": "BINARY", + "Date32": "DATE", + "Date64": "DATE", +} + +# DataFusion cast type -> Spark type mapping +DF_TO_SPARK_CAST_TYPE = { + "TINYINT": "TINYINT", + "SMALLINT": "SMALLINT", + "INT": "INT", + "INTEGER": "INT", + "BIGINT": "BIGINT", + "FLOAT": "FLOAT", + "REAL": "FLOAT", + "DOUBLE": "DOUBLE", + "STRING": "STRING", + "VARCHAR": "STRING", + "TEXT": "STRING", + "BOOLEAN": "BOOLEAN", + "BINARY": "BINARY", + "DATE": "DATE", + "TIMESTAMP": "TIMESTAMP", + # PostgreSQL-style aliases used in some .slt files + "FLOAT8": "DOUBLE", + "FLOAT4": "FLOAT", + "INT8": "BIGINT", + "INT4": "INT", + "INT2": "SMALLINT", + "BYTEA": "BINARY", +} + +# Unsupported Arrow types for Spark +UNSUPPORTED_ARROW_TYPES = { + "Utf8View", + "LargeUtf8", + "LargeBinary", + "BinaryView", +} + +# --------------------------------------------------------------------------- +# SLT record types +# --------------------------------------------------------------------------- + + +@dataclass +class QueryRecord: + """A 'query <TYPE_CODES> [rowsort]' block.""" + + type_codes: str + sql: str + expected: list[str] + rowsort: bool + line_number: int + in_ansi_block: bool = False + + +@dataclass +class ErrorRecord: + """A 'query error <pattern>' or 'statement error <pattern>' block.""" + + pattern: str + sql: str + line_number: int + kind: str = "query" # "query" or "statement" + in_ansi_block: bool = False + + +@dataclass +class StatementRecord: + """A 'statement ok' block (DDL/config).""" + + sql: str + line_number: int + in_ansi_block: bool = False + + +# --------------------------------------------------------------------------- +# 1. SLT Parser +# --------------------------------------------------------------------------- + + +def parse_slt(filepath: str) -> list: + """Parse an .slt file into a list of records.""" + with open(filepath) as f: + lines = f.readlines() + + records = [] + i = 0 + in_ansi_mode = False + + while i < len(lines): + line = lines[i].rstrip("\n") + + # Skip blank lines and comments + if not line.strip() or line.strip().startswith("#"): + i += 1 + continue + + # query error <pattern> + m = re.match(r"^query\s+error\s+(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="query", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement error <pattern> + m = re.match(r"^statement\s+error\s*(.*)", line) + if m: + pattern = m.group(1).strip() + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + records.append( + ErrorRecord( + pattern=pattern, + sql="\n".join(sql_lines), + line_number=line_num, + kind="statement", + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # statement ok + m = re.match(r"^statement\s+ok\s*$", line) + if m: + line_num = i + 1 + i += 1 + sql_lines = [] + while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith("#"): + stripped = lines[i].rstrip("\n") + if ( + re.match(r"^query\s", stripped) + or re.match(r"^statement\s", stripped) + ): + break + sql_lines.append(stripped) + i += 1 + sql = "\n".join(sql_lines) + + # Track ANSI mode from statements + if re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*true", + sql, + re.IGNORECASE, + ): + in_ansi_mode = True + elif re.search( + r"set\s+datafusion\.execution\.enable_ansi_mode\s*=\s*false", + sql, + re.IGNORECASE, + ): + in_ansi_mode = False + + records.append( + StatementRecord( + sql=sql, line_number=line_num, in_ansi_block=in_ansi_mode + ) + ) + continue + + # query <TYPE_CODES> [rowsort] + m = re.match(r"^query\s+(\S+)(\s+rowsort)?\s*$", line) + if m: + type_codes = m.group(1) + rowsort = m.group(2) is not None + line_num = i + 1 + i += 1 + + # Collect SQL lines until ---- + sql_lines = [] + while i < len(lines) and lines[i].rstrip("\n") != "----": + sql_lines.append(lines[i].rstrip("\n")) + i += 1 + + # Skip the ---- separator + if i < len(lines) and lines[i].rstrip("\n") == "----": + i += 1 + + # Collect expected result lines until blank line or next record. + # Note: do NOT treat # as a comment here — result values can + # start with # (e.g., soundex('#') -> '#'). + expected = [] + while i < len(lines): + result_line = lines[i].rstrip("\n") + if result_line == "": + i += 1 + break + if re.match(r"^(query|statement)\s", result_line): + break + # A ## comment line in the results section signals end of results + if result_line.startswith("##"): + break + expected.append(result_line) + i += 1 + + records.append( + QueryRecord( + type_codes=type_codes, + sql="\n".join(sql_lines), + expected=expected, + rowsort=rowsort, + line_number=line_num, + in_ansi_block=in_ansi_mode, + ) + ) + continue + + # Unknown line, skip + i += 1 + + return records + + +# --------------------------------------------------------------------------- +# 2. SQL Translator (DataFusion -> PySpark) +# --------------------------------------------------------------------------- + + +def _translate_cast_type(df_type: str) -> Optional[str]: + """Map a DataFusion type name to a Spark SQL type name.""" + upper = df_type.upper().strip() + # Direct match + if upper in DF_TO_SPARK_CAST_TYPE: + return DF_TO_SPARK_CAST_TYPE[upper] + # DECIMAL(p, s) - skip if precision > 38 (Spark max) + if upper.startswith("DECIMAL"): + m = re.match(r"DECIMAL\(\s*(\d+)", upper) + if m and int(m.group(1)) > 38: + raise _SkipQuery(f"Decimal precision {m.group(1)} exceeds Spark max of 38") + return df_type # pass through + return upper # pass through and hope for the best + + +class _SkipQuery(Exception): + """Signal that a query should be skipped.""" + + pass + + +def _replace_arrow_cast_nested(sql: str) -> str: + """Replace arrow_cast(...) handling nested parentheses properly.""" + result = [] + i = 0 + while i < len(sql): + if sql[i:].startswith("arrow_cast("): + start = i + i += len("arrow_cast(") + depth = 1 + inner_start = i + while i < len(sql) and depth > 0: + if sql[i] == "(": + depth += 1 + elif sql[i] == ")": + depth -= 1 + i += 1 + inner = sql[inner_start : i - 1] + + # Find the last top-level comma + depth = 0 + last_comma = -1 + for idx, ch in enumerate(inner): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + elif ch == "," and depth == 0: + last_comma = idx + + if last_comma == -1: + result.append(sql[start:i]) + continue + + expr = inner[:last_comma].strip() + arrow_type_raw = inner[last_comma + 1 :].strip().strip("'\"") + + if arrow_type_raw in UNSUPPORTED_ARROW_TYPES: + raise _SkipQuery(f"unsupported Arrow type: {arrow_type_raw}") + if arrow_type_raw.startswith("Dictionary("): + raise _SkipQuery(f"unsupported Arrow type: {arrow_type_raw}") + if arrow_type_raw.startswith(("LargeList(", "FixedSizeList(")): + raise _SkipQuery(f"unsupported Arrow type: {arrow_type_raw}") + if arrow_type_raw.startswith("List("): + # List(X) -> ARRAY<spark_type> - skip for now + raise _SkipQuery(f"unsupported Arrow type: {arrow_type_raw}") + + spark_type = ARROW_TO_SPARK_TYPE.get(arrow_type_raw) + if spark_type is None: + raise _SkipQuery(f"unmapped Arrow type: {arrow_type_raw}") + + result.append(f"CAST({expr} AS {spark_type})") + else: + result.append(sql[i]) + i += 1 + + return "".join(result) + + +def _translate_casts(sql: str) -> str: + """Translate DataFusion :: cast syntax to Spark CAST() syntax.""" + # Order matters: most specific patterns first + + result = sql + changed = True + while changed: + changed = False + + # 1. Parenthesized expressions: (expr)::TYPE + # Walk through looking for ):: and then find matching ( + i = 0 + while i < len(result): + if result[i] == ")" and result[i + 1 : i + 3] == "::": + # Walk backwards to find matching ( + depth = 0 + j = i + while j >= 0: + if result[j] == ")": + depth += 1 + elif result[j] == "(": + depth -= 1 + if depth == 0: + break + j -= 1 + if j >= 0 and depth == 0: + # Check if ( is preceded by a function name + func_start = j + while func_start > 0 and (result[func_start - 1].isalnum() or result[func_start - 1] == "_"): + func_start -= 1 + + paren_expr = result[j + 1 : i] + # Extract type after :: + type_start = i + 3 + type_end = type_start + while type_end < len(result) and ( + result[type_end].isalnum() or result[type_end] == "_" + ): + type_end += 1 + # Check for DECIMAL(p,s) style + if type_end < len(result) and result[type_end] == "(": + paren_depth = 1 + type_end += 1 + while type_end < len(result) and paren_depth > 0: + if result[type_end] == "(": + paren_depth += 1 + elif result[type_end] == ")": + paren_depth -= 1 + type_end += 1 + + cast_type = result[i + 3 : type_end] + spark_type = _translate_cast_type(cast_type) + + if func_start < j: + # func(...)::TYPE -> CAST(func(...) AS TYPE) + func_call = result[func_start : i + 1] + replacement = f"CAST({func_call} AS {spark_type})" + result = result[:func_start] + replacement + result[type_end:] + else: + # (expr)::TYPE -> CAST(expr AS TYPE) + replacement = f"CAST({paren_expr} AS {spark_type})" + result = result[:j] + replacement + result[type_end:] + changed = True + break + i += 1 + if changed: + continue + + # 2. String literals: 'val'::TYPE + m = re.search(r"'([^']*)'::(\w+(?:\([^)]*\))?)", result) Review Comment: Will this handle a literal with an escaped quote? for instance - `'Andy''s'` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
