This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bbca53d2ab [DNNL] Add TensorRequisite concept (#11345)
bbca53d2ab is described below
commit bbca53d2ab354d7e8bed11fc9e1eae13fbee7730
Author: apeskov <[email protected]>
AuthorDate: Thu Jun 2 13:04:12 2022 +0300
[DNNL] Add TensorRequisite concept (#11345)
Allow to use DNNL runtime in multi instance mode.
Thread safe execution of Run() method.
Signed-off-by: Alexander Peskov <[email protected]>
---
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 1412 ++++++----------------
src/runtime/contrib/dnnl/dnnl_tensor_requisite.h | 720 +++++++++++
src/runtime/contrib/dnnl/dnnl_utils.cc | 24 +-
src/runtime/contrib/dnnl/dnnl_utils.h | 98 +-
4 files changed, 1239 insertions(+), 1015 deletions(-)
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index f6a1c3b790..a2417f012e 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -32,7 +32,12 @@
#include "../json/json_node.h"
#include "../json/json_runtime.h"
-#include "dnnl.hpp"
+
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "dnnl_tensor_requisite.h"
#include "dnnl_utils.h"
namespace tvm {
@@ -43,552 +48,82 @@ using namespace tvm::runtime;
using namespace tvm::runtime::json;
class DNNLJSONRuntime : public JSONRuntimeBase {
- using tag = dnnl::memory::format_tag;
- using dt = dnnl::memory::data_type;
-
public:
DNNLJSONRuntime(const std::string& symbol_name, const std::string&
graph_json,
const Array<String> const_names)
- : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+ : JSONRuntimeBase(symbol_name, graph_json, const_names),
+ next_unique_eid_offset_(data_entry_.size()),
+ run_arg_eid_(input_var_eid_) {
+ for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e));
+ }
- const char* type_key() const { return "dnnl_json"; }
+ const char* type_key() const override { return "dnnl_json"; }
void Init(const Array<NDArray>& consts) override {
- BuildEngine();
-
ICHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";
// Setup constants entries for weights.
SetupConstants(consts);
+ BuildEngine();
}
- void Run() override {
- // Fill in the input buffers.
- for (size_t i = 0; i < input_nodes_.size(); ++i) {
- auto eid = EntryID(input_nodes_[i], 0);
- size_t offset_in_bytes =
- entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) /
8);
- size_t buffer_size = GetDataSize(*data_entry_[eid]);
- write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first,
buffer_size,
- offset_in_bytes);
- }
+ /* Unused stub implementation */
+ void Run() override { LOG(FATAL) << "Unreachable code"; }
- // Invoke the engine through intepreting the stream.
- for (size_t i = 0; i < net_.size(); ++i) {
- net_.at(i).execute(stream_, net_args_.at(i));
- }
- stream_.wait();
-
- // Read output buffers.
- for (size_t i = 0; i < outputs_.size(); ++i) {
- auto eid = EntryID(outputs_[i]);
- size_t offset_in_bytes =
- entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) /
8);
- size_t buffer_size = GetDataSize(*data_entry_[eid]);
- read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first,
buffer_size,
- offset_in_bytes);
+ /* Thread safe implementation of Run. Keep runtime instance immutable */
+ void Run(const TVMArgs& args) const {
+ auto arg_data_provider = makeIODataProvider(args);
+ auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider);
+ // Execute primitives one by one
+ for (const auto& act : net_) {
+ auto prim = std::get<0>(act);
+ auto arg_reqs = std::get<1>(act);
+
+ // Find proper dnnl::memory buffers
+ std::unordered_map<int, dnnl::memory> mem_args;
+ for (const auto& kvp : arg_reqs) mem_args[kvp.first] =
mem_solver(kvp.second);
+
+ prim.execute(stream_, mem_args);
}
}
- private:
- tag layout2tag(std::string layout) {
- static const std::map<std::string, tag> str2tag = {{"nc", tag::nc},
- {"cn", tag::cn},
- {"tn", tag::tn},
- {"nt", tag::nt},
- {"ncw", tag::ncw},
- {"nwc", tag::nwc},
- {"nchw", tag::nchw},
- {"nhwc", tag::nhwc},
- {"chwn", tag::chwn},
- {"ncdhw", tag::ncdhw},
- {"ndhwc", tag::ndhwc},
- {"oi", tag::oi},
- {"io", tag::io},
- {"oiw", tag::oiw},
- {"owi", tag::owi},
- {"wio", tag::wio},
- {"iwo", tag::iwo},
- {"oihw", tag::oihw},
- {"hwio", tag::hwio},
- {"ohwi", tag::ohwi},
- {"ihwo", tag::ihwo},
- {"iohw", tag::iohw},
- {"oidhw", tag::oidhw},
- {"dhwio", tag::dhwio},
- {"odhwi", tag::odhwi},
- {"iodhw", tag::iodhw},
- {"idhwo", tag::idhwo},
- {"goiw", tag::goiw},
- {"gowi", tag::gowi},
- {"wigo", tag::wigo},
- {"gohwi", tag::gohwi},
- {"goihw", tag::goihw},
- {"hwigo", tag::hwigo},
- {"giohw", tag::giohw},
- {"goidhw", tag::goidhw},
- {"giodhw", tag::giodhw},
- {"godhwi", tag::godhwi},
- {"dhwigo", tag::dhwigo},
- {"tnc", tag::tnc},
- {"ntc", tag::ntc},
- {"ldnc", tag::ldnc},
- {"ldigo", tag::ldigo},
- {"ldgoi", tag::ldgoi},
- {"ldio", tag::ldio},
- {"ldoi", tag::ldoi},
- {"ldgo", tag::ldgo},
- {"nCdhw16c",
tag::nCdhw16c},
- {"nCdhw4c",
tag::nCdhw4c},
- {"nCdhw8c",
tag::nCdhw8c},
- {"nChw16c",
tag::nChw16c},
- {"nChw4c", tag::nChw4c},
- {"nChw8c", tag::nChw8c},
- {"nCw16c", tag::nCw16c},
- {"nCw4c", tag::nCw4c},
- {"nCw8c", tag::nCw8c},
- {"NCw16n16c",
tag::NCw16n16c},
- {"NChw16n16c",
tag::NChw16n16c},
- {"NCdhw16n16c",
tag::NCdhw16n16c},
- {"NCdhw32n32c",
tag::NCdhw32n32c},
- {"NChw32n32c",
tag::NChw32n32c},
- {"IOhw16i16o",
tag::IOhw16i16o},
- {"OI16i16o",
tag::OI16i16o},
- {"OI16i32o",
tag::OI16i32o},
- {"OI16i64o",
tag::OI16i64o},
- {"OI8i16o2i",
tag::OI8i16o2i},
- {"OI8i32o2i",
tag::OI8i32o2i},
- {"OI8i64o2i",
tag::OI8i64o2i},
- {"OI4i16o4i",
tag::OI4i16o4i},
- {"OI4i32o4i",
tag::OI4i32o4i},
- {"OI4i64o4i",
tag::OI4i64o4i},
- {"Ohwi32o",
tag::Ohwi32o},
- {"IOdhw16i16o",
tag::IOdhw16i16o},
- {"gIOhw16i16o",
tag::gIOhw16i16o},
- {"gOhwi32o",
tag::gOhwi32o},
- {"Goidhw16g",
tag::Goidhw16g},
- {"IOw16o16i",
tag::IOw16o16i},
- {"OIw16i16o",
tag::OIw16i16o},
- {"OIw16i32o",
tag::OIw16i32o},
- {"OIw16i64o",
tag::OIw16i64o},
- {"IOw16i16o",
tag::IOw16i16o},
- {"gIOw16i16o",
tag::gIOw16i16o},
- {"OIw16o16i",
tag::OIw16o16i},
- {"Oiw16o", tag::Oiw16o},
- {"OIw4i16o4i",
tag::OIw4i16o4i},
- {"OIw4i32o4i",
tag::OIw4i32o4i},
- {"OIw4i64o4i",
tag::OIw4i64o4i},
- {"OIw2i8o4i",
tag::OIw2i8o4i},
- {"OIw4i4o",
tag::OIw4i4o},
- {"OIw4o4i",
tag::OIw4o4i},
- {"Oiw4o", tag::Oiw4o},
- {"OIw8i16o2i",
tag::OIw8i16o2i},
- {"OIw8i32o2i",
tag::OIw8i32o2i},
- {"OIw8i64o2i",
tag::OIw8i64o2i},
- {"OIw8i8o",
tag::OIw8i8o},
- {"OIw8o16i2o",
tag::OIw8o16i2o},
- {"OIw8o8i",
tag::OIw8o8i},
- {"OIw8o4i",
tag::OIw8o4i},
- {"OIw16i16o4i",
tag::OIw16i16o4i},
- {"OIw16i32o4i",
tag::OIw16i32o4i},
- {"OIw16i48o4i",
tag::OIw16i48o4i},
- {"OIw16i64o4i",
tag::OIw16i64o4i},
- {"OIw16i16o2i",
tag::OIw16i16o2i},
- {"OIw16i32o2i",
tag::OIw16i32o2i},
- {"OIw16i48o2i",
tag::OIw16i48o2i},
- {"OIw16i64o2i",
tag::OIw16i64o2i},
- {"OIw16o16i2o",
tag::OIw16o16i2o},
- {"Owi16o", tag::Owi16o},
- {"OwI16o2i",
tag::OwI16o2i},
- {"Owi4o", tag::Owi4o},
- {"Owi8o", tag::Owi8o},
- {"IOhw16o16i",
tag::IOhw16o16i},
- {"Ohwi16o",
tag::Ohwi16o},
- {"OhwI16o2i",
tag::OhwI16o2i},
- {"Ohwi4o", tag::Ohwi4o},
- {"Ohwi8o", tag::Ohwi8o},
- {"OIhw16i16o",
tag::OIhw16i16o},
- {"OIhw16i32o",
tag::OIhw16i32o},
- {"OIhw16i64o",
tag::OIhw16i64o},
- {"OIhw16o16i",
tag::OIhw16o16i},
- {"Oihw16o",
tag::Oihw16o},
- {"OIhw4i16o4i",
tag::OIhw4i16o4i},
- {"OIhw4i32o4i",
tag::OIhw4i32o4i},
- {"OIhw4i64o4i",
tag::OIhw4i64o4i},
- {"OIhw4i4o",
tag::OIhw4i4o},
- {"OIhw4o4i",
tag::OIhw4o4i},
- {"Oihw4o", tag::Oihw4o},
- {"OIhw8i16o2i",
tag::OIhw8i16o2i},
- {"OIhw8i32o2i",
tag::OIhw8i32o2i},
- {"OIhw8i64o2i",
tag::OIhw8i64o2i},
- {"OIhw8i8o",
tag::OIhw8i8o},
- {"OIhw8o16i2o",
tag::OIhw8o16i2o},
- {"OIhw8o8i",
tag::OIhw8o8i},
- {"OIhw8o4i",
tag::OIhw8o4i},
- {"OIhw2i8o4i",
tag::OIhw2i8o4i},
- {"IOdhw16o16i",
tag::IOdhw16o16i},
- {"Odhwi16o",
tag::Odhwi16o},
- {"OdhwI16o2i",
tag::OdhwI16o2i},
- {"Odhwi4o",
tag::Odhwi4o},
- {"Odhwi8o",
tag::Odhwi8o},
- {"OIdhw16i16o",
tag::OIdhw16i16o},
- {"OIdhw16i32o",
tag::OIdhw16i32o},
- {"OIdhw16i64o",
tag::OIdhw16i64o},
- {"OIdhw16o16i",
tag::OIdhw16o16i},
- {"Oidhw16o",
tag::Oidhw16o},
- {"OIdhw4i4o",
tag::OIdhw4i4o},
- {"OIdhw4o4i",
tag::OIdhw4o4i},
- {"Oidhw4o",
tag::Oidhw4o},
- {"OIdhw8i16o2i",
tag::OIdhw8i16o2i},
- {"OIdhw8i32o2i",
tag::OIdhw8i32o2i},
- {"OIdhw8i64o2i",
tag::OIdhw8i64o2i},
- {"OIdhw4i16o4i",
tag::OIdhw4i16o4i},
- {"OIdhw16i16o4i",
tag::OIdhw16i16o4i},
- {"OIdhw16i32o4i",
tag::OIdhw16i32o4i},
- {"OIdhw16i48o4i",
tag::OIdhw16i48o4i},
- {"OIdhw16i64o4i",
tag::OIdhw16i64o4i},
- {"OIdhw16i16o2i",
tag::OIdhw16i16o2i},
- {"OIdhw16i32o2i",
tag::OIdhw16i32o2i},
- {"OIdhw16i48o2i",
tag::OIdhw16i48o2i},
- {"OIdhw16i64o2i",
tag::OIdhw16i64o2i},
- {"OIdhw4i32o4i",
tag::OIdhw4i32o4i},
- {"OIdhw4i64o4i",
tag::OIdhw4i64o4i},
- {"OIdhw2i8o4i",
tag::OIdhw2i8o4i},
- {"OIdhw8i8o",
tag::OIdhw8i8o},
- {"OIdhw8o8i",
tag::OIdhw8o8i},
- {"OIdhw8o4i",
tag::OIdhw8o4i},
- {"gIOw16o16i",
tag::gIOw16o16i},
- {"gOIw16i16o",
tag::gOIw16i16o},
- {"gOIw16o16i",
tag::gOIw16o16i},
- {"gOiw16o",
tag::gOiw16o},
- {"gOIw4i16o4i",
tag::gOIw4i16o4i},
- {"gOIw2i8o4i",
tag::gOIw2i8o4i},
- {"gOIw4i4o",
tag::gOIw4i4o},
- {"gOIw4o4i",
tag::gOIw4o4i},
- {"gOiw4o", tag::gOiw4o},
- {"gOIw8i16o2i",
tag::gOIw8i16o2i},
- {"gOIw8i8o",
tag::gOIw8i8o},
- {"gOIw8o16i2o",
tag::gOIw8o16i2o},
- {"gOIw8o8i",
tag::gOIw8o8i},
- {"gOIw8o4i",
tag::gOIw8o4i},
- {"gOIw16i16o4i",
tag::gOIw16i16o4i},
- {"gOIw16i16o2i",
tag::gOIw16i16o2i},
- {"gOIw16o16i2o",
tag::gOIw16o16i2o},
- {"gOwi16o",
tag::gOwi16o},
- {"gOwI16o2i",
tag::gOwI16o2i},
- {"gOwi4o", tag::gOwi4o},
- {"gOwi8o", tag::gOwi8o},
- {"Goiw8g", tag::Goiw8g},
- {"Goiw16g",
tag::Goiw16g},
- {"gIOhw16o16i",
tag::gIOhw16o16i},
- {"gOhwi16o",
tag::gOhwi16o},
- {"gOhwI16o2i",
tag::gOhwI16o2i},
- {"gOhwi4o",
tag::gOhwi4o},
- {"gOhwi8o",
tag::gOhwi8o},
- {"Goihw16g",
tag::Goihw16g},
- {"gOIhw16i16o",
tag::gOIhw16i16o},
- {"gOIhw16o16i",
tag::gOIhw16o16i},
- {"gOihw16o",
tag::gOihw16o},
- {"gOIhw4i16o4i",
tag::gOIhw4i16o4i},
- {"gOIhw2i8o4i",
tag::gOIhw2i8o4i},
- {"gOIhw4i4o",
tag::gOIhw4i4o},
- {"gOIhw4o4i",
tag::gOIhw4o4i},
- {"gOihw4o",
tag::gOihw4o},
- {"Goihw8g",
tag::Goihw8g},
- {"gOIhw8i16o2i",
tag::gOIhw8i16o2i},
- {"gOIhw8i8o",
tag::gOIhw8i8o},
- {"gOIhw8o16i2o",
tag::gOIhw8o16i2o},
- {"OIw4o8i8o4i",
tag::OIw4o8i8o4i},
- {"OIdhw4o8i8o4i",
tag::OIdhw4o8i8o4i},
- {"OIhw4o8i8o4i",
tag::OIhw4o8i8o4i},
- {"OIhw2o8i8o2i",
tag::OIhw2o8i8o2i},
- {"gOIw4o8i8o4i",
tag::gOIw4o8i8o4i},
- {"gOIdhw4o8i8o4i",
tag::gOIdhw4o8i8o4i},
- {"gOIhw4o8i8o4i",
tag::gOIhw4o8i8o4i},
- {"gOIhw2o8i8o2i",
tag::gOIhw2o8i8o2i},
- {"OIhw16i16o4i",
tag::OIhw16i16o4i},
- {"OIhw16i32o4i",
tag::OIhw16i32o4i},
- {"OIhw16i48o4i",
tag::OIhw16i48o4i},
- {"OIhw16i64o4i",
tag::OIhw16i64o4i},
- {"OIhw16i16o2i",
tag::OIhw16i16o2i},
- {"OIhw16i32o2i",
tag::OIhw16i32o2i},
- {"OIhw16i48o2i",
tag::OIhw16i48o2i},
- {"OIhw16i64o2i",
tag::OIhw16i64o2i},
- {"OIhw16o16i2o",
tag::OIhw16o16i2o},
- {"gOIhw16i16o4i",
tag::gOIhw16i16o4i},
- {"gOIhw16i16o2i",
tag::gOIhw16i16o2i},
- {"gOIhw16o16i2o",
tag::gOIhw16o16i2o},
- {"gOIhw8o8i",
tag::gOIhw8o8i},
- {"gOIhw8o4i",
tag::gOIhw8o4i},
- {"gIOdhw16i16o",
tag::gIOdhw16i16o},
- {"gIOdhw16o16i",
tag::gIOdhw16o16i},
- {"gOdhwi16o",
tag::gOdhwi16o},
- {"gOdhwI16o2i",
tag::gOdhwI16o2i},
- {"gOdhwi4o",
tag::gOdhwi4o},
- {"gOdhwi8o",
tag::gOdhwi8o},
- {"gOIdhw16i16o",
tag::gOIdhw16i16o},
- {"gOIdhw16o16i",
tag::gOIdhw16o16i},
- {"gOidhw16o",
tag::gOidhw16o},
- {"gOIdhw4i4o",
tag::gOIdhw4i4o},
- {"gOIdhw4o4i",
tag::gOIdhw4o4i},
- {"gOidhw4o",
tag::gOidhw4o},
- {"gOIdhw8i16o2i",
tag::gOIdhw8i16o2i},
- {"gOIdhw4i16o4i",
tag::gOIdhw4i16o4i},
- {"gOIdhw16i16o4i",
tag::gOIdhw16i16o4i},
- {"gOIdhw16i16o2i",
tag::gOIdhw16i16o2i},
- {"gOIdhw2i8o4i",
tag::gOIdhw2i8o4i},
- {"gOIdhw8i8o",
tag::gOIdhw8i8o},
- {"gOIdhw8o8i",
tag::gOIdhw8o8i},
- {"gOIdhw8o4i",
tag::gOIdhw8o4i},
- {"gOIw2i4o2i",
tag::gOIw2i4o2i},
- {"gOIhw2i4o2i",
tag::gOIhw2i4o2i},
- {"gOIdhw2i4o2i",
tag::gOIdhw2i4o2i},
- {"gOIw2o4i2o",
tag::gOIw2o4i2o},
- {"gOIhw2o4i2o",
tag::gOIhw2o4i2o},
- {"gOIdhw2o4i2o",
tag::gOIdhw2o4i2o},
- {"gOIw4i8o2i",
tag::gOIw4i8o2i},
- {"gOIhw4i8o2i",
tag::gOIhw4i8o2i},
- {"gOIdhw4i8o2i",
tag::gOIdhw4i8o2i},
- {"gOIw4o8i2o",
tag::gOIw4o8i2o},
- {"gOIhw4o8i2o",
tag::gOIhw4o8i2o},
- {"gOIdhw4o8i2o",
tag::gOIdhw4o8i2o},
- {"ldOi32o",
tag::ldOi32o},
- {"ldOI32o4i",
tag::ldOI32o4i},
- {"ldgOi32o",
tag::ldgOi32o},
- {"ldgOI32o2i",
tag::ldgOI32o2i},
- {"ldgOI32o4i",
tag::ldgOI32o4i},
- {"OwI16o4i",
tag::OwI16o4i},
- {"OhwI16o4i",
tag::OhwI16o4i},
- {"gOwI16o4i",
tag::gOwI16o4i},
- {"gOhwI16o4i",
tag::gOhwI16o4i},
- {"OdhwI16o4i",
tag::OdhwI16o4i},
- {"gOdhwI16o4i",
tag::gOdhwI16o4i},
- {"Owi32o", tag::Owi32o},
- {"OwI32o2i",
tag::OwI32o2i},
- {"OwI32o4i",
tag::OwI32o4i},
- {"Owi48o", tag::Owi48o},
- {"OwI48o2i",
tag::OwI48o2i},
- {"OwI48o4i",
tag::OwI48o4i},
- {"Owi64o", tag::Owi64o},
- {"OwI64o2i",
tag::OwI64o2i},
- {"OwI64o4i",
tag::OwI64o4i},
- {"wIo2i", tag::wIo2i},
- {"wIo4i", tag::wIo4i},
- {"gOwi32o",
tag::gOwi32o},
- {"gOwI32o2i",
tag::gOwI32o2i},
- {"gOwI32o4i",
tag::gOwI32o4i},
- {"gOwi48o",
tag::gOwi48o},
- {"gOwI48o2i",
tag::gOwI48o2i},
- {"gOwI48o4i",
tag::gOwI48o4i},
- {"gOwi64o",
tag::gOwi64o},
- {"gOwI64o2i",
tag::gOwI64o2i},
- {"gOwI64o4i",
tag::gOwI64o4i},
- {"gwio", tag::gwio},
- {"gwIo2i", tag::gwIo2i},
- {"gwIo4i", tag::gwIo4i},
- {"OhwI32o",
tag::OhwI32o},
- {"OhwI32o2i",
tag::OhwI32o2i},
- {"OhwI32o4i",
tag::OhwI32o4i},
- {"Ohwi48o",
tag::Ohwi48o},
- {"OhwI48o2i",
tag::OhwI48o2i},
- {"OhwI48o4i",
tag::OhwI48o4i},
- {"Ohwi64o",
tag::Ohwi64o},
- {"OhwI64o2i",
tag::OhwI64o2i},
- {"OhwI64o4i",
tag::OhwI64o4i},
- {"hwIo2i", tag::hwIo2i},
- {"hwIo4i", tag::hwIo4i},
- {"gOhwI32o",
tag::gOhwI32o},
- {"gOhwI32o2i",
tag::gOhwI32o2i},
- {"gOhwI32o4i",
tag::gOhwI32o4i},
- {"gOhwi48o",
tag::gOhwi48o},
- {"gOhwI48o2i",
tag::gOhwI48o2i},
- {"gOhwI48o4i",
tag::gOhwI48o4i},
- {"gOhwi64o",
tag::gOhwi64o},
- {"gOhwI64o2i",
tag::gOhwI64o2i},
- {"gOhwI64o4i",
tag::gOhwI64o4i},
- {"ghwio", tag::ghwio},
- {"ghwIo2i",
tag::ghwIo2i},
- {"ghwIo4i",
tag::ghwIo4i},
- {"Odhwi32o",
tag::Odhwi32o},
- {"OdhwI32o2i",
tag::OdhwI32o2i},
- {"OdhwI32o4i",
tag::OdhwI32o4i},
- {"Odhwi48o",
tag::Odhwi48o},
- {"OdhwI48o2i",
tag::OdhwI48o2i},
- {"OdhwI48o4i",
tag::OdhwI48o4i},
- {"Odhwi64o",
tag::Odhwi64o},
- {"OdhwI64o2i",
tag::OdhwI64o2i},
- {"OdhwI64o4i",
tag::OdhwI64o4i},
- {"dhwIo2i",
tag::dhwIo2i},
- {"dhwIo4i",
tag::dhwIo4i},
- {"gOdhwi32o",
tag::gOdhwi32o},
- {"gOdhwI32o2i",
tag::gOdhwI32o2i},
- {"gOdhwI32o4i",
tag::gOdhwI32o4i},
- {"gOdhwi48o",
tag::gOdhwi48o},
- {"gOdhwI48o2i",
tag::gOdhwI48o2i},
- {"gOdhwI48o4i",
tag::gOdhwI48o4i},
- {"gOdhwi64o",
tag::gOdhwi64o},
- {"gOdhwI64o2i",
tag::gOdhwI64o2i},
- {"gOdhwI64o4i",
tag::gOdhwI64o4i},
- {"gdhwio", tag::gdhwio},
- {"gdhwIo2i",
tag::gdhwIo2i},
- {"gdhwIo4i",
tag::gdhwIo4i},
- {"ldIo32i",
tag::ldIo32i},
- {"ldgIo32i",
tag::ldgIo32i},
- {"ldgIO32i2o",
tag::ldgIO32i2o},
- {"nCdhw32c",
tag::nCdhw32c},
- {"nChw32c",
tag::nChw32c},
- {"nCw32c", tag::nCw32c},
- {"NCw32n16c",
tag::NCw32n16c},
- {"NChw32n16c",
tag::NChw32n16c},
- {"NCdhw32n16c",
tag::NCdhw32n16c},
- {"NCw32n32c",
tag::NCw32n32c},
- {"OI16i16o4i",
tag::OI16i16o4i},
- {"IOw8o16i2o",
tag::IOw8o16i2o},
- {"IOhw8o16i2o",
tag::IOhw8o16i2o},
- {"Owhi16o",
tag::Owhi16o},
- {"OIdhw8o16i2o",
tag::OIdhw8o16i2o},
- {"IOdhw8o16i2o",
tag::IOdhw8o16i2o},
- {"Goiw4g", tag::Goiw4g},
- {"gIOw8o16i2o",
tag::gIOw8o16i2o},
- {"Goiw32g",
tag::Goiw32g},
- {"Goihw4g",
tag::Goihw4g},
- {"gIOhw8o16i2o",
tag::gIOhw8o16i2o},
- {"Goihw32g",
tag::Goihw32g},
- {"gOwhi16o",
tag::gOwhi16o},
- {"IOw4i8o8i4o",
tag::IOw4i8o8i4o},
- {"IOhw4i8o8i4o",
tag::IOhw4i8o8i4o},
- {"IOdhw4i8o8i4o",
tag::IOdhw4i8o8i4o},
- {"gIOw4i8o8i4o",
tag::gIOw4i8o8i4o},
- {"gIOhw4i8o8i4o",
tag::gIOhw4i8o8i4o},
- {"gIOdhw4i8o8i4o",
tag::gIOdhw4i8o8i4o},
- {"gOIdhw8o16i2o",
tag::gOIdhw8o16i2o},
- {"gIOdhw8o16i2o",
tag::gIOdhw8o16i2o},
- {"Goidhw32g",
tag::Goidhw32g},
- {"OI16i32o4i",
tag::OI16i32o4i},
- {"OI16i48o4i",
tag::OI16i48o4i},
- {"OI16i64o4i",
tag::OI16i64o4i},
- {"OI16i16o2i",
tag::OI16i16o2i},
- {"OI16i32o2i",
tag::OI16i32o2i},
- {"OI16i48o2i",
tag::OI16i48o2i},
- {"OI16i64o2i",
tag::OI16i64o2i},
- {"OwI16i16o2i",
tag::OwI16i16o2i},
- {"gOwI16i16o2i",
tag::gOwI16i16o2i},
- {"OhwI16i16o2i",
tag::OhwI16i16o2i},
- {"gOhwI16i16o2i",
tag::gOhwI16i16o2i},
- {"OdhwI16i16o2i",
tag::OdhwI16i16o2i},
- {"gOdhwI16i16o2i",
tag::gOdhwI16i16o2i},
- {"OwI16i16o4i",
tag::OwI16i16o4i},
- {"gOwI16i16o4i",
tag::gOwI16i16o4i},
- {"OhwI16i16o4i",
tag::OhwI16i16o4i},
- {"gOhwI16i16o4i",
tag::gOhwI16i16o4i},
- {"OdhwI16i16o4i",
tag::OdhwI16i16o4i},
- {"gOdhwI16i16o4i",
tag::gOdhwI16i16o4i},
- {"OwI16i32o2i",
tag::OwI16i32o2i},
- {"OwI16i32o4i",
tag::OwI16i32o4i},
- {"OwI16i48o2i",
tag::OwI16i48o2i},
- {"OwI16i48o4i",
tag::OwI16i48o4i},
- {"OwI16i64o2i",
tag::OwI16i64o2i},
- {"OwI16i64o4i",
tag::OwI16i64o4i},
- {"gOwI16i32o2i",
tag::gOwI16i32o2i},
- {"gOwI16i32o4i",
tag::gOwI16i32o4i},
- {"gOwI16i48o2i",
tag::gOwI16i48o2i},
- {"gOwI16i48o4i",
tag::gOwI16i48o4i},
- {"gOwI16i64o2i",
tag::gOwI16i64o2i},
- {"gOwI16i64o4i",
tag::gOwI16i64o4i},
- {"OhwI16i32o2i",
tag::OhwI16i32o2i},
- {"OhwI16i32o4i",
tag::OhwI16i32o4i},
- {"OhwI16i48o2i",
tag::OhwI16i48o2i},
- {"OhwI16i48o4i",
tag::OhwI16i48o4i},
- {"OhwI16i64o2i",
tag::OhwI16i64o2i},
- {"OhwI16i64o4i",
tag::OhwI16i64o4i},
- {"gOhwI16i32o2i",
tag::gOhwI16i32o2i},
- {"gOhwI16i32o4i",
tag::gOhwI16i32o4i},
- {"gOhwI16i48o2i",
tag::gOhwI16i48o2i},
- {"gOhwI16i48o4i",
tag::gOhwI16i48o4i},
- {"gOhwI16i64o2i",
tag::gOhwI16i64o2i},
- {"gOhwI16i64o4i",
tag::gOhwI16i64o4i},
- {"OdhwI16i32o2i",
tag::OdhwI16i32o2i},
- {"OdhwI16i32o4i",
tag::OdhwI16i32o4i},
- {"OdhwI16i48o2i",
tag::OdhwI16i48o2i},
- {"OdhwI16i48o4i",
tag::OdhwI16i48o4i},
- {"OdhwI16i64o2i",
tag::OdhwI16i64o2i},
- {"OdhwI16i64o4i",
tag::OdhwI16i64o4i},
- {"gOdhwI16i32o2i",
tag::gOdhwI16i32o2i},
- {"gOdhwI16i32o4i",
tag::gOdhwI16i32o4i},
- {"gOdhwI16i48o2i",
tag::gOdhwI16i48o2i},
- {"gOdhwI16i48o4i",
tag::gOdhwI16i48o4i},
- {"gOdhwI16i64o2i",
tag::gOdhwI16i64o2i},
- {"gOdhwI16i64o4i",
tag::gOdhwI16i64o4i},
- {"hwioG16g",
tag::hwioG16g},
- {"NCdhw40n32c",
tag::NCdhw40n32c},
- {"NChw40n32c",
tag::NChw40n32c},
- {"NCw40n32c",
tag::NCw40n32c},
- {"OIdhw4o8i8o2i",
tag::OIdhw4o8i8o2i},
- {"OIhw4o8i8o2i",
tag::OIhw4o8i8o2i},
- {"OIw4o8i8o2i",
tag::OIw4o8i8o2i},
- {"gOIdhw4o8i8o2i",
tag::gOIdhw4o8i8o2i},
- {"gOIhw4o8i8o2i",
tag::gOIhw4o8i8o2i},
- {"gOIw4o8i8o2i",
tag::gOIw4o8i8o2i},
- {"IOdhw4i8o8i2o",
tag::IOdhw4i8o8i2o},
- {"IOhw4i8o8i2o",
tag::IOhw4i8o8i2o},
- {"IOw4i8o8i2o",
tag::IOw4i8o8i2o},
- {"gIOdhw4i8o8i2o",
tag::gIOdhw4i8o8i2o},
- {"gIOhw4i8o8i2o",
tag::gIOhw4i8o8i2o},
- {"gIOw4i8o8i2o",
tag::gIOw4i8o8i2o},
- {"NCdhw40n16c",
tag::NCdhw40n16c},
- {"NCw40n16c",
tag::NCw40n16c},
- {"NChw40n16c",
tag::NChw40n16c},
- {"NCw2c32n8c",
tag::NCw2c32n8c},
- {"NChw2c32n8c",
tag::NChw2c32n8c},
- {"NCdhw2c32n8c",
tag::NCdhw2c32n8c},
- {"OIw2i8o16i4o",
tag::OIw2i8o16i4o},
- {"OIhw2i8o16i4o",
tag::OIhw2i8o16i4o},
- {"OIdhw2i8o16i4o",
tag::OIdhw2i8o16i4o},
- {"OIw2o8i16o4i",
tag::OIw2o8i16o4i},
- {"OIw2o8i16o2i",
tag::OIw2o8i16o2i},
- {"IOw2i8o16i4o",
tag::IOw2i8o16i4o},
- {"IOw2i8o16i2o",
tag::IOw2i8o16i2o},
- {"OIhw2o8i16o4i",
tag::OIhw2o8i16o4i},
- {"OIhw2o8i16o2i",
tag::OIhw2o8i16o2i},
- {"IOhw2i8o16i4o",
tag::IOhw2i8o16i4o},
- {"IOhw2i8o16i2o",
tag::IOhw2i8o16i2o},
- {"OIdhw2o8i16o4i",
tag::OIdhw2o8i16o4i},
- {"OIdhw2o8i16o2i",
tag::OIdhw2o8i16o2i},
- {"IOdhw2i8o16i4o",
tag::IOdhw2i8o16i4o},
- {"IOdhw2i8o16i2o",
tag::IOdhw2i8o16i2o},
- {"gOIw2o8i16o2i",
tag::gOIw2o8i16o2i},
- {"gIOw2i8o16i2o",
tag::gIOw2i8o16i2o},
- {"gIOhw2i8o16i2o",
tag::gIOhw2i8o16i2o},
- {"gIOdhw2i8o16i2o",
tag::gIOdhw2i8o16i2o},
- {"gOIhw2o8i16o2i",
tag::gOIhw2o8i16o2i},
- {"gOIdhw2o8i16o2i",
tag::gOIdhw2o8i16o2i},
- {"gOIw2o8i16o4i",
tag::gOIw2o8i16o4i},
- {"gOIhw2o8i16o4i",
tag::gOIhw2o8i16o4i}};
- std::string key = "";
- for (const auto& c : layout) {
- if (std::isalpha(c, std::locale("C"))) {
- char lower_c = std::tolower(c);
- if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) {
- key.push_back(c);
- } else {
- key.push_back(lower_c);
- }
- } else if (std::isdigit(c)) {
- key.push_back(c);
- } else {
- LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl;
- }
- }
- if (str2tag.count(key) == 0) {
- LOG(WARNING) << "convert unregistered layout '" << key << "' to
tag::any";
- return tag::any;
+ /* Override GetFunction to reimplement Run method */
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) override {
+ if (this->symbol_name_ == name) {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ ICHECK(this->initialized_) << "The module has not been initialized";
+
+ ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
+ << "Found mismatch in the number of provided data entries and
required.";
+
+ Run(args);
+ });
} else {
- return str2tag.at(key);
+ return JSONRuntimeBase::GetFunction(name, sptr_to_self);
+ }
+ }
+
+ /* Same as makeInitDataProvider but in case of InputOutput return real
DLTensor */
+ TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args)
const {
+ auto extract_dl_tensor = [](const TVMArgValue& val) -> const DLTensor* {
+ ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() ==
kTVMDLTensorHandle)
+ << "Expect NDArray or DLTensor";
+ return val.IsObjectRef<NDArray>() ? val.operator NDArray().operator->()
+ : val.operator DLTensor*();
+ };
+
+ std::map<uint32_t, const DLTensor*> io_map; // eid to dl tensor map
+ for (size_t i = 0; i < run_arg_eid_.size(); i++) {
+ io_map[run_arg_eid_[i]] = extract_dl_tensor(args[i]);
}
+
+ // lambda with captured IO data handlers
+ return [io_map](uint32_t eid) -> const DLTensor* { return io_map.at(eid);
};
}
- std::map<std::string, dnnl::algorithm> elt_name2algo{
+ private:
+ const std::map<std::string, dnnl::algorithm> elt_name2algo{
{"abs", dnnl::algorithm::eltwise_abs},
{"exp", dnnl::algorithm::eltwise_exp},
{"log", dnnl::algorithm::eltwise_log},
@@ -626,64 +161,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return std::regex_match(op_name, bias_add_pat) ? true : false;
}
- dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims,
std::string layout) {
- std::vector<char> axis = {
- 'N', 'C', 'O', 'I', 'D', 'H', 'W',
- };
- dnnl::memory::dims out_dims;
- std::string::iterator t = layout.begin();
- // Remove numbers in layout string to match the size of input_dims
- while (t != layout.end()) {
- if (*t >= '0' && *t <= '9') {
- layout.erase(t);
- } else {
- t++;
- }
- }
- // Push the correct shapes of each axis into the output_dims
- for (auto a : axis) {
- if (layout.find(a) != std::string::npos) {
- dnnl::memory::dim shape = input_dims[layout.find(a)];
- char lower_a = std::tolower(a);
- for (size_t i = 0; i < layout.size(); ++i) {
- if (lower_a == layout[i]) {
- shape *= input_dims[i];
- }
- }
- out_dims.push_back(shape);
- }
- }
- // Multiply O and I with G, respectively
- if (layout.find("G") != std::string::npos) {
- dnnl::memory::dim G = 1;
- if (layout.find("g") != std::string::npos) {
- G = input_dims[layout.find("g")] * input_dims[layout.find("G")];
- } else {
- G = input_dims[layout.find("G")];
- }
- out_dims[0] *= G;
- out_dims[1] *= G;
- }
- return out_dims;
- }
-
- dnnl::memory::dims TransformStr2Dims(std::vector<std::string> strs, bool
dilates = false) {
- dnnl::memory::dims out_dims;
- if (dilates) {
- std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims),
- [](const std::string& str) { return std::stoi(str) - 1;
});
- } else {
- std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims),
- [](const std::string& str) { return std::stoi(str); });
- }
- return out_dims;
- }
-
// Build up the engine based on the input graph.
void BuildEngine() {
engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
stream_ = dnnl::stream(engine_);
+ std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
+ tensor_registry_ = TensorRegistry(engine_, io_eid_set);
+
std::regex conv_pat(".*conv[1-3]d.*");
std::regex deconv_pat(".*deconv[1-3]d.*");
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
@@ -725,562 +210,471 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}
}
- // Bind a JSON graph node entry to a DNNL memory.
- dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry,
dnnl::memory::desc mem_desc,
- size_t offset = 0) {
- auto eid = EntryID(entry);
- if (entry_out_mem_.count(eid) == 0) {
- return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset);
- }
- return entry_out_mem_[eid].first;
- }
-
- // Bind a JSON graph node entry to a given DNNL memory.
- dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory
mem,
- size_t offset = 0) {
- auto eid = EntryID(entry);
- // Since the DNNL memory has been created before calling this function, we
assume the entry
- // has not yet been bound to the other DNNL memory; otherwise it may have
memory leak.
- ICHECK_EQ(entry_out_mem_.count(eid), 0);
-
- entry_out_mem_[eid] = {mem, offset};
- return entry_out_mem_[eid].first;
- }
-
void Convolution(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = node.GetOpName();
dnnl::primitive_attr attr;
+ attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
bool has_bias = ParsingOpName(op_name, attr);
// Setup attributes.
- auto data_entry = node.GetInputs()[0];
- auto weight_entry = node.GetInputs()[1];
- JSONGraphNodeEntry out_entry(nid, 0);
- dnnl::memory::dims input_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::dims weight_shape =
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
- dnnl::memory::dims out_shape =
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
- dnnl::memory::dim channels =
- node.GetAttr<std::vector<std::string>>("channels")[0] != ""
- ? std::stoi(node.GetAttr<std::vector<std::string>>("channels")[0])
- : out_shape[1];
- std::vector<std::string> str_strides =
node.GetAttr<std::vector<std::string>>("strides");
- std::vector<std::string> str_dilates =
node.GetAttr<std::vector<std::string>>("dilation");
- std::vector<std::string> str_padding =
node.GetAttr<std::vector<std::string>>("padding");
- std::vector<std::string> str_padding_l(str_padding.begin(),
- str_padding.begin() +
str_padding.size() / 2);
- std::vector<std::string> str_padding_r(str_padding.end() -
str_padding.size() / 2,
- str_padding.end());
- dnnl::memory::dim groups =
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
- std::string data_layout =
node.GetAttr<std::vector<std::string>>("data_layout")[0];
- std::string kernel_layout =
node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
-
- // Memory shapes.
- dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
- dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape,
kernel_layout);
- dnnl::memory::dims bias_dims = {channels};
- dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
- dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
- dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
- dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
- dnnl::memory::dims dst_dims = src_dims;
- dst_dims[1] = channels;
- weights_dims_[0] = channels;
- weights_dims_[1] = src_dims[1];
- for (size_t i = 2; i < src_dims.size(); i++) {
- dnnl::memory::dim K = weights_dims_[i];
- dnnl::memory::dim S = strides_dims[i - 2];
- dnnl::memory::dim D = dilates_dims[i - 2];
- dnnl::memory::dim PL = padding_dims_l[i - 2];
- dnnl::memory::dim PR = padding_dims_r[i - 2];
- dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
- dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1;
+ auto src_tr = GetInput(nid, 0);
+ auto wgh_tr = GetInput(nid, 1);
+ auto dst_tr = GetOutput(nid, 0);
+ auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+ auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+ auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+ auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+ std::vector<int64_t> padding_l(padding.begin(), padding.begin() +
padding.size() / 2);
+ std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2,
padding.end());
+ auto groups = GetNodeAttr<int>(node, "groups");
+ auto src_layout = GetNodeAttr<std::string>(node, "data_layout");
+ auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+ auto wgh_layout = GetNodeAttr<std::string>(node, "kernel_layout");
+
+ // dst_layout == "" means to use data_layout
+ if (dst_layout.empty()) dst_layout = src_layout;
+
+ // Minus one for DNNL representation. No dilation for DNNL is 0, for relay
is 1.
+ for (auto& d : dilates) d--;
+
+ // Take into account provided layout strings
+ src_tr = src_tr.TreatAs(src_layout);
+ dst_tr = dst_tr.TreatAs(dst_layout);
+ wgh_tr = wgh_tr.TreatAs(wgh_layout);
+
+ // Should support G mixed with O. Like { G*O, I, H, W }
+ // Use { G, O, I, H, W } weight format even if groups == 1
+ if (wgh_layout.find("G") == std::string::npos) {
+ auto w_dims = wgh_tr.dims();
+ w_dims[0] /= groups;
+ w_dims.insert(w_dims.begin(), groups);
+ wgh_tr = wgh_tr.Reshape(w_dims);
}
- dnnl::memory::dims weights_dims = weights_dims_;
- if (groups > 1) {
- weights_dims = {groups, channels / groups, src_dims[1] / groups};
- weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2,
weights_dims_.end());
- if (kernel_layout == "OIHW") {
- kernel_layout.insert(0, "G");
- }
+ // Assumption that bias is correct and can be squeezed to 1D
+ bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
+
+ // TODO(@apeskov): This is WA. In case of padded blocked tensor format we
do not know original
+ // shapes. Example tensor {1, 10, 224, 224} with layout "NCNH8c" will
lead to tensor
+ // {1, 2, 224, 224, 8}. Identically as for shapes {1, 11, 224, 224} or
{1, 15, 224, 224}.
+ //
+ // Let's try to compensate it for weight tensor. Weight IC should match
with source IC.
+ // Example src: [1, 3, 224, 224] with layout NCHW
+ // wgh: [16, 3, 3, 3] with layout OIHW2i8o -> [2, 2, 3, 3, 2, 8]
+ if (wgh_tr.dims()[2] != src_tr.dims()[1] / groups) {
+ auto wgh_croped_dims = wgh_tr.dims();
+ wgh_croped_dims[2] = src_tr.dims()[1];
+ auto zero_offset = dnnl::memory::dims(wgh_tr.dims().size(), 0);
+ wgh_tr = wgh_tr.Crop(wgh_croped_dims, zero_offset);
}
- // Memory descriptions.
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- auto conv_src_md = dnnl::memory::desc(src_dims, dtype,
layout2tag(data_layout));
- auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype,
layout2tag(kernel_layout));
- auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any);
- auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
// Conv description.
- auto conv_desc =
- has_bias ? dnnl::convolution_forward::desc(
- dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
- conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, strides_dims,
- dilates_dims, padding_dims_l, padding_dims_r)
- :
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference,
-
dnnl::algorithm::convolution_direct, conv_src_md,
- conv_weights_md,
conv_dst_md, strides_dims,
- dilates_dims,
padding_dims_l, padding_dims_r);
+ auto conv_desc = dnnl::convolution_forward::desc(
+ dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
+ src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(),
bias_tr.LayoutAny().desc(),
+ dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r);
// Enable elementwise post-ops.
auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc,
attr, engine_);
- // Push to the network.
- auto conv = dnnl::convolution_forward(conv_prim_desc);
- net_.push_back(conv);
-
- // Data memory.
- auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md);
+ src_tr = src_tr.RequestLayout(conv_prim_desc.src_desc());
+ wgh_tr = wgh_tr.RequestLayout(conv_prim_desc.weights_desc());
+ dst_tr = dst_tr.RequestLayout(conv_prim_desc.dst_desc());
+ bias_tr = bias_tr.RequestLayout(conv_prim_desc.bias_desc());
- // Weight memory.
- auto conv_weights_memory = BindDNNLMemory(weight_entry,
conv_prim_desc.weights_desc());
+ auto scratchpad_tr =
TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc());
- // Output memory.
- auto conv_dst_memory = BindDNNLMemory(out_entry,
conv_prim_desc.dst_desc());
-
- // Bias memory.
- auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
- if (has_bias) {
- auto bias_entry = node.GetInputs()[2];
- BindDNNLMemory(bias_entry, conv_bias_memory);
-
- // Bind memory buffers.
- net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory},
- {DNNL_ARG_WEIGHTS, conv_weights_memory},
- {DNNL_ARG_BIAS, conv_bias_memory},
- {DNNL_ARG_DST, conv_dst_memory}});
- } else {
- // Bind memory buffers.
- net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory},
- {DNNL_ARG_WEIGHTS, conv_weights_memory},
- {DNNL_ARG_DST, conv_dst_memory}});
- }
+ Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr},
+ {DNNL_ARG_WEIGHTS,
wgh_tr},
+ {DNNL_ARG_BIAS,
bias_tr},
+ {DNNL_ARG_SCRATCHPAD,
scratchpad_tr},
+ {DNNL_ARG_DST,
dst_tr}});
}
void Deconvolution(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = node.GetOpName();
dnnl::primitive_attr attr;
+ attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
bool has_bias = ParsingOpName(op_name, attr);
// Setup attributes.
- auto data_entry = node.GetInputs()[0];
- auto weight_entry = node.GetInputs()[1];
- JSONGraphNodeEntry out_entry(nid, 0);
- dnnl::memory::dims input_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::dims weight_shape =
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
- dnnl::memory::dims out_shape =
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
- dnnl::memory::dim channels =
- node.GetAttr<std::vector<std::string>>("channels")[0] != ""
- ? std::stoi(node.GetAttr<std::vector<std::string>>("channels")[0])
- : out_shape[1];
- std::vector<std::string> str_strides =
node.GetAttr<std::vector<std::string>>("strides");
- std::vector<std::string> str_dilates =
node.GetAttr<std::vector<std::string>>("dilation");
- std::vector<std::string> str_padding =
node.GetAttr<std::vector<std::string>>("padding");
- std::vector<std::string> str_padding_l(str_padding.begin(),
- str_padding.begin() +
str_padding.size() / 2);
- std::vector<std::string> str_padding_r(str_padding.end() -
str_padding.size() / 2,
- str_padding.end());
- std::vector<std::string> str_out_padding =
- node.GetAttr<std::vector<std::string>>("output_padding");
- dnnl::memory::dim groups =
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
- std::string data_layout =
node.GetAttr<std::vector<std::string>>("data_layout")[0];
- std::string kernel_layout =
node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
-
- // Memory shapes.
- dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
- dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape,
kernel_layout);
- // legalize shape IOHW with layout OIHW
- if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) {
- std::swap(weights_dims_[0], weights_dims_[1]);
- if (kernel_layout.find("OI") == 0) {
- kernel_layout.replace(kernel_layout.find("OI"), 2, "IO");
- }
- }
- weights_dims_[0] = channels;
- weights_dims_[1] = src_dims[1];
- dnnl::memory::dims bias_dims = {channels};
- dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
- dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
- dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
- dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
- dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding);
- dnnl::memory::dims dst_dims = src_dims;
- dst_dims[1] = channels;
- for (size_t i = 2; i < src_dims.size(); i++) {
- dnnl::memory::dim K = weights_dims_[i];
- dnnl::memory::dim S = strides_dims[i - 2];
- dnnl::memory::dim D = dilates_dims[i - 2];
- dnnl::memory::dim PL = padding_dims_l[i - 2];
- dnnl::memory::dim PR = padding_dims_r[i - 2];
- dnnl::memory::dim OP = out_padding[i - 2];
- dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
- dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP;
+ auto src_tr = GetInput(nid, 0);
+ auto wgh_tr = GetInput(nid, 1);
+ auto dst_tr = GetOutput(nid, 0);
+ auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+
+ auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+ auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+ auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+ std::vector<int64_t> padding_l(padding.begin(), padding.begin() +
padding.size() / 2);
+ std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2,
padding.end());
+ auto groups = GetNodeAttr<int>(node, "groups");
+ auto src_layout = GetNodeAttr<std::string>(node, "data_layout");
+ auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+ auto wgh_layout = GetNodeAttr<std::string>(node, "kernel_layout");
+
+ // dst_layout == "" means to use data_layout
+ if (dst_layout.empty()) dst_layout = src_layout;
+
+ // Minus one for DNNL representation. No dilation for DNNL is 0, for relay
is 1.
+ for (auto& d : dilates) d--;
+
+ // TODO(@apeskov): WA. conv3dTranspose uses wrong layout specifier. IO
instead of OI.
+ auto wgh_logic_layout = TensorRequisite::DefaultLogicLayoutFor(wgh_layout);
+ if (wgh_logic_layout == "OIDHW") wgh_logic_layout = "IODHW";
+ if (wgh_logic_layout == "GOIDHW") wgh_logic_layout = "GIODHW";
+
+ // Take into account provided layout strings
+ src_tr = src_tr.TreatAs(src_layout);
+ dst_tr = dst_tr.TreatAs(dst_layout);
+ wgh_tr = wgh_tr.TreatAs(wgh_layout, wgh_logic_layout);
+
+ // Should support G mixed with O. Like { G*O, I, H, W }
+ if (wgh_layout.find("G") == std::string::npos) {
+ auto w_dims = wgh_tr.dims();
+ w_dims[0] /= groups;
+ w_dims.insert(w_dims.begin(), groups);
+ wgh_tr = wgh_tr.Reshape(w_dims);
}
- dnnl::memory::dims weights_dims = weights_dims_;
- if (groups > 1) {
- weights_dims = {groups, channels / groups, src_dims[1] / groups};
- weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2,
weights_dims_.end());
- }
+ // Assumption that bias is correct and can be squeezed to 1D
+ bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
- // Memory descriptions.
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- auto deconv_src_md = dnnl::memory::desc(src_dims, dtype,
layout2tag(data_layout));
- auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype,
layout2tag(kernel_layout));
- auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x);
- auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
- // Transposed covn2d description.
- auto deconv_desc =
- has_bias ? dnnl::deconvolution_forward::desc(
- dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
- deconv_src_md, deconv_weights_md, deconv_bias_md,
deconv_dst_md,
- strides_dims, dilates_dims, padding_dims_l,
padding_dims_r)
- : dnnl::deconvolution_forward::desc(
- dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
- deconv_src_md, deconv_weights_md, deconv_dst_md,
strides_dims, dilates_dims,
- padding_dims_l, padding_dims_r);
+ // Conv description.
+ auto deconv_desc = dnnl::deconvolution_forward::desc(
+ dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
+ src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(),
bias_tr.LayoutAny().desc(),
+ dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r);
// Enable elementwise post-ops.
auto deconv_prim_desc =
dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
- // Push to the network.
- auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);
- net_.push_back(deconv);
-
- // Data memory.
- auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md);
-
- // Weight memory.
- auto deconv_weights_memory = BindDNNLMemory(weight_entry,
deconv_prim_desc.weights_desc());
-
- // Output memory.
- auto deconv_dst_memory = BindDNNLMemory(out_entry,
deconv_prim_desc.dst_desc());
+ src_tr = src_tr.RequestLayout(deconv_prim_desc.src_desc());
+ wgh_tr = wgh_tr.RequestLayout(deconv_prim_desc.weights_desc());
+ dst_tr = dst_tr.RequestLayout(deconv_prim_desc.dst_desc());
+ bias_tr = bias_tr.RequestLayout(deconv_prim_desc.bias_desc());
- // Bias memory.
- auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x},
engine_);
- if (has_bias) {
- auto bias_entry = node.GetInputs()[2];
- BindDNNLMemory(bias_entry, deconv_bias_memory);
+ auto scratchpad_tr =
TensorRequisite::AsIs(deconv_prim_desc.scratchpad_desc());
- // Bind memory buffers.
- net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory},
- {DNNL_ARG_WEIGHTS, deconv_weights_memory},
- {DNNL_ARG_BIAS, deconv_bias_memory},
- {DNNL_ARG_DST, deconv_dst_memory}});
- } else {
- // Bind memory buffers.
- net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory},
- {DNNL_ARG_WEIGHTS, deconv_weights_memory},
- {DNNL_ARG_DST, deconv_dst_memory}});
- }
+ Submit(dnnl::deconvolution_forward(deconv_prim_desc), {{DNNL_ARG_SRC,
src_tr},
+ {DNNL_ARG_WEIGHTS,
wgh_tr},
+ {DNNL_ARG_BIAS,
bias_tr},
+
{DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+ {DNNL_ARG_DST,
dst_tr}});
}
void Dense(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = node.GetOpName();
dnnl::primitive_attr attr;
+ attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
bool has_bias = ParsingOpName(op_name, attr);
// Setup attributes.
- auto data_entry = node.GetInputs()[0];
- auto weight_entry = node.GetInputs()[1];
- JSONGraphNodeEntry out_entry(nid, 0);
- dnnl::memory::dims input_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::dims weight_shape =
nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
- dnnl::memory::dims out_shape =
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
- dnnl::memory::dim OC = out_shape[1];
-
- // Memory shapes.
- dnnl::memory::dims data_dims = input_shape;
- dnnl::memory::dims weight_dims = weight_shape;
- dnnl::memory::dims bias_dims = {OC};
- dnnl::memory::dims out_dims = out_shape;
-
- // Memory descriptions.
- auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
- auto dtype = dtype_dl2dnnl(dl_dtype);
- auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc});
- auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc});
- auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x});
- auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc});
+ auto src_tr = GetInput(nid, 0);
+ auto wgh_tr = GetInput(nid, 1);
+ auto dst_tr = GetOutput(nid, 0);
+ auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1);
+
+ // Assumption that bias is correct and can be squeezed to 1D
+ bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
// Dense description.
- auto dense_desc =
dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
- weight_md, bias_md,
dst_md);
+ auto dense_desc = dnnl::inner_product_forward::desc(
+ dnnl::prop_kind::forward_inference, src_tr.LayoutAny().desc(),
wgh_tr.LayoutAny().desc(),
+ bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc());
// Enable elementwise post-ops.
auto dense_prim_desc =
dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_);
- auto dense = dnnl::inner_product_forward(dense_prim_desc);
- net_.push_back(dense);
+ src_tr = src_tr.RequestLayout(dense_prim_desc.src_desc());
+ wgh_tr = wgh_tr.RequestLayout(dense_prim_desc.weights_desc());
+ dst_tr = dst_tr.RequestLayout(dense_prim_desc.dst_desc());
+ bias_tr = bias_tr.RequestLayout(dense_prim_desc.bias_desc());
- // Memories.
- auto data_memory = BindDNNLMemory(data_entry, data_md);
- auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
+ auto scratchpad_tr =
TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc());
- // Bias memory.
- auto bias_memory = dnnl::memory(bias_md, engine_);
- if (has_bias) {
- auto bias_entry = node.GetInputs()[2];
- BindDNNLMemory(bias_entry, bias_memory);
- } else {
- float bias[OC] = {0};
- write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8));
- }
-
- // Output memory.
- auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
-
- net_args_.push_back({{DNNL_ARG_SRC, data_memory},
- {DNNL_ARG_WEIGHTS, weight_memory},
- {DNNL_ARG_BIAS, bias_memory},
- {DNNL_ARG_DST, dst_memory}});
+ Submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC,
src_tr},
+ {DNNL_ARG_WEIGHTS,
wgh_tr},
+ {DNNL_ARG_BIAS,
bias_tr},
+
{DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+ {DNNL_ARG_DST,
dst_tr}});
}
void BatchNorm(const size_t& nid) {
auto node = nodes_[nid];
- auto data_entry = node.GetInputs()[0];
- auto gamma_entry = node.GetInputs()[1];
- auto beta_entry = node.GetInputs()[2];
- auto mean_entry = node.GetInputs()[3];
- auto variance_entry = node.GetInputs()[4];
- dnnl::memory::dims data_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::dim IC = data_shape[1];
- float epsilon =
std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
+ auto src_tr = GetInput(nid, 0);
+ auto gamma_tr = GetInput(nid, 1);
+ auto beta_tr = GetInput(nid, 2);
+ auto mean_tr = GetInput(nid, 3);
+ auto var_tr = GetInput(nid, 4);
+ auto dst_tr = GetOutput(nid, 0);
- // Memory description.
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
+ auto axis = GetNodeAttr<int>(node, "axis");
+ auto epsilon = GetNodeAttr<float>(node, "epsilon");
+ auto center = GetNodeAttr<bool>(node, "center");
+ auto scale = GetNodeAttr<bool>(node, "scale");
+
+ ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case";
- // BN description.
auto bn_desc = dnnl::batch_normalization_forward::desc(
- dnnl::prop_kind::forward_inference, data_md, epsilon,
+ dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
dnnl::normalization_flags::use_global_stats |
dnnl::normalization_flags::use_scale_shift);
auto bn_prim_desc =
dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_);
- auto bn = dnnl::batch_normalization_forward(bn_prim_desc);
- net_.push_back(bn);
-
- // Memories.
- auto data_memory = BindDNNLMemory(data_entry, data_md);
- JSONGraphNodeEntry out_entry(nid, 0);
- auto out_memory = BindDNNLMemory(out_entry, data_md);
- auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc());
- auto variance_memory = BindDNNLMemory(variance_entry,
bn_prim_desc.variance_desc());
-
- // In DNNL, weight is composed of gamma+beta, so we point them to the same
DNNL memory but
- // assign an offset to beta data for runtime serialization.
- auto weight_memory = BindDNNLMemory(gamma_entry,
bn_prim_desc.weights_desc(), 0);
- BindDNNLMemory(beta_entry, weight_memory, IC);
-
- net_args_.push_back({{DNNL_ARG_SRC, data_memory},
- {DNNL_ARG_DST, out_memory},
- {DNNL_ARG_SCALE_SHIFT, weight_memory},
- {DNNL_ARG_MEAN, mean_memory},
- {DNNL_ARG_VARIANCE, variance_memory}});
+
+ // Concatenate scale and shift tensors
+ auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(),
GenUniqueEid());
+ auto sc_sh_dims = scale_shift_tr.dims();
+ ICHECK(sc_sh_dims.size() == 2);
+ ICHECK(sc_sh_dims[0] == 2);
+ sc_sh_dims[0] /= 2;
+ auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
+ auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();
+
+ auto register_copy = [this](const TensorRequisite& src, const
TensorRequisite& dst) {
+ dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_,
dst.desc());
+ Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST,
dst}});
+ };
+
+ register_copy(gamma_tr, scale_tr);
+ register_copy(beta_tr, shift_tr);
+
+ Submit(dnnl::batch_normalization_forward(bn_prim_desc), {{DNNL_ARG_SRC,
src_tr},
+ {DNNL_ARG_DST,
dst_tr},
+
{DNNL_ARG_SCALE_SHIFT, scale_shift_tr},
+ {DNNL_ARG_MEAN,
mean_tr},
+
{DNNL_ARG_VARIANCE, var_tr}});
}
void Pooling(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
+ auto src_tr = GetInput(nid, 0);
+ auto dst_tr = GetOutput(nid, 0);
+
// Setup attributes.
- auto data_entry = node.GetInputs()[0];
- JSONGraphNodeEntry out_entry(nid, 0);
- dnnl::memory::dims input_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::dims out_shape =
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
- std::vector<std::string> str_kernel =
node.GetAttr<std::vector<std::string>>("pool_size");
- std::vector<std::string> str_strides =
node.GetAttr<std::vector<std::string>>("strides");
- std::vector<std::string> str_padding =
node.GetAttr<std::vector<std::string>>("padding");
- std::vector<std::string> str_padding_l(str_padding.begin(),
- str_padding.begin() +
str_padding.size() / 2);
- std::vector<std::string> str_padding_r(str_padding.end() -
str_padding.size() / 2,
- str_padding.end());
- std::vector<std::string> str_dilates =
node.GetAttr<std::vector<std::string>>("dilation");
- std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
+ auto strides = GetNodeAttr<std::vector<int64_t>>(node, "strides");
+ auto dilates = GetNodeAttr<std::vector<int64_t>>(node, "dilation");
+ auto padding = GetNodeAttr<std::vector<int64_t>>(node, "padding");
+ std::vector<int64_t> padding_l(padding.begin(), padding.begin() +
padding.size() / 2);
+ std::vector<int64_t> padding_r(padding.begin() + padding.size() / 2,
padding.end());
+ auto kernel = GetNodeAttr<std::vector<int64_t>>(node, "pool_size");
+ auto src_layout = GetNodeAttr<std::string>(node, "layout");
+ auto dst_layout = GetNodeAttr<std::string>(node, "out_layout");
+
+ // dst_layout == "" means to use data_layout
+ if (dst_layout.empty()) dst_layout = src_layout;
+
+ // Minus one for DNNL representation. No dilation for DNNL is 0, for relay
is 1.
+ for (auto& d : dilates) d--;
+
+ // Take into account provided layout strings
+ src_tr = src_tr.TreatAs(src_layout);
+ dst_tr = dst_tr.TreatAs(dst_layout);
// Attributes related to AvgPool
if (algo == dnnl::algorithm::pooling_avg) {
- int int_countpad =
std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
- bool count_include_pad = int_countpad != 0 ? true : false;
- algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding
- : dnnl::algorithm::pooling_avg_exclude_padding;
+ auto include_pad = GetNodeAttr<bool>(node, "count_include_pad");
+ algo = include_pad ? dnnl::algorithm::pooling_avg_include_padding
+ : dnnl::algorithm::pooling_avg_exclude_padding;
}
- dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout);
- dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout);
- dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel);
- dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
- dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
- dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
- dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
-
- // Memory descriptions.
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout));
- auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
-
// Pooling description.
- auto pool_desc =
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo,
- pool_src_md, pool_dst_md,
strides_dims,
- kernel_dims, padding_dims_l,
padding_dims_r);
-
- auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc,
engine_, true);
- auto pool = dnnl::pooling_forward(pool_prim_desc);
- net_.push_back(pool);
+ auto pool_desc = dnnl::pooling_v2_forward::desc(
+ dnnl::prop_kind::forward_inference, algo, src_tr.desc(), //<= Do not
use any for src tensor
+ dst_tr.LayoutAny().desc(), strides, kernel, dilates, padding_l,
padding_r);
+ auto pool_prim_desc = dnnl::pooling_v2_forward::primitive_desc(pool_desc,
engine_);
- // Memories.
- auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md);
+ src_tr = src_tr.RequestLayout(pool_prim_desc.src_desc());
+ dst_tr = dst_tr.RequestLayout(pool_prim_desc.dst_desc());
- auto pool2d_dst_memory = BindDNNLMemory(out_entry,
pool_prim_desc.dst_desc());
+ auto scratchpad_tr =
TensorRequisite::AsIs(pool_prim_desc.scratchpad_desc());
- // Bind memory buffers.
- net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST,
pool2d_dst_memory}});
+ Submit(dnnl::pooling_v2_forward(pool_prim_desc),
+ {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr},
{DNNL_ARG_SCRATCHPAD, scratchpad_tr}});
}
void Eltwise(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = node.GetOpName();
- auto algo = elt_name2algo[op_name];
+ auto algo = elt_name2algo.at(op_name);
+
+ auto src_tr = GetInput(nid, 0);
+ auto dst_tr = GetOutput(nid, 0);
- auto data_entry = node.GetInputs()[0];
- dnnl::memory::dims shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
float alpha = 0., beta = 0.;
if (op_name == "clip") {
- alpha = std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
- beta = std::stof(node.GetAttr<std::vector<std::string>>("a_max")[0]);
+ alpha = GetNodeAttr<float>(node, "a_min");
+ beta = GetNodeAttr<float>(node, "a_max");
} else if (op_name == "nn.leaky_relu") {
- alpha = std::stof(node.GetAttr<std::vector<std::string>>("alpha")[0]);
+ alpha = GetNodeAttr<float>(node, "alpha");
}
- auto elt_desc =
- dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
data_md, alpha, beta);
+ auto elt_desc =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
+ src_tr.desc(), alpha, beta);
auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc,
engine_);
- ICHECK(data_md == elt_prim_desc.dst_desc());
-
- auto elt = dnnl::eltwise_forward(elt_prim_desc);
- net_.push_back(elt);
+ ICHECK(src_tr.desc() == elt_prim_desc.dst_desc());
- auto data_memory = BindDNNLMemory(data_entry, data_md);
- JSONGraphNodeEntry out_entry(nid, 0);
- auto out_memory = BindDNNLMemory(out_entry, data_md);
-
- net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST,
out_memory}});
+ Submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr},
{DNNL_ARG_DST, dst_tr}});
}
void Softmax(const size_t& nid) {
auto node = nodes_[nid];
- auto data_entry = node.GetInputs()[0];
- dnnl::memory::dims shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
+ auto src_tr = GetInput(nid, 0);
+ auto dst_tr = GetOutput(nid, 0);
+
+ auto axis = GetNodeAttr<int>(node, "axis");
if (axis < 0) {
- axis = shape.size() + axis;
+ axis = src_tr.dims().size() + axis;
}
- auto dtype =
dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
auto softmax_desc =
- dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference,
data_md, axis);
+ dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference,
src_tr.desc(), axis);
auto softmax_prim_desc =
dnnl::softmax_forward::primitive_desc(softmax_desc, engine_);
- ICHECK(data_md == softmax_prim_desc.dst_desc());
-
- auto softmax = dnnl::softmax_forward(softmax_prim_desc);
- net_.push_back(softmax);
+ ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc());
- auto data_memory = BindDNNLMemory(data_entry, data_md);
- JSONGraphNodeEntry out_entry(nid, 0);
- auto out_memory = BindDNNLMemory(out_entry, data_md);
-
- net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST,
out_memory}});
+ Submit(dnnl::softmax_forward(softmax_prim_desc),
+ {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}});
}
void Binary(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
+ ICHECK_EQ(node.GetInputs().size(), 2U);
// Memory and compute description.
- std::vector<dnnl::memory::dims> data_dims;
- std::vector<dnnl::memory::desc> data_mds;
- std::vector<dnnl::memory> data_memories;
+ auto lhs_tr = GetInput(nid, 0);
+ auto rhs_tr = GetInput(nid, 1);
+ auto dst_tr = GetOutput(nid, 0);
- ICHECK_EQ(node.GetInputs().size(), 2U);
- for (auto entry : node.GetInputs()) {
- auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
- auto dtype =
dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]);
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
-
- data_dims.push_back(data_shape);
- data_mds.push_back(data_md);
- data_memories.push_back(BindDNNLMemory(entry, data_md));
- }
- ICHECK(data_dims[0] == data_dims[1]);
- auto out_md = data_mds[0];
- JSONGraphNodeEntry out_entry(nid, 0);
- auto out_memory = BindDNNLMemory(out_entry, out_md);
+ lhs_tr = lhs_tr.Broadcast(dst_tr.dims());
+ rhs_tr = rhs_tr.Broadcast(dst_tr.dims());
- auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1],
out_md);
+ auto binary_desc = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(),
dst_tr.desc());
auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_);
- auto binary = dnnl::binary(binary_prim_desc);
- net_.push_back(binary);
- net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]},
- {DNNL_ARG_SRC_1, data_memories[1]},
- {DNNL_ARG_DST, out_memory}});
+ Submit(dnnl::binary(binary_prim_desc),
+ {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST,
dst_tr}});
+ }
+
+ template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0>
+ T AttrConvert(std::vector<std::string> val) {
+ ICHECK_EQ(val.size(), 1);
+ return std::stol(val[0]);
+ }
+
+ template <typename T, std::enable_if_t<std::is_floating_point<T>::value,
int> = 0>
+ T AttrConvert(std::vector<std::string> val) {
+ ICHECK_EQ(val.size(), 1);
+ return std::stof(val[0]);
+ }
+
+ template <typename T, std::enable_if_t<std::is_same<T, std::string>::value,
int> = 0>
+ T AttrConvert(std::vector<std::string> val) {
+ ICHECK_EQ(val.size(), 1);
+ return val[0];
+ }
+
+ template <typename T,
+ std::enable_if_t<std::is_same<T, std::vector<typename
T::value_type>>::value, int> = 0>
+ T AttrConvert(std::vector<std::string> val) {
+ T res;
+ for (const auto& el : val) res.push_back(AttrConvert<typename
T::value_type>({el}));
+ return res;
+ }
+
+ /*!
+ * \brief Helper to extract node attribute with ability to specify default
value and result type.
+ */
+ template <typename T>
+ const T GetNodeAttr(const json::JSONGraphNode& node, std::string name,
+ std::vector<std::string> def = {}) {
+ auto attr = node.HasAttr(name) ?
node.GetAttr<std::vector<std::string>>(name) : def;
+ return AttrConvert<T>(attr);
}
- // Read from DNNL memory (+offset) and write to the handle.
- inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem,
size_t size,
- size_t offset = 0) {
- uint8_t* src = static_cast<uint8_t*>(mem.get_data_handle());
- std::copy(src + offset, src + offset + size,
static_cast<uint8_t*>(handle));
+ TensorRequisite GetInput(const size_t& nid, const int idx) {
+ if (idx == -1) return {}; // -1 reserved value for empty input.
+
+ const JSONGraphNode& node = nodes_[nid];
+
+ ICHECK_LT(idx, node.GetInputs().size());
+ auto data_entry = node.GetInputs()[idx];
+
+ auto shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ auto dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
+ auto eid = node_row_ptr_[data_entry.id_] + data_entry.index_;
+ auto const_dl_tensor = data_entry_[eid];
+
+ auto desc = MakePlainDesc(shape, dtype);
+
+ TensorRequisite res;
+ if (const_dl_tensor) {
+ ICHECK(const_dl_tensor->data);
+ ICHECK(const_dl_tensor->strides == nullptr);
+ auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data);
+ res = TensorRequisite::AsIs(mem, eid);
+ } else {
+ res = TensorRequisite::AsIs(desc, eid);
+ }
+ return res;
}
- // Read from the handle and write to DNNL memory (+offset).
- inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem,
size_t size,
- size_t offset = 0) {
- uint8_t* dst = static_cast<uint8_t*>(mem.get_data_handle());
- std::copy(reinterpret_cast<uint8_t*>(handle),
reinterpret_cast<uint8_t*>(handle) + size,
- dst + offset);
+ TensorRequisite GetOutput(const size_t& nid, const int idx) {
+ if (idx == -1) return {}; // -1 reserved value for empty input.
+
+ const JSONGraphNode& node = nodes_[nid];
+
+ ICHECK_LT(idx, node.GetNumOutput());
+ auto shape = node.GetOpShape()[idx];
+ auto dtype = node.GetOpDataType()[idx];
+ auto eid = node_row_ptr_[nid] + static_cast<uint32_t>(idx);
+
+ ICHECK(data_entry_[eid] == nullptr);
+ auto desc = MakePlainDesc(shape, dtype);
+
+ return TensorRequisite::AsIs(desc, eid).Backward();
}
- // Generate DNNL memory description and infer the data layout by the given
shape.
- inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims&
shape, dt dtype) {
- dnnl::memory::desc data_md;
- switch (shape.size()) {
- case 2:
- data_md = dnnl::memory::desc({shape, dtype, tag::ab});
- break;
- case 3:
- data_md = dnnl::memory::desc({shape, dtype, tag::abc});
- break;
- case 4:
- data_md = dnnl::memory::desc({shape, dtype, tag::abcd});
- break;
- case 5:
- data_md = dnnl::memory::desc({shape, dtype, tag::abcde});
- break;
- default:
- LOG(FATAL) << "Unsupported data shape dimension: " << shape.size();
- break;
+ /*! \brief Helper function to register primitive into execution queue */
+ void Submit(const dnnl::primitive& prim,
+ const std::unordered_map<int, TensorRequisite>& tr_args) {
+ // Register all provided TR arguments
+ std::unordered_map<int, TensorRegistry::ArgId> prim_arg_id;
+ TensorRegistry::ActionQue post_prim_actions;
+ for (const auto& kvp : tr_args) {
+ const auto& key = kvp.first;
+ const auto& tr = kvp.second;
+
+ if (!tr.defined()) continue; // empty arg is admitted. Just skip it
+ auto arg_id = tensor_registry_.Register(tr, tr.IsReversed() ?
&post_prim_actions : &net_);
+ prim_arg_id[key] = arg_id;
}
- return data_md;
+
+ // Register main primitive
+ net_.push_back({prim, prim_arg_id});
+
+ // Register post actions
+ net_.insert(net_.end(), post_prim_actions.begin(),
post_prim_actions.end());
}
+ uint32_t GenUniqueEid() { return next_unique_eid_offset_++; }
+
/* The dnnl engine. */
dnnl::engine engine_;
/* The dnnl stream. */
dnnl::stream stream_;
/* The network layers that are represented in dnnl primitives. */
- std::vector<dnnl::primitive> net_;
- /* The memory that is consumed by arguments. */
- std::vector<std::unordered_map<int, dnnl::memory>> net_args_;
- /* The entry ID to its corresponding output memory. */
- std::unordered_map<uint32_t, std::pair<dnnl::memory, size_t>> entry_out_mem_;
+ TensorRegistry::ActionQue net_;
+ /* Storage for all memory objects */
+ TensorRegistry tensor_registry_;
+ /* Generator of new unique eid which doesn't match with existing data entry
*/
+ uint32_t next_unique_eid_offset_;
+ /* Map of Run arg idx to corresponding eid */
+ std::vector<uint32_t> run_arg_eid_;
};
runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json,
diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
new file mode 100644
index 0000000000..d02ceff5de
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_tensor_requisite.cc
+ * \brief Helper TR wrapper to simplify tensors processing
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
+#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
+
+#include <dlpack/dlpack.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "dnnl_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace utils;
+
+/*!
+ * \brief Helper object to simplify tensor transformation description.
+ *
+ * Allow to specify original source tensor and future actions which should be
applied to it.
+ * Can be treated as sequence of reordering or reinterpretation of original
source tensor.
+ * Finally TR can be solved as proper interpretation of source memory buffer,
or sequence of
+ * dnnl::reorder operators which will provide desired data.
+ *
+ * \note Empty TR object allow any manipulation. Empty TR will be returned.
+ *
+ * \sa TensorRegistry
+ *
+ * Example:
+ * \code
+ * dnnl::memory src_mem = ...; // 5D tensor, shape {5, 2, 128, 128, 8}
+ *
+ * // Construct TR
+ * auto tr = TensorRequisite.AsIs(src_mem, eid); // 5D
+ *
+ * // describe sequence of layout transformation
+ * tr = tr.TreatAs("ABCD8b"); // 4D
+ * tr = tr.Permute({0, 2, 3, 1}); // Permute axes NCHW -> NHWC
+ * tr = tr.Crop({1, 128, 128, 16}, {0, 0, 0}); // extract first batch
element
+ * tr = tr.Squeeze(); // 1D
+ *
+ * // register TR
+ * TensorRegistry t_reg;
+ * auto t_id = t_reg.register(tr);
+ *
+ * // Get final dnnl::memory object
+ * auto solver = t_reg.MakeSolver(ext_tensor_provider);
+ * auto mem = solver(t_id);
+ * \endcode
+ *
+ */
+class TensorRequisite {
+ public:
+ using Tid = uint32_t;
+ static constexpr Tid kUndefinedTid = std::numeric_limits<uint32_t>::max() -
1;
+
+ /*! \brief Empty constructor */
+ TensorRequisite() {}
+
+ /*! \brief Construct TR on top of existing memory object */
+ static TensorRequisite AsIs(const dnnl::memory& mem, Tid id = kUndefinedTid)
{
+ auto res = AsIs(mem.get_desc(), id);
+ if (mem.get_data_handle() != nullptr) res.mem_ = mem;
+ return res;
+ }
+
+ /*! \brief Construct TR on top of existing memory descriptor object */
+ static TensorRequisite AsIs(const dnnl::memory::desc& desc, Tid id =
kUndefinedTid) {
+ return {desc, {}, false, {}, id, false};
+ }
+
+ /*! \brief return logical shape of tensor */
+ dnnl::memory::dims dims() const { return t_desc_.dims(); }
+
+ /*! \brief return data type of tensor */
+ dnnl::memory::data_type data_type() const { return t_desc_.data_type(); }
+
+ /*! \brief return tensor desc */
+ dnnl::memory::desc desc() const { return t_desc_; }
+
+ /*! \brief Make TR with backward dataflow */
+ TensorRequisite Backward() const {
+ if (!defined()) return *this;
+ ICHECK(orig_ == nullptr);
+ return {t_desc_, orig_, reinterpret_, mem_, eid_, true};
+ }
+
+ /*! \brief Produce TR with permuted axes */
+ TensorRequisite Permute(const std::vector<int>& permutation) const {
+ if (!defined()) return *this; // nothing for empty TR
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ // reinterpret memory buffer with new strides
+ auto desc = t_desc_.permute_axes(permutation);
+ return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Produce TR with reinterpret data of original tr */
+ TensorRequisite Reshape(const dnnl::memory::dims& shape) const {
+ if (!defined()) return *this; // nothing for empty TR
+ if (t_desc_.dims() == shape) return *this;
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ // reinterpret memory buffer with new strides
+ auto desc = t_desc_.reshape(shape);
+ return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Produce TR with broadcasted values */
+ TensorRequisite Broadcast(const dnnl::memory::dims& shape) const {
+ if (!defined()) return *this; // nothing for empty TR
+ if (t_desc_.dims() == shape) return *this;
+ ICHECK(!reverse_data_flow_);
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+
+ // numpy like broadcast
+ auto extended_dims = t_desc_.dims();
+ auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(),
1);
+ extended_dims.insert(extended_dims.begin(), one_filled.begin(),
one_filled.end());
+ auto desc = t_desc_.reshape(extended_dims);
+ for (size_t i = 0; i < extended_dims.size(); i++) {
+ if (extended_dims[i] == shape[i]) continue;
+ ICHECK(extended_dims[i] == 1);
+ ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]);
+
+ desc.data.dims[i] = shape[i];
+ desc.data.padded_dims[i] = shape[i];
+ desc.data.format_desc.blocking.strides[i] = 0;
+ }
+
+ // reinterpret memory buffer with new strides
+ return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Produce TR with sub memory view (ROI) */
+ TensorRequisite Crop(const dnnl::memory::dims& shape, const
dnnl::memory::dims& offset) const {
+ if (!defined()) return *this; // nothing for empty TR
+
+ ICHECK_EQ(shape.size(), t_desc_.dims().size());
+ ICHECK_EQ(offset.size(), t_desc_.dims().size());
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ // reinterpret memory buffer with new strides
+ auto desc = t_desc_.submemory_desc(shape, offset, /*allow_empty=*/true);
+
+ // Originally DNNL implementation is very limited. Let's slightly enhance
it.
+ if (!desc && t_desc_.data.format_kind == dnnl_blocked) {
+ bool offset_is_zero =
+ std::all_of(offset.begin(), offset.end(), [](auto el) { return el ==
0; });
+
+ dnnl::memory::dims block_sizes(t_desc_.dims().size(), 1);
+ for (int i = 0; i < t_desc_.data.format_desc.blocking.inner_nblks; i++)
+ block_sizes[t_desc_.data.format_desc.blocking.inner_idxs[i]] *=
+ t_desc_.data.format_desc.blocking.inner_blks[i];
+
+ bool shape_reduction_less_than_block = true;
+ for (int i = 0; i < t_desc_.data.ndims; i++) {
+ shape_reduction_less_than_block &= t_desc_.data.dims[i] - shape[i] <
block_sizes[i];
+ }
+
+ // This is auto padded case. Just update dims value.
+ if (offset_is_zero && shape_reduction_less_than_block) {
+ desc = t_desc_;
+ std::copy(shape.begin(), shape.end(), desc.data.dims);
+ }
+ }
+
+ ICHECK(desc);
+
+ return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Produce TR with squeeze shape */
+ TensorRequisite Squeeze(const dnnl::memory::dims& dims_to_squeeze = {})
const {
+ if (!defined()) return *this; // nothing for empty TR
+
+ dnnl::memory::dims squeezed_dims;
+ if (dims_to_squeeze.empty()) {
+ for (auto d : t_desc_.dims())
+ if (d != 1) squeezed_dims.push_back(d);
+ } else {
+ for (size_t i = 0; i < t_desc_.dims().size(); i++)
+ if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) ==
dims_to_squeeze.end())
+ squeezed_dims.push_back(t_desc_.dims()[i]);
+ }
+
+ if (squeezed_dims.empty()) squeezed_dims = {1};
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ // reinterpret memory buffer with new strides
+ auto desc = t_desc_.reshape(squeezed_dims);
+ return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Produce TR with specified layout descriptor */
+ TensorRequisite RequestLayout(dnnl::memory::desc desc) const {
+ if (!defined()) return *this; // nothing for empty TR
+
+ // If it's the same desc just return self
+ if (desc == t_desc_) return *this;
+
+ ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not
compatible with "
+ "presented shape";
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ return {desc, orig, false, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Define which logical dims ordering is default for particular
layout string. */
+ static std::string DefaultLogicLayoutFor(const std::string& layout) {
+ // Rank is all non digit marked dims
+ auto it = layout.begin();
+ while (it != layout.end() && !std::isdigit(*it)) it++;
+ int rank = std::distance(layout.begin(), it);
+
+ static const std::vector<std::string> sparse_dims = {"W", "HW", "DHW"};
+ if (layout.find("N") != std::string::npos) return "NC" + sparse_dims[rank
- 3];
+ if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank
- 4];
+ if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank
- 3];
+
+ LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme
to handle it";
+ return {};
+ }
+
+ /*!
+ * \brief Treat TR shape as described in layout string.
+ *
+ * Blocked dimensions will be concatenated and put into proper shape
position corresponding to .
+ * resulting_layout_logic argument. If desired logic layout was not provided
it will be deduced
+ * automatically based on some internal heuristics.
+ *
+ * Limitation 1. Blocking dims should be dense. Dims marked with digits use
natural strides.
+ * Limitation 2. Blocking dims are innermost. Dims marked like 8c, 4o goes
after regular
+ * dimensions. NC8cHW4h4cD is not valid tensor in terms of
DNNL. And cannot be
+ * achieved with memory reinterpretation, so data copy is
required. Proper layout
+ * looks like NCHWD_8c4h4c, first part is outer dims, second
digits marked part is
+ * innermost.
+ */
+ TensorRequisite TreatAs(const std::string& layout, std::string
desired_logic_layout = "") const {
+ if (desired_logic_layout.empty()) desired_logic_layout =
DefaultLogicLayoutFor(layout);
+
+ const auto origin_dims = dims();
+
+ // split layout string to tokens {size, tag} like {16, 'C'}, {4, 'O'}
+ std::vector<std::pair<int, char>> layout_tokens;
+ for (auto it = layout.begin(); it != layout.end();) {
+ auto start = it;
+ while (std::isdigit(*it)) it++;
+ int blk_size = start == it ? -1 : std::stoi(std::string{start, it});
+ layout_tokens.push_back({blk_size, std::toupper(*it)});
+ it++;
+ }
+
+ // check applicability of layout
+ auto it = layout_tokens.begin();
+ while (it != layout_tokens.end() && it->first == -1) it++;
+ int rank = std::distance(layout_tokens.begin(), it);
+ while (it != layout_tokens.end()) {
+ ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be
innermost. "
+ << "But received layout is " << layout;
+ it++;
+ }
+
+ ICHECK_EQ(layout_tokens.size(), origin_dims.size());
+ ICHECK_EQ(rank, desired_logic_layout.size()) << layout;
+
+ std::vector<std::pair<int, char>> outermost_tokens(layout_tokens.begin(),
+ layout_tokens.begin() +
rank);
+ std::vector<std::pair<int, char>> innermost_tokens(layout_tokens.begin() +
rank,
+ layout_tokens.end());
+ // define dim resulting dim positions
+ std::map<char, int> dim_position_by_tag;
+ for (size_t i = 0; i < desired_logic_layout.size(); i++)
+ dim_position_by_tag[std::toupper(desired_logic_layout[i])] = i;
+
+ // Construct resulting desc by modifying original one
+ dnnl::memory::desc res_desc = t_desc_;
+
+ memset(&res_desc.data.format_desc.blocking, 0,
sizeof(res_desc.data.format_desc.blocking));
+ std::fill(res_desc.data.dims, res_desc.data.dims + DNNL_MAX_NDIMS, 0);
+ std::fill(res_desc.data.padded_dims, res_desc.data.padded_dims +
DNNL_MAX_NDIMS, 0);
+
+ res_desc.data.ndims = rank;
+ res_desc.data.format_desc.blocking.inner_nblks = innermost_tokens.size();
+
+ auto res_dims = res_desc.data.dims;
+ auto res_strides = res_desc.data.format_desc.blocking.strides;
+ auto res_inner_blks = res_desc.data.format_desc.blocking.inner_blks;
+ auto res_inner_idxs = res_desc.data.format_desc.blocking.inner_idxs;
+
+ std::fill(res_dims, res_dims + rank, 1);
+
+ int orig_dim_idx = 0;
+ for (const auto& p : outermost_tokens) {
+ auto tag = p.second;
+ auto dim_size = origin_dims[orig_dim_idx];
+
+ auto result_dim_position = dim_position_by_tag[tag];
+ res_dims[result_dim_position] *= dim_size;
+ res_strides[result_dim_position] =
t_desc_.data.format_desc.blocking.strides[orig_dim_idx];
+ orig_dim_idx++;
+ }
+ for (const auto& p : innermost_tokens) {
+ auto tag = p.second;
+ auto dim_size = origin_dims[orig_dim_idx];
+ auto result_dim_position = dim_position_by_tag[tag];
+ ICHECK_EQ(p.first, dim_size)
+ << "Blocking layout is not applicable to tensor with shape: " <<
origin_dims
+ << ". Requested layout is " << layout;
+
+ res_dims[result_dim_position] *= dim_size;
+ *res_inner_blks++ = dim_size;
+ *res_inner_idxs++ = result_dim_position;
+ orig_dim_idx++;
+ }
+
+ // Assume tensor is dense. There is no additional padding.
+ std::copy(res_desc.data.dims, res_desc.data.dims + rank,
res_desc.data.padded_dims);
+
+ if (t_desc_ == res_desc) return *this;
+
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ return {res_desc, orig, true, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*!
+ * \brief Produce TR with unspecified layout.
+ *
+ * Cannot be registered in TensorRegistry. Only for querying DNNL for
preferred layouts.
+ */
+ TensorRequisite LayoutAny() const {
+ auto orig = std::make_shared<TensorRequisite>(*this);
+ // Recreate tensor desc with layout 'any'
+ dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(),
dnnl::memory::format_tag::any};
+ return {any_desc, orig, false, {}, kUndefinedTid, reverse_data_flow_};
+ }
+
+ /*! \brief Check is TR is constant. */
+ bool IsConstant() const {
+ if (orig_) return orig_->IsConstant();
+ return mem_.operator bool();
+ }
+
+ /*! \brief Check is tensor is scalar. */
+ bool IsScalar() const { return t_desc_.dims().size() == 1 &&
t_desc_.dims()[0] == 1; }
+
+ /*! \brief Return const data memory if available. */
+ dnnl::memory GetConstData() const {
+ if (mem_) return mem_;
+ if (!orig_) return {};
+
+ if (auto orig_const_data = orig_->GetConstData()) {
+ if (reinterpret_) {
+ return {t_desc_, orig_const_data.get_engine(),
orig_const_data.get_data_handle()};
+ } else {
+ auto eng = orig_const_data.get_engine();
+ auto res = dnnl::memory{t_desc_, eng};
+ dnnl::reorder(orig_const_data, res).execute(dnnl::stream(eng),
orig_const_data, res);
+ return res;
+ }
+ }
+ return {};
+ }
+
+ /*!
+ * \brief Return const data memory in form of vector.
+ *
+ * Same as GetConstData but use std::vector instead of dnnl::memory. Works
only for 1D tensor
+ * and scalar TRs. Useful for specification of 1D DNNL attributes like
zero_point or
+ * per_channel_scale
+ */
+ template <typename T>
+ std::vector<T> GetConstDataLikeVec() const {
+ auto const_data = GetConstData();
+ auto desc = const_data.get_desc();
+ ICHECK(desc.data_type() == utils::DnnlDType<T>());
+ ICHECK(desc.dims().size() == 1);
+
+ auto size = desc.get_size() / sizeof(T);
+ auto ptr = static_cast<T*>(const_data.get_data_handle());
+
+ return std::vector<T>(ptr, ptr + size);
+ }
+
+ /*! \brief Get value of constant scalar tensor if possible. */
+ template <typename T>
+ T GetConstScalarData() const {
+ ICHECK(IsConstant());
+ ICHECK(IsScalar());
+ auto const_data = GetConstData();
+ auto desc = const_data.get_desc();
+ ICHECK(desc.data_type() == utils::DnnlDType<T>());
+
+ auto ptr = static_cast<T*>(const_data.get_data_handle());
+ return *ptr;
+ }
+
+ /*! \brief Check if tensor is not empty. */
+ bool defined() const { return !t_desc_.is_zero(); }
+
+ /*! \brief Same as defined */
+ operator bool() const { return defined(); }
+
+ /*!
+ * \brief Check if tensor represent a reversed data flow.
+ * Useful for describing output processing
+ */
+ bool IsReversed() const { return reverse_data_flow_; }
+
+ private:
+ TensorRequisite(const dnnl::memory::desc& t_desc, const
std::shared_ptr<TensorRequisite>& orig,
+ bool reinterpret, const dnnl::memory& const_mem, uint32_t
eid,
+ bool reverse_data_flow)
+ : t_desc_(t_desc),
+ orig_(orig),
+ reinterpret_(reinterpret),
+ mem_(const_mem),
+ eid_(eid),
+ reverse_data_flow_(reverse_data_flow) {
+ if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid);
+ if (eid_ != kUndefinedTid) ICHECK(!orig_);
+ }
+
+ /* Descriptor of particular tensor */
+ dnnl::memory::desc t_desc_ = {};
+ /* Parent TR object which is referred from this TR */
+ std::shared_ptr<TensorRequisite> orig_ = {};
+ /* Flag to specify which action should be done with orig TR, reordering or
reinterpretation */
+ bool reinterpret_ = false;
+ /* Const memory object if available */
+ dnnl::memory mem_ = {};
+ /* Entry ID of tensor if available */
+ uint32_t eid_ = kUndefinedTid;
+
+ /*
+ * Flag to describe reverse data flow case
+ * All operation on queue will be executed in reverse order. Actual for dst
tensor description
+ */
+ bool reverse_data_flow_ = false;
+
+ friend class TensorRegistry;
+};
+
+/*!
+ * \brief The registry of tensors. Implement matching of provided TRs and real
memory buffers.
+ *
+ * Registration of TR performed by calling method Register(), which will
return ArgId object.
+ * ArgId can be mapped to real memory via memory solver created by method
MakeSolver().
+ */
+class TensorRegistry {
+ private:
+ enum ArgReqFlag {
+ CONST, /// < Constant tensor. ExecutionCTX independent
+ TMP_STORAGE, /// < Intermediate tensors. Stored inside TensorRegistry.
Inaccessible outside
+ EXT_EID, /// < External data. Input or Output.
+ };
+
+ public:
+ struct ArgId {
+ TensorRegistry::ArgReqFlag flag_;
+ uint32_t idx_;
+ };
+
+ using Action = std::tuple<dnnl::primitive, std::unordered_map<int, ArgId>>;
+ using ActionQue = std::vector<Action>;
+ using DLTensorProvider = std::function<const DLTensor*(uint32_t)>;
+ using MemSolver = std::function<const dnnl::memory(ArgId)>;
+
+ TensorRegistry() = default;
+ TensorRegistry(const dnnl::engine& eng, const std::set<uint32_t>& ext_io_eid)
+ : tmp_mem_collection_(1), ext_io_eid_(ext_io_eid), eng_(eng),
stream_(eng) {}
+
+ /*!
+ * \brief Register TR to registry
+ *
+ * Resolution of TR may lead to introduction of intermediate memory buffers
and additional
+ * transformation actions which should be performed before or after usage of
corresponding memory
+ * buffer. Additional actions will be append to provided actions queue.
Corresponding to
+ * tr.IsReversed() value actions should be executed before or after usage of
resulting ArgId.
+ *
+ * \param tr tensor requisite sequence to register
+ * \param action resulting action queue. If TR resolution is required
execution of some
+ * transformation actions they will be put here
+ * \return associated ArgId. Should be used as argument for MemSolver.
+ */
+ ArgId Register(const TensorRequisite& tr, ActionQue* action) {
+ // 1) Constant tensor. Direct reference
+ if (auto const_data = tr.GetConstData()) {
+ auto idx = const_mem_collection_.size();
+ const_mem_collection_.push_back(const_data);
+ return MakeArgReq(ArgReqFlag::CONST, static_cast<uint32_t>(idx));
+ }
+
+ // 2) EID mapped tensor. Direct reference
+ if (tr.eid_ != TensorRequisite::kUndefinedTid) {
+ if (ext_io_eid_.count(tr.eid_) == 0) { // Not IO tensor, means it's
intermediate
+ if (eid2idx_tmp_.count(tr.eid_)) {
+ auto idx = eid2idx_tmp_.at(tr.eid_);
+ return MakeArgReq(ArgReqFlag::TMP_STORAGE, idx);
+ } else {
+ // register himself
+ auto idx = tmp_mem_collection_.size();
+ tmp_mem_collection_.push_back(tr.t_desc_);
+ eid2idx_tmp_[tr.eid_] = idx;
+ return MakeArgReq(ArgReqFlag::TMP_STORAGE,
static_cast<uint32_t>(idx));
+ }
+ } else {
+ auto idx = ext_mem_collection_.size();
+ ext_mem_collection_.push_back({tr.eid_, tr.t_desc_});
+ return MakeArgReq(ArgReqFlag::EXT_EID, static_cast<uint32_t>(idx));
+ }
+ }
+
+ // 3) Tensors with transform actions
+ if (tr.orig_) {
+ // recursive register of orig TR
+ auto orig_arg_req = Register(*tr.orig_, action);
+ if (tr.reinterpret_) {
+ return RegisterReinterpret(orig_arg_req, tr.t_desc_);
+ } else {
+ return RegisterReorder(orig_arg_req, tr.t_desc_,
tr.reverse_data_flow_, action);
+ }
+ }
+
+ // 4) Scratchpad
+ ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid);
+ auto idx = tmp_mem_collection_.size();
+ tmp_mem_collection_.push_back(tr.t_desc_);
+ tmp_mem_mapping_[idx] = 0; // zero position tmp mem object is reserved
for scratchpads
+
+ auto scratchpad_size = tr.t_desc_.get_size();
+ auto glob_scratchpad_size = tmp_mem_collection_[0].get_size();
+ if (scratchpad_size > glob_scratchpad_size) {
+ tmp_mem_collection_[0] =
+ dnnl::memory::desc({static_cast<dnnl::memory::dim>(scratchpad_size)},
+ dnnl::memory::data_type::u8,
dnnl::memory::format_tag::a);
+ }
+ return MakeArgReq(TMP_STORAGE, static_cast<uint32_t>(idx));
+ }
+
+ /*!
+ * \brief Construct memory solver for all registered TRs.
+ * \param ext_provider callback to resolve external IO buffers
+ * \return memory solver object to match ArgId to dnnl::memory objects
+ */
+ MemSolver MakeSolver(const DLTensorProvider& ext_provider) const {
+ return MemSolverImpl(eng_, ext_provider, const_mem_collection_,
ext_mem_collection_,
+ tmp_mem_collection_, tmp_mem_mapping_);
+ }
+
+ private:
+ ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) {
+ switch (src_ar.flag_) {
+ case TMP_STORAGE: {
+ auto idx = tmp_mem_collection_.size();
+ tmp_mem_collection_.push_back(desc);
+ tmp_mem_mapping_[idx] = src_ar.idx_;
+ return MakeArgReq(TMP_STORAGE, idx);
+ }
+ case EXT_EID: {
+ auto ext_req = ext_mem_collection_[src_ar.idx_];
+ auto idx = ext_mem_collection_.size();
+ ext_mem_collection_.push_back({ext_req.first, desc});
+ return MakeArgReq(EXT_EID, idx);
+ }
+ default:
+ LOG(FATAL) << "Unknown case";
+ }
+ return {};
+ }
+
+ ArgId RegisterReorder(ArgId src_ar, const dnnl::memory::desc& desc, bool
reverse_data_flow,
+ ActionQue* action) {
+ ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID);
+
+ auto src_desc = src_ar.flag_ == TMP_STORAGE ?
tmp_mem_collection_[src_ar.idx_]
+ :
ext_mem_collection_[src_ar.idx_].second;
+ auto idx = tmp_mem_collection_.size();
+ tmp_mem_collection_.push_back(desc);
+ auto dst_ar = MakeArgReq(TMP_STORAGE, idx);
+
+ // reorder action submit
+ if (reverse_data_flow) {
+ auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_,
src_desc);
+ action->insert(action->begin(),
+ {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, dst_ar},
{DNNL_ARG_TO, src_ar}}});
+ } else {
+ auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_,
desc);
+ action->push_back(
+ {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO,
dst_ar}}});
+ }
+ return dst_ar;
+ }
+ /*! \brief Implementation of memory solver */
+ class MemSolverImpl {
+ public:
+ MemSolverImpl(const dnnl::engine& eng, const DLTensorProvider&
ext_data_provider,
+ const std::vector<dnnl::memory>& const_mems,
+ const std::vector<std::pair<uint32_t, dnnl::memory::desc>>&
ext_mems,
+ const std::vector<dnnl::memory::desc>& tmp_mem_descs,
+ const std::map<size_t, size_t>& tmp_mem_mapping)
+ : eng_(eng),
+ ext_data_provider_(ext_data_provider),
+ const_mems_(const_mems),
+ ext_mems_(ext_mems) {
+ // Construct temp memory objects on the fly. While we have no scratchpads
+ // support on VM/GraphExecutor level.
+ tmp_mems_.resize(tmp_mem_descs.size());
+ for (size_t i = 0; i < tmp_mem_descs.size(); i++) {
+ auto found = tmp_mem_mapping.find(i);
+
+ if (found != tmp_mem_mapping.end()) {
+ auto reuse_hdl = tmp_mems_[found->second].get_data_handle();
+ tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_, reuse_hdl);
+ } else {
+ tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_);
+ }
+ }
+ }
+
+ /*! \brief Find memory object associated with provided ArgId */
+ dnnl::memory operator()(const ArgId& ar) const {
+ switch (ar.flag_) {
+ case CONST:
+ return const_mems_.at(ar.idx_);
+ case TMP_STORAGE:
+ return tmp_mems_.at(ar.idx_);
+ case EXT_EID: {
+ auto eid_and_desc = ext_mems_.at(ar.idx_);
+ auto eid = eid_and_desc.first;
+ auto desc = eid_and_desc.second;
+
+ auto ext_dl_tensor = ext_data_provider_(eid);
+ ICHECK(ext_dl_tensor->data);
+ return dnnl::memory{desc, eng_, ext_dl_tensor->data};
+ }
+ }
+ return {};
+ }
+
+ private:
+ const dnnl::engine& eng_;
+ const DLTensorProvider& ext_data_provider_;
+ const std::vector<dnnl::memory>& const_mems_;
+ const std::vector<std::pair<uint32_t, dnnl::memory::desc>>& ext_mems_;
+ std::vector<dnnl::memory> tmp_mems_;
+ };
+
+ ArgId MakeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; }
+
+ /* Collection of const memory objects. */
+ std::vector<dnnl::memory> const_mem_collection_;
+
+ /* Collection of intermediate memory descriptors. Zero position is reserved
for scratchpads. */
+ std::vector<dnnl::memory::desc> tmp_mem_collection_;
+
+ /* Mapping of some temp buffer on previously registered. */
+ std::map<size_t, size_t> tmp_mem_mapping_;
+
+ /* Collection of external_intermediate memory objects.
+ * first - eid of external buffer to ask
+ * second - t_desc describes how to treat external buffer */
+ std::vector<std::pair<uint32_t, dnnl::memory::desc>> ext_mem_collection_;
+
+ /* Map of eid to index of temp buffer in tmp_mem_collection_ */
+ std::unordered_map<uint32_t, size_t> eid2idx_tmp_;
+
+ /* List of external eid */
+ std::set<uint32_t> ext_io_eid_;
+
+ /* Engine of all tensors existing in this registry */
+ dnnl::engine eng_;
+
+ /* Execution stream use to reorder const data */
+ dnnl::stream stream_;
+};
+
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc
b/src/runtime/contrib/dnnl/dnnl_utils.cc
index 7e79f1c939..23992209f2 100644
--- a/src/runtime/contrib/dnnl/dnnl_utils.cc
+++ b/src/runtime/contrib/dnnl/dnnl_utils.cc
@@ -23,11 +23,14 @@
#include "dnnl_utils.h"
+#include "tvm/runtime/logging.h"
+
namespace tvm {
namespace runtime {
namespace contrib {
-using dt = dnnl::memory::data_type;
-dt dtype_dl2dnnl(DLDataType dltype) {
+
+dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) {
+ using dt = dnnl::memory::data_type;
dt dnnl_type = dt::undef;
if (dltype.code == DataType::TypeCode::kFloat) {
if (dltype.bits == 16) {
@@ -51,6 +54,23 @@ dt dtype_dl2dnnl(DLDataType dltype) {
}
return dnnl_type;
}
+
+dnnl::memory::dims shape_dl2dnnl(const std::vector<int64_t>& shape) {
+ if (shape.empty()) return {1}; // DNNL scalar representation is 1D tensor
+ return shape;
+}
+
+dnnl::memory::desc MakePlainDesc(const std::vector<int64_t>& shape, DLDataType
dltype) {
+ auto dnnl_shape = shape_dl2dnnl(shape);
+ auto dnnl_dtype = dtype_dl2dnnl(dltype);
+
+ auto dnnl_plain_strides = dnnl::memory::dims(dnnl_shape.size(), 1);
+ for (int i = dnnl_shape.size() - 2; i >= 0; i--)
+ dnnl_plain_strides[i] = dnnl_plain_strides[i + 1] * dnnl_shape[i + 1];
+
+ return {dnnl_shape, dnnl_dtype, dnnl_plain_strides};
+}
+
} // namespace contrib
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h
b/src/runtime/contrib/dnnl/dnnl_utils.h
index 4fb236f96f..a598b67044 100644
--- a/src/runtime/contrib/dnnl/dnnl_utils.h
+++ b/src/runtime/contrib/dnnl/dnnl_utils.h
@@ -18,16 +18,23 @@
*/
/*!
- * \file src/runtime/contrib/dnnl/dnnl_utils.h
- * \brief utils for DNNL.
+ * \file src/runtime/contrib/dnnl/dnnl_utils.cc
+ * \brief Some DNNL specific utility functions
*/
#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
-#include <tvm/runtime/data_type.h>
+#include <cstdint>
+#include <ostream>
+#include <string>
+#include <vector>
-#include "dnnl.hpp"
+// TODO(@apeskov): Have to mute warning from dnnl headers.
+// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command
+#include <dnnl.hpp>
+
+#include "tvm/runtime/data_type.h"
namespace tvm {
namespace runtime {
@@ -40,7 +47,90 @@ namespace contrib {
*/
dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype);
+/*!
+ * \brief Converter TVM shape to DNNL dims
+ * \param shape tvm shape
+ * \return dims in terms of dnnl
+ */
+dnnl::memory::dims shape_dl2dnnl(const std::vector<int64_t>& shape);
+
+/*!
+ * \brief Construct plain tensor descriptor
+ * \param shape provided shape
+ * \param dltype provided data type
+ * \return resulting plain tensor desc
+ */
+dnnl::memory::desc MakePlainDesc(const std::vector<int64_t>& shape, DLDataType
dltype);
+
+namespace utils {
+
+/*! \brief Pretty printer util for shape */
+inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::dims&
dims) {
+ o << "[";
+ auto d = dims.begin();
+ if (d != dims.end()) o << *d++;
+ while (d != dims.end()) o << "," << *d++;
+ o << "]";
+ return o;
+}
+
+/*! \brief Pretty printer util for data type */
+inline std::ostream& operator<<(std::ostream& o, const
dnnl::memory::data_type& type) {
+ std::string name = "undef";
+ switch (type) {
+ case dnnl::memory::data_type::undef:
+ name = "undef";
+ break;
+ case dnnl::memory::data_type::f32:
+ name = "fp32";
+ break;
+ case dnnl::memory::data_type::f16:
+ name = "fp16";
+ break;
+ case dnnl::memory::data_type::bf16:
+ name = "bf16";
+ break;
+ case dnnl::memory::data_type::s32:
+ name = "i32";
+ break;
+ case dnnl::memory::data_type::s8:
+ name = "i8";
+ break;
+ case dnnl::memory::data_type::u8:
+ name = "u8";
+ break;
+ }
+ o << name;
+ return o;
+}
+
+/*! \brief Converter data type template arg to runtime object */
+template <typename T>
+inline dnnl::memory::data_type DnnlDType();
+
+template <>
+inline dnnl::memory::data_type DnnlDType<int>() {
+ return dnnl::memory::data_type::s32;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<float>() {
+ return dnnl::memory::data_type::f32;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<uint8_t>() {
+ return dnnl::memory::data_type::u8;
+}
+
+template <>
+inline dnnl::memory::data_type DnnlDType<int8_t>() {
+ return dnnl::memory::data_type::s8;
+}
+
+} // namespace utils
} // namespace contrib
} // namespace runtime
} // namespace tvm
+
#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_