This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bbb93370cb [Unity] e2e Relax minimum build flow (#13961)
bbb93370cb is described below
commit bbb93370cbb888778a244999fa979a9dee32f4c0
Author: Yuchen Jin <[email protected]>
AuthorDate: Sat Feb 11 09:26:36 2023 -0800
[Unity] e2e Relax minimum build flow (#13961)
This PR introduces the e2e Relax lowering flow (`relax.vm.build`). Tests
for each pass in the flow are added.
Co-Authored-by: Altan Haan <[email protected]>
Co-Authored-by: Andrew Liu <[email protected]>
Co-Authored-by: Hongyi Jin <[email protected]>
Co-Authored-by: Jiawei Liu <[email protected]>
Co-Authored-by: Junru Shao <[email protected]>
Co-Authored-by: Prakalp Srivastava <[email protected]>
Co-Authored-by: Ruihang Lai <[email protected]>
Co-Authored-by: Siyuan Feng <[email protected]>
Co-Authored-by: Steven S. <Lyubomirsky [email protected]>
Co-Authored-by: Sunghyun Park <[email protected]>
Co-Authored-by: Tianqi Chen <[email protected]>
Co-Authored-by: Yong Wu <[email protected]>
Co-Authored-by: Ziheng Jiang <[email protected]>
---
CMakeLists.txt | 1 +
include/tvm/relax/analysis.h | 16 +
include/tvm/relax/backend.h | 7 +
include/tvm/relax/transform.h | 35 +
python/tvm/relax/analysis/analysis.py | 45 +
python/tvm/relax/op/__init__.py | 3 +
python/tvm/relax/op/{ => builtin}/__init__.py | 6 +-
.../relax/op/{__init__.py => builtin/_ffi_api.py} | 9 +-
python/tvm/relax/op/builtin/builtin.py | 44 +
python/tvm/relax/op/manipulate.py | 62 ++
python/tvm/relax/op/{ => memory}/__init__.py | 6 +-
.../relax/op/{__init__.py => memory/_ffi_api.py} | 9 +-
python/tvm/relax/op/memory/memory.py | 108 +++
python/tvm/relax/{op => testing}/__init__.py | 6 +-
python/tvm/relax/testing/nn.py | 194 +++++
python/tvm/relax/transform/transform.py | 53 ++
python/tvm/relax/vm.py | 4 +-
python/tvm/script/ir_builder/relax/ir.py | 6 +
src/relax/analysis/tir_op_pattern_kind.cc | 447 ++++++++++
src/relax/backend/vm/vm_builtin_lower.cc | 208 +++++
src/relax/op/op.cc | 81 ++
src/relax/op/tensor/manipulate.cc | 163 ++++
.../backend.h => src/relax/op/tensor/manipulate.h | 25 +-
src/relax/transform/attach_global_symbol.cc | 68 ++
src/relax/transform/call_tir_rewrite.cc | 137 ++++
src/relax/transform/rewrite_dataflow_reshape.cc | 110 +++
src/relax/transform/to_non_dataflow.cc | 67 ++
tests/python/relax/test_analysis.py | 172 ++++
tests/python/relax/test_transform.py | 141 ++++
.../relax/test_transform_attach_global_symbol.py | 88 ++
.../test_transform_rewrite_dataflow_reshape.py | 166 ++++
tests/python/relax/test_vm_build.py | 908 +++++++++++++++++++++
32 files changed, 3358 insertions(+), 37 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ab21d90e00..5a28a9acde 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -292,6 +292,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/ir/*.cc
src/relax/op/*.cc
src/relax/analysis/*.cc
+ src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/utils.cc
)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index ad2bd19aa4..24cfe5b9bf 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -259,6 +259,22 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const
StructInfo& derived,
*/
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);
+
+/*!
+ * \brief Check if the given PrimFunc is essentially doing a reshape operation.
+ * The reshape operation also includes expand_dims, squeeze, flatten, etc.
+ * \details Here the allowed reshape pattern is: for example, assume the
operation is
+ * `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove
that the flattened
+ * index of l_0, ..., l_b under buffer B equals to the flattened index of r_0,
..., r_a under
+ * buffer A.
+ * \param func The function to be examined.
+ * \return A boolean indicating if the given PrimFunc is doing a reshape.
+ * \note According to the description above, the returned result can only be
false-negative and
+ * cannot be false-positive, since whenever we cannot prove the equality, we
return false. This
+ * property guarantees the safety of this function.
+ */
+TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
+
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
index 4ebeacac0f..2fb11f5a6f 100644
--- a/include/tvm/relax/backend.h
+++ b/include/tvm/relax/backend.h
@@ -30,6 +30,13 @@ namespace tvm {
namespace relax {
namespace transform {
+/*!
+ * \brief Perform builtin lowering to map most of the op to VM builtin
functions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass VMBuiltinLower();
+
/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR
functions.
*
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index fa288a7f06..ff98b16d25 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -65,6 +65,41 @@ TVM_DLL Pass CreateDataflowBlockPass(
const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
+/*!
+ * \brief Transform all dataflow structure to non-dataflow version.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass ToNonDataflow();
+
+/*!
+ * \brief Perform explicit tensor allocation for call_tir.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass CallTIRRewrite();
+
+/*!
+ * \brief Convert all reshape-like call_tir whose corresponding binding
+ * vars are DataflowVars to relax.reshape operator calls. The relax.reshape
+ * calls will be lowered an external builtin function call in a subsequent
+ * pass, where the external builtin function does a CreateView operation
+ * at runtime, instead of doing real data copy.
+ * Here "reshape-like" includes reshape, expand_dims, flatten, etc.
+ *
+ * \return The Pass.
+ * \note The pass is applied at the first stage of Relax VM build, before
+ * rewriting call_tir, as this pass requires dataflow information.
+ */
+TVM_DLL Pass RewriteDataflowReshape();
+
+/*!
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for
codegen.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass AttachGlobalSymbol();
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index d81c477145..27416c3a79 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -162,3 +162,48 @@ def struct_info_lca(lhs: StructInfo, rhs: StructInfo) ->
StructInfo:
The corresponding lca result.
"""
return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore
+
+
+def post_order_visit(expr, fvisit):
+ """Recursively visit the ir in post DFS order node,
+ apply fvisit. Each node is guaranteed to be visited
+ only once.
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The input expression.
+
+ fvisit : function
+ The visitor function to be applied.
+ """
+ return _ffi_api.post_order_visit(expr, fvisit) # type: ignore
+
+
+def has_reshape_pattern(func: tir.PrimFunc) -> bool:
+ """Check if the given PrimFunc is essentially doing a reshape operation.
+ The reshape operation also includes expand_dims, squeeze, flatten, etc.
+
+ Here the allowed reshape pattern is: for example, assume the operation is
+ `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove
+ that the flattened index of l_0, ..., l_b under buffer B equals to the
+ flattened index of r_0, ..., r_a under buffer A.
+
+ Parameters
+ ----------
+ func : tir.PrimFunc
+ The function to be examined.
+
+ Returns
+ -------
+ ret : bool
+ A boolean indicating if the given PrimFunc is doing a reshape.
+
+ Notes
+ -----
+ According to the description above, the returned result can only be
+ false-negative and cannot be false-positive, since whenever we cannot
+ prove the equality, we return false. This property guarantees the safety
+ of this function.
+ """
+ return _ffi_api.has_reshape_pattern(func) # type: ignore
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 101b0827d6..9a131cdf95 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -20,3 +20,6 @@
# Operators
from .base import *
from .binary import *
+from .manipulate import *
+from . import builtin
+from . import memory
diff --git a/python/tvm/relax/op/__init__.py
b/python/tvm/relax/op/builtin/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/builtin/__init__.py
index 101b0827d6..04837724b1 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/builtin/__init__.py
@@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Relax builtin operators."""
-# Operators
-from .base import *
-from .binary import *
+from .builtin import *
diff --git a/python/tvm/relax/op/__init__.py
b/python/tvm/relax/op/builtin/_ffi_api.py
similarity index 83%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/builtin/_ffi_api.py
index 101b0827d6..42fe8cb652 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/builtin/_ffi_api.py
@@ -13,10 +13,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""FFI APIs for tvm.relax.op.builtin"""
+import tvm._ffi
-# Operators
-from .base import *
-from .binary import *
+tvm._ffi._init_api("relax.op.builtin", __name__)
diff --git a/python/tvm/relax/op/builtin/builtin.py
b/python/tvm/relax/op/builtin/builtin.py
new file mode 100644
index 0000000000..0afe6a42d0
--- /dev/null
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+"""The builtin Relax operators."""
+
+from ...expr import Call, Expr
+from ...utils import args_converter
+from . import _ffi_api
+
+
+@args_converter.auto
+def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
+ """Construct a Call to allocate a tensor with specific shape, dtype,
runtime_device_index.
+
+ Parameters
+ ----------
+ shape : Expr
+ The shape of the tensor to be allocated.
+
+ dtype : str
+ The datatype of the tensor to be allocated.
+
+ runtime_device_index : int
+ The device index indicating on which device the tensor is to be
allocated at runtime.
+ Index -1 is reserved for the host device.
+
+ Returns
+ -------
+ result : Call
+ A relax Call, which gets the allocated tensor.
+ """
+ return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type:
ignore
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
new file mode 100644
index 0000000000..fa9c815225
--- /dev/null
+++ b/python/tvm/relax/op/manipulate.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Manipulation operators."""
+from typing import Tuple, Union
+
+from tvm.ir.expr import PrimExpr
+
+
+from . import _ffi_api
+from ..expr import Expr
+
+
+PrimExprLike = Union[int, PrimExpr]
+
+
+def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
+ """Reshape the input array.
+
+ ``-1`` infers the dimension of the output shape by using the remainder of
+ the input dimensions keeping the size of the new array same as that of the
input array.
+ At most one dimension of shape can be -1.
+
+ .. code-block:: python
+
+ x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4)
+ x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8)
+ x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,)
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input data to the operator.
+
+ shape : Union[Tuple[PrimExprLike], Expr]
+ The new shape. Should be compatible with the original shape.
+
+ Returns
+ -------
+ result : relax.Expr
+ The reshaped result.
+
+ Note
+ ----
+ The ``-1`` inference is only performed at compile-time.
+ That is to say, in any case the dimension length of ``-1`` cannot be
inferred in
+ compile-time, an error will be thrown.
+ """
+ return _ffi_api.reshape(x, shape) # type: ignore
diff --git a/python/tvm/relax/op/__init__.py
b/python/tvm/relax/op/memory/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/memory/__init__.py
index 101b0827d6..e039590251 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/memory/__init__.py
@@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Relax memory primitives."""
-# Operators
-from .base import *
-from .binary import *
+from .memory import *
diff --git a/python/tvm/relax/op/__init__.py
b/python/tvm/relax/op/memory/_ffi_api.py
similarity index 83%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/memory/_ffi_api.py
index 101b0827d6..475de481b2 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/memory/_ffi_api.py
@@ -13,10 +13,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""FFI APIs for tvm.relax.op.memory"""
+import tvm._ffi
-# Operators
-from .base import *
-from .binary import *
+tvm._ffi._init_api("relax.op.memory", __name__)
diff --git a/python/tvm/relax/op/memory/memory.py
b/python/tvm/relax/op/memory/memory.py
new file mode 100644
index 0000000000..b58b987d2a
--- /dev/null
+++ b/python/tvm/relax/op/memory/memory.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+"""Relax memory primitives."""
+
+from . import _ffi_api
+from ...expr import Expr, Call
+from ...utils import args_converter
+
+
+@args_converter.auto
+def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str,
dtype: str) -> Call:
+ """Construct a Call to allocate a storage with specific size,
virtual_device_index,
+ storage_scope and dtype.
+
+ Parameters
+ ----------
+ size : Expr
+ The size of the storage to be allocated.
+
+ virtual_device_index : int
+ The virtual device index indicating on which device the storage is to
be allocated.
+ Index -1 is reserved for the host device.
+
+ storage_scope : str
+ The storage scope to allocate the storage to.
+
+ dtype : str
+ The datatype of the storage to be allocated.
+
+ Returns
+ -------
+ result : Call
+ A relax Call, which gets the allocated storage.
+ """
+ return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope,
dtype) # type: ignore
+
+
+@args_converter.auto
+def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call:
+ """Construct a Call to allocate a tensor on a certain storage starting
from the given offset.
+
+ Parameters
+ ----------
+ storage : Expr
+ The storage to allocate the tensor to.
+
+ offset : int
+ The storage offset to allocate the tensor.
+
+ shape : Expr
+ The shape of the tensor to be allocated.
+
+ dtype : str
+ The datatype of the tensor to be allocated.
+
+ Returns
+ -------
+ result : Call
+ A relax Call, which gets the allocated tensor.
+ """
+ return _ffi_api.alloc_tensor(storage, offset, shape, dtype) # type: ignore
+
+
+@args_converter.auto
+def kill_storage(storage: Expr) -> Call:
+ """Construct a Call to kill a storage.
+
+ Parameters
+ ----------
+ storage : Expr
+ The storage to be killed.
+
+ Returns
+ -------
+ result : Call
+ A relax Call to kill a storage.
+ """
+ return _ffi_api.kill_storage(storage) # type: ignore
+
+
+@args_converter.auto
+def kill_tensor(tensor: Expr) -> Call:
+ """Construct a Call to kill a tensor.
+
+ Parameters
+ ----------
+ tensor : Expr
+ The tensor to be killed.
+
+ Returns
+ -------
+ result : Call
+ A relax Call to kill a tensor.
+ """
+ return _ffi_api.kill_tensor(tensor) # type: ignore
diff --git a/python/tvm/relax/op/__init__.py
b/python/tvm/relax/testing/__init__.py
similarity index 91%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/testing/__init__.py
index 101b0827d6..ab1dd6f515 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""The Relax testing namespace containing nn and translator."""
-# Operators
-from .base import *
-from .binary import *
+from .nn import *
diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py
new file mode 100644
index 0000000000..830ddd779f
--- /dev/null
+++ b/python/tvm/relax/testing/nn.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""PyTorch-like nn.Module API for constructing workloads."""
+
+
+from typing import List, Any, Callable, Union
+import typing
+import numpy as np # type: ignore
+
+import tvm
+from tvm import relax, topi, tir
+
+
+def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var:
+ return relax.BlockBuilder.current().emit_te(func, *args, **kwargs)
+
+
+class Placeholder(relax.Var):
+ """A placeholder variable that can represent model input."""
+
+ def __init__(
+ self, shape: Union[List[Any], typing.Tuple[Any, ...]],
dtype="float32", name="data"
+ ):
+ if not isinstance(shape, (list, tuple)):
+ raise TypeError("the shape of Placeholder is expected to be a list
or a tuple")
+ super().__init__(
+ relax.BlockBuilder.current().get_unique_name(name),
relax.TensorStructInfo(shape, dtype)
+ )
+
+
+class Parameter(relax.Var):
+ """A special kind of relax Var that represents model parameter(weight)."""
+
+ def __init__(
+ self, shape: Union[List[Any], typing.Tuple[Any, ...]],
dtype="float32", name="param"
+ ):
+ if not isinstance(shape, (list, tuple)):
+ raise TypeError("the shape of Parameter is expected to be a list
or a tuple")
+ super().__init__(
+ relax.BlockBuilder.current().get_unique_name(name),
relax.TensorStructInfo(shape, dtype)
+ )
+
+
+class Module:
+ """Base class for all model modules.
+
+ A neural network or a layer can subclass this class.
+
+ Example
+ -------
+ .. code-block:: python
+
+ # Define a linear layer
+ class Linear(Module)
+ def __init__(self, in_features, out_features, bias=True):
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = Parameter((in_features, out_features),
name="linear_weight")
+ if bias:
+ self.bias = Parameter((out_features,), name="linear_bias")
+ else:
+ self.bias = None
+
+ # All submodules should implement forward.
+ # Defines the forward computation performed at every call.
+ def forward(self, input: relax.Expr) -> relax.Var:
+ y = emit_te(topi.matmul, input, self.weight)
+ if self.bias is not None:
+ y = emit_te(topi.add, y, self.bias)
+ return y
+ """
+
+ def parameters(self) -> List[Parameter]:
+ """Return the list of parameters in the module."""
+ return _unpack_params(self.__dict__)
+
+ def forward(self, input: relax.Expr):
+ """Define the computation performed at every call."""
+ raise NotImplementedError()
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+
+def _unpack_params(value: object) -> List[relax.Var]:
+ if isinstance(value, Parameter):
+ return [value]
+ if isinstance(value, Module):
+ return value.parameters()
+ if isinstance(value, dict):
+ params = []
+ for v in value.values():
+ params += _unpack_params(v)
+ return params
+ if isinstance(value, (list, tuple)):
+ params = []
+ for v in value:
+ params += _unpack_params(v)
+ return params
+ if value is None or isinstance(value, (int, float, str)):
+ return []
+ raise TypeError("not supported type when unpacking parameters:
{}".format(type(value)))
+
+
+def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
+ """Utility function to initialize model's parameters."""
+ shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params}
+ params = []
+ for k, v in shape_dict.items():
+ if k.startswith("data"):
+ continue
+ if isinstance(v, relax.ShapeExpr):
+ shape = []
+ for i in v:
+ if isinstance(i, tir.IntImm):
+ shape.append(int(i))
+ else:
+ raise TypeError("cannot initialize for unknown-shape
parameters.")
+ params.append(tvm.nd.array(np.zeros(shape).astype(np.float32)))
+ else:
+ raise TypeError("cannot initialize for unknown-shape parameters.")
+ return params
+
+
+class Sequential(Module):
+ """A sequential container that concatenates modules in it.
+
+ Example
+ -------
+ .. code-block:: python
+
+ model = nn.Sequential(
+ nn.Conv2d(1, 20, 5),
+ nn.ReLU(),
+ nn.Conv2d(20, 64, 5),
+ nn.ReLU()
+ )
+ """
+
+ def __init__(self, *modules: Module):
+ self.modules = modules
+
+ def forward(self, input: relax.Expr) -> relax.Var:
+ for module in self.modules:
+ input = module(input)
+ return input
+
+
+class ReLU(Module):
+ """Applies the rectified linear unit activation function on the input."""
+
+ def forward(self, input: relax.Expr) -> relax.Var:
+ return emit_te(topi.nn.relu, input)
+
+
+class LogSoftmax(Module):
+ """Applies log softmax activation function on the input."""
+
+ def forward(self, input: relax.Expr) -> relax.Var:
+ return emit_te(topi.nn.log_softmax, input)
+
+
+class Linear(Module):
+ """Applies a linear transformation to the input data: :math:`y = xA +
b`."""
+
+ def __init__(self, in_features, out_features, bias=True):
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = Parameter((in_features, out_features),
name="linear_weight")
+ if bias:
+ self.bias = Parameter((out_features,), name="linear_bias")
+ else:
+ self.bias = None
+
+ def forward(self, input: relax.Expr) -> relax.Var:
+ y = emit_te(topi.matmul, input, self.weight)
+ if self.bias is not None:
+ y = emit_te(topi.add, y, self.bias)
+ return y
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index f20f06c522..cab18797c6 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -37,6 +37,49 @@ class DataflowBlockPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relax.DataflowBlock in a module."""
+def ToNonDataflow() -> tvm.ir.transform.Pass:
+ """Transform all dataflow structure to non-dataflow version.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.ToNonDataflow() # type: ignore
+
+
+def CallTIRRewrite() -> tvm.ir.transform.Pass:
+ """Perform explicit tensor allocation for call_tir.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.CallTIRRewrite() # type: ignore
+
+
+def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
+ """Convert all reshape-like call_tir to VM reshape operator call.
+ The VM reshape operator calls will be further lowered to a CreateView
+ operation at runtime, instead of doing real data copy.
+ Here "reshape-like" includes reshape, expand_dims, flatten, etc.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ """
+ return _ffi_api.RewriteDataflowReshape() # type: ignore
+
+
+def VMBuiltinLower() -> tvm.ir.transform.Pass:
+ """Lowering generic intrinsic to VM intrinsics.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.VMBuiltinLower() # type: ignore
+
+
def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
"""Lower the symbolic shape and argument and match-cast structinfo
matching.
@@ -52,6 +95,16 @@ def VMShapeLower(*, emit_err_ctx: bool = True) ->
tvm.ir.transform.Pass:
return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore
+def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
+ """Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.AttachGlobalSymbol() # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py
index ba16dfb079..ff6bf816b6 100644
--- a/python/tvm/relax/vm.py
+++ b/python/tvm/relax/vm.py
@@ -581,7 +581,9 @@ def build(
if isinstance(target, str):
target = tvm.target.Target(target)
- passes = [relax.transform.ToNonDataflow()]
+ passes = []
+ passes.append(relax.transform.RewriteDataflowReshape())
+ passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 647ef8f25a..0692ec5683 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -31,13 +31,16 @@ from tvm.relax import Call, Expr, ExternFunc, TupleGetItem,
Var, const
from tvm.relax.op import (
add,
assert_op,
+ builtin,
call_builtin_with_ctx,
call_tir,
invoke_closure,
make_closure,
+ memory,
multiply,
null_value,
print,
+ reshape,
shape_of,
)
from tvm.relax.struct_info import StructInfo
@@ -381,6 +384,7 @@ __all__ = [
"add",
"arg",
"assert_op",
+ "builtin",
"call_packed",
"call_tir",
"call_builtin_with_ctx",
@@ -396,11 +400,13 @@ __all__ = [
"function",
"invoke_closure",
"make_closure",
+ "memory",
"multiply",
"null_value",
"output",
"prim_value",
"print",
+ "reshape",
"shape_of",
"str",
"tuple",
diff --git a/src/relax/analysis/tir_op_pattern_kind.cc
b/src/relax/analysis/tir_op_pattern_kind.cc
new file mode 100644
index 0000000000..b7ac8faddd
--- /dev/null
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+
+class PatternKindAnalyzer : public StmtExprVisitor {
+ public:
+ explicit PatternKindAnalyzer(const tir::PrimFunc& func) {
+ for (const tir::Var& param : func->params) {
+ Optional<Buffer> param_buf = func->buffer_map.Get(param);
+ if (param_buf.defined()) {
+ param_buffers_.insert(param_buf.value());
+ }
+ }
+ }
+
+ private:
+ bool IsOutputBlock(const BlockNode* block) {
+ for (const BufferRegion& write_region : block->writes) {
+ if (param_buffers_.count(write_region->buffer)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ // We only support one buffer store in a block (ususally generated by TE
compute)
+ // If we have already seen buffer store in the current block, classify as
Opaque.
+ if (store_.defined()) {
+ kind_ = relay::kOpaque;
+ return;
+ }
+ store_ = GetRef<BufferStore>(op);
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ loads_.push_back(GetRef<BufferLoad>(op));
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ if (op->name_hint == "root") {
+ // Skip the root block
+ StmtVisitor::VisitStmt(op->body);
+ return;
+ }
+
+ // Step 1. Clear loads and store
+ loads_.clear();
+ store_ = NullOpt;
+ // Step 2. Visit block body.
+ StmtVisitor::VisitStmt(op->body);
+ BufferStore store = store_.value();
+
+ // Step 3. Checking load store indices pattern
+ relay::OpPatternKind index_pair_pattern = relay::kElemWise;
+ bool has_elem_wise = false;
+ for (const BufferLoad& load : loads_) {
+ // Since elemwise is stricter than broadcast and broadcast is stricter
than injective,
+ // while the order amount enums: kElemWise < kBroadcast < kInjective.
+ // We can simpily use `std::max` to detect these three patterns.
+ // E.g Here is only one store node but two load nodes, like C[i, j] =
A[i, j] + B[i]
+ // Buffer C and A are elemwise but C and B are broadcast. So the whole
block follows
+ // broadcast pattern.
+ if (IsElemwisePattern(store, load)) {
+ index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
+ has_elem_wise = true;
+ } else if (IsBroadcastPattern(store, load)) {
+ index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
+ } else if (IsInjectivePattern(store, load)) {
+ index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
+ } else {
+ index_pair_pattern = relay::kOpaque;
+ break;
+ }
+ }
+ // If there is a index pair is kElemWise and others are kBroadcast, we
regard it as kElemWise
+ // e.g. A[i, j] = B[i, j] + C[i]
+ if (index_pair_pattern == relay::kBroadcast && has_elem_wise) {
+ index_pair_pattern = relay::kElemWise;
+ }
+ // If the block index pattern is not opaque, update kind.
+ if (index_pair_pattern != relay::kOpaque) {
+ // This rule for softmax: reduce + injective.
+ if (IsOutputBlock(op) && kind_ == relay::kCommReduce) {
+ kind_ = relay::kOutEWiseFusable;
+ } else {
+ kind_ = std::max(kind_, index_pair_pattern);
+ }
+ return;
+ }
+
+ // Step 4. Checking if the block contains reduce axis by looking into
block iterators.
+ bool has_reduction = false;
+ Array<tir::Var> reduce_vars;
+ for (const IterVar& it : op->iter_vars) {
+ if (it->iter_type == kCommReduce) {
+ has_reduction = true;
+ reduce_vars.push_back(it->var);
+ }
+ }
+
+ if (has_reduction) {
+ if (IsFMA(op->body)) {
+ // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
+ kind_ = std::max(kind_, relay::kOutEWiseFusable);
+ return;
+ } else {
+ for (size_t i = 0; i < loads_.size(); ++i) {
+ // If it's not a pure reduce, regards as kOutEWiseFusable.
+ // This rule works for pooling for now.
+ if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
+ kind_ = std::max(kind_, relay::kOutEWiseFusable);
+ return;
+ }
+ }
+ }
+ kind_ = std::max(kind_, relay::kCommReduce);
+ } else {
+ kind_ = relay::kOpaque;
+ }
+ }
+
+ /********** Helper Functions **********/
+
+ /*! \brief Checking if two arrays contains same elements. */
+ static bool IsSameArray(const Array<PrimExpr>& lhs, const Array<PrimExpr>&
rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!lhs[i].same_as(rhs[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /*!
+ * \brief Checking the load indices and store indices follows elemwise
pattern.
+ * It's elemwise pattern iff load indices and store indices are the same.
+ * E.g A[i, j] = B[i, j]
+ */
+ static bool IsElemwisePattern(const BufferStore& store, const BufferLoad&
load) {
+ return IsSameArray(store->indices, load->indices);
+ }
+
+ /*!
+ * \brief Checking the load indices and store indices follows broadcast
pattern.
+ * It's broadcast pattern iff all load indices are in the store indices in
order
+ * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the
store indices
+ * A[i, j] = B[i, k] is not broadcast since `k` are not in the store
indices.
+ * A[i, j] = B[j, i] is not broadcast the load indices are not in the
same order as store's
+ */
+ static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad&
load) {
+ size_t ndim_load_buf = load->buffer->shape.size();
+ size_t ndim_store_buf = store->buffer->shape.size();
+
+ for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) {
+ if (is_const_int(load->buffer->shape[i], 1) &&
is_const_int(load->indices[i], 0)) {
+ // Skip unit load dimensions
+ // E.g. A[i, j] = B[1, j] is still broadcast
+ continue;
+ }
+
+ // Try to find the i-th load indice in the store indices.
+ while (j < ndim_store_buf &&
!store->indices[j].same_as(load->indices[i])) {
+ ++j;
+ }
+
+ // It's not broadcast if we cannot find load indices in the store
indices in order.
+ if (j == ndim_store_buf) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /*!
+ * \brief Checking the load indices and store indices follows injective
pattern.
+ * It's injective pattern iff all load indice vars are in the store indices,
no matter orders.
+ * Note that we only support store indices are direct vars so far, which can
be enhance later.
+ * E.g. A[i, j] = B[j, i] is injective.
+ * A[i, j] = B[i - j] is injective since the load indice vars are only
i, j
+ */
+ static bool IsInjectivePattern(const BufferStore& store, const BufferLoad&
load) {
+ std::unordered_set<const tir::VarNode*> vars;
+ for (const PrimExpr& store_index : store->indices) {
+ if (const auto* v = store_index.as<tir::VarNode>()) {
+ vars.insert(v);
+ } else {
+ return false;
+ }
+ }
+ for (const PrimExpr& load_index : load->indices) {
+ // return false if there are vars used in load indices but not in store
indices.
+ if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return
!vars.count(var); })) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /*!
+ * \brief Checking the load indices and store indices allow data reuse.
+ * It allow data reuse iff there is any vars in load indices but they are
not in store indices
+ * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse.
+ * Store = A[i, j] and Load = B[i, j + k] allow data reuse.
+ */
+ static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad&
load) {
+ std::unordered_set<const tir::VarNode*> vars;
+ for (const PrimExpr& index : store->indices) {
+ if (const auto* v = index.as<tir::VarNode>()) {
+ vars.insert(v);
+ } else {
+ return false;
+ }
+ }
+ for (const PrimExpr& index : load->indices) {
+ PreOrderVisit(index, [&](const ObjectRef& node) {
+ if (const auto* v = node.as<tir::VarNode>()) {
+ if (vars.count(v)) {
+ vars.erase(v);
+ }
+ }
+ return true;
+ });
+ }
+ return !vars.empty();
+ }
+
+ /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] *
B[j, k] */
+ static bool IsFMA(const Stmt& body) {
+ if (const auto* store = body.as<BufferStoreNode>()) {
+ if (const auto* add = store->value.as<AddNode>()) {
+ if (const auto* l = add->a.as<BufferLoadNode>()) {
+ if (const auto* r = add->b.as<MulNode>()) {
+ bool incremental =
+ store->buffer.same_as(l->buffer) &&
IsSameArray(store->indices, l->indices);
+ const auto* l_load = r->a.as<BufferLoadNode>();
+ const auto* r_load = r->b.as<BufferLoadNode>();
+ if (incremental && l_load && r_load) {
+ return IsAllowReusePattern(GetRef<BufferStore>(store),
GetRef<BufferLoad>(l_load)) &&
+ IsAllowReusePattern(GetRef<BufferStore>(store),
GetRef<BufferLoad>(r_load));
+ }
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ /*!
+ * \brief Checking if it is pure reduce pattern.
+ * It's pure reduce pattern iff all reduces axis are directly reduce var
+ * E.g. A[i] = sum(B[i, j]) is pure reduce
+ * A[i] = sum(B[i, j + k]) is not pure reduce
+ * pooling is not pure reduce
+ */
+ static bool IsPureReducePattern(Array<tir::Var> reduce_loops,
Array<PrimExpr> indices) {
+ for (const PrimExpr& e : indices) {
+ int id = -1;
+ if (UsesVar(e, [&](const tir::VarNode* var) {
+ for (size_t i = 0; i < reduce_loops.size(); ++i) {
+ if (reduce_loops[i].get() == var) {
+ id = i;
+ return true;
+ }
+ }
+ return false;
+ })) {
+ if (!reduce_loops[id].same_as(e)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ private:
+ /*!
+ * \brief The BufferStore node in the current block.
+ * \note We only support one BufferStore node in a block (ususally generated
by TE compute)
+ */
+ Optional<BufferStore> store_;
+ /*! \brief The BufferLoad nodes in the current block. */
+ Array<BufferLoad> loads_;
+ /*! \brief The result of op pattern. */
+ relay::OpPatternKind kind_ = relay::kElemWise;
+ /*! \brief The buffers from function params. I.e. the input and output
buffers. */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> param_buffers_;
+
+ public:
+ relay::OpPatternKind GetResult() { return kind_; }
+};
+
+relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) {
+ PatternKindAnalyzer analyzer(func);
+ analyzer(func->body);
+ return analyzer.GetResult();
+}
+
+bool HasReshapePattern(const PrimFunc& func) {
+ class ReshapeDetector : public StmtVisitor {
+ public:
+ static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer,
Stmt stmt) {
+ ReshapeDetector detector(src_buffer, dst_buffer);
+ detector(stmt);
+ return detector.is_reshape_;
+ }
+
+ private:
+ explicit ReshapeDetector(const Buffer& src_buffer, const Buffer&
dst_buffer)
+ : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer)
{}
+
+ void VisitStmt_(const ForNode* loop) final {
+ ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+ // To detect the reshape pattern, we require each For to have
+ // either another For or a BlockRealize as body.
+ if (!(loop->body->IsInstance<ForNode>() ||
loop->body->IsInstance<BlockRealizeNode>())) {
+ return;
+ }
+ this->VisitStmt(loop->body);
+ }
+
+ void VisitStmt_(const BlockRealizeNode* block_realize) final {
+ // Constructing the mapping from block iterators to iterator
+ // binding values. The mapping will be used in the substitution of
+ // the flattened buffer access index.
+ const Block& block = block_realize->block;
+ const Array<IterVar>& block_iter = block->iter_vars;
+ const Array<PrimExpr>& iter_values = block_realize->iter_values;
+ ICHECK_EQ(block_iter.size(), iter_values.size());
+ int n_iter = block_iter.size();
+ for (int i = 0; i < n_iter; ++i) {
+ // To detect the reshape pattern, we require each block iter to be
data-parallel.
+ if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) {
+ return;
+ }
+ var_map_.Set(block_iter[i]->var, iter_values[i]);
+ }
+
+ // Recurse into the block.
+ this->VisitStmt(block);
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ // Step 0. If the block body is a ForNode, recurse into it.
+ if (block->body->IsInstance<ForNode>()) {
+ this->VisitStmt(block->body);
+ return;
+ }
+
+ // Step 1. Get the load/store pattern of the block body.
+ // To detect the reshape pattern, we require the block body to be a
+ // BufferStore, which has a BufferLoad as value.
+ const auto* buffer_store = block->body.as<BufferStoreNode>();
+ if (buffer_store == nullptr) {
+ return;
+ }
+ const auto* buffer_load = buffer_store->value.as<BufferLoadNode>();
+ if (buffer_load == nullptr) {
+ return;
+ }
+ // Further, we require the buffer being stored and being loaded to
+ // match the parameter of the PrimFunc, namely `dst_buffer_` and
`src_buffer_`.
+ if (!(buffer_store->buffer.same_as(dst_buffer_) &&
+ buffer_load->buffer.same_as(src_buffer_))) {
+ return;
+ }
+
+ // Step 3. Calculate the flattened access index according to the
load/store pattern.
+ auto f_calc_flattened_idx = [](const Buffer& buffer, const
Array<PrimExpr>& indices) {
+ ICHECK_EQ(indices.size(), buffer->shape.size());
+ int ndim = indices.size();
+ PrimExpr idx = 0;
+ for (int i = 0; i < ndim; ++i) {
+ idx = idx * buffer->shape[i] + indices[i];
+ }
+ return idx;
+ };
+ PrimExpr src_idx = f_calc_flattened_idx(src_buffer_,
buffer_load->indices);
+ PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_,
buffer_store->indices);
+
+ // Step 4. Substitute the block iterators in the flattened index
+ // with loop variables, and check if we can prove their equality.
+ src_idx = tir::Substitute(std::move(src_idx), var_map_);
+ dst_idx = tir::Substitute(std::move(dst_idx), var_map_);
+ if (ana_.CanProveEqual(src_idx, dst_idx)) {
+ this->is_reshape_ = true;
+ }
+ }
+
+ bool is_reshape_;
+ /*! \brief The mapping from block vars to block binding values. */
+ Map<tir::Var, PrimExpr> var_map_;
+ const Buffer& src_buffer_;
+ const Buffer& dst_buffer_;
+ arith::Analyzer ana_;
+ };
+
+ if (func->params.size() < 2) {
+ return false;
+ }
+ Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front());
+ Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back());
+ if (!(src_buffer.defined() && dst_buffer.defined())) {
+ return false;
+ }
+
+ // To detect the reshape pattern, we require each For to have
+ // either another For or a BlockRealize as body.
+ ICHECK(func->body->IsInstance<BlockRealizeNode>());
+ return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(),
func->body);
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern);
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/vm_builtin_lower.cc
new file mode 100644
index 0000000000..6613b39626
--- /dev/null
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/backend/vm/vm_builtin_lower.cc
+ * \brief Lowers most builtin functions and packed calls.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/type.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+// This pass lowers most ops to VM specific builtins.
+// TODO(relax-team): revisit after PrimValue.
+class VMBuiltinLowerMutator : public ExprMutator {
+ public:
+ using ExprMutator::VisitExpr_;
+
+ // A workaround to remove the CallNodes of killing tensors and storages.
+ void VisitBinding_(const VarBindingNode* binding) final {
+ const auto* call = binding->value.as<CallNode>();
+ if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op ==
mem_kill_tensor_op_)) {
+ return;
+ }
+ ExprMutator::VisitBinding_(binding);
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ // post-order mutation
+ Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+
+ if (call->op == call_tir_dyn_op_) {
+ return CallTIRDyn(call);
+ } else if (call->op == reshape_op_) {
+ return Reshape(call);
+ } else if (call->op == make_closure_op_) {
+ return MakeClosure(call);
+ } else if (call->op == invoke_closure_op_) {
+ return InvokeClosure(call);
+ } else if (call->op == alloc_tensor_op_) {
+ return MakeAllocTensor(call);
+ } else if (call->op == mem_alloc_storage_op_) {
+ return MakeMemAllocStorage(call);
+ } else if (call->op == mem_alloc_tensor_op_) {
+ return MakeMemAllocTensor(call);
+ } else {
+ return call;
+ }
+ }
+
+ Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
+ // Question: what if the dtype of tensor_type is unknown?
+ // Symbolic/static shape case
+ if (auto* shape_expr = shape.as<ShapeExprNode>()) {
+ int64_t elem_bytes = runtime::GetVectorBytes(dtype);
+ PrimExpr ret = IntImm(DataType::Int(64), elem_bytes);
+ for (PrimExpr dim : shape_expr->values) {
+ ret = ret * dim;
+ }
+ return ShapeExpr({ret});
+ } else {
+ return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)},
Attrs(),
+ {GetStructInfo(shape)});
+ }
+ }
+
+ Expr MakeAllocTensor(const Call& call) {
+ ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
+ DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[1]);
+ DataType dtype = output_dtype->value;
+ Expr storage_size = ComputeStorageSize(output_shape, dtype);
+ PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
+ Var storage = builder_->Emit(
+ Call(vm_alloc_storage_op_, {storage_size, runtime_device_index,
output_dtype}, Attrs()),
+ "storage");
+ Expr shape = call->args[0];
+ PrimValue offset = PrimValue::Int64(0);
+ return Call(vm_alloc_tensor_op_, {storage, offset, shape,
DataTypeImm(dtype)}, Attrs());
+ }
+
+ Expr MakeMemAllocStorage(const Call& call) {
+ PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
+ DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[3]);
+ return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index,
output_dtype}, Attrs());
+ }
+
+ Expr MakeMemAllocTensor(const Call& call) {
+ PrimValue offset = Downcast<PrimValue>(call->args[1]);
+ DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]);
+ return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2],
dtype}, Attrs());
+ }
+
+ Expr CallTIRDyn(const Call& call_node) {
+ ICHECK(call_node->args.size() == 2);
+ ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
+ ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+ Array<Expr> args;
+
+ auto tir_args = Downcast<Tuple>(call_node->args[1]);
+ args.push_back(call_node->args[0]);
+ for (Expr arg : tir_args->fields) {
+ args.push_back(arg);
+ }
+ return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_});
+ }
+
+ Expr Reshape(const Call& call_node) {
+ ICHECK(call_node->args.size() == 2);
+ ICHECK(call_node->struct_info_.defined());
+ CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
+ << "VMBuiltinLower expects the shape arg of reshape op to be a
ShapeExpr";
+ return Call(builtin_reshape_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
+ }
+
+ Expr MakeClosure(const Call& call_node) {
+ ICHECK(call_node->args.size() == 2);
+ ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
+ ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+
+ Array<Expr> args;
+ auto func = call_node->args[0];
+ auto closure_args = Downcast<Tuple>(call_node->args[1]);
+
+ args.push_back(func);
+ for (Expr arg : closure_args->fields) {
+ args.push_back(arg);
+ }
+
+ return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_});
+ }
+
+ Expr InvokeClosure(const Call& call_node) {
+ ICHECK(call_node->args.size() == 2);
+ ICHECK(call_node->args[0]->IsInstance<VarNode>());
+ ICHECK(call_node->args[1]->IsInstance<TupleNode>());
+
+ Array<Expr> args;
+
+ args.push_back(call_node->args[0]);
+
+ // args for the invoke_closure
+ auto invoke_closure_args = Downcast<Tuple>(call_node->args[1]);
+ for (Expr arg : invoke_closure_args->fields) {
+ args.push_back(arg);
+ }
+ return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_,
Tuple(args)}, Attrs(),
+ {object_sinfo_});
+ }
+
+ const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+ const StructInfo object_sinfo_ = ObjectStructInfo();
+ const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
+ // object to pattern match.
+ const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
+ const Op& reshape_op_ = Op::Get("relax.reshape");
+ const Op& make_closure_op_ = Op::Get("relax.make_closure");
+ const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
+ const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
+ const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage");
+ const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor");
+ const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage");
+ const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor");
+ // functions to lower to
+ const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
+ const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+ // Function to compute allocated shape.
+ const ExternFunc
builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
+ const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
+ const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
+ const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
+ const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
+};
+
+Expr VMBuiltinLower(const Expr& e) { return
VMBuiltinLowerMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass VMBuiltinLower() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(VMBuiltinLower(f)); };
+ return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index ca66b0a9ef..ba167a45bc 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -226,6 +226,87 @@ Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t
runtime_device_index) {
TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
+// memory planning alloc_storage
+
+RELAY_REGISTER_OP("relax.memory.alloc_storage")
+ .set_num_inputs(4)
+ .add_argument("total_space", "Expr", "The total space of the storage to
allocate.")
+ .add_argument(
+ "virtual_device_index", "int64_t",
+ "The virtual device index indicating on which device the storage is to
be allocated, "
+ "Index -1 is reserved for the host device.")
+ .add_argument("storage_scope", "string",
+ "The storage scope of the storage to allocate. Default is
global.")
+ .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
+
+Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string
storage_scope,
+ DataType dtype) {
+ static const Op& op = Op::Get("relax.memory.alloc_storage");
+ return Call(
+ op,
+ {size, PrimValue::Int64(virtual_device_index), StringImm(storage_scope),
DataTypeImm(dtype)},
+ Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage);
+
+// memory planning alloc_tensor
+
+StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder&
ctx) {
+ ICHECK(GetStructInfoAs<ShapeStructInfoNode>(call->args[2]))
+ << "must be a Expr of ShapeStructInfo, but got " <<
call->args[1]->GetTypeKey();
+ DataType out_dtype;
+ if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
+ const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
+ out_dtype = dtype_imm->value;
+ }
+ return TensorStructInfo(call->args[2], out_dtype);
+}
+
+RELAY_REGISTER_OP("relax.memory.alloc_tensor")
+ .set_num_inputs(4)
+ .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
+ .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+ .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
+ .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoMemAllocTensor);
+
+Expr MakeMemAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+ static const Op& op = Op::Get("relax.memory.alloc_tensor");
+ return Call(op, {storage, PrimValue::Int64(offset), shape,
DataTypeImm(dtype)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor);
+
+// memory planning kill_storage
+
+RELAY_REGISTER_OP("relax.memory.kill_storage")
+ .set_num_inputs(1)
+ .add_argument("storage", "Expr", "The storage to be killed.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeMemKillStorage(Expr storage) {
+ static const Op& op = Op::Get("relax.memory.kill_storage");
+ return Call(op, {storage}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage);
+
+// memory planning kill_tensor
+
+RELAY_REGISTER_OP("relax.memory.kill_tensor")
+ .set_num_inputs(1)
+ .add_argument("tensor", "Expr", "The tensor to be killed.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeMemKillTensor(Expr tensor) {
+ static const Op& op = Op::Get("relax.memory.kill_tensor");
+ return Call(op, {tensor}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor);
+
// vm alloc_storage
RELAY_REGISTER_OP("relax.vm.alloc_storage")
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
new file mode 100644
index 0000000000..2088a8306e
--- /dev/null
+++ b/src/relax/op/tensor/manipulate.cc
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file manipulate.cc
+ * \brief Manipulation operators.
+ */
+
+#include "manipulate.h"
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+// Helper function for flatten and reshape.
+PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
+ PrimExpr shape_prod = IntImm(DataType::Int(64), 1);
+ for (PrimExpr value : shape_values) {
+ shape_prod *= value;
+ }
+ return shape_prod;
+}
+
+/* relax.reshape */
+Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
+ if (const auto* e = shape.as<ExprNode>()) {
+ return GetRef<Expr>(e);
+ }
+
+ const auto* array = shape.as<ArrayNode>();
+ CHECK(array != nullptr) << "Reshape only expects the input new shape to be
either an Expr or an "
+ "Array of PrimExprs. However, the given new shape
is "
+ << shape;
+ int dim_to_infer = -1;
+ PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
+ for (int i = 0; i < static_cast<int>(array->size()); ++i) {
+ const auto* _len = array->at(i).as<PrimExprNode>();
+ CHECK(_len != nullptr) << "Reshape only expects the input new shape to be
either an Expr or an "
+ "Array of PrimExprs. However, the given new
shape is "
+ << shape;
+ PrimExpr len = GetRef<PrimExpr>(_len);
+ CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be
all "
+ "integers. However, the give new shape is "
+ << shape;
+ const auto* int_len = len.as<IntImmNode>();
+ if (int_len != nullptr && int_len->value == -1) {
+ CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the
new shape. However, "
+ "there are multiple \"-1\" in the given
new shape "
+ << shape;
+ dim_to_infer = i;
+ } else {
+ CHECK(int_len == nullptr || int_len->value > 0)
+ << "Reshape requires all values in the new shape to be positive
except a single \"-1\". "
+ "However, the given new shape is "
+ << shape;
+ // We expect any symbolic not to signal the intent of -1, and therefore
do no check for
+ // symbolic value here.
+ new_shape_prod = new_shape_prod * len;
+ }
+ }
+
+ Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array);
+ // When there is no dimension to infer, just return the input array as
ShapeExpr.
+ if (dim_to_infer == -1) {
+ return ShapeExpr(array_ref);
+ }
+
+ // Otherwise, we require the input tensor to have known shape value for
inference.
+ const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(data);
+ CHECK(data_sinfo != nullptr)
+ << "Reshape expects the input data to be a Tensor. However, the given
input is "
+ << data->struct_info_->GetTypeKey();
+ CHECK(data_sinfo->shape.defined())
+ << "Reshape expects the input tensor to have known shape when there is
some dimension length "
+ "to infer. However, the given input has no shape.";
+ const auto* shape_sinfo =
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());
+ CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined())
+ << "Reshape expects the input tensor to have known shape when there is
some dimension length "
+ "to infer. However, the given input shape is "
+ << data_sinfo->shape << " whose shape value is unknown.";
+
+ arith::Analyzer analyzer;
+ PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
+ array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod,
new_shape_prod)));
+ return ShapeExpr(array_ref);
+}
+
+Expr reshape(Expr x, ObjectRef shape) {
+ Expr shape_in_expr = ConvertNewShapeToExpr(x, shape);
+ static const Op& op = Op::Get("relax.reshape");
+ return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape);
+
+StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take
2 arguments");
+ }
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* new_shape_sinfo =
GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call->span)
+ << "Reshape requires the input data to be Tensor.
However, the given one is "
+ << call->args[0]->struct_info_->GetTypeKey());
+ }
+ if (new_shape_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call->span)
+ << "Reshape requires the input new shape to be Shape. However, the
given one is "
+ << call->args[1]->struct_info_->GetTypeKey());
+ }
+
+ Optional<Array<PrimExpr>> old_shape_values;
+ if (data_sinfo->shape.defined()) {
+ const auto* old_shape_sinfo =
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());
+ ICHECK_NOTNULL(old_shape_sinfo);
+ old_shape_values = old_shape_sinfo->values;
+ }
+
+ if (new_shape_sinfo->values.defined() && old_shape_values.defined()) {
+ PrimExpr new_shape_prod =
ComputeShapeProduct(new_shape_sinfo->values.value());
+ PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value());
+ if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) {
+ ctx->ReportFatal(Diagnostic::Error(call->span)
+ << "Reshape expects the new shape to be convertible
from the old shape. "
+ "However, the old shape is "
+ << data_sinfo->shape << ", with product " <<
old_shape_prod
+ << ", while the new shape is " << call->args[1] << ",
with product "
+ << new_shape_prod);
+ }
+ }
+ return TensorStructInfo(call->args[1], data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.reshape")
+ .set_num_inputs(2)
+ .add_argument("x", "Tensor", "The input tensor.")
+ .add_argument("shape", "Shape", "The input new shape.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape);
+
+} // namespace relax
+} // namespace tvm
diff --git a/include/tvm/relax/backend.h b/src/relax/op/tensor/manipulate.h
similarity index 58%
copy from include/tvm/relax/backend.h
copy to src/relax/op/tensor/manipulate.h
index 4ebeacac0f..1a3eb0547d 100644
--- a/include/tvm/relax/backend.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -18,27 +18,28 @@
*/
/*!
- * \file tvm/relax/backend.h
- * \brief Relax backend specific transformation passes.
+ * \file manipulate.h
+ * \brief The functions to make Relax tensor manipulation operator calls.
*/
-#ifndef TVM_RELAX_BACKEND_H_
-#define TVM_RELAX_BACKEND_H_
+#ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_
+#define TVM_RELAX_OP_TENSOR_MANIPULATE_H_
-#include <tvm/relax/transform.h>
+#include "../op_common.h"
namespace tvm {
namespace relax {
-namespace transform {
/*!
- * \brief Lower the shape expression in relax to VM shape heap and TIR
functions.
- *
- * \return The Pass.
+ * \brief Reshape the input array, supporting `-1` inference in the new
+ * shape when the new shape is given as an Array of PrimExpr.
+ * \param x The input data to the operator.
+ * \param shape The new shape. Should be compatible with the original shape.
+ * It is required to be either an Array of PrimExpr, or a Shape in Relax
+ * \return The reshaped result.
*/
-TVM_DLL Pass VMShapeLower();
+Expr reshape(Expr x, ObjectRef shape);
-} // namespace transform
} // namespace relax
} // namespace tvm
-#endif // TVM_RELAX_BACKEND_H_
+#endif // TVM_RELAX_OP_TENSOR_MANIPULATE_H_
diff --git a/src/relax/transform/attach_global_symbol.cc
b/src/relax/transform/attach_global_symbol.cc
new file mode 100644
index 0000000000..be779e97bc
--- /dev/null
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/attach_global_symbol.cc
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for
codegen.
+ */
+
+#include <tvm/relax/transform.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace relax {
+
+class GlobalSymbolAttacher {
+ public:
+ explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {}
+
+ IRModule Attach() {
+ IRModule ret;
+ for (auto& p : mod_->functions) {
+ BaseFunc func = p.second;
+ if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
+ func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
p.first->name_hint);
+ } else if (auto* relax_func = func.as<FunctionNode>()) {
+ func = WithAttr(GetRef<Function>(relax_func), "global_symbol",
p.first->name_hint);
+ } else {
+ LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
+ throw;
+ }
+ ret->Add(p.first, func);
+ }
+ return ret;
+ }
+
+ private:
+ IRModule mod_;
+};
+
+namespace transform {
+
+Pass AttachGlobalSymbol() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) { return
GlobalSymbolAttacher(mod).Attach(); };
+ return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/call_tir_rewrite.cc
b/src/relax/transform/call_tir_rewrite.cc
new file mode 100644
index 0000000000..2ea039e022
--- /dev/null
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/call_tir_rewrite.cc
+ * \brief Perform explicit tensor allocation for call_tir.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+#include "../../relay/transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relax {
+
+// ==================
+// CallTIRMutator
+// Perform explicit tensor allocation for call_tir.
+// Example:
+// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
+// -->
+// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32")
+// rx.call_packed(func, x, gv0)
+
+class CallTIRMutator : public ExprMutator {
+ public:
+ using ExprMutator::VisitExpr_;
+ Expr VisitExpr_(const CallNode* call) override {
+ // post-order mutation
+ Expr expr = VisitExprPostOrder_(call);
+ call = expr.as<CallNode>();
+
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+ static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
+
+ if (call->op == call_tir_op) {
+ Array<Expr> outs;
+ if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr))
{
+ // single output case
+ const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
+ ICHECK(tensor_sinfo->shape.defined())
+ << "the TensorStructInfo shape of call_tir has not populated";
+ outs.push_back(
+ builder_->Emit(Call(alloc_tensor_op, //
+
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
+ DataTypeImm(tensor_sinfo->dtype),
PrimValue::Int64(0)}, //
+ Attrs()),
+ "alloc"));
+ } else if (const auto& _tuple_sinfo =
MatchStructInfo<TupleStructInfo>(expr)) {
+ // multiple output case
+ const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value();
+ for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+ const auto& field = tuple_sinfo->fields[i];
+
+ ICHECK(field->IsInstance<TensorStructInfoNode>())
+ << "call_tir expects Tuple of TensorStructInfo, but got " <<
field
+ << " as an element of TupleStructInfo";
+ const auto& field_tensor = Downcast<TensorStructInfo>(field);
+ ICHECK(field_tensor->shape.defined())
+ << "call_tir expects all TensorStructInfo has shape, but got "
<< field_tensor
+ << " as an element of TupleStructInfo";
+ outs.push_back(
+ builder_->Emit(Call(alloc_tensor_op,
+
{Downcast<ShapeExpr>(field_tensor->shape.value()),
+ DataTypeImm(field_tensor->dtype),
PrimValue::Int64(0)},
+ Attrs()),
+ "alloc"));
+ }
+ } else {
+ LOG(FATAL) << "TypeError: The struct info of call_tir expects to be
TensorStructInfo or "
+ "TupleStructInfo, but got"
+ << expr->struct_info_;
+ }
+
+ Array<Expr> args;
+ if (call->args[1].as<TupleNode>()) {
+ args = Downcast<Tuple>(call->args[1])->fields;
+ args.insert(args.end(), outs.begin(), outs.end());
+
+ if (call->args.size() == 2) {
+ builder_->Emit(Call(call->args[0], args), "_");
+ } else {
+ // unpack semantics
+ args.push_back(call->args[2]);
+ builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}),
"_");
+ }
+ } else {
+ args = outs;
+ args.insert(args.begin(), call->args[1]);
+ builder_->Emit(Call(call->args[0], args), "_");
+ }
+
+ if (outs.size() == 1) {
+ return outs[0];
+ }
+ return std::move(Tuple(outs));
+ }
+
+ return GetRef<Expr>(call);
+ }
+};
+
+Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass CallTIRRewrite() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(CallTIRRewrite(f)); };
+ return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc
b/src/relax/transform/rewrite_dataflow_reshape.cc
new file mode 100644
index 0000000000..aec0911ecc
--- /dev/null
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/rewrite_dataflow_reshape.cc
+ * \brief Transform all reshape within dataflow block to a relax.reshape
operator
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+class DataflowReshapeRewriter : public ExprMutator {
+ public:
+ explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}
+
+ private:
+ using ExprMutator::VisitExpr_;
+
+ BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+ // We only rewrite the bindings inside dataflow blocks.
+ if (const auto* dataflow_block = block.as<DataflowBlockNode>()) {
+ return VisitBindingBlock_(dataflow_block);
+ } else {
+ return block;
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ // We only rewrite the bindings that are not dataflow output (which means
they are not
+ // externally referenced)
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
+ } else {
+ ExprMutator::VisitBinding_(binding);
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* call) final {
+ if (!IsCallingTIRReshape(call)) {
+ return GetRef<Call>(call);
+ }
+
+ // We bring the calls of reshape PrimFunc back to calls of high-level
+ // relax.reshape op, which will be lowered to calls of the ExternFunc
+ // vm.builtin.reshape in the VMBuiltinLower pass.
+ Array<Expr> args = Downcast<Tuple>(call->args[1])->fields;
+ ICHECK_EQ(args.size(), 1);
+ TensorStructInfo res_sinfo =
Downcast<TensorStructInfo>(call->struct_info_);
+ ICHECK(res_sinfo->shape.defined());
+ return reshape(args[0], res_sinfo->shape.value());
+ }
+
+ bool IsCallingTIRReshape(const CallNode* call) {
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+ if (call->op != call_tir_op) {
+ return false;
+ }
+ const auto* gv = call->args[0].as<GlobalVarNode>();
+ if (gv == nullptr) {
+ return false;
+ }
+ const auto* func =
mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
+ ICHECK_NOTNULL(func);
+ return HasReshapePattern(GetRef<tir::PrimFunc>(func));
+ }
+
+ const IRModule& mod_;
+};
+
+Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) {
+ return DataflowReshapeRewriter(mod)(f);
+}
+
+namespace transform {
+
+Pass RewriteDataflowReshape() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(RewriteDataflowReshape(f, m));
+ };
+ return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape")
+ .set_body_typed(RewriteDataflowReshape);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/to_non_dataflow.cc
b/src/relax/transform/to_non_dataflow.cc
new file mode 100644
index 0000000000..db2e9d7ee5
--- /dev/null
+++ b/src/relax/transform/to_non_dataflow.cc
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_non_dataflow.cc
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+class ToNonDFMutator : public ExprMutator {
+ public:
+ Var VisitVarDef(const Var& var) final {
+ if (var.as<DataflowVarNode>()) {
+ Var new_var = Var(var->vid, GetStructInfo(var), var->span);
+ this->var_remap_[var->vid] = new_var;
+ return new_var;
+ }
+ return var;
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+ builder_->BeginBindingBlock();
+ for (Binding binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+ return builder_->EndBlock();
+ }
+};
+
+Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); }
+
+namespace transform {
+
+Pass ToNonDataflow() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) { return
Downcast<Function>(ToNonDataflow(f)); };
+ return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
new file mode 100644
index 0000000000..20ea1a4593
--- /dev/null
+++ b/tests/python/relax/test_analysis.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Set, Union
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm import relax as rx
+from tvm.relax.analysis import has_reshape_pattern
+from tvm.script import relax as R, tir as T
+
+
+def test_reshape_pattern_reshape():
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"],
+ T_reshape: T.Buffer[(8, 3), "float32"],
+ ):
+ for i0, i1 in T.grid(8, 3):
+ with T.block("T_reshape"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(
+ rxplaceholder[
+ (ax0 * 3 + ax1) // 24,
+ (ax0 * 3 + ax1) % 24 // 12,
+ (ax0 * 3 + ax1) % 12 // 4,
+ (ax0 * 3 + ax1) % 4,
+ ]
+ )
+ T.writes(T_reshape[ax0, ax1])
+ T_reshape[ax0, ax1] = rxplaceholder[
+ (ax0 * 3 + ax1) // 24,
+ (ax0 * 3 + ax1) % 24 // 12,
+ (ax0 * 3 + ax1) % 12 // 4,
+ (ax0 * 3 + ax1) % 4,
+ ]
+
+ assert has_reshape_pattern(reshape)
+
+
+def test_reshape_pattern_reshape_scheduled():
+ @T.prim_func
+ def reshape_scheduled(
+ rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"],
+ T_reshape: T.Buffer[(8, 3), "float32"],
+ ):
+ for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"):
+ with T.block("T_reshape"):
+ ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 +
i0_i1_fused_1) // 3)
+ ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 +
i0_i1_fused_1) % 3)
+ T.reads(
+ rxplaceholder[
+ (ax0 * 3 + ax1) // 24,
+ (ax0 * 3 + ax1) % 24 // 12,
+ (ax0 * 3 + ax1) % 12 // 4,
+ (ax0 * 3 + ax1) % 4,
+ ]
+ )
+ T.writes(T_reshape[ax0, ax1])
+ T_reshape[ax0, ax1] = rxplaceholder[
+ (ax0 * 3 + ax1) // 24,
+ (ax0 * 3 + ax1) % 24 // 12,
+ (ax0 * 3 + ax1) % 12 // 4,
+ (ax0 * 3 + ax1) % 4,
+ ]
+
+ assert has_reshape_pattern(reshape_scheduled)
+
+
+def test_reshape_pattern_expand_dims():
+ @T.prim_func
+ def expand_dims(
+ rxplaceholder: T.Buffer[(2, 3, 4), "float32"],
+ expand_dims: T.Buffer[(2, 1, 1, 1, 3, 1, 4, 1), "float32"],
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1):
+ with T.block("expand_dims"):
+ i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap(
+ "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]
+ )
+ T.reads(rxplaceholder[i0_1, i4_1, i6_1])
+ T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1,
i7_1])
+ expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] =
rxplaceholder[
+ i0_1, i4_1, i6_1
+ ]
+
+ assert has_reshape_pattern(expand_dims)
+
+
+def test_reshape_pattern_with_raggedness():
+ @T.prim_func
+ def reshape_raggedness(
+ A: T.Buffer[(100, 768), "float32"],
+ src_indptr: T.Buffer[(9,), "int32"],
+ B: T.Buffer[(100, 12, 64), "float32"],
+ ):
+ for b in T.serial(8):
+ with T.block("block0"):
+ vb = T.axis.spatial(8, b)
+ for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]):
+ for h in T.serial(12):
+ for f in T.serial(64):
+ with T.block("block1"):
+ vi, vh, vf = T.axis.remap("SSS", [i, h, f])
+ B[src_indptr[vb] + vi, vh, vf] = A[
+ src_indptr[vb] + vi, vh * 64 + vf
+ ]
+
+ assert has_reshape_pattern(reshape_raggedness)
+
+
+def test_reshape_pattern_reject_seqstmt():
+ @T.prim_func
+ def identity_bias(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4),
"float32"]):
+ C = T.alloc_buffer((128, 128), "float32")
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SS", [i0, i1])
+ C[vi0, vi1] = A[vi0, vi1]
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SS", [i0, i1])
+ B[vi0, vi1] = C[vi0, vi1] + T.float32(1)
+
+ @T.prim_func
+ def identity_identity(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4),
"float32"]):
+ C = T.alloc_buffer((128, 128), "float32")
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SS", [i0, i1])
+ C[vi0, vi1] = A[vi0, vi1]
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SS", [i0, i1])
+ B[vi0, vi1] = C[vi0, vi1]
+
+ assert not has_reshape_pattern(identity_bias)
+ assert not has_reshape_pattern(identity_identity)
+
+
+def test_reshape_pattern_reject_reduction():
+ @T.prim_func
+ def reduction(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4,),
"float32"]):
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SR", [i0, i1])
+ with T.init():
+ B[vi0] = T.float32(0)
+ B[vi0] = B[vi0] + A[vi0, vi1]
+
+ assert not has_reshape_pattern(reduction)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
new file mode 100644
index 0000000000..624b7877cd
--- /dev/null
+++ b/tests/python/relax/test_transform.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import relax
+from tvm.ir import structural_equal
+from tvm.ir.base import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def test_to_non_dataflow():
+ @tvm.script.ir_module
+ class TestToNonDataflow:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.var("int64"), T.var("int64")
+ with R.dataflow():
+ lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32"))
+ R.output(gv0)
+ return gv0
+
+ mod = TestToNonDataflow
+
+ old_vars = []
+
+ def fvisit(e):
+ if isinstance(e, relax.Var):
+ nonlocal old_vars
+ old_vars.append(e)
+
+ relax.analysis.post_order_visit(mod["foo"], fvisit)
+ x, lv0, gv0 = old_vars
+
+ new_mod = relax.transform.ToNonDataflow()(mod)
+
+ new_vars = []
+
+ def fvisit(e):
+ if isinstance(e, relax.Var):
+ nonlocal new_vars
+ new_vars.append(e)
+
+ relax.analysis.post_order_visit(new_mod["foo"], fvisit)
+
+ assert x == new_vars[0]
+ assert lv0 != new_vars[1]
+ assert isinstance(lv0, relax.DataflowVar)
+ assert not isinstance(new_vars[1], relax.DataflowVar)
+
+ assert isinstance(gv0, relax.Var)
+ assert isinstance(new_vars[2], relax.Var)
+ assert gv0 == new_vars[2]
+
+
+def test_call_tir_rewrite():
+ @tvm.script.ir_module
+ class TestCallTIRRewrite:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.var("int64"), T.var("int64")
+ gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ return gv0
+
+ mod = TestCallTIRRewrite
+
+ # before rewrite
+ v0 = mod["foo"].body.blocks[0].bindings[0].var
+ s0 = mod["foo"].body.blocks[0].bindings[0].value
+ assert isinstance(s0, relax.Call)
+ assert s0.op.name == "relax.call_tir"
+
+ # after rewrite
+ new_mod = relax.transform.CallTIRRewrite()(mod)
+ func = new_mod["foo"]
+
+ block = func.body.blocks[0]
+ assert not isinstance(block, relax.DataflowBlock)
+
+ s1 = block.bindings[0].value
+ assert isinstance(s1, relax.Call)
+ assert s1.op.name == "relax.builtin.alloc_tensor"
+ assert isinstance(s1.args[0], relax.ShapeExpr)
+ assert structural_equal(s1.args[0], s0.sinfo_args[0].shape)
+ s2 = block.bindings[1].value
+ assert s2.op.global_symbol == "test.op.identity"
+
+
+def test_vm_builtin_lower():
+ @tvm.script.ir_module
+ class TestVMBuiltinLower:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+ m, n = T.var("int64"), T.var("int64")
+ alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0,
dtype="float32")
+ _ = R.call_packed(
+ "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
+ )
+ gv0 = alloc
+ return gv0
+
+ mod = TestVMBuiltinLower
+
+ # after vm builtin lowering
+ new_mod = relax.transform.VMBuiltinLower()(mod)
+ func = new_mod["foo"]
+
+ assert isinstance(new_mod, tvm.IRModule)
+ assert isinstance(func, tvm.relax.expr.Function)
+
+ block = func.body.blocks[0]
+ s1 = block.bindings[0].value
+ assert isinstance(s1, relax.Call)
+ assert s1.op.name == "relax.vm.alloc_storage"
+ s2 = block.bindings[1].value
+ assert isinstance(s2, relax.Call)
+ s3 = block.bindings[2].value
+ assert isinstance(s3, relax.Call)
+ assert isinstance(s3.op, relax.ExternFunc)
+ assert s3.op.global_symbol == "test.op.identity"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py
b/tests/python/relax/test_transform_attach_global_symbol.py
new file mode 100644
index 0000000000..edfc646e21
--- /dev/null
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import tir, relax
+from tvm.ir import assert_structural_equal
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
[email protected]_module
+class Before:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ m = T.var("int64")
+ n = T.var("int64")
+ k = T.var("int64")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")) -> R.Tensor:
+ m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+ gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k),
dtype="float32"))
+ return gv0
+
+
+def test_basic():
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_matmul"})
+ m = T.var("int64")
+ n = T.var("int64")
+ k = T.var("int64")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def main(
+ x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")
+ ) -> R.Tensor:
+ R.func_attr({"global_symbol": "main"})
+ m, n, k = T.var("int64"), T.var("int64"), T.var("int64")
+ gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k),
dtype="float32"))
+ return gv0
+
+ before = Before
+ expected = Expected
+ after = relax.transform.AttachGlobalSymbol()(before)
+ assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
new file mode 100644
index 0000000000..9e67ef2d89
--- /dev/null
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R, tir as T
+
+
+def test_reshape_expand_dims():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer[(T.int64(8), T.int64(3)), "float32"],
+ T_reshape: T.Buffer[(T.int64(2), T.int64(4), T.int64(3)),
"float32"],
+ ):
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(
+ rxplaceholder[
+ (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+ (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+ T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+ (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+ (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+ ]
+
+ @T.prim_func
+ def expand_dims(
+ rxplaceholder: T.Buffer[(T.int64(2), T.int64(4), T.int64(3)),
"float32"],
+ expand_dims: T.Buffer[
+ (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)),
+ "float32",
+ ],
+ ):
+ for i0, i1, i2, i3, i4 in T.grid(
+ T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
+ ):
+ with T.block("expand_dims"):
+ i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(rxplaceholder[i0_1, i2_1, i4_1])
+ T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
+ expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] =
rxplaceholder[i0_1, i2_1, i4_1]
+
+ @R.function
+ def main(
+ x: R.Tensor((8, 3), dtype="float32")
+ ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
+ with R.dataflow():
+ y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3),
dtype="float32"))
+ z = R.call_tir(expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4,
1, 3), "float32"))
+ R.output(z)
+ return z
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer[(T.int64(8), T.int64(3)), "float32"],
+ T_reshape: T.Buffer[(T.int64(2), T.int64(4), T.int64(3)),
"float32"],
+ ):
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(
+ rxplaceholder[
+ (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2)
// T.int64(3),
+ (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2)
% T.int64(3),
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+ T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+ (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) //
T.int64(3),
+ (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) %
T.int64(3),
+ ]
+
+ @T.prim_func
+ def expand_dims(
+ rxplaceholder: T.Buffer[(T.int64(2), T.int64(4), T.int64(3)),
"float32"],
+ expand_dims: T.Buffer[
+ (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)),
"float32"
+ ],
+ ):
+ for i0, i1, i2, i3, i4 in T.grid(
+ T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)
+ ):
+ with T.block("expand_dims"):
+ i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(rxplaceholder[i0_1, i2_1, i4_1])
+ T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1])
+ expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] =
rxplaceholder[i0_1, i2_1, i4_1]
+
+ @R.function
+ def main(
+ x: R.Tensor((8, 3), dtype="float32")
+ ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
+ with R.dataflow():
+ y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3))
+ # Note: `z` is the output var of the dataflow block, and is
thus
+ # not expected to be rewritten.
+ z = R.call_tir(
+ expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3),
dtype="float32")
+ )
+ R.output(z)
+ return z
+
+ assert relax.analysis.has_reshape_pattern(Module["expand_dims"])
+ mod = relax.transform.RewriteDataflowReshape()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_reshape_non_dataflow():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer[(T.int64(8), T.int64(3)), "float32"],
+ T_reshape: T.Buffer[(T.int64(2), T.int64(4), T.int64(3)),
"float32"],
+ ):
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(
+ rxplaceholder[
+ (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+ (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+ T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[
+ (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3),
+ (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3),
+ ]
+
+ @R.function
+ def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3),
dtype="float32"):
+ y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3),
dtype="float32"))
+ return y
+
+ assert relax.analysis.has_reshape_pattern(Module["reshape"])
+ # The binding var of the call_tir is not a DataflowVar. So the pass does
no change.
+ mod = relax.transform.RewriteDataflowReshape()(Module)
+ tvm.ir.assert_structural_equal(mod, Module)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
new file mode 100644
index 0000000000..8d98f0de32
--- /dev/null
+++ b/tests/python/relax/test_vm_build.py
@@ -0,0 +1,908 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from typing import Tuple, Callable
+
+import sys
+import tempfile
+import numpy as np
+import pytest
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax, rpc, te, tir, topi
+from tvm.contrib import utils
+from tvm.relax.testing import nn
+from tvm.script import relax as R, tir as T
+from tvm.relax.testing.vm import check_saved_func
+
+EXEC_MODE = ["bytecode"]
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_compile_simple(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileStage0:
+ @R.function
+ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4),
"float32")):
+ z = R.call_packed(
+ "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
+ )
+ return y
+
+ mod = TestVMCompileStage0
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+ inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm["foo"](inp1, inp2)
+ tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_match_check(exec_mode):
+ @tvm.script.ir_module
+ class TestMatchCheck:
+ @R.function
+ def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) ->
R.Tensor(["m", "n"], dtype=None):
+ return y
+
+ mod = TestMatchCheck
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32"))
+ y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32"))
+ y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32"))
+ y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32"))
+
+ vm["foo"](x0, y0)
+
+ with pytest.raises(RuntimeError, match=".*return.*"):
+ vm["foo"](x0, y1)
+
+ with pytest.raises(ValueError, match=".*return.*"):
+ vm["foo"](x0, y2)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_compile_stage2(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileStage2:
+ @R.function
+ def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
+ n, m = T.var("int64"), T.var("int64")
+ _ = R.match_cast(x, R.Tensor((n, m), "float32"))
+ return (n * 2, m * 3)
+
+ mod = TestVMCompileStage2
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ shape = (32, 16)
+ arr = tvm.nd.array(np.random.rand(*shape).astype("float32"))
+ res = vm["foo"](arr)
+ assert res[0] == shape[0] * 2
+ assert res[1] == shape[1] * 3
+
+ # dtype mismatch
+ with pytest.raises(ValueError, match=".*dtype.*"):
+ vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32")))
+
+ # ndim mismatch
+ with pytest.raises(ValueError, match=".*match_cast.*ndim.*"):
+ vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32")))
+
+ # type mismach
+ with pytest.raises(TypeError):
+ vm["foo"]([])
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_compile_stage3(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileStage3:
+ @R.function
+ def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor:
+ with R.dataflow():
+ y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16),
dtype="float32"))
+ R.output(y)
+ return y
+
+ mod = TestVMCompileStage3
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ shape = (32, 16)
+ inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+ res = vm["foo"](inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_compile_e2e(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileE2E:
+ @R.function
+ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor:
+ with R.dataflow():
+ n, m = T.var("int64"), T.var("int64")
+ _ = R.match_cast(x, R.Tensor((n, m), "float32"))
+ y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2),
dtype="float32"))
+ R.output(y)
+ return y
+
+ mod = TestVMCompileE2E
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ shape = (32, 16)
+ inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+ res = check_saved_func(vm, "foo", inp)
+ tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)),
rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_compile_e2e_func_param_with_shape(exec_mode):
+ @tvm.script.ir_module
+ class TestVMCompileE2E2:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_matmul"})
+ m = T.var("int32")
+ n = T.var("int32")
+ k = T.var("int32")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def func(
+ x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")
+ ) -> R.Tensor:
+ m, k = T.var("int64"), T.var("int64")
+ gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
+ return gv0
+
+ mod = TestVMCompileE2E2
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
+ weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
+ res = check_saved_func(vm, "func", data, weight)
+ expected = np.dot(data.numpy(), weight.numpy())
+ tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_emit_te_extern(exec_mode):
+ if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
+ print("skip because extern function is not available")
+ return
+ bb = relax.BlockBuilder()
+ n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+ x = relax.Var("x", R.Tensor([n, m], "float32"))
+ y = relax.Var("y", R.Tensor([m, n], "float32"))
+
+ with bb.function("rx_cblas_matmul", [x, y]):
+ out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False,
transb=False)
+ bb.emit_func_output(out)
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
+ weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
+ res = check_saved_func(vm, "rx_cblas_matmul", data, weight)
+ expected = np.dot(data.numpy(), weight.numpy())
+ tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_emit_te_concat(exec_mode):
+ # concatenate of two vectors of size (n,) and (m,)
+ bb = relax.BlockBuilder()
+ n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+ x = relax.Var("x", R.Tensor([n], "float32"))
+ y = relax.Var("y", R.Tensor([m], "float32"))
+
+ def te_func(A, B):
+ C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i],
B[i - n]))
+ return C
+
+ with bb.function("rx_func", [x, y]):
+ x1 = bb.emit_te(te_func, x, y)
+ bb.emit_func_output(x1)
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(
+ np.random.rand(
+ 1,
+ ).astype(np.float32)
+ )
+ inp2 = tvm.nd.array(
+ np.random.rand(
+ 2,
+ ).astype(np.float32)
+ )
+ res = check_saved_func(vm, "rx_func", inp, inp2)
+ tvm.testing.assert_allclose(
+ res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7
+ )
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_emit_te_dtype_change(exec_mode):
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor([n], "float32"))
+
+ # convert a tensor with dtype of float32 to int16
+ def te_func(A):
+ B = te.compute((n,), lambda i: A[i].astype("int16"))
+ return B
+
+ with bb.function("rx_func", [x]):
+ y = bb.emit_te(te_func, x)
+ bb.emit_func_output(y)
+
+ mod = bb.get()
+
+ new_mod = relax.transform.CallTIRRewrite()(mod)
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(
+ np.random.rand(
+ 1,
+ ).astype(np.float32)
+ )
+ res = check_saved_func(vm, "rx_func", inp)
+ np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16"))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_emit_te_floor_symbolic_shape(exec_mode):
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor([n], "float32"))
+
+ def te_func(A):
+ C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1)
+ return C
+
+ with bb.function("rx_func", [x]):
+ x1 = bb.emit_te(te_func, x)
+ bb.emit_func_output(x1)
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ shape = (9,)
+ inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+ res = check_saved_func(vm, "rx_func", inp)
+
+ def expected_output():
+ output_shape = (shape[0] // 2,)
+ return inp.numpy()[: output_shape[0]] + 1
+
+ tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_emit_te_constant_param_cpu(exec_mode):
+ x_np = np.random.rand(2, 2).astype("float32")
+ c_np = np.random.rand(2, 2).astype("float32")
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 2), "float32"))
+ c = relax.const(c_np, "float32")
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, c)
+ gv = bb.emit_output(lv0)
+ bb.emit_func_output(gv)
+
+ mod = bb.get()
+ exec = relax.vm.build(mod, "llvm", exec_mode=exec_mode)
+ dev = tvm.cpu()
+ vm = relax.VirtualMachine(exec, dev)
+
+ add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev))
+ tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]_gpu
+def test_vm_emit_te_constant_param_gpu(exec_mode):
+ x_np = np.random.rand(2, 2).astype("float32")
+ c_np = np.random.rand(2, 2).astype("float32")
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 2), "float32"))
+ c = relax.const(c_np, "float32")
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, c)
+ gv = bb.emit_output(lv0)
+ bb.emit_func_output(gv)
+
+ mod = bb.get()
+ sch = tvm.tir.Schedule(mod, debug_mask="all")
+ loops = sch.get_loops(sch.get_block(name="T_add", func_name="add"))
+ sch.bind(loops[0], "threadIdx.x")
+
+ exec = relax.vm.build(sch.mod, "cuda", exec_mode=exec_mode)
+ dev = tvm.cuda()
+ vm = relax.VirtualMachine(exec, dev)
+
+ add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev))
+ tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_shape(exec_mode):
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor([n], "float32"))
+ y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32"))
+
+ def te_func(A, B):
+ C = te.compute((n,), lambda i: A[i] + B[i // 2])
+ return C
+
+ with bb.function("rx_func", [x, y]):
+ x1 = bb.emit_te(te_func, x, y)
+ bb.emit_func_output(x1)
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ shape1 = (5,)
+ shape2 = (3,)
+ inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32))
+ inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32))
+ res = check_saved_func(vm, "rx_func", inp, inp2)
+
+ def expected_output():
+ return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5]
+
+ tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_dyn_tir_shape(exec_mode):
+ # case where TIR variables are unbound in generated PrimFunc
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+
+ def te_func(A):
+ C = te.compute((n + 1), lambda i: A[i])
+ return C
+
+ with bb.function("rx_func"):
+ x = nn.Placeholder((n,), dtype="float32", name="x")
+ y = nn.Placeholder((n + 1,), dtype="float32", name="y")
+
+ x1 = bb.emit_te(te_func, y)
+ bb.emit_func_output(x1, params=[x, y])
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ ex.mod.export_library("exec.so")
+ exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+ os.remove("exec.so")
+ assert ex.as_text() == exec1.as_text()
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ inp = tvm.nd.array(np.random.rand(2).astype(np.float32))
+ inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32))
+
+ res = check_saved_func(vm, "rx_func", inp, inp2)
+
+ tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_tuple(exec_mode):
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+
+ with bb.function("rx_func"):
+ x = nn.Placeholder((n,), dtype="float32", name="x")
+ y = nn.Placeholder((n,), dtype="float32", name="y")
+ tup = relax.Tuple([x, y])
+ item = tup[0]
+ bb.emit_func_output([tup, item], params=[x, y])
+
+ mod = bb.get()
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ shape = (5,)
+ inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+ inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+ (res1, res2), res3 = vm["rx_func"](inp, inp2)
+
+ tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7,
atol=1e-7)
+ tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7,
atol=1e-7)
+ tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7,
atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_tuplegetitem(exec_mode):
+ @tvm.script.ir_module
+ class TestVMTupleGetItem:
+ @R.function
+ def tuple_get_item(
+ x: R.Tensor(ndim=2, dtype="float32"),
+ y: R.Tensor(ndim=2, dtype="float32"),
+ ):
+ t = (x, y)
+ a = t[0]
+ b = t[1]
+ c = R.call_packed("test.vm.add", a, b,
sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
+ return c
+
+ mod = TestVMTupleGetItem
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp)
+ tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(),
rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_lower_memory_alloc_storage_tensor(exec_mode):
+ @tvm.script.ir_module
+ class TestMemoryAllocStorageTensor:
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")):
+ storage = R.memory.alloc_storage(
+ (24,), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ )
+ y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
+ _ = copy(x, y)
+ return y
+
+ @T.prim_func
+ def copy(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(2, 3),
"float32"]):
+ for i0, i1 in T.grid(2, 3):
+ with T.block("block"):
+ vi0, vi1 = T.axis.remap("SS", [i0, i1])
+ B[vi0, vi1] = A[vi0, vi1]
+
+ mod = TestMemoryAllocStorageTensor
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ y = vm["main"](x)
+ tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_sub_func_call(exec_mode):
+ @tvm.script.ir_module
+ class TestVMSubFunction:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_matmul"})
+ m = T.var("int32")
+ n = T.var("int32")
+ k = T.var("int32")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def relax_matmul_tir(
+ x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
+ ) -> R.Tensor((32, 32), dtype="float32"):
+ with R.dataflow():
+ gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32),
dtype="float32"))
+ R.output(gv0)
+ return gv0
+
+ @R.function
+ def relax_matmul_packed(
+ x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
+ ) -> R.Object:
+ gv0 = R.call_packed("test.vm.mul", x, w,
sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
+ return gv0
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32),
"float32")) -> R.Object:
+ gv0 = relax_matmul_tir(x, w)
+ gv1 = relax_matmul_packed(gv0, gv0)
+ return gv1
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(TestVMSubFunction, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32))
+ y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32))
+ res = check_saved_func(vm, "main", x_inp, y_inp)
+ product = np.dot(x_inp.numpy(), y_inp.numpy())
+ expected = product * product
+ tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_recursion(exec_mode):
+ @tvm.script.ir_module
+ class TestVMRecursion:
+ @R.function
+ def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor:
+ cond = R.call_packed(
+ "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1,
dtype="float32"))
+ )
+ if cond:
+ res = R.const(1.0)
+ else:
+ gv0 = R.call_packed(
+ "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1,
dtype="float32"))
+ )
+ tmp = recursion(gv0)
+ res = R.call_packed(
+ "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1,
dtype="float32"))
+ )
+ return res
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(TestVMRecursion, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ inp = np.empty(1).astype("float32")
+ recursion_runs = np.random.randint(1, 10)
+ inp.fill(recursion_runs)
+ inp = tvm.nd.array(inp)
+ res = check_saved_func(vm, "recursion", inp)
+ tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs),
rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_closure(exec_mode):
+ @tvm.script.ir_module
+ class TestClosure:
+ @R.function
+ def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2,
3), "float32")):
+ return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor))
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), "float32"),
+ y: R.Tensor((2, 3), "float32"),
+ ):
+ clo = R.make_closure(lifted_func_1, (x,))
+ res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor))
+ return res
+
+ mod = TestClosure
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]],
dtype="float32"))
+ res = check_saved_func(vm, "main", x_inp, y_inp)
+ tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy())
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_time_evaluator(exec_mode):
+ @tvm.script.ir_module
+ class TestTimeEvaluator:
+ @R.function
+ def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")):
+ return R.call_packed(
+ "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1,
dtype="float32"))
+ )
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x = tvm.nd.array(np.random.rand(1).astype("float32"))
+ y = tvm.nd.array(np.random.rand(1).astype("float32"))
+
+ # ensure we can use time_evaluator with the stateful API
+ vm.set_input("main", x, y)
+ timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main")
+ # just checking that it has some results at all
+ assert timing_res.results
+
+ # ensure we can use it with a closure
+ vm.save_function("main", "saved_main", x, y)
+ timing_res = vm.time_evaluator("saved_main", tvm.cpu())()
+ assert timing_res.results
+
+
[email protected]_module
+class TestVMSetInput:
+ @T.prim_func
+ def test_vm_mul(x: T.handle, y: T.handle, z: T.handle):
+ T.func_attr({"global_symbol": "test_vm_mul"})
+ m = T.var("int32")
+ n = T.var("int32")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (m, n))
+ C = T.match_buffer(z, (m, n))
+
+ for i, j in T.grid(m, n):
+ with T.block("mul"):
+ vi = T.axis.spatial(m, i)
+ vj = T.axis.spatial(n, j)
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = A[vi, vj] * B[vi, vj]
+
+ # test returning a tuple
+ @R.function
+ def test_vm_tuple(
+ x: R.Tensor((), "int32")
+ ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")):
+ return (x, x)
+
+ # nested tuple too
+ @R.function
+ def test_vm_nested_tuple(
+ x: R.Tensor((), "int32")
+ ) -> R.Tuple(
+ R.Tuple(
+ R.Tensor((), "int32"),
+ R.Tuple(
+ R.Tensor((), "int32"),
+ ),
+ ),
+ R.Tensor((), "int32"),
+ ):
+ return ((x, (x,)), x)
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32),
"float32")) -> R.Tensor:
+ gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32),
dtype="float32"))
+ return gv0
+
+
+def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) ->
None:
+ a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ vm.set_input("main", a, b)
+ vm.invoke_stateful("main")
+ res0 = vm.get_outputs("main")
+
+ data_dict = {"x": a, "w": b}
+ vm.set_input("main", **data_dict)
+ vm.invoke_stateful("main")
+ res1 = vm.get_outputs("main")
+ tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(),
rtol=1e-7, atol=1e-7)
+ tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7,
atol=1e-7)
+
+ # bug! If you don't bind the NDArray to a var, the memory will get
corrupted.
+ # Possibly due to object lifecycles and other FFI issues
+ a = tvm.nd.array(np.array(2).astype("int32"), device)
+ vm.set_input("test_vm_tuple", a)
+ vm.invoke_stateful("test_vm_tuple")
+ res2 = vm.get_outputs("test_vm_tuple")
+ # the results are NDArrays wrapped around scalars,
+ # so we have to get the scalar out of the NDArray
+ assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2)
+
+ b = tvm.nd.array(np.array(1).astype("int32"), device)
+ vm.set_input("test_vm_nested_tuple", b)
+ vm.invoke_stateful("test_vm_nested_tuple")
+ res3 = vm.get_outputs("test_vm_nested_tuple")
+ assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1
+ result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)),
int(res3[1].numpy()))
+ assert result_cast == ((1, (1,)), 1)
+
+
+def set_input_attempt_stateless(vm: relax.VirtualMachine, device:
tvm.runtime.Device) -> None:
+ # this should fail: once you set inputs, you cannot run statelessly
+ a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ vm.set_input("main", a, b)
+ # must use invoke stateful!
+ vm["main"]()
+
+
+def set_input_attempt_invoke(vm: relax.VirtualMachine, device:
tvm.runtime.Device) -> None:
+ # this should fail: if the function needs inputs, you can't invoke directly
+ vm.invoke_stateful("main")
+
+
+def set_input_attempt_get(vm: relax.VirtualMachine, device:
tvm.runtime.Device) -> None:
+ # this should fail: you can't get outputs without invoking the function
first
+ a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ vm.set_input("main", a, b)
+ _ = vm.get_outputs("main")
+
+
+def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]:
+ """Returns a local VM for the given mod and the device"""
+ target = tvm.target.Target("llvm", host="llvm")
+ exec = relax.vm.build(TestVMSetInput, target, exec_mode=exec_mode)
+ exec.mod.export_library("exec.so")
+ exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+ os.remove("exec.so")
+ device = tvm.cpu()
+ return relax.VirtualMachine(exec_loaded, device), device
+
+
+def run_on_rpc(
+ mod: tvm.IRModule,
+ trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None],
+ exec_mode: str,
+):
+ """
+ Sets up a VM over localhost using the given mod and runs the given trial
function.
+ The trial function should take a VM and a device
+ """
+ target = tvm.target.Target("llvm", host="llvm")
+ exec = relax.vm.build(mod, target, exec_mode=exec_mode)
+ temp = utils.tempdir()
+ path = temp.relpath("vm_library.so")
+ exec.mod.export_library(path)
+
+ # Use local rpc server for testing.
+ # Server must use popen so it doesn't inherit the current process state. It
+ # will crash otherwise.
+ # Adapted from relay/test_vm.py
+ def check_remote(server):
+ remote = rpc.connect(server.host, server.port, session_timeout=10)
+
+ # Upload the serialized Executable.
+ remote.upload(path)
+ # Get a handle to remote Executable.
+ rexec = remote.load_module("vm_library.so")
+
+ device = remote.cpu()
+ # Build a VM out of the executable and context.
+ vm = relax.vm.VirtualMachine(exec=rexec, device=device)
+ trial_func(vm, device)
+
+ check_remote(rpc.Server("127.0.0.1"))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_set_input(exec_mode):
+ set_input_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
+def save_function_kwargs_trial(vm: relax.VirtualMachine, device:
tvm.runtime.Device) -> None:
+ # just checking that we can use kwargs for the args when saving a function
+ a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ vm.save_function("main", "saved_main", x=a, w=b)
+ res0 = vm["saved_main"]()
+ tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(),
rtol=1e-7, atol=1e-7)
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_save_function_kwargs(exec_mode):
+ save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_save_function_kwargs_rpc(exec_mode):
+ run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode)
+
+
+def save_function_time_evaluator_trial(
+ vm: relax.VirtualMachine, device: tvm.runtime.Device
+) -> None:
+ # just checking that the saved function can be called in the time evaluator
+ a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
+ vm.save_function("main", "saved_main", a, b)
+ vm.time_evaluator("saved_main", device)()
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_save_function_time_evaluator(exec_mode):
+ save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_save_function_time_evaluator(exec_mode):
+ run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode)
+
+
+# if you set an input, you should not be able to call statelessly
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_stateless_failure(exec_mode):
+ set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode))
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_stateless_failure_rpc(exec_mode):
+ run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode)
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_invoke_failure(exec_mode):
+ set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode))
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_invoke_failure_rpc(exec_mode):
+ run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode)
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_get_failure(exec_mode):
+ set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode))
+
+
[email protected]("exec_mode", EXEC_MODE)
[email protected]()
+def test_set_input_get_failure_rpc(exec_mode):
+ run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()