This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new f751ca4e045 [branch-2.1](functions) fix be crash for function
random_bytes and mark_first/last_n (#36003)
f751ca4e045 is described below
commit f751ca4e045a069bca01be77b2bf3f6c4ba0e200
Author: zclllyybb <[email protected]>
AuthorDate: Fri Jun 7 10:30:41 2024 +0800
[branch-2.1](functions) fix be crash for function random_bytes and
mark_first/last_n (#36003)
pick #35884
---
be/src/vec/functions/function_string.h | 31 +++++++++++-----------
.../expressions/functions/scalar/MaskFirstN.java | 8 ++++++
.../expressions/functions/scalar/MaskLastN.java | 8 ++++++
.../correctness_p0/test_mask_function.groovy | 17 ++++++++++++
.../nereids_function_p0/scalar_function/R.groovy | 4 +++
5 files changed, 52 insertions(+), 16 deletions(-)
diff --git a/be/src/vec/functions/function_string.h
b/be/src/vec/functions/function_string.h
index fbaed751c7d..31c6cbb5ecb 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -792,10 +792,7 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override
{
- DCHECK_GE(arguments.size(), 1);
- DCHECK_LE(arguments.size(), 2);
-
- int n = -1;
+ int n = -1; // means unassigned
auto res = ColumnString::create();
auto col =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
@@ -803,17 +800,20 @@ public:
if (arguments.size() == 2) {
const auto& col = *block.get_by_position(arguments[1]).column;
+ // the 2nd arg is const. checked in fe.
+ if (col.get_int(0) < 0) [[unlikely]] {
+ return Status::InvalidArgument(
+ "function {} only accept non-negative input for 2nd
argument but got {}",
+ name, col.get_int(0));
+ }
n = col.get_int(0);
- } else if (arguments.size() > 2) {
- return Status::InvalidArgument(
- fmt::format("too many arguments for function {}",
get_name()));
}
- if (n == -1) {
+ if (n == -1) { // no 2nd arg, just mask all
FunctionMask::vector_mask(source_column, *res,
FunctionMask::DEFAULT_UPPER_MASK,
FunctionMask::DEFAULT_LOWER_MASK,
FunctionMask::DEFAULT_NUMBER_MASK);
- } else if (n >= 0) {
+ } else { // n >= 0
vector(source_column, n, *res);
}
@@ -2901,19 +2901,18 @@ public:
ColumnPtr argument_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
- const auto* length_col =
check_and_get_column<ColumnInt32>(argument_column.get());
-
- if (!length_col) {
- return Status::InternalError("Not supported input argument type");
- }
+ const auto* length_col = assert_cast<const
ColumnInt32*>(argument_column.get());
std::vector<uint8_t> random_bytes;
std::random_device rd;
std::mt19937 gen(rd());
for (size_t i = 0; i < input_rows_count; ++i) {
- UInt64 length = length_col->get64(i);
- random_bytes.resize(length);
+ if (length_col->get_element(i) < 0) [[unlikely]] {
+ return Status::InvalidArgument("argument {} of function {} at
row {} was invalid.",
+ length_col->get_element(i),
name, i);
+ }
+ random_bytes.resize(length_col->get_element(i));
std::uniform_int_distribution<uint8_t> distribution(0, 255);
for (auto& byte : random_bytes) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
index 81a968067c2..33e19d468e8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@@ -65,6 +66,13 @@ public class MaskFirstN extends ScalarFunction implements
ExplicitlyCastableSign
return new MaskFirstN(children.get(0), children.get(1));
}
+ @Override
+ public void checkLegalityAfterRewrite() {
+ if (arity() == 2 && !child(1).isLiteral()) {
+ throw new AnalysisException("mask_first_n must accept literal for
2nd argument");
+ }
+ }
+
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
index cb8246f04ab..eafb85ee89b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@@ -65,6 +66,13 @@ public class MaskLastN extends ScalarFunction implements
ExplicitlyCastableSigna
return new MaskLastN(children.get(0), children.get(1));
}
+ @Override
+ public void checkLegalityAfterRewrite() {
+ if (arity() == 2 && !child(1).isLiteral()) {
+ throw new AnalysisException("mask_last_n must accept literal for
2nd argument");
+ }
+ }
+
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
diff --git a/regression-test/suites/correctness_p0/test_mask_function.groovy
b/regression-test/suites/correctness_p0/test_mask_function.groovy
index b242e72eccc..b7717ab183c 100644
--- a/regression-test/suites/correctness_p0/test_mask_function.groovy
+++ b/regression-test/suites/correctness_p0/test_mask_function.groovy
@@ -75,4 +75,21 @@ suite("test_mask_function") {
qt_select_digital_masking """
select digital_masking(13812345678);
"""
+
+ test {
+ sql """ select mask_last_n("12345", -100); """
+ exception "function mask_last_n only accept non-negative input for 2nd
argument but got -100"
+ }
+ test {
+ sql """ select mask_first_n("12345", -100); """
+ exception "function mask_first_n only accept non-negative input for
2nd argument but got -100"
+ }
+ test {
+ sql """ select mask_last_n("12345", id) from table_mask_test; """
+ exception "mask_last_n must accept literal for 2nd argument"
+ }
+ test {
+ sql """ select mask_first_n("12345", id) from table_mask_test; """
+ exception "mask_first_n must accept literal for 2nd argument"
+ }
}
diff --git
a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
index fa58e6d0cb2..1110ed3a47a 100644
--- a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
+++ b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
@@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") {
qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from
fn_test_not_nullable order by kstr"
sql "SELECT random_bytes(7);"
qt_sql_random_bytes "SELECT random_bytes(null);"
+ test {
+ sql " select random_bytes(-1); "
+ exception "argument -1 of function random_bytes at row 0 was
invalid"
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]