This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8352f2f [TE][TensorIR] fix tensor attr in create_prim_func (#9764)
8352f2f is described below
commit 8352f2fc0dfe3062b9cbbefba9119e4b650a50bb
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Dec 20 10:14:36 2021 +0800
[TE][TensorIR] fix tensor attr in create_prim_func (#9764)
* [TE][TensorIR] fix tensor attr in create_prim_func
* Update src/te/operation/create_primfunc.cc
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
---
src/te/operation/create_primfunc.cc | 23 ++++++++++++++++++++++-
tests/python/unittest/test_te_create_primfunc.py | 16 ++++++++++++++++
2 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 81f6067..5de0538 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -25,6 +25,7 @@
#include <algorithm>
#include <unordered_set>
+#include "../../tir/ir/functor_common.h"
#include "../schedule/graph.h"
namespace tvm {
@@ -144,7 +145,27 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp&
compute_op, const te::
}
// Step 6. Add script_parsing_detect_access attr for auto complete the whole
IR.
- Map<String, ObjectRef> annotations = compute_op->attrs;
+ Map<String, ObjectRef> annotations;
+ auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
+ if (const auto* tensor_value = value.as<te::TensorNode>()) {
+ return info->tensor2buffers.at(GetRef<te::Tensor>(tensor_value));
+ } else {
+ return value;
+ }
+ };
+
+ for (const auto& pair : compute_op->attrs) {
+ const String& key = pair.first;
+ const ObjectRef& value = pair.second;
+ // TensorIR will not allow Tensor data structure
+ if (value->IsInstance<ArrayNode>()) {
+ const auto array_value = Downcast<Array<ObjectRef>>(value);
+ annotations.Set(key, MutateArray(array_value, mutate_attr));
+ } else {
+ annotations.Set(key, mutate_attr(value));
+ }
+ }
+ // Set script_parsing_detect_access
annotations.Set(tir::attr::script_parsing_detect_access,
IntImm(DataType::Int(32), 3));
// Step 7. Create Block and BlockRealize.
diff --git a/tests/python/unittest/test_te_create_primfunc.py
b/tests/python/unittest/test_te_create_primfunc.py
index 3a5512b..68ea2ab 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -344,6 +344,21 @@ def test_select_simplify():
assert script_func.find("Var") == -1
+def test_tensor_attr():
+ k = te.reduce_axis((0, 128), "k")
+ A = te.placeholder((128, 128), name="A")
+ B = te.placeholder((128, 128), name="B")
+ C = te.compute(
+ (128, 128),
+ lambda x, y: te.sum(A[x, k] * B[y, k], axis=k),
+ name="C",
+ attrs={"layout_free_placeholders": [B]},
+ )
+ func = te.create_prim_func([A, B, C])
+ rt_func = tvm.script.from_source(func.script())
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+
if __name__ == "__main__":
test_unique_name()
test_matmul()
@@ -355,3 +370,4 @@ if __name__ == "__main__":
test_error_reporting()
test_constant()
test_select_simplify()
+ test_tensor_attr()