merrymercy commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r485321344
##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -81,12 +81,16 @@ def apply_steps_from_state(self, state):
state : Union[State, StateObject]
The state from which we get transform steps.
+ layout_rewrite: Bool
+ Rewrite the layout of placeholder to make it
+ most frendly for the generated schedule to read from.
Review comment:
```suggestion
layout_rewrite: Bool
Rewrite the layout of placeholders specified by
"layout_free_placeholders" attr
to make it most friendly for the generated schedule to read from.
```
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,319 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op) {
+ ParseKernelLayout(new_layout, &new_shape_, &new_names_);
+ }
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names_.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names_[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape_[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape_[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ Array<PrimExpr> new_shape_;
+ std::vector<std::string> new_names_;
+};
+
+std::string get_orig_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
Review comment:
Code style
get_orig_layout -> GetOrigLayout
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,319 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op) {
+ ParseKernelLayout(new_layout, &new_shape_, &new_names_);
+ }
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names_.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names_[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape_[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape_[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ Array<PrimExpr> new_shape_;
+ std::vector<std::string> new_names_;
+};
+
+std::string get_orig_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = AxisBaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string orig_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout);
+ return orig_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
Review comment:
name style
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]