tkonolige commented on a change in pull request #7559:
URL: https://github.com/apache/tvm/pull/7559#discussion_r589691387



##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -196,31 +198,10 @@ void GraphRuntime::LoadParams(const std::string& 
param_blob) {
 }
 
 void GraphRuntime::LoadParams(dmlc::Stream* strm) {
-  uint64_t header, reserved;
-  ICHECK(strm->Read(&header)) << "Invalid parameters file format";
-  ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
-  ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
-
-  std::vector<std::string> names;
-  ICHECK(strm->Read(&names)) << "Invalid parameters file format";
-  uint64_t sz;
-  strm->Read(&sz);
-  size_t size = static_cast<size_t>(sz);
-  ICHECK(size == names.size()) << "Invalid parameters file format";
-  for (size_t i = 0; i < size; ++i) {
-    int in_idx = GetInputIndex(names[i]);
-    if (in_idx < 0) {
-      NDArray temp;
-      temp.Load(strm);
-      continue;
-    }
-    uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
-    ICHECK_LT(eid, data_entry_.size());
-
-    // The data_entry is allocated on device, NDArray.load always load the 
array into CPU.
-    NDArray temp;
-    temp.Load(strm);
-    data_entry_[eid].CopyFrom(temp);
+  Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
+  for (auto& p : params) {
+    uint32_t eid = this->entry_id(input_nodes_[GetInputIndex(p.first)], 0);
+    data_entry_[eid].CopyFrom(p.second);

Review comment:
       That's not possible. The data has to be loaded onto the cpu and then 
copied to the device.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to