junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459046943
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
Review comment:
Do it only work for `te::Operation` with a single output? Do we have a
fallback solution for operators with multiple outputs like `argmax`?
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
Review comment:
it's a long line...maybe consider move rhs out
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] -
producer_shape[n_common]))) {
+ break;
+ }
+
+ bool direct_access = true;
+ for (const auto& access : access_list) {
+ if (!IsConstShiftEqual(cop->axis[n_common]->var,
access[n_common])) {
+ direct_access = false;
+ break;
+ }
+ }
+
+ if (!direct_access) {
+ break;
+ }
+ }
+
+ node->num_common_outer_iterators[op][producer] = n_common;
+ node->num_common_outer_iterators[producer][op] = n_common;
+ }
+ } else {
+ LOG(FATAL) << "Invalid op: " << op;
+ }
+ }
+
+ // do some static analysis
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->is_injective[op] = true;
+ node->needs_multi_level_tiling[op] = false;
+ node->is_strict_inlineable[op] = false;
+ node->is_output[op] = false;
+ } else if (auto pop = op.as<te::ComputeOpNode>()) {
+ // check whether this op is element-wise and strict-inlineable
+ bool is_injective = true;
+ bool is_strict_inlineable = true;
+
+ bool axis_missing, axis_duplicated, same_order;
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access = pair.second;
+ for (const auto& index : access) {
+ if (!auto_scheduler::IsInjective(op, index, &axis_missing,
&axis_duplicated,
+ &same_order)) {
+ is_injective = false;
+ is_strict_inlineable = false;
+ break;
+ }
+ if (!same_order || axis_duplicated) {
+ // do not strictly inline transpose
Review comment:
i kinda understand this sentence: `transpose` doesn't give `same_order`,
so it is not strictly inlinable. Can we give a formal definition of strict
inlineable?
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] -
producer_shape[n_common]))) {
+ break;
+ }
+
+ bool direct_access = true;
+ for (const auto& access : access_list) {
+ if (!IsConstShiftEqual(cop->axis[n_common]->var,
access[n_common])) {
+ direct_access = false;
+ break;
+ }
+ }
+
+ if (!direct_access) {
+ break;
+ }
+ }
+
+ node->num_common_outer_iterators[op][producer] = n_common;
+ node->num_common_outer_iterators[producer][op] = n_common;
+ }
+ } else {
+ LOG(FATAL) << "Invalid op: " << op;
+ }
+ }
+
+ // do some static analysis
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->is_injective[op] = true;
+ node->needs_multi_level_tiling[op] = false;
+ node->is_strict_inlineable[op] = false;
+ node->is_output[op] = false;
+ } else if (auto pop = op.as<te::ComputeOpNode>()) {
+ // check whether this op is element-wise and strict-inlineable
+ bool is_injective = true;
+ bool is_strict_inlineable = true;
+
+ bool axis_missing, axis_duplicated, same_order;
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access = pair.second;
+ for (const auto& index : access) {
+ if (!auto_scheduler::IsInjective(op, index, &axis_missing,
&axis_duplicated,
+ &same_order)) {
+ is_injective = false;
+ is_strict_inlineable = false;
+ break;
+ }
+ if (!same_order || axis_duplicated) {
+ // do not strictly inline transpose
+ is_strict_inlineable = false;
+ }
+ }
+ if (!is_injective) {
+ break;
+ }
+ }
+ if (has_branch[op]) {
+ is_strict_inlineable = false;
+ }
+
+ // don't strictly inline expensive op (e.g. exp)
+ bool has_expensive_op = false;
+ for (const auto& expr : pop->body) {
+ has_expensive_op |= HasExpensiveOp(expr);
+ }
+
+ node->is_injective[op] = is_injective;
+ node->is_strict_inlineable[op] = is_strict_inlineable &&
!has_expensive_op;
+
+ // check whether the op needs multi-level tiling
+ bool needs_multi_level_tiling = false;
+ int n_missing = 0;
+
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access = pair.second;
+ std::unordered_set<const VarNode*> vars;
+ for (const std::vector<PrimExpr>& indices : access) {
+ for (const PrimExpr& expr : indices) {
+ GatherVars(expr, &vars);
+ }
+ }
+ bool missing = false;
+ for (const auto& axis : pop->axis) {
+ if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get())
== 0) {
+ missing = true;
+ }
+ }
+ if (missing) {
+ n_missing++;
+ }
Review comment:
you don't need this flag, just break
```suggestion
for (const auto& axis : pop->axis) {
if (GetIntImm(axis->dom->extent) > 1 &&
vars.count(axis->var.get()) == 0) {
++n_missing;
break;
}
}
```
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
Review comment:
this function is way too large...consider decomposing it into several
smaller ones.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
Review comment:
It don't have to be static. static may sometimes interfere with
backtrace printing stuff.
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] -
producer_shape[n_common]))) {
+ break;
+ }
+
+ bool direct_access = true;
+ for (const auto& access : access_list) {
+ if (!IsConstShiftEqual(cop->axis[n_common]->var,
access[n_common])) {
+ direct_access = false;
+ break;
+ }
+ }
+
+ if (!direct_access) {
+ break;
+ }
+ }
+
+ node->num_common_outer_iterators[op][producer] = n_common;
+ node->num_common_outer_iterators[producer][op] = n_common;
+ }
+ } else {
+ LOG(FATAL) << "Invalid op: " << op;
+ }
+ }
+
+ // do some static analysis
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->is_injective[op] = true;
+ node->needs_multi_level_tiling[op] = false;
+ node->is_strict_inlineable[op] = false;
+ node->is_output[op] = false;
+ } else if (auto pop = op.as<te::ComputeOpNode>()) {
+ // check whether this op is element-wise and strict-inlineable
+ bool is_injective = true;
+ bool is_strict_inlineable = true;
+
+ bool axis_missing, axis_duplicated, same_order;
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access = pair.second;
+ for (const auto& index : access) {
+ if (!auto_scheduler::IsInjective(op, index, &axis_missing,
&axis_duplicated,
+ &same_order)) {
+ is_injective = false;
+ is_strict_inlineable = false;
+ break;
+ }
+ if (!same_order || axis_duplicated) {
+ // do not strictly inline transpose
+ is_strict_inlineable = false;
+ }
+ }
+ if (!is_injective) {
+ break;
+ }
+ }
+ if (has_branch[op]) {
+ is_strict_inlineable = false;
+ }
+
+ // don't strictly inline expensive op (e.g. exp)
+ bool has_expensive_op = false;
+ for (const auto& expr : pop->body) {
+ has_expensive_op |= HasExpensiveOp(expr);
+ }
+
+ node->is_injective[op] = is_injective;
+ node->is_strict_inlineable[op] = is_strict_inlineable &&
!has_expensive_op;
+
+ // check whether the op needs multi-level tiling
+ bool needs_multi_level_tiling = false;
+ int n_missing = 0;
+
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access = pair.second;
Review comment:
I feel like we should have consistent naming convention, at least inside
a function. There are three places in this function, where "pair.second" is
called "access" or "access_list", each element of which is called "index",
"indices" and "access" - maybe it is better to come up with a consistent naming
for them....
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] -
producer_shape[n_common]))) {
+ break;
+ }
+
+ bool direct_access = true;
+ for (const auto& access : access_list) {
+ if (!IsConstShiftEqual(cop->axis[n_common]->var,
access[n_common])) {
+ direct_access = false;
+ break;
+ }
+ }
+
+ if (!direct_access) {
+ break;
+ }
+ }
+
+ node->num_common_outer_iterators[op][producer] = n_common;
+ node->num_common_outer_iterators[producer][op] = n_common;
+ }
+ } else {
+ LOG(FATAL) << "Invalid op: " << op;
+ }
+ }
+
+ // do some static analysis
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->is_injective[op] = true;
+ node->needs_multi_level_tiling[op] = false;
+ node->is_strict_inlineable[op] = false;
+ node->is_output[op] = false;
+ } else if (auto pop = op.as<te::ComputeOpNode>()) {
Review comment:
why it is named pop...i thought the convention is cop...
```suggestion
} else if (const auto* cop = op.as<te::ComputeOpNode>()) {
```
##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>&
tensors) {
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+
op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ if (auto pv = expr.as<VarNode>()) {
+ return pv == var.get();
+ } else if (auto padd = expr.as<AddNode>()) {
+ return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>())
||
+ (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+ } else if (auto psub = expr.as<SubNode>()) {
+ return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>())
||
+ (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+ } else {
+ return false;
+ }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index,
bool* axis_missing,
+ bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : index) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const
VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // get all ops
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ TensorAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.buf_accesses) {
+ std::vector<std::vector<PrimExpr>>& accesses =
node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(),
iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.buf_accesses);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(),
producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] -
producer_shape[n_common]))) {
+ break;
+ }
+
+ bool direct_access = true;
Review comment:
hmm just curious why it is named direct_access?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]