shehabgamin commented on PR #16409:
URL: https://github.com/apache/datafusion/pull/16409#issuecomment-2972618052
Not sure if it makes sense to commit the script I used, so I'll paste it
here for now:
```
"""
WARNING:
- This script extracts only basic, straightforward tests.
- It is not comprehensive and will not capture most function tests.
- Intended as a quick-and-dirty tool for generating minimal Spark
function tests.
- Run this script from the root directory of the Sail project.
"""
import glob
import json
import os
import re
from pyspark.sql import SparkSession
# From project root in Sail: https://github.com/lakehq/sail
FUNCTIONS_PATH = "crates/sail-spark-connect/tests/gold_data/function/"
def extract_simple_function_arguments(query):
"""
Extract arguments from simple function calls of pattern:
SELECT SOME_FUNC(ARG0, ARG1, .... ARGN);
Only accepts basic literal arguments - no arrays, nested functions, etc.
Example queries NOT accepted:
- query = "SELECT any(col) FROM VALUES (NULL), (true), (false) AS
tab(col);"
- query = "SELECT array_append(CAST(null as Array<Int>), 2);"
- query = "SELECT array_append(array('b', 'd', 'c', 'a'), 'd');"
- query = "SELECT * FROM explode(collection => array(10, 20));"
- query = "SELECT cast('10' as int);"
Example queries accepted:
- query = "SELECT ceil(5);"
- query = "SELECT ceil(3.1411, -3);"
- query = "SELECT now();"
"""
if any(f in query.lower() for f in ["cast", "map", "from",
"raise_error", "regexp", "rlike", " in "]):
return None
pattern = r'SELECT\s+\w+\s*\(([^)]*)\)\s*;\s*'
match = re.search(pattern, query, re.IGNORECASE | re.DOTALL)
if not match:
return None
args_string = match.group(1).strip()
if not args_string: # Empty function call
return []
# Filter out complex arguments - reject if contains brackets, parens,
etc...
if any(char in args_string for char in ['[', ']', '(', ')']):
return None
arguments = re.split(r',(?=(?:[^"\']*["\'][^"\']*["\'])*[^"\']*$)',
args_string)
arguments = [arg.strip() for arg in arguments if arg.strip()]
return arguments
def extract_function_name(query):
pattern = r'SELECT\s+(\w+)\s*\(([^)]*)\)\s*;\s*'
match = re.search(pattern, query, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()
return None
def create_typed_query(query, func_name, type_results):
if not type_results:
return query
typed_args = []
for key, spark_type in type_results.items():
if key.startswith('typeof(') and key.endswith(')'):
arg = key[7:-1]
typed_args.append(f"{arg}::{spark_type}")
typed_query = f"SELECT {func_name}({', '.join(typed_args)});"
return [f"# Original Query: {query}", f"# PySpark 3.5.5 Result:
{type_results}", typed_query]
def main():
spark = SparkSession.builder.getOrCreate()
function_dict = {}
json_files = glob.glob(os.path.join(FUNCTIONS_PATH, "*.json"))
num_queries = 0
for file_path in json_files:
with open(file_path, "r") as f:
data = json.load(f)
directory_name = os.path.basename(file_path).removesuffix('.json')
if directory_name not in function_dict:
function_dict[directory_name] = {}
for test in data["tests"]:
if len(test["input"]["schema"]["fields"]) != 1:
# Skip generator tests with multiple fields
continue
if "exception" in test:
# Skip tests that are expected to raise exceptions
continue
query = test["input"]["query"].strip()
arguments = extract_simple_function_arguments(query)
if arguments is not None:
func_name = extract_function_name(query)
if func_name is not None:
func_call = re.sub('select', '', query,
flags=re.IGNORECASE).strip().rstrip(';').strip()
if arguments:
typeof_parts = [f"typeof({arg})" for arg in
arguments]
combined_query = f"SELECT {func_call},
typeof({func_call}), {', '.join(typeof_parts)};"
else:
combined_query = f"SELECT {func_call},
typeof({func_call});"
print(f"ORIGINAL QUERY: {query}\nRUNNING QUERY:
{combined_query}")
try:
result = spark.sql(combined_query).collect()
except Exception as e:
if "CANNOT_PARSE_DATATYPE" in str(e):
print(f"Skipping query due to unsupported
datatype: {e}")
continue
else:
raise
if len(result) != 1:
spark.stop()
raise ValueError(f"Unexpected result length:
{len(result)} for query: {combined_query}")
result_row = result[0]
type_results = {}
for i in range(2, len(result_row)):
col_name = result_row.__fields__[i]
type_results[col_name.lower()] = result_row[i]
typed_query = create_typed_query(query, func_name,
type_results)
if func_name.lower() not in
function_dict[directory_name]:
function_dict[directory_name][func_name.lower()] = []
function_dict[directory_name][func_name.lower()].append(typed_query)
num_queries += 1
print(f"Processed {num_queries} queries from {len(json_files)} JSON
files.")
base_dir = os.path.join("tmp", "slt")
for directory, functions in function_dict.items():
dir_path = os.path.join(base_dir, directory)
os.makedirs(dir_path, exist_ok=True)
for func_name, queries in functions.items():
file_path = os.path.join(dir_path, f"{func_name}.slt")
with open(file_path, 'w') as f:
for query_data in queries:
f.write(f"#{query_data[0]}\n")
f.write(f"#{query_data[1]}\n")
f.write("#query\n")
f.write(f"#{query_data[2]}\n")
f.write("\n")
spark.stop()
return function_dict
if __name__ == "__main__":
main()
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]