yongwww commented on code in PR #15833:
URL: https://github.com/apache/tvm/pull/15833#discussion_r1340399667


##########
include/tvm/runtime/memory/memory_manager.h:
##########
@@ -37,15 +37,15 @@
 
 namespace tvm {
 namespace runtime {
-namespace vm {
+namespace memory {
 
 struct Buffer {
   /*! \brief The pointer to the allocated block of memory. */
   void* data{nullptr};
   /*! \brief The size of the block. */
   size_t size{0};
   /*! \brief The shape of the tensor. */
-  std::vector<int64_t> shape;
+  ShapeTuple shape;

Review Comment:
   we can remove this member `ShapeTuple shape;`, and related 
`buf.shape.push_back(shape[i]);` in naive_allocator.h. Looks it is not used.



##########
src/runtime/memory/memory_manager.cc:
##########
@@ -154,29 +154,30 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
   return it->second.get();
 }
 
-NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, 
DLDevice dev) {
+NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,
+                         Optional<String> mem_scope) {
   VerifyDataType(dtype);
   NDArray::Container* container = new NDArray::Container(nullptr, shape, 
dtype, dev);
   container->SetDeleter(BufferDeleter);
-  size_t size = GetDataSize(container->dl_tensor);
+  size_t size = DeviceAPI::Get(dev)->GetDataSize(container->dl_tensor);
   size_t alignment = GetDataAlignment(container->dl_tensor);
   Buffer* buffer = new Buffer;
-  *buffer = this->Alloc(size, alignment, dtype);
+  if (!mem_scope.defined() || mem_scope == "global") {
+    *buffer = this->Alloc(size, alignment, dtype);
+  } else {
+    *buffer = this->Alloc(shape, dtype, mem_scope.value());
+  }
   container->manager_ctx = reinterpret_cast<void*>(buffer);
   container->dl_tensor.data = buffer->data;
   return NDArray(GetObjectPtr<Object>(container));
 }
 

Review Comment:
   it would be helpful to add the definition of Storage as below in in 
memory_manager.cc, it was not defined yet.
   
   ```
   Storage::Storage(Buffer buffer) {
     auto n = make_object<StorageObj>();
     n->buffer = std::move(buffer);
     data_ = std::move(n);
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to