This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new beef1f7c75 [TE] Support using tir::Var as CreatePrimFunc args (#15817)
beef1f7c75 is described below
commit beef1f7c758fb23d661199142f6c3195adf2c946
Author: Lesheng Jin <[email protected]>
AuthorDate: Tue Sep 26 18:33:55 2023 -0700
[TE] Support using tir::Var as CreatePrimFunc args (#15817)
- Allow `CreatePrimFunc` args to be a mixture of `te::Tensor` and
`tir::Var`.
- Integrate `CreatePrimFunc` and `CreateRelaxPrimFunc` into one function.
```python
idx = te.var("idx", dtype="int64")
m = te.var("m", dtype="int64")
n = te.var("n", dtype="int64")
tensor = te.placeholder((m, n), name="tensor")
slice0 = te.compute((idx, n), lambda i, j: tensor[i, j], name="slice")
# use idx as an arg
te.create_prim_func([tensor, idx, slice])
```
---
python/tvm/te/operation.py | 11 ++--
src/te/operation/create_primfunc.cc | 64 ++++++++++--------------
src/te/operation/create_primfunc.h | 8 ++-
tests/python/unittest/test_te_create_primfunc.py | 29 +++++++++++
4 files changed, 66 insertions(+), 46 deletions(-)
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 372e99f654..ccd5baf2cd 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -19,7 +19,7 @@ import inspect
# pylint: disable=invalid-name
from numbers import Integral as _Integral
-from typing import List, Optional
+from typing import List, Optional, Union
import tvm._ffi
import tvm.arith._ffi_api
@@ -567,12 +567,12 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
def create_prim_func(
- ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None
+ ops: List[Union[_tensor.Tensor, tvm.tir.Var]], index_dtype_override:
Optional[str] = None
) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from tensor expression
Parameters
----------
- ops : List[Tensor]
+ ops : List[Union[_tensor.Tensor, tvm.tir.Var]]
The source expression.
Example
-------
@@ -672,4 +672,7 @@ def create_relax_prim_func(
"""
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
- return _ffi_api.CreateRelaxPrimFunc(ops, tir_var_list,
index_dtype_override)
+ arg_list = ops
+ if tir_var_list is not None:
+ arg_list += tir_var_list
+ return _ffi_api.CreatePrimFunc(arg_list, index_dtype_override)
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index d3daf6e6c3..8a5be5ad93 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -570,10 +570,9 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
}
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args,
TVMRetValue* ret) {
- Array<te::Tensor> arg_list = args[0];
+ Array<ObjectRef> arg_list = args[0];
std::optional<DataType> index_dtype_override{std::nullopt};
// Add conversion to make std::optional compatible with FFI.
- ICHECK_EQ(args.size(), 2);
if (args[1].type_code() != kTVMNullptr) {
index_dtype_override = args[1].operator DataType();
}
@@ -581,26 +580,23 @@
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue*
});
// Relax version impl
-PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
- const Array<Stmt>& root_stmts,
CreateFuncInfo* info,
- const Optional<Array<tir::Var>>
tir_var_list) {
+PrimFunc GenerateAndCompletePrimFunc(const Array<ObjectRef>& arg_tir_var_list,
+ const Array<Stmt>& root_stmts,
CreateFuncInfo* info) {
Array<Var> parameters;
Map<Var, Buffer> buffer_map;
- for (const te::Tensor& tensor : arg_list) {
- Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
- parameters.push_back(arg);
- auto it = info->tensor2buffers.find(tensor);
- ICHECK(it != info->tensor2buffers.end());
- buffer_map.Set(arg, it->second);
- }
-
- // add additional arguments for tir vars that are left unbound by match
buffer
- if (tir_var_list) {
- for (const Var& v : tir_var_list.value()) {
- parameters.push_back(v);
+ for (const ObjectRef& x : arg_tir_var_list) {
+ if (auto n = x.as<te::TensorNode>()) {
+ te::Tensor tensor = GetRef<te::Tensor>(n);
+ Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
+ parameters.push_back(arg);
+ auto it = info->tensor2buffers.find(tensor);
+ ICHECK(it != info->tensor2buffers.end());
+ buffer_map.Set(arg, it->second);
+ } else if (auto n = x.as<tir::VarNode>()) {
+ tir::Var var = GetRef<tir::Var>(n);
+ parameters.push_back(var);
}
}
-
PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
/*body=*/SeqStmt::Flatten(root_stmts),
/*ret_type=*/VoidType(),
@@ -613,19 +609,25 @@ PrimFunc GenerateAndCompletePrimFunc(const
Array<te::Tensor>& arg_list,
return func;
}
-PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
+PrimFunc CreatePrimFuncWithConstants(const Array<ObjectRef>& arg_list,
const Array<runtime::NDArray>& constants,
- const Optional<Array<tir::Var>>&
tir_var_list,
std::optional<DataType>
index_dtype_override) {
+ Array<te::Tensor> tensor_arg_list;
+ for (const ObjectRef& x : arg_list) {
+ if (auto tensor_node = x.as<te::TensorNode>()) {
+ te::Tensor tensor = GetRef<te::Tensor>(tensor_node);
+ tensor_arg_list.push_back(tensor);
+ }
+ }
// Infomations used in CreatePrimFunc and its sub-functions.
- CreateFuncInfo info(arg_list);
+ CreateFuncInfo info(tensor_arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
// Analyzer
arith::Analyzer analyzer;
// Step 1. Create ordered array of operations and validate they are
supported.
- Array<te::Operation> order = CollectOrderedOps(arg_list);
+ Array<te::Operation> order = CollectOrderedOps(tensor_arg_list);
// Step 2. Initialize buffer binds map
InitializeBufferBinds(order, &info);
@@ -634,7 +636,7 @@ PrimFunc CreatePrimFuncWithConstants(const
Array<te::Tensor>& arg_list,
for (const te::Operation& op : order) {
RewriteStageToBlock(op, &info, &root_stmts, &analyzer);
}
- auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info,
tir_var_list);
+ auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
func = tir::BindParams(func, constants);
if (index_dtype_override.has_value()) {
func =
IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func));
@@ -643,22 +645,10 @@ PrimFunc CreatePrimFuncWithConstants(const
Array<te::Tensor>& arg_list,
return result;
}
-PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
- const Optional<Array<tir::Var>> tir_var_list,
+PrimFunc CreatePrimFunc(const Array<ObjectRef>& arg_list,
std::optional<DataType> index_dtype_override) {
- return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list,
index_dtype_override);
+ return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override);
}
-TVM_REGISTER_GLOBAL("te.CreateRelaxPrimFunc").set_body([](TVMArgs args,
TVMRetValue* ret) {
- Array<te::Tensor> arg_list = args[0];
- Optional<Array<tir::Var>> tir_var_list = args[1];
- std::optional<DataType> index_dtype_override{std::nullopt};
- // Add conversion to make std::optional compatible with FFI.
- if (args[2].type_code() != kTVMNullptr) {
- index_dtype_override = args[2].operator DataType();
- }
- *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override);
-});
-
} // namespace tir
} // namespace tvm
diff --git a/src/te/operation/create_primfunc.h
b/src/te/operation/create_primfunc.h
index 946f024849..496ee45ba4 100644
--- a/src/te/operation/create_primfunc.h
+++ b/src/te/operation/create_primfunc.h
@@ -45,8 +45,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>&
arg_list,
// Relax version
// TODO(relax-team) combine with the relay version
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
-PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
- const Optional<Array<tir::Var>> tir_var_list,
+PrimFunc CreatePrimFunc(const Array<ObjectRef>& arg_list,
std::optional<DataType> index_dtype_override);
/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If
the size of the
@@ -54,10 +53,9 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
* Constant tensors will not be part of the parameters of the created
PrimFunc, instead constants
* will be embedded in the body as AllocateConstNode.
*/
-PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
+PrimFunc CreatePrimFuncWithConstants(const Array<ObjectRef>& arg_list,
const Array<runtime::NDArray>& constants,
- const Optional<Array<tir::Var>>&
tir_var_list,
- std::optional<DataType>
index_dtype_override = std::nullopt);
+ std::optional<DataType>
index_dtype_override);
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_te_create_primfunc.py
b/tests/python/unittest/test_te_create_primfunc.py
index 2598d620ba..326ef2b8ce 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -785,5 +785,34 @@ def test_extern_with_explicit_buffer_access():
_check_workload(te_extern, tir_extern)
+def te_slice_with_var_input():
+ idx = te.var("idx", dtype="int64")
+ m = te.var("m", dtype="int64")
+ n = te.var("n", dtype="int64")
+ tensor = te.placeholder((m, n), name="tensor")
+ slice0 = te.compute((idx, n), lambda i, j: tensor[i, j], name="slice")
+ return [tensor, idx, slice0]
+
+
[email protected]_func
+def tir_slice_with_var_input(var_tensor: T.handle, idx: T.int64, var_slice:
T.handle):
+ T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+ m, n = T.int64(), T.int64()
+ tensor = T.match_buffer(var_tensor, (m, n))
+ slice = T.match_buffer(var_slice, (idx, n))
+ # with T.block("root"):
+ for i, j in T.grid(idx, n):
+ with T.block("slice"):
+ v_i = T.axis.spatial(idx, i)
+ v_j = T.axis.spatial(n, j)
+ T.reads(tensor[v_i, v_j])
+ T.writes(slice[v_i, v_j])
+ slice[v_i, v_j] = tensor[v_i, v_j]
+
+
+def test_with_var_input():
+ _check_workload(te_slice_with_var_input, tir_slice_with_var_input,
index_dtype_override="int64")
+
+
if __name__ == "__main__":
tvm.testing.main()