This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2fd23843fd Remove create_relax_prim_func (#15828)
2fd23843fd is described below

commit 2fd23843fdc7ee7c6e4bcc888eb8c459d9bbe715
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Sep 27 01:09:30 2023 -0700

    Remove create_relax_prim_func (#15828)
    
    A followup pr of #15817.
---
 python/tvm/relax/utils.py  |  6 ++---
 python/tvm/te/__init__.py  |  1 -
 python/tvm/te/operation.py | 65 ----------------------------------------------
 3 files changed, 3 insertions(+), 69 deletions(-)

diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 02dd941080..a1fa9cafe8 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -27,7 +27,7 @@ from ..runtime import String, convert_to_object
 from . import _ffi_api
 from .expr import Tuple as rx_Tuple
 from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
-from ..te import Tensor as te_Tensor, create_relax_prim_func
+from ..te import Tensor as te_Tensor, create_prim_func
 from ..ir import Array, Attrs, Type, Map
 from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
 
@@ -441,8 +441,8 @@ def gen_call_tir_inputs(
     outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
     unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + 
tir_kwarg_list)
 
-    inputs = [*te_args] + outs
-    tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
+    inputs = [*te_args] + outs + unbound_tir_vars
+    tir_func = create_prim_func(inputs, "int64")
 
     if primfunc_attrs:
         tir_func = tir_func.with_attrs(primfunc_attrs)
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index 40fac0f92f..0907ea2ebf 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -41,7 +41,6 @@ from .tag import tag_scope
 from .operation import placeholder, compute, scan, extern, var, size_var, const
 from .operation import thread_axis, reduce_axis
 from .operation import create_prim_func
-from .operation import create_relax_prim_func
 from .operation import extern_primfunc
 
 from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, 
ExternOp, HybridOp
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index ccd5baf2cd..4edbe247f5 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -611,68 +611,3 @@ def create_prim_func(
     if not isinstance(ops, (list, tuple, Array)):
         ops = [ops]
     return _ffi_api.CreatePrimFunc(ops, index_dtype_override)
-
-
-def create_relax_prim_func(
-    ops: List[_tensor.Tensor],
-    tir_var_list: List[tvm.tir.Var] = None,
-    index_dtype_override: Optional[str] = None,
-) -> tvm.tir.PrimFunc:
-    """Create a TensorIR PrimFunc from tensor expression
-
-    Parameters
-    ----------
-    ops : List[Tensor]
-        The source expression.
-
-    tir_var_list: List[Var]
-        TIR variables to add as parameters to generated PrimFunc
-
-    Example
-    -------
-    We define a matmul kernel using following code:
-
-    .. code-block:: python
-
-        import tvm
-        from tvm import te
-        from tvm.te import create_prim_func
-        import tvm.script
-
-        A = te.placeholder((128, 128), name="A")
-        B = te.placeholder((128, 128), name="B")
-        k = te.reduce_axis((0, 128), "k")
-        C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], 
axis=k), name="C")
-        func = create_prim_func([A, B, C])
-        print(func.script())
-
-    If we want to use TensorIR schedule to do transformations on such kernel,
-    we need to use `create_prim_func([A, B, C])` to create a schedulable 
PrimFunc.
-    The generated function looks like:
-
-    .. code-block:: python
-
-        @T.prim_func
-        def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
-            A = T.match_buffer(a, (128, 128))
-            B = T.match_buffer(b, (128, 128))
-            C = T.match_buffer(c, (128, 128))
-
-            for i, j, k in T.grid(128, 128, 128):
-                with T.block():
-                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
-                    with T.init():
-                        C[vi, vj] = 0.0
-                    C[vi, vj] += A[vi, vk] * B[vj, vk]
-
-    Returns
-    -------
-    func : tir.PrimFunc
-        The created function.
-    """
-    if not isinstance(ops, (list, tuple, Array)):
-        ops = [ops]
-    arg_list = ops
-    if tir_var_list is not None:
-        arg_list += tir_var_list
-    return _ffi_api.CreatePrimFunc(arg_list, index_dtype_override)

Reply via email to