This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch junrushao1994-patch-1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b992246bdc0e53b642b98e9a17190eb73d0a624f Author: Junru Shao <[email protected]> AuthorDate: Thu Sep 2 21:47:31 2021 -0700 [TensorIR][Minor] Allow Tuple/Array in TE lowering --- python/tvm/te/operation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 6af3429..a0b9b43 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -22,13 +22,13 @@ from typing import List import tvm._ffi import tvm.tir import tvm.tir._ffi_api - from tvm._ffi.base import string_types +from tvm.ir import Array from tvm.runtime import convert +from . import _ffi_api from . import tag as _tag from . import tensor as _tensor -from . import _ffi_api def placeholder(shape, dtype=None, name="placeholder"): @@ -431,6 +431,7 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters ---------- ops : List[Tensor] @@ -473,6 +474,6 @@ def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: func : tir.PrimFunc The created function. """ - if not isinstance(ops, list): + if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops)
