tqchen commented on code in PR #18272:
URL: https://github.com/apache/tvm/pull/18272#discussion_r2326322145
##########
ffi/include/tvm/ffi/container/ndarray.h:
##########
@@ -110,6 +109,21 @@ inline size_t GetDataSize(const DLTensor& arr) {
return GetDataSize(size, arr.dtype);
}
+/*!
+ * \brief Infer the stride from shape
+ *
+ * \param shape the input Shape
+ * \return the inferred stride
+ */
+inline Shape InferStrideFromShape(Shape shape) {
+ size_t ndim = shape.size();
+ Array<int64_t> strides(ndim, 1);
Review Comment:
do not use array, instead, use
https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/container/shape.h#L71
details::MakeEmptyShape, then fill in the strides.
do something like
```
int64_t stride = 1;
for () {
assign
}
```
We can also move this function to shape.h details::
##########
ffi/include/tvm/ffi/container/ndarray.h:
##########
@@ -41,7 +41,6 @@ namespace ffi {
* \return The check result.
*/
inline bool IsContiguous(const DLTensor& arr) {
- if (arr.strides == nullptr) return true;
Review Comment:
keep this for now as this function is defined per DLTensor
##########
ffi/include/tvm/ffi/container/ndarray.h:
##########
@@ -202,10 +219,6 @@ class NDArrayObjFromDLPack : public NDArrayObj {
public:
explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) :
tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
- // set strides to nullptr if the tensor is contiguous.
- if (IsContiguous(tensor->dl_tensor)) {
Review Comment:
also need to check `tensor->dl_tensor->strides == nullptr` and act
accordingly if needed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]