This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new beef1f7c75 [TE] Support using tir::Var as CreatePrimFunc args (#15817)
beef1f7c75 is described below

commit beef1f7c758fb23d661199142f6c3195adf2c946
Author: Lesheng Jin <[email protected]>
AuthorDate: Tue Sep 26 18:33:55 2023 -0700

    [TE] Support using tir::Var as CreatePrimFunc args (#15817)
    
    - Allow `CreatePrimFunc` args to be a mixture of `te::Tensor` and 
`tir::Var`.
    - Integrate `CreatePrimFunc` and `CreateRelaxPrimFunc` into one function.
    
    ```python
    idx = te.var("idx", dtype="int64")
    m = te.var("m", dtype="int64")
    n = te.var("n", dtype="int64")
    tensor = te.placeholder((m, n), name="tensor")
    slice0 = te.compute((idx, n), lambda i, j: tensor[i, j], name="slice")
    # use idx as an arg
    te.create_prim_func([tensor, idx, slice])
    ```
---
 python/tvm/te/operation.py                       | 11 ++--
 src/te/operation/create_primfunc.cc              | 64 ++++++++++--------------
 src/te/operation/create_primfunc.h               |  8 ++-
 tests/python/unittest/test_te_create_primfunc.py | 29 +++++++++++
 4 files changed, 66 insertions(+), 46 deletions(-)

diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 372e99f654..ccd5baf2cd 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -19,7 +19,7 @@ import inspect
 
 # pylint: disable=invalid-name
 from numbers import Integral as _Integral
-from typing import List, Optional
+from typing import List, Optional, Union
 
 import tvm._ffi
 import tvm.arith._ffi_api
@@ -567,12 +567,12 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
 
 
 def create_prim_func(
-    ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None
+    ops: List[Union[_tensor.Tensor, tvm.tir.Var]], index_dtype_override: 
Optional[str] = None
 ) -> tvm.tir.PrimFunc:
     """Create a TensorIR PrimFunc from tensor expression
     Parameters
     ----------
-    ops : List[Tensor]
+    ops : List[Union[_tensor.Tensor, tvm.tir.Var]]
         The source expression.
     Example
     -------
@@ -672,4 +672,7 @@ def create_relax_prim_func(
     """
     if not isinstance(ops, (list, tuple, Array)):
         ops = [ops]
-    return _ffi_api.CreateRelaxPrimFunc(ops, tir_var_list, 
index_dtype_override)
+    arg_list = ops
+    if tir_var_list is not None:
+        arg_list += tir_var_list
+    return _ffi_api.CreatePrimFunc(arg_list, index_dtype_override)
diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index d3daf6e6c3..8a5be5ad93 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -570,10 +570,9 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
 }
 
 TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, 
TVMRetValue* ret) {
-  Array<te::Tensor> arg_list = args[0];
+  Array<ObjectRef> arg_list = args[0];
   std::optional<DataType> index_dtype_override{std::nullopt};
   // Add conversion to make std::optional compatible with FFI.
-  ICHECK_EQ(args.size(), 2);
   if (args[1].type_code() != kTVMNullptr) {
     index_dtype_override = args[1].operator DataType();
   }
@@ -581,26 +580,23 @@ 
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue*
 });
 
 // Relax version impl
-PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
-                                     const Array<Stmt>& root_stmts, 
CreateFuncInfo* info,
-                                     const Optional<Array<tir::Var>> 
tir_var_list) {
+PrimFunc GenerateAndCompletePrimFunc(const Array<ObjectRef>& arg_tir_var_list,
+                                     const Array<Stmt>& root_stmts, 
CreateFuncInfo* info) {
   Array<Var> parameters;
   Map<Var, Buffer> buffer_map;
-  for (const te::Tensor& tensor : arg_list) {
-    Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
-    parameters.push_back(arg);
-    auto it = info->tensor2buffers.find(tensor);
-    ICHECK(it != info->tensor2buffers.end());
-    buffer_map.Set(arg, it->second);
-  }
-
-  // add additional arguments for tir vars that are left unbound by match 
buffer
-  if (tir_var_list) {
-    for (const Var& v : tir_var_list.value()) {
-      parameters.push_back(v);
+  for (const ObjectRef& x : arg_tir_var_list) {
+    if (auto n = x.as<te::TensorNode>()) {
+      te::Tensor tensor = GetRef<te::Tensor>(n);
+      Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
+      parameters.push_back(arg);
+      auto it = info->tensor2buffers.find(tensor);
+      ICHECK(it != info->tensor2buffers.end());
+      buffer_map.Set(arg, it->second);
+    } else if (auto n = x.as<tir::VarNode>()) {
+      tir::Var var = GetRef<tir::Var>(n);
+      parameters.push_back(var);
     }
   }
-
   PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
                                      /*body=*/SeqStmt::Flatten(root_stmts),
                                      /*ret_type=*/VoidType(),
@@ -613,19 +609,25 @@ PrimFunc GenerateAndCompletePrimFunc(const 
Array<te::Tensor>& arg_list,
   return func;
 }
 
-PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
+PrimFunc CreatePrimFuncWithConstants(const Array<ObjectRef>& arg_list,
                                      const Array<runtime::NDArray>& constants,
-                                     const Optional<Array<tir::Var>>& 
tir_var_list,
                                      std::optional<DataType> 
index_dtype_override) {
+  Array<te::Tensor> tensor_arg_list;
+  for (const ObjectRef& x : arg_list) {
+    if (auto tensor_node = x.as<te::TensorNode>()) {
+      te::Tensor tensor = GetRef<te::Tensor>(tensor_node);
+      tensor_arg_list.push_back(tensor);
+    }
+  }
   // Infomations used in CreatePrimFunc and its sub-functions.
-  CreateFuncInfo info(arg_list);
+  CreateFuncInfo info(tensor_arg_list);
   // Root body stmts.
   Array<Stmt> root_stmts;
   // Analyzer
   arith::Analyzer analyzer;
 
   // Step 1. Create ordered array of operations and validate they are 
supported.
-  Array<te::Operation> order = CollectOrderedOps(arg_list);
+  Array<te::Operation> order = CollectOrderedOps(tensor_arg_list);
 
   // Step 2. Initialize buffer binds map
   InitializeBufferBinds(order, &info);
@@ -634,7 +636,7 @@ PrimFunc CreatePrimFuncWithConstants(const 
Array<te::Tensor>& arg_list,
   for (const te::Operation& op : order) {
     RewriteStageToBlock(op, &info, &root_stmts, &analyzer);
   }
-  auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info, 
tir_var_list);
+  auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
   func = tir::BindParams(func, constants);
   if (index_dtype_override.has_value()) {
     func = 
IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func));
@@ -643,22 +645,10 @@ PrimFunc CreatePrimFuncWithConstants(const 
Array<te::Tensor>& arg_list,
   return result;
 }
 
-PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
-                        const Optional<Array<tir::Var>> tir_var_list,
+PrimFunc CreatePrimFunc(const Array<ObjectRef>& arg_list,
                         std::optional<DataType> index_dtype_override) {
-  return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, 
index_dtype_override);
+  return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override);
 }
 
-TVM_REGISTER_GLOBAL("te.CreateRelaxPrimFunc").set_body([](TVMArgs args, 
TVMRetValue* ret) {
-  Array<te::Tensor> arg_list = args[0];
-  Optional<Array<tir::Var>> tir_var_list = args[1];
-  std::optional<DataType> index_dtype_override{std::nullopt};
-  // Add conversion to make std::optional compatible with FFI.
-  if (args[2].type_code() != kTVMNullptr) {
-    index_dtype_override = args[2].operator DataType();
-  }
-  *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override);
-});
-
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/te/operation/create_primfunc.h 
b/src/te/operation/create_primfunc.h
index 946f024849..496ee45ba4 100644
--- a/src/te/operation/create_primfunc.h
+++ b/src/te/operation/create_primfunc.h
@@ -45,8 +45,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& 
arg_list,
 // Relax version
 // TODO(relax-team) combine with the relay version
 /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
-PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
-                        const Optional<Array<tir::Var>> tir_var_list,
+PrimFunc CreatePrimFunc(const Array<ObjectRef>& arg_list,
                         std::optional<DataType> index_dtype_override);
 
 /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If 
the size of the
@@ -54,10 +53,9 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
  * Constant tensors will not be part of the parameters of the created 
PrimFunc, instead constants
  * will be embedded in the body as AllocateConstNode.
  */
-PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
+PrimFunc CreatePrimFuncWithConstants(const Array<ObjectRef>& arg_list,
                                      const Array<runtime::NDArray>& constants,
-                                     const Optional<Array<tir::Var>>& 
tir_var_list,
-                                     std::optional<DataType> 
index_dtype_override = std::nullopt);
+                                     std::optional<DataType> 
index_dtype_override);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_te_create_primfunc.py 
b/tests/python/unittest/test_te_create_primfunc.py
index 2598d620ba..326ef2b8ce 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -785,5 +785,34 @@ def test_extern_with_explicit_buffer_access():
     _check_workload(te_extern, tir_extern)
 
 
+def te_slice_with_var_input():
+    idx = te.var("idx", dtype="int64")
+    m = te.var("m", dtype="int64")
+    n = te.var("n", dtype="int64")
+    tensor = te.placeholder((m, n), name="tensor")
+    slice0 = te.compute((idx, n), lambda i, j: tensor[i, j], name="slice")
+    return [tensor, idx, slice0]
+
+
[email protected]_func
+def tir_slice_with_var_input(var_tensor: T.handle, idx: T.int64, var_slice: 
T.handle):
+    T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+    m, n = T.int64(), T.int64()
+    tensor = T.match_buffer(var_tensor, (m, n))
+    slice = T.match_buffer(var_slice, (idx, n))
+    # with T.block("root"):
+    for i, j in T.grid(idx, n):
+        with T.block("slice"):
+            v_i = T.axis.spatial(idx, i)
+            v_j = T.axis.spatial(n, j)
+            T.reads(tensor[v_i, v_j])
+            T.writes(slice[v_i, v_j])
+            slice[v_i, v_j] = tensor[v_i, v_j]
+
+
+def test_with_var_input():
+    _check_workload(te_slice_with_var_input, tir_slice_with_var_input, 
index_dtype_override="int64")
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to