This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 56bf7c887fd [FLINK-38424][planner] Support to parse VECTOR_SEARCH
function (#27039)
56bf7c887fd is described below
commit 56bf7c887fd4470d926afc6690550833739727b5
Author: Shengkai <[email protected]>
AuthorDate: Thu Oct 16 11:15:20 2025 +0800
[FLINK-38424][planner] Support to parse VECTOR_SEARCH function (#27039)
---
.../calcite/sql/validate/SqlValidatorImpl.java | 37 +++-
.../functions/sql/FlinkSqlOperatorTable.java | 4 +
.../sql/ml/SqlVectorSearchTableFunction.java | 239 +++++++++++++++++++++
.../planner/functions/utils/SqlValidatorUtils.java | 22 +-
.../stream/sql/VectorSearchTableFunctionTest.java | 224 +++++++++++++++++++
.../stream/sql/VectorSearchTableFunctionTest.xml | 141 ++++++++++++
6 files changed, 653 insertions(+), 14 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
index 624fb7e6d7b..89317e352f6 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
@@ -16,7 +16,9 @@
*/
package org.apache.calcite.sql.validate;
+import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -177,10 +179,12 @@ import static org.apache.calcite.util.Util.first;
*
* <p>Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading
Calcite to 1.41.0.
*
- * <p>Lines 3895 ~ 3899, 6574 ~ 6580 Flink improves Optimize the retrieval of
sub-operands in
+ * <p>Line 2618 ~2631, set the correct scope for VECTOR_SEARCH.
+ *
+ * <p>Lines 3920 ~ 3925, 6599 ~ 6606 Flink improves Optimize the retrieval of
sub-operands in
* SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}.
*
- * <p>Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check
on SqlSnapshot.
+ * <p>Lines 5340 ~ 5347, FLINK-24352 Add null check for temporal table check
on SqlSnapshot.
*/
public class SqlValidatorImpl implements SqlValidatorWithHints {
// ~ Static fields/initializers
---------------------------------------------
@@ -2570,6 +2574,10 @@ public class SqlValidatorImpl implements
SqlValidatorWithHints {
case LATERAL:
// ----- FLINK MODIFICATION BEGIN -----
SqlBasicCall sbc = (SqlBasicCall) node;
+ // Put the usingScope which is a JoinScope,
+ // in order to make visible the left items
+ // of the JOIN tree.
+ scopes.put(node, usingScope);
registerFrom(
parentScope,
usingScope,
@@ -2580,10 +2588,6 @@ public class SqlValidatorImpl implements
SqlValidatorWithHints {
extendList,
forceNullable,
true);
- // Put the usingScope which is a JoinScope,
- // in order to make visible the left items
- // of the JOIN tree.
- scopes.put(node, usingScope);
return sbc;
// ----- FLINK MODIFICATION END -----
@@ -2614,6 +2618,27 @@ public class SqlValidatorImpl implements
SqlValidatorWithHints {
scopes.put(node, getSelectScope(call1.operand(0)));
return newNode;
}
+
+ // Related to CALCITE-4077
+ // ----- FLINK MODIFICATION BEGIN -----
+ FlinkSqlCallBinding binding =
+ new FlinkSqlCallBinding(this, getEmptyScope(),
call1);
+ if (op instanceof SqlVectorSearchTableFunction
+ && binding.operand(0)
+ .isA(
+ new HashSet<>(
+
Collections.singletonList(SqlKind.SELECT)))) {
+ boolean queryColumnIsNotLiteral =
+ binding.operand(2).getKind() !=
SqlKind.LITERAL;
+ if (!queryColumnIsNotLiteral && !lateral) {
+ throw new ValidationException(
+ "The query column is not literal, please
use LATERAL TABLE to run VECTOR_SEARCH.");
+ }
+ SqlValidatorScope scope = getSelectScope((SqlSelect)
binding.operand(0));
+ scopes.put(enclosingNode, scope);
+ return newNode;
+ }
+ // ----- FLINK MODIFICATION END -----
}
// Put the usingScope which can be a JoinScope
// or a SelectScope, in order to see the left items
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
index 5ddbebd98c9..4469b376420 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
@@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import
org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
import org.apache.flink.table.planner.plan.type.FlinkReturnTypes;
import
org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker;
@@ -1328,6 +1329,9 @@ public class FlinkSqlOperatorTable extends
ReflectiveSqlOperatorTable {
// MODEL TABLE FUNCTIONS
public static final SqlFunction ML_EVALUATE = new
SqlMLEvaluateTableFunction();
+ // SEARCH FUNCTIONS
+ public static final SqlFunction VECTOR_SEARCH = new
SqlVectorSearchTableFunction();
+
// Catalog Functions
public static final SqlFunction CURRENT_DATABASE =
BuiltInSqlFunction.newBuilder()
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java
new file mode 100644
index 00000000000..a655efdf9f0
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
+
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
+import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.SqlCallBinding;
+import org.apache.calcite.sql.SqlFunction;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperandCountRange;
+import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlOperatorBinding;
+import org.apache.calcite.sql.SqlTableFunction;
+import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlOperandCountRanges;
+import org.apache.calcite.sql.type.SqlOperandMetadata;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.validate.SqlNameMatcher;
+import org.apache.calcite.util.Util;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+import static
org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
+
+/**
+ * {@link SqlVectorSearchTableFunction} implements an operator for search.
+ *
+ * <p>It allows four parameters:
+ *
+ * <ol>
+ * <li>a table
+ * <li>a descriptor to provide a column name from the input table
+ * <li>a query column from the left table
+ * <li>a literal value for top k
+ * </ol>
+ */
+public class SqlVectorSearchTableFunction extends SqlFunction implements
SqlTableFunction {
+
+ private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE";
+ private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH";
+ private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY";
+ private static final String PARAM_TOP_K = "TOP_K";
+
+ private static final String OUTPUT_SCORE = "score";
+
+ public SqlVectorSearchTableFunction() {
+ super(
+ "VECTOR_SEARCH",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.CURSOR,
+ null,
+ new OperandMetadataImpl(),
+ SqlFunctionCategory.SYSTEM);
+ }
+
+ @Override
+ public SqlReturnTypeInference getRowTypeInference() {
+ return new SqlReturnTypeInference() {
+ @Override
+ public @Nullable RelDataType inferReturnType(SqlOperatorBinding
opBinding) {
+ final RelDataTypeFactory typeFactory =
opBinding.getTypeFactory();
+ final RelDataType inputRowType = opBinding.getOperandType(0);
+
+ return typeFactory
+ .builder()
+ .kind(inputRowType.getStructKind())
+ .addAll(inputRowType.getFieldList())
+ .addAll(
+ SqlValidatorUtils.makeOutputUnique(
+ inputRowType.getFieldList(),
+ Collections.singletonList(
+ new RelDataTypeFieldImpl(
+ OUTPUT_SCORE,
+ 0,
+
typeFactory.createSqlType(
+
SqlTypeName.DOUBLE)))))
+ .build();
+ }
+ };
+ }
+
+ @Override
+ public boolean argumentMustBeScalar(int ordinal) {
+ return ordinal != 0;
+ }
+
+ private static class OperandMetadataImpl implements SqlOperandMetadata {
+
+ private static final List<String> PARAMETERS =
+ Collections.unmodifiableList(
+ Arrays.asList(
+ PARAM_SEARCH_TABLE,
+ PARAM_COLUMN_TO_SEARCH,
+ PARAM_COLUMN_TO_QUERY,
+ PARAM_TOP_K));
+
+ @Override
+ public List<RelDataType> paramTypes(RelDataTypeFactory
relDataTypeFactory) {
+ return Collections.nCopies(
+ PARAMETERS.size(),
relDataTypeFactory.createSqlType(SqlTypeName.ANY));
+ }
+
+ @Override
+ public List<String> paramNames() {
+ return PARAMETERS;
+ }
+
+ @Override
+ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean
throwOnFailure) {
+ // check vector table contains descriptor columns
+ if
(!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
+ return
SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
+ callBinding, throwOnFailure);
+ }
+
+ List<SqlNode> operands = callBinding.operands();
+ // check descriptor has one column
+ SqlCall descriptor = (SqlCall) operands.get(1);
+ List<SqlNode> descriptorCols = descriptor.getOperandList();
+ if (descriptorCols.size() != 1) {
+ return SqlValidatorUtils.throwExceptionOrReturnFalse(
+ Optional.of(
+ new ValidationException(
+ String.format(
+ "Expect parameter
COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple
columns are found in operand %s.",
+ descriptor))),
+ throwOnFailure);
+ }
+
+ // check descriptor type is ARRAY<FLOAT> or ARRAY<DOUBLE>
+ RelDataType searchTableType = callBinding.getOperandType(0);
+ SqlNameMatcher matcher =
callBinding.getValidator().getCatalogReader().nameMatcher();
+ SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0);
+ String descriptorColName =
+ columnName.isSimple() ? columnName.getSimple() :
Util.last(columnName.names);
+ int index = matcher.indexOf(searchTableType.getFieldNames(),
descriptorColName);
+ RelDataType targetType =
searchTableType.getFieldList().get(index).getType();
+ LogicalType targetLogicalType = toLogicalType(targetType);
+
+ if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY)
+ && ((ArrayType) (targetLogicalType))
+ .getElementType()
+ .isAnyOf(LogicalTypeRoot.FLOAT,
LogicalTypeRoot.DOUBLE))) {
+ return SqlValidatorUtils.throwExceptionOrReturnFalse(
+ Optional.of(
+ new ValidationException(
+ String.format(
+ "Expect search column `%s`
type is ARRAY<FLOAT> or ARRAY<DOUBLE>, but its type is %s.",
+ columnName, targetType))),
+ throwOnFailure);
+ }
+
+ // check query type is ARRAY<FLOAT> or ARRAY<DOUBLE>
+ LogicalType sourceLogicalType =
toLogicalType(callBinding.getOperandType(2));
+ if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType,
targetLogicalType)) {
+ return SqlValidatorUtils.throwExceptionOrReturnFalse(
+ Optional.of(
+ new ValidationException(
+ String.format(
+ "Can not cast the query column
type %s to target type %s. Please keep the query column type is same to the
search column type.",
+ sourceLogicalType,
targetType))),
+ throwOnFailure);
+ }
+
+ // check topK is literal
+ LogicalType topKType =
toLogicalType(callBinding.getOperandType(3));
+ if (!operands.get(3).getKind().equals(SqlKind.LITERAL)
+ || !topKType.is(LogicalTypeRoot.INTEGER)) {
+ return SqlValidatorUtils.throwExceptionOrReturnFalse(
+ Optional.of(
+ new ValidationException(
+ String.format(
+ "Expect parameter topK is
integer literal in VECTOR_SEARCH, but it is %s with type %s.",
+ operands.get(3), topKType))),
+ throwOnFailure);
+ }
+
+ return true;
+ }
+
+ @Override
+ public SqlOperandCountRange getOperandCountRange() {
+ return SqlOperandCountRanges.between(4, 4);
+ }
+
+ @Override
+ public String getAllowedSignatures(SqlOperator op, String opName) {
+ return opName + "(TABLE table_name, DESCRIPTOR(query_column),
search_column, top_k)";
+ }
+
+ @Override
+ public Consistency getConsistency() {
+ return Consistency.NONE;
+ }
+
+ @Override
+ public boolean isOptional(int i) {
+ return false;
+ }
+
+ @Override
+ public boolean isFixedParameters() {
+ return true;
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java
index 42e381a6062..66b58499e09 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java
@@ -160,27 +160,33 @@ public class SqlValidatorUtils {
/**
* Make output field names unique from input field names by appending
index. For example, Input
* has field names {@code a, b, c} and output has field names {@code b, c,
d}. After calling
- * this function, new output field names will be {@code b0, c0, d}.
Duplicate names are not
- * checked inside input and output itself.
+ * this function, new output field names will be {@code b0, c0, d}.
+ *
+ * <p>We assume that input fields in the input parameter are uniquely
named, just as the output
+ * fields in the output parameter are.
*
* @param input Input fields
* @param output Output fields
- * @return
+ * @return output fields with unique names.
*/
public static List<RelDataTypeField> makeOutputUnique(
List<RelDataTypeField> input, List<RelDataTypeField> output) {
- final Set<String> inputFieldNames = new HashSet<>();
+ final Set<String> uniqueNames = new HashSet<>();
for (RelDataTypeField field : input) {
- inputFieldNames.add(field.getName());
+ uniqueNames.add(field.getName());
}
List<RelDataTypeField> result = new ArrayList<>();
for (RelDataTypeField field : output) {
String fieldName = field.getName();
- if (inputFieldNames.contains(fieldName)) {
- fieldName += "0"; // Append index to make it unique
+ int count = 0;
+ String candidate = fieldName;
+ while (uniqueNames.contains(candidate)) {
+ candidate = fieldName + count;
+ count++;
}
- result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(),
field.getType()));
+ uniqueNames.add(candidate);
+ result.add(new RelDataTypeFieldImpl(candidate, field.getIndex(),
field.getType()));
}
return result;
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
new file mode 100644
index 00000000000..5d85e6b88eb
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.stream.sql;
+
+import org.apache.flink.core.testutils.FlinkAssertions;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.api.ValidationException;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
+import org.apache.flink.table.planner.utils.TableTestBase;
+import org.apache.flink.table.planner.utils.TableTestUtil;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link SqlVectorSearchTableFunction}. */
+public class VectorSearchTableFunctionTest extends TableTestBase {
+
+ private TableTestUtil util;
+
+ @BeforeEach
+ public void setup() {
+ util = streamTestUtil(TableConfig.getDefault());
+
+ // Create test table
+ util.tableEnv()
+ .executeSql(
+ "CREATE TABLE QueryTable (\n"
+ + " a INT,\n"
+ + " b BIGINT,\n"
+ + " c STRING,\n"
+ + " d ARRAY<FLOAT>,\n"
+ + " rowtime TIMESTAMP(3),\n"
+ + " proctime as PROCTIME(),\n"
+ + " WATERMARK FOR rowtime AS rowtime -
INTERVAL '1' SECOND\n"
+ + ") with (\n"
+ + " 'connector' = 'values'\n"
+ + ")");
+
+ util.tableEnv()
+ .executeSql(
+ "CREATE TABLE VectorTable (\n"
+ + " e INT,\n"
+ + " f BIGINT,\n"
+ + " g ARRAY<FLOAT>\n"
+ + ") with (\n"
+ + " 'connector' = 'values'\n"
+ + ")");
+ }
+
+ @Test
+ void testSimple() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`g`),
QueryTable.d, 10"
+ + ")\n"
+ + ")";
+ util.verifyRelPlan(sql);
+ }
+
+ @Test
+ void testLiteralValue() {
+ String sql =
+ "SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ TableException.class,
+
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
+ + "+-
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database,
VectorTable]], fields=[e, f, g])"));
+ }
+
+ @Test
+ void testLiteralValueWithoutLateralKeyword() {
+ String sql =
+ "SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable,
DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ TableException.class,
+
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)],
rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
+ + "+-
FlinkLogicalTableSourceScan(table=[[default_catalog, default_database,
VectorTable]], fields=[e, f, g])"));
+ }
+
+ @Test
+ void testNamedArgument() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " SEARCH_TABLE => TABLE VectorTable,\n"
+ + " COLUMN_TO_QUERY => QueryTable.d,\n"
+ + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n"
+ + " TOP_K => 10"
+ + " )\n"
+ + ")";
+ util.verifyRelPlan(sql);
+ }
+
+ @Test
+ void testOutOfOrderNamedArgument() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " COLUMN_TO_QUERY => QueryTable.d,\n"
+ + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n"
+ + " TOP_K => 10,\n"
+ + " SEARCH_TABLE => TABLE VectorTable\n"
+ + " )\n"
+ + ")";
+ util.verifyRelPlan(sql);
+ }
+
+ @Test
+ void testNameConflicts() {
+ util.tableEnv()
+ .executeSql(
+ "CREATE TABLE NameConflictTable(\n"
+ + " a INT,\n"
+ + " score ARRAY<FLOAT>,\n"
+ + " score0 ARRAY<FLOAT>,\n"
+ + " score1 ARRAY<FLOAT>\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values'\n"
+ + ")");
+ util.verifyRelPlan(
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE NameConflictTable, DESCRIPTOR(`score`),
QueryTable.d, 10))");
+ }
+
+ @Test
+ void testDescriptorTypeIsNotExpected() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`f`),
QueryTable.d, 10"
+ + ")\n"
+ + ")";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ ValidationException.class,
+ "Expect search column `f` type is ARRAY<FLOAT>
or ARRAY<DOUBLE>, but its type is BIGINT."));
+ }
+
+ @Test
+ void testDescriptorContainsMultipleColumns() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`f`, `g`),
QueryTable.d, 10"
+ + ")\n"
+ + ")";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ ValidationException.class,
+ "Expect parameter COLUMN_TO_SEARCH for
VECTOR_SEARCH only contains one column, but multiple columns are found in
operand DESCRIPTOR(`f`, `g`)."));
+ }
+
+ @Test
+ void testQueryColumnIsNotArray() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`g`),
QueryTable.c, 10"
+ + ")\n"
+ + ")";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ ValidationException.class,
+ "Can not cast the query column type STRING to
target type FLOAT ARRAY. Please keep the query column type is same to the
search column type."));
+ }
+
+ @Test
+ void testIllegalTopKValue1() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`g`),
QueryTable.d, 10.0"
+ + ")\n"
+ + ")";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ ValidationException.class,
+ "Expect parameter topK is integer literal in
VECTOR_SEARCH, but it is 10.0 with type DECIMAL(3, 1) NOT NULL."));
+ }
+
+ @Test
+ void testIllegalTopKValue2() {
+ String sql =
+ "SELECT * FROM QueryTable, LATERAL TABLE(\n"
+ + "VECTOR_SEARCH(\n"
+ + " TABLE VectorTable, DESCRIPTOR(`g`),
QueryTable.d, QueryTable.a"
+ + ")\n"
+ + ")";
+ assertThatThrownBy(() -> util.verifyRelPlan(sql))
+ .satisfies(
+ FlinkAssertions.anyCauseMatches(
+ ValidationException.class,
+ "Expect parameter topK is integer literal in
VECTOR_SEARCH, but it is QueryTable.a with type INT."));
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
new file mode 100644
index 00000000000..8aca81dc52d
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml
@@ -0,0 +1,141 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testNameConflicts">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+VECTOR_SEARCH(
+ TABLE NameConflictTable, DESCRIPTOR(`score`), QueryTable.d, 10))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5],
a0=[$6], score=[$7], score0=[$8], score1=[$9], score2=[$10])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{3}])
+ :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4,
1000:INTERVAL SECOND)])
+ : +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4],
proctime=[PROCTIME()])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
QueryTable]])
+ +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'score'), $cor0.d, 10)], rowType=[RecordType(INTEGER a,
FLOAT ARRAY score, FLOAT ARRAY score0, FLOAT ARRAY score1, DOUBLE score2)])
+ +- LogicalProject(a=[$0], score=[$1], score0=[$2], score1=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
NameConflictTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime,
a0, score, score0, score1, score2])
++- Correlate(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'score'), $cor0.d, 10)],
correlate=[table(VECTOR_SEARCH(TABLE(),DESCRIPTOR('score'),$cor0.d,10))],
select=[a,b,c,d,rowtime,proctime,a0,score,score0,score1,score2],
rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, FLOAT ARRAY d,
TIMESTAMP(3) *ROWTIME* rowtime, TIMESTAMP_LTZ(3) *PROCTIME* proctime, INTEGER
a0, FLOAT ARRAY score, FLOAT ARRAY score0, FLOAT ARRAY score1, DOUBLE score2)],
joinType=[INNER])
+ +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL
SECOND)])
+ +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+ +- TableSourceScan(table=[[default_catalog, default_database,
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNamedArgument">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+VECTOR_SEARCH(
+ SEARCH_TABLE => TABLE VectorTable,
+ COLUMN_TO_QUERY => QueryTable.d,
+ COLUMN_TO_SEARCH => DESCRIPTOR(`g`),
+ TOP_K => 10 )
+)]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5],
e=[$6], f=[$7], g=[$8], score=[$9])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{3}])
+ :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4,
1000:INTERVAL SECOND)])
+ : +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4],
proctime=[PROCTIME()])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
QueryTable]])
+ +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, BIGINT
f, FLOAT ARRAY g, DOUBLE score)])
+ +- LogicalProject(e=[$0], f=[$1], g=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
VectorTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime,
e, f, g, score])
++- Correlate(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'),
$cor0.d, 10)],
correlate=[table(VECTOR_SEARCH(TABLE(),DESCRIPTOR('g'),$cor0.d,10))],
select=[a,b,c,d,rowtime,proctime,e,f,g,score], rowType=[RecordType(INTEGER a,
BIGINT b, VARCHAR(2147483647) c, FLOAT ARRAY d, TIMESTAMP(3) *ROWTIME* rowtime,
TIMESTAMP_LTZ(3) *PROCTIME* proctime, INTEGER e, BIGINT f, FLOAT ARRAY g,
DOUBLE score)], joinType=[INNER])
+ +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL
SECOND)])
+ +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+ +- TableSourceScan(table=[[default_catalog, default_database,
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testSimple">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+VECTOR_SEARCH(
+ TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, 10)
+)]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5],
e=[$6], f=[$7], g=[$8], score=[$9])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{3}])
+ :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4,
1000:INTERVAL SECOND)])
+ : +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4],
proctime=[PROCTIME()])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
QueryTable]])
+ +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, BIGINT
f, FLOAT ARRAY g, DOUBLE score)])
+ +- LogicalProject(e=[$0], f=[$1], g=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
VectorTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime,
e, f, g, score])
++- Correlate(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'),
$cor0.d, 10)],
correlate=[table(VECTOR_SEARCH(TABLE(),DESCRIPTOR('g'),$cor0.d,10))],
select=[a,b,c,d,rowtime,proctime,e,f,g,score], rowType=[RecordType(INTEGER a,
BIGINT b, VARCHAR(2147483647) c, FLOAT ARRAY d, TIMESTAMP(3) *ROWTIME* rowtime,
TIMESTAMP_LTZ(3) *PROCTIME* proctime, INTEGER e, BIGINT f, FLOAT ARRAY g,
DOUBLE score)], joinType=[INNER])
+ +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL
SECOND)])
+ +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+ +- TableSourceScan(table=[[default_catalog, default_database,
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testOutOfOrderNamedArgument">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
+VECTOR_SEARCH(
+ COLUMN_TO_QUERY => QueryTable.d,
+ COLUMN_TO_SEARCH => DESCRIPTOR(`g`),
+ TOP_K => 10,
+ SEARCH_TABLE => TABLE VectorTable
+ )
+)]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4], proctime=[$5],
e=[$6], f=[$7], g=[$8], score=[$9])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{3}])
+ :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($4,
1000:INTERVAL SECOND)])
+ : +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], rowtime=[$4],
proctime=[PROCTIME()])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
QueryTable]])
+ +- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0),
DESCRIPTOR(_UTF-16LE'g'), $cor0.d, 10)], rowType=[RecordType(INTEGER e, BIGINT
f, FLOAT ARRAY g, DOUBLE score)])
+ +- LogicalProject(e=[$0], f=[$1], g=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
VectorTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, rowtime, PROCTIME_MATERIALIZE(proctime) AS proctime,
e, f, g, score])
++- Correlate(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'),
$cor0.d, 10)],
correlate=[table(VECTOR_SEARCH(TABLE(),DESCRIPTOR('g'),$cor0.d,10))],
select=[a,b,c,d,rowtime,proctime,e,f,g,score], rowType=[RecordType(INTEGER a,
BIGINT b, VARCHAR(2147483647) c, FLOAT ARRAY d, TIMESTAMP(3) *ROWTIME* rowtime,
TIMESTAMP_LTZ(3) *PROCTIME* proctime, INTEGER e, BIGINT f, FLOAT ARRAY g,
DOUBLE score)], joinType=[INNER])
+ +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 1000:INTERVAL
SECOND)])
+ +- Calc(select=[a, b, c, d, rowtime, PROCTIME() AS proctime])
+ +- TableSourceScan(table=[[default_catalog, default_database,
QueryTable]], fields=[a, b, c, d, rowtime])
+]]>
+ </Resource>
+ </TestCase>
+</Root>