minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477112491



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, 
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& 
new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, 
const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // 
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, 
const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& 
placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), 
stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // 
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = 
Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to