This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bbca53d2ab [DNNL] Add TensorRequisite concept (#11345)
bbca53d2ab is described below

commit bbca53d2ab354d7e8bed11fc9e1eae13fbee7730
Author: apeskov <[email protected]>
AuthorDate: Thu Jun 2 13:04:12 2022 +0300

    [DNNL] Add TensorRequisite concept (#11345)
    
    Allow to use DNNL runtime in multi instance mode.
    Thread safe execution of Run() method.
    
    Signed-off-by: Alexander Peskov <[email protected]>
---
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc    | 1412 ++++++----------------
 src/runtime/contrib/dnnl/dnnl_tensor_requisite.h |  720 +++++++++++
 src/runtime/contrib/dnnl/dnnl_utils.cc           |   24 +-
 src/runtime/contrib/dnnl/dnnl_utils.h            |   98 +-
 4 files changed, 1239 insertions(+), 1015 deletions(-)

diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc 
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index f6a1c3b790..a2417f012e 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -32,7 +32,12 @@
 
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
-#include "dnnl.hpp"
+
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+//  -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "dnnl_tensor_requisite.h"
 #include "dnnl_utils.h"
 
 namespace tvm {
@@ -43,552 +48,82 @@ using namespace tvm::runtime;
 using namespace tvm::runtime::json;
 
 class DNNLJSONRuntime : public JSONRuntimeBase {
-  using tag = dnnl::memory::format_tag;
-  using dt = dnnl::memory::data_type;
-
  public:
   DNNLJSONRuntime(const std::string& symbol_name, const std::string& 
graph_json,
                   const Array<String> const_names)
-      : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+      : JSONRuntimeBase(symbol_name, graph_json, const_names),
+        next_unique_eid_offset_(data_entry_.size()),
+        run_arg_eid_(input_var_eid_) {
+    for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e));
+  }
 
-  const char* type_key() const { return "dnnl_json"; }
+  const char* type_key() const override { return "dnnl_json"; }
 
   void Init(const Array<NDArray>& consts) override {
-    BuildEngine();
-
     ICHECK_EQ(consts.size(), const_idx_.size())
         << "The number of input constants must match the number of required.";
 
     // Setup constants entries for weights.
     SetupConstants(consts);
+    BuildEngine();
   }
 
-  void Run() override {
-    // Fill in the input buffers.
-    for (size_t i = 0; i < input_nodes_.size(); ++i) {
-      auto eid = EntryID(input_nodes_[i], 0);
-      size_t offset_in_bytes =
-          entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 
8);
-      size_t buffer_size = GetDataSize(*data_entry_[eid]);
-      write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, 
buffer_size,
-                           offset_in_bytes);
-    }
+  /* Unused stub implementation */
+  void Run() override { LOG(FATAL) << "Unreachable code"; }
 
-    // Invoke the engine through intepreting the stream.
-    for (size_t i = 0; i < net_.size(); ++i) {
-      net_.at(i).execute(stream_, net_args_.at(i));
-    }
-    stream_.wait();
-
-    // Read output buffers.
-    for (size_t i = 0; i < outputs_.size(); ++i) {
-      auto eid = EntryID(outputs_[i]);
-      size_t offset_in_bytes =
-          entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 
8);
-      size_t buffer_size = GetDataSize(*data_entry_[eid]);
-      read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, 
buffer_size,
-                            offset_in_bytes);
+  /* Thread safe implementation of Run. Keep runtime instance immutable */
+  void Run(const TVMArgs& args) const {
+    auto arg_data_provider = makeIODataProvider(args);
+    auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider);
+    // Execute primitives one by one
+    for (const auto& act : net_) {
+      auto prim = std::get<0>(act);
+      auto arg_reqs = std::get<1>(act);
+
+      // Find proper dnnl::memory buffers
+      std::unordered_map<int, dnnl::memory> mem_args;
+      for (const auto& kvp : arg_reqs) mem_args[kvp.first] = 
mem_solver(kvp.second);
+
+      prim.execute(stream_, mem_args);
     }
   }
 
