This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6100211a536ea9644d2a153163a022b797cd1dab Author: Andrew Zhao Luo <[email protected]> AuthorDate: Fri Sep 2 14:42:04 2022 -0700 optional complete --- src/tir/ir/buffer_common.h | 16 +++++++++------- src/tir/ir/expr.cc | 8 ++++---- src/tir/ir/stmt.cc | 8 ++++---- src/tir/transforms/inject_ptx_async_copy.cc | 8 ++++---- src/tir/transforms/storage_rewrite.cc | 21 +++++++++------------ 5 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h index 5921c54d98..8dac41a02e 100644 --- a/src/tir/ir/buffer_common.h +++ b/src/tir/ir/buffer_common.h @@ -26,7 +26,7 @@ #include <tvm/ir/type.h> #include <tvm/runtime/data_type.h> -#include <optional> +#include <utility> namespace tvm { namespace tir { @@ -36,20 +36,22 @@ namespace tir { * * \param type The type to be checked. * - * \return An std::optional<DataType> object. If the type is a pointer - * to a primitive type, the object has a value which is the pointed-to - * type. Otherwise the object is nullopt. + * \return A (bool, DataType) pair. If the type is a pointer to a + * primitive, the boolean is true and the DataType is the pointed-to + * type. Otherwise, the boolean is false and the DataType is + * default-constructed. This can be replaced with std::optional with + * C++17 if/when C++17 is required. */ -inline std::optional<runtime::DataType> GetPointerType(const Type& type) { +inline std::pair<bool, runtime::DataType> GetPointerType(const Type& type) { if (type.defined()) { if (auto* ptr_type = type.as<PointerTypeNode>()) { if (auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) { - return prim_type->dtype; + return {true, prim_type->dtype}; } } } - return std::nullopt; + return {false, DataType()}; } } // namespace tir diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 59db4ea410..f841f94b5a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -648,7 +648,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S // annotation tells us otherwise. int element_lanes = 1; auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); - if (pointer_type.has_value()) { + if (pointer_type.first) { // Cannot check element type of array, as it may be different than // the loaded type in some cases. // @@ -663,11 +663,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. - // ICHECK(dtype.element_of() == pointer_type->element_of()) + // ICHECK(dtype.element_of() == pointer_type.second.element_of()) // << "Type mismatch, cannot load type " << dtype << " from buffer " << // buffer_var->name_hint - // << " of type " << pointer_type.value(); - element_lanes = pointer_type->lanes(); + // << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); } // The C-based codegens assume that all loads occur on a array with diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e21d014fe1..524204f3d3 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -271,7 +271,7 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, // annotation tells us otherwise. int element_lanes = 1; auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); - if (pointer_type.has_value()) { + if (pointer_type.first) { // Currently cannot check element type of array, see Load::Load // for details. @@ -279,10 +279,10 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. - // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of()) + // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of()) // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " - // << buffer_var->name_hint << " of type " << pointer_type.value(); - element_lanes = pointer_type->lanes(); + // << buffer_var->name_hint << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); } ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 8ee0d054e5..c74ce9d3d2 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -60,21 +60,21 @@ class PTXAsyncCopyInjector : public StmtMutator { if (bytes == 4 || bytes == 8 || bytes == 16) { auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + ICHECK(dst_elem_type.first && src_elem_type.first) << "Both store and load buffer should have a pointer type annotation."; int index_factor = 1; - if (dst_elem_type.value() != src_elem_type.value()) { + if (dst_elem_type != src_elem_type) { // The only case where src and dst have different dtypes is when the dst shared memory // is a byte buffer generated by merging dynamic shared memory. ICHECK(store->buffer.scope() == "shared.dyn"); - ICHECK(dst_elem_type.value() == DataType::UInt(8)); + ICHECK(dst_elem_type.second == DataType::UInt(8)); // BufferStore/Load have the "pointer reinterpret" semantics according to their // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; // To replace BufferStore/Load with cp.async, we need to multiply the store index by // the byte size of the "value" dtype, to get the correct offset into the byte buffer. - index_factor = src_elem_type->bytes(); + index_factor = src_elem_type.second.bytes(); } if (indices_lanes == 1) { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 177017f9a2..d15bed56fd 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -899,7 +899,7 @@ class StoragePlanRewriter : public StmtExprMutator { const StorageScope& scope, size_t const_nbits) { ICHECK(op != nullptr); // Re-use not successful, allocate a new buffer. - auto entry = std::make_unique<StorageEntry>(); + std::unique_ptr<StorageEntry> entry(new StorageEntry()); entry->attach_scope_ = attach_scope; entry->scope = scope; entry->elem_type = op->dtype.element_of(); @@ -1010,11 +1010,11 @@ class StoragePlanRewriter : public StmtExprMutator { // symbolic free list, for non constant items. std::list<StorageEntry*> sym_free_list_; // The allocation attach map - std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_; + std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_; // The allocation assign map std::unordered_map<const VarNode*, StorageEntry*> alloc_map_; // The allocations - std::vector<std::unique_ptr<StorageEntry>> alloc_vec_; + std::vector<std::unique_ptr<StorageEntry> > alloc_vec_; // The buffer objects being remapped std::unordered_map<const BufferNode*, Buffer> buffer_remap_; // analyzer @@ -1125,8 +1125,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // track the parameter itself. for (Var buffer_var : params) { auto pointer_type = GetPointerType(buffer_var->type_annotation); - if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) { - DataType dtype = pointer_type.value(); + if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) { + DataType dtype = pointer_type.second; PrimExpr extent = 0; OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap); } @@ -1190,8 +1190,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { void HandleLetNode(Var let_var) { if (let_var->dtype.is_handle()) { auto pointer_type = GetPointerType(let_var->type_annotation); - if (pointer_type.has_value()) { - OnArrayDeclaration(let_var, pointer_type.value(), 0, BufferVarInfo::kLetNode); + if (pointer_type.first) { + OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode); } else if (allow_untyped_pointers_) { OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); } else { @@ -1463,13 +1463,10 @@ class VectorTypeRewriter : public StmtExprMutator { Stmt VisitStmt_(const LetStmtNode* op) final { auto it = rewrite_map_.find(op->var.get()); - PrimExpr value = this->VisitExpr(op->value); - Stmt body = this->VisitStmt(op->body); - Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + if (it == rewrite_map_.end()) { return GetRef<Stmt>(op); } - return LetStmt(var, value, body); + return LetStmt(it->second.new_buffer_var, op->value, op->body); } Buffer RemapBuffer(Buffer buf) {
