This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new d777e7c [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216) d777e7c is described below commit d777e7c612cf7a9aae4d8433c36f031c6b6f985c Author: Tianqi Chen <tqc...@users.noreply.github.com> AuthorDate: Wed Jan 6 19:59:12 2021 -0500 [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216) * [TIR][REFACTOR] Enforce allocate to only accept buffer_var with correct PtrType. This is a refactoring step to cleanup legacy issue of opaque buffer var without ptr type information. Now all the allocation comes with the right pointer data type. Places touched: - TVMScript Parser: add the right info to get the correct pointer type. - Cross thread all reduce: set the right pointer type. - Storage rewrite: setup the right pointer type. - Custom dtype: remap the variables with new pointer type. x * Address comments Co-authored-by: Tristan Konolige <tristan.konol...@gmail.com> Co-authored-by: Tristan Konolige <tristan.konol...@gmail.com> --- include/tvm/tir/op.h | 2 +- python/tvm/script/parser.py | 25 +++-- python/tvm/script/scope_handler.py | 13 ++- python/tvm/tir/buffer.py | 5 +- src/driver/driver_api.cc | 3 +- src/target/source/codegen_cuda.cc | 6 +- src/te/operation/cross_thread_reduction.cc | 6 +- src/tir/ir/buffer.cc | 14 ++- src/tir/ir/stmt.cc | 9 +- src/tir/ir/stmt_functor.cc | 14 ++- src/tir/transforms/lower_custom_datatypes.cc | 147 ++++++++++++++++++-------- src/tir/transforms/lower_thread_allreduce.cc | 16 +-- src/tir/transforms/storage_rewrite.cc | 34 +++--- tests/cpp/ir_functor_test.cc | 10 +- tests/python/unittest/test_tir_constructor.py | 1 + 15 files changed, 209 insertions(+), 96 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 61481d9..4a907fc 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -1241,7 +1241,7 @@ inline void DivAmbiguityError(const TA& a) { "please call div, indexdiv/indexmod, " "floordiv/floormod or truncdiv/truncmod directly " "to avoid ambiguity in the code. " - "Checkout these functions in expr_operator.h."); + "Checkout these functions in tir/op.h."); } // The following code are not intended to be used in the codebase. diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index db976d0..33b0bab 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -230,6 +230,19 @@ class TVMScriptParser(Transformer): """Match the arguments of a function call in the AST to the required arguments of the function. This handles positional arguments, positional arguments specified by name, keyword arguments, and varargs. + + Parameters + ---------- + func : Function + The function that provides the signature + + node_call: ast.Call + The AST call node that calls into the function. + + Returns + ------- + arg_list : list + The parsed positional argument. """ assert isinstance(node_call, ast.Call) # collect arguments @@ -435,8 +448,8 @@ class TVMScriptParser(Transformer): node.rhs.span, ) # Pattern 4 - func.enter_scope(node, self.context) arg_list = self.parse_arg_list(func, node.rhs) + func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) func.body = self.parse_body(node) return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) elif isinstance(func, SpecialStmt): @@ -532,9 +545,9 @@ class TVMScriptParser(Transformer): self.current_col_offset = node.span.start_column self.context.new_scope(nodes=node.body.stmts) # for scope handler process the scope - func.enter_scope(node, self.context) - func.body = self.parse_body(node) arg_list = self.parse_arg_list(func, node.rhs) + func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) + func.body = self.parse_body(node) res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) # exit the scope self.context.pop_scope() @@ -571,9 +584,9 @@ class TVMScriptParser(Transformer): self.current_col_offset = node.body.span.start_column self.context.new_scope(nodes=node.body.stmts) # with scope handler process the scope - func.enter_scope(node, self.context) - func.body = self.parse_body(node) arg_list = self.parse_arg_list(func, node.rhs) + func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) + func.body = self.parse_body(node) res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) # exit the scope self.context.pop_scope() @@ -689,7 +702,7 @@ class TVMScriptParser(Transformer): if isinstance(func, Intrin) and func.stmt: return func.handle(arg_list, node.call.func_name.span) elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: - func.enter_scope(node, self.context) + func.enter_scope(node, self.context, arg_list, node.call.func_name.span) func.body = self.parse_body(node) return func.exit_scope(node, self.context, arg_list, node.call.func_name.span) elif isinstance(func, SpecialStmt) and not func.def_symbol: diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 7f252e3..21ed7f6 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -35,7 +35,7 @@ class ScopeHandler: def signature(self): return "tir." + self.func.__name__, get_param_list(self.func) - def enter_scope(self, node, context): + def enter_scope(self, node, context, arg_list, span): pass def exit_scope(self, node, context, arg_list, span): @@ -86,7 +86,7 @@ class Allocate(WithScopeHandler): super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None - def enter_scope(self, node, context): + def enter_scope(self, node, context, arg_list, span): # define buffer vars in symbol table if isinstance(node, ast.With): names = WithScopeHandler.get_optional_var_names(node, context) @@ -98,7 +98,12 @@ class Allocate(WithScopeHandler): else: raise Exception("Internal Bug") - self.buffer_var = tvm.te.var(name, "handle", span=from_synr_span(node.lhs.id.span)) + def setup_buffer_var(extents, dtype, scope, condition=True, span=None): + """Setup buffer var for a given type.""" + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + + setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span)) context.update_symbol(name, self.buffer_var) @@ -187,7 +192,7 @@ class ForScopeHandler(ScopeHandler): super().__init__(func) self.loop_vars = None - def enter_scope(self, node, context): + def enter_scope(self, node, context, arg_list, span): assert isinstance(node, ast.For) loop_var_names = list() diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 2f50aa8..95966a5 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -247,7 +247,10 @@ def decl_buffer( shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) if data is None: - data = Var(name, PointerType(PrimType(dtype)), span) + # Bool is represented as uint1 in the IR, but stored as int8 + storage_type = PrimType(dtype) + storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type + data = Var(name, PointerType(storage_type), span) return _ffi_api.Buffer( data, dtype, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f88b621..bbbb7e3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -69,7 +69,8 @@ Target DefaultTargetHost(Target target) { tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact) { - auto data = tir::Var(name, PointerType(PrimType(dtype))); + DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + auto data = tir::Var(name, PointerType(PrimType(storage_dtype))); bool has_any = false; if (!compact) { for (const auto& it : shape) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c0fb39f..6c73716 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -581,7 +581,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; const VarNode* buffer = op->buffer_var.as<VarNode>(); - std::string scope = alloc_storage_scope_.at(buffer); + auto it = alloc_storage_scope_.find(buffer); + ICHECK(it != alloc_storage_scope_.end()) + << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; + + std::string scope = it->second; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index b0fb9b6..da20dd8 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -145,7 +145,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Array<PrimExpr> lhs; for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; - normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle()); + normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), + PointerType(PrimType(t))); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array<PrimExpr> init_value = combiner->identity_element; @@ -175,7 +176,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, freduce_args.push_back(const_true(1)); std::vector<Var> res_handles(size); for (size_t idx = 0; idx < size; ++idx) { - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle()); + DataType dtype = reduces[idx]->dtype; + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); freduce_args.push_back(res_handles[idx]); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 23a2b3a..1667eb7 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -46,8 +46,9 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) { } Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, Span span) { - return Buffer(Var(name, PointerType(PrimType(dtype)), span), dtype, shape, Array<PrimExpr>(), - PrimExpr(), name, "", 0, 0, kDefault, span); + DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, + Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span); } // Split the given expression w.r.t the add operator @@ -384,9 +385,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides, PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type, Span span) { - ICHECK(IsPointerType(data->type_annotation, dtype)) + DataType storage_dtype = dtype; + // specially handle bool + if (storage_dtype == DataType::Bool()) { + storage_dtype = DataType::Int(8); + } + ICHECK(IsPointerType(data->type_annotation, storage_dtype)) << "Buffer data field expect to have the right pointer type annotation" - << " annotation=" << data->type_annotation << ", dtype=" << dtype; + << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype; auto n = make_object<BufferNode>(); n->data = std::move(data); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 86960d9..fd03046 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -274,9 +274,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition, Stmt body, Span span) { - // TODO(tvm-team): Add invariant check to make sure - // IsPointerPType(buffer_var->type_annotation, dtype) - // once we fix the allocate tvm script printing. + CHECK(IsPointerType(buffer_var->type_annotation, dtype)) + << "The allocated data type (" << dtype + << ") does not match the type annotation of the buffer " << buffer_var << " (" + << buffer_var->type_annotation + << "). The data type should be an element of the pointer type."; + for (size_t i = 0; i < extents.size(); ++i) { ICHECK(extents[i].defined()); ICHECK(extents[i].dtype().is_scalar()); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 529380b..e0ccb49 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -480,7 +480,6 @@ class IRSubstitue : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as<LoadNode>(); if (auto mapped_var = vmap_(op->buffer_var)) { @@ -491,7 +490,6 @@ class IRSubstitue : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as<StoreNode>(); if (auto mapped_var = vmap_(op->buffer_var)) { @@ -501,6 +499,18 @@ class IRSubstitue : public StmtExprMutator { } } + Stmt VisitStmt_(const AttrStmtNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as<AttrStmtNode>(); + // remap var node in attr + if (const auto* var_node = op->node.as<VarNode>()) { + if (auto mapped_var = vmap_(GetRef<Var>(var_node))) { + return AttrStmt(mapped_var, op->attr_key, op->value, op->body); + } + } + return ret; + } + private: std::function<Optional<PrimExpr>(const Var&)> vmap_; }; diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index a3e5a92..21f1b18 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -44,14 +44,14 @@ class CustomDatatypesLowerer : public StmtExprMutator { public: explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} - inline PrimExpr VisitExpr_(const CastNode* op) final { + PrimExpr VisitExpr_(const CastNode* op) final { auto type_code = op->dtype.code(); auto src_type_code = op->value.dtype().code(); // If either datatype is a registered custom datatype, we must lower. - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || - datatype::Registry::Global()->GetTypeRegistered(src_type_code); + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || + datatype::Registry::Global()->GetTypeRegistered(src_type_code); PrimExpr expr = StmtExprMutator::VisitExpr_(op); - if (toBeLowered) { + if (to_be_lowered) { auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type " << static_cast<unsigned>(type_code) << " source type " @@ -61,7 +61,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } - inline PrimExpr VisitExpr_(const FloatImmNode* imm) final { + PrimExpr VisitExpr_(const FloatImmNode* imm) final { auto type_code = imm->dtype.code(); auto e = GetRef<PrimExpr>(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { @@ -73,35 +73,86 @@ class CustomDatatypesLowerer : public StmtExprMutator { return e; } - inline Stmt VisitStmt_(const AllocateNode* allocate) final { - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); - Stmt stmt = StmtExprMutator::VisitStmt_(allocate); - allocate = stmt.as<AllocateNode>(); + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef<Var>(op); - if (toBeLowered) { + auto itr = var_remap_.find(var); + if (itr != var_remap_.end()) { + return itr->second; + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const AllocateNode* allocate) final { + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); + + if (to_be_lowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body); + auto new_buffer_var = + Var(allocate->buffer_var->name_hint, PointerType(PrimType(new_allocate_type))); + var_remap_[allocate->buffer_var] = new_buffer_var; + + Stmt stmt = StmtExprMutator::VisitStmt_(allocate); + allocate = stmt.as<AllocateNode>(); + + return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition, + allocate->body); + } else { + return StmtExprMutator::VisitStmt_(allocate); } - return stmt; } - inline PrimExpr VisitExpr_(const LoadNode* load) final { - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); + PrimExpr VisitExpr_(const LoadNode* load) final { + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); PrimExpr expr = StmtExprMutator::VisitExpr_(load); load = expr.as<LoadNode>(); - if (toBeLowered) { + if (to_be_lowered) { auto new_load_type = DataType::UInt(load->dtype.bits()); - return Load(new_load_type, load->buffer_var, load->index, load->predicate); + auto buffer_var = load->buffer_var; + auto it = var_remap_.find(buffer_var); + if (it != var_remap_.end()) { + buffer_var = it->second; + } + return Load(new_load_type, buffer_var, load->index, load->predicate); } return expr; } - inline PrimExpr VisitExpr_(const CallNode* call) final { - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code()); + Stmt VisitStmt_(const StoreNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as<StoreNode>(); + + auto it = var_remap_.find(op->buffer_var); + if (it != var_remap_.end()) { + return Store(it->second, op->value, op->index, op->predicate); + } else { + return ret; + } + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as<AttrStmtNode>(); + // Due to legacy reasons, some attr node can contain + // information(e.g. alignment) of buffer variables. + // remap these vars when needed + // TODO(tvm-team): remove the rewriting once the buffer var + // attrs are being refactored into the corresponding definition node + if (const auto* var_node = op->node.as<VarNode>()) { + auto it = var_remap_.find(GetRef<Var>(var_node)); + if (it != var_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); + } + } + return ret; + } + + PrimExpr VisitExpr_(const CallNode* call) final { + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code()); PrimExpr expr = StmtExprMutator::VisitExpr_(call); call = expr.as<CallNode>(); - if (toBeLowered) { + if (to_be_lowered) { auto op = call->op.as<OpNode>(); ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented"; auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code()); @@ -113,38 +164,42 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } -#define DEFINE_MUTATE(OP, NodeName) \ - inline PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as<NodeName>(); \ - if (toBeLowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast<unsigned>(type_code) << " not found"; \ - return (*lower)(expr); \ - } \ - return expr; \ +#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \ + PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as<NodeName>(); \ + if (to_be_lowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast<unsigned>(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ } - DEFINE_MUTATE(Add, AddNode); - DEFINE_MUTATE(Sub, SubNode); - DEFINE_MUTATE(Mul, MulNode); - DEFINE_MUTATE(Div, DivNode); - DEFINE_MUTATE(Mod, ModNode); - DEFINE_MUTATE(Min, MinNode); - DEFINE_MUTATE(Max, MaxNode); - DEFINE_MUTATE(EQ, EQNode); - DEFINE_MUTATE(NE, NENode); - DEFINE_MUTATE(LT, LTNode); - DEFINE_MUTATE(LE, LENode); - DEFINE_MUTATE(GT, GTNode); - DEFINE_MUTATE(GE, GENode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode); + TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode); // Later changes may need to add more mutate functions as we support workloads with more ops. +#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE + private: std::string target_; + // remap buffer vars + std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_; }; namespace transform { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c24e26b..f6cb096 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -224,14 +224,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr index(0); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + Type ptr_type = PointerType(PrimType(types[idx])); + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type); PrimExpr pred = const_true(types[idx].lanes()); seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred)); // Uses a local variable to store the shuffled data. // Later on, this allocation will be properly attached to this statement. - Var var("t" + std::to_string(idx), types[idx]); - Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0)); + Var var("t" + std::to_string(idx), ptr_type); + Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0)); local_vars.push_back(s); } @@ -239,14 +240,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // a divergent control flow. Here it uses a variable to cache the current // active channels. // - Var mask_var("mask", DataType::UInt(32)); + DataType mask_dtype = DataType::UInt(32); + Var mask_var("mask", PointerType(PrimType(mask_dtype))); { PrimExpr pred = const_true(1); - PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); + PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. - auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0)); + auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0)); local_vars.push_back(stmt); } @@ -338,7 +340,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx]))); PrimExpr pred = const_true(types[idx].lanes()); seq.emplace_back(Store(shared_bufs[idx], values[idx], BufIndex(reduce_index, group_index, reduce_extent), pred)); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 78c5ca7..d4c5ca0 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -23,6 +23,7 @@ * Re-write data access to enable memory sharing when possible. */ #include <tvm/arith/analyzer.h> +#include <tvm/ir/type.h> #include <tvm/runtime/registry.h> #include <tvm/target/target_info.h> #include <tvm/tir/analysis.h> @@ -934,7 +935,12 @@ class VectorAllocRewriter : public StmtExprMutator { if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body); + // create a new buffer var + DataType new_dtype = tvec[0]; + Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); + // update the remap req. + var_remap_.Set(op->buffer_var, new_buffer_var); + return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); } } return stmt; @@ -949,23 +955,21 @@ class VectorAllocRewriter : public StmtExprMutator { // Internal access map std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_; + // Variables to remap + Map<tir::Var, PrimExpr> var_remap_; // internal analyzer arith::Analyzer analyzer_; }; -Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); - return VectorAllocRewriter()(std::move(stmt)); -} - PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); VectorAllocRewriter rewriter; - n->body = rewriter(n->body); + n->body = rewriter(std::move(n->body)); + Map<tir::Var, PrimExpr> var_remap = std::move(rewriter.var_remap_); Array<tir::Var> args; - Map<tir::Var, PrimExpr> remap_vars; + // rewrite paramters if needed. for (Var var : f->params) { if (var.dtype().is_handle()) { const auto& tvec = rewriter.acc_map_[var.get()]; @@ -973,15 +977,14 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { if (tvec.size() == 1) { tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); args.push_back(new_var); - remap_vars.Set(var, new_var); - + var_remap.Set(var, new_var); } else { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !var->type_annotation.defined()) { tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); args.push_back(new_var); - remap_vars.Set(var, new_var); + var_remap.Set(var, new_var); } else { args.push_back(var); } @@ -991,9 +994,13 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { } } + // no variable remap is needed. + if (var_remap.size() == 0) return f; + + // remap the variables. ICHECK_EQ(args.size(), n->params.size()); n->params = args; - n->body = Substitute(n->body, remap_vars); + n->body = Substitute(n->body, var_remap); return f; } @@ -1003,8 +1010,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - n->body = VectorAllocRewriter()(std::move(n->body)); - return f; + return PointerValueTypeRewrite(std::move(f)); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 683caaa..9be8398 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -114,8 +114,9 @@ TEST(IRF, StmtVisitor) { auto fmaketest = [&]() { auto z = x + 1; Stmt body = Evaluate(z); - Var buffer("b", DataType::Handle()); - return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body); + DataType dtype = DataType::Float(32); + Var buffer("b", PointerType(PrimType(dtype))); + return Allocate(buffer, dtype, {z, z}, const_true(), body); }; v(fmaketest()); ICHECK_EQ(v.count, 3); @@ -140,8 +141,9 @@ TEST(IRF, StmtMutator) { auto fmakealloc = [&]() { auto z = x + 1; Stmt body = Evaluate(z); - Var buffer("b", DataType::Handle()); - return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body); + DataType dtype = DataType::Float(32); + Var buffer("b", PointerType(PrimType(dtype))); + return Allocate(buffer, dtype, {1, z}, const_true(), body); }; auto fmakeif = [&]() { diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 3cde5d7..2bf4ba5 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -154,6 +154,7 @@ def test_stmt_constructor(): assert x.index.value == 10 assert x.value.value == 1 + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32"