- private:
-  tag layout2tag(std::string layout) {
-    static const std::map<std::string, tag> str2tag = {{"nc", tag::nc},
-                                                       {"cn", tag::cn},
-                                                       {"tn", tag::tn},
-                                                       {"nt", tag::nt},
-                                                       {"ncw", tag::ncw},
-                                                       {"nwc", tag::nwc},
-                                                       {"nchw", tag::nchw},
-                                                       {"nhwc", tag::nhwc},
-                                                       {"chwn", tag::chwn},
-                                                       {"ncdhw", tag::ncdhw},
-                                                       {"ndhwc", tag::ndhwc},
-                                                       {"oi", tag::oi},
-                                                       {"io", tag::io},
-                                                       {"oiw", tag::oiw},
-                                                       {"owi", tag::owi},
-                                                       {"wio", tag::wio},
-                                                       {"iwo", tag::iwo},
-                                                       {"oihw", tag::oihw},
-                                                       {"hwio", tag::hwio},
-                                                       {"ohwi", tag::ohwi},
-                                                       {"ihwo", tag::ihwo},
-                                                       {"iohw", tag::iohw},
-                                                       {"oidhw", tag::oidhw},
-                                                       {"dhwio", tag::dhwio},
-                                                       {"odhwi", tag::odhwi},
-                                                       {"iodhw", tag::iodhw},
-                                                       {"idhwo", tag::idhwo},
-                                                       {"goiw", tag::goiw},
-                                                       {"gowi", tag::gowi},
-                                                       {"wigo", tag::wigo},
-                                                       {"gohwi", tag::gohwi},
-                                                       {"goihw", tag::goihw},
-                                                       {"hwigo", tag::hwigo},
-                                                       {"giohw", tag::giohw},
-                                                       {"goidhw", tag::goidhw},
-                                                       {"giodhw", tag::giodhw},
-                                                       {"godhwi", tag::godhwi},
-                                                       {"dhwigo", tag::dhwigo},
-                                                       {"tnc", tag::tnc},
-                                                       {"ntc", tag::ntc},
-                                                       {"ldnc", tag::ldnc},
-                                                       {"ldigo", tag::ldigo},
-                                                       {"ldgoi", tag::ldgoi},
-                                                       {"ldio", tag::ldio},
-                                                       {"ldoi", tag::ldoi},
-                                                       {"ldgo", tag::ldgo},
-                                                       {"nCdhw16c", 
tag::nCdhw16c},
-                                                       {"nCdhw4c", 
tag::nCdhw4c},
-                                                       {"nCdhw8c", 
tag::nCdhw8c},
-                                                       {"nChw16c", 
tag::nChw16c},
-                                                       {"nChw4c", tag::nChw4c},
-                                                       {"nChw8c", tag::nChw8c},
-                                                       {"nCw16c", tag::nCw16c},
-                                                       {"nCw4c", tag::nCw4c},
-                                                       {"nCw8c", tag::nCw8c},
-                                                       {"NCw16n16c", 
tag::NCw16n16c},
-                                                       {"NChw16n16c", 
tag::NChw16n16c},
-                                                       {"NCdhw16n16c", 
tag::NCdhw16n16c},
-                                                       {"NCdhw32n32c", 
tag::NCdhw32n32c},
-                                                       {"NChw32n32c", 
tag::NChw32n32c},
-                                                       {"IOhw16i16o", 
tag::IOhw16i16o},
-                                                       {"OI16i16o", 
tag::OI16i16o},
-                                                       {"OI16i32o", 
tag::OI16i32o},
-                                                       {"OI16i64o", 
tag::OI16i64o},
-                                                       {"OI8i16o2i", 
tag::OI8i16o2i},
-                                                       {"OI8i32o2i", 
tag::OI8i32o2i},
-                                                       {"OI8i64o2i", 
tag::OI8i64o2i},
-                                                       {"OI4i16o4i", 
tag::OI4i16o4i},
-                                                       {"OI4i32o4i", 
tag::OI4i32o4i},
-                                                       {"OI4i64o4i", 
tag::OI4i64o4i},
-                                                       {"Ohwi32o", 
tag::Ohwi32o},
-                                                       {"IOdhw16i16o", 
tag::IOdhw16i16o},
-                                                       {"gIOhw16i16o", 
tag::gIOhw16i16o},
-                                                       {"gOhwi32o", 
tag::gOhwi32o},
-                                                       {"Goidhw16g", 
tag::Goidhw16g},
-                                                       {"IOw16o16i", 
tag::IOw16o16i},
-                                                       {"OIw16i16o", 
tag::OIw16i16o},
-                                                       {"OIw16i32o", 
tag::OIw16i32o},
-                                                       {"OIw16i64o", 
tag::OIw16i64o},
-                                                       {"IOw16i16o", 
tag::IOw16i16o},
-                                                       {"gIOw16i16o", 
tag::gIOw16i16o},
-                                                       {"OIw16o16i", 
tag::OIw16o16i},
-                                                       {"Oiw16o", tag::Oiw16o},
-                                                       {"OIw4i16o4i", 
tag::OIw4i16o4i},
-                                                       {"OIw4i32o4i", 
tag::OIw4i32o4i},
-                                                       {"OIw4i64o4i", 
tag::OIw4i64o4i},
-                                                       {"OIw2i8o4i", 
tag::OIw2i8o4i},
-                                                       {"OIw4i4o", 
tag::OIw4i4o},
-                                                       {"OIw4o4i", 
tag::OIw4o4i},
-                                                       {"Oiw4o", tag::Oiw4o},
-                                                       {"OIw8i16o2i", 
tag::OIw8i16o2i},
-                                                       {"OIw8i32o2i", 
tag::OIw8i32o2i},
-                                                       {"OIw8i64o2i", 
tag::OIw8i64o2i},
-                                                       {"OIw8i8o", 
tag::OIw8i8o},
-                                                       {"OIw8o16i2o", 
tag::OIw8o16i2o},
-                                                       {"OIw8o8i", 
tag::OIw8o8i},
-                                                       {"OIw8o4i", 
tag::OIw8o4i},
-                                                       {"OIw16i16o4i", 
tag::OIw16i16o4i},
-                                                       {"OIw16i32o4i", 
tag::OIw16i32o4i},
-                                                       {"OIw16i48o4i", 
tag::OIw16i48o4i},
-                                                       {"OIw16i64o4i", 
tag::OIw16i64o4i},
-                                                       {"OIw16i16o2i", 
tag::OIw16i16o2i},
-                                                       {"OIw16i32o2i", 
tag::OIw16i32o2i},
-                                                       {"OIw16i48o2i", 
tag::OIw16i48o2i},
-                                                       {"OIw16i64o2i", 
tag::OIw16i64o2i},
-                                                       {"OIw16o16i2o", 
tag::OIw16o16i2o},
-                                                       {"Owi16o", tag::Owi16o},
-                                                       {"OwI16o2i", 
tag::OwI16o2i},
-                                                       {"Owi4o", tag::Owi4o},
-                                                       {"Owi8o", tag::Owi8o},
-                                                       {"IOhw16o16i", 
tag::IOhw16o16i},
-                                                       {"Ohwi16o", 
tag::Ohwi16o},
-                                                       {"OhwI16o2i", 
tag::OhwI16o2i},
-                                                       {"Ohwi4o", tag::Ohwi4o},
-                                                       {"Ohwi8o", tag::Ohwi8o},
-                                                       {"OIhw16i16o", 
tag::OIhw16i16o},
-                                                       {"OIhw16i32o", 
tag::OIhw16i32o},
-                                                       {"OIhw16i64o", 
tag::OIhw16i64o},
-                                                       {"OIhw16o16i", 
tag::OIhw16o16i},
-                                                       {"Oihw16o", 
tag::Oihw16o},
-                                                       {"OIhw4i16o4i", 
tag::OIhw4i16o4i},
-                                                       {"OIhw4i32o4i", 
tag::OIhw4i32o4i},
-                                                       {"OIhw4i64o4i", 
tag::OIhw4i64o4i},
-                                                       {"OIhw4i4o", 
tag::OIhw4i4o},
-                                                       {"OIhw4o4i", 
tag::OIhw4o4i},
-                                                       {"Oihw4o", tag::Oihw4o},
-                                                       {"OIhw8i16o2i", 
tag::OIhw8i16o2i},
-                                                       {"OIhw8i32o2i", 
tag::OIhw8i32o2i},
-                                                       {"OIhw8i64o2i", 
tag::OIhw8i64o2i},
-                                                       {"OIhw8i8o", 
tag::OIhw8i8o},
-                                                       {"OIhw8o16i2o", 
tag::OIhw8o16i2o},
-                                                       {"OIhw8o8i", 
tag::OIhw8o8i},
-                                                       {"OIhw8o4i", 
tag::OIhw8o4i},
-                                                       {"OIhw2i8o4i", 
tag::OIhw2i8o4i},
-                                                       {"IOdhw16o16i", 
tag::IOdhw16o16i},
-                                                       {"Odhwi16o", 
tag::Odhwi16o},
-                                                       {"OdhwI16o2i", 
tag::OdhwI16o2i},
-                                                       {"Odhwi4o", 
tag::Odhwi4o},
-                                                       {"Odhwi8o", 
tag::Odhwi8o},
-                                                       {"OIdhw16i16o", 
tag::OIdhw16i16o},
-                                                       {"OIdhw16i32o", 
tag::OIdhw16i32o},
-                                                       {"OIdhw16i64o", 
tag::OIdhw16i64o},
-                                                       {"OIdhw16o16i", 
tag::OIdhw16o16i},
-                                                       {"Oidhw16o", 
tag::Oidhw16o},
-                                                       {"OIdhw4i4o", 
tag::OIdhw4i4o},
-                                                       {"OIdhw4o4i", 
tag::OIdhw4o4i},
-                                                       {"Oidhw4o", 
tag::Oidhw4o},
-                                                       {"OIdhw8i16o2i", 
tag::OIdhw8i16o2i},
-                                                       {"OIdhw8i32o2i", 
tag::OIdhw8i32o2i},
-                                                       {"OIdhw8i64o2i", 
tag::OIdhw8i64o2i},
-                                                       {"OIdhw4i16o4i", 
tag::OIdhw4i16o4i},
-                                                       {"OIdhw16i16o4i", 
tag::OIdhw16i16o4i},
-                                                       {"OIdhw16i32o4i", 
tag::OIdhw16i32o4i},
-                                                       {"OIdhw16i48o4i", 
tag::OIdhw16i48o4i},
-                                                       {"OIdhw16i64o4i", 
tag::OIdhw16i64o4i},
-                                                       {"OIdhw16i16o2i", 
tag::OIdhw16i16o2i},
-                                                       {"OIdhw16i32o2i", 
tag::OIdhw16i32o2i},
-                                                       {"OIdhw16i48o2i", 
tag::OIdhw16i48o2i},
-                                                       {"OIdhw16i64o2i", 
tag::OIdhw16i64o2i},
-                                                       {"OIdhw4i32o4i", 
tag::OIdhw4i32o4i},
-                                                       {"OIdhw4i64o4i", 
tag::OIdhw4i64o4i},
-                                                       {"OIdhw2i8o4i", 
tag::OIdhw2i8o4i},
-                                                       {"OIdhw8i8o", 
tag::OIdhw8i8o},
-                                                       {"OIdhw8o8i", 
tag::OIdhw8o8i},
-                                                       {"OIdhw8o4i", 
tag::OIdhw8o4i},
-                                                       {"gIOw16o16i", 
tag::gIOw16o16i},
-                                                       {"gOIw16i16o", 
tag::gOIw16i16o},
-                                                       {"gOIw16o16i", 
tag::gOIw16o16i},
-                                                       {"gOiw16o", 
tag::gOiw16o},
-                                                       {"gOIw4i16o4i", 
tag::gOIw4i16o4i},
-                                                       {"gOIw2i8o4i", 
tag::gOIw2i8o4i},
-                                                       {"gOIw4i4o", 
tag::gOIw4i4o},
-                                                       {"gOIw4o4i", 
tag::gOIw4o4i},
-                                                       {"gOiw4o", tag::gOiw4o},
-                                                       {"gOIw8i16o2i", 
tag::gOIw8i16o2i},
-                                                       {"gOIw8i8o", 
tag::gOIw8i8o},
-                                                       {"gOIw8o16i2o", 
tag::gOIw8o16i2o},
-                                                       {"gOIw8o8i", 
tag::gOIw8o8i},
-                                                       {"gOIw8o4i", 
tag::gOIw8o4i},
-                                                       {"gOIw16i16o4i", 
tag::gOIw16i16o4i},
-                                                       {"gOIw16i16o2i", 
tag::gOIw16i16o2i},
-                                                       {"gOIw16o16i2o", 
tag::gOIw16o16i2o},
-                                                       {"gOwi16o", 
tag::gOwi16o},
-                                                       {"gOwI16o2i", 
tag::gOwI16o2i},
-                                                       {"gOwi4o", tag::gOwi4o},
-                                                       {"gOwi8o", tag::gOwi8o},
-                                                       {"Goiw8g", tag::Goiw8g},
-                                                       {"Goiw16g", 
tag::Goiw16g},
-                                                       {"gIOhw16o16i", 
tag::gIOhw16o16i},
-                                                       {"gOhwi16o", 
tag::gOhwi16o},
-                                                       {"gOhwI16o2i", 
tag::gOhwI16o2i},
-                                                       {"gOhwi4o", 
tag::gOhwi4o},
-                                                       {"gOhwi8o", 
tag::gOhwi8o},
-                                                       {"Goihw16g", 
tag::Goihw16g},
-                                                       {"gOIhw16i16o", 
tag::gOIhw16i16o},
-                                                       {"gOIhw16o16i", 
tag::gOIhw16o16i},
-                                                       {"gOihw16o", 
tag::gOihw16o},
-                                                       {"gOIhw4i16o4i", 
tag::gOIhw4i16o4i},
-                                                       {"gOIhw2i8o4i", 
tag::gOIhw2i8o4i},
-                                                       {"gOIhw4i4o", 
tag::gOIhw4i4o},
-                                                       {"gOIhw4o4i", 
tag::gOIhw4o4i},
-                                                       {"gOihw4o", 
tag::gOihw4o},
-                                                       {"Goihw8g", 
tag::Goihw8g},
-                                                       {"gOIhw8i16o2i", 
tag::gOIhw8i16o2i},
-                                                       {"gOIhw8i8o", 
tag::gOIhw8i8o},
-                                                       {"gOIhw8o16i2o", 
tag::gOIhw8o16i2o},
-                                                       {"OIw4o8i8o4i", 
tag::OIw4o8i8o4i},
-                                                       {"OIdhw4o8i8o4i", 
tag::OIdhw4o8i8o4i},
-                                                       {"OIhw4o8i8o4i", 
tag::OIhw4o8i8o4i},
-                                                       {"OIhw2o8i8o2i", 
tag::OIhw2o8i8o2i},
-                                                       {"gOIw4o8i8o4i", 
tag::gOIw4o8i8o4i},
-                                                       {"gOIdhw4o8i8o4i", 
tag::gOIdhw4o8i8o4i},
-                                                       {"gOIhw4o8i8o4i", 
tag::gOIhw4o8i8o4i},
-                                                       {"gOIhw2o8i8o2i", 
tag::gOIhw2o8i8o2i},
-                                                       {"OIhw16i16o4i", 
tag::OIhw16i16o4i},
-                                                       {"OIhw16i32o4i", 
tag::OIhw16i32o4i},
-                                                       {"OIhw16i48o4i", 
tag::OIhw16i48o4i},
-                                                       {"OIhw16i64o4i", 
tag::OIhw16i64o4i},
-                                                       {"OIhw16i16o2i", 
tag::OIhw16i16o2i},
-                                                       {"OIhw16i32o2i", 
tag::OIhw16i32o2i},
-                                                       {"OIhw16i48o2i", 
tag::OIhw16i48o2i},
-                                                       {"OIhw16i64o2i", 
tag::OIhw16i64o2i},
-                                                       {"OIhw16o16i2o", 
tag::OIhw16o16i2o},
-                                                       {"gOIhw16i16o4i", 
tag::gOIhw16i16o4i},
-                                                       {"gOIhw16i16o2i", 
tag::gOIhw16i16o2i},
-                                                       {"gOIhw16o16i2o", 
tag::gOIhw16o16i2o},
-                                                       {"gOIhw8o8i", 
tag::gOIhw8o8i},
-                                                       {"gOIhw8o4i", 
tag::gOIhw8o4i},
-                                                       {"gIOdhw16i16o", 
tag::gIOdhw16i16o},
-                                                       {"gIOdhw16o16i", 
tag::gIOdhw16o16i},
-                                                       {"gOdhwi16o", 
tag::gOdhwi16o},
-                                                       {"gOdhwI16o2i", 
tag::gOdhwI16o2i},
-                                                       {"gOdhwi4o", 
tag::gOdhwi4o},
-                                                       {"gOdhwi8o", 
tag::gOdhwi8o},
-                                                       {"gOIdhw16i16o", 
tag::gOIdhw16i16o},
-                                                       {"gOIdhw16o16i", 
tag::gOIdhw16o16i},
-                                                       {"gOidhw16o", 
tag::gOidhw16o},
-                                                       {"gOIdhw4i4o", 
tag::gOIdhw4i4o},
-                                                       {"gOIdhw4o4i", 
tag::gOIdhw4o4i},
-                                                       {"gOidhw4o", 
tag::gOidhw4o},
-                                                       {"gOIdhw8i16o2i", 
tag::gOIdhw8i16o2i},
-                                                       {"gOIdhw4i16o4i", 
tag::gOIdhw4i16o4i},
-                                                       {"gOIdhw16i16o4i", 
tag::gOIdhw16i16o4i},
-                                                       {"gOIdhw16i16o2i", 
tag::gOIdhw16i16o2i},
-                                                       {"gOIdhw2i8o4i", 
tag::gOIdhw2i8o4i},
-                                                       {"gOIdhw8i8o", 
tag::gOIdhw8i8o},
-                                                       {"gOIdhw8o8i", 
tag::gOIdhw8o8i},
-                                                       {"gOIdhw8o4i", 
tag::gOIdhw8o4i},
-                                                       {"gOIw2i4o2i", 
tag::gOIw2i4o2i},
-                                                       {"gOIhw2i4o2i", 
tag::gOIhw2i4o2i},
-                                                       {"gOIdhw2i4o2i", 
tag::gOIdhw2i4o2i},
-                                                       {"gOIw2o4i2o", 
tag::gOIw2o4i2o},
-                                                       {"gOIhw2o4i2o", 
tag::gOIhw2o4i2o},
-                                                       {"gOIdhw2o4i2o", 
tag::gOIdhw2o4i2o},
-                                                       {"gOIw4i8o2i", 
tag::gOIw4i8o2i},
-                                                       {"gOIhw4i8o2i", 
tag::gOIhw4i8o2i},
-                                                       {"gOIdhw4i8o2i", 
tag::gOIdhw4i8o2i},
-                                                       {"gOIw4o8i2o", 
tag::gOIw4o8i2o},
-                                                       {"gOIhw4o8i2o", 
tag::gOIhw4o8i2o},
-                                                       {"gOIdhw4o8i2o", 
tag::gOIdhw4o8i2o},
-                                                       {"ldOi32o", 
tag::ldOi32o},
-                                                       {"ldOI32o4i", 
tag::ldOI32o4i},
-                                                       {"ldgOi32o", 
tag::ldgOi32o},
-                                                       {"ldgOI32o2i", 
tag::ldgOI32o2i},
-                                                       {"ldgOI32o4i", 
tag::ldgOI32o4i},
-                                                       {"OwI16o4i", 
tag::OwI16o4i},
-                                                       {"OhwI16o4i", 
tag::OhwI16o4i},
-                                                       {"gOwI16o4i", 
tag::gOwI16o4i},
-                                                       {"gOhwI16o4i", 
tag::gOhwI16o4i},
-                                                       {"OdhwI16o4i", 
tag::OdhwI16o4i},
-                                                       {"gOdhwI16o4i", 
tag::gOdhwI16o4i},
-                                                       {"Owi32o", tag::Owi32o},
-                                                       {"OwI32o2i", 
tag::OwI32o2i},
-                                                       {"OwI32o4i", 
tag::OwI32o4i},
-                                                       {"Owi48o", tag::Owi48o},
-                                                       {"OwI48o2i", 
tag::OwI48o2i},
-                                                       {"OwI48o4i", 
tag::OwI48o4i},
-                                                       {"Owi64o", tag::Owi64o},
-                                                       {"OwI64o2i", 
tag::OwI64o2i},
-                                                       {"OwI64o4i", 
tag::OwI64o4i},
-                                                       {"wIo2i", tag::wIo2i},
-                                                       {"wIo4i", tag::wIo4i},
-                                                       {"gOwi32o", 
tag::gOwi32o},
-                                                       {"gOwI32o2i", 
tag::gOwI32o2i},
-                                                       {"gOwI32o4i", 
tag::gOwI32o4i},
-                                                       {"gOwi48o", 
tag::gOwi48o},
-                                                       {"gOwI48o2i", 
tag::gOwI48o2i},
-                                                       {"gOwI48o4i", 
tag::gOwI48o4i},
-                                                       {"gOwi64o", 
tag::gOwi64o},
-                                                       {"gOwI64o2i", 
tag::gOwI64o2i},
-                                                       {"gOwI64o4i", 
tag::gOwI64o4i},
-                                                       {"gwio", tag::gwio},
-                                                       {"gwIo2i", tag::gwIo2i},
-                                                       {"gwIo4i", tag::gwIo4i},
-                                                       {"OhwI32o", 
tag::OhwI32o},
-                                                       {"OhwI32o2i", 
tag::OhwI32o2i},
-                                                       {"OhwI32o4i", 
tag::OhwI32o4i},
-                                                       {"Ohwi48o", 
tag::Ohwi48o},
-                                                       {"OhwI48o2i", 
tag::OhwI48o2i},
-                                                       {"OhwI48o4i", 
tag::OhwI48o4i},
-                                                       {"Ohwi64o", 
tag::Ohwi64o},
-                                                       {"OhwI64o2i", 
tag::OhwI64o2i},
-                                                       {"OhwI64o4i", 
tag::OhwI64o4i},
-                                                       {"hwIo2i", tag::hwIo2i},
-                                                       {"hwIo4i", tag::hwIo4i},
-                                                       {"gOhwI32o", 
tag::gOhwI32o},
-                                                       {"gOhwI32o2i", 
tag::gOhwI32o2i},
-                                                       {"gOhwI32o4i", 
tag::gOhwI32o4i},
-                                                       {"gOhwi48o", 
tag::gOhwi48o},
-                                                       {"gOhwI48o2i", 
tag::gOhwI48o2i},
-                                                       {"gOhwI48o4i", 
tag::gOhwI48o4i},
-                                                       {"gOhwi64o", 
tag::gOhwi64o},
-                                                       {"gOhwI64o2i", 
tag::gOhwI64o2i},
-                                                       {"gOhwI64o4i", 
tag::gOhwI64o4i},
-                                                       {"ghwio", tag::ghwio},
-                                                       {"ghwIo2i", 
tag::ghwIo2i},
-                                                       {"ghwIo4i", 
tag::ghwIo4i},
-                                                       {"Odhwi32o", 
tag::Odhwi32o},
-                                                       {"OdhwI32o2i", 
tag::OdhwI32o2i},
-                                                       {"OdhwI32o4i", 
tag::OdhwI32o4i},
-                                                       {"Odhwi48o", 
tag::Odhwi48o},
-                                                       {"OdhwI48o2i", 
tag::OdhwI48o2i},
-                                                       {"OdhwI48o4i", 
tag::OdhwI48o4i},
-                                                       {"Odhwi64o", 
tag::Odhwi64o},
-                                                       {"OdhwI64o2i", 
tag::OdhwI64o2i},
-                                                       {"OdhwI64o4i", 
tag::OdhwI64o4i},
-                                                       {"dhwIo2i", 
tag::dhwIo2i},
-                                                       {"dhwIo4i", 
tag::dhwIo4i},
-                                                       {"gOdhwi32o", 
tag::gOdhwi32o},
-                                                       {"gOdhwI32o2i", 
tag::gOdhwI32o2i},
-                                                       {"gOdhwI32o4i", 
tag::gOdhwI32o4i},
-                                                       {"gOdhwi48o", 
tag::gOdhwi48o},
-                                                       {"gOdhwI48o2i", 
tag::gOdhwI48o2i},
-                                                       {"gOdhwI48o4i", 
tag::gOdhwI48o4i},
-                                                       {"gOdhwi64o", 
tag::gOdhwi64o},
-                                                       {"gOdhwI64o2i", 
tag::gOdhwI64o2i},
-                                                       {"gOdhwI64o4i", 
tag::gOdhwI64o4i},
-                                                       {"gdhwio", tag::gdhwio},
-                                                       {"gdhwIo2i", 
tag::gdhwIo2i},
-                                                       {"gdhwIo4i", 
tag::gdhwIo4i},
-                                                       {"ldIo32i", 
tag::ldIo32i},
-                                                       {"ldgIo32i", 
tag::ldgIo32i},
-                                                       {"ldgIO32i2o", 
tag::ldgIO32i2o},
-                                                       {"nCdhw32c", 
tag::nCdhw32c},
-                                                       {"nChw32c", 
tag::nChw32c},
-                                                       {"nCw32c", tag::nCw32c},
-                                                       {"NCw32n16c", 
tag::NCw32n16c},
-                                                       {"NChw32n16c", 
tag::NChw32n16c},
-                                                       {"NCdhw32n16c", 
tag::NCdhw32n16c},
-                                                       {"NCw32n32c", 
tag::NCw32n32c},
-                                                       {"OI16i16o4i", 
tag::OI16i16o4i},
-                                                       {"IOw8o16i2o", 
tag::IOw8o16i2o},
-                                                       {"IOhw8o16i2o", 
tag::IOhw8o16i2o},
-                                                       {"Owhi16o", 
tag::Owhi16o},
-                                                       {"OIdhw8o16i2o", 
tag::OIdhw8o16i2o},
-                                                       {"IOdhw8o16i2o", 
tag::IOdhw8o16i2o},
-                                                       {"Goiw4g", tag::Goiw4g},
-                                                       {"gIOw8o16i2o", 
tag::gIOw8o16i2o},
-                                                       {"Goiw32g", 
tag::Goiw32g},
-                                                       {"Goihw4g", 
tag::Goihw4g},
-                                                       {"gIOhw8o16i2o", 
tag::gIOhw8o16i2o},
-                                                       {"Goihw32g", 
tag::Goihw32g},
-                                                       {"gOwhi16o", 
tag::gOwhi16o},
-                                                       {"IOw4i8o8i4o", 
tag::IOw4i8o8i4o},
-                                                       {"IOhw4i8o8i4o", 
tag::IOhw4i8o8i4o},
-                                                       {"IOdhw4i8o8i4o", 
tag::IOdhw4i8o8i4o},
-                                                       {"gIOw4i8o8i4o", 
tag::gIOw4i8o8i4o},
-                                                       {"gIOhw4i8o8i4o", 
tag::gIOhw4i8o8i4o},
-                                                       {"gIOdhw4i8o8i4o", 
tag::gIOdhw4i8o8i4o},
-                                                       {"gOIdhw8o16i2o", 
tag::gOIdhw8o16i2o},
-                                                       {"gIOdhw8o16i2o", 
tag::gIOdhw8o16i2o},
-                                                       {"Goidhw32g", 
tag::Goidhw32g},
-                                                       {"OI16i32o4i", 
tag::OI16i32o4i},
-                                                       {"OI16i48o4i", 
tag::OI16i48o4i},
-                                                       {"OI16i64o4i", 
tag::OI16i64o4i},
-                                                       {"OI16i16o2i", 
tag::OI16i16o2i},
-                                                       {"OI16i32o2i", 
tag::OI16i32o2i},
-                                                       {"OI16i48o2i", 
tag::OI16i48o2i},
-                                                       {"OI16i64o2i", 
tag::OI16i64o2i},
-                                                       {"OwI16i16o2i", 
tag::OwI16i16o2i},
-                                                       {"gOwI16i16o2i", 
tag::gOwI16i16o2i},
-                                                       {"OhwI16i16o2i", 
tag::OhwI16i16o2i},
-                                                       {"gOhwI16i16o2i", 
tag::gOhwI16i16o2i},
-                                                       {"OdhwI16i16o2i", 
tag::OdhwI16i16o2i},
-                                                       {"gOdhwI16i16o2i", 
tag::gOdhwI16i16o2i},
-                                                       {"OwI16i16o4i", 
tag::OwI16i16o4i},
-                                                       {"gOwI16i16o4i", 
tag::gOwI16i16o4i},
-                                                       {"OhwI16i16o4i", 
tag::OhwI16i16o4i},
-                                                       {"gOhwI16i16o4i", 
tag::gOhwI16i16o4i},
-                                                       {"OdhwI16i16o4i", 
tag::OdhwI16i16o4i},
-                                                       {"gOdhwI16i16o4i", 
tag::gOdhwI16i16o4i},
-                                                       {"OwI16i32o2i", 
tag::OwI16i32o2i},
-                                                       {"OwI16i32o4i", 
tag::OwI16i32o4i},
-                                                       {"OwI16i48o2i", 
tag::OwI16i48o2i},
-                                                       {"OwI16i48o4i", 
tag::OwI16i48o4i},
-                                                       {"OwI16i64o2i", 
tag::OwI16i64o2i},
-                                                       {"OwI16i64o4i", 
tag::OwI16i64o4i},
-                                                       {"gOwI16i32o2i", 
tag::gOwI16i32o2i},
-                                                       {"gOwI16i32o4i", 
tag::gOwI16i32o4i},
-                                                       {"gOwI16i48o2i", 
tag::gOwI16i48o2i},
-                                                       {"gOwI16i48o4i", 
tag::gOwI16i48o4i},
-                                                       {"gOwI16i64o2i", 
tag::gOwI16i64o2i},
-                                                       {"gOwI16i64o4i", 
tag::gOwI16i64o4i},
-                                                       {"OhwI16i32o2i", 
tag::OhwI16i32o2i},
-                                                       {"OhwI16i32o4i", 
tag::OhwI16i32o4i},
-                                                       {"OhwI16i48o2i", 
tag::OhwI16i48o2i},
-                                                       {"OhwI16i48o4i", 
tag::OhwI16i48o4i},
-                                                       {"OhwI16i64o2i", 
tag::OhwI16i64o2i},
-                                                       {"OhwI16i64o4i", 
tag::OhwI16i64o4i},
-                                                       {"gOhwI16i32o2i", 
tag::gOhwI16i32o2i},
-                                                       {"gOhwI16i32o4i", 
tag::gOhwI16i32o4i},
-                                                       {"gOhwI16i48o2i", 
tag::gOhwI16i48o2i},
-                                                       {"gOhwI16i48o4i", 
tag::gOhwI16i48o4i},
-                                                       {"gOhwI16i64o2i", 
tag::gOhwI16i64o2i},
-                                                       {"gOhwI16i64o4i", 
tag::gOhwI16i64o4i},
-                                                       {"OdhwI16i32o2i", 
tag::OdhwI16i32o2i},
-                                                       {"OdhwI16i32o4i", 
tag::OdhwI16i32o4i},
-                                                       {"OdhwI16i48o2i", 
tag::OdhwI16i48o2i},
-                                                       {"OdhwI16i48o4i", 
tag::OdhwI16i48o4i},
-                                                       {"OdhwI16i64o2i", 
tag::OdhwI16i64o2i},
-                                                       {"OdhwI16i64o4i", 
tag::OdhwI16i64o4i},
-                                                       {"gOdhwI16i32o2i", 
tag::gOdhwI16i32o2i},
-                                                       {"gOdhwI16i32o4i", 
tag::gOdhwI16i32o4i},
-                                                       {"gOdhwI16i48o2i", 
tag::gOdhwI16i48o2i},
-                                                       {"gOdhwI16i48o4i", 
tag::gOdhwI16i48o4i},
-                                                       {"gOdhwI16i64o2i", 
tag::gOdhwI16i64o2i},
-                                                       {"gOdhwI16i64o4i", 
tag::gOdhwI16i64o4i},
-                                                       {"hwioG16g", 
tag::hwioG16g},
-                                                       {"NCdhw40n32c", 
tag::NCdhw40n32c},
-                                                       {"NChw40n32c", 
tag::NChw40n32c},
-                                                       {"NCw40n32c", 
tag::NCw40n32c},
-                                                       {"OIdhw4o8i8o2i", 
tag::OIdhw4o8i8o2i},
-                                                       {"OIhw4o8i8o2i", 
tag::OIhw4o8i8o2i},
-                                                       {"OIw4o8i8o2i", 
tag::OIw4o8i8o2i},
-                                                       {"gOIdhw4o8i8o2i", 
tag::gOIdhw4o8i8o2i},
-                                                       {"gOIhw4o8i8o2i", 
tag::gOIhw4o8i8o2i},
-                                                       {"gOIw4o8i8o2i", 
tag::gOIw4o8i8o2i},
-                                                       {"IOdhw4i8o8i2o", 
tag::IOdhw4i8o8i2o},
-                                                       {"IOhw4i8o8i2o", 
tag::IOhw4i8o8i2o},
-                                                       {"IOw4i8o8i2o", 
tag::IOw4i8o8i2o},
-                                                       {"gIOdhw4i8o8i2o", 
tag::gIOdhw4i8o8i2o},
-                                                       {"gIOhw4i8o8i2o", 
tag::gIOhw4i8o8i2o},
-                                                       {"gIOw4i8o8i2o", 
tag::gIOw4i8o8i2o},
-                                                       {"NCdhw40n16c", 
tag::NCdhw40n16c},
-                                                       {"NCw40n16c", 
tag::NCw40n16c},
-                                                       {"NChw40n16c", 
tag::NChw40n16c},
-                                                       {"NCw2c32n8c", 
tag::NCw2c32n8c},
-                                                       {"NChw2c32n8c", 
tag::NChw2c32n8c},
-                                                       {"NCdhw2c32n8c", 
tag::NCdhw2c32n8c},
-                                                       {"OIw2i8o16i4o", 
tag::OIw2i8o16i4o},
-                                                       {"OIhw2i8o16i4o", 
tag::OIhw2i8o16i4o},
-                                                       {"OIdhw2i8o16i4o", 
tag::OIdhw2i8o16i4o},
-                                                       {"OIw2o8i16o4i", 
tag::OIw2o8i16o4i},
-                                                       {"OIw2o8i16o2i", 
tag::OIw2o8i16o2i},
-                                                       {"IOw2i8o16i4o", 
tag::IOw2i8o16i4o},
-                                                       {"IOw2i8o16i2o", 
tag::IOw2i8o16i2o},
-                                                       {"OIhw2o8i16o4i", 
tag::OIhw2o8i16o4i},
-                                                       {"OIhw2o8i16o2i", 
tag::OIhw2o8i16o2i},
-                                                       {"IOhw2i8o16i4o", 
tag::IOhw2i8o16i4o},
-                                                       {"IOhw2i8o16i2o", 
tag::IOhw2i8o16i2o},
-                                                       {"OIdhw2o8i16o4i", 
tag::OIdhw2o8i16o4i},
-                                                       {"OIdhw2o8i16o2i", 
tag::OIdhw2o8i16o2i},
-                                                       {"IOdhw2i8o16i4o", 
tag::IOdhw2i8o16i4o},
-                                                       {"IOdhw2i8o16i2o", 
tag::IOdhw2i8o16i2o},
-                                                       {"gOIw2o8i16o2i", 
tag::gOIw2o8i16o2i},
-                                                       {"gIOw2i8o16i2o", 
tag::gIOw2i8o16i2o},
-                                                       {"gIOhw2i8o16i2o", 
tag::gIOhw2i8o16i2o},
-                                                       {"gIOdhw2i8o16i2o", 
tag::gIOdhw2i8o16i2o},
-                                                       {"gOIhw2o8i16o2i", 
tag::gOIhw2o8i16o2i},
-                                                       {"gOIdhw2o8i16o2i", 
tag::gOIdhw2o8i16o2i},
-                                                       {"gOIw2o8i16o4i", 
tag::gOIw2o8i16o4i},
-                                                       {"gOIhw2o8i16o4i", 
tag::gOIhw2o8i16o4i}};
-    std::string key = "";
-    for (const auto& c : layout) {
-      if (std::isalpha(c, std::locale("C"))) {
-        char lower_c = std::tolower(c);
-        if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) {
-          key.push_back(c);
-        } else {
-          key.push_back(lower_c);
-        }
-      } else if (std::isdigit(c)) {
-        key.push_back(c);
-      } else {
-        LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl;
-      }
-    }
-    if (str2tag.count(key) == 0) {
-      LOG(WARNING) << "convert unregistered layout '" << key << "' to 
tag::any";
-      return tag::any;
+  /* Override GetFunction to reimplement Run method */
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) override {
+    if (this->symbol_name_ == name) {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        ICHECK(this->initialized_) << "The module has not been initialized";
+
+        ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
+            << "Found mismatch in the number of provided data entries and 
required.";
+
+        Run(args);
+      });
     } else {
-      return str2tag.at(key);
+      return JSONRuntimeBase::GetFunction(name, sptr_to_self);
+    }
+  }
+
+  /* Same as makeInitDataProvider but in case of InputOutput return real 
DLTensor */
+  TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args) 
const {
+    auto extract_dl_tensor = [](const TVMArgValue& val) -> const DLTensor* {
+      ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() == 
kTVMDLTensorHandle)
+          << "Expect NDArray or DLTensor";
+      return val.IsObjectRef<NDArray>() ? val.operator NDArray().operator->()
+                                        : val.operator DLTensor*();
+    };
+
+    std::map<uint32_t, const DLTensor*> io_map;  // eid to dl tensor map
+    for (size_t i = 0; i < run_arg_eid_.size(); i++) {
+      io_map[run_arg_eid_[i]] = extract_dl_tensor(args[i]);
     }
