This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 24feeed612 [GLUTEN-7717][CH] [ARM]fix compile issue for
SparkFunctionRoundHalfUp (#7718)
24feeed612 is described below
commit 24feeed61243dbd233bd27644aad439647a180ef
Author: loudongfeng <[email protected]>
AuthorDate: Wed Oct 30 11:19:10 2024 +0800
[GLUTEN-7717][CH] [ARM]fix compile issue for SparkFunctionRoundHalfUp
(#7718)
* [GLUTEN-7717][CH] [ARM]fix compile issue for SparkFunctionRoundHalfUp
* fix hand type issue
---
.../Functions/SparkFunctionRoundHalfUp.h | 63 ++++++++++++++++++----
1 file changed, 52 insertions(+), 11 deletions(-)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
index 0bd28b116d..432595e091 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
@@ -27,11 +27,15 @@ namespace local_engine
{
using namespace DB;
-template <typename T>
+template <typename T, Vectorize vectorize>
class BaseFloatRoundingHalfUpComputation;
+#ifdef __SSE4_1__
+
+/// vectorized implementation for x86
+
template <>
-class BaseFloatRoundingHalfUpComputation<Float32>
+class BaseFloatRoundingHalfUpComputation<Float32, Vectorize::Yes>
{
public:
using ScalarType = Float32;
@@ -59,7 +63,7 @@ public:
};
template <>
-class BaseFloatRoundingHalfUpComputation<Float64>
+class BaseFloatRoundingHalfUpComputation<Float64, Vectorize::Yes>
{
public:
using ScalarType = Float64;
@@ -86,13 +90,43 @@ public:
static VectorType prepare(size_t scale) { return load1(scale); }
};
+/// end __SSE4_1__
+#endif
+
+/// Sequential implementation for ARM. Also used for scalar arguments
+
+template <typename T>
+class BaseFloatRoundingHalfUpComputation<T, Vectorize::No>
+{
+public:
+ using ScalarType = T;
+ using VectorType = T;
+ static const size_t data_count = 1;
+
+ static VectorType load(const ScalarType * in) { return *in; }
+ static VectorType load1(const ScalarType in) { return in; }
+ static VectorType store(ScalarType * out, ScalarType val) { return *out =
val;}
+ static VectorType multiply(VectorType val, VectorType scale) { return val
* scale; }
+ static VectorType divide(VectorType val, VectorType scale) { return val /
scale; }
+ template <RoundingMode mode>
+ static VectorType apply(VectorType val)
+ {
+ return roundWithMode(val, mode);
+ }
+
+ static VectorType prepare(size_t scale)
+ {
+ return load1(scale);
+ }
+};
+
/** Implementation of low-level round-off functions for floating-point values.
*/
-template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
-class FloatRoundingHalfUpComputation : public
BaseFloatRoundingHalfUpComputation<T>
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
Vectorize vectorize>
+class FloatRoundingHalfUpComputation : public
BaseFloatRoundingHalfUpComputation<T, vectorize>
{
- using Base = BaseFloatRoundingHalfUpComputation<T>;
+ using Base = BaseFloatRoundingHalfUpComputation<T, vectorize>;
public:
static inline void compute(const T * __restrict in, const typename
Base::VectorType & scale, T * __restrict out)
@@ -124,15 +158,22 @@ struct FloatRoundingHalfUpImpl
private:
static_assert(!is_decimal<T>);
- using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode>;
- using Data = std::array<T, Op::data_count>;
+ template <Vectorize vectorize =
+#ifdef __SSE4_1__
+ Vectorize::Yes
+#else
+ Vectorize::No
+#endif
+ >
+ using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode,
vectorize>;
+ using Data = std::array<T, Op<>::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;
public:
static NO_INLINE void apply(const Container & in, size_t scale, Container
& out)
{
- auto mm_scale = Op::prepare(scale);
+ auto mm_scale = Op<>::prepare(scale);
const size_t data_count = std::tuple_size<Data>();
@@ -144,7 +185,7 @@ public:
while (p_in < limit)
{
- Op::compute(p_in, mm_scale, p_out);
+ Op<>::compute(p_in, mm_scale, p_out);
p_in += data_count;
p_out += data_count;
}
@@ -157,7 +198,7 @@ public:
size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in);
memcpy(&tmp_src, p_in, tail_size_bytes);
- Op::compute(reinterpret_cast<T *>(&tmp_src), mm_scale,
reinterpret_cast<T *>(&tmp_dst));
+ Op<>::compute(reinterpret_cast<T *>(&tmp_src), mm_scale,
reinterpret_cast<T *>(&tmp_dst));
memcpy(p_out, &tmp_dst, tail_size_bytes);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]