This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch pinot-functions
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b0a5874948109c4cc63586052984dc24004038f6
Author: Beto Dealmeida <robe...@dealmeida.net>
AuthorDate: Thu Oct 2 07:31:20 2025 -0400

    fix(pinot): more functions
---
 superset/sql/dialects/pinot.py               |  30 +++
 tests/unit_tests/sql/dialects/pinot_tests.py | 267 +++++++++++++++++++++++++++
 2 files changed, 297 insertions(+)

diff --git a/superset/sql/dialects/pinot.py b/superset/sql/dialects/pinot.py
index ef8ad14be5..e0aca30817 100644
--- a/superset/sql/dialects/pinot.py
+++ b/superset/sql/dialects/pinot.py
@@ -114,6 +114,36 @@ class Pinot(MySQL):
                 e.args.get("start"),
                 e.args.get("length"),
             ),
+            exp.StrPosition: lambda self, e: self.func(
+                "STRPOS",
+                e.this,
+                e.args.get("substr"),
+                e.args.get("position"),
+            ),
+            exp.StartsWith: lambda self, e: self.func(
+                "STARTSWITH",
+                e.this,
+                e.args.get("expression"),
+            ),
+            exp.Chr: lambda self, e: self.func(
+                "CHR",
+                *e.args.get("expressions", []),
+            ),
+            exp.Mod: lambda self, e: self.func(
+                "MOD",
+                e.this,
+                e.args.get("expression"),
+            ),
+            exp.ArrayAgg: lambda self, e: self.func(
+                "ARRAY_AGG",
+                e.this,
+            ),
+            exp.JSONExtractScalar: lambda self, e: self.func(
+                "JSON_EXTRACT_SCALAR",
+                e.this,
+                e.args.get("expression"),
+                e.args.get("variant"),
+            ),
         }
         # Remove DATE_TRUNC transformation - Pinot supports standard SQL 
DATE_TRUNC
         TRANSFORMS.pop(exp.DateTrunc, None)
diff --git a/tests/unit_tests/sql/dialects/pinot_tests.py 
b/tests/unit_tests/sql/dialects/pinot_tests.py
index 02cbfc54e6..974fa6f207 100644
--- a/tests/unit_tests/sql/dialects/pinot_tests.py
+++ b/tests/unit_tests/sql/dialects/pinot_tests.py
@@ -611,3 +611,270 @@ def test_substr_cross_dialect_generation() -> None:
     mysql_output = parsed.sql(dialect="mysql")
     assert "SUBSTRING(" in mysql_output
     assert pinot_output != mysql_output  # They should be different