+
+    // lambda with captured IO data handlers
+    return [io_map](uint32_t eid) -> const DLTensor* { return io_map.at(eid); 
};
   }
 
-  std::map<std::string, dnnl::algorithm> elt_name2algo{
+ private:
+  const std::map<std::string, dnnl::algorithm> elt_name2algo{
       {"abs", dnnl::algorithm::eltwise_abs},
       {"exp", dnnl::algorithm::eltwise_exp},
       {"log", dnnl::algorithm::eltwise_log},
@@ -626,64 +161,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     return std::regex_match(op_name, bias_add_pat) ? true : false;
   }
 
-  dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims, 
std::string layout) {
-    std::vector<char> axis = {
-        'N', 'C', 'O', 'I', 'D', 'H', 'W',
-    };
-    dnnl::memory::dims out_dims;
-    std::string::iterator t = layout.begin();
-    // Remove numbers in layout string to match the size of input_dims
-    while (t != layout.end()) {
-      if (*t >= '0' && *t <= '9') {
-        layout.erase(t);
-      } else {
-        t++;
-      }
-    }
-    // Push the correct shapes of each axis into the output_dims
-    for (auto a : axis) {
-      if (layout.find(a) != std::string::npos) {
-        dnnl::memory::dim shape = input_dims[layout.find(a)];
-        char lower_a = std::tolower(a);
-        for (size_t i = 0; i < layout.size(); ++i) {
-          if (lower_a == layout[i]) {
-            shape *= input_dims[i];
-          }
-        }
-        out_dims.push_back(shape);
-      }
-    }
-    // Multiply O and I with G, respectively
-    if (layout.find("G") != std::string::npos) {
-      dnnl::memory::dim G = 1;
-      if (layout.find("g") != std::string::npos) {
-        G = input_dims[layout.find("g")] * input_dims[layout.find("G")];
-      } else {
-        G = input_dims[layout.find("G")];
-      }
-      out_dims[0] *= G;
-      out_dims[1] *= G;
-    }
-    return out_dims;
-  }
-
-  dnnl::memory::dims TransformStr2Dims(std::vector<std::string> strs, bool 
dilates = false) {
-    dnnl::memory::dims out_dims;
-    if (dilates) {
-      std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims),
-                     [](const std::string& str) { return std::stoi(str) - 1; 
});
-    } else {
-      std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims),
-                     [](const std::string& str) { return std::stoi(str); });
-    }
-    return out_dims;
-  }
-
   // Build up the engine based on the input graph.
   void BuildEngine() {
     engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
     stream_ = dnnl::stream(engine_);
 
+    std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
+    tensor_registry_ = TensorRegistry(engine_, io_eid_set);
+
     std::regex conv_pat(".*conv[1-3]d.*");
     std::regex deconv_pat(".*deconv[1-3]d.*");
     std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
@@ -725,562 +210,471 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     }
   }
 
