uchenily commented on code in PR #55184:
URL: https://github.com/apache/doris/pull/55184#discussion_r2297598993
##########
be/src/vec/functions/array/function_array_distance.h:
##########
@@ -110,7 +105,7 @@ class FunctionArrayDistance : public IFunction {
bool use_default_implementation_for_nulls() const override { return false;
}
Review Comment:
done
##########
be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp:
##########
@@ -78,19 +75,11 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const
RowDescriptor& row_des
vir_col_expr->debug_string());
}
- std::shared_ptr<vectorized::VCastExpr> cast_to_array_expr =
-
std::dynamic_pointer_cast<vectorized::VCastExpr>(distance_fn_call->children()[0]);
-
- if (cast_to_array_expr == nullptr) {
- return Status::InternalError("Ann topn expr expect cast_to_array_expr,
got\n{}",
-
distance_fn_call->children()[0]->debug_string());
- }
-
std::shared_ptr<vectorized::VSlotRef> slot_ref =
-
std::dynamic_pointer_cast<vectorized::VSlotRef>(cast_to_array_expr->children()[0]);
+
std::dynamic_pointer_cast<vectorized::VSlotRef>(distance_fn_call->children()[0]);
Review Comment:
if we add cast_to_array, the following error will occur:
[INTERNAL_ERROR]Ann topn expr expect cast_to_array_expr, got
SlotRef(slot_id=1 type=Array(Nullable(FLOAT)))
it seems that there will not be a CastExpr generated here.
##########
be/src/vec/functions/array/function_array_distance.h:
##########
@@ -140,27 +137,14 @@ class FunctionArrayDistance : public IFunction {
}
// prepare return data
- auto dst = ColumnFloat64::create(input_rows_count);
+ auto dst = ColumnType::create(input_rows_count);
auto& dst_data = dst->get_data();
- auto dst_null_column = ColumnUInt8::create(input_rows_count, 0);
- auto& dst_null_data = dst_null_column->get_data();
const auto& offsets1 = *arr1.offsets_ptr;
const auto& offsets2 = *arr2.offsets_ptr;
- const auto& nested_col1 = assert_cast<const
ColumnFloat64*>(arr1.nested_col.get());
- const auto& nested_col2 = assert_cast<const
ColumnFloat64*>(arr2.nested_col.get());
+ const auto& nested_col1 = assert_cast<const
ColumnType*>(arr1.nested_col.get());
Review Comment:
Now all distance funcs are changed to PropagateNullable, block data type is
FLOAT rather than Nullable(FLOAT), so I removed the null_column.
##########
be/src/vec/functions/array/function_array_distance.h:
##########
@@ -35,82 +37,77 @@ class L1Distance {
public:
static constexpr auto name = "l1_distance";
struct State {
- double sum = 0;
+ float sum = 0;
};
- static void accumulate(State& state, double x, double y) { state.sum +=
fabs(x - y); }
- static double finalize(const State& state) { return state.sum; }
+ static void accumulate(State& state, float x, float y) { state.sum +=
fabs(x - y); }
+ static float finalize(const State& state) { return state.sum; }
};
class L2Distance {
public:
static constexpr auto name = "l2_distance";
struct State {
- double sum = 0;
+ float sum = 0;
};
- static void accumulate(State& state, double x, double y) { state.sum += (x
- y) * (x - y); }
- static double finalize(const State& state) { return sqrt(state.sum); }
+ static void accumulate(State& state, float x, float y) { state.sum += (x -
y) * (x - y); }
+ static float finalize(const State& state) { return sqrt(state.sum); }
};
class InnerProduct {
public:
static constexpr auto name = "inner_product";
struct State {
- double sum = 0;
+ float sum = 0;
};
- static void accumulate(State& state, double x, double y) { state.sum += x
* y; }
- static double finalize(const State& state) { return state.sum; }
+ static void accumulate(State& state, float x, float y) { state.sum += x *
y; }
+ static float finalize(const State& state) { return state.sum; }
};
class CosineDistance {
public:
static constexpr auto name = "cosine_distance";
struct State {
- double dot_prod = 0;
- double squared_x = 0;
- double squared_y = 0;
+ float dot_prod = 0;
+ float squared_x = 0;
+ float squared_y = 0;
};
- static void accumulate(State& state, double x, double y) {
+ static void accumulate(State& state, float x, float y) {
state.dot_prod += x * y;
state.squared_x += x * x;
state.squared_y += y * y;
}
- static double finalize(const State& state) {
+ static float finalize(const State& state) {
return 1 - state.dot_prod / sqrt(state.squared_x * state.squared_y);
}
};
-class L2DistanceApproximate {
+class L2DistanceApproximate : public L2Distance {
public:
static constexpr auto name = "l2_distance_approximate";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum += (x
- y) * (x - y); }
- static double finalize(const State& state) { return sqrt(state.sum); }
};
-class InnerProductApproximate {
+class InnerProductApproximate : public InnerProduct {
public:
static constexpr auto name = "inner_product_approximate";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum += x
* y; }
- static double finalize(const State& state) { return state.sum; }
};
-template <typename DistanceImpl>
+template <typename DistanceImpl, PrimitiveType Type>
class FunctionArrayDistance : public IFunction {
public:
+ using DataType = PrimitiveTypeTraits<Type>::DataType;
Review Comment:
yes, not necessary at present, i will change it
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]