fsk119 commented on code in PR #27121:
URL: https://github.com/apache/flink/pull/27121#discussion_r2438063386


##########
flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml:
##########
@@ -98,22 +102,142 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], 
rowtime=[$4], proctime=[$5], e=[$
     <Resource name="optimized rel plan">
       <![CDATA[
 Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime, 
e, f, g, score])
-+- Correlate(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), 
$cor0.d, 10)], 
correlate=[table(VECTOR_SEARCH(TABLE(),DESCRIPTOR('g'),$cor0.d,10))], 
select=[a,b,c,d,rowtime,proctime,e,f,g,score], rowType=[RecordType(INTEGER a, 
BIGINT b, VARCHAR(2147483647) c, FLOAT ARRAY d, TIMESTAMP(3) *ROWTIME* rowtime, 
TIMESTAMP_LTZ(3) *PROCTIME* proctime, INTEGER e, BIGINT f, FLOAT ARRAY g, 
DOUBLE score)], joinType=[INNER])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], 
joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[d], topK=[10], 
select=[a, b, c, d, rowtime, proctime, e, f, g, score])
    +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL 
SECOND)])
       +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
          +- TableSourceScan(table=[[default_catalog, default_database, 
QueryTable]], fields=[a, b, c, d, rowtime])
 ]]>
     </Resource>
   </TestCase>
-  <TestCase name="testOutOfOrderNamedArgument">
+  <TestCase name="testSearchTableWithCalc">
     <Resource name="sql">
       <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
 VECTOR_SEARCH(
-    COLUMN_TO_QUERY => QueryTable.d,
-    COLUMN_TO_SEARCH => DESCRIPTOR(`g`),
-    TOP_K => 10,
-    SEARCH_TABLE => TABLE VectorTable
-  )
+    TABLE VectorTableWithProctime, DESCRIPTOR(`g`), QueryTable.d, 10))]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5], 
e=[$6], f=[$7], g=[$8], proctime0=[$9], score=[$10])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{3}])
+   :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4, 
1000:INTERVAL SECOND)])
+   :  +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], 
proctime=[PROCTIME()])
+   :     +- LogicalTableScan(table=[[default_catalog, default_database, 
QueryTable]])
+   +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, BIGINT 
f, FLOAT ARRAY g, TIMESTAMP_LTZ(3) *PROCTIME* proctime, DOUBLE score)])
+      +- LogicalProject(e=[$0], f=[$1], g=[$2], proctime=[$3])
+         +- LogicalProject(e=[$0], f=[$1], g=[$2], proctime=[PROCTIME()])
+            +- LogicalTableScan(table=[[default_catalog, default_database, 
VectorTableWithProctime]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime, 
e, f, g, PROCTIME_MATERIALIZE(proctime0) AS proctime0, score])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTableWithProctime],
 joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[d], topK=[10], 
select=[a, b, c, d, rowtime, proctime, e, f, g, PROCTIME() AS proctime, score])
+   +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL 
SECOND)])
+      +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+         +- TableSourceScan(table=[[default_catalog, default_database, 
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSearchTableWithDescriptorUsingMetadata">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+  VECTOR_SEARCH(
+    TABLE VectorTableWithMetadata,
+    DESCRIPTOR(`f`),
+    QueryTable.d,
+    10  )
+)]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5], 
e=[$6], f=[$7], g=[$8], h=[$9], score=[$10])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{3}])
+   :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4, 
1000:INTERVAL SECOND)])
+   :  +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], 
proctime=[PROCTIME()])
+   :     +- LogicalTableScan(table=[[default_catalog, default_database, 
QueryTable]])
+   +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'f'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, FLOAT 
ARRAY f, FLOAT ARRAY g, INTEGER h, DOUBLE score)])
+      +- LogicalProject(e=[$0], f=[$1], g=[$2], h=[$3])
+         +- LogicalProject(e=[$0], f=[$2], g=[$1], h=[+($0, 1)])
+            +- LogicalTableScan(table=[[default_catalog, default_database, 
VectorTableWithMetadata, metadata=[f]]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime, 
e, f, g, h, score])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTableWithMetadata],
 joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[d], topK=[10], 
