This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6a5b78690 [Relax] Enhance Relax op and ONNX frontend (#17462)
c6a5b78690 is described below
commit c6a5b7869023f7fd7b2926be847d39d363c13def
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 17 01:05:34 2024 +0800
[Relax] Enhance Relax op and ONNX frontend (#17462)
---
include/tvm/relax/attrs/manipulate.h | 11 +++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 66 ++++++++++++++---
python/tvm/relax/op/__init__.py | 5 ++
python/tvm/relax/op/binary.py | 26 +++++++
python/tvm/relax/op/create.py | 68 ++++++++++++++++++
python/tvm/relax/op/manipulate.py | 44 ++++++++++++
python/tvm/relax/transform/legalize_ops/binary.py | 3 +-
python/tvm/relax/transform/legalize_ops/create.py | 30 ++++++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 19 +++++
python/tvm/script/ir_builder/relax/ir.py | 10 +++
python/tvm/topi/tensor.py | 35 ++++++++-
src/relax/op/distributed/binary.cc | 2 +
src/relax/op/tensor/binary.cc | 2 +
src/relax/op/tensor/binary.h | 6 ++
src/relax/op/tensor/create.cc | 84 ++++++++++++++++++++++
src/relax/op/tensor/create.h | 40 ++++++++++-
src/relax/op/tensor/manipulate.cc | 75 +++++++++++++++++++
src/relax/op/tensor/manipulate.h | 12 ++++
tests/python/relax/test_frontend_onnx.py | 26 +++++--
tests/python/relax/test_op_create.py | 58 +++++++++++++++
tests/python/relax/test_op_manipulate.py | 52 ++++++++++++++
21 files changed, 657 insertions(+), 17 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index e53ba3c36e..ea41488354 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -176,6 +176,17 @@ struct ScatterNDAttrs : public
tvm::AttrsNode<ScatterNDAttrs> {
}
}; // struct ScatterNDAttrs
+/*! \brief Attributes used in one_hot operator */
+struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
+ int depth;
+ int axis;
+
+ TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
+ TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
+ TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
+ }
+}; // struct OneHotAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 43c1ec681a..6c9225070d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -287,7 +287,7 @@ class Sub(BinaryBase):
relax_op = relax.op.subtract
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
+ def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)
@@ -298,7 +298,7 @@ class Mul(BinaryBase):
relax_op = relax.op.multiply
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
+ def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)
@@ -309,7 +309,7 @@ class Div(BinaryBase):
relax_op = relax.op.divide
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
+ def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)
@@ -320,7 +320,24 @@ class Pow(BinaryBase):
relax_op = relax.op.power
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
+
+class Mod(BinaryBase):
+ """Converts an onnx Mod node into an equivalent Relax expression."""
+
+ numpy_op = _np.mod
+ relax_op = relax.op.mod
+
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ if attr.get("fmod", 0) == 0:
+ cls.numpy_op = _np.fmod
+ cls.relax_op = relax.op.floor_mod
+ else:
+ cls.numpy_op = _np.mod
+ cls.relax_op = relax.op.mod
return cls.base_impl(bb, inputs, attr, params)
@@ -523,6 +540,23 @@ class LogSoftmax(OnnxOpConverter):
return relax.op.nn.log_softmax(inputs[0], axis=axis)
+class Hardmax(OnnxOpConverter):
+ """Converts an onnx Hardmax node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", -1)
+ indices = inputs[0]
+ dtype = indices.struct_info.dtype
+ axis_len = int(inputs[0].struct_info.shape[axis])
+ argmax = relax.op.argmax(indices, axis=axis)
+ on_value = relax.PrimValue(tvm.tir.const(1.0, dtype))
+ off_value = relax.PrimValue(tvm.tir.const(0.0, dtype))
+
+ one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
+ return one_hot
+
+
class Transpose(OnnxOpConverter):
"""Converts an onnx Transpose node into an equivalent Relax expression."""
@@ -731,6 +765,20 @@ class Size(OnnxOpConverter):
return
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
+class EyeLike(OnnxOpConverter):
+ """Convert an onnx EyeLike node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ k = attr.get("k", 0)
+ input_dtype = inputs[0].struct_info.dtype
+ if "dtype" in attr and get_type(attr["dtype"]) != input_dtype:
+ raise ValueError(
+ f"dtype mismatch between input ({input_dtype}) and attribute
({attr['dtype']})"
+ )
+ return relax.op.eye_like(inputs[0], k, input_dtype)
+
+
class Gemm(OnnxOpConverter):
"""Convert an onnx Gemm node into an equivalent Relax expression."""
@@ -2520,13 +2568,13 @@ class OneHot(OnnxOpConverter):
depth = get_constant(inputs[1], params)
values = get_constant(inputs[2], params)
axis = attr.get("axis", -1)
- dtype = values.struct_info.dtype
assert isinstance(depth, relax.Constant), "Only constant depth
currently supported."
depth = depth.data.numpy().tolist()
assert isinstance(values, relax.Constant), "Only constant values
currently supported."
values = values.data.numpy().tolist()
off_value, on_value = values
- return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth,
axis, dtype)
+ off_value, on_value = relax.PrimValue(off_value),
relax.PrimValue(on_value)
+ return relax.op.one_hot(indices, on_value, off_value, depth, axis)
class Unique(OnnxOpConverter):
@@ -2800,7 +2848,7 @@ def _get_convert_map():
"Sub": Sub,
"Mul": Mul,
"Div": Div,
- # "Mod": Mod,
+ "Mod": Mod,
"Less": Less,
"LessOrEqual": LessOrEqual,
"Greater": Greater,
@@ -2870,7 +2918,7 @@ def _get_convert_map():
"Sigmoid": Sigmoid,
"Softmax": Softmax,
"LogSoftmax": LogSoftmax,
- # "Hardmax": Hardmax,
+ "Hardmax": Hardmax,
"Transpose": Transpose,
"Unsqueeze": Unsqueeze,
"Where": Where,
@@ -2889,7 +2937,7 @@ def _get_convert_map():
"ScatterND": ScatterND,
# "Compress": Compress,
"Size": Size,
- # "EyeLike": EyeLike,
+ "EyeLike": EyeLike,
# Normalization
"BatchNormalization": BatchNormalization,
"LayerNormalization": LayerNormalization,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 84b31ccec0..1603ea2f0f 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -50,6 +50,7 @@ from .binary import (
divide,
equal,
floor_divide,
+ floor_mod,
greater,
greater_equal,
left_shift,
@@ -60,6 +61,7 @@ from .binary import (
logical_xor,
maximum,
minimum,
+ mod,
multiply,
not_equal,
power,
@@ -72,6 +74,8 @@ from .create import (
full_like,
ones,
ones_like,
+ eye,
+ eye_like,
tril,
triu,
zeros,
@@ -89,6 +93,7 @@ from .manipulate import (
flatten,
flip,
layout_transform,
+ one_hot,
permute_dims,
repeat,
reshape,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 7632235cb3..7a41c8b095 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.subtract(x1, x2) # type: ignore
+def mod(x1: Expr, x2: Expr) -> Expr:
+ """Modulo with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ x1 : Expr
+ The first input tensor.
+ x2 : Expr
+ The second input tensor.
+ """
+ return _ffi_api.mod(x1, x2) # type: ignore
+
+
+def floor_mod(x1: Expr, x2: Expr) -> Expr:
+ """Floor modulo with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ x1 : Expr
+ The first input tensor.
+ x2 : Expr
+ The second input tensor.
+ """
+ return _ffi_api.floor_mod(x1, x2) # type: ignore
+
+
###################### Comparison operators ######################
diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py
index 092d79a74d..c61d9521a4 100644
--- a/python/tvm/relax/op/create.py
+++ b/python/tvm/relax/op/create.py
@@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str,
DataType]] = None) -> Expr:
return _ffi_api.zeros_like(x, dtype) # type: ignore
+def eye(
+ n: Union[PrimExprLike, PrimValue],
+ m: Optional[Union[PrimExprLike, PrimValue]] = None,
+ k: Union[PrimExprLike, PrimValue] = 0,
+ dtype: Union[str, DataType] = "float32",
+) -> Expr:
+ """Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.
+
+ Parameters
+ ----------
+ n : Union[PrimExprLike, PrimValue]
+ Number of rows in the output.
+
+ m : Optional[Union[PrimExprLike, PrimValue]]
+ Number of columns in the output. If None, defaults to n.
+
+ k : Union[PrimExprLike, PrimValue]
+ Index of the diagonal: 0 (the default) refers to the main diagonal,
+ a positive value refers to an upper diagonal, and a negative value
+ to a lower diagonal.
+
+ dtype : Union[str, DataType]
+ The data type of the created tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result tensor.
+ """
+ m = n if m is None else m
+ n = n if isinstance(n, PrimValue) else PrimValue(n)
+ m = m if isinstance(m, PrimValue) else PrimValue(m)
+ k = k if isinstance(k, PrimValue) else PrimValue(k)
+ return _ffi_api.eye(n, m, k, dtype) # type: ignore
+
+
+def eye_like(
+ x: Expr,
+ k: Union[PrimExprLike, PrimValue] = 0,
+ dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+ """Return a 2-D tensor with ones on the diagonal and zeros elsewhere,
+ with the same shape as the input tensor.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input tensor, which provides the shape, and dtype
+ when the `dtype` field is not specified.
+
+ k : Union[PrimExprLike, PrimValue]
+ Index of the diagonal: 0 (the default) refers to the main diagonal,
+ a positive value refers to an upper diagonal, and a negative value
+ to a lower diagonal.
+
+ dtype : Optional[Union[str, DataType]]
+ The data type of the created tensor.
+ If dtype is not given, it will by default use the dtype of the input
tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result tensor.
+ """
+ k = k if isinstance(k, PrimValue) else PrimValue(k)
+ return _ffi_api.eye_like(x, k, dtype) # type: ignore
+
+
def arange(
start: Union[PrimExprLike, PrimValue],
end: Optional[Union[PrimExprLike, PrimValue]] = None,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 1673a79b08..3210cc8216 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr,
reduction: str = "updat
"""
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type:
ignore
+
+
+def one_hot(
+ indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int,
axis: int = -1
+) -> Expr:
+ """Returns a one-hot tensor.
+
+ Parameters
+ ----------
+ indices : relax.Expr
+ The indices to set to `on_value`.
+
+ on_value : relax.PrimValue
+ The value to fill at `indices`.
+
+ off_value : relax.PrimValue
+ The value to fill at other locations.
+
+ depth : int
+ The depth of the one-hot dimension.
+
+ axis : int, optional
+ The axis to fill. Default is -1 which adds a new dimension at the end.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ indices = [0, 1, 2]
+ depth = 3
+ on_value = 1
+ off_value = 0
+
+ one_hot(indices, on_value, off_value, depth) =
+ [[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]]
+ """
+ return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) #
type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index d28e100edb..41e317f1e0 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -48,7 +48,8 @@ register_legalize("relax.multiply", _binary(topi.multiply))
register_legalize("relax.power", _binary(topi.power))
register_legalize("relax.subtract", _binary(topi.subtract))
register_legalize("relax.equal", _binary(topi.equal))
-
+register_legalize("relax.mod", _binary(topi.mod))
+register_legalize("relax.floor_mod", _binary(topi.floor_mod))
register_legalize("relax.greater", _binary(topi.greater))
register_legalize("relax.greater_equal", _binary(topi.greater_equal))
register_legalize("relax.less", _binary(topi.less))
diff --git a/python/tvm/relax/transform/legalize_ops/create.py
b/python/tvm/relax/transform/legalize_ops/create.py
index 1b022672d0..8bf85e34de 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -70,6 +70,36 @@ register_legalize("relax.tril", _tril_triu(is_upper=False,
primfunc_name="tril")
register_legalize("relax.triu", _tril_triu(is_upper=True,
primfunc_name="triu"))
+def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc:
+ def eye_call_te(bb: BlockBuilder, call: Call) -> Expr:
+ _convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x,
python_native=True)
+ if is_like:
+ x = call.args[0]
+ k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1
else 0
+ n, m = x.struct_info.shape
+ dtype = x.struct_info.dtype
+ else:
+ n = _convert_to_scalar_const(call.args[0])
+ m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1
else n
+ k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2
else 0
+ dtype = call.attrs.dtype
+
+ return bb.call_te(
+ topi.eye,
+ n,
+ m,
+ k,
+ dtype,
+ primfunc_name_hint=primfunc_name,
+ )
+
+ return eye_call_te
+
+
+register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye"))
+register_legalize("relax.eye_like", _eye(is_like=True,
primfunc_name="eye_like"))
+
+
@register_legalize("relax.arange")
def _arange(bb: BlockBuilder, call: Call) -> Expr:
assert len(call.args) == 3
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 105d763403..163085a07c 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -185,6 +185,25 @@ def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.one_hot")
+def _one_hot(bb: BlockBuilder, call: Call) -> Expr:
+ indices, on_value, off_value = call.args
+ if not (isinstance(on_value, relax.PrimValue) and isinstance(off_value,
relax.PrimValue)):
+ raise ValueError("on_value and off_value must be PrimValue")
+ on_value, off_value = on_value.value, off_value.value
+ if on_value.dtype != off_value.dtype:
+ raise ValueError("on_value and off_value must have the same dtype")
+ return bb.call_te(
+ topi.one_hot,
+ indices,
+ on_value,
+ off_value,
+ call.attrs.depth,
+ call.attrs.axis,
+ on_value.dtype,
+ )
+
+
@register_legalize("relax.layout_transform")
def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
def te_layout_transform(data, name):
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index f7847e2af8..049345fcb1 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -85,10 +85,13 @@ from tvm.relax.op import (
ewise_fma,
exp,
expand_dims,
+ eye,
+ eye_like,
flatten,
flip,
floor,
floor_divide,
+ floor_mod,
full,
full_like,
grad,
@@ -119,6 +122,7 @@ from tvm.relax.op import (
memory,
min,
minimum,
+ mod,
multinomial_from_uniform,
multiply,
negative,
@@ -127,6 +131,7 @@ from tvm.relax.op import (
null_value,
ones,
ones_like,
+ one_hot,
permute_dims,
power,
print,
@@ -753,10 +758,13 @@ __all__ = [
"exp",
"expand_dims",
"ext_dev",
+ "eye",
+ "eye_like",
"flatten",
"flip",
"floor",
"floor_divide",
+ "floor_mod",
"full",
"full_like",
"func_attr",
@@ -795,6 +803,7 @@ __all__ = [
"metal",
"min",
"minimum",
+ "mod",
"multinomial_from_uniform",
"multiply",
"negative",
@@ -802,6 +811,7 @@ __all__ = [
"null_value",
"ones",
"ones_like",
+ "one_hot",
"opencl",
"output",
"permute_dims",
diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py
index 31ebe86760..449c599dea 100644
--- a/python/tvm/topi/tensor.py
+++ b/python/tvm/topi/tensor.py
@@ -16,7 +16,11 @@
# under the License.
# pylint:
disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators"""
-from __future__ import absolute_import as _abs
+
+from typing import Optional
+
+from tvm import te
+
from . import cpp
@@ -73,3 +77,32 @@ def full_like(x, fill_value):
The result.
"""
return cpp.full_like(x, fill_value)
+
+
+def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: str = "float32")
-> te.Tensor:
+ """Generate an identity matrix or a matrix with ones on the k-th diagonal.
+
+ Parameters
+ ----------
+ n : int
+ Number of rows
+ m : int, optional
+ Number of columns. If None, defaults to n.
+ k : int, optional
+ Index of the diagonal. 0 (default) refers to the main diagonal.
+ A positive value refers to an upper diagonal, and a negative value
+ to a lower diagonal.
+ dtype : str, optional
+ Data type of the returned array.
+
+ Returns
+ -------
+ y : tvm.te.Tensor
+ The result.
+ """
+ m = m if m is not None else n
+ return te.compute(
+ (n, m),
+ lambda i, j: te.if_then_else(i == j - k, te.const(1, dtype),
te.const(0, dtype)),
+ name="eye",
+ )
diff --git a/src/relax/op/distributed/binary.cc
b/src/relax/op/distributed/binary.cc
index 6ad71e0f85..1e7fa81727 100644
--- a/src/relax/op/distributed/binary.cc
+++ b/src/relax/op/distributed/binary.cc
@@ -42,6 +42,8 @@
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_divide);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(multiply);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(power);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(subtract);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(mod);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_mod);
/***************** Comparison operators *****************/
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index f1dc3d4904..bd4c681c79 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -181,6 +181,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod);
/***************** Comparison operators *****************/
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 003bcb7e27..b66eb96f84 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -79,6 +79,12 @@ Expr power(Expr x1, Expr x2);
/*! \brief Subtraction with numpy-style broadcasting. */
Expr subtract(Expr x1, Expr x2);
+/*! \brief Modulo with numpy-style broadcasting. */
+Expr mod(Expr x1, Expr x2);
+
+/*! \brief Floor modulo with numpy-style broadcasting. */
+Expr floor_mod(Expr x1, Expr x2);
+
/***************** Comparison operators *****************/
/*! \brief Broadcasted element-wise test for (lhs == rhs). */
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 7aca1470ae..8696d85f77 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -228,6 +228,90 @@ TVM_REGISTER_OP("relax.zeros_like")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoOnesLikeZerosLike)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.eye & relax.eye_like */
+Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) {
+ ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("relax.eye");
+ return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs),
{});
+}
+
+Expr eye_like(Expr x, PrimValue k, DataType dtype) {
+ ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("relax.eye_like");
+ return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye);
+TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like);
+
+StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Eye op should have 3 arguments: n, m, and k, but got
" << call->args.size()
+ << " arguments");
+ }
+
+ auto get_prim_value = [&ctx](const Expr& expr, std::string key) {
+ if (!expr->IsInstance<PrimValueNode>()) {
+ ctx->ReportFatal(Diagnostic::Error(expr)
+ << "Eye expects the `" << key << "` to be a PrimValue,
but got "
+ << expr->GetTypeKey());
+ }
+ return expr.as<PrimValueNode>()->value;
+ };
+
+ PrimExpr n = get_prim_value(call->args[0], "n");
+ PrimExpr m = get_prim_value(call->args[1], "m");
+
+ DataType dtype = call->attrs.as<InitAttrs>()->dtype;
+ return TensorStructInfo(ShapeExpr({n, m}), dtype);
+}
+
+StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Eye_like op should have 2 arguments: x and k, but got
"
+ << call->args.size() << " arguments");
+ }
+
+ const auto* x_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ if (x_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Eye_like expects the input `x` to be a Tensor, but
got "
+ << call->args[0]->struct_info_->GetTypeKey());
+ }
+ if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Eye_like expects the input tensor to be
2-dimensional, but got "
+ << x_sinfo->ndim << " dimensions");
+ }
+
+ const auto* attrs = call->attrs.as<InitAttrs>();
+ DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype;
+
+ return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.eye")
+ .set_attrs_type<InitAttrs>()
+ .set_num_inputs(3)
+ .add_argument("n", "PrimValue", "Number of rows in the output.")
+ .add_argument("m", "PrimValue", "Number of columns in the output.")
+ .add_argument("k", "PrimValue", "Index of the diagonal.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEye)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+TVM_REGISTER_OP("relax.eye_like")
+ .set_attrs_type<InitAttrs>()
+ .set_num_inputs(2)
+ .add_argument("x", "Tensor", "The input tensor.")
+ .add_argument("k", "PrimValue", "Index of the diagonal.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEyeLike)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.arange */
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) {
ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h
index 6e7c825523..d88336146d 100644
--- a/src/relax/op/tensor/create.h
+++ b/src/relax/op/tensor/create.h
@@ -72,12 +72,48 @@ Expr ones(Expr shape, DataType dtype);
*/
Expr ones_like(Expr x, DataType dtype);
-/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */
+/*!
+ * \brief Construct a tensor of all zeros, with the input shape and dtype.
+ * \param shape The shape of the created tensor.
+ * \param dtype The data type of the created tensor.
+ * \return The result tensor.
+ */
Expr zeros(Expr shape, DataType dtype);
-/*! \brief Construct a tensor with all zeros, with shape of the input tensor
shape. */
+/*!
+ * \brief Construct a tensor with all zeros, with shape of the input tensor
shape.
+ * \param x The input tensor, which provides the shape, and dtype
+ * when the input dtype is void.
+ * \param dtype The data type of the created tensor. If it is
+ * void, the input tensor's dtype will be used.
+ * \return The result tensor.
+ */
Expr zeros_like(Expr x, DataType dtype);
+/*!
+ * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.
+ * \param n The number of rows and columns in the output.
+ * \param m The number of columns in the output. If None, defaults to n.
+ * \param k The index of the diagonal. A positive value refers to an upper
diagonal,
+ * a negative value to a lower diagonal, and 0 to the main diagonal.
+ * \param dtype The data type of the created tensor.
+ * \return The result tensor.
+ */
+Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype);
+
+/*!
+ * \brief Construct a tensor with ones on the diagonal and zeros elsewhere,
+ * with shape and dtype similar to the input tensor.
+ * \param x The input tensor, which provides the shape, and dtype
+ * when the input dtype is void.
+ * \param k The index of the diagonal. A positive value refers to an upper
diagonal,
+ * a negative value to a lower diagonal, and 0 to the main diagonal.
+ * \param dtype The data type of the created tensor. If it is
+ * void, the input tensor's dtype will be used.
+ * \return The result tensor.
+ */
+Expr eye_like(Expr x, PrimValue k, DataType dtype);
+
/*! \brief Construct a tensor with evenly spaced elements. */
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index ca7d0a0945..ba44341302 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -30,6 +30,8 @@
#include <utility>
#include <vector>
+#include "tvm/runtime/data_type.h"
+
namespace tvm {
namespace relax {
@@ -1665,5 +1667,78 @@ TVM_REGISTER_OP("relax.scatter_nd")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.one_hot */
+TVM_REGISTER_NODE_TYPE(OneHotAttrs);
+Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth,
int axis) {
+ ObjectPtr<OneHotAttrs> attrs = make_object<OneHotAttrs>();
+ attrs->depth = depth;
+ attrs->axis = axis;
+
+ // Check if on_value and off_value have the same dtype
+ DataType on_dtype = on_value->value->dtype;
+ DataType off_dtype = off_value->value->dtype;
+ ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have
the same dtype, "
+ << "but got " << on_dtype << " and " <<
off_dtype;
+
+ ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth;
+
+ static const Op& op = Op::Get("relax.one_hot");
+ return Call(op, {indices, on_value, off_value}, Attrs(attrs), {});
+} // namespace relax
+
+TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot);
+
+StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+ const auto* attrs = call->attrs.as<OneHotAttrs>();
+ PrimValue on_value = Downcast<PrimValue>(call->args[1]);
+ PrimValue off_value = Downcast<PrimValue>(call->args[2]);
+ // Check if on_value and off_value have the same dtype
+ ICHECK(on_value->value->dtype == off_value->value->dtype)
+ << "one_hot: on_value and off_value must have the same dtype, "
+ << "but got " << on_value->value->dtype << " and " <<
off_value->value->dtype;
+ DataType dtype = on_value->value->dtype;
+
+ // Check if indices has an integer dtype
+ if (indices_sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Data type of indices has not been specified. Assume it
has an integer type.";
+ } else if (!(indices_sinfo->dtype.is_int() ||
indices_sinfo->dtype.is_uint())) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "one_hot op requires the input indices to have integer
dtype. However, the "
+ "given indices dtype is "
+ << indices_sinfo->dtype);
+ }
+ // Check if indices has unknown dimension
+ if (indices_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice);
+ }
+ // Get the shape of indices
+ const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+ if (indices_shape == nullptr) {
+ return TensorStructInfo(dtype, indices_sinfo->ndim + 1,
indices_sinfo->vdevice);
+ }
+
+ Array<PrimExpr> output_shape = indices_shape->values;
+ int axis = attrs->axis;
+ if (axis < 0) {
+ axis += output_shape.size() + 1;
+ }
+ ICHECK(0 <= axis && axis <= static_cast<int>(output_shape.size()))
+ << "one_hot: axis must be in the range of [0, " << output_shape.size()
<< "], "
+ << "but got " << axis;
+ output_shape.insert(output_shape.begin() + axis, attrs->depth);
+
+ return TensorStructInfo(ShapeExpr(output_shape), dtype,
indices_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.one_hot")
+ .set_attrs_type<OneHotAttrs>()
+ .set_num_inputs(3)
+ .add_argument("indices", "Tensor", "The indices tensor.")
+ .add_argument("on_value", "PrimValue", "The value to fill at specified
indices.")
+ .add_argument("off_value", "PrimValue", "The value to fill at other
indices.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOneHot)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index e9fa1131e8..010ceb663e 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -27,6 +27,7 @@
#include <tvm/relax/attrs/manipulate.h>
#include "../op_common.h"
+#include "tvm/relax/expr.h"
namespace tvm {
namespace relax {
@@ -206,6 +207,17 @@ Expr scatter_elements(Expr data, Expr indices, Expr
updates, int axis, String re
*/
Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);
+/*!
+ * \brief Returns a one-hot tensor.
+ * \param indices The indices to set to `on_value`.
+ * \param on_value The value to fill at `indices`.
+ * \param off_value The value to fill at other locations.
+ * \param depth The depth of the one hot dimension.
+ * \param axis The axis to fill.
+ * \return The computed result.
+ */
+Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth,
int axis);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 1b4c5d281a..46373510b1 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -63,8 +63,11 @@ def generate_random_inputs(
if dtype == "bool":
# random_value = np.random.choice(a=[False, True], size=shape)
random_value = rg.choice(a=[False, True], size=shape)
+ elif dtype.startswith("int"):
+ # Keep non-zero values
+ random_value = rg.integers(low=-63, high=63,
size=shape).astype(dtype)
+ random_value[random_value <= 0] -= 1
else:
- # random_value = np.random.normal(size=shape).astype(dtype)
random_value = rg.standard_normal(size=shape).astype(dtype)
input_values[i.name] = random_value
@@ -246,7 +249,6 @@ def verify_binary_scalar(op_name, attrs={}, domain=None,
dtype=TensorProto.INT32
)
model = helper.make_model(graph, producer_name="binary_test")
- # NOTE: explicitly pass inputs to avoid numerical error
check_correctness(model, opset=opset)
@@ -327,6 +329,16 @@ def test_binary(op_name: str):
verify_binary_scalar(op_name)
[email protected]("int_mode", [True, False])
+def test_mod(int_mode: bool):
+ if int_mode:
+ dtype, fmod = TensorProto.INT32, 0
+ else:
+ dtype, fmod = TensorProto.FLOAT, 1
+ verify_binary("Mod", [1, 32], [1, 32], [1, 32], attrs={"fmod": fmod},
dtype=dtype)
+ verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype)
+
+
@pytest.mark.parametrize("num_inputs", [1, 2, 4])
@pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"])
def test_multi_input(op_name: str, num_inputs: int):
@@ -430,6 +442,7 @@ def test_bitwise_shift(direction: str):
"Sigmoid",
"Softmax",
"LogSoftmax",
+ "Hardmax",
"Identity",
],
)
@@ -445,7 +458,7 @@ def test_unary(op_name: str):
output_dtype = TensorProto.BOOL
else:
output_dtype = TensorProto.FLOAT
- verify_unary(op_name, [32, 32], input_dtype=input_dtype,
output_dtype=output_dtype)
+ verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype,
output_dtype=output_dtype)
@pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT,
TensorProto.FLOAT16])
@@ -567,6 +580,11 @@ def test_size():
check_correctness(model)
[email protected]("k", [-1, 0, 1])
+def test_eye_like(k: int):
+ verify_unary("EyeLike", [32, 32], attrs={"k": k})
+
+
@pytest.mark.parametrize("alpha", [None, 0.25, 1.0])
@pytest.mark.parametrize("beta", [None, 0.35, 1.0])
@pytest.mark.parametrize("useC", [False, True])
@@ -966,7 +984,7 @@ def test_cumsum1():
)
model = helper.make_model(graph, producer_name="cumsum_graph")
- check_correctness(model)
+ check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)})
@pytest.mark.parametrize("axis", [[0, 2], None])
diff --git a/tests/python/relax/test_op_create.py
b/tests/python/relax/test_op_create.py
index 1e895169f6..67f3470191 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -545,6 +545,64 @@ def
test_ones_like_zeros_like_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.zeros_like(x1))
+def test_eye_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3),
"float32"))
+ _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4),
"float32"))
+ _check_inference(bb, relax.op.eye(3, dtype="int64"),
relax.TensorStructInfo((3, 3), "int64"))
+ _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3,
5), "float32"))
+ _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3,
5), "float32"))
+
+
+def test_eye_infer_struct_info_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ k = tir.Var("k", "int64")
+
+ _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n),
"float32"))
+ _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m),
"float32"))
+ _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n),
"float32"))
+
+
+def test_eye_like_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 5), "int64"))
+ x2 = relax.Var("x", R.Tensor((3, 3)))
+
+ _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4),
"float32"))
+ _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5),
"int64"))
+ _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3),
dtype=""))
+ _check_inference(bb, relax.op.eye_like(x0, k=1),
relax.TensorStructInfo((3, 4), "float32"))
+ _check_inference(
+ bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2,
5), "float32")
+ )
+
+
+def test_eye_like_infer_struct_info_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ x = relax.Var("x", R.Tensor((n, m), "float32"))
+ k = tir.Var("k", "int64")
+
+ _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m),
"float32"))
+ _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n,
m), "float32"))
+
+
+def test_eye_like_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.eye_like(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.eye_like(x1))
+
+
def test_arange_infer_struct_info():
bb = relax.BlockBuilder()
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index e958b03e4c..f6aefc8591 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3377,5 +3377,57 @@ def test_scatter_nd_infer_struct_info():
)
+def test_one_hot_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ # Test case 1: Basic usage
+ i0 = relax.Var("indices", R.Tensor((3,), "int32"))
+ _check_inference(
+ bb,
+ relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5),
+ relax.TensorStructInfo((3, 5), "float32"),
+ )
+
+ # Test case 2: With specified axis
+ i1 = relax.Var("indices", R.Tensor((2, 2), "int32"))
+ _check_inference(
+ bb,
+ relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3,
axis=1),
+ relax.TensorStructInfo((2, 3, 2), "int64"),
+ )
+
+ # Test case 3: With symbolic shape
+ n = tir.Var("n", "int64")
+ i2 = relax.Var("indices", R.Tensor((n,), "int32"))
+ _check_inference(
+ bb,
+ relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4),
+ relax.TensorStructInfo((n, 4), "float32"),
+ )
+
+ # Test case 4: With unknown shape
+ i3 = relax.Var("indices", R.Tensor("int32"))
+ _check_inference(
+ bb,
+ relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6),
+ relax.TensorStructInfo(dtype="float32"),
+ )
+
+ # Test case 5: With different on_value and off_value dtypes
+ i3 = relax.Var("indices", R.Tensor((2, 3), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0),
relax.PrimValue(0), 5))
+
+ # Test case 6: With invalid indices dtype
+ i4 = relax.Var("indices", R.Tensor((2, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0),
relax.PrimValue(0.0), 5))
+
+ # Test case 7: With invalid depth
+ i5 = relax.Var("indices", R.Tensor((2, 3), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0),
relax.PrimValue(0.0), -1))
+
+
if __name__ == "__main__":
tvm.testing.main()