This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 26d9082b9a5bb098901187daca69aa407827fec4
Author: Jensen <[email protected]>
AuthorDate: Wed Apr 10 14:26:48 2024 +0800

    [Feature](function) Add function strcmp (#33272)
---
 be/src/vec/functions/function_string.cpp           |  1 +
 be/src/vec/functions/function_string.h             | 62 +++++++++++++++++++
 be/test/vec/function/function_string_test.cpp      | 66 ++++++++++++++++++++
 .../doris/catalog/BuiltinScalarFunctions.java      |  2 +
 .../trees/expressions/functions/scalar/Strcmp.java | 70 ++++++++++++++++++++++
 .../expressions/visitor/ScalarFunctionVisitor.java |  5 ++
 gensrc/script/doris_builtins_functions.py          |  5 +-
 7 files changed, 210 insertions(+), 1 deletion(-)

diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index 69f8699b5c8..bfbd57f4747 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -1159,6 +1159,7 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     factory.register_function<FunctionMaskPartial<false>>();
     factory.register_function<FunctionSubReplace<SubReplaceThreeImpl>>();
     factory.register_function<FunctionSubReplace<SubReplaceFourImpl>>();
+    factory.register_function<FunctionStrcmp>();
 
     /// @TEMPORARY: for be_exec_version=3
     
factory.register_alternative_function<FunctionSubstringOld<Substr3ImplOld>>();
diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index 9ae686f3398..515f9ad11ac 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -294,6 +294,68 @@ private:
     }
 };
 
+class FunctionStrcmp : public IFunction {
+public:
+    static constexpr auto name = "strcmp";
+
+    static FunctionPtr create() { return std::make_shared<FunctionStrcmp>(); }
+
+    String get_name() const override { return name; }
+
+    size_t get_number_of_arguments() const override { return 2; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return std::make_shared<DataTypeInt8>();
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) const override 
{
+        const auto& [arg0_column, arg0_const] =
+                unpack_if_const(block.get_by_position(arguments[0]).column);
+        const auto& [arg1_column, arg1_const] =
+                unpack_if_const(block.get_by_position(arguments[1]).column);
+
+        auto result_column = ColumnInt8::create(input_rows_count);
+
+        if (auto arg0 = check_and_get_column<ColumnString>(arg0_column.get())) 
{
+            if (auto arg1 = 
check_and_get_column<ColumnString>(arg1_column.get())) {
+                if (arg0_const) {
+                    scalar_vector(arg0->get_data_at(0), *arg1, *result_column);
+                } else if (arg1_const) {
+                    vector_scalar(*arg0, arg1->get_data_at(0), *result_column);
+                } else {
+                    vector_vector(*arg0, *arg1, *result_column);
+                }
+            }
+        }
+
+        block.replace_by_position(result, std::move(result_column));
+        return Status::OK();
+    }
+
+private:
+    static void scalar_vector(const StringRef str, const ColumnString& vec1, 
ColumnInt8& res) {
+        size_t size = vec1.size();
+        for (size_t i = 0; i < size; ++i) {
+            res.get_data()[i] = str.compare(vec1.get_data_at(i));
+        }
+    }
+
+    static void vector_scalar(const ColumnString& vec0, const StringRef str, 
ColumnInt8& res) {
+        size_t size = vec0.size();
+        for (size_t i = 0; i < size; ++i) {
+            res.get_data()[i] = vec0.get_data_at(i).compare(str);
+        }
+    }
+
+    static void vector_vector(const ColumnString& vec0, const ColumnString& 
vec1, ColumnInt8& res) {
+        size_t size = vec0.size();
+        for (size_t i = 0; i < size; ++i) {
+            res.get_data()[i] = 
vec0.get_data_at(i).compare(vec1.get_data_at(i));
+        }
+    }
+};
+
 struct SubstringUtilOld {
     static constexpr auto name = "substring";
 
diff --git a/be/test/vec/function/function_string_test.cpp 
b/be/test/vec/function/function_string_test.cpp
index 2c0fecbb300..6e3c2bba957 100644
--- a/be/test/vec/function/function_string_test.cpp
+++ b/be/test/vec/function/function_string_test.cpp
@@ -1199,4 +1199,70 @@ TEST(function_string_test, function_uuid_test) {
     }
 }
 