select=[a, b, c, d, rowtime, proctime, e, f, g, +(e, 1) AS h, score])
+   +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL 
SECOND)])
+      +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+         +- TableSourceScan(table=[[default_catalog, default_database, 
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSearchTableWithProjection">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+VECTOR_SEARCH(
+    (SELECT e, g, proctime FROM VectorTableWithProctime), DESCRIPTOR(`g`), 
QueryTable.d, 10))]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5], 
e=[$6], g=[$7], proctime0=[$8], score=[$9])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{3}])
+   :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4, 
1000:INTERVAL SECOND)])
+   :  +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], 
proctime=[PROCTIME()])
+   :     +- LogicalTableScan(table=[[default_catalog, default_database, 
QueryTable]])
+   +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), 
DESCRIPTOR(_UTF-16LE'g'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, FLOAT 
ARRAY g, TIMESTAMP_LTZ(3) *PROCTIME* proctime, DOUBLE score)])
+      +- LogicalProject(e=[$0], g=[$2], proctime=[$3])
+         +- LogicalProject(e=[$0], f=[$1], g=[$2], proctime=[PROCTIME()])
+            +- LogicalTableScan(table=[[default_catalog, default_database, 
VectorTableWithProctime]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime, 
e, g, PROCTIME_MATERIALIZE(proctime0) AS proctime0, score])
++- 
VectorSearchTableFunction(table=[default_catalog.default_database.VectorTableWithProctime],
 joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[d], topK=[10], 
select=[a, b, c, d, rowtime, proctime, e, g, PROCTIME() AS proctime, score])

Review Comment:
   Nope. It just influences the plan display.  