+
+
+@pytest.mark.parametrize(
+    "function_name,sample_args",
+    [
+        # Math functions
+        ("ABS", "-5"),
+        ("CEIL", "3.14"),
+        ("FLOOR", "3.14"),
+        ("EXP", "2"),
+        ("LN", "10"),
+        ("SQRT", "16"),
+        ("ROUNDDECIMAL", "3.14159, 2"),
+        ("ADD", "1, 2, 3"),
+        ("SUB", "10, 3"),
+        ("MULT", "5, 4"),
+        ("MOD", "10, 3"),
+        # String functions
+        ("UPPER", "'hello'"),
+        ("LOWER", "'HELLO'"),
+        ("REVERSE", "'hello'"),
+        ("SUBSTR", "'hello', 0, 3"),
+        ("CONCAT", "'hello', ' ', 'world'"),
+        ("TRIM", "' hello '"),
+        ("LTRIM", "' hello'"),
+        ("RTRIM", "'hello '"),
+        ("LENGTH", "'hello'"),
+        ("STRPOS", "'hello', 'l', 1"),
+        ("STARTSWITH", "'hello', 'he'"),
+        ("REPLACE", "'hello', 'l', 'r'"),
+        ("RPAD", "'hello', 10, 'x'"),
+        ("LPAD", "'hello', 10, 'x'"),
+        ("CODEPOINT", "'A'"),
+        ("CHR", "65"),
+        ("regexpExtract", "'foo123bar', '[0-9]+'"),
+        ("regexpReplace", "'hello', 'l', 'r'"),
+        ("remove", "'hello', 'l'"),
+        ("urlEncoding", "'hello world'"),
+        ("urlDecoding", "'hello%20world'"),
+        ("fromBase64", "'aGVsbG8='"),
+        ("toUtf8", "'hello'"),
+        ("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"),
+        # DateTime functions
+        ("DATETRUNC", "'day', timestamp_col"),
+        ("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', 
'1:DAYS'"),
+        ("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"),
+        ("NOW", ""),
+        ("AGO", "'P1D'"),
+        ("YEAR", "timestamp_col"),
+        ("QUARTER", "timestamp_col"),
+        ("MONTH", "timestamp_col"),
+        ("WEEK", "timestamp_col"),
+        ("DAY", "timestamp_col"),
+        ("HOUR", "timestamp_col"),
+        ("MINUTE", "timestamp_col"),
+        ("SECOND", "timestamp_col"),
+        ("MILLISECOND", "timestamp_col"),
+        ("DAYOFWEEK", "timestamp_col"),
+        ("DAYOFYEAR", "timestamp_col"),
+        ("YEAROFWEEK", "timestamp_col"),
+        ("toEpochSeconds", "timestamp_col"),
+        ("toEpochMinutes", "timestamp_col"),
+        ("toEpochHours", "timestamp_col"),
+        ("toEpochDays", "timestamp_col"),
+        ("fromEpochSeconds", "1234567890"),
+        ("fromEpochMinutes", "20576131"),
+        ("fromEpochHours", "342935"),
+        ("fromEpochDays", "14288"),
+        ("toDateTime", "timestamp_col, 'yyyy-MM-dd'"),
+        ("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"),
+        ("timezoneHour", "timestamp_col"),
+        ("timezoneMinute", "timestamp_col"),
+        ("DATE_ADD", "'day', 7, NOW()"),
+        ("DATE_SUB", "'day', 7, NOW()"),
+        ("TIMESTAMPADD", "'day', 7, timestamp_col"),
+        ("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"),
+        ("dateTrunc", "'day', timestamp_col"),
+        ("dateDiff", "'day', timestamp1, timestamp2"),
+        ("dateAdd", "'day', 7, timestamp_col"),
+        ("dateBin", "'day', timestamp_col, NOW()"),
+        ("toIso8601", "timestamp_col"),
+        ("fromIso8601", "'2024-01-01T00:00:00Z'"),
+        # Aggregation functions
+        ("COUNT", "*"),
+        ("SUM", "amount"),
+        ("AVG", "value"),
+        ("MIN", "value"),
+        ("MAX", "value"),
+        ("DISTINCTCOUNT", "user_id"),
+        ("DISTINCTCOUNTBITMAP", "user_id"),
+        ("DISTINCTCOUNTHLL", "user_id"),
+        ("DISTINCTCOUNTRAWHLL", "user_id"),
+        ("DISTINCTCOUNTHLLPLUS", "user_id"),
+        ("DISTINCTCOUNTRAWHLLPLUS", "user_id"),
+        ("DISTINCTCOUNTSMARTHLL", "user_id"),
+        ("DISTINCTCOUNTCPCSKETCH", "user_id"),
+        ("DISTINCTCOUNTRAWCPCSKETCH", "user_id"),
+        ("DISTINCTCOUNTTHETASKETCH", "user_id"),
+        ("DISTINCTCOUNTRAWTHETASKETCH", "user_id"),
+        ("DISTINCTCOUNTTUPLESKETCH", "user_id"),
+        ("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"),
+        ("DISTINCTCOUNTULL", "user_id"),
+        ("DISTINCTCOUNTRAWULL", "user_id"),
+        ("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"),
+        ("SUMVALUESINTEGERSUMTUPLESKETCH", "value"),
+        ("PERCENTILE", "value, 95"),
+        ("PERCENTILEEST", "value, 95"),
+        ("PERCENTILETDIGEST", "value, 95"),
+        ("PERCENTILESMARTTDIGEST", "value, 95"),
+        ("PERCENTILEKLL", "value, 95"),
+        ("PERCENTILEKLLRAW", "value, 95"),
+        ("HISTOGRAM", "value, 10"),
+        ("MODE", "category"),
+        ("MINMAXRANGE", "value"),
+        ("SUMPRECISION", "value, 10"),
+        ("ARG_MIN", "value, id"),
+        ("ARG_MAX", "value, id"),
+        ("COVAR_POP", "x, y"),
+        ("COVAR_SAMP", "x, y"),
+        ("LASTWITHTIME", "value, timestamp_col, 'LONG'"),
+        ("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"),
+        ("ARRAY_AGG", "value"),
+        # Multi-value functions
+        ("COUNTMV", "tags"),
+        ("MAXMV", "scores"),
+        ("MINMV", "scores"),
+        ("SUMMV", "scores"),
+        ("AVGMV", "scores"),
+        ("MINMAXRANGEMV", "scores"),
+        ("PERCENTILEMV", "scores, 95"),
+        ("PERCENTILEESTMV", "scores, 95"),
+        ("PERCENTILETDIGESTMV", "scores, 95"),
+        ("PERCENTILEKLLMV", "scores, 95"),
+        ("DISTINCTCOUNTMV", "tags"),
+        ("DISTINCTCOUNTBITMAPMV", "tags"),
+        ("DISTINCTCOUNTHLLMV", "tags"),
+        ("DISTINCTCOUNTRAWHLLMV", "tags"),
+        ("DISTINCTCOUNTHLLPLUSMV", "tags"),
+        ("DISTINCTCOUNTRAWHLLPLUSMV", "tags"),
+        ("ARRAYLENGTH", "array_col"),
+        ("MAP_VALUE", "map_col, 'key'"),
+        ("VALUEIN", "value, 'val1', 'val2'"),
+        # JSON functions
+        ("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"),
+        ("JSONEXTRACTKEY", "json_col, '$.data'"),
+        ("JSONFORMAT", "json_col"),
+        ("JSONPATH", "json_col, '$.name'"),
+        ("JSONPATHLONG", "json_col, '$.id'"),
+        ("JSONPATHDOUBLE", "json_col, '$.price'"),
+        ("JSONPATHSTRING", "json_col, '$.name'"),
+        ("JSONPATHARRAY", "json_col, '$.items'"),
+        ("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"),
+        ("TOJSONMAPSTR", "map_col"),
+        ("JSON_MATCH", "json_col, '\"$.name\"=''value'''"),
+        ("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"),
+        # Array functions
+        ("arrayReverseInt", "int_array"),
+        ("arrayReverseString", "string_array"),
+        ("arraySortInt", "int_array"),
+        ("arraySortString", "string_array"),
+        ("arrayIndexOfInt", "int_array, 5"),
+        ("arrayIndexOfString", "string_array, 'value'"),
+        ("arrayContainsInt", "int_array, 5"),
+        ("arrayContainsString", "string_array, 'value'"),
+        ("arraySliceInt", "int_array, 0, 3"),
+        ("arraySliceString", "string_array, 0, 3"),
+        ("arrayDistinctInt", "int_array"),
+        ("arrayDistinctString", "string_array"),
+        ("arrayRemoveInt", "int_array, 5"),
+        ("arrayRemoveString", "string_array, 'value'"),
+        ("arrayUnionInt", "int_array1, int_array2"),
+        ("arrayUnionString", "string_array1, string_array2"),
+        ("arrayConcatInt", "int_array1, int_array2"),
+        ("arrayConcatString", "string_array1, string_array2"),
+        ("arrayElementAtInt", "int_array, 0"),
+        ("arrayElementAtString", "string_array, 0"),
+        ("arraySumInt", "int_array"),
+        ("arrayValueConstructor", "1, 2, 3"),
+        ("arrayToString", "array_col, ','"),
+        # Geospatial functions
+        ("ST_DISTANCE", "point1, point2"),
+        ("ST_CONTAINS", "polygon, point"),
+        ("ST_AREA", "polygon"),
+        ("ST_GEOMFROMTEXT", "'POINT(1 2)'"),
+        ("ST_GEOMFROMWKB", "wkb_col"),
+        ("ST_GEOGFROMWKB", "wkb_col"),
+        ("ST_GEOGFROMTEXT", "'POINT(1 2)'"),
+        ("ST_POINT", "1.0, 2.0"),
+        ("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"),
+        ("ST_ASBINARY", "geom_col"),
+        ("ST_ASTEXT", "geom_col"),
+        ("ST_GEOMETRYTYPE", "geom_col"),
+        ("ST_EQUALS", "geom1, geom2"),
+        ("ST_WITHIN", "geom1, geom2"),
+        ("ST_UNION", "geom1, geom2"),
+        ("ST_GEOMFROMGEOJSON", "'{\"type\":\"Point\",\"coordinates\":[1,2]}'"),
+        ("ST_GEOGFROMGEOJSON", "'{\"type\":\"Point\",\"coordinates\":[1,2]}'"),
+        ("ST_ASGEOJSON", "geom_col"),
+        ("toSphericalGeography", "geom_col"),
+        ("toGeometry", "geog_col"),
+        # Binary/Hash functions
+        ("SHA", "'hello'"),
+        ("SHA256", "'hello'"),
+        ("SHA512", "'hello'"),
+        ("SHA224", "'hello'"),
+        ("MD5", "'hello'"),
+        ("MD2", "'hello'"),
+        ("toBase64", "'hello'"),
+        ("fromUtf8", "bytes_col"),
+        ("MurmurHash2", "'hello'"),
+        ("MurmurHash3Bit32", "'hello'"),
+        # Window functions
+        ("ROW_NUMBER", ""),
+        ("RANK", ""),
+        ("DENSE_RANK", ""),
+        # Funnel analysis
+        ("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"),
+        ("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"),
+        ("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"),
+        # Text search
+        ("TEXT_MATCH", "text_col, 'search query'"),
+        # Vector functions
+        ("VECTOR_SIMILARITY", "vector1, vector2"),
+        ("l2_distance", "vector1, vector2"),
+        # Lookup
+        ("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"),
+        # URL functions
+        ("urlProtocol", "'https://example.com/path'"),
+        ("urlDomain", "'https://example.com/path'"),
+        ("urlPath", "'https://example.com/path'"),
+        ("urlPort", "'https://example.com:8080/path'"),
+        ("urlEncode", "'hello world'"),
+        ("urlDecode", "'hello%20world'"),
+        # Conditional
+        ("COALESCE", "val1, val2, 'default'"),
+        ("NULLIF", "val1, val2"),
+        ("GREATEST", "1, 2, 3"),
+        ("LEAST", "1, 2, 3"),
+        # Other
+        ("REGEXP_LIKE", "'hello', 'h.*'"),
+        ("GROOVY", "'{return arg0 + arg1}', col1, col2"),
+    ],
+)
+def test_pinot_function_names_preserved(function_name: str, sample_args: str) 
-> None:
+    """
+    Test that Pinot function names are preserved during parse/generate 
roundtrip.
+
+    This ensures that when we parse Pinot SQL and generate it back, the 
function
+    names remain unchanged. This is critical for maintaining compatibility with
+    Pinot's function library.
+    """
+    # Special handling for window functions
+    if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]:
+        sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table"  # 
noqa: S608
+    else:
+        sql = f"SELECT {function_name}({sample_args}) FROM table"  # noqa: S608
+
+    # Parse with Pinot dialect
+    parsed = sqlglot.parse_one(sql, Pinot)
+
+    # Generate back to Pinot
+    generated = parsed.sql(dialect=Pinot)
+
+    # The function name should be preserved (case-insensitive check)
+    assert function_name.upper() in generated.upper(), (
+        f"Function {function_name} not preserved in output: {generated}"
+    )

Reply via email to