comaniac commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472339009
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -40,6 +40,7 @@
#include <vector>
#include "../arith/pattern_match.h"
+#include "search_policy/utils.h"
Review comment:
- All utility functions should be moved to `utils.h`.
- All the function names should follow C++ naming convention. For example,
`ParseKernelLayout` instead of `parser_kernel_layout`.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
Review comment:
Provide more comments in this function to help future maintain.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
Review comment:
- Can we inline this function?
- `BaseName` is too general.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
Review comment:
s/uint/size_t/
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
Review comment:
This function needs more comments.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
Review comment:
It'd be better to use `for(size_t stage_id = 0; stage_id <
stage->stages.size(); ++stage_id)` if you need the ID.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
Review comment:
It seems like `new_layout_` is fixed after the rewriter is constructed.
Accordingly, `new_shape` and `new_names` should also be fixed. IMHO, we should
be able to figure out the new shape and names in the constructor as well to
make the logic more clear.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
Review comment:
Ditto: utility function, naming, and comments.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
Review comment:
Ditto: utility function, naming, and comments.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
Review comment:
Since this statement is pretty long, I'd suggest
```
if (!op->IsInstance<te::ComputeOpNode>()) {
continue;
}
```
so that we can reduce an indent.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ const Map<String, ObjectRef>& attrs = op->attrs;
+ if (attrs.count(layout_free_placeholders_key)) {
+ const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+ Array<te::Tensor> placeholders =
Downcast<Array<te::Tensor>>(attr_value);
+ for (const auto& placeholder : placeholders) {
+ const auto& placeholder_op = placeholder->op;
+
+ // Check whether this placeholder has already been handled
+ if (handled_ops.count(placeholder_op)) {
+ continue;
+ }
+
+ // skip the op that is not direct consumer of this placeholder,
+ // mostly due to cache read/write.
+ bool direct_consumer = false;
+ for (auto& t : op->InputTensors()) {
+ if (t->op == placeholder_op) {
+ direct_consumer = true;
+ break;
+ }
+ }
+ if (!direct_consumer) {
+ continue;
+ }
+
+ std::set<std::string> placeholder_axis_names;
+ get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+ Array<PrimExpr> new_shape;
+ std::string new_layout = get_new_layout(&new_shape, state, stage_id,
stage, op,
+ placeholder,
placeholder_axis_names);
+
+ handled_ops.insert(placeholder_op);
+
+ Array<te::Operation> old_ops = pdag->ops;
+ ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+ // Create new placeholder
+ te::Operation new_placeholder_op;
+ new_placeholder_op = te::PlaceholderOp(placeholder_op->name,
new_shape,
+
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+ te::Operation new_compute_op, old_compute_op;
+ Array<PrimExpr> new_body;
+ IndexRewriter index_rewriter(placeholder_op, new_layout);
+ for (auto& op : old_ops) {
+ if (auto* pop = op.as<te::ComputeOpNode>()) {
+ bool need_update = false;
+ for (auto& t : op->InputTensors()) {
+ if (t->op == placeholder_op) {
+ need_update = true;
+ break;
+ }
+ }
+ if (need_update) {
+ for (auto& body : pop->body) {
+ new_body.push_back(index_rewriter.Rewrite(body));
+ }
+ old_compute_op = op;
+ CHECK(!new_compute_op.defined());
+ new_compute_op =
+ te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis,
new_body);
+ }
+ }
+ }
+
+ // construct the map from old_op to new_op
+ std::unordered_map<te::Operation, te::Operation> updated_ops;
+ for (size_t i = 0; i < old_ops.size(); ++i) {
+ auto old_op = old_ops[i];
+ if (old_op == placeholder_op) {
+ pops->SetItem(i, new_placeholder_op);
+ updated_ops[placeholder_op] = new_placeholder_op;
+ } else if (old_op == old_compute_op) {
+ pops->SetItem(i, new_compute_op);
+ updated_ops[old_compute_op] = new_compute_op;
+ } else {
+ pops->SetItem(i, old_op);
+ }
+ }
+
+ // Because ops is sorted in topo-order, only do one pass linear scan
here.
Review comment:
Add comments talking about the purpose.
```suggestion
// Because ops is sorted in topo-order, we only need one pass to
(what).
```
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ const Map<String, ObjectRef>& attrs = op->attrs;
+ if (attrs.count(layout_free_placeholders_key)) {
+ const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+ Array<te::Tensor> placeholders =
Downcast<Array<te::Tensor>>(attr_value);
+ for (const auto& placeholder : placeholders) {
+ const auto& placeholder_op = placeholder->op;
+
+ // Check whether this placeholder has already been handled
+ if (handled_ops.count(placeholder_op)) {
+ continue;
+ }
+
+ // skip the op that is not direct consumer of this placeholder,
+ // mostly due to cache read/write.
+ bool direct_consumer = false;
+ for (auto& t : op->InputTensors()) {
+ if (t->op == placeholder_op) {
+ direct_consumer = true;
+ break;
+ }
+ }
+ if (!direct_consumer) {
+ continue;
+ }
+
+ std::set<std::string> placeholder_axis_names;
+ get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+ Array<PrimExpr> new_shape;
+ std::string new_layout = get_new_layout(&new_shape, state, stage_id,
stage, op,
+ placeholder,
placeholder_axis_names);
+
+ handled_ops.insert(placeholder_op);
+
+ Array<te::Operation> old_ops = pdag->ops;
+ ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+ // Create new placeholder
+ te::Operation new_placeholder_op;
+ new_placeholder_op = te::PlaceholderOp(placeholder_op->name,
new_shape,
+
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+ te::Operation new_compute_op, old_compute_op;
Review comment:
Comment on what is this loop for.
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -118,6 +123,8 @@ class IteratorNode : public Object {
IteratorKind iter_kind;
/*! \brief The annotation type of this iterator. */
IteratorAnnotation annotation;
+ /*! The original iterators before fusion. */
+ std::vector<Iterator> ori_iters;
Review comment:
Same opinion.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ const Map<String, ObjectRef>& attrs = op->attrs;
+ if (attrs.count(layout_free_placeholders_key)) {
+ const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+ Array<te::Tensor> placeholders =
Downcast<Array<te::Tensor>>(attr_value);
+ for (const auto& placeholder : placeholders) {
+ const auto& placeholder_op = placeholder->op;
+
+ // Check whether this placeholder has already been handled
+ if (handled_ops.count(placeholder_op)) {
+ continue;
+ }
+
+ // skip the op that is not direct consumer of this placeholder,
+ // mostly due to cache read/write.
Review comment:
```suggestion
// Skip the op that is not direct consumer of this placeholder.
// This is usually caused by cache read/write.
```
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ const Map<String, ObjectRef>& attrs = op->attrs;
+ if (attrs.count(layout_free_placeholders_key)) {
+ const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+ Array<te::Tensor> placeholders =
Downcast<Array<te::Tensor>>(attr_value);
+ for (const auto& placeholder : placeholders) {
+ const auto& placeholder_op = placeholder->op;
+
+ // Check whether this placeholder has already been handled
+ if (handled_ops.count(placeholder_op)) {
+ continue;
+ }
+
+ // skip the op that is not direct consumer of this placeholder,
+ // mostly due to cache read/write.
+ bool direct_consumer = false;
+ for (auto& t : op->InputTensors()) {
+ if (t->op == placeholder_op) {
+ direct_consumer = true;
+ break;
+ }
+ }
+ if (!direct_consumer) {
+ continue;
+ }
+
+ std::set<std::string> placeholder_axis_names;
+ get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+ Array<PrimExpr> new_shape;
+ std::string new_layout = get_new_layout(&new_shape, state, stage_id,
stage, op,
+ placeholder,
placeholder_axis_names);
+
+ handled_ops.insert(placeholder_op);
+
+ Array<te::Operation> old_ops = pdag->ops;
+ ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+ // Create new placeholder
+ te::Operation new_placeholder_op;
+ new_placeholder_op = te::PlaceholderOp(placeholder_op->name,
new_shape,
+
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+ te::Operation new_compute_op, old_compute_op;
+ Array<PrimExpr> new_body;
+ IndexRewriter index_rewriter(placeholder_op, new_layout);
+ for (auto& op : old_ops) {
+ if (auto* pop = op.as<te::ComputeOpNode>()) {
+ bool need_update = false;
+ for (auto& t : op->InputTensors()) {
+ if (t->op == placeholder_op) {
+ need_update = true;
+ break;
+ }
+ }
+ if (need_update) {
+ for (auto& body : pop->body) {
+ new_body.push_back(index_rewriter.Rewrite(body));
+ }
+ old_compute_op = op;
+ CHECK(!new_compute_op.defined());
+ new_compute_op =
+ te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis,
new_body);
+ }
+ }
+ }
+
+ // construct the map from old_op to new_op
+ std::unordered_map<te::Operation, te::Operation> updated_ops;
+ for (size_t i = 0; i < old_ops.size(); ++i) {
+ auto old_op = old_ops[i];
+ if (old_op == placeholder_op) {
+ pops->SetItem(i, new_placeholder_op);
+ updated_ops[placeholder_op] = new_placeholder_op;
+ } else if (old_op == old_compute_op) {
+ pops->SetItem(i, new_compute_op);
+ updated_ops[old_compute_op] = new_compute_op;
+ } else {
+ pops->SetItem(i, old_op);
+ }
+ }
+
+ // Because ops is sorted in topo-order, only do one pass linear scan
here.
+ for (size_t i = 0; i < pops->size(); ++i) {
+ auto old_op = Downcast<te::Operation>(pops->at(i));
+ if (auto* pop = old_op.as<te::ComputeOpNode>()) {
+ auto inputs = pop->InputTensors();
+ std::unordered_map<te::Tensor, te::Tensor> rmap;
+ for (auto input : inputs) {
+ auto it = updated_ops.find(input->op);
+ te::Operation new_op;
+ while (it != updated_ops.end()) {
+ new_op = it->second;
+ it = updated_ops.find(new_op);
+ }
+ if (new_op.defined()) {
+ int index = input->value_index;
+ rmap[input] = new_op.output(index);
+ }
+ }
+ if (!rmap.empty()) {
+ te::Operation new_op = pop->ReplaceInputs(old_op, rmap);
+ updated_ops[old_op] = new_op;
+ pops->SetItem(i, new_op);
+ }
+ }
+ }
+
+ pdag->init_state = State(pdag->ops);
+
+ Array<te::Tensor> old_tensors = pdag->tensors;
+ ArrayNode* ptensors = pdag->tensors.CopyOnWrite();
+
+ for (size_t i = 0; i < old_tensors.size(); ++i) {
Review comment:
Comment on what is this loop for.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
Review comment:
Why do you need to assign a name to `axis_name`? Seems like this will
never be used in the rest of this function.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
Review comment:
remove this line.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
Review comment:
Be more specific about what's the assumption here.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+ std::vector<std::string>* axes) {
+ int32_t factor = 0;
+ std::string axis = "";
+ for (char c : std::string(layout)) {
+ if (c >= 'A' && c <= 'z') {
+ axis += c;
+ if (factor != 0) {
+ shape->push_back(factor);
+ factor = 0;
+ }
+ } else if (c >= '0' && c <= '9') {
+ factor = factor * 10 + c - '0';
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ axis = "";
+ }
+ } else {
+ LOG(FATAL) << "Invalid layout " << layout;
+ }
+ }
+ if (!axis.empty()) {
+ axes->push_back(axis);
+ }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0,
str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+ IndexRewriter(const te::Operation& placeholder_op, const std::string&
new_layout)
+ : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+ PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+ PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+ te::Tensor t = Downcast<te::Tensor>(op->producer);
+ if (t->op == placeholder_op_) {
+ Array<PrimExpr> new_shape;
+ std::vector<std::string> new_names;
+ parse_kernel_layout(new_layout_, &new_shape, &new_names);
+ std::unordered_map<std::string, PrimExpr> name_to_arg;
+ for (const auto& arg : op->indices) {
+ std::string axis_name;
+ if (const auto* pimm = arg.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+ CHECK_EQ(name_to_arg.count(axis_name), 0);
+ name_to_arg[axis_name] = arg;
+ }
+ }
+
+ std::unordered_map<std::string, PrimExpr> div_factors;
+ std::vector<PrimExpr> r_new_args;
+ for (int i = new_names.size() - 1; i >= 0; --i) {
+ auto ori_iter_name = new_names[i];
+ auto name_it = name_to_arg.find(ori_iter_name);
+ CHECK(name_it != name_to_arg.end());
+ PrimExpr ori_arg = name_it->second;
+
+ PrimExpr mod_factor = new_shape[i];
+
+ PrimExpr div_factor = 1;
+ if (div_factors.count(ori_iter_name)) {
+ div_factor = div_factors[ori_iter_name];
+ }
+ div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+ PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+ r_new_args.push_back(new_arg);
+ }
+
+ Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+ std::make_move_iterator(r_new_args.rend()));
+ return ProducerLoad(op->producer, new_args);
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const te::Operation& placeholder_op_;
+ const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
+ const te::Tensor& placeholder) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+ extractor.Extract(exp);
+ }
+
+ std::ostringstream os;
+ uint i = 0;
+ const auto& placeholder_op = placeholder->op;
+ CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+ for (const auto& ev : extractor.read_access[placeholder_op]) {
+ for (const auto& e : ev) {
+ std::string axis_name;
+ if (const auto* pimm = e.as<IntImmNode>()) {
+ CHECK_EQ(pimm->value, 0);
+ axis_name = "IntImm";
+ } else {
+ axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+ }
+
+ placeholder_axis_names->insert(axis_name);
+ os << placeholder->shape[i++] << axis_name;
+ }
+ }
+
+ CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+ std::string ori_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+ return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state,
const int stage_id,
+ const Stage& stage, const te::Operation& op,
+ const te::Tensor& placeholder,
+ const std::set<std::string>&
placeholder_axis_names) {
+ std::ostringstream os;
+ Array<Iterator> stage_iters;
+
+ auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+ int attach_pos = -1;
+ size_t iters_before_attach = 0;
+ if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+ auto attach = attach_it->second;
+ const auto& attach_stage = state->stages[attach.first];
+ attach_pos = attach.second;
+ stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+ attach_stage->iters.begin() + attach_pos + 1);
+ }
+
+ stage_iters.insert(stage_iters.end(), stage->iters.begin(),
stage->iters.end());
+
+ std::vector<Iterator> iters;
+ for (size_t i = 0; i < stage_iters.size(); ++i) {
+ const auto& iter = stage_iters[i];
+ if (iter->ori_iters.empty()) {
+ iters.push_back(iter);
+ } else {
+ for (const Iterator& ori_iter : iter->ori_iters) {
+ iters.push_back(ori_iter);
+ }
+ }
+ if (static_cast<int>(i) == attach_pos) {
+ iters_before_attach = iters.size();
+ }
+ }
+
+ std::vector<std::string> new_names;
+ std::vector<std::string> new_axis_names;
+ for (const Iterator& iter : iters) {
+ std::set<std::string> ori_iter_names;
+ ExtractOriginalIterators(iter->name, &ori_iter_names);
+ // fused iters have been replaced with iter->ori_iters.
+ // So there should be only one ori iter name extracted from iter->name.
+ CHECK_EQ(ori_iter_names.size(), 1);
+ auto ori_iter_name = BaseName(*ori_iter_names.begin());
+ new_axis_names.push_back(ori_iter_name);
+ }
+ for (size_t i = 0; i < new_axis_names.size(); ++i) {
+ auto iter = iters[i];
+ std::string ori_iter_name;
+ if (i < iters_before_attach) {
+ ori_iter_name = new_axis_names[i + iters_before_attach];
+ } else {
+ ori_iter_name = new_axis_names[i];
+ }
+ if (placeholder_axis_names.count(ori_iter_name)) {
+ os << iter->range->extent << ori_iter_name;
+ new_names.push_back(ori_iter_name);
+ new_shape->push_back(iter->range->extent);
+ }
+ }
+ std::string new_layout = os.str();
+ os.str("");
+ // TODO(minmin): uncomment this line for relay integration
+ //
::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+ return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+ ComputeDAGNode* pdag = this->CopyOnWrite();
+ auto node = make_object<StateNode>();
+ node->transform_steps = transform_steps;
+ node->concrete = true;
+ const State& state = InferBound(State(node));
+ OperationSet handled_ops;
+ int stage_id = -1;
+ for (const auto& stage : state->stages) {
+ stage_id += 1;
+ const te::Operation& op = stage->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ const Map<String, ObjectRef>& attrs = op->attrs;
+ if (attrs.count(layout_free_placeholders_key)) {
Review comment:
```suggestion
if (attrs.count(layout_free_placeholders_key) == 0) {
continue;
}
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]