This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 579dd1013b5 [Feature] Add FE constant folding for cosine_similarity
and standardize test patterns (#60403)
579dd1013b5 is described below
commit 579dd1013b53d90cfb46c3b7ca07ca21fad583ed
Author: Copilot <[email protected]>
AuthorDate: Tue Feb 24 21:59:45 2026 +0800
[Feature] Add FE constant folding for cosine_similarity and standardize
test patterns (#60403)
Adds FE constant folding support for `cosine_similarity` and replaces
try-catch blocks in regression tests with the standard `test{sql,
exception}` pattern.
Co-authored-by: copilot-swe-agent[bot]
<[email protected]>
Co-authored-by: zclllyybb <[email protected]>
---
.../functions/array/function_array_distance.cpp | 18 +++
.../vec/functions/array/function_array_distance.h | 6 +
.../function_array_cosine_similarity_test.cpp | 159 +++++++++++++++++++++
.../doris/catalog/BuiltinScalarFunctions.java | 2 +
.../trees/expressions/ExpressionEvaluator.java | 2 +
.../functions/executable/ArrayArithmetic.java | 88 ++++++++++++
.../functions/scalar/CosineSimilarity.java | 75 ++++++++++
.../expressions/visitor/ScalarFunctionVisitor.java | 5 +
.../test_array_distance_functions.out | 58 ++++++++
.../test_array_distance_functions.groovy | 141 ++++++++++++++++--
10 files changed, 539 insertions(+), 15 deletions(-)
diff --git a/be/src/vec/functions/array/function_array_distance.cpp
b/be/src/vec/functions/array/function_array_distance.cpp
index 0d7b40f2700..e070652f492 100644
--- a/be/src/vec/functions/array/function_array_distance.cpp
+++ b/be/src/vec/functions/array/function_array_distance.cpp
@@ -38,10 +38,28 @@ float CosineDistance::distance(const float* x, const float*
y, size_t d) {
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
+FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
+float CosineSimilarity::distance(const float* x, const float* y, size_t d) {
+ float dot_prod = 0;
+ float squared_x = 0;
+ float squared_y = 0;
+ for (size_t i = 0; i < d; ++i) {
+ dot_prod += x[i] * y[i];
+ squared_x += x[i] * x[i];
+ squared_y += y[i] * y[i];
+ }
+ if (squared_x == 0 or squared_y == 0) {
+ return 0.0f;
+ }
+ return dot_prod / sqrt(squared_x * squared_y);
+}
+FAISS_PRAGMA_IMPRECISE_FUNCTION_END
+
void register_function_array_distance(SimpleFunctionFactory& factory) {
factory.register_function<FunctionArrayDistance<L1Distance>>();
factory.register_function<FunctionArrayDistance<L2Distance>>();
factory.register_function<FunctionArrayDistance<CosineDistance>>();
+ factory.register_function<FunctionArrayDistance<CosineSimilarity>>();
factory.register_function<FunctionArrayDistance<InnerProduct>>();
factory.register_function<FunctionArrayDistance<L2DistanceApproximate>>();
factory.register_function<FunctionArrayDistance<InnerProductApproximate>>();
diff --git a/be/src/vec/functions/array/function_array_distance.h
b/be/src/vec/functions/array/function_array_distance.h
index 946f28b7e50..a3b6fb7adc5 100644
--- a/be/src/vec/functions/array/function_array_distance.h
+++ b/be/src/vec/functions/array/function_array_distance.h
@@ -69,6 +69,12 @@ public:
static float distance(const float* x, const float* y, size_t d);
};
+class CosineSimilarity {
+public:
+ static constexpr auto name = "cosine_similarity";
+ static float distance(const float* x, const float* y, size_t d);
+};
+
class L2DistanceApproximate : public L2Distance {
public:
static constexpr auto name = "l2_distance_approximate";
diff --git a/be/test/vec/function/function_array_cosine_similarity_test.cpp
b/be/test/vec/function/function_array_cosine_similarity_test.cpp
new file mode 100644
index 00000000000..b5f5c734093
--- /dev/null
+++ b/be/test/vec/function/function_array_cosine_similarity_test.cpp
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <string>
+
+#include "function_test_util.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_number.h"
+
+namespace doris::vectorized {
+
+TEST(function_cosine_similarity_test, cosine_similarity) {
+ std::string func_name = "cosine_similarity";
+ TestArray empty_arr;
+
+ // cosine_similarity(Array<Float>, Array<Float>) - identical vectors
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(1.0), Float32(2.0), Float32(3.0)};
+ TestArray vec2 = {Float32(1.0), Float32(2.0), Float32(3.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(1.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - orthogonal vectors
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(1.0), Float32(0.0)};
+ TestArray vec2 = {Float32(0.0), Float32(1.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(0.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - opposite vectors
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(1.0), Float32(2.0), Float32(3.0)};
+ TestArray vec2 = {Float32(-1.0), Float32(-2.0), Float32(-3.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(-1.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - zero vector handling
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(0.0), Float32(0.0), Float32(0.0)};
+ TestArray vec2 = {Float32(1.0), Float32(2.0), Float32(3.0)};
+ TestArray vec3 = {Float32(0.0), Float32(0.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(0.0)},
+ {{vec2, vec1}, Float32(0.0)},
+ {{vec3, vec3}, Float32(0.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - empty arrays
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ DataSet data_set = {{{empty_arr, empty_arr}, Float32(0.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - known value test
+ // cos_sim([1,2,3], [3,5,7]) = 34 / sqrt(14*83) ≈ 0.9974149
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(1.0), Float32(2.0), Float32(3.0)};
+ TestArray vec2 = {Float32(3.0), Float32(5.0), Float32(7.0)};
+ // Expected: 34 / sqrt(14 * 83) = 34 / sqrt(1162) ≈ 0.9974149
+ float expected = 34.0f / std::sqrt(14.0f * 83.0f);
+ DataSet data_set = {{{vec1, vec2}, Float32(expected)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - 2D vectors
+ // cos_sim([3,4], [4,3]) = 24 / sqrt(25*25) = 24/25 = 0.96
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(3.0), Float32(4.0)};
+ TestArray vec2 = {Float32(4.0), Float32(3.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(0.96)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - single element
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(5.0)};
+ TestArray vec2 = {Float32(10.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(1.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - negative values
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(-1.0), Float32(-2.0)};
+ TestArray vec2 = {Float32(1.0), Float32(2.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(-1.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+
+ // cosine_similarity(Array<Float>, Array<Float>) - mixed values
+ {
+ InputTypeSet input_types = {PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT,
+ PrimitiveType::TYPE_ARRAY,
PrimitiveType::TYPE_FLOAT};
+
+ TestArray vec1 = {Float32(1.0), Float32(-1.0), Float32(1.0)};
+ TestArray vec2 = {Float32(-1.0), Float32(1.0), Float32(-1.0)};
+ DataSet data_set = {{{vec1, vec2}, Float32(-1.0)}};
+
+ static_cast<void>(check_function<DataTypeFloat32, false>(func_name,
input_types, data_set));
+ }
+}
+
+} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index f9fb681b7ce..c8719dbaeb5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -145,6 +145,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTz;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.CosineSimilarity;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
@@ -713,6 +714,7 @@ public class BuiltinScalarFunctions implements
FunctionHelper {
scalar(Cosh.class, "cosh"),
scalar(Cot.class, "cot"),
scalar(CosineDistance.class, "cosine_distance"),
+ scalar(CosineSimilarity.class, "cosine_similarity"),
scalar(CountEqual.class, "countequal"),
scalar(CountSubstring.class, "count_substrings"),
scalar(CreateMap.class, "map"),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
index 17504ceb8e2..02d35c10d4c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.NotSupportedException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
+import
org.apache.doris.nereids.trees.expressions.functions.executable.ArrayArithmetic;
import
org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
import
org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
import
org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
@@ -177,6 +178,7 @@ public enum ExpressionEvaluator {
}
ImmutableMultimap.Builder<String, Method> mapBuilder = new
ImmutableMultimap.Builder<>();
List<Class<?>> classes = ImmutableList.of(
+ ArrayArithmetic.class,
DateTimeAcquire.class,
DateTimeExtractAndTransform.class,
DateLiteral.class,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
new file mode 100644
index 00000000000..d1cd087fce6
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.executable;
+
+import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.ExecFunction;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+
+import java.util.List;
+
+/**
+ * Executable functions for array operations.
+ */
+public class ArrayArithmetic {
+
+ /**
+ * Compute cosine similarity between two float arrays.
+ * cosine_similarity(x, y) = dot(x, y) / (||x|| * ||y||)
+ */
+ @ExecFunction(name = "cosine_similarity")
+ public static Expression cosineSimilarity(ArrayLiteral array1,
ArrayLiteral array2) {
+ List<Literal> items1 = array1.getValue();
+ List<Literal> items2 = array2.getValue();
+
+ // Check for null elements
+ for (Literal item : items1) {
+ if (item instanceof NullLiteral) {
+ throw new AnalysisException("function cosine_similarity cannot
have null");
+ }
+ }
+ for (Literal item : items2) {
+ if (item instanceof NullLiteral) {
+ throw new AnalysisException("function cosine_similarity cannot
have null");
+ }
+ }
+
+ // Check array sizes
+ if (items1.size() != items2.size()) {
+ throw new AnalysisException("function cosine_similarity have
different input element sizes of array: "
+ + items1.size() + " and " + items2.size());
+ }
+
+ // Handle empty arrays
+ if (items1.isEmpty()) {
+ return new FloatLiteral(0.0f);
+ }
+
+ // Compute dot product and squared norms
+ double dotProd = 0.0;
+ double squaredX = 0.0;
+ double squaredY = 0.0;
+
+ for (int i = 0; i < items1.size(); i++) {
+ double x = ((Number) items1.get(i).getValue()).doubleValue();
+ double y = ((Number) items2.get(i).getValue()).doubleValue();
+ dotProd += x * y;
+ squaredX += x * x;
+ squaredY += y * y;
+ }
+
+ // Handle zero vectors
+ if (squaredX == 0.0 || squaredY == 0.0) {
+ return new FloatLiteral(0.0f);
+ }
+
+ float result = (float) (dotProd / Math.sqrt(squaredX * squaredY));
+ return new FloatLiteral(result);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineSimilarity.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineSimilarity.java
new file mode 100644
index 00000000000..2c4662ab3b4
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineSimilarity.java
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.FloatType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * cosine_similarity function
+ */
+public class CosineSimilarity extends ScalarFunction implements
ExplicitlyCastableSignature,
+ BinaryExpression, AlwaysNotNullable {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+ FunctionSignature.ret(FloatType.INSTANCE)
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
+ );
+
+ /**
+ * constructor with 2 arguments.
+ */
+ public CosineSimilarity(Expression arg0, Expression arg1) {
+ super("cosine_similarity", arg0, arg1);
+ }
+
+ /** constructor for withChildren and reuse signature */
+ private CosineSimilarity(ScalarFunctionParams functionParams) {
+ super(functionParams);
+ }
+
+ /**
+ * withChildren.
+ */
+ @Override
+ public CosineSimilarity withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new CosineSimilarity(getFunctionParams(children));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitCosineSimilarity(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index 94336d03b92..c32d2157d50 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -156,6 +156,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTz;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.CosineSimilarity;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
@@ -1115,6 +1116,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(cosineDistance, context);
}
+ default R visitCosineSimilarity(CosineSimilarity cosineSimilarity, C
context) {
+ return visitScalarFunction(cosineSimilarity, context);
+ }
+
default R visitCountEqual(CountEqual countequal, C context) {
return visitScalarFunction(countequal, context);
}
diff --git
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
index 021b0d34120..714da1f5f9b 100644
---
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
+++
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
@@ -29,3 +29,61 @@
-- !sql --
0.0
+-- !cosine_sim_identical --
+1.0
+
+-- !cosine_sim_orthogonal --
+0.0
+
+-- !cosine_sim_opposite --
+-1.0
+
+-- !cosine_sim_float --
+0.9746318
+
+-- !cosine_sim_known --
+0.9974149
+
+-- !cosine_sim_single --
+1.0
+
+-- !cosine_sim_2d --
+0.96
+
+-- !cosine_sim_negative --
+-1.0
+
+-- !cosine_sim_zero_first --
+0.0
+
+-- !cosine_sim_zero_second --
+0.0
+
+-- !cosine_sim_both_zero --
+0.0
+
+-- !cosine_sim_single_zero --
+0.0
+
+-- !cosine_sim_mixed --
+-1.0
+
+-- !cosine_sim_distance_relation --
+1.0
+
+-- !cosine_sim_empty --
+0.0
+
+-- !cosine_sim_large --
+0.98387
+
+-- !cosine_sim_small --
+0.9838699
+
+-- !cosine_sim_table --
+1 1.0
+2 0.0
+3 0.9746318
+4 -1.0
+5 0.96
+
diff --git
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
index b97f15ac29c..db1cae31054 100644
---
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
+++
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
@@ -51,40 +51,151 @@ suite("test_array_distance_functions") {
}
// abnormal test cases
- try {
+ test {
sql "SELECT l2_distance([0, 0], [1])"
- } catch (Exception ex) {
- assert("${ex}".contains("function l2_distance have different input
element sizes"))
+ exception "function l2_distance have different input element sizes"
}
- try {
+ test {
sql "SELECT cosine_distance([NULL], [NULL, NULL])"
- } catch (Exception ex) {
- assert("${ex}".contains("function cosine_distance cannot have null"))
+ exception "function cosine_distance cannot have null"
}
// Test cases for the nullable array offset fix
// These cases specifically test scenarios where absolute offsets might
differ
// but actual array sizes are the same (should pass) or different (should
fail)
- try {
+ test {
sql "SELECT l1_distance([1.0, 2.0, 3.0], [4.0, 5.0])"
- } catch (Exception ex) {
- assert("${ex}".contains("function l1_distance have different input
element sizes"))
+ exception "function l1_distance have different input element sizes"
}
- try {
+ test {
sql "SELECT inner_product([1.0], [2.0, 3.0, 4.0])"
- } catch (Exception ex) {
- assert("${ex}".contains("function inner_product have different input
element sizes"))
+ exception "function inner_product have different input element sizes"
}
- try {
+ test {
sql "SELECT l1_distance([1, 2, 3], [0, NULL, 0])"
- } catch (Exception ex) {
- assert("${ex}".contains("function l1_distance cannot have null"))
+ exception "function l1_distance cannot have null"
}
// Edge case: empty arrays should work
qt_sql "SELECT l1_distance(CAST([] as ARRAY<DOUBLE>), CAST([] as
ARRAY<DOUBLE>))"
qt_sql "SELECT l2_distance(CAST([] as ARRAY<DOUBLE>), CAST([] as
ARRAY<DOUBLE>))"
+
+ // =========================
+ // cosine_similarity tests
+ // =========================
+
+ // Basic test: identical vectors have similarity of 1.0
+ qt_cosine_sim_identical "SELECT cosine_similarity([1, 2, 3], [1, 2, 3])"
+
+ // Basic test: orthogonal vectors have similarity of 0.0
+ qt_cosine_sim_orthogonal "SELECT cosine_similarity([1, 0], [0, 1])"
+
+ // Basic test: opposite vectors have similarity of -1.0
+ qt_cosine_sim_opposite "SELECT cosine_similarity([1, 2, 3], [-1, -2, -3])"
+
+ // Test with float arrays
+ qt_cosine_sim_float "SELECT cosine_similarity([1.0, 2.0, 3.0], [4.0, 5.0,
6.0])"
+
+ // Test known value: cos(theta) = (1*3 + 2*5 + 3*7) / (sqrt(14) *
sqrt(83)) = 34 / sqrt(1162) ≈ 0.9974
+ qt_cosine_sim_known "SELECT cosine_similarity([1, 2, 3], [3, 5, 7])"
+
+ // Test with single element arrays
+ qt_cosine_sim_single "SELECT cosine_similarity([5], [10])"
+
+ // Test with 2D vectors
+ qt_cosine_sim_2d "SELECT cosine_similarity([3, 4], [4, 3])"
+
+ // Test with negative values
+ qt_cosine_sim_negative "SELECT cosine_similarity([-1, -2], [1, 2])"
+
+ // Test zero vector handling: returns 0.0 when either vector is zero
+ qt_cosine_sim_zero_first "SELECT cosine_similarity([0, 0, 0], [1, 2, 3])"
+ qt_cosine_sim_zero_second "SELECT cosine_similarity([1, 2, 3], [0, 0, 0])"
+ qt_cosine_sim_both_zero "SELECT cosine_similarity([0, 0], [0, 0])"
+ qt_cosine_sim_single_zero "SELECT cosine_similarity([0], [0])"
+
+ // Test with mixed positive and negative
+ qt_cosine_sim_mixed "SELECT cosine_similarity([1, -1, 1], [-1, 1, -1])"
+
+ // Test relationship with cosine_distance: cosine_similarity = 1 -
cosine_distance
+ // For non-zero vectors, these should sum to 1.0
+ qt_cosine_sim_distance_relation "SELECT cosine_similarity([1, 2, 3], [3,
5, 7]) + cosine_distance([1, 2, 3], [3, 5, 7])"
+
+ // Test empty arrays
+ qt_cosine_sim_empty "SELECT cosine_similarity(CAST([] as ARRAY<FLOAT>),
CAST([] as ARRAY<FLOAT>))"
+
+ // Test NULL handling: should throw exception
+ test {
+ sql "SELECT cosine_similarity([1, 2, 3], NULL)"
+ exception "function cosine_similarity cannot be null"
+ }
+
+ test {
+ sql "SELECT cosine_similarity(NULL, [1, 2, 3])"
+ exception "function cosine_similarity cannot be null"
+ }
+
+ test {
+ sql "SELECT cosine_similarity(NULL, NULL)"
+ exception "function cosine_similarity cannot be null"
+ }
+
+ // Test array with NULL element: should throw exception
+ test {
+ sql "SELECT cosine_similarity([1, NULL, 3], [4, 5, 6])"
+ exception "function cosine_similarity cannot have null"
+ }
+
+ test {
+ sql "SELECT cosine_similarity([1, 2, 3], [4, NULL, 6])"
+ exception "function cosine_similarity cannot have null"
+ }
+
+ // Test different array sizes: should throw exception
+ test {
+ sql "SELECT cosine_similarity([1, 2], [1, 2, 3])"
+ exception "function cosine_similarity have different input element
sizes"
+ }
+
+ test {
+ sql "SELECT cosine_similarity([1, 2, 3, 4], [1, 2])"
+ exception "function cosine_similarity have different input element
sizes"
+ }
+
+ // Test large values
+ qt_cosine_sim_large "SELECT cosine_similarity([1000000, 2000000],
[3000000, 4000000])"
+
+ // Test small values
+ qt_cosine_sim_small "SELECT cosine_similarity([0.001, 0.002], [0.003,
0.004])"
+
+ // Test with multiple rows using table
+ sql "DROP TABLE IF EXISTS test_cosine_similarity_table"
+ sql """
+ CREATE TABLE test_cosine_similarity_table (
+ id INT,
+ vec1 ARRAY<FLOAT>,
+ vec2 ARRAY<FLOAT>
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+
+ sql """
+ INSERT INTO test_cosine_similarity_table VALUES
+ (1, [1, 0, 0], [1, 0, 0]),
+ (2, [1, 0, 0], [0, 1, 0]),
+ (3, [1, 2, 3], [4, 5, 6]),
+ (4, [1, 1], [-1, -1]),
+ (5, [3, 4], [4, 3])
+ """
+
+ qt_cosine_sim_table "SELECT id, cosine_similarity(vec1, vec2) as
similarity FROM test_cosine_similarity_table ORDER BY id"
+
+ sql "DROP TABLE IF EXISTS test_cosine_similarity_table"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]