junrushao1994 commented on a change in pull request #8467:
URL: https://github.com/apache/tvm/pull/8467#discussion_r672678925
##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -258,6 +258,95 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const
BlockRV& block_rv) {
}
/******** Schedule: loops manipulation ********/
+
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+ CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
+ Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::Fuse(state_, loop_srefs);
+ TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<LoopRV>(result);
+}
+
+Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
+ const Array<Optional<ExprRV>>&
factor_rvs) {
+ class NotSingleInferFactorError : public ScheduleError {
+ public:
+ explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: only one factor can be specified as -1 or none";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "Only one factor can be specified as -1 or none";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ IRModule mod_;
+ };
+
+ class WrongFactorProductError : public ScheduleError {
+ public:
+ explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod),
loop_(std::move(loop)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The product of factors is not larger than or
equal to the extent of "
+ "loop";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "The product of factors is not larger than or equal to the extent
of loop {0}";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
+
+ IRModule mod_;
+ For loop_;
+ };
+ // Prepare for the splitting
+ StmtSRef loop_sref = this->GetSRef(loop_rv);
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ Array<PrimExpr> factors;
+ factors.reserve(factor_rvs.size());
+ int infer_index = -1;
+ PrimExpr tot_length = 1;
+ Array<StmtSRef> results;
+ TVM_TIR_SCHEDULE_BEGIN();
+ // infer factor if needed and check validity of factors
+ for (size_t i = 0; i < factor_rvs.size(); i++) {
+ if (!factor_rvs[i].defined()) {
+ factors.push_back(Integer(-1));
+ if (infer_index == -1) {
+ infer_index = i;
+ } else {
+ throw NotSingleInferFactorError(state_->mod);
+ }
+ } else {
+ PrimExpr factor = this->Get(factor_rvs[i].value());
+ factors.push_back(factor);
+ tot_length *= factor;
+ }
+ }
+ arith::Analyzer analyzer;
+ if (infer_index != -1) {
+ factors.Set(infer_index,
+ analyzer.Simplify(floordiv(loop->extent + tot_length - 1,
tot_length)));
+ } else if (!analyzer.CanProve(tot_length >= loop->extent)) {
Review comment:
use `this->analyzer`
##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -258,6 +258,95 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const
BlockRV& block_rv) {
}
/******** Schedule: loops manipulation ********/
+
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+ CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
+ Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::Fuse(state_, loop_srefs);
+ TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<LoopRV>(result);
+}
+
+Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
+ const Array<Optional<ExprRV>>&
factor_rvs) {
+ class NotSingleInferFactorError : public ScheduleError {
+ public:
+ explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: only one factor can be specified as -1 or none";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "Only one factor can be specified as -1 or none";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ IRModule mod_;
+ };
+
+ class WrongFactorProductError : public ScheduleError {
+ public:
+ explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod),
loop_(std::move(loop)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The product of factors is not larger than or
equal to the extent of "
+ "loop";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "The product of factors is not larger than or equal to the extent
of loop {0}";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
+
+ IRModule mod_;
+ For loop_;
+ };
+ // Prepare for the splitting
+ StmtSRef loop_sref = this->GetSRef(loop_rv);
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ Array<PrimExpr> factors;
+ factors.reserve(factor_rvs.size());
+ int infer_index = -1;
+ PrimExpr tot_length = 1;
+ Array<StmtSRef> results;
+ TVM_TIR_SCHEDULE_BEGIN();
+ // infer factor if needed and check validity of factors
+ for (size_t i = 0; i < factor_rvs.size(); i++) {
+ if (!factor_rvs[i].defined()) {
+ factors.push_back(Integer(-1));
+ if (infer_index == -1) {
+ infer_index = i;
+ } else {
+ throw NotSingleInferFactorError(state_->mod);
+ }
+ } else {
+ PrimExpr factor = this->Get(factor_rvs[i].value());
+ factors.push_back(factor);
+ tot_length *= factor;
+ }
+ }
+ arith::Analyzer analyzer;
+ if (infer_index != -1) {
+ factors.Set(infer_index,
+ analyzer.Simplify(floordiv(loop->extent + tot_length - 1,
tot_length)));
+ } else if (!analyzer.CanProve(tot_length >= loop->extent)) {
+ LOG(INFO) << infer_index;
Review comment:
remove this?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]