This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new faa8a0ad46 [Unity][nn.Module] Introduce operator `empty` (#16327)
faa8a0ad46 is described below
commit faa8a0ad46d2e3159680df0e09a84e5d6376b1fd
Author: Junru Shao <[email protected]>
AuthorDate: Mon Jan 1 20:30:57 2024 -0800
[Unity][nn.Module] Introduce operator `empty` (#16327)
This PR introduces an operator `op.empty` in the `nn.Module` frontend.
It helps us to create an uninitialized memory from the memory pool,
which could be used as temporary scratchpad memory to handcrafted
operators.
---
python/tvm/relax/frontend/nn/op.py | 59 +++++++++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 27 ++++++++++++--
2 files changed, 83 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 3197145289..66f023ef9d 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1142,6 +1142,65 @@ def zeros(
return wrap_nested(_op.zeros(shape, dtype), name)
+def ones(
+ shape: Sequence[IntExpr],
+ dtype: str = "float32",
+ name: str = "ones",
+) -> Tensor:
+ """Construct a tensor of all zeros, with the input shape and dtype.
+
+ Parameters
+ ----------
+ shape : Sequence[IntExpr]
+ The shape of the created tensor.
+
+ dtype : str
+ The data type of the created tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ return wrap_nested(_op.ones(shape, dtype), name)
+
+
+def empty(
+ shape: Sequence[IntExpr],
+ dtype: str = "float32",
+ name: str = "empty",
+) -> Tensor:
+ """Construct an uninitialized tensor, with the input shape and dtype.
+
+ Parameters
+ ----------
+ shape : Sequence[IntExpr]
+ The shape of the created tensor.
+
+ dtype : str
+ The data type of the created tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ return wrap_nested( # type: ignore
+ _op.builtin.alloc_tensor(
+ rx.ShapeExpr(shape), # type: ignore
+ dtype,
+ runtime_device_index=0,
+ ),
+ name,
+ )
+
+
def split(
ary: Tensor,
indices_or_sections: Union[int, Sequence[int]],
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 55870426e4..43f4a9efc0 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -17,12 +17,14 @@
# pylint: disable=missing-docstring, invalid-name
import tvm
import tvm.testing
-from tvm import tir
+from tvm import relax, tir
from tvm.relax.frontend.nn import Module, Tensor, op, spec
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
+# mypy: disable-error-code="attr-defined,valid-type,name-defined"
+
def test_binary():
class Model(Module):
@@ -174,7 +176,7 @@ def test_image():
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1])
conv2d = op.conv2d(padded, weight, bias)
- interpolate = op.interpolate(x, size=[40, 40])
+ interpolate = op.interpolate(x, size=[40, 40]) # type: ignore
return (conv2d, interpolate)
@R.function
@@ -347,7 +349,7 @@ def test_create():
class Model(Module):
def test(self, x: Tensor):
triu_out = op.triu(x)
- full_with_scalar_out = op.full([10, 10], fill_value=10)
+ full_with_scalar_out = op.full([10, 10], fill_value=10) # type:
ignore
full_with_FloatImm_out = op.full(
[10, 10], fill_value=tir.FloatImm(dtype="float32", value=10)
)
@@ -638,5 +640,24 @@ def test_extern():
tvm.ir.assert_structural_equal(irmodule, Expected)
+def test_empty():
+ @tvm.register_func("test_empty_assert", override=True)
+ def test_empty_assert(_lineo, x):
+ assert x.shape == (10, 10)
+ assert x.dtype == "float32"
+
+ class Model(Module):
+ def test(self):
+ result = op.empty([10, 10], dtype="float32")
+ op.debug_func("test_empty_assert", result)
+ return result
+
+ irmodule, _ = Model().export_tvm(spec={"test": {}}, debug=True)
+ ex = relax.build(irmodule, "llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ effects = vm["_initialize_effect"]()
+ vm["test"](*effects)
+
+
if __name__ == "__main__":
tvm.testing.main()