This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0b2358c2e4 [Relay] make "ToScalar" support directly obtaining
"int64_t" (#16324)
0b2358c2e4 is described below
commit 0b2358c2e4656648f726d4a16507ef2513451ad5
Author: mawnja <[email protected]>
AuthorDate: Thu Jan 11 03:21:25 2024 +0800
[Relay] make "ToScalar" support directly obtaining "int64_t" (#16324)
Because on Windows, "long double" is 64 bits instead of 128 bits like on
Linux, to avoid overflow from "long double" to "int64_t"
Co-authored-by: wenjian.ma <[email protected]>
---
src/relay/transforms/pattern_utils.h | 43 ++++++++++++++++++++---------------
src/relay/transforms/simplify_expr.cc | 2 +-
2 files changed, 26 insertions(+), 19 deletions(-)
diff --git a/src/relay/transforms/pattern_utils.h
b/src/relay/transforms/pattern_utils.h
index 50c2e00298..b26bd76496 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -468,43 +468,43 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \param i element index
* \return Converted scalar value, or None if conversion failed
*/
-static inline std::optional<long double> TryToScalar(const runtime::NDArray&
array, size_t i = 0) {
+template <typename T>
+static inline std::optional<T> TryToScalar(const runtime::NDArray& array,
size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
- return std::optional<long
double>(reinterpret_cast<int8_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<int8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return std::optional<long
double>(reinterpret_cast<int16_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<int16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return std::optional<long
double>(reinterpret_cast<int32_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<int32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long
double>(reinterpret_cast<int64_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<int64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 1) { // bool
- return std::optional<long
double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 8) {
- return std::optional<long
double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return std::optional<long
double>(reinterpret_cast<uint16_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<uint16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return std::optional<long
double>(reinterpret_cast<uint32_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<uint32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long
double>(reinterpret_cast<uint64_t*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<uint64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLFloat) {
if (array->dtype.bits == 16) {
- return std::optional<long double>(
- __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
- reinterpret_cast<uint16_t*>(array->data)[i]));
+ return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 10, float,
uint32_t, 23>(
+ reinterpret_cast<uint16_t*>(array->data)[i]));
}
if (array->dtype.bits == 32) {
- return std::optional<long
double>(reinterpret_cast<float*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<float*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long
double>(reinterpret_cast<double*>(array->data)[i]);
+ return std::optional<T>(reinterpret_cast<double*>(array->data)[i]);
}
} else if (array->dtype.code == kDLBfloat) {
if (array->dtype.bits == 16) {
- return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7,
float, uint32_t, 23>(
+ return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 7, float,
uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
}
@@ -517,8 +517,15 @@ static inline std::optional<long double> TryToScalar(const
runtime::NDArray& arr
* \param i element index
* \return Converted scalar value
*/
+template <typename T>
+static inline T ToScalar(const runtime::NDArray& array, size_t i = 0) {
+ auto try_value = TryToScalar<T>(array, i);
+ ICHECK(try_value) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
+ return try_value.value();
+}
+
static inline long double ToScalar(const runtime::NDArray& array, size_t i =
0) {
- auto try_value = TryToScalar(array, i);
+ auto try_value = TryToScalar<long double>(array, i);
ICHECK(try_value) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
return try_value.value();
}
@@ -534,7 +541,7 @@ static inline Array<Integer> ToVector(const
runtime::NDArray& array) {
size_t len = array.Shape().front();
Array<Integer> out;
for (size_t i = 0; i < len; ++i) {
- long double elem_val = ToScalar(array, i);
+ uint64_t elem_val = ToScalar<uint64_t>(array, i);
out.push_back(Integer(IntImm(DataType::Int(32),
static_cast<int64_t>(elem_val))));
}
return out;
diff --git a/src/relay/transforms/simplify_expr.cc
b/src/relay/transforms/simplify_expr.cc
index 208c9821b6..8036d301e1 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -794,7 +794,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
if (!IsScalar(GetRef<Expr>(constant))) {
return false;
}
- auto value = TryToScalar(constant->data, 0);
+ auto value = TryToScalar<long double>(constant->data, 0);
if (!value) {
// unsupported dtype
return false;