This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 29ce66eeb4 Change tir::GetPointerType to return
std::optional<DataType> (#12458)
29ce66eeb4 is described below
commit 29ce66eeb46eae3c73079177609e2319be0366dd
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Tue Aug 16 15:34:34 2022 -0500
Change tir::GetPointerType to return std::optional<DataType> (#12458)
It was returning a std::pair<bool, DataType> to emulate the behavior
of std::optional.
---
src/tir/ir/buffer_common.h | 16 +++++++---------
src/tir/ir/expr.cc | 8 ++++----
src/tir/ir/stmt.cc | 8 ++++----
src/tir/transforms/inject_ptx_async_copy.cc | 8 ++++----
src/tir/transforms/storage_rewrite.cc | 8 ++++----
5 files changed, 23 insertions(+), 25 deletions(-)
diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h
index 8dac41a02e..5921c54d98 100644
--- a/src/tir/ir/buffer_common.h
+++ b/src/tir/ir/buffer_common.h
@@ -26,7 +26,7 @@
#include <tvm/ir/type.h>
#include <tvm/runtime/data_type.h>
-#include <utility>
+#include <optional>
namespace tvm {
namespace tir {
@@ -36,22 +36,20 @@ namespace tir {
*
* \param type The type to be checked.
*
- * \return A (bool, DataType) pair. If the type is a pointer to a
- * primitive, the boolean is true and the DataType is the pointed-to
- * type. Otherwise, the boolean is false and the DataType is
- * default-constructed. This can be replaced with std::optional with
- * C++17 if/when C++17 is required.
+ * \return An std::optional<DataType> object. If the type is a pointer
+ * to a primitive type, the object has a value which is the pointed-to
+ * type. Otherwise the object is nullopt.
*/
-inline std::pair<bool, runtime::DataType> GetPointerType(const Type& type) {
+inline std::optional<runtime::DataType> GetPointerType(const Type& type) {
if (type.defined()) {
if (auto* ptr_type = type.as<PointerTypeNode>()) {
if (auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
- return {true, prim_type->dtype};
+ return prim_type->dtype;
}
}
}
- return {false, DataType()};
+ return std::nullopt;
}
} // namespace tir
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index f841f94b5a..59db4ea410 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -648,7 +648,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index,
PrimExpr predicate, S
// annotation tells us otherwise.
int element_lanes = 1;
auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
- if (pointer_type.first) {
+ if (pointer_type.has_value()) {
// Cannot check element type of array, as it may be different than
// the loaded type in some cases.
//
@@ -663,11 +663,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr
index, PrimExpr predicate, S
// See
https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
- // ICHECK(dtype.element_of() == pointer_type.second.element_of())
+ // ICHECK(dtype.element_of() == pointer_type->element_of())
// << "Type mismatch, cannot load type " << dtype << " from buffer " <<
// buffer_var->name_hint
- // << " of type " << pointer_type.second;
- element_lanes = pointer_type.second.lanes();
+ // << " of type " << pointer_type.value();
+ element_lanes = pointer_type->lanes();
}
// The C-based codegens assume that all loads occur on a array with
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..e21d014fe1 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -271,7 +271,7 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr
index, PrimExpr predicate,
// annotation tells us otherwise.
int element_lanes = 1;
auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
- if (pointer_type.first) {
+ if (pointer_type.has_value()) {
// Currently cannot check element type of array, see Load::Load
// for details.
@@ -279,10 +279,10 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr
index, PrimExpr predicate,
// See
https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
- // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of())
+ // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of())
// << "Type mismatch, cannot store type " << value.dtype() << " into
buffer "
- // << buffer_var->name_hint << " of type " << pointer_type.second;
- element_lanes = pointer_type.second.lanes();
+ // << buffer_var->name_hint << " of type " << pointer_type.value();
+ element_lanes = pointer_type->lanes();
}
ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc
b/src/tir/transforms/inject_ptx_async_copy.cc
index c74ce9d3d2..8ee0d054e5 100644
--- a/src/tir/transforms/inject_ptx_async_copy.cc
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -60,21 +60,21 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type =
GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type =
GetPointerType(load->buffer->data->type_annotation);
- ICHECK(dst_elem_type.first && src_elem_type.first)
+ ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type
annotation.";
int index_factor = 1;
- if (dst_elem_type != src_elem_type) {
+ if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when
the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
- ICHECK(dst_elem_type.second == DataType::UInt(8));
+ ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics
according to their
// "value" dtype. Their "indices" are supposed to be applied
after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] =
fp16_value;
// To replace BufferStore/Load with cp.async, we need to
multiply the store index by
// the byte size of the "value" dtype, to get the correct offset
into the byte buffer.
- index_factor = src_elem_type.second.bytes();
+ index_factor = src_elem_type->bytes();
}
if (indices_lanes == 1) {
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index ee6b3c836b..421af99ab3 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1125,8 +1125,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
// track the parameter itself.
for (Var buffer_var : params) {
auto pointer_type = GetPointerType(buffer_var->type_annotation);
- if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
- DataType dtype = pointer_type.second;
+ if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
+ DataType dtype = pointer_type.value();
PrimExpr extent = 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncBufferMap);
}
@@ -1190,8 +1190,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
void HandleLetNode(Var let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
- if (pointer_type.first) {
- OnArrayDeclaration(let_var, pointer_type.second, 0,
BufferVarInfo::kLetNode);
+ if (pointer_type.has_value()) {
+ OnArrayDeclaration(let_var, pointer_type.value(), 0,
BufferVarInfo::kLetNode);
} else if (allow_untyped_pointers_) {
OnArrayDeclaration(let_var, let_var->dtype, 0,
BufferVarInfo::kLetNode);
} else {