icemelon9 commented on a change in pull request #5459:
URL: https://github.com/apache/incubator-tvm/pull/5459#discussion_r427061276
##########
File path: src/relay/transforms/pattern_util.h
##########
@@ -325,6 +345,65 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
return tvm::StructuralEqual()(a, b);
}
+/*!
+ * \brief Convert an element of a NDArray with type int or float to scalar.
+ * \param array Input NDArray
+ * \param i element index
+ * \return Converted scalar value.
+ */
+static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+ if (array->dtype.code == kDLInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<int8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<int16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<int32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<int64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLUInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<uint8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<uint16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<uint32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<uint64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLFloat) {
+#if (__ARM_FP16_FORMAT_IEEE == 1)
+ if (array->dtype.bits == 16) {
+ return reinterpret_cast<__fp16*>(array->data)[i];
+ }
+#endif
+ if (array->dtype.bits == 32) {
+ return reinterpret_cast<float*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<double*>(array->data)[i];
+ }
+ }
+ LOG(FATAL) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
+ // make compiler happy
+ return -std::numeric_limits<double>::infinity();
+}
+
+/*!
+ * \brief Convert a NDArray with type int or float to Array<Integer>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<Integer> ToVector(const runtime::NDArray& array) {
+ size_t len = array.Shape().front();
Review comment:
Probably check the ndim == 1 here?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]