+TEST(function_string_test, function_strcmp_test) {
+    std::string func_name = "strcmp";
+    {
+        InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
+
+        DataSet data_set = {{{Null(), Null()}, Null()},
+                            {{std::string(""), std::string("")}, (int8_t)0},
+                            {{std::string("test"), std::string("test")}, 
(int8_t)0},
+                            {{std::string("test1"), std::string("test")}, 
(int8_t)1},
+                            {{std::string("test"), std::string("test1")}, 
(int8_t)-1},
+                            {{Null(), std::string("test")}, Null()},
+                            {{std::string("test"), Null()}, Null()},
+                            {{VARCHAR(""), VARCHAR("")}, (int8_t)0},
+                            {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
+                            {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
+                            {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
+                            {{Null(), VARCHAR("test")}, Null()},
+                            {{VARCHAR("test"), Null()}, Null()}};
+        static_cast<void>(check_function<DataTypeInt8, true>(func_name, 
input_types, data_set));
+    }
+    {
+        InputTypeSet input_types = {Consted {TypeIndex::String}, 
TypeIndex::String};
+        DataSet data_set = {{{Null(), Null()}, Null()},
+                            {{std::string(""), std::string("")}, (int8_t)0},
+                            {{std::string("test"), std::string("test")}, 
(int8_t)0},
+                            {{std::string("test1"), std::string("test")}, 
(int8_t)1},
+                            {{std::string("test"), std::string("test1")}, 
(int8_t)-1},
+                            {{Null(), std::string("test")}, Null()},
+                            {{std::string("test"), Null()}, Null()},
+                            {{VARCHAR(""), VARCHAR("")}, (int8_t)0},
+                            {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
+                            {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
+                            {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
+                            {{Null(), VARCHAR("test")}, Null()},
+                            {{VARCHAR("test"), Null()}, Null()}};
+
+        for (const auto& line : data_set) {
+            DataSet const_dataset = {line};
+            static_cast<void>(
+                    check_function<DataTypeInt8, true>(func_name, input_types, 
const_dataset));
+        }
+    }
+    {
+        InputTypeSet input_types = {TypeIndex::String, Consted 
{TypeIndex::String}};
+        DataSet data_set = {{{Null(), Null()}, Null()},
+                            {{std::string(""), std::string("")}, (int8_t)0},
+                            {{std::string("test"), std::string("test")}, 
(int8_t)0},
+                            {{std::string("test1"), std::string("test")}, 
(int8_t)1},
+                            {{std::string("test"), std::string("test1")}, 
(int8_t)-1},
+                            {{Null(), std::string("test")}, Null()},
+                            {{std::string("test"), Null()}, Null()},
+                            {{VARCHAR(""), VARCHAR("")}, (int8_t)0},
+                            {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
+                            {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
+                            {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
+                            {{Null(), VARCHAR("test")}, Null()},
+                            {{VARCHAR("test"), Null()}, Null()}};
+
+        for (const auto& line : data_set) {
+            DataSet const_dataset = {line};
+            static_cast<void>(
+                    check_function<DataTypeInt8, true>(func_name, input_types, 
const_dataset));
+        }
+    }
+}
+
 } // namespace doris::vectorized
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index d27b0f3a311..1654a2098db 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -384,6 +384,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StartsWith;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrLeft;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrRight;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrToDate;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Strcmp;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.SubBitmap;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.SubReplace;
@@ -843,6 +844,7 @@ public class BuiltinScalarFunctions implements 
FunctionHelper {
             scalar(StX.class, "st_x"),
             scalar(StY.class, "st_y"),
             scalar(StartsWith.class, "starts_with"),
+            scalar(Strcmp.class, "strcmp"),
             scalar(StrLeft.class, "strleft"),
             scalar(StrRight.class, "strright"),
             scalar(StrToDate.class, "str_to_date"),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java
new file mode 100644
index 00000000000..b9aaff85fce
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.StringType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.types.VarcharType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * ScalarFunction 'strcmp'.
+ */
+public class Strcmp extends ScalarFunction
+        implements BinaryExpression, ExplicitlyCastableSignature, 
PropagateNullable {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.ret(TinyIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, 
VarcharType.SYSTEM_DEFAULT),
+            
FunctionSignature.ret(TinyIntType.INSTANCE).args(StringType.INSTANCE, 
StringType.INSTANCE));
+
+    /**
+     * constructor with 2 argument.
+     */
+    public Strcmp(Expression arg0, Expression arg1) {
+        super("strcmp", arg0, arg1);
+    }
+
+    /**
+     * withChildren.
+     */
+    @Override
+    public Strcmp withChildren(List<Expression> children) {
+        Preconditions.checkArgument(children.size() == 2);
+        return new Strcmp(children.get(0), children.get(1));
+    }
+
+    @Override
+    public List<FunctionSignature> getSignatures() {
+        return SIGNATURES;
+    }
+
+    @Override
+    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        return visitor.visitStrcmp(this, context);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index baa801f7786..cb26a8bf4dc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -382,6 +382,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StartsWith;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrLeft;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrRight;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.StrToDate;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Strcmp;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.SubBitmap;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.SubReplace;
@@ -2016,6 +2017,10 @@ public interface ScalarFunctionVisitor<R, C> {
         return visitScalarFunction(inttoUuid, context);
     }
 
+    default R visitStrcmp(Strcmp strcmp, C context) {
+        return visitScalarFunction(strcmp, context);
+    }
+
     default R visitVersion(Version version, C context) {
         return visitScalarFunction(version, context);
     }
diff --git a/gensrc/script/doris_builtins_functions.py 
b/gensrc/script/doris_builtins_functions.py
index 0e7615829d9..fdc08755307 100644
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -1619,6 +1619,7 @@ visible_functions = {
         [['sub_replace'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'INT', 'INT'], 
'ALWAYS_NULLABLE'],
 
         [['char'], 'VARCHAR', ['VARCHAR', 'INT', '...'], 'ALWAYS_NULLABLE'],
+        [['strcmp'], 'INT', ['VARCHAR', 'VARCHAR'], 'DEPEND_ON_ARGUMENT'],
 
         [['substr', 'substring'], 'STRING', ['STRING', 'INT'], 
'DEPEND_ON_ARGUMENT'],
         [['substr', 'substring'], 'STRING', ['STRING', 'INT', 'INT'], 
'DEPEND_ON_ARGUMENT'],
@@ -1670,7 +1671,8 @@ visible_functions = {
         [['split_part'], 'STRING', ['STRING', 'STRING', 'INT'], 
'ALWAYS_NULLABLE'],
         [['substring_index'], 'STRING', ['STRING', 'STRING', 'INT'], 
'DEPEND_ON_ARGUMENT'],
         [['url_decode'], 'STRING', ['STRING'], ''],
-        [['random_bytes'], 'STRING', ['INT'], '']
+        [['random_bytes'], 'STRING', ['INT'], ''],
+        [['strcmp'], 'INT', ['STRING', 'STRING'], 'DEPEND_ON_ARGUMENT']
     ],
 
 
@@ -2245,6 +2247,7 @@ null_result_with_one_null_param_functions = [
     'fmod',
     'substr',
     'substring',
+    'strcmp'
     'append_trailing_char_if_absent',
     'ST_X',
     'ST_Y',


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to