masahi commented on a change in pull request #5919:
URL: https://github.com/apache/incubator-tvm/pull/5919#discussion_r447565441



##########
File path: src/runtime/contrib/dnnl/dnnl_json_runtime.cc
##########
@@ -0,0 +1,455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+ * \brief A simple JSON runtime for DNNL.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "dnnl.hpp"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class DNNLJSONRuntime : public JSONRuntimeBase {
+  using tag = dnnl::memory::format_tag;
+  using dt = dnnl::memory::data_type;
+
+ public:
+  DNNLJSONRuntime(const std::string& symbol_name, const std::string& 
graph_json,
+                  const Array<String> const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+  const char* type_key() const { return "dnnl_json"; }
+
+  void Init(const Array<NDArray>& consts) override {
+    BuildEngine();
+
+    CHECK_EQ(consts.size(), const_idx_.size())
+        << "The number of input constants must match the number of required.";
+
+    // Setup constants entries for weights.
+    SetupConstants(consts);
+  }
+
+  void Run() override {
+    // Fill in the input buffers.
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto eid = EntryID(input_nodes_[i], 0);
+      // TODO(@comaniac): Support other data lengths.
+      size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+      size_t buffer_size = GetDataSize(*data_entry_[eid]);
+      write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, 
buffer_size,
+                           offset_in_bytes);
+    }
+
+    // Invoke the engine through intepreting the stream.
+    for (size_t i = 0; i < net_.size(); ++i) {
+      net_.at(i).execute(stream_, net_args_.at(i));
+    }
+    stream_.wait();
+
+    // Read output buffers.
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      auto eid = EntryID(outputs_[i]);
+      size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+      size_t buffer_size = GetDataSize(*data_entry_[eid]);
+      read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, 
buffer_size,
+                            offset_in_bytes);
+    }
+  }
+
+ private:
+  // Build up the engine based on the input graph.
+  void BuildEngine() {
+    engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
+    stream_ = dnnl::stream(engine_);
+
+    // Build subgraph engine.
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      const auto& node = nodes_[nid];
+      if (node.GetOpType() == "kernel") {
+        CHECK_EQ(node.GetOpType(), "kernel");
+        auto op_name = node.GetOpName();
+        if ("nn.conv2d" == op_name) {
+          Conv2d(nid);
+        } else if ("dnnl.conv2d_relu" == op_name) {
+          Conv2d(nid, true, false);
+        } else if ("dnnl.conv2d_bias_relu" == op_name) {
+          Conv2d(nid, true, true);
+        } else if ("nn.dense" == op_name) {
+          Dense(nid);
+        } else if ("nn.batch_norm" == op_name) {
+          BatchNorm(nid);
+        } else if ("nn.relu" == op_name) {
+          Relu(nid);
+        } else if ("add" == op_name) {
+          Add(nid);
+        } else {
+          LOG(FATAL) << "Unsupported op: " << op_name;
+        }
+      }
+    }
+  }
+
+  // Bind a JSON graph node entry to a DNNL memory.
+  dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, 
dnnl::memory::desc mem_desc,
+                              size_t offset = 0) {
+    auto eid = EntryID(entry);
+    if (entry_out_mem_.count(eid) == 0) {
+      return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset);
+    }
+    return entry_out_mem_[eid].first;
+  }
+
+  // Bind a JSON graph node entry to a given DNNL memory.
+  dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory 
mem,
+                              size_t offset = 0) {
+    auto eid = EntryID(entry);
+    // Since the DNNL memory has been created before calling this function, we 
assume the entry
+    // has not yet been bind to the other DNNL memory; otherwise it may have 
memory leak.

Review comment:
       bound




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to