This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8352f2f  [TE][TensorIR] fix tensor attr in create_prim_func (#9764)
8352f2f is described below

commit 8352f2fc0dfe3062b9cbbefba9119e4b650a50bb
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Dec 20 10:14:36 2021 +0800

    [TE][TensorIR] fix tensor attr in create_prim_func (#9764)
    
    * [TE][TensorIR] fix tensor attr in create_prim_func
    
    * Update src/te/operation/create_primfunc.cc
    
    Co-authored-by: Junru Shao <[email protected]>
    
    Co-authored-by: Junru Shao <[email protected]>
---
 src/te/operation/create_primfunc.cc              | 23 ++++++++++++++++++++++-
 tests/python/unittest/test_te_create_primfunc.py | 16 ++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index 81f6067..5de0538 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -25,6 +25,7 @@
 #include <algorithm>
 #include <unordered_set>
 
+#include "../../tir/ir/functor_common.h"
 #include "../schedule/graph.h"
 
 namespace tvm {
@@ -144,7 +145,27 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& 
compute_op, const te::
   }
 
   // Step 6. Add script_parsing_detect_access attr for auto complete the whole 
IR.
-  Map<String, ObjectRef> annotations = compute_op->attrs;
+  Map<String, ObjectRef> annotations;
+  auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
+    if (const auto* tensor_value = value.as<te::TensorNode>()) {
+      return info->tensor2buffers.at(GetRef<te::Tensor>(tensor_value));
+    } else {
+      return value;
+    }
+  };
+
+  for (const auto& pair : compute_op->attrs) {
+    const String& key = pair.first;
+    const ObjectRef& value = pair.second;
+    // TensorIR will not allow Tensor data structure
+    if (value->IsInstance<ArrayNode>()) {
+      const auto array_value = Downcast<Array<ObjectRef>>(value);
+      annotations.Set(key, MutateArray(array_value, mutate_attr));
+    } else {
+      annotations.Set(key, mutate_attr(value));
+    }
+  }
+  // Set script_parsing_detect_access
   annotations.Set(tir::attr::script_parsing_detect_access, 
IntImm(DataType::Int(32), 3));
 
   // Step 7. Create Block and BlockRealize.
diff --git a/tests/python/unittest/test_te_create_primfunc.py 
b/tests/python/unittest/test_te_create_primfunc.py
index 3a5512b..68ea2ab 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -344,6 +344,21 @@ def test_select_simplify():
     assert script_func.find("Var") == -1
 
 
+def test_tensor_attr():
+    k = te.reduce_axis((0, 128), "k")
+    A = te.placeholder((128, 128), name="A")
+    B = te.placeholder((128, 128), name="B")
+    C = te.compute(
+        (128, 128),
+        lambda x, y: te.sum(A[x, k] * B[y, k], axis=k),
+        name="C",
+        attrs={"layout_free_placeholders": [B]},
+    )
+    func = te.create_prim_func([A, B, C])
+    rt_func = tvm.script.from_source(func.script())
+    tvm.ir.assert_structural_equal(func, rt_func)
+
+
 if __name__ == "__main__":
     test_unique_name()
     test_matmul()
@@ -355,3 +370,4 @@ if __name__ == "__main__":
     test_error_reporting()
     test_constant()
     test_select_simplify()
+    test_tensor_attr()

Reply via email to