-  // Bind a JSON graph node entry to a DNNL memory.
-  dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, 
dnnl::memory::desc mem_desc,
-                              size_t offset = 0) {
-    auto eid = EntryID(entry);
-    if (entry_out_mem_.count(eid) == 0) {
-      return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset);
-    }
-    return entry_out_mem_[eid].first;
-  }
-
-  // Bind a JSON graph node entry to a given DNNL memory.
-  dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory 
mem,
-                              size_t offset = 0) {
-    auto eid = EntryID(entry);
-    // Since the DNNL memory has been created before calling this function, we 
assume the entry
-    // has not yet been bound to the other DNNL memory; otherwise it may have 
memory leak.
-    ICHECK_EQ(entry_out_mem_.count(eid), 0);
-
-    entry_out_mem_[eid] = {mem, offset};
-    return entry_out_mem_[eid].first;
-  }
-
   void Convolution(const size_t& nid) {
     auto node = nodes_[nid];
     auto op_name = node.GetOpName();
     dnnl::primitive_attr attr;
+    attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
     bool has_bias = ParsingOpName(op_name, attr);
 
     // Setup attributes.
-    auto data_entry = node.GetInputs()[0];
-    auto weight_entry = node.GetInputs()[1];
-    JSONGraphNodeEntry out_entry(nid, 0);
-    dnnl::memory::dims input_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::dims weight_shape = 
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
-    dnnl::memory::dims out_shape = 
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
-    dnnl::memory::dim channels =
-        node.GetAttr<std::vector<std::string>>("channels")[0] != ""
-            ? std::stoi(node.GetAttr<std::vector<std::string>>("channels")[0])
-            : out_shape[1];
-    std::vector<std::string> str_strides = 
node.GetAttr<std::vector<std::string>>("strides");
-    std::vector<std::string> str_dilates = 
node.GetAttr<std::vector<std::string>>("dilation");
-    std::vector<std::string> str_padding = 
node.GetAttr<std::vector<std::string>>("padding");
-    std::vector<std::string> str_padding_l(str_padding.begin(),
-                                           str_padding.begin() + 
str_padding.size() / 2);
-    std::vector<std::string> str_padding_r(str_padding.end() - 
str_padding.size() / 2,
-                                           str_padding.end());
-    dnnl::memory::dim groups = 
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
-    std::string data_layout = 
node.GetAttr<std::vector<std::string>>("data_layout")[0];
-    std::string kernel_layout = 
node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
-
-    // Memory shapes.
-    dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
-    dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, 
kernel_layout);
-    dnnl::memory::dims bias_dims = {channels};
-    dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
-    dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
-    dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
-    dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
-    dnnl::memory::dims dst_dims = src_dims;
-    dst_dims[1] = channels;
-    weights_dims_[0] = channels;
-    weights_dims_[1] = src_dims[1];
-    for (size_t i = 2; i < src_dims.size(); i++) {
-      dnnl::memory::dim K = weights_dims_[i];
-      dnnl::memory::dim S = strides_dims[i - 2];
-      dnnl::memory::dim D = dilates_dims[i - 2];
-      dnnl::memory::dim PL = padding_dims_l[i - 2];
-      dnnl::memory::dim PR = padding_dims_r[i - 2];
-      dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
-      dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1;
+    auto src_tr = GetInput(nid, 0);
+    auto wgh_tr = GetInput(nid, 1);
+    auto dst_tr = GetOutput(nid, 0);
+    auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+    auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+    auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+    auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+    std::vector<int64_t> padding_l(padding.begin(), padding.begin() + 
padding.size() / 2);
+    std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2, 
padding.end());
+    auto groups = GetNodeAttr<int>(node, "groups");
+    auto src_layout = GetNodeAttr<std::string>(node, "data_layout");
+    auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+    auto wgh_layout = GetNodeAttr<std::string>(node, "kernel_layout");
+
+    // dst_layout == "" means to use data_layout
+    if (dst_layout.empty()) dst_layout = src_layout;
+
+    // Minus one for DNNL representation. No dilation for DNNL is 0, for relay 
is 1.
+    for (auto& d : dilates) d--;
+
+    // Take into account provided layout strings
+    src_tr = src_tr.TreatAs(src_layout);
+    dst_tr = dst_tr.TreatAs(dst_layout);
+    wgh_tr = wgh_tr.TreatAs(wgh_layout);
+
+    // Should support G mixed with O. Like { G*O, I, H, W }
+    // Use { G, O, I, H, W } weight format even if groups == 1
+    if (wgh_layout.find("G") == std::string::npos) {
+      auto w_dims = wgh_tr.dims();
+      w_dims[0] /= groups;
+      w_dims.insert(w_dims.begin(), groups);
+      wgh_tr = wgh_tr.Reshape(w_dims);
     }
 
-    dnnl::memory::dims weights_dims = weights_dims_;
-    if (groups > 1) {
-      weights_dims = {groups, channels / groups, src_dims[1] / groups};
-      weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, 
weights_dims_.end());
-      if (kernel_layout == "OIHW") {
-        kernel_layout.insert(0, "G");
-      }
+    // Assumption that bias is correct and can be squeezed to 1D
+    bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
+
+    // TODO(@apeskov): This is WA. In case of padded blocked tensor format we 
do not know original
+    //  shapes. Example tensor {1, 10, 224, 224} with layout "NCNH8c" will 
lead to tensor
+    //  {1, 2, 224, 224, 8}. Identically as for shapes {1, 11, 224, 224} or 
{1, 15, 224, 224}.
+    //
+    // Let's try to compensate it for weight tensor. Weight IC should match 
with source IC.
+    // Example src: [1, 3, 224, 224] with layout NCHW
+    //         wgh: [16, 3, 3, 3] with layout OIHW2i8o -> [2, 2, 3, 3, 2, 8]
+    if (wgh_tr.dims()[2] != src_tr.dims()[1] / groups) {
+      auto wgh_croped_dims = wgh_tr.dims();
+      wgh_croped_dims[2] = src_tr.dims()[1];
+      auto zero_offset = dnnl::memory::dims(wgh_tr.dims().size(), 0);
+      wgh_tr = wgh_tr.Crop(wgh_croped_dims, zero_offset);
     }
 
-    // Memory descriptions.
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    auto conv_src_md = dnnl::memory::desc(src_dims, dtype, 
layout2tag(data_layout));
-    auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, 
layout2tag(kernel_layout));
-    auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any);
-    auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
     // Conv description.
-    auto conv_desc =
-        has_bias ? dnnl::convolution_forward::desc(
-                       dnnl::prop_kind::forward_inference, 
dnnl::algorithm::convolution_direct,
-                       conv_src_md, conv_weights_md, conv_bias_md, 
conv_dst_md, strides_dims,
-                       dilates_dims, padding_dims_l, padding_dims_r)
-                 : 
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference,
-                                                   
dnnl::algorithm::convolution_direct, conv_src_md,
-                                                   conv_weights_md, 
conv_dst_md, strides_dims,
-                                                   dilates_dims, 
padding_dims_l, padding_dims_r);
+    auto conv_desc = dnnl::convolution_forward::desc(
+        dnnl::prop_kind::forward_inference, 
dnnl::algorithm::convolution_direct,
+        src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), 
bias_tr.LayoutAny().desc(),
+        dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r);
 
     // Enable elementwise post-ops.
     auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, 
attr, engine_);
 
-    // Push to the network.
-    auto conv = dnnl::convolution_forward(conv_prim_desc);
-    net_.push_back(conv);
-
-    // Data memory.
-    auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md);
+    src_tr = src_tr.RequestLayout(conv_prim_desc.src_desc());
+    wgh_tr = wgh_tr.RequestLayout(conv_prim_desc.weights_desc());
+    dst_tr = dst_tr.RequestLayout(conv_prim_desc.dst_desc());
+    bias_tr = bias_tr.RequestLayout(conv_prim_desc.bias_desc());
 
-    // Weight memory.
-    auto conv_weights_memory = BindDNNLMemory(weight_entry, 
conv_prim_desc.weights_desc());
+    auto scratchpad_tr = 
TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc());
 
-    // Output memory.
-    auto conv_dst_memory = BindDNNLMemory(out_entry, 
conv_prim_desc.dst_desc());
-
-    // Bias memory.
-    auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
-    if (has_bias) {
-      auto bias_entry = node.GetInputs()[2];
-      BindDNNLMemory(bias_entry, conv_bias_memory);
-
-      // Bind memory buffers.
-      net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory},
-                           {DNNL_ARG_WEIGHTS, conv_weights_memory},
-                           {DNNL_ARG_BIAS, conv_bias_memory},
-                           {DNNL_ARG_DST, conv_dst_memory}});
-    } else {
-      // Bind memory buffers.
-      net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory},
-                           {DNNL_ARG_WEIGHTS, conv_weights_memory},
-                           {DNNL_ARG_DST, conv_dst_memory}});
-    }
+    Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr},
+                                                       {DNNL_ARG_WEIGHTS, 
wgh_tr},
+                                                       {DNNL_ARG_BIAS, 
bias_tr},
+                                                       {DNNL_ARG_SCRATCHPAD, 
scratchpad_tr},
+                                                       {DNNL_ARG_DST, 
dst_tr}});
   }
 
   void Deconvolution(const size_t& nid) {
     auto node = nodes_[nid];
     auto op_name = node.GetOpName();
     dnnl::primitive_attr attr;
+    attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
     bool has_bias = ParsingOpName(op_name, attr);
 
     // Setup attributes.
-    auto data_entry = node.GetInputs()[0];
-    auto weight_entry = node.GetInputs()[1];
-    JSONGraphNodeEntry out_entry(nid, 0);
-    dnnl::memory::dims input_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::dims weight_shape = 
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
-    dnnl::memory::dims out_shape = 
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
-    dnnl::memory::dim channels =
-        node.GetAttr<std::vector<std::string>>("channels")[0] != ""
-            ? std::stoi(node.GetAttr<std::vector<std::string>>("channels")[0])
-            : out_shape[1];
-    std::vector<std::string> str_strides = 
node.GetAttr<std::vector<std::string>>("strides");
-    std::vector<std::string> str_dilates = 
node.GetAttr<std::vector<std::string>>("dilation");
-    std::vector<std::string> str_padding = 
node.GetAttr<std::vector<std::string>>("padding");
-    std::vector<std::string> str_padding_l(str_padding.begin(),
-                                           str_padding.begin() + 
str_padding.size() / 2);
-    std::vector<std::string> str_padding_r(str_padding.end() - 
str_padding.size() / 2,
-                                           str_padding.end());
-    std::vector<std::string> str_out_padding =
-        node.GetAttr<std::vector<std::string>>("output_padding");
-    dnnl::memory::dim groups = 
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
-    std::string data_layout = 
node.GetAttr<std::vector<std::string>>("data_layout")[0];
-    std::string kernel_layout = 
node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
-
-    // Memory shapes.
-    dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
-    dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, 
kernel_layout);
-    // legalize shape IOHW with layout OIHW
-    if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) {
-      std::swap(weights_dims_[0], weights_dims_[1]);
-      if (kernel_layout.find("OI") == 0) {
-        kernel_layout.replace(kernel_layout.find("OI"), 2, "IO");
-      }
-    }
-    weights_dims_[0] = channels;
-    weights_dims_[1] = src_dims[1];
-    dnnl::memory::dims bias_dims = {channels};
-    dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
-    dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
-    dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
-    dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
-    dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding);
-    dnnl::memory::dims dst_dims = src_dims;
-    dst_dims[1] = channels;
-    for (size_t i = 2; i < src_dims.size(); i++) {
-      dnnl::memory::dim K = weights_dims_[i];
-      dnnl::memory::dim S = strides_dims[i - 2];
-      dnnl::memory::dim D = dilates_dims[i - 2];
-      dnnl::memory::dim PL = padding_dims_l[i - 2];
-      dnnl::memory::dim PR = padding_dims_r[i - 2];
-      dnnl::memory::dim OP = out_padding[i - 2];
-      dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
-      dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP;
+    auto src_tr = GetInput(nid, 0);
+    auto wgh_tr = GetInput(nid, 1);
+    auto dst_tr = GetOutput(nid, 0);
+    auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+
+    auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+    auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+    auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+    std::vector<int64_t> padding_l(padding.begin(), padding.begin() + 
padding.size() / 2);
+    std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2, 
padding.end());
+    auto groups = GetNodeAttr<int>(node, "groups");
+    auto src_layout = GetNodeAttr<std::string>(node, "data_layout");
+    auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+    auto wgh_layout = GetNodeAttr<std::string>(node, "kernel_layout");
+
+    // dst_layout == "" means to use data_layout
+    if (dst_layout.empty()) dst_layout = src_layout;
+
+    // Minus one for DNNL representation. No dilation for DNNL is 0, for relay 
is 1.
+    for (auto& d : dilates) d--;
+
+    // TODO(@apeskov): WA. conv3dTranspose uses wrong layout specifier. IO 
instead of OI.
+    auto wgh_logic_layout = TensorRequisite::DefaultLogicLayoutFor(wgh_layout);
+    if (wgh_logic_layout == "OIDHW") wgh_logic_layout = "IODHW";
+    if (wgh_logic_layout == "GOIDHW") wgh_logic_layout = "GIODHW";
+
+    // Take into account provided layout strings
+    src_tr = src_tr.TreatAs(src_layout);
+    dst_tr = dst_tr.TreatAs(dst_layout);
+    wgh_tr = wgh_tr.TreatAs(wgh_layout, wgh_logic_layout);
+
+    // Should support G mixed with O. Like { G*O, I, H, W }
+    if (wgh_layout.find("G") == std::string::npos) {
+      auto w_dims = wgh_tr.dims();
+      w_dims[0] /= groups;
+      w_dims.insert(w_dims.begin(), groups);
+      wgh_tr = wgh_tr.Reshape(w_dims);
     }
 
