piiswrong commented on a change in pull request #7947: [WIP] Refactor infer 
storage function for sparse operators.
URL: https://github.com/apache/incubator-mxnet/pull/7947#discussion_r140920207
 
 

 ##########
 File path: src/operator/operator_common.h
 ##########
 @@ -462,6 +497,72 @@ class SparseTempStorage {
   Storage::Handle  handle_;
 };
 
+/*! \brief get string representation of the operator stypes */
+inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         const std::vector<int>& in_attrs,
+                                         const std::vector<int>& out_attrs) {
+  std::string result = "";
+  result += "operator = " + attrs.op->name + "\n";
+  result += "input storage types = [";
+  for (const auto attr : in_attrs) {
+    result += common::stype_string(attr) + ", ";
+  }
+  result += "]\n";
+  result += "output storage types = [";
+  for (const auto attr : out_attrs) {
+    result += common::stype_string(attr) + ", ";
+  }
+  result += "]\n";
+  result += "params = {";
+  for (auto kv : attrs.dict) {
+    result += "\"" + kv.first + "\" : " + kv.second + ", ";
+  }
+  result += "}\n";
+  result += "context.dev_mask = " + std::to_string(dev_mask);
+  return result;
+}
+
+/*! \brief get string representation of the operator */
+inline std::string operator_string(const nnvm::NodeAttrs& attrs,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
+  std::string result = "";
+  std::vector<int> in_stypes;
+  std::vector<int> out_stypes;
+  auto xform = [](const NDArray arr) -> int { return arr.storage_type(); };
+  std::transform(inputs.begin(), inputs.end(), std::back_inserter(in_stypes), 
xform);
+  std::transform(outputs.begin(), outputs.end(), 
std::back_inserter(out_stypes), xform);
+  result += operator_stype_string(attrs, ctx.run_ctx.ctx.dev_mask(), 
in_stypes, out_stypes);
+  return result;
+}
+
+/*! \brief log storage fallback event
+ */
+inline void LogStorageFallback(const nnvm::NodeAttrs& attrs,
+                               const int dev_mask,
+                               const std::vector<int>* in_attrs,
+                               const std::vector<int>* out_attrs) {
+  using namespace op;
+  auto warning_printed = 
dmlc::ThreadLocalStore<std::unordered_set<std::string>>::Get();
+  bool log_verbose = dmlc::GetEnv("MXNET_STORAGE_FALLBACK_LOG_VERBOSE", true);
+  if (log_verbose) {
+    std::string warning = operator_stype_string(attrs, dev_mask, *in_attrs, 
*out_attrs);
+    if (warning_printed->find(warning) == warning_printed->end()) {
+      LOG(INFO) << "\nStorage fallback detected:\n" << warning
 
 Review comment:
   This message doesn't say anything about which operator is falling back
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to