snuyanzin commented on code in PR #28113:
URL: https://github.com/apache/flink/pull/28113#discussion_r3196563126


##########
flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java:
##########
@@ -968,13 +969,390 @@ void testRawLiteralScalarFunction() throws Exception {
         
assertThat(TestCollectionTableFactory.getResult()).containsExactlyInAnyOrder(sinkData);
     }
 
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("inputForTestCalcLocalRefReuse")
+    void testCalcLocalRefReuse(
+            String sql, List<Row> expectedRows, int expectedDetCalls, int 
expectedNonDetCalls) {
+        final List<Row> sourceData = List.of(Row.of("Bob"), Row.of("Alice"));
+
+        TestCollectionTableFactory.reset();
+        TestCollectionTableFactory.initData(sourceData);
+        CountingUpperScalarFunction.COUNT.set(0);
+        NonDeterministicCountingScalarFunction.COUNT.set(0);
+
+        tEnv().createTemporarySystemFunction("Det", 
CountingUpperScalarFunction.class);
+        tEnv().createTemporarySystemFunction(
+                        "Nondet", 
NonDeterministicCountingScalarFunction.class);
+        tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH 
('connector' = 'COLLECTION')");
+
+        final List<Row> actual = 
CollectionUtil.iteratorToList(tEnv().executeSql(sql).collect());
+
+        assertThat(actual).containsExactlyElementsOf(expectedRows);
+        assertThat(CountingUpperScalarFunction.COUNT.get())
+                .as("Deterministic invocations")
+                .isEqualTo(expectedDetCalls);
+        assertThat(NonDeterministicCountingScalarFunction.COUNT.get())
+                .as("Non-deterministic invocations")
+                .isEqualTo(expectedNonDetCalls);
+    }
+
+    static Stream<Arguments> inputForTestCalcLocalRefReuse() {
+        return Stream.of(
+                Arguments.of(
+                        "SELECT Det(s), Det(s), Det(s) FROM SourceTable",
+                        List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", 
"ALICE", "ALICE")),
+                        2, // expected localref calls: rows × 1 (cached)
+                        0),
+                Arguments.of(
+                        "SELECT Det(s), Det(s), UPPER(s) FROM SourceTable",
+                        List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", 
"ALICE", "ALICE")),
+                        2, // rows × 1 (cached); built-in UPPER not counted
+                        0),
+                Arguments.of(
+                        "SELECT Det(Det(s)), Det(Det(s)), Det(Det(s)) FROM 
SourceTable",
+                        List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE", 
"ALICE", "ALICE")),
+                        4, // rows × 2 layers
+                        0),
+                Arguments.of(
+                        "SELECT Nondet(s), Nondet(s), Nondet(s) FROM 
SourceTable",
+                        List.of(
+                                Row.of("BOB_1", "BOB_2", "BOB_3"),
+                                Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+                        0,
+                        6 // rows × 3 projections
+                        ),
+                Arguments.of(
+                        "SELECT Nondet(Det(s)), Nondet(Det(s)), Nondet(Det(s)) 
FROM SourceTable",
+                        List.of(
+                                Row.of("BOB_1", "BOB_2", "BOB_3"),
+                                Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+                        2, // rows × 1 (inner cached)
+                        6 // rows × 3 projections
+                        ),
+                Arguments.of(
+                        "SELECT Det(Nondet(s)), Det(Nondet(s)), Det(Nondet(s)) 
FROM SourceTable",
+                        List.of(
+                                Row.of("BOB_1", "BOB_2", "BOB_3"),
+                                Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+                        6, // rows × 3 (nondet input disables cache)
+                        6 // rows × 3 projections
+                        ),
+                // shared Det in filter → cached once per row
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' 
AND Det(s) <> ' '",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        2,
+                        0),
+                // mixed UDF + built-in
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Det(s) IS NOT NULL AND Det(s) <> '' 
AND UPPER(s) <> ''",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        2,
+                        0),
+                // nested Det in filter; both layers cached
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Det(Det(s)) IS NOT NULL"
+                                + " AND Det(Det(s)) <> '' AND Det(Det(s)) <> ' 
'",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        4,
+                        0),
+                // non-deterministic in filter — never cached
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Nondet(s) IS NOT NULL"
+                                + " AND Nondet(s) <> '' AND Nondet(s) <> ' '",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        0,
+                        6),
+                // outer nondet, inner Det cached
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Nondet(Det(s)) IS NOT NULL"
+                                + " AND Nondet(Det(s)) <> '' AND 
Nondet(Det(s)) <> ' '",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        2,
+                        6),
+                // Det with nondet input → cache bypassed
+                Arguments.of(
+                        "SELECT s FROM SourceTable"
+                                + " WHERE Det(Nondet(s)) IS NOT NULL"
+                                + " AND Det(Nondet(s)) <> '' AND 
Det(Nondet(s)) <> ' '",
+                        List.of(Row.of("Bob"), Row.of("Alice")),
+                        6,
+                        6),
+                // filter ↔ projection share via unified program
+                Arguments.of(
+                        "SELECT Det(s) FROM SourceTable WHERE Det(s) = 'BOB'",
+                        List.of(Row.of("BOB")),
+                        2,
+                        0),
+                Arguments.of(
+                        "SELECT Det(s), Det(s) FROM SourceTable WHERE Det(s) = 
'BOB'",
+                        List.of(Row.of("BOB", "BOB")),
+                        2,
+                        0),
+
+                // 
---------------------------------------------------------------------------
+                // JSON construction scenarios. These verify that the localref 
/ RexProgram CSE
+                // cache also fires when the shared sub-expression is wrapped 
inside (or itself
+                // is) a JSON_OBJECT / JSON_ARRAY / JSON_STRING call.
+                // 
---------------------------------------------------------------------------
+
+                // JSON_OBJECT × 2 sharing inner Det → cached once per row.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s)),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Det(s))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"),
+                                Row.of("{\"a\":\"ALICE\"}", 
"{\"b\":\"ALICE\"}")),
+                        2, // rows × 1 (cached)
+                        0),
+                // JSON_ARRAY × 2 sharing inner Det → cached.
+                Arguments.of(
+                        "SELECT JSON_ARRAY(Det(s)), JSON_ARRAY(Det(s)) FROM 
SourceTable",
+                        List.of(
+                                Row.of("[\"BOB\"]", "[\"BOB\"]"),
+                                Row.of("[\"ALICE\"]", "[\"ALICE\"]")),
+                        2,
+                        0),
+                // JSON_STRING × 2 sharing inner Det → cached.
+                Arguments.of(
+                        "SELECT JSON_STRING(Det(s)), JSON_STRING(Det(s)) FROM 
SourceTable",
+                        List.of(Row.of("\"BOB\"", "\"BOB\""), 
Row.of("\"ALICE\"", "\"ALICE\"")),
+                        2,
+                        0),
+                // Mixed JSON_OBJECT + JSON_ARRAY sharing same Det.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), 
JSON_ARRAY(Det(s))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"k\":\"BOB\"}", "[\"BOB\"]"),
+                                Row.of("{\"k\":\"ALICE\"}", "[\"ALICE\"]")),
+                        2,
+                        0),
+                // Mixed JSON_OBJECT + JSON_STRING sharing same Det.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)), 
JSON_STRING(Det(s))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"k\":\"BOB\"}", "\"BOB\""),
+                                Row.of("{\"k\":\"ALICE\"}", "\"ALICE\"")),
+                        2,
+                        0),
+                // JSON_OBJECT × 3 sharing same Det → cached across all 3 
sites.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s)),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Det(s)),"
+                                + " JSON_OBJECT(KEY 'c' VALUE Det(s))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}", 
"{\"c\":\"BOB\"}"),
+                                Row.of(
+                                        "{\"a\":\"ALICE\"}",
+                                        "{\"b\":\"ALICE\"}",
+                                        "{\"c\":\"ALICE\"}")),
+                        2,
+                        0),
+                // Nested Det(Det(s)) inside two JSON_OBJECT projections → 
both layers cached.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Det(s))),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Det(Det(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"),
+                                Row.of("{\"a\":\"ALICE\"}", 
"{\"b\":\"ALICE\"}")),
+                        4, // rows × 2 layers
+                        0),
+                // Nondet inside two JSON_OBJECT projections → never cached.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(s)),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Nondet(s))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB_1\"}", 
"{\"b\":\"BOB_2\"}"),
+                                Row.of("{\"a\":\"ALICE_3\"}", 
"{\"b\":\"ALICE_4\"}")),
+                        0,
+                        4 // rows × 2 projections
+                        ),
+                // Outer Nondet, inner Det inside two JSON_OBJECT projections 
— Det cached.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(Det(s))),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Nondet(Det(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB_1\"}", 
"{\"b\":\"BOB_2\"}"),
+                                Row.of("{\"a\":\"ALICE_3\"}", 
"{\"b\":\"ALICE_4\"}")),
+                        2, // inner Det cached
+                        4),
+                // Outer Det, inner Nondet → outer cache disabled by nondet 
operand.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Nondet(s))),"
+                                + " JSON_OBJECT(KEY 'b' VALUE Det(Nondet(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":\"BOB_1\"}", 
"{\"b\":\"BOB_2\"}"),
+                                Row.of("{\"a\":\"ALICE_3\"}", 
"{\"b\":\"ALICE_4\"}")),
+                        4, // outer Det not cached (nondet operand)
+                        4),
+                // Filter ↔ JSON projection share Det via unified program.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s))"
+                                + " FROM SourceTable WHERE Det(s) = 'BOB'",
+                        List.of(Row.of("{\"k\":\"BOB\"}")),
+                        2,
+                        0),
+                // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two 
outer JSON_OBJECT
+                // projections — verifies CSE works when the cached node is 
itself a JSON
+                // construction call (and validates the JSON helpers' 
RexLocalRef deref path
+                // along the way).
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY 
'k' VALUE Det(s))),"
+                                + " JSON_OBJECT(KEY 'outer2' VALUE 
JSON_OBJECT(KEY 'k' VALUE Det(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of(
+                                        "{\"outer1\":{\"k\":\"BOB\"}}",
+                                        "{\"outer2\":{\"k\":\"BOB\"}}"),
+                                Row.of(
+                                        "{\"outer1\":{\"k\":\"ALICE\"}}",
+                                        "{\"outer2\":{\"k\":\"ALICE\"}}")),
+                        2,
+                        0),
+                // Shared inner JSON_ARRAY(Det(s)) inside two outer 
JSON_OBJECT projections.
+                Arguments.of(
+                        "SELECT JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(Det(s))),"
+                                + " JSON_OBJECT(KEY 'b' VALUE 
JSON_ARRAY(Det(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("{\"a\":[\"BOB\"]}", 
"{\"b\":[\"BOB\"]}"),
+                                Row.of("{\"a\":[\"ALICE\"]}", 
"{\"b\":[\"ALICE\"]}")),
+                        2,
+                        0),
+                // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two 
JSON_ARRAY
+                // projections.
+                Arguments.of(
+                        "SELECT JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s))),"
+                                + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE 
Det(s)))"
+                                + " FROM SourceTable",
+                        List.of(
+                                Row.of("[{\"k\":\"BOB\"}]", 
"[{\"k\":\"BOB\"}]"),
+                                Row.of("[{\"k\":\"ALICE\"}]", 
"[{\"k\":\"ALICE\"}]")),
+                        2,
+                        0));
+    }
+
+    @Test
+    void testLocalRefReuseForMixedArgs() {
+        final List<Row> sourceData = List.of(Row.of("Bob"), Row.of("Alice"));
+        final int callSites = 2;
+
+        TestCollectionTableFactory.reset();
+        TestCollectionTableFactory.initData(sourceData);
+        CountingUpperScalarFunction.COUNT.set(0);
+        NonDeterministicCountingScalarFunction.COUNT.set(0);
+        CountingConcat3ScalarFunction.COUNT.set(0);
+
+        tEnv().createTemporarySystemFunction("Det", 
CountingUpperScalarFunction.class);
+        tEnv().createTemporarySystemFunction(
+                        "Nondet", 
NonDeterministicCountingScalarFunction.class);
+        tEnv().createTemporarySystemFunction("Concat3", 
CountingConcat3ScalarFunction.class);
+        tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH 
('connector' = 'COLLECTION')");
+
+        final List<Row> actual =
+                CollectionUtil.iteratorToList(
+                        tEnv().executeSql(
+                                        "SELECT Concat3(Det(s), Nondet(s), 
Det(s)),"
+                                                + " Concat3(Det(s), Nondet(s), 
Det(s))"
+                                                + " FROM SourceTable")
+                                .collect());
+
+        assertThat(actual)
+                .containsExactly(
+                        Row.of("BOB/BOB_1/BOB", "BOB/BOB_2/BOB"),
+                        Row.of("ALICE/ALICE_3/ALICE", "ALICE/ALICE_4/ALICE"));
+
+        
assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size());
+        assertThat(NonDeterministicCountingScalarFunction.COUNT.get())
+                .isEqualTo(sourceData.size() * callSites);
+        // Concat3 is deterministic however has non-deterministic input
+        assertThat(CountingConcat3ScalarFunction.COUNT.get())
+                .isEqualTo(sourceData.size() * callSites);
+    }
+
+    @Test
+    void testCalcSharesSubExpressionBetweenFilterAndProjection() {
+        final List<Row> sourceData =
+                List.of(Row.of("Bob"), Row.of("Bob"), Row.of("Alice"), 
Row.of("Alice"));
+
+        TestCollectionTableFactory.reset();
+        TestCollectionTableFactory.initData(sourceData);
+        CountingUpperScalarFunction.COUNT.set(0);
+
+        tEnv().createTemporarySystemFunction("CountingUpper", 
CountingUpperScalarFunction.class);
+        tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH 
('connector' = 'COLLECTION')");
+
+        final List<Row> actual =
+                CollectionUtil.iteratorToList(
+                        tEnv().executeSql(
+                                        "SELECT CountingUpper(s) FROM 
SourceTable"
+                                                + " WHERE CountingUpper(s) = 
'BOB' AND CountingUpper(s) <> 'BOB2'")
+                                .collect());
+
+        assertThat(actual).containsExactly(Row.of("BOB"), Row.of("BOB"));
+
+        // Filter and projection share via the unified RexProgram, so the UDF 
runs once per
+        // source row regardless of how many call sites name it.
+        
assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size());
+    }
+
+    /**
+     * Pins the CASE-WHEN guard interaction with the RexLocalRef cache.
+     *
+     * <p>Prior to scoped caching, RexProgramBuilder collapsed the division 
{@code a / b} into a

Review Comment:
   sorry, dropped and added instructions



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to