-    dnnl::memory::dims weights_dims = weights_dims_;
-    if (groups > 1) {
-      weights_dims = {groups, channels / groups, src_dims[1] / groups};
-      weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, 
weights_dims_.end());
-    }
+    // Assumption that bias is correct and can be squeezed to 1D
+    bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
 
-    // Memory descriptions.
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, 
layout2tag(data_layout));
-    auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, 
layout2tag(kernel_layout));
-    auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x);
-    auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
-    // Transposed covn2d description.
-    auto deconv_desc =
-        has_bias ? dnnl::deconvolution_forward::desc(
-                       dnnl::prop_kind::forward_inference, 
dnnl::algorithm::deconvolution_direct,
-                       deconv_src_md, deconv_weights_md, deconv_bias_md, 
deconv_dst_md,
-                       strides_dims, dilates_dims, padding_dims_l, 
padding_dims_r)
-                 : dnnl::deconvolution_forward::desc(
-                       dnnl::prop_kind::forward_inference, 
dnnl::algorithm::deconvolution_direct,
-                       deconv_src_md, deconv_weights_md, deconv_dst_md, 
strides_dims, dilates_dims,
-                       padding_dims_l, padding_dims_r);
+    // Conv description.
+    auto deconv_desc = dnnl::deconvolution_forward::desc(
+        dnnl::prop_kind::forward_inference, 
dnnl::algorithm::deconvolution_direct,
+        src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), 
bias_tr.LayoutAny().desc(),
+        dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r);
 
     // Enable elementwise post-ops.
     auto deconv_prim_desc = 
dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
 
-    // Push to the network.
-    auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);
-    net_.push_back(deconv);
-
-    // Data memory.
-    auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md);
-
-    // Weight memory.
-    auto deconv_weights_memory = BindDNNLMemory(weight_entry, 
deconv_prim_desc.weights_desc());
-
-    // Output memory.
-    auto deconv_dst_memory = BindDNNLMemory(out_entry, 
deconv_prim_desc.dst_desc());
+    src_tr = src_tr.RequestLayout(deconv_prim_desc.src_desc());
+    wgh_tr = wgh_tr.RequestLayout(deconv_prim_desc.weights_desc());
+    dst_tr = dst_tr.RequestLayout(deconv_prim_desc.dst_desc());
+    bias_tr = bias_tr.RequestLayout(deconv_prim_desc.bias_desc());
 
-    // Bias memory.
-    auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, 
engine_);
-    if (has_bias) {
-      auto bias_entry = node.GetInputs()[2];
-      BindDNNLMemory(bias_entry, deconv_bias_memory);
+    auto scratchpad_tr = 
TensorRequisite::AsIs(deconv_prim_desc.scratchpad_desc());
 
-      // Bind memory buffers.
-      net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory},
-                           {DNNL_ARG_WEIGHTS, deconv_weights_memory},
-                           {DNNL_ARG_BIAS, deconv_bias_memory},
-                           {DNNL_ARG_DST, deconv_dst_memory}});
-    } else {
-      // Bind memory buffers.
-      net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory},
-                           {DNNL_ARG_WEIGHTS, deconv_weights_memory},
-                           {DNNL_ARG_DST, deconv_dst_memory}});
-    }
+    Submit(dnnl::deconvolution_forward(deconv_prim_desc), {{DNNL_ARG_SRC, 
src_tr},
+                                                           {DNNL_ARG_WEIGHTS, 
wgh_tr},
+                                                           {DNNL_ARG_BIAS, 
bias_tr},
+                                                           
{DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+                                                           {DNNL_ARG_DST, 
dst_tr}});
   }
 
   void Dense(const size_t& nid) {
     auto node = nodes_[nid];
     auto op_name = node.GetOpName();
     dnnl::primitive_attr attr;
+    attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
     bool has_bias = ParsingOpName(op_name, attr);
 
     // Setup attributes.
-    auto data_entry = node.GetInputs()[0];
-    auto weight_entry = node.GetInputs()[1];
-    JSONGraphNodeEntry out_entry(nid, 0);
-    dnnl::memory::dims input_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::dims weight_shape = 
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
-    dnnl::memory::dims out_shape = 
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
-    dnnl::memory::dim OC = out_shape[1];
-
-    // Memory shapes.
-    dnnl::memory::dims data_dims = input_shape;
-    dnnl::memory::dims weight_dims = weight_shape;
-    dnnl::memory::dims bias_dims = {OC};
-    dnnl::memory::dims out_dims = out_shape;
-
-    // Memory descriptions.
-    auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
-    auto dtype = dtype_dl2dnnl(dl_dtype);
-    auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc});
-    auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc});
-    auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x});
-    auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc});
+    auto src_tr = GetInput(nid, 0);
+    auto wgh_tr = GetInput(nid, 1);
+    auto dst_tr = GetOutput(nid, 0);
+    auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+
+    // Assumption that bias is correct and can be squeezed to 1D
+    bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
 
     // Dense description.
-    auto dense_desc = 
dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
-                                                        weight_md, bias_md, 
dst_md);
+    auto dense_desc = dnnl::inner_product_forward::desc(
+        dnnl::prop_kind::forward_inference, src_tr.LayoutAny().desc(), 
wgh_tr.LayoutAny().desc(),
+        bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc());
 
     // Enable elementwise post-ops.
     auto dense_prim_desc = 
dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_);
 
-    auto dense = dnnl::inner_product_forward(dense_prim_desc);
-    net_.push_back(dense);
+    src_tr = src_tr.RequestLayout(dense_prim_desc.src_desc());
+    wgh_tr = wgh_tr.RequestLayout(dense_prim_desc.weights_desc());
+    dst_tr = dst_tr.RequestLayout(dense_prim_desc.dst_desc());
+    bias_tr = bias_tr.RequestLayout(dense_prim_desc.bias_desc());
 
-    // Memories.
-    auto data_memory = BindDNNLMemory(data_entry, data_md);
-    auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
+    auto scratchpad_tr = 
TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc());
 
-    // Bias memory.
-    auto bias_memory = dnnl::memory(bias_md, engine_);
-    if (has_bias) {
-      auto bias_entry = node.GetInputs()[2];
-      BindDNNLMemory(bias_entry, bias_memory);
-    } else {
-      float bias[OC] = {0};
-      write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8));
-    }
-
-    // Output memory.
-    auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
-
-    net_args_.push_back({{DNNL_ARG_SRC, data_memory},
-                         {DNNL_ARG_WEIGHTS, weight_memory},
-                         {DNNL_ARG_BIAS, bias_memory},
-                         {DNNL_ARG_DST, dst_memory}});
+    Submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC, 
src_tr},
+                                                          {DNNL_ARG_WEIGHTS, 
wgh_tr},
+                                                          {DNNL_ARG_BIAS, 
bias_tr},
+                                                          
{DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+                                                          {DNNL_ARG_DST, 
dst_tr}});
   }
 
   void BatchNorm(const size_t& nid) {
     auto node = nodes_[nid];
 
-    auto data_entry = node.GetInputs()[0];
-    auto gamma_entry = node.GetInputs()[1];
-    auto beta_entry = node.GetInputs()[2];
-    auto mean_entry = node.GetInputs()[3];
-    auto variance_entry = node.GetInputs()[4];
-    dnnl::memory::dims data_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::dim IC = data_shape[1];
-    float epsilon = 
std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
+    auto src_tr = GetInput(nid, 0);
+    auto gamma_tr = GetInput(nid, 1);
+    auto beta_tr = GetInput(nid, 2);
+    auto mean_tr = GetInput(nid, 3);
+    auto var_tr = GetInput(nid, 4);
+    auto dst_tr = GetOutput(nid, 0);
 
-    // Memory description.
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
+    auto axis = GetNodeAttr<int>(node, "axis");
+    auto epsilon = GetNodeAttr<float>(node, "epsilon");
+    auto center = GetNodeAttr<bool>(node, "center");
+    auto scale = GetNodeAttr<bool>(node, "scale");
+
+    ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case";
 
-    // BN description.
     auto bn_desc = dnnl::batch_normalization_forward::desc(
-        dnnl::prop_kind::forward_inference, data_md, epsilon,
+        dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
         dnnl::normalization_flags::use_global_stats | 
dnnl::normalization_flags::use_scale_shift);
     auto bn_prim_desc = 
dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_);
-    auto bn = dnnl::batch_normalization_forward(bn_prim_desc);
-    net_.push_back(bn);
-
-    // Memories.
-    auto data_memory = BindDNNLMemory(data_entry, data_md);
-    JSONGraphNodeEntry out_entry(nid, 0);
-    auto out_memory = BindDNNLMemory(out_entry, data_md);
-    auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc());
-    auto variance_memory = BindDNNLMemory(variance_entry, 
bn_prim_desc.variance_desc());
-
-    // In DNNL, weight is composed of gamma+beta, so we point them to the same 
DNNL memory but
-    // assign an offset to beta data for runtime serialization.
-    auto weight_memory = BindDNNLMemory(gamma_entry, 
bn_prim_desc.weights_desc(), 0);
-    BindDNNLMemory(beta_entry, weight_memory, IC);
-
-    net_args_.push_back({{DNNL_ARG_SRC, data_memory},
-                         {DNNL_ARG_DST, out_memory},
-                         {DNNL_ARG_SCALE_SHIFT, weight_memory},
-                         {DNNL_ARG_MEAN, mean_memory},
-                         {DNNL_ARG_VARIANCE, variance_memory}});
+
+    // Concatenate scale and shift tensors
+    auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(), 
GenUniqueEid());
+    auto sc_sh_dims = scale_shift_tr.dims();
+    ICHECK(sc_sh_dims.size() == 2);
+    ICHECK(sc_sh_dims[0] == 2);
+    sc_sh_dims[0] /= 2;
+    auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
+    auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();
+
+    auto register_copy = [this](const TensorRequisite& src, const 
TensorRequisite& dst) {
+      dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, 
dst.desc());
+      Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, 
dst}});
+    };
+
+    register_copy(gamma_tr, scale_tr);
+    register_copy(beta_tr, shift_tr);
+
+    Submit(dnnl::batch_normalization_forward(bn_prim_desc), {{DNNL_ARG_SRC, 
src_tr},
+                                                             {DNNL_ARG_DST, 
dst_tr},
+                                                             
{DNNL_ARG_SCALE_SHIFT, scale_shift_tr},
+                                                             {DNNL_ARG_MEAN, 
mean_tr},
+                                                             
{DNNL_ARG_VARIANCE, var_tr}});
   }
 
   void Pooling(const size_t& nid, dnnl::algorithm algo) {
     auto node = nodes_[nid];
 
+    auto src_tr = GetInput(nid, 0);
+    auto dst_tr = GetOutput(nid, 0);
+
     // Setup attributes.
-    auto data_entry = node.GetInputs()[0];
-    JSONGraphNodeEntry out_entry(nid, 0);
-    dnnl::memory::dims input_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::dims out_shape = 
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
-    std::vector<std::string> str_kernel = 
node.GetAttr<std::vector<std::string>>("pool_size");
-    std::vector<std::string> str_strides = 
node.GetAttr<std::vector<std::string>>("strides");
-    std::vector<std::string> str_padding = 
node.GetAttr<std::vector<std::string>>("padding");
-    std::vector<std::string> str_padding_l(str_padding.begin(),
-                                           str_padding.begin() + 
str_padding.size() / 2);
-    std::vector<std::string> str_padding_r(str_padding.end() - 
str_padding.size() / 2,
-                                           str_padding.end());
-    std::vector<std::string> str_dilates = 
node.GetAttr<std::vector<std::string>>("dilation");
-    std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
+    auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+    auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+    auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+    std::vector<int64_t> padding_l(padding.begin(), padding.begin() + 
padding.size() / 2);
+    std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2, 
padding.end());
+    auto kernel = GetNodeAttr<std::vector<int64_t>>(node, "pool_size");
+    auto src_layout = GetNodeAttr<std::string>(node, "layout");
+    auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+
+    // dst_layout == "" means to use data_layout
+    if (dst_layout.empty()) dst_layout = src_layout;
+
+    // Minus one for DNNL representation. No dilation for DNNL is 0, for relay 
is 1.
+    for (auto& d : dilates) d--;
+
+    // Take into account provided layout strings
+    src_tr = src_tr.TreatAs(src_layout);
+    dst_tr = dst_tr.TreatAs(dst_layout);
 
     // Attributes related to AvgPool
     if (algo == dnnl::algorithm::pooling_avg) {
-      int int_countpad = 
std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
-      bool count_include_pad = int_countpad != 0 ? true : false;
-      algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding
-                               : dnnl::algorithm::pooling_avg_exclude_padding;
+      auto include_pad = GetNodeAttr<bool>(node, "count_include_pad");
+      algo = include_pad ? dnnl::algorithm::pooling_avg_include_padding
+                         : dnnl::algorithm::pooling_avg_exclude_padding;
     }
 
-    dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout);
-    dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout);
-    dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel);
-    dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
-    dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
-    dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
-    dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
-
-    // Memory descriptions.
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout));
-    auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
     // Pooling description.
-    auto pool_desc = 
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo,
-                                                 pool_src_md, pool_dst_md, 
strides_dims,
-                                                 kernel_dims, padding_dims_l, 
padding_dims_r);
-
-    auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc, 
engine_, true);
-    auto pool = dnnl::pooling_forward(pool_prim_desc);
-    net_.push_back(pool);
+    auto pool_desc = dnnl::pooling_v2_forward::desc(
+        dnnl::prop_kind::forward_inference, algo, src_tr.desc(),  //<= Do not 
use any for src tensor
+        dst_tr.LayoutAny().desc(), strides, kernel, dilates, padding_l, 
padding_r);
+    auto pool_prim_desc = dnnl::pooling_v2_forward::primitive_desc(pool_desc, 
engine_);
 
-    // Memories.
-    auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md);
+    src_tr = src_tr.RequestLayout(pool_prim_desc.src_desc());
+    dst_tr = dst_tr.RequestLayout(pool_prim_desc.dst_desc());
 
-    auto pool2d_dst_memory = BindDNNLMemory(out_entry, 
pool_prim_desc.dst_desc());
+    auto scratchpad_tr = 
TensorRequisite::AsIs(pool_prim_desc.scratchpad_desc());
 
