This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2fd23843fd Remove create_relax_prim_func (#15828)
2fd23843fd is described below
commit 2fd23843fdc7ee7c6e4bcc888eb8c459d9bbe715
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Sep 27 01:09:30 2023 -0700
Remove create_relax_prim_func (#15828)
A followup pr of #15817.
---
python/tvm/relax/utils.py | 6 ++---
python/tvm/te/__init__.py | 1 -
python/tvm/te/operation.py | 65 ----------------------------------------------
3 files changed, 3 insertions(+), 69 deletions(-)
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 02dd941080..a1fa9cafe8 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -27,7 +27,7 @@ from ..runtime import String, convert_to_object
from . import _ffi_api
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
-from ..te import Tensor as te_Tensor, create_relax_prim_func
+from ..te import Tensor as te_Tensor, create_prim_func
from ..ir import Array, Attrs, Type, Map
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
@@ -441,8 +441,8 @@ def gen_call_tir_inputs(
outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list +
tir_kwarg_list)
- inputs = [*te_args] + outs
- tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
+ inputs = [*te_args] + outs + unbound_tir_vars
+ tir_func = create_prim_func(inputs, "int64")
if primfunc_attrs:
tir_func = tir_func.with_attrs(primfunc_attrs)
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index 40fac0f92f..0907ea2ebf 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -41,7 +41,6 @@ from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func
-from .operation import create_relax_prim_func
from .operation import extern_primfunc
from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp,
ExternOp, HybridOp
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index ccd5baf2cd..4edbe247f5 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -611,68 +611,3 @@ def create_prim_func(
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops, index_dtype_override)
-
-
-def create_relax_prim_func(
- ops: List[_tensor.Tensor],
- tir_var_list: List[tvm.tir.Var] = None,
- index_dtype_override: Optional[str] = None,
-) -> tvm.tir.PrimFunc:
- """Create a TensorIR PrimFunc from tensor expression
-
- Parameters
- ----------
- ops : List[Tensor]
- The source expression.
-
- tir_var_list: List[Var]
- TIR variables to add as parameters to generated PrimFunc
-
- Example
- -------
- We define a matmul kernel using following code:
-
- .. code-block:: python
-
- import tvm
- from tvm import te
- from tvm.te import create_prim_func
- import tvm.script
-
- A = te.placeholder((128, 128), name="A")
- B = te.placeholder((128, 128), name="B")
- k = te.reduce_axis((0, 128), "k")
- C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k],
axis=k), name="C")
- func = create_prim_func([A, B, C])
- print(func.script())
-
- If we want to use TensorIR schedule to do transformations on such kernel,
- we need to use `create_prim_func([A, B, C])` to create a schedulable
PrimFunc.
- The generated function looks like:
-
- .. code-block:: python
-
- @T.prim_func
- def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128, 128))
- C = T.match_buffer(c, (128, 128))
-
- for i, j, k in T.grid(128, 128, 128):
- with T.block():
- vi, vj, vk = T.axis.remap("SSR", [i, j, k])
- with T.init():
- C[vi, vj] = 0.0
- C[vi, vj] += A[vi, vk] * B[vj, vk]
-
- Returns
- -------
- func : tir.PrimFunc
- The created function.
- """
- if not isinstance(ops, (list, tuple, Array)):
- ops = [ops]
- arg_list = ops
- if tir_var_list is not None:
- arg_list += tir_var_list
- return _ffi_api.CreatePrimFunc(arg_list, index_dtype_override)