junrushao1994 commented on a change in pull request #9871: URL: https://github.com/apache/tvm/pull/9871#discussion_r785346258
########## File path: src/tir/schedule/primitive/blockize_tensorize.cc ########## @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace + * represented by the outer loops. + */ +class SubspaceNotDivisibleError : public ScheduleError { + public: + explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) + : mod_(std::move(mod)), + scope_loop_(std::move(scope_loop)), + inner_block_(std::move(inner_block)) {} + + String FastErrorString() const final { + return "ScheduleError: The bindings of the inner block can not be blockized."; + } + + String DetailRenderTemplate() const final { + return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " + "starting at {1}."; + } + + IRModule mod() const final { return mod_; } + + Array<ObjectRef> LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + + private: + IRModule mod_; + For scope_loop_; + Block inner_block_; +}; + +/*! + * \brief Detect if bindings are a trivial case of the subspace division where we can divide the + * block iter bindings into two categories: + * 1. The binding covers no inner loop vars. + * 2. The binding covers only inner loop vars. + * + * The bindings are not required to be quasi-affine. + * + * \param iter_vars The input iterators + * \param bindings The values of iter_vars + * \param outer_loops Iterators outside the subspace. + * \param inner_loops Iterators of the subspace + * \param predicate The predicate constaints on the input iterators. + * \return The result of the subspace division. + */ +Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter_vars, + const Array<PrimExpr>& bindings, + const Array<Var>& outer_iters, + const Array<Var>& inner_iters, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + std::vector<Array<arith::IterMark>> res; + std::unordered_set<const VarNode*> outer_loop_vars; + std::unordered_set<const VarNode*> inner_loop_vars; + for (const Var& var : outer_iters) { + outer_loop_vars.insert(var.get()); + } + for (const Var& var : inner_iters) { + inner_loop_vars.insert(var.get()); + } + const arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); + + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = UsesVar( + bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); + bool inner = UsesVar( + bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + arith::IterMark iter_mark; + if (bindings[i]->IsInstance<VarNode>()) { + iter_mark = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + if (outer && !inner) { + arith::IterMark outer{nullptr}; + const auto& outer_iter = iter_mark; + const auto& inner_iter = unit_iter_mark; + res.push_back({outer_iter, inner_iter}); + } else if (inner && !outer) { + const auto& outer_iter = unit_iter_mark; + const auto& inner_iter = iter_mark; + res.push_back({outer_iter, inner_iter}); + } else if (!outer && !inner) { + const auto& outer_iter = unit_iter_mark; + const auto& inner_iter = unit_iter_mark; + res.push_back({outer_iter, inner_iter}); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +/*! + * \brief Generate the blockized init block. + * \param block The original block with init. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \return The subtree of the init block and its outer loops. + */ +Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, + const std::vector<const ForNode*>& inner_loops) { + Array<IterVar> init_block_iters; + Array<PrimExpr> init_bindings; + const Block& inner_block = inner_block_realize->block; + + // Step 1: Collect data-parallel block iters + for (size_t i = 0; i < inner_block->iter_vars.size(); i++) { + const IterVar& iter_var = inner_block->iter_vars[i]; + const PrimExpr& binding = inner_block_realize->iter_values[i]; + if (iter_var->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [&iter_var](const VarNode* var) { return var == iter_var->var.get(); })) { + init_block_iters.push_back(iter_var); + init_bindings.push_back(binding); + } + } + + // Step 2: Collect loops related to iters of the init block + std::vector<const ForNode*> init_loops; + for (const ForNode* inner_loop : inner_loops) { + for (const PrimExpr& init_binding : init_bindings) { + if (UsesVar(init_binding, + [inner_loop](const VarNode* var) { return var == inner_loop->loop_var.get(); })) { + init_loops.push_back(inner_loop); + } + } + } + + // Step 3: Create new block iters for the init block + Map<Var, PrimExpr> subst_map; + for (size_t i = 0; i < init_block_iters.size(); i++) { + IterVar new_iter_var = init_block_iters[i]; + auto* new_init_var_node = new_iter_var.CopyOnWrite(); + Var old_var = new_iter_var->var; + new_init_var_node->var = old_var.copy_with_suffix("_init"); + subst_map.Set(old_var, new_iter_var->var); + init_block_iters.Set(i, std::move(new_iter_var)); + } + + // Step 4: Generate loop nests and the init block + Block init_block{/*iter_vars=*/init_block_iters, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt}; + Stmt new_init = BlockRealize( + /*iter_values=*/init_bindings, + /*predicate=*/inner_block_realize->predicate, + /*block=*/std::move(init_block)); + + // Step 5: Generate the parent loops for the init block + for (const ForNode* init_loop : init_loops) { + ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop); + new_loop->loop_var = init_loop->loop_var.copy_with_suffix(""); + subst_map.Set(init_loop->loop_var, new_loop->loop_var); + new_loop->body = std::move(new_init); + new_init = For(new_loop); + } + + // Step 6: Substitute with new loop variables and block iters to prevent duplication of + // variables in the outer block. + new_init = Substitute(new_init, subst_map); + + return new_init; +} + +/*! + * \brief A helper to collect the parent loops of the block. The loops are divided into two groups, + * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the + * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its + * successor loops. It is possible that 'outer_loops' is empty. + */ +class LoopSubspaceCollector { + public: + /*! + * \brief Collect the parent loops of the block and store the result in the corresponding fields. + * \param block_sref The sref to the target block. + * \param loop_sref The sref to the separator loop. + */ + void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { + bool inner = true; + for (StmtSRefNode* current_sref = block_sref->parent; + current_sref && current_sref->stmt->IsInstance<ForNode>(); + current_sref = current_sref->parent) { + const auto* current_loop = current_sref->StmtAs<ForNode>(); + ICHECK(current_loop); + if (inner) { + inner_loops.push_back(current_loop); + inner_loop_vars.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_loop_vars.push_back(current_loop->loop_var); + } + loop_var_domain.Set(current_loop->loop_var, + Range::FromMinExtent(current_loop->min, current_loop->extent)); + if (current_sref == loop_sref.get()) inner = false; + } + } + /*! \brief Outer loops which are ancestors of the separator. */ + std::vector<const ForNode*> outer_loops; + /*! \brief Inner loops which are the separator itself or its successors. */ + std::vector<const ForNode*> inner_loops; + /*! \brief Loop variables of the outer loops. */ + Array<Var> outer_loop_vars; + /*! \brief Loop variables of the inner loops. */ + Array<Var> inner_loop_vars; + /*! \brief Domain of the loop variables. */ + Map<Var, Range> loop_var_domain; +}; + +/*! + * \brief Check the bindings of the block iters can be divided by a subspace collected by the + * collector. + * \param mod The current IR module. + * \param block_realize The block realize to be checked. + * \param collector The collector which has collected the loops of the block. + * \param analyzer The arithmetic analyzer. + * \return The result of the subspace division. + * \throws ScheduleError If the bindings are not divisible by the subspace. + */ +Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod, + const BlockRealize& block_realize, + const LoopSubspaceCollector& collector, + arith::Analyzer* analyzer) { + const Block& block = block_realize->block; + DiagnosticContext diag_ctx(DiagnosticContext::Default(mod)); + + Array<Array<arith::IterMark>> division = + arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, + collector.inner_loop_vars, block_realize->predicate, + /*require_bijective=*/false, analyzer, diag_ctx); + + if (division.empty()) { + // If we can't do perfect subspace division, check if it is a trivial case of subspace division. + // In this case, we can still blockize. + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, + collector.outer_loop_vars, collector.inner_loop_vars, + block_realize->predicate); + } + if (division.empty()) { + throw SubspaceNotDivisibleError(mod, GetRef<For>(collector.inner_loops.back()), block); + } + return division; +} + +/*! + * \brief The binding extractor to compute the bindings of the outer and the inner blocks after + * blockize. + */ +class BlockizedBindingExtractor { + public: + /*! + * \brief Extract bindings for blockize. + * \param iter_vars The iter vars of the original inner block. + * \param division The result of the subspace division. + */ + void ExtractBindings(const Array<IterVar>& iter_vars, + const Array<Array<arith::IterMark>>& division) { + ICHECK(iter_vars.size() + 1 == division.size()); Review comment: ```suggestion ICHECK_EQ(iter_vars.size() + 1, division.size()); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