-    // Bind memory buffers.
-    net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST, 
pool2d_dst_memory}});
+    Submit(dnnl::pooling_v2_forward(pool_prim_desc),
+           {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, 
{DNNL_ARG_SCRATCHPAD, scratchpad_tr}});
   }
 
   void Eltwise(const size_t& nid) {
     auto node = nodes_[nid];
     auto op_name = node.GetOpName();
-    auto algo = elt_name2algo[op_name];
+    auto algo = elt_name2algo.at(op_name);
+
+    auto src_tr = GetInput(nid, 0);
+    auto dst_tr = GetOutput(nid, 0);
 
-    auto data_entry = node.GetInputs()[0];
-    dnnl::memory::dims shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
     float alpha = 0., beta = 0.;
     if (op_name == "clip") {
-      alpha = std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
-      beta = std::stof(node.GetAttr<std::vector<std::string>>("a_max")[0]);
+      alpha = GetNodeAttr<float>(node, "a_min");
+      beta = GetNodeAttr<float>(node, "a_max");
     } else if (op_name == "nn.leaky_relu") {
-      alpha = std::stof(node.GetAttr<std::vector<std::string>>("alpha")[0]);
+      alpha = GetNodeAttr<float>(node, "alpha");
     }
 
-    auto elt_desc =
-        dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, 
data_md, alpha, beta);
+    auto elt_desc = 
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
+                                                src_tr.desc(), alpha, beta);
     auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, 
engine_);
-    ICHECK(data_md == elt_prim_desc.dst_desc());
-
-    auto elt = dnnl::eltwise_forward(elt_prim_desc);
-    net_.push_back(elt);
+    ICHECK(src_tr.desc() == elt_prim_desc.dst_desc());
 
-    auto data_memory = BindDNNLMemory(data_entry, data_md);
-    JSONGraphNodeEntry out_entry(nid, 0);
-    auto out_memory = BindDNNLMemory(out_entry, data_md);
-
-    net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, 
out_memory}});
+    Submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr}, 
{DNNL_ARG_DST, dst_tr}});
   }
 
   void Softmax(const size_t& nid) {
     auto node = nodes_[nid];
 
-    auto data_entry = node.GetInputs()[0];
-    dnnl::memory::dims shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
+    auto src_tr = GetInput(nid, 0);
+    auto dst_tr = GetOutput(nid, 0);
+
+    auto axis = GetNodeAttr<int>(node, "axis");
     if (axis < 0) {
-      axis = shape.size() + axis;
+      axis = src_tr.dims().size() + axis;
     }
-    auto dtype = 
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
 
     auto softmax_desc =
-        dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, 
data_md, axis);
+        dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, 
src_tr.desc(), axis);
     auto softmax_prim_desc = 
dnnl::softmax_forward::primitive_desc(softmax_desc, engine_);
-    ICHECK(data_md == softmax_prim_desc.dst_desc());
-
-    auto softmax = dnnl::softmax_forward(softmax_prim_desc);
-    net_.push_back(softmax);
+    ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc());
 
-    auto data_memory = BindDNNLMemory(data_entry, data_md);
-    JSONGraphNodeEntry out_entry(nid, 0);
-    auto out_memory = BindDNNLMemory(out_entry, data_md);
-
-    net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, 
out_memory}});
+    Submit(dnnl::softmax_forward(softmax_prim_desc),
+           {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}});
   }
 
   void Binary(const size_t& nid, dnnl::algorithm algo) {
     auto node = nodes_[nid];
+    ICHECK_EQ(node.GetInputs().size(), 2U);
 
     // Memory and compute description.
-    std::vector<dnnl::memory::dims> data_dims;
-    std::vector<dnnl::memory::desc> data_mds;
-    std::vector<dnnl::memory> data_memories;
+    auto lhs_tr = GetInput(nid, 0);
+    auto rhs_tr = GetInput(nid, 1);
+    auto dst_tr = GetOutput(nid, 0);
 
-    ICHECK_EQ(node.GetInputs().size(), 2U);
-    for (auto entry : node.GetInputs()) {
-      auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
-      auto dtype = 
dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]);
-      dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
-
-      data_dims.push_back(data_shape);
-      data_mds.push_back(data_md);
-      data_memories.push_back(BindDNNLMemory(entry, data_md));
-    }
-    ICHECK(data_dims[0] == data_dims[1]);
-    auto out_md = data_mds[0];
-    JSONGraphNodeEntry out_entry(nid, 0);
-    auto out_memory = BindDNNLMemory(out_entry, out_md);
+    lhs_tr = lhs_tr.Broadcast(dst_tr.dims());
+    rhs_tr = rhs_tr.Broadcast(dst_tr.dims());
 
-    auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], 
out_md);
+    auto binary_desc = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(), 
dst_tr.desc());
     auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_);
-    auto binary = dnnl::binary(binary_prim_desc);
-    net_.push_back(binary);
 
-    net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]},
-                         {DNNL_ARG_SRC_1, data_memories[1]},
-                         {DNNL_ARG_DST, out_memory}});
+    Submit(dnnl::binary(binary_prim_desc),
+           {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST, 
dst_tr}});
+  }
+
+  template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0>
+  T AttrConvert(std::vector<std::string> val) {
+    ICHECK_EQ(val.size(), 1);
+    return std::stol(val[0]);
+  }
+
+  template <typename T, std::enable_if_t<std::is_floating_point<T>::value, 
int> = 0>
+  T AttrConvert(std::vector<std::string> val) {
+    ICHECK_EQ(val.size(), 1);
+    return std::stof(val[0]);
+  }
+
+  template <typename T, std::enable_if_t<std::is_same<T, std::string>::value, 
int> = 0>
+  T AttrConvert(std::vector<std::string> val) {
+    ICHECK_EQ(val.size(), 1);
+    return val[0];
+  }
+
+  template <typename T,
+            std::enable_if_t<std::is_same<T, std::vector<typename 
T::value_type>>::value, int> = 0>
+  T AttrConvert(std::vector<std::string> val) {
+    T res;
+    for (const auto& el : val) res.push_back(AttrConvert<typename 
T::value_type>({el}));
+    return res;
+  }
+
+  /*!
+   * \brief Helper to extract node attribute with ability to specify default 
value and result type.
+   */
+  template <typename T>
+  const T GetNodeAttr(const json::JSONGraphNode& node, std::string name,
+                      std::vector<std::string> def = {}) {
+    auto attr = node.HasAttr(name) ? 
node.GetAttr<std::vector<std::string>>(name) : def;
+    return AttrConvert<T>(attr);
   }
 
-  // Read from DNNL memory (+offset) and write to the handle.
-  inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, 
size_t size,
-                                    size_t offset = 0) {
-    uint8_t* src = static_cast<uint8_t*>(mem.get_data_handle());
-    std::copy(src + offset, src + offset + size, 
static_cast<uint8_t*>(handle));
+  TensorRequisite GetInput(const size_t& nid, const int idx) {
+    if (idx == -1) return {};  // -1 reserved value for empty input.
+
+    const JSONGraphNode& node = nodes_[nid];
+
+    ICHECK_LT(idx, node.GetInputs().size());
+    auto data_entry = node.GetInputs()[idx];
+
+    auto shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+    auto dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
+    auto eid = node_row_ptr_[data_entry.id_] + data_entry.index_;
+    auto const_dl_tensor = data_entry_[eid];
+
+    auto desc = MakePlainDesc(shape, dtype);
+
+    TensorRequisite res;
+    if (const_dl_tensor) {
+      ICHECK(const_dl_tensor->data);
+      ICHECK(const_dl_tensor->strides == nullptr);
+      auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data);
+      res = TensorRequisite::AsIs(mem, eid);
+    } else {
+      res = TensorRequisite::AsIs(desc, eid);
+    }
+    return res;
   }
 
