This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 29ce66eeb4 Change tir::GetPointerType to return 
std::optional<DataType> (#12458)
29ce66eeb4 is described below

commit 29ce66eeb46eae3c73079177609e2319be0366dd
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Tue Aug 16 15:34:34 2022 -0500

    Change tir::GetPointerType to return std::optional<DataType> (#12458)
    
    It was returning a std::pair<bool, DataType> to emulate the behavior
    of std::optional.
---
 src/tir/ir/buffer_common.h                  | 16 +++++++---------
 src/tir/ir/expr.cc                          |  8 ++++----
 src/tir/ir/stmt.cc                          |  8 ++++----
 src/tir/transforms/inject_ptx_async_copy.cc |  8 ++++----
 src/tir/transforms/storage_rewrite.cc       |  8 ++++----
 5 files changed, 23 insertions(+), 25 deletions(-)

diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h
index 8dac41a02e..5921c54d98 100644
--- a/src/tir/ir/buffer_common.h
+++ b/src/tir/ir/buffer_common.h
@@ -26,7 +26,7 @@
 #include <tvm/ir/type.h>
 #include <tvm/runtime/data_type.h>
 
-#include <utility>
+#include <optional>
 
 namespace tvm {
 namespace tir {
@@ -36,22 +36,20 @@ namespace tir {
  *
  * \param type The type to be checked.
  *
- * \return A (bool, DataType) pair.  If the type is a pointer to a
- * primitive, the boolean is true and the DataType is the pointed-to
- * type.  Otherwise, the boolean is false and the DataType is
- * default-constructed.  This can be replaced with std::optional with
- * C++17 if/when C++17 is required.
+ * \return An std::optional<DataType> object. If the type is a pointer
+ * to a primitive type, the object has a value which is the pointed-to
+ * type. Otherwise the object is nullopt.
  */
-inline std::pair<bool, runtime::DataType> GetPointerType(const Type& type) {
+inline std::optional<runtime::DataType> GetPointerType(const Type& type) {
   if (type.defined()) {
     if (auto* ptr_type = type.as<PointerTypeNode>()) {
       if (auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
-        return {true, prim_type->dtype};
+        return prim_type->dtype;
       }
     }
   }
 
-  return {false, DataType()};
+  return std::nullopt;
 }
 
 }  // namespace tir
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index f841f94b5a..59db4ea410 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -648,7 +648,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, 
PrimExpr predicate, S
   // annotation tells us otherwise.
   int element_lanes = 1;
   auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
-  if (pointer_type.first) {
+  if (pointer_type.has_value()) {
     // Cannot check element type of array, as it may be different than
     // the loaded type in some cases.
     //
@@ -663,11 +663,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr 
index, PrimExpr predicate, S
     // See 
https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
     // for discussion.
 
-    // ICHECK(dtype.element_of() == pointer_type.second.element_of())
+    // ICHECK(dtype.element_of() == pointer_type->element_of())
     //     << "Type mismatch, cannot load type " << dtype << " from buffer " <<
     //     buffer_var->name_hint
-    //     << " of type " << pointer_type.second;
-    element_lanes = pointer_type.second.lanes();
+    //     << " of type " << pointer_type.value();
+    element_lanes = pointer_type->lanes();
   }
 
   // The C-based codegens assume that all loads occur on a array with
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..e21d014fe1 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -271,7 +271,7 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr 
index, PrimExpr predicate,
   // annotation tells us otherwise.
   int element_lanes = 1;
   auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
-  if (pointer_type.first) {
+  if (pointer_type.has_value()) {
     // Currently cannot check element type of array, see Load::Load
     // for details.
 
@@ -279,10 +279,10 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr 
index, PrimExpr predicate,
     // See 
https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
     // for discussion.
 
-    // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of())
+    // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of())
     //     << "Type mismatch, cannot store type " << value.dtype() << " into 
buffer "
-    //     << buffer_var->name_hint << " of type " << pointer_type.second;
-    element_lanes = pointer_type.second.lanes();
+    //     << buffer_var->name_hint << " of type " << pointer_type.value();
+    element_lanes = pointer_type->lanes();
   }
 
   ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc 
b/src/tir/transforms/inject_ptx_async_copy.cc
index c74ce9d3d2..8ee0d054e5 100644
--- a/src/tir/transforms/inject_ptx_async_copy.cc
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -60,21 +60,21 @@ class PTXAsyncCopyInjector : public StmtMutator {
           if (bytes == 4 || bytes == 8 || bytes == 16) {
             auto dst_elem_type = 
GetPointerType(store->buffer->data->type_annotation);
             auto src_elem_type = 
GetPointerType(load->buffer->data->type_annotation);
-            ICHECK(dst_elem_type.first && src_elem_type.first)
+            ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
                 << "Both store and load buffer should have a pointer type 
annotation.";
 
             int index_factor = 1;
-            if (dst_elem_type != src_elem_type) {
+            if (dst_elem_type.value() != src_elem_type.value()) {
               // The only case where src and dst have different dtypes is when 
the dst shared memory
               // is a byte buffer generated by merging dynamic shared memory.
               ICHECK(store->buffer.scope() == "shared.dyn");
-              ICHECK(dst_elem_type.second == DataType::UInt(8));
+              ICHECK(dst_elem_type.value() == DataType::UInt(8));
               // BufferStore/Load have the "pointer reinterpret" semantics 
according to their
               // "value" dtype. Their "indices" are supposed to be applied 
after such pointer cast,
               // for example: ((*float16)(byte_buffer))[buffer->indices] = 
fp16_value;
               // To replace BufferStore/Load with cp.async, we need to 
multiply the store index by
               // the byte size of the "value" dtype, to get the correct offset 
into the byte buffer.
-              index_factor = src_elem_type.second.bytes();
+              index_factor = src_elem_type->bytes();
             }
 
             if (indices_lanes == 1) {
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index ee6b3c836b..421af99ab3 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1125,8 +1125,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
     // track the parameter itself.
     for (Var buffer_var : params) {
       auto pointer_type = GetPointerType(buffer_var->type_annotation);
-      if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
-        DataType dtype = pointer_type.second;
+      if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
+        DataType dtype = pointer_type.value();
         PrimExpr extent = 0;
         OnArrayDeclaration(buffer_var, dtype, extent, 
BufferVarInfo::kPrimFuncBufferMap);
       }
@@ -1190,8 +1190,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
   void HandleLetNode(Var let_var) {
     if (let_var->dtype.is_handle()) {
       auto pointer_type = GetPointerType(let_var->type_annotation);
-      if (pointer_type.first) {
-        OnArrayDeclaration(let_var, pointer_type.second, 0, 
BufferVarInfo::kLetNode);
+      if (pointer_type.has_value()) {
+        OnArrayDeclaration(let_var, pointer_type.value(), 0, 
BufferVarInfo::kLetNode);
       } else if (allow_untyped_pointers_) {
         OnArrayDeclaration(let_var, let_var->dtype, 0, 
BufferVarInfo::kLetNode);
       } else {

Reply via email to