##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalVectorSearchTableFunctionRule.java:
##########
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.stream;
+
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import 
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
+import org.apache.flink.table.planner.plan.nodes.exec.spec.VectorSearchSpec;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
+import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import 
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalVectorSearchTableFunction;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptTable;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.plan.volcano.RelSubset;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Calc;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.TableScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.util.Util;
+import org.immutables.value.Value;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Rule to convert a {@link FlinkLogicalCorrelate} with VECTOR_SEARCH call 
into a {@link
+ * StreamPhysicalVectorSearchTableFunction}.
+ */
[email protected]
+public class StreamPhysicalVectorSearchTableFunctionRule
+        extends RelRule<StreamPhysicalVectorSearchTableFunctionRule.Config> {
+
+    public static final StreamPhysicalVectorSearchTableFunctionRule INSTANCE =
+            Config.DEFAULT.toRule();
+
+    protected StreamPhysicalVectorSearchTableFunctionRule(Config config) {
+        super(config);
+    }
+
+    @Override
+    public boolean matches(RelOptRuleCall call) {
+        FlinkLogicalTableFunctionScan scan = call.rel(2);
+        RexNode rexNode = scan.getCall();
+        if (!(rexNode instanceof RexCall)) {
+            return false;
+        }
+        RexCall rexCall = (RexCall) rexNode;
+        return rexCall.getOperator() instanceof SqlVectorSearchTableFunction;
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        // QUERY_TABLE
+        RelNode input = call.rel(1);
+        final RelNode newInput = RelOptRule.convert(input, 
FlinkConventions.STREAM_PHYSICAL());
+
+        // SEARCH_TABLE
+        FlinkLogicalCorrelate correlate = call.rel(0);
+        FlinkLogicalTableFunctionScan vectorSearchCall = call.rel(2);
+        String functionName = ((RexCall) 
vectorSearchCall.getCall()).getOperator().getName();
+        SearchTableExtractor extractor = new 
SearchTableExtractor(functionName);
+        extractor.visit(vectorSearchCall.getInput(0));
+
+        call.transformTo(
+                new StreamPhysicalVectorSearchTableFunction(
+                        correlate.getCluster(),
+                        
correlate.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()),
+                        newInput,
+                        extractor.searchTable,
+                        extractor.calcProgram == null
+                                ? null
+                                : pullUpRexProgram(
+                                        input.getCluster(),
+                                        extractor.searchTable.getRowType(),
+                                        vectorSearchCall.getRowType(),
+                                        extractor.calcProgram),
+                        buildVectorSearchSpec(
+                                correlate, vectorSearchCall, 
extractor.searchTable, functionName),
+                        correlate.getRowType()));
+    }
+
+    private VectorSearchSpec buildVectorSearchSpec(
+            FlinkLogicalCorrelate correlate,
+            FlinkLogicalTableFunctionScan scan,
+            RelOptTable searchTable,
+            String functionName) {
+        JoinRelType joinType = correlate.getJoinType();
+        if (joinType != JoinRelType.INNER && joinType != JoinRelType.LEFT) {
+            throw new TableException(
+                    String.format(
+                            "%s only supports INNER JOIN and LEFT JOIN, but 
get %s JOIN.",
+                            functionName, joinType));
+        }
+
+        RexCall functionCall = (RexCall) scan.getCall();
+
+        // COLUMN_TO_SEARCH
+        RexCall descriptorCall = (RexCall) functionCall.getOperands().get(1);
+        RexNode searchColumn = descriptorCall.getOperands().get(0);
+        if (!(searchColumn instanceof RexLiteral)) {
+            throw new TableException(
+                    String.format(
+                            "%s got an unknown parameter column_to_search in 
descriptor: %s.",
+                            functionName, searchColumn));
+        }
+        int searchIndex =
+                searchTable
+                        .getRowType()
+                        .getFieldNames()
+                        .indexOf(RexLiteral.stringValue(searchColumn));
+        if (searchIndex == -1) {
+            throw new TableException(
+                    String.format(
+                            "%s can not find column `%s` in the search_table 
%s physical output type. Currently, Flink doesn't support to use computed 
column as the search column.",
+                            functionName,
+                            RexLiteral.stringValue(searchColumn),
+                            String.join(".", searchTable.getQualifiedName())));
+        }
+
+        // COLUMN_TO_QUERY
+        FunctionCallUtil.FunctionParam queryColumn =
+                getQueryColumnParam(functionCall.getOperands().get(2), 
correlate, functionName);
+
+        Map<Integer, FunctionCallUtil.FunctionParam> searchColumns = new 
LinkedHashMap<>();
+        searchColumns.put(searchIndex, queryColumn);
+
+        // TOP_K
+        RexLiteral topK = (RexLiteral) functionCall.getOperands().get(3);
+        FunctionCallUtil.Constant topKParam =
+                new 
FunctionCallUtil.Constant(FlinkTypeFactory.toLogicalType(topK.getType()), topK);
+
+        return new VectorSearchSpec(joinType, searchColumns, topKParam);
+    }
+
+    private FunctionCallUtil.FunctionParam getQueryColumnParam(
+            RexNode queryColumn, FlinkLogicalCorrelate correlate, String 
functionName) {
+        if (queryColumn instanceof RexFieldAccess) {
+            RexNode refNode = ((RexFieldAccess) 
queryColumn).getReferenceExpr();
+            if (refNode instanceof RexFieldAccess) {
+                // nested field unsupported
+                throw new TableException(
+                        String.format(
+                                "%s does not support nested field in parameter 
column_to_query, but get %s.",
+                                functionName, queryColumn));
+            } else if 
(!(correlate.getCorrelationId().equals(((RexCorrelVariable) refNode).id))) {
+                throw new TableException(
+                        String.format(
+                                "This is a bug. Planner can not resolve the 
correlation in %s. Please file an issue.",
+                                functionName));
+            }
+            return new FunctionCallUtil.FieldRef(
+                    ((RexFieldAccess) queryColumn).getField().getIndex());
+        } else {
+            throw new TableException(
+                    String.format(
+                            "Expect function %s's parameter column_to_query is 
literal or field reference, but get expression %s. ",
+                            functionName, queryColumn));
+        }
+    }
+
+    /**
+     * Pull up the Calc under the VectorSearchCall.
+     *
+     * <p>Note: The vector search operator actually fetch the data and then do 
the calculation. So
+     * pull up the calc is to align the behaviour.
+     */
+    private RexProgram pullUpRexProgram(
+            RelOptCluster cluster,
+            RelDataType scanOutputType,
+            RelDataType originFunctionCallType,
+            RexProgram originProgram) {
+        RelDataType searchOutputType =
+                cluster.getTypeFactory()
+                        .builder()
+                        .kind(scanOutputType.getStructKind())
+                        .addAll(scanOutputType.getFieldList())
+                        .add(Util.last(originFunctionCallType.getFieldList()))
+                        .build();
+        RelDataType newOutputType =
+                cluster.getTypeFactory()
+                        .builder()
+                        .kind(originProgram.getOutputRowType().getStructKind())
+                        
.addAll(originProgram.getOutputRowType().getFieldList())
+                        .add(Util.last(originFunctionCallType.getFieldList()))
+                        .build();
+        List<RexNode> exprs = new ArrayList<>(originProgram.getExprList());
+        exprs.add(
+                cluster.getRexBuilder()
+                        .makeInputRef(searchOutputType, 
searchOutputType.getFieldCount() - 1));
+        return RexProgramBuilder.create(
+                        cluster.getRexBuilder(),
+                        searchOutputType,
+                        exprs,
+                        originProgram.getProjectList(),
+                        originProgram.getCondition(),
+                        newOutputType,
+                        true,
+                        null)
+                .getProgram();
+    }
+
+    @Value.Immutable
+    public interface Config extends RelRule.Config {
+
+        Config DEFAULT =
+                
ImmutableStreamPhysicalVectorSearchTableFunctionRule.Config.builder()
+                        .build()
+                        .withOperandSupplier(
+                                b0 ->
+                                        b0.operand(FlinkLogicalCorrelate.class)
+                                                .inputs(
+                                                        b1 ->
+                                                                
b1.operand(FlinkLogicalRel.class)
+                                                                        
.anyInputs(),
+                                                        b2 ->
+                                                                b2.operand(
+                                                                               
 FlinkLogicalTableFunctionScan
+                                                                               
         .class)
+                                                                        
.anyInputs()))
+                        
.withDescription("StreamPhysicalVectorSearchTableFunctionRule");
+
+        @Override
+        default StreamPhysicalVectorSearchTableFunctionRule toRule() {
+            return new StreamPhysicalVectorSearchTableFunctionRule(this);
+        }
+    }
+
+    /**
+     * A utility class to extract table source and calc program.
+     *
+     * <p>Supported tree structure:
+     *
+     * <pre>{@code
+     * Calc(without filter) —— TableScan
+     * TableScan
+     * }</pre>
+     */
+    static class SearchTableExtractor {
+
+        enum NodeType {
+            CALC,
+            SCAN
+        }
+
+        @Nullable RexProgram calcProgram;
+        TableSourceTable searchTable;
+
+        private final String functionName;
+        private NodeType parentNode;
+
+        SearchTableExtractor(String functionName) {
+            this.functionName = functionName;
+        }
+
+        private void visit(RelNode rel) {
+            if (rel instanceof RelSubset) {
+                rel = ((RelSubset) rel).getBestOrOriginal();
+            }
+
+            NodeType currentNode = transform(rel);
+            switch (currentNode) {
+                case CALC:
+                    if (parentNode != null) {
+                        throw new RelOptPlanner.CannotPlanException(
+                                String.format(
+                                        "%s assumes calc to be the first node 
in parameter search_table, but it has a parent %s.",

Review Comment:
   I don't think we can add a test here, because the relational structure we 
currently support is too limited — only scan and calc nodes are allowed.
   
   However, it’s not possible to construct a calc node as a child of a scan 
node, or to build a tree like calc → calc → scan. This is because calc nodes 
are often merged or eliminated during optimization (see CalcRemoveRule and 
CalcMergeRule).
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to