-  // Read from the handle and write to DNNL memory (+offset).
-  inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, 
size_t size,
-                                   size_t offset = 0) {
-    uint8_t* dst = static_cast<uint8_t*>(mem.get_data_handle());
-    std::copy(reinterpret_cast<uint8_t*>(handle), 
reinterpret_cast<uint8_t*>(handle) + size,
-              dst + offset);
+  TensorRequisite GetOutput(const size_t& nid, const int idx) {
+    if (idx == -1) return {};  // -1 reserved value for empty input.
+
+    const JSONGraphNode& node = nodes_[nid];
+
+    ICHECK_LT(idx, node.GetNumOutput());
+    auto shape = node.GetOpShape()[idx];
+    auto dtype = node.GetOpDataType()[idx];
+    auto eid = node_row_ptr_[nid] + static_cast<uint32_t>(idx);
+
+    ICHECK(data_entry_[eid] == nullptr);
+    auto desc = MakePlainDesc(shape, dtype);
+
+    return TensorRequisite::AsIs(desc, eid).Backward();
   }
 
-  // Generate DNNL memory description and infer the data layout by the given 
shape.
-  inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& 
shape, dt dtype) {
-    dnnl::memory::desc data_md;
-    switch (shape.size()) {
-      case 2:
-        data_md = dnnl::memory::desc({shape, dtype, tag::ab});
-        break;
-      case 3:
-        data_md = dnnl::memory::desc({shape, dtype, tag::abc});
-        break;
-      case 4:
-        data_md = dnnl::memory::desc({shape, dtype, tag::abcd});
-        break;
-      case 5:
-        data_md = dnnl::memory::desc({shape, dtype, tag::abcde});
-        break;
-      default:
-        LOG(FATAL) << "Unsupported data shape dimension: " << shape.size();
-        break;
+  /*! \brief Helper function to register primitive into execution queue */
+  void Submit(const dnnl::primitive& prim,
+              const std::unordered_map<int, TensorRequisite>& tr_args) {
+    // Register all provided TR arguments
+    std::unordered_map<int, TensorRegistry::ArgId> prim_arg_id;
+    TensorRegistry::ActionQue post_prim_actions;
+    for (const auto& kvp : tr_args) {
+      const auto& key = kvp.first;
+      const auto& tr = kvp.second;
+
+      if (!tr.defined()) continue;  // empty arg is admitted. Just skip it
+      auto arg_id = tensor_registry_.Register(tr, tr.IsReversed() ? 
&post_prim_actions : &net_);
+      prim_arg_id[key] = arg_id;
     }
-    return data_md;
+
+    // Register main primitive
+    net_.push_back({prim, prim_arg_id});
+
+    // Register post actions
+    net_.insert(net_.end(), post_prim_actions.begin(), 
post_prim_actions.end());
   }
 
+  uint32_t GenUniqueEid() { return next_unique_eid_offset_++; }
+
   /* The dnnl engine. */
   dnnl::engine engine_;
   /* The dnnl stream. */
   dnnl::stream stream_;
   /* The network layers that are represented in dnnl primitives. */
-  std::vector<dnnl::primitive> net_;
-  /* The memory that is consumed by arguments. */
-  std::vector<std::unordered_map<int, dnnl::memory>> net_args_;
-  /* The entry ID to its corresponding output memory. */
-  std::unordered_map<uint32_t, std::pair<dnnl::memory, size_t>> entry_out_mem_;
+  TensorRegistry::ActionQue net_;
+  /* Storage for all memory objects */
+  TensorRegistry tensor_registry_;
+  /* Generator of new unique eid which doesn't match with existing data entry 
*/
+  uint32_t next_unique_eid_offset_;
+  /* Map of Run arg idx to corresponding eid */
+  std::vector<uint32_t> run_arg_eid_;
 };
 
 runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json,
diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h 
b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
new file mode 100644
index 0000000000..d02ceff5de
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_tensor_requisite.cc
+ * \brief Helper TR wrapper to simplify tensors processing
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
+#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
+
+#include <dlpack/dlpack.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+//  -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "dnnl_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace utils;
+
+/*!
+ * \brief Helper object to simplify tensor transformation description.
+ *
+ * Allow to specify original source tensor and future actions which should be 
applied to it.
+ * Can be treated as sequence of reordering or reinterpretation of original 
source tensor.
+ * Finally TR can be solved as proper interpretation of source memory buffer, 
or sequence of
+ * dnnl::reorder operators which will provide desired data.
+ *
+ * \note Empty TR object allow any manipulation. Empty TR will be returned.
+ *
+ * \sa TensorRegistry
+ *
+ * Example:
+ * \code
+ *   dnnl::memory src_mem = ...;  // 5D tensor, shape {5, 2, 128, 128, 8}
+ *
+ *   // Construct TR
+ *   auto tr = TensorRequisite.AsIs(src_mem, eid);  // 5D
+ *
+ *   // describe sequence of layout transformation
+ *   tr = tr.TreatAs("ABCD8b");  // 4D
+ *   tr = tr.Permute({0, 2, 3, 1});  // Permute axes NCHW -> NHWC
+ *   tr = tr.Crop({1, 128, 128, 16}, {0, 0, 0});  // extract first batch 
element
+ *   tr = tr.Squeeze(); // 1D
+ *
+ *   // register TR
+ *   TensorRegistry t_reg;
+ *   auto t_id = t_reg.register(tr);
+ *
+ *   // Get final dnnl::memory object
+ *   auto solver = t_reg.MakeSolver(ext_tensor_provider);
+ *   auto mem = solver(t_id);
+ * \endcode
+ *
+ */
+class TensorRequisite {
+ public:
+  using Tid = uint32_t;
+  static constexpr Tid kUndefinedTid = std::numeric_limits<uint32_t>::max() - 
1;
+
+  /*! \brief Empty constructor */
+  TensorRequisite() {}
+
+  /*! \brief Construct TR on top of existing memory object */
+  static TensorRequisite AsIs(const dnnl::memory& mem, Tid id = kUndefinedTid) 
{
+    auto res = AsIs(mem.get_desc(), id);
+    if (mem.get_data_handle() != nullptr) res.mem_ = mem;
+    return res;
+  }
+
+  /*! \brief Construct TR on top of existing memory descriptor object */
+  static TensorRequisite AsIs(const dnnl::memory::desc& desc, Tid id = 
kUndefinedTid) {
+    return {desc, {}, false, {}, id, false};
+  }
+
+  /*! \brief return logical shape of tensor */
+  dnnl::memory::dims dims() const { return t_desc_.dims(); }
+
+  /*! \brief return data type of tensor */
+  dnnl::memory::data_type data_type() const { return t_desc_.data_type(); }
+
+  /*! \brief return tensor desc */
+  dnnl::memory::desc desc() const { return t_desc_; }
+
+  /*! \brief Make TR with backward dataflow */
+  TensorRequisite Backward() const {
+    if (!defined()) return *this;
+    ICHECK(orig_ == nullptr);
+    return {t_desc_, orig_, reinterpret_, mem_, eid_, true};
+  }
+
+  /*! \brief Produce TR with permuted axes */
+  TensorRequisite Permute(const std::vector<int>& permutation) const {
+    if (!defined()) return *this;  // nothing for empty TR
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    // reinterpret memory buffer with new strides
+    auto desc = t_desc_.permute_axes(permutation);
+    return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Produce TR with reinterpret data of original tr */
+  TensorRequisite Reshape(const dnnl::memory::dims& shape) const {
+    if (!defined()) return *this;  // nothing for empty TR
+    if (t_desc_.dims() == shape) return *this;
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    // reinterpret memory buffer with new strides
+    auto desc = t_desc_.reshape(shape);
+    return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Produce TR with broadcasted values */
+  TensorRequisite Broadcast(const dnnl::memory::dims& shape) const {
+    if (!defined()) return *this;  // nothing for empty TR
+    if (t_desc_.dims() == shape) return *this;
+    ICHECK(!reverse_data_flow_);
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+
+    // numpy like broadcast
+    auto extended_dims = t_desc_.dims();
+    auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(), 
1);
+    extended_dims.insert(extended_dims.begin(), one_filled.begin(), 
one_filled.end());
+    auto desc = t_desc_.reshape(extended_dims);
+    for (size_t i = 0; i < extended_dims.size(); i++) {
+      if (extended_dims[i] == shape[i]) continue;
+      ICHECK(extended_dims[i] == 1);
+      ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]);
+
+      desc.data.dims[i] = shape[i];
+      desc.data.padded_dims[i] = shape[i];
+      desc.data.format_desc.blocking.strides[i] = 0;
+    }
+
+    // reinterpret memory buffer with new strides
+    return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Produce TR with sub memory view (ROI) */
+  TensorRequisite Crop(const dnnl::memory::dims& shape, const 
dnnl::memory::dims& offset) const {
+    if (!defined()) return *this;  // nothing for empty TR
+
+    ICHECK_EQ(shape.size(), t_desc_.dims().size());
+    ICHECK_EQ(offset.size(), t_desc_.dims().size());
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    // reinterpret memory buffer with new strides
+    auto desc = t_desc_.submemory_desc(shape, offset, /*allow_empty=*/true);
+
+    // Originally DNNL implementation is very limited. Let's slightly enhance 
it.
+    if (!desc && t_desc_.data.format_kind == dnnl_blocked) {
+      bool offset_is_zero =
+          std::all_of(offset.begin(), offset.end(), [](auto el) { return el == 
0; });
+
+      dnnl::memory::dims block_sizes(t_desc_.dims().size(), 1);
+      for (int i = 0; i < t_desc_.data.format_desc.blocking.inner_nblks; i++)
+        block_sizes[t_desc_.data.format_desc.blocking.inner_idxs[i]] *=
+            t_desc_.data.format_desc.blocking.inner_blks[i];
+
+      bool shape_reduction_less_than_block = true;
+      for (int i = 0; i < t_desc_.data.ndims; i++) {
+        shape_reduction_less_than_block &= t_desc_.data.dims[i] - shape[i] < 
block_sizes[i];
+      }
+
+      // This is auto padded case. Just update dims value.
+      if (offset_is_zero && shape_reduction_less_than_block) {
+        desc = t_desc_;
+        std::copy(shape.begin(), shape.end(), desc.data.dims);
+      }
+    }
+
+    ICHECK(desc);
+
+    return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Produce TR with squeeze shape */
+  TensorRequisite Squeeze(const dnnl::memory::dims& dims_to_squeeze = {}) 
const {
+    if (!defined()) return *this;  // nothing for empty TR
+
+    dnnl::memory::dims squeezed_dims;
+    if (dims_to_squeeze.empty()) {
+      for (auto d : t_desc_.dims())
+        if (d != 1) squeezed_dims.push_back(d);
+    } else {
+      for (size_t i = 0; i < t_desc_.dims().size(); i++)
+        if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) == 
dims_to_squeeze.end())
+          squeezed_dims.push_back(t_desc_.dims()[i]);
+    }
+
+    if (squeezed_dims.empty()) squeezed_dims = {1};
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    // reinterpret memory buffer with new strides
+    auto desc = t_desc_.reshape(squeezed_dims);
+    return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Produce TR with specified layout descriptor */
+  TensorRequisite RequestLayout(dnnl::memory::desc desc) const {
+    if (!defined()) return *this;  // nothing for empty TR
+
+    // If it's the same desc just return self
+    if (desc == t_desc_) return *this;
+
+    ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not 
compatible with "
+                                             "presented shape";
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    return {desc, orig, false, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Define which logical dims ordering is default for particular 
layout string. */
+  static std::string DefaultLogicLayoutFor(const std::string& layout) {
+    // Rank is all non digit marked dims
+    auto it = layout.begin();
+    while (it != layout.end() && !std::isdigit(*it)) it++;
+    int rank = std::distance(layout.begin(), it);
+
+    static const std::vector<std::string> sparse_dims = {"W", "HW", "DHW"};
+    if (layout.find("N") != std::string::npos) return "NC" + sparse_dims[rank 
- 3];
+    if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank 
- 4];
+    if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank 
- 3];
+
+    LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme 
to handle it";
+    return {};
+  }
+
+  /*!
+   * \brief Treat TR shape as described in layout string.
+   *
+   * Blocked dimensions will be concatenated and put into proper shape 
position corresponding to  .
+   * resulting_layout_logic argument. If desired logic layout was not provided 
it will be deduced
+   * automatically based on some internal heuristics.
+   *
+   * Limitation 1. Blocking dims should be dense. Dims marked with digits use 
natural strides.
+   * Limitation 2. Blocking dims are innermost. Dims marked like 8c, 4o goes 
after regular
+   *               dimensions. NC8cHW4h4cD is not valid tensor in terms of 
DNNL. And cannot be
+   *               achieved with memory reinterpretation, so data copy is 
required. Proper layout
+   *               looks like NCHWD_8c4h4c, first part is outer dims, second 
digits marked part is
+   *               innermost.
+   */
+  TensorRequisite TreatAs(const std::string& layout, std::string 
desired_logic_layout = "") const {
+    if (desired_logic_layout.empty()) desired_logic_layout = 
DefaultLogicLayoutFor(layout);
+
+    const auto origin_dims = dims();
+
+    // split layout string to tokens {size, tag} like {16, 'C'}, {4, 'O'}
+    std::vector<std::pair<int, char>> layout_tokens;
+    for (auto it = layout.begin(); it != layout.end();) {
+      auto start = it;
+      while (std::isdigit(*it)) it++;
+      int blk_size = start == it ? -1 : std::stoi(std::string{start, it});
+      layout_tokens.push_back({blk_size, std::toupper(*it)});
+      it++;
+    }
+
+    // check applicability of layout
+    auto it = layout_tokens.begin();
+    while (it != layout_tokens.end() && it->first == -1) it++;
+    int rank = std::distance(layout_tokens.begin(), it);
+    while (it != layout_tokens.end()) {
+      ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be 
innermost. "
+                               << "But received layout is " << layout;
+      it++;
+    }
+
+    ICHECK_EQ(layout_tokens.size(), origin_dims.size());
+    ICHECK_EQ(rank, desired_logic_layout.size()) << layout;
+
+    std::vector<std::pair<int, char>> outermost_tokens(layout_tokens.begin(),
+                                                       layout_tokens.begin() + 
rank);
+    std::vector<std::pair<int, char>> innermost_tokens(layout_tokens.begin() + 
rank,
+                                                       layout_tokens.end());
+    // define dim resulting dim positions
+    std::map<char, int> dim_position_by_tag;
+    for (size_t i = 0; i < desired_logic_layout.size(); i++)
+      dim_position_by_tag[std::toupper(desired_logic_layout[i])] = i;
+
+    // Construct resulting desc by modifying original one
+    dnnl::memory::desc res_desc = t_desc_;
+
+    memset(&res_desc.data.format_desc.blocking, 0, 
sizeof(res_desc.data.format_desc.blocking));
+    std::fill(res_desc.data.dims, res_desc.data.dims + DNNL_MAX_NDIMS, 0);
+    std::fill(res_desc.data.padded_dims, res_desc.data.padded_dims + 
DNNL_MAX_NDIMS, 0);
+
+    res_desc.data.ndims = rank;
+    res_desc.data.format_desc.blocking.inner_nblks = innermost_tokens.size();
+
+    auto res_dims = res_desc.data.dims;
+    auto res_strides = res_desc.data.format_desc.blocking.strides;
+    auto res_inner_blks = res_desc.data.format_desc.blocking.inner_blks;
+    auto res_inner_idxs = res_desc.data.format_desc.blocking.inner_idxs;
+
+    std::fill(res_dims, res_dims + rank, 1);
+
+    int orig_dim_idx = 0;
+    for (const auto& p : outermost_tokens) {
+      auto tag = p.second;
+      auto dim_size = origin_dims[orig_dim_idx];
+
+      auto result_dim_position = dim_position_by_tag[tag];
+      res_dims[result_dim_position] *= dim_size;
+      res_strides[result_dim_position] = 
t_desc_.data.format_desc.blocking.strides[orig_dim_idx];
+      orig_dim_idx++;
+    }
+    for (const auto& p : innermost_tokens) {
+      auto tag = p.second;
+      auto dim_size = origin_dims[orig_dim_idx];
+      auto result_dim_position = dim_position_by_tag[tag];
+      ICHECK_EQ(p.first, dim_size)
+          << "Blocking layout is not applicable to tensor with shape: " << 
origin_dims
+          << ". Requested layout is " << layout;
+
+      res_dims[result_dim_position] *= dim_size;
+      *res_inner_blks++ = dim_size;
+      *res_inner_idxs++ = result_dim_position;
+      orig_dim_idx++;
+    }
+
+    // Assume tensor is dense. There is no additional padding.
+    std::copy(res_desc.data.dims, res_desc.data.dims + rank, 
res_desc.data.padded_dims);
+
+    if (t_desc_ == res_desc) return *this;
+
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    return {res_desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*!
+   * \brief Produce TR with unspecified layout.
+   *
+   * Cannot be registered in TensorRegistry. Only for querying DNNL for 
preferred layouts.
+   */
+  TensorRequisite LayoutAny() const {
+    auto orig = std::make_shared<TensorRequisite>(*this);
+    // Recreate tensor desc with layout 'any'
+    dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(), 
dnnl::memory::format_tag::any};
+    return {any_desc, orig, false, {}, kUndefinedTid, reverse_data_flow_};
+  }
+
+  /*! \brief Check is TR is constant. */
+  bool IsConstant() const {
+    if (orig_) return orig_->IsConstant();
+    return mem_.operator bool();
+  }
+
+  /*! \brief Check is tensor is scalar. */
+  bool IsScalar() const { return t_desc_.dims().size() == 1 && 
t_desc_.dims()[0] == 1; }
+
+  /*! \brief Return const data memory if available. */
+  dnnl::memory GetConstData() const {
+    if (mem_) return mem_;
+    if (!orig_) return {};
+
+    if (auto orig_const_data = orig_->GetConstData()) {
+      if (reinterpret_) {
+        return {t_desc_, orig_const_data.get_engine(), 
orig_const_data.get_data_handle()};
+      } else {
+        auto eng = orig_const_data.get_engine();
+        auto res = dnnl::memory{t_desc_, eng};
+        dnnl::reorder(orig_const_data, res).execute(dnnl::stream(eng), 
orig_const_data, res);
+        return res;
+      }
+    }
+    return {};
+  }
+
+  /*!
+   * \brief Return const data memory in form of vector.
+   *
+   * Same as GetConstData but use std::vector instead of dnnl::memory. Works 
only for 1D tensor
+   * and scalar TRs. Useful for specification of 1D DNNL attributes like 
zero_point or
+   * per_channel_scale
+   */
+  template <typename T>
+  std::vector<T> GetConstDataLikeVec() const {
+    auto const_data = GetConstData();
+    auto desc = const_data.get_desc();
+    ICHECK(desc.data_type() == utils::DnnlDType<T>());
+    ICHECK(desc.dims().size() == 1);
+
+    auto size = desc.get_size() / sizeof(T);
+    auto ptr = static_cast<T*>(const_data.get_data_handle());
+
+    return std::vector<T>(ptr, ptr + size);
+  }
+
+  /*! \brief Get value of constant scalar tensor if possible. */
+  template <typename T>
+  T GetConstScalarData() const {
+    ICHECK(IsConstant());
+    ICHECK(IsScalar());
+    auto const_data = GetConstData();
+    auto desc = const_data.get_desc();
+    ICHECK(desc.data_type() == utils::DnnlDType<T>());
+
+    auto ptr = static_cast<T*>(const_data.get_data_handle());
+    return *ptr;
+  }
+
+  /*! \brief Check if tensor is not empty. */
+  bool defined() const { return !t_desc_.is_zero(); }
+
+  /*! \brief Same as defined */
+  operator bool() const { return defined(); }
+
+  /*!
+   * \brief Check if tensor represent a reversed data flow.
+   * Useful for describing output processing
+   */
+  bool IsReversed() const { return reverse_data_flow_; }
+
+ private:
+  TensorRequisite(const dnnl::memory::desc& t_desc, const 
std::shared_ptr<TensorRequisite>& orig,
+                  bool reinterpret, const dnnl::memory& const_mem, uint32_t 
eid,
+                  bool reverse_data_flow)
+      : t_desc_(t_desc),
+        orig_(orig),
+        reinterpret_(reinterpret),
+        mem_(const_mem),
+        eid_(eid),
+        reverse_data_flow_(reverse_data_flow) {
+    if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid);
+    if (eid_ != kUndefinedTid) ICHECK(!orig_);
+  }
+
+  /* Descriptor of particular tensor  */
+  dnnl::memory::desc t_desc_ = {};
+  /* Parent TR object which is referred from this TR */
+  std::shared_ptr<TensorRequisite> orig_ = {};
+  /* Flag to specify which action should be done with orig TR, reordering or 
reinterpretation */
+  bool reinterpret_ = false;
+  /* Const memory object if available */
+  dnnl::memory mem_ = {};
+  /* Entry ID of tensor if available */
+  uint32_t eid_ = kUndefinedTid;
+
+  /*
+   * Flag to describe reverse data flow case
+   * All operation on queue will be executed in reverse order. Actual for dst 
tensor description
+   */
+  bool reverse_data_flow_ = false;
+
+  friend class TensorRegistry;
+};
+
+/*!
+ * \brief The registry of tensors. Implement matching of provided TRs and real 
memory buffers.
+ *
+ * Registration of TR performed by calling method Register(), which will 
return ArgId object.
+ * ArgId can be mapped to real memory via memory solver created by method 
MakeSolver().
+ */
+class TensorRegistry {
+ private:
+  enum ArgReqFlag {
+    CONST,        /// < Constant tensor. ExecutionCTX independent
+    TMP_STORAGE,  /// < Intermediate tensors. Stored inside TensorRegistry. 
Inaccessible outside
+    EXT_EID,      /// < External data. Input or Output.
+  };
+
+ public:
+  struct ArgId {
+    TensorRegistry::ArgReqFlag flag_;
+    uint32_t idx_;
+  };
+
+  using Action = std::tuple<dnnl::primitive, std::unordered_map<int, ArgId>>;
+  using ActionQue = std::vector<Action>;
+  using DLTensorProvider = std::function<const DLTensor*(uint32_t)>;
+  using MemSolver = std::function<const dnnl::memory(ArgId)>;
+
+  TensorRegistry() = default;
+  TensorRegistry(const dnnl::engine& eng, const std::set<uint32_t>& ext_io_eid)
+      : tmp_mem_collection_(1), ext_io_eid_(ext_io_eid), eng_(eng), 
stream_(eng) {}
+
+  /*!
+   * \brief Register TR to registry
+   *
+   * Resolution of TR may lead to introduction of intermediate memory buffers 
and additional
+   * transformation actions which should be performed before or after usage of 
corresponding memory
+   * buffer. Additional actions will be append to provided actions queue. 
Corresponding to
+   * tr.IsReversed() value actions should be executed before or after usage of 
resulting ArgId.
+   *
+   * \param tr tensor requisite sequence to register
+   * \param action resulting action queue. If TR resolution is required 
execution of some
+   *               transformation actions they will be put here
+   * \return associated ArgId. Should be used as argument for MemSolver.
+   */
+  ArgId Register(const TensorRequisite& tr, ActionQue* action) {
+    // 1) Constant tensor. Direct reference
+    if (auto const_data = tr.GetConstData()) {
+      auto idx = const_mem_collection_.size();
+      const_mem_collection_.push_back(const_data);
+      return MakeArgReq(ArgReqFlag::CONST, static_cast<uint32_t>(idx));
+    }
+
+    // 2) EID mapped tensor. Direct reference
+    if (tr.eid_ != TensorRequisite::kUndefinedTid) {
+      if (ext_io_eid_.count(tr.eid_) == 0) {  // Not IO tensor, means it's 
intermediate
+        if (eid2idx_tmp_.count(tr.eid_)) {
+          auto idx = eid2idx_tmp_.at(tr.eid_);
+          return MakeArgReq(ArgReqFlag::TMP_STORAGE, idx);
+        } else {
+          // register himself
+          auto idx = tmp_mem_collection_.size();
+          tmp_mem_collection_.push_back(tr.t_desc_);
+          eid2idx_tmp_[tr.eid_] = idx;
+          return MakeArgReq(ArgReqFlag::TMP_STORAGE, 
static_cast<uint32_t>(idx));
+        }
+      } else {
+        auto idx = ext_mem_collection_.size();
+        ext_mem_collection_.push_back({tr.eid_, tr.t_desc_});
+        return MakeArgReq(ArgReqFlag::EXT_EID, static_cast<uint32_t>(idx));
+      }
+    }
+
+    // 3) Tensors with transform actions
+    if (tr.orig_) {
+      // recursive register of orig TR
+      auto orig_arg_req = Register(*tr.orig_, action);
+      if (tr.reinterpret_) {
+        return RegisterReinterpret(orig_arg_req, tr.t_desc_);
+      } else {
+        return RegisterReorder(orig_arg_req, tr.t_desc_, 
tr.reverse_data_flow_, action);
+      }
+    }
+
+    // 4) Scratchpad
+    ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid);
+    auto idx = tmp_mem_collection_.size();
+    tmp_mem_collection_.push_back(tr.t_desc_);
+    tmp_mem_mapping_[idx] = 0;  // zero position tmp mem object is reserved 
for scratchpads
+
+    auto scratchpad_size = tr.t_desc_.get_size();
+    auto glob_scratchpad_size = tmp_mem_collection_[0].get_size();
+    if (scratchpad_size > glob_scratchpad_size) {
+      tmp_mem_collection_[0] =
+          dnnl::memory::desc({static_cast<dnnl::memory::dim>(scratchpad_size)},
+                             dnnl::memory::data_type::u8, 
dnnl::memory::format_tag::a);
+    }
+    return MakeArgReq(TMP_STORAGE, static_cast<uint32_t>(idx));
+  }
+
+  /*!
+   * \brief Construct memory solver for all registered TRs.
+   * \param ext_provider callback to resolve external IO buffers
+   * \return memory solver object to match ArgId to dnnl::memory objects
+   */
+  MemSolver MakeSolver(const DLTensorProvider& ext_provider) const {
+    return MemSolverImpl(eng_, ext_provider, const_mem_collection_, 
ext_mem_collection_,
+                         tmp_mem_collection_, tmp_mem_mapping_);
+  }
+
+ private:
+  ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) {
+    switch (src_ar.flag_) {
+      case TMP_STORAGE: {
+        auto idx = tmp_mem_collection_.size();
+        tmp_mem_collection_.push_back(desc);
+        tmp_mem_mapping_[idx] = src_ar.idx_;
+        return MakeArgReq(TMP_STORAGE, idx);
+      }
+      case EXT_EID: {
+        auto ext_req = ext_mem_collection_[src_ar.idx_];
+        auto idx = ext_mem_collection_.size();
+        ext_mem_collection_.push_back({ext_req.first, desc});
+        return MakeArgReq(EXT_EID, idx);
+      }
+      default:
+        LOG(FATAL) << "Unknown case";
+    }
+    return {};
+  }
+
+  ArgId RegisterReorder(ArgId src_ar, const dnnl::memory::desc& desc, bool 
reverse_data_flow,
+                        ActionQue* action) {
+    ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID);
+
+    auto src_desc = src_ar.flag_ == TMP_STORAGE ? 
tmp_mem_collection_[src_ar.idx_]
+                                                : 
ext_mem_collection_[src_ar.idx_].second;
+    auto idx = tmp_mem_collection_.size();
+    tmp_mem_collection_.push_back(desc);
+    auto dst_ar = MakeArgReq(TMP_STORAGE, idx);
+
+    // reorder action submit
+    if (reverse_data_flow) {
+      auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_, 
src_desc);
+      action->insert(action->begin(),
+                     {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, dst_ar}, 
{DNNL_ARG_TO, src_ar}}});
+    } else {
+      auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_, 
desc);
+      action->push_back(
+          {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, 
dst_ar}}});
+    }
+    return dst_ar;
+  }
+  /*! \brief Implementation of memory solver */
+  class MemSolverImpl {
+   public:
+    MemSolverImpl(const dnnl::engine& eng, const DLTensorProvider& 
ext_data_provider,
+                  const std::vector<dnnl::memory>& const_mems,
+                  const std::vector<std::pair<uint32_t, dnnl::memory::desc>>& 
ext_mems,
+                  const std::vector<dnnl::memory::desc>& tmp_mem_descs,
+                  const std::map<size_t, size_t>& tmp_mem_mapping)
+        : eng_(eng),
+          ext_data_provider_(ext_data_provider),
+          const_mems_(const_mems),
+          ext_mems_(ext_mems) {
+      // Construct temp memory objects on the fly. While we have no scratchpads
+      // support on VM/GraphExecutor level.
+      tmp_mems_.resize(tmp_mem_descs.size());
+      for (size_t i = 0; i < tmp_mem_descs.size(); i++) {
+        auto found = tmp_mem_mapping.find(i);
+
+        if (found != tmp_mem_mapping.end()) {
+          auto reuse_hdl = tmp_mems_[found->second].get_data_handle();
+          tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_, reuse_hdl);
+        } else {
+          tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_);
+        }
+      }
+    }
+
+    /*! \brief Find memory object associated with provided ArgId */
+    dnnl::memory operator()(const ArgId& ar) const {
+      switch (ar.flag_) {
+        case CONST:
+          return const_mems_.at(ar.idx_);
+        case TMP_STORAGE:
+          return tmp_mems_.at(ar.idx_);
+        case EXT_EID: {
+          auto eid_and_desc = ext_mems_.at(ar.idx_);
+          auto eid = eid_and_desc.first;
+          auto desc = eid_and_desc.second;
+
+          auto ext_dl_tensor = ext_data_provider_(eid);
+          ICHECK(ext_dl_tensor->data);
+          return dnnl::memory{desc, eng_, ext_dl_tensor->data};
+        }
+      }
+      return {};
+    }
+
+   private:
+    const dnnl::engine& eng_;
+    const DLTensorProvider& ext_data_provider_;
+    const std::vector<dnnl::memory>& const_mems_;
+    const std::vector<std::pair<uint32_t, dnnl::memory::desc>>& ext_mems_;
+    std::vector<dnnl::memory> tmp_mems_;
+  };
+
+  ArgId MakeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; }
+
+  /* Collection of const memory objects. */
+  std::vector<dnnl::memory> const_mem_collection_;
+
+  /* Collection of intermediate memory descriptors. Zero position is reserved 
for scratchpads. */
+  std::vector<dnnl::memory::desc> tmp_mem_collection_;
+
+  /* Mapping of some temp buffer on previously registered. */
+  std::map<size_t, size_t> tmp_mem_mapping_;
+
+  /* Collection of external_intermediate memory objects.
+   *  first  - eid of external buffer to ask
+   *  second - t_desc describes how to treat external buffer */
+  std::vector<std::pair<uint32_t, dnnl::memory::desc>> ext_mem_collection_;
+
+  /* Map of eid to index of temp buffer in tmp_mem_collection_ */
+  std::unordered_map<uint32_t, size_t> eid2idx_tmp_;
+
+  /* List of external eid */
+  std::set<uint32_t> ext_io_eid_;
+
+  /* Engine of all tensors existing in this registry */
+  dnnl::engine eng_;
+
+  /* Execution stream use to reorder const data */
+  dnnl::stream stream_;
+};
+
+}  // namespace contrib
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc 
b/src/runtime/contrib/dnnl/dnnl_utils.cc
index 7e79f1c939..23992209f2 100644
--- a/src/runtime/contrib/dnnl/dnnl_utils.cc
+++ b/src/runtime/contrib/dnnl/dnnl_utils.cc
@@ -23,11 +23,14 @@
 
 #include "dnnl_utils.h"
 
+#include "tvm/runtime/logging.h"
+
 namespace tvm {
 namespace runtime {
 namespace contrib {
-using dt = dnnl::memory::data_type;
-dt dtype_dl2dnnl(DLDataType dltype) {
+
+dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) {
+  using dt = dnnl::memory::data_type;
   dt dnnl_type = dt::undef;
   if (dltype.code == DataType::TypeCode::kFloat) {
     if (dltype.bits == 16) {
@@ -51,6 +54,23 @@ dt dtype_dl2dnnl(DLDataType dltype) {
   }
   return dnnl_type;
 }
+
+dnnl::memory::dims shape_dl2dnnl(const std::vector<int64_t>& shape) {
+  if (shape.empty()) return {1};  // DNNL scalar representation is 1D tensor
+  return shape;
+}
+
+dnnl::memory::desc MakePlainDesc(const std::vector<int64_t>& shape, DLDataType 
dltype) {
+  auto dnnl_shape = shape_dl2dnnl(shape);
+  auto dnnl_dtype = dtype_dl2dnnl(dltype);
+
+  auto dnnl_plain_strides = dnnl::memory::dims(dnnl_shape.size(), 1);
+  for (int i = dnnl_shape.size() - 2; i >= 0; i--)
+    dnnl_plain_strides[i] = dnnl_plain_strides[i + 1] * dnnl_shape[i + 1];
+
+  return {dnnl_shape, dnnl_dtype, dnnl_plain_strides};
+}
+
 }  // namespace contrib
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h 
b/src/runtime/contrib/dnnl/dnnl_utils.h
index 4fb236f96f..a598b67044 100644
--- a/src/runtime/contrib/dnnl/dnnl_utils.h
+++ b/src/runtime/contrib/dnnl/dnnl_utils.h
@@ -18,16 +18,23 @@
  */
 
 /*!
- * \file src/runtime/contrib/dnnl/dnnl_utils.h
- * \brief utils for DNNL.
+ * \file src/runtime/contrib/dnnl/dnnl_utils.cc
+ * \brief Some DNNL specific utility functions
  */
 
 #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
 #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
 
-#include <tvm/runtime/data_type.h>
+#include <cstdint>
+#include <ostream>
+#include <string>
+#include <vector>
 
-#include "dnnl.hpp"
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+//  -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "tvm/runtime/data_type.h"
 
 namespace tvm {
 namespace runtime {
@@ -40,7 +47,90 @@ namespace contrib {
  */
 dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype);
 
+/*!
+ * \brief Converter TVM shape to DNNL dims
+ * \param shape tvm shape
+ * \return dims in terms of dnnl
+ */
+dnnl::memory::dims shape_dl2dnnl(const std::vector<int64_t>& shape);
+
+/*!
+ * \brief Construct plain tensor descriptor
+ * \param shape provided shape
+ * \param dltype provided data type
+ * \return resulting plain tensor desc
+ */
+dnnl::memory::desc MakePlainDesc(const std::vector<int64_t>& shape, DLDataType 
dltype);
+
+namespace utils {
+
+/*! \brief Pretty printer util for shape */
+inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::dims& 
dims) {
+  o << "[";
+  auto d = dims.begin();
+  if (d != dims.end()) o << *d++;
+  while (d != dims.end()) o << "," << *d++;
+  o << "]";
+  return o;
+}
+
+/*! \brief Pretty printer util for data type */
+inline std::ostream& operator<<(std::ostream& o, const 
dnnl::memory::data_type& type) {
+  std::string name = "undef";
+  switch (type) {
+    case dnnl::memory::data_type::undef:
+      name = "undef";
+      break;
+    case dnnl::memory::data_type::f32:
+      name = "fp32";
+      break;
+    case dnnl::memory::data_type::f16:
+      name = "fp16";
+      break;
+    case dnnl::memory::data_type::bf16:
+      name = "bf16";
+      break;
+    case dnnl::memory::data_type::s32:
+      name = "i32";
+      break;
+    case dnnl::memory::data_type::s8:
+      name = "i8";
+      break;
+    case dnnl::memory::data_type::u8:
+      name = "u8";
+      break;
+  }
+  o << name;
+  return o;
+}
+
+/*! \brief Converter data type template arg to runtime object */
+template <typename T>
+inline dnnl::memory::data_type DnnlDType();
+
+template <>
+inline dnnl::memory::data_type DnnlDType<int>() {
+  return dnnl::memory::data_type::s32;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<float>() {
+  return dnnl::memory::data_type::f32;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<uint8_t>() {
+  return dnnl::memory::data_type::u8;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<int8_t>() {
+  return dnnl::memory::data_type::s8;
+}
+
+}  // namespace utils
 }  // namespace contrib
 }  // namespace runtime
 }  // namespace tvm
+
 #endif  // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_

Reply via email to