This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6a5b78690 [Relax] Enhance Relax op and ONNX frontend (#17462)
c6a5b78690 is described below

commit c6a5b7869023f7fd7b2926be847d39d363c13def
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 17 01:05:34 2024 +0800

    [Relax] Enhance Relax op and ONNX frontend (#17462)
---
 include/tvm/relax/attrs/manipulate.h               | 11 +++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 66 ++++++++++++++---
 python/tvm/relax/op/__init__.py                    |  5 ++
 python/tvm/relax/op/binary.py                      | 26 +++++++
 python/tvm/relax/op/create.py                      | 68 ++++++++++++++++++
 python/tvm/relax/op/manipulate.py                  | 44 ++++++++++++
 python/tvm/relax/transform/legalize_ops/binary.py  |  3 +-
 python/tvm/relax/transform/legalize_ops/create.py  | 30 ++++++++
 .../tvm/relax/transform/legalize_ops/manipulate.py | 19 +++++
 python/tvm/script/ir_builder/relax/ir.py           | 10 +++
 python/tvm/topi/tensor.py                          | 35 ++++++++-
 src/relax/op/distributed/binary.cc                 |  2 +
 src/relax/op/tensor/binary.cc                      |  2 +
 src/relax/op/tensor/binary.h                       |  6 ++
 src/relax/op/tensor/create.cc                      | 84 ++++++++++++++++++++++
 src/relax/op/tensor/create.h                       | 40 ++++++++++-
 src/relax/op/tensor/manipulate.cc                  | 75 +++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   | 12 ++++
 tests/python/relax/test_frontend_onnx.py           | 26 +++++--
 tests/python/relax/test_op_create.py               | 58 +++++++++++++++
 tests/python/relax/test_op_manipulate.py           | 52 ++++++++++++++
 21 files changed, 657 insertions(+), 17 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index e53ba3c36e..ea41488354 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -176,6 +176,17 @@ struct ScatterNDAttrs : public 
tvm::AttrsNode<ScatterNDAttrs> {
   }
 };  // struct ScatterNDAttrs
 
+/*! \brief Attributes used in one_hot operator */
+struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
+  int depth;
+  int axis;
+
+  TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
+    TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
+    TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
+  }
+};  // struct OneHotAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 43c1ec681a..6c9225070d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -287,7 +287,7 @@ class Sub(BinaryBase):
     relax_op = relax.op.subtract
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
+    def _impl_v7(cls, bb, inputs, attr, params):
         return cls.base_impl(bb, inputs, attr, params)
 
 
@@ -298,7 +298,7 @@ class Mul(BinaryBase):
     relax_op = relax.op.multiply
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
+    def _impl_v7(cls, bb, inputs, attr, params):
         return cls.base_impl(bb, inputs, attr, params)
 
 
@@ -309,7 +309,7 @@ class Div(BinaryBase):
     relax_op = relax.op.divide
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
+    def _impl_v7(cls, bb, inputs, attr, params):
         return cls.base_impl(bb, inputs, attr, params)
 
 
@@ -320,7 +320,24 @@ class Pow(BinaryBase):
     relax_op = relax.op.power
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
+    def _impl_v7(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
+
+class Mod(BinaryBase):
+    """Converts an onnx Mod node into an equivalent Relax expression."""
+
+    numpy_op = _np.mod
+    relax_op = relax.op.mod
+
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        if attr.get("fmod", 0) == 0:
+            cls.numpy_op = _np.fmod
+            cls.relax_op = relax.op.floor_mod
+        else:
+            cls.numpy_op = _np.mod
+            cls.relax_op = relax.op.mod
         return cls.base_impl(bb, inputs, attr, params)
 
 
@@ -523,6 +540,23 @@ class LogSoftmax(OnnxOpConverter):
         return relax.op.nn.log_softmax(inputs[0], axis=axis)
 
 
+class Hardmax(OnnxOpConverter):
+    """Converts an onnx Hardmax node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", -1)
+        indices = inputs[0]
+        dtype = indices.struct_info.dtype
+        axis_len = int(inputs[0].struct_info.shape[axis])
+        argmax = relax.op.argmax(indices, axis=axis)
+        on_value = relax.PrimValue(tvm.tir.const(1.0, dtype))
+        off_value = relax.PrimValue(tvm.tir.const(0.0, dtype))
+
+        one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
+        return one_hot
+
+
 class Transpose(OnnxOpConverter):
     """Converts an onnx Transpose node into an equivalent Relax expression."""
 
@@ -731,6 +765,20 @@ class Size(OnnxOpConverter):
         return 
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
 
 
+class EyeLike(OnnxOpConverter):
+    """Convert an onnx EyeLike node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        k = attr.get("k", 0)
+        input_dtype = inputs[0].struct_info.dtype
+        if "dtype" in attr and get_type(attr["dtype"]) != input_dtype:
+            raise ValueError(
+                f"dtype mismatch between input ({input_dtype}) and attribute 
({attr['dtype']})"
+            )
+        return relax.op.eye_like(inputs[0], k, input_dtype)
+
+
 class Gemm(OnnxOpConverter):
     """Convert an onnx Gemm node into an equivalent Relax expression."""
 
@@ -2520,13 +2568,13 @@ class OneHot(OnnxOpConverter):
         depth = get_constant(inputs[1], params)
         values = get_constant(inputs[2], params)
         axis = attr.get("axis", -1)
-        dtype = values.struct_info.dtype
         assert isinstance(depth, relax.Constant), "Only constant depth 
currently supported."
         depth = depth.data.numpy().tolist()
         assert isinstance(values, relax.Constant), "Only constant values 
currently supported."
         values = values.data.numpy().tolist()
         off_value, on_value = values
-        return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, 
axis, dtype)
+        off_value, on_value = relax.PrimValue(off_value), 
relax.PrimValue(on_value)
+        return relax.op.one_hot(indices, on_value, off_value, depth, axis)
 
 
 class Unique(OnnxOpConverter):
@@ -2800,7 +2848,7 @@ def _get_convert_map():
         "Sub": Sub,
         "Mul": Mul,
         "Div": Div,
-        # "Mod": Mod,
+        "Mod": Mod,
         "Less": Less,
         "LessOrEqual": LessOrEqual,
         "Greater": Greater,
@@ -2870,7 +2918,7 @@ def _get_convert_map():
         "Sigmoid": Sigmoid,
         "Softmax": Softmax,
         "LogSoftmax": LogSoftmax,
-        # "Hardmax": Hardmax,
+        "Hardmax": Hardmax,
         "Transpose": Transpose,
         "Unsqueeze": Unsqueeze,
         "Where": Where,
@@ -2889,7 +2937,7 @@ def _get_convert_map():
         "ScatterND": ScatterND,
         # "Compress": Compress,
         "Size": Size,
-        # "EyeLike": EyeLike,
+        "EyeLike": EyeLike,
         # Normalization
         "BatchNormalization": BatchNormalization,
         "LayerNormalization": LayerNormalization,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 84b31ccec0..1603ea2f0f 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -50,6 +50,7 @@ from .binary import (
     divide,
     equal,
     floor_divide,
+    floor_mod,
     greater,
     greater_equal,
     left_shift,
@@ -60,6 +61,7 @@ from .binary import (
     logical_xor,
     maximum,
     minimum,
+    mod,
     multiply,
     not_equal,
     power,
@@ -72,6 +74,8 @@ from .create import (
     full_like,
     ones,
     ones_like,
+    eye,
+    eye_like,
     tril,
     triu,
     zeros,
@@ -89,6 +93,7 @@ from .manipulate import (
     flatten,
     flip,
     layout_transform,
+    one_hot,
     permute_dims,
     repeat,
     reshape,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 7632235cb3..7a41c8b095 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr:
     return _ffi_api.subtract(x1, x2)  # type: ignore
 
 
+def mod(x1: Expr, x2: Expr) -> Expr:
+    """Modulo with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    x1 : Expr
+        The first input tensor.
+    x2 : Expr
+        The second input tensor.
+    """
+    return _ffi_api.mod(x1, x2)  # type: ignore
+
+
+def floor_mod(x1: Expr, x2: Expr) -> Expr:
+    """Floor modulo with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    x1 : Expr
+        The first input tensor.
+    x2 : Expr
+        The second input tensor.
+    """
+    return _ffi_api.floor_mod(x1, x2)  # type: ignore
+
+
 ###################### Comparison operators ######################
 
 
diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py
index 092d79a74d..c61d9521a4 100644
--- a/python/tvm/relax/op/create.py
+++ b/python/tvm/relax/op/create.py
@@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, 
DataType]] = None) -> Expr:
     return _ffi_api.zeros_like(x, dtype)  # type: ignore
 
 
+def eye(
+    n: Union[PrimExprLike, PrimValue],
+    m: Optional[Union[PrimExprLike, PrimValue]] = None,
+    k: Union[PrimExprLike, PrimValue] = 0,
+    dtype: Union[str, DataType] = "float32",
+) -> Expr:
+    """Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    n : Union[PrimExprLike, PrimValue]
+        Number of rows in the output.
+
+    m : Optional[Union[PrimExprLike, PrimValue]]
+        Number of columns in the output. If None, defaults to n.
+
+    k : Union[PrimExprLike, PrimValue]
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal, and a negative value
+        to a lower diagonal.
+
+    dtype : Union[str, DataType]
+        The data type of the created tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result tensor.
+    """
+    m = n if m is None else m
+    n = n if isinstance(n, PrimValue) else PrimValue(n)
+    m = m if isinstance(m, PrimValue) else PrimValue(m)
+    k = k if isinstance(k, PrimValue) else PrimValue(k)
+    return _ffi_api.eye(n, m, k, dtype)  # type: ignore
+
+
+def eye_like(
+    x: Expr,
+    k: Union[PrimExprLike, PrimValue] = 0,
+    dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    """Return a 2-D tensor with ones on the diagonal and zeros elsewhere,
+    with the same shape as the input tensor.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input tensor, which provides the shape, and dtype
+        when the `dtype` field is not specified.
+
+    k : Union[PrimExprLike, PrimValue]
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal, and a negative value
+        to a lower diagonal.
+
+    dtype : Optional[Union[str, DataType]]
+        The data type of the created tensor.
+        If dtype is not given, it will by default use the dtype of the input 
tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result tensor.
+    """
+    k = k if isinstance(k, PrimValue) else PrimValue(k)
+    return _ffi_api.eye_like(x, k, dtype)  # type: ignore
+
+
 def arange(
     start: Union[PrimExprLike, PrimValue],
     end: Optional[Union[PrimExprLike, PrimValue]] = None,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 1673a79b08..3210cc8216 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, 
reduction: str = "updat
 
     """
     return _ffi_api.scatter_nd(data, indices, updates, reduction)  # type: 
ignore
+
+
+def one_hot(
+    indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, 
axis: int = -1
+) -> Expr:
+    """Returns a one-hot tensor.
+
+    Parameters
+    ----------
+    indices : relax.Expr
+        The indices to set to `on_value`.
+
+    on_value : relax.PrimValue
+        The value to fill at `indices`.
+
+    off_value : relax.PrimValue
+        The value to fill at other locations.
+
+    depth : int
+        The depth of the one-hot dimension.
+
+    axis : int, optional
+        The axis to fill. Default is -1 which adds a new dimension at the end.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        indices = [0, 1, 2]
+        depth = 3
+        on_value = 1
+        off_value = 0
+
+        one_hot(indices, on_value, off_value, depth) =
+            [[1, 0, 0],
+             [0, 1, 0],
+             [0, 0, 1]]
+    """
+    return _ffi_api.one_hot(indices, on_value, off_value, depth, axis)  # 
type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index d28e100edb..41e317f1e0 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -48,7 +48,8 @@ register_legalize("relax.multiply", _binary(topi.multiply))
 register_legalize("relax.power", _binary(topi.power))
 register_legalize("relax.subtract", _binary(topi.subtract))
 register_legalize("relax.equal", _binary(topi.equal))
-
+register_legalize("relax.mod", _binary(topi.mod))
+register_legalize("relax.floor_mod", _binary(topi.floor_mod))
 register_legalize("relax.greater", _binary(topi.greater))
 register_legalize("relax.greater_equal", _binary(topi.greater_equal))
 register_legalize("relax.less", _binary(topi.less))
diff --git a/python/tvm/relax/transform/legalize_ops/create.py 
b/python/tvm/relax/transform/legalize_ops/create.py
index 1b022672d0..8bf85e34de 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -70,6 +70,36 @@ register_legalize("relax.tril", _tril_triu(is_upper=False, 
primfunc_name="tril")
 register_legalize("relax.triu", _tril_triu(is_upper=True, 
primfunc_name="triu"))
 
 
+def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc:
+    def eye_call_te(bb: BlockBuilder, call: Call) -> Expr:
+        _convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, 
python_native=True)
+        if is_like:
+            x = call.args[0]
+            k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 
else 0
+            n, m = x.struct_info.shape
+            dtype = x.struct_info.dtype
+        else:
+            n = _convert_to_scalar_const(call.args[0])
+            m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 
else n
+            k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 
else 0
+            dtype = call.attrs.dtype
+
+        return bb.call_te(
+            topi.eye,
+            n,
+            m,
+            k,
+            dtype,
+            primfunc_name_hint=primfunc_name,
+        )
+
+    return eye_call_te
+
+
+register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye"))
+register_legalize("relax.eye_like", _eye(is_like=True, 
primfunc_name="eye_like"))
+
+
 @register_legalize("relax.arange")
 def _arange(bb: BlockBuilder, call: Call) -> Expr:
     assert len(call.args) == 3
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 105d763403..163085a07c 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -185,6 +185,25 @@ def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.one_hot")
+def _one_hot(bb: BlockBuilder, call: Call) -> Expr:
+    indices, on_value, off_value = call.args
+    if not (isinstance(on_value, relax.PrimValue) and isinstance(off_value, 
relax.PrimValue)):
+        raise ValueError("on_value and off_value must be PrimValue")
+    on_value, off_value = on_value.value, off_value.value
+    if on_value.dtype != off_value.dtype:
+        raise ValueError("on_value and off_value must have the same dtype")
+    return bb.call_te(
+        topi.one_hot,
+        indices,
+        on_value,
+        off_value,
+        call.attrs.depth,
+        call.attrs.axis,
+        on_value.dtype,
+    )
+
+
 @register_legalize("relax.layout_transform")
 def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
     def te_layout_transform(data, name):
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index f7847e2af8..049345fcb1 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -85,10 +85,13 @@ from tvm.relax.op import (
     ewise_fma,
     exp,
     expand_dims,
+    eye,
+    eye_like,
     flatten,
     flip,
     floor,
     floor_divide,
+    floor_mod,
     full,
     full_like,
     grad,
@@ -119,6 +122,7 @@ from tvm.relax.op import (
     memory,
     min,
     minimum,
+    mod,
     multinomial_from_uniform,
     multiply,
     negative,
@@ -127,6 +131,7 @@ from tvm.relax.op import (
     null_value,
     ones,
     ones_like,
+    one_hot,
     permute_dims,
     power,
     print,
@@ -753,10 +758,13 @@ __all__ = [
     "exp",
     "expand_dims",
     "ext_dev",
+    "eye",
+    "eye_like",
     "flatten",
     "flip",
     "floor",
     "floor_divide",
+    "floor_mod",
     "full",
     "full_like",
     "func_attr",
@@ -795,6 +803,7 @@ __all__ = [
     "metal",
     "min",
     "minimum",
+    "mod",
     "multinomial_from_uniform",
     "multiply",
     "negative",
@@ -802,6 +811,7 @@ __all__ = [
     "null_value",
     "ones",
     "ones_like",
+    "one_hot",
     "opencl",
     "output",
     "permute_dims",
diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py
index 31ebe86760..449c599dea 100644
--- a/python/tvm/topi/tensor.py
+++ b/python/tvm/topi/tensor.py
@@ -16,7 +16,11 @@
 # under the License.
 # pylint: 
disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
 """Elementwise operators"""
-from __future__ import absolute_import as _abs
+
+from typing import Optional
+
+from tvm import te
+
 from . import cpp
 
 
@@ -73,3 +77,32 @@ def full_like(x, fill_value):
         The result.
     """
     return cpp.full_like(x, fill_value)
+
+
+def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: str = "float32") 
-> te.Tensor:
+    """Generate an identity matrix or a matrix with ones on the k-th diagonal.
+
+    Parameters
+    ----------
+    n : int
+        Number of rows
+    m : int, optional
+        Number of columns. If None, defaults to n.
+    k : int, optional
+        Index of the diagonal. 0 (default) refers to the main diagonal.
+        A positive value refers to an upper diagonal, and a negative value
+        to a lower diagonal.
+    dtype : str, optional
+        Data type of the returned array.
+
+    Returns
+    -------
+    y : tvm.te.Tensor
+        The result.
+    """
+    m = m if m is not None else n
+    return te.compute(
+        (n, m),
+        lambda i, j: te.if_then_else(i == j - k, te.const(1, dtype), 
te.const(0, dtype)),
+        name="eye",
+    )
diff --git a/src/relax/op/distributed/binary.cc 
b/src/relax/op/distributed/binary.cc
index 6ad71e0f85..1e7fa81727 100644
--- a/src/relax/op/distributed/binary.cc
+++ b/src/relax/op/distributed/binary.cc
@@ -42,6 +42,8 @@ 
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_divide);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(multiply);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(power);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(subtract);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(mod);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_mod);
 
 /***************** Comparison operators *****************/
 
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index f1dc3d4904..bd4c681c79 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -181,6 +181,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod);
 
 /***************** Comparison operators *****************/
 
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 003bcb7e27..b66eb96f84 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -79,6 +79,12 @@ Expr power(Expr x1, Expr x2);
 /*! \brief Subtraction with numpy-style broadcasting. */
 Expr subtract(Expr x1, Expr x2);
 
+/*! \brief Modulo with numpy-style broadcasting. */
+Expr mod(Expr x1, Expr x2);
+
+/*! \brief Floor modulo with numpy-style broadcasting. */
+Expr floor_mod(Expr x1, Expr x2);
+
 /***************** Comparison operators *****************/
 
 /*! \brief Broadcasted element-wise test for (lhs == rhs). */
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 7aca1470ae..8696d85f77 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -228,6 +228,90 @@ TVM_REGISTER_OP("relax.zeros_like")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoOnesLikeZerosLike)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.eye & relax.eye_like */
+Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) {
+  ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("relax.eye");
+  return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), 
{});
+}
+
+Expr eye_like(Expr x, PrimValue k, DataType dtype) {
+  ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("relax.eye_like");
+  return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye);
+TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like);
+
+StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Eye op should have 3 arguments: n, m, and k, but got 
" << call->args.size()
+                     << " arguments");
+  }
+
+  auto get_prim_value = [&ctx](const Expr& expr, std::string key) {
+    if (!expr->IsInstance<PrimValueNode>()) {
+      ctx->ReportFatal(Diagnostic::Error(expr)
+                       << "Eye expects the `" << key << "` to be a PrimValue, 
but got "
+                       << expr->GetTypeKey());
+    }
+    return expr.as<PrimValueNode>()->value;
+  };
+
+  PrimExpr n = get_prim_value(call->args[0], "n");
+  PrimExpr m = get_prim_value(call->args[1], "m");
+
+  DataType dtype = call->attrs.as<InitAttrs>()->dtype;
+  return TensorStructInfo(ShapeExpr({n, m}), dtype);
+}
+
+StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Eye_like op should have 2 arguments: x and k, but got 
"
+                     << call->args.size() << " arguments");
+  }
+
+  const auto* x_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  if (x_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Eye_like expects the input `x` to be a Tensor, but 
got "
+                     << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Eye_like expects the input tensor to be 
2-dimensional, but got "
+                     << x_sinfo->ndim << " dimensions");
+  }
+
+  const auto* attrs = call->attrs.as<InitAttrs>();
+  DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype;
+
+  return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.eye")
+    .set_attrs_type<InitAttrs>()
+    .set_num_inputs(3)
+    .add_argument("n", "PrimValue", "Number of rows in the output.")
+    .add_argument("m", "PrimValue", "Number of columns in the output.")
+    .add_argument("k", "PrimValue", "Index of the diagonal.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEye)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+TVM_REGISTER_OP("relax.eye_like")
+    .set_attrs_type<InitAttrs>()
+    .set_num_inputs(2)
+    .add_argument("x", "Tensor", "The input tensor.")
+    .add_argument("k", "PrimValue", "Index of the diagonal.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEyeLike)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.arange */
 Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) {
   ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h
index 6e7c825523..d88336146d 100644
--- a/src/relax/op/tensor/create.h
+++ b/src/relax/op/tensor/create.h
@@ -72,12 +72,48 @@ Expr ones(Expr shape, DataType dtype);
  */
 Expr ones_like(Expr x, DataType dtype);
 
-/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */
+/*!
+ * \brief Construct a tensor of all zeros, with the input shape and dtype.
+ * \param shape The shape of the created tensor.
+ * \param dtype The data type of the created tensor.
+ * \return The result tensor.
+ */
 Expr zeros(Expr shape, DataType dtype);
 
-/*! \brief Construct a tensor with all zeros, with shape of the input tensor 
shape. */
+/*!
+ * \brief Construct a tensor with all zeros, with shape of the input tensor 
shape.
+ * \param x The input tensor, which provides the shape, and dtype
+ * when the input dtype is void.
+ * \param dtype The data type of the created tensor. If it is
+ * void, the input tensor's dtype will be used.
+ * \return The result tensor.
+ */
 Expr zeros_like(Expr x, DataType dtype);
 
+/*!
+ * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.
+ * \param n The number of rows and columns in the output.
+ * \param m The number of columns in the output. If None, defaults to n.
+ * \param k The index of the diagonal. A positive value refers to an upper 
diagonal,
+ *          a negative value to a lower diagonal, and 0 to the main diagonal.
+ * \param dtype The data type of the created tensor.
+ * \return The result tensor.
+ */
+Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype);
+
+/*!
+ * \brief Construct a tensor with ones on the diagonal and zeros elsewhere,
+ *        with shape and dtype similar to the input tensor.
+ * \param x The input tensor, which provides the shape, and dtype
+ * when the input dtype is void.
+ * \param k The index of the diagonal. A positive value refers to an upper 
diagonal,
+ *          a negative value to a lower diagonal, and 0 to the main diagonal.
+ * \param dtype The data type of the created tensor. If it is
+ * void, the input tensor's dtype will be used.
+ * \return The result tensor.
+ */
+Expr eye_like(Expr x, PrimValue k, DataType dtype);
+
 /*! \brief Construct a tensor with evenly spaced elements. */
 Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);
 
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index ca7d0a0945..ba44341302 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -30,6 +30,8 @@
 #include <utility>
 #include <vector>
 
+#include "tvm/runtime/data_type.h"
+
 namespace tvm {
 namespace relax {
 
@@ -1665,5 +1667,78 @@ TVM_REGISTER_OP("relax.scatter_nd")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.one_hot */
+TVM_REGISTER_NODE_TYPE(OneHotAttrs);
+Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, 
int axis) {
+  ObjectPtr<OneHotAttrs> attrs = make_object<OneHotAttrs>();
+  attrs->depth = depth;
+  attrs->axis = axis;
+
+  // Check if on_value and off_value have the same dtype
+  DataType on_dtype = on_value->value->dtype;
+  DataType off_dtype = off_value->value->dtype;
+  ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have 
the same dtype, "
+                                << "but got " << on_dtype << " and " << 
off_dtype;
+
+  ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth;
+
+  static const Op& op = Op::Get("relax.one_hot");
+  return Call(op, {indices, on_value, off_value}, Attrs(attrs), {});
+}  // namespace relax
+
+TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot);
+
+StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+  const auto* attrs = call->attrs.as<OneHotAttrs>();
+  PrimValue on_value = Downcast<PrimValue>(call->args[1]);
+  PrimValue off_value = Downcast<PrimValue>(call->args[2]);
+  // Check if on_value and off_value have the same dtype
+  ICHECK(on_value->value->dtype == off_value->value->dtype)
+      << "one_hot: on_value and off_value must have the same dtype, "
+      << "but got " << on_value->value->dtype << " and " << 
off_value->value->dtype;
+  DataType dtype = on_value->value->dtype;
+
+  // Check if indices has an integer dtype
+  if (indices_sinfo->IsUnknownDtype()) {
+    LOG(WARNING) << "Data type of indices has not been specified. Assume it 
has an integer type.";
+  } else if (!(indices_sinfo->dtype.is_int() || 
indices_sinfo->dtype.is_uint())) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "one_hot op requires the input indices to have integer 
dtype. However, the "
+                        "given indices dtype is "
+                     << indices_sinfo->dtype);
+  }
+  // Check if indices has unknown dimension
+  if (indices_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice);
+  }
+  // Get the shape of indices
+  const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+  if (indices_shape == nullptr) {
+    return TensorStructInfo(dtype, indices_sinfo->ndim + 1, 
indices_sinfo->vdevice);
+  }
+
+  Array<PrimExpr> output_shape = indices_shape->values;
+  int axis = attrs->axis;
+  if (axis < 0) {
+    axis += output_shape.size() + 1;
+  }
+  ICHECK(0 <= axis && axis <= static_cast<int>(output_shape.size()))
+      << "one_hot: axis must be in the range of [0, " << output_shape.size() 
<< "], "
+      << "but got " << axis;
+  output_shape.insert(output_shape.begin() + axis, attrs->depth);
+
+  return TensorStructInfo(ShapeExpr(output_shape), dtype, 
indices_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.one_hot")
+    .set_attrs_type<OneHotAttrs>()
+    .set_num_inputs(3)
+    .add_argument("indices", "Tensor", "The indices tensor.")
+    .add_argument("on_value", "PrimValue", "The value to fill at specified 
indices.")
+    .add_argument("off_value", "PrimValue", "The value to fill at other 
indices.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOneHot)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index e9fa1131e8..010ceb663e 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -27,6 +27,7 @@
 #include <tvm/relax/attrs/manipulate.h>
 
 #include "../op_common.h"
+#include "tvm/relax/expr.h"
 
 namespace tvm {
 namespace relax {
@@ -206,6 +207,17 @@ Expr scatter_elements(Expr data, Expr indices, Expr 
updates, int axis, String re
  */
 Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);
 
+/*!
+ * \brief Returns a one-hot tensor.
+ * \param indices The indices to set to `on_value`.
+ * \param on_value The value to fill at `indices`.
+ * \param off_value The value to fill at other locations.
+ * \param depth The depth of the one hot dimension.
+ * \param axis The axis to fill.
+ * \return The computed result.
+ */
+Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, 
int axis);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 1b4c5d281a..46373510b1 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -63,8 +63,11 @@ def generate_random_inputs(
         if dtype == "bool":
             # random_value = np.random.choice(a=[False, True], size=shape)
             random_value = rg.choice(a=[False, True], size=shape)
+        elif dtype.startswith("int"):
+            # Keep non-zero values
+            random_value = rg.integers(low=-63, high=63, 
size=shape).astype(dtype)
+            random_value[random_value <= 0] -= 1
         else:
-            # random_value = np.random.normal(size=shape).astype(dtype)
             random_value = rg.standard_normal(size=shape).astype(dtype)
         input_values[i.name] = random_value
 
@@ -246,7 +249,6 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, 
dtype=TensorProto.INT32
     )
 
     model = helper.make_model(graph, producer_name="binary_test")
-    # NOTE: explicitly pass inputs to avoid numerical error
     check_correctness(model, opset=opset)
 
 
@@ -327,6 +329,16 @@ def test_binary(op_name: str):
     verify_binary_scalar(op_name)
 
 
[email protected]("int_mode", [True, False])
+def test_mod(int_mode: bool):
+    if int_mode:
+        dtype, fmod = TensorProto.INT32, 0
+    else:
+        dtype, fmod = TensorProto.FLOAT, 1
+    verify_binary("Mod", [1, 32], [1, 32], [1, 32], attrs={"fmod": fmod}, 
dtype=dtype)
+    verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype)
+
+
 @pytest.mark.parametrize("num_inputs", [1, 2, 4])
 @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"])
 def test_multi_input(op_name: str, num_inputs: int):
@@ -430,6 +442,7 @@ def test_bitwise_shift(direction: str):
         "Sigmoid",
         "Softmax",
         "LogSoftmax",
+        "Hardmax",
         "Identity",
     ],
 )
@@ -445,7 +458,7 @@ def test_unary(op_name: str):
         output_dtype = TensorProto.BOOL
     else:
         output_dtype = TensorProto.FLOAT
-    verify_unary(op_name, [32, 32], input_dtype=input_dtype, 
output_dtype=output_dtype)
+    verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, 
output_dtype=output_dtype)
 
 
 @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, 
TensorProto.FLOAT16])
@@ -567,6 +580,11 @@ def test_size():
     check_correctness(model)
 
 
[email protected]("k", [-1, 0, 1])
+def test_eye_like(k: int):
+    verify_unary("EyeLike", [32, 32], attrs={"k": k})
+
+
 @pytest.mark.parametrize("alpha", [None, 0.25, 1.0])
 @pytest.mark.parametrize("beta", [None, 0.35, 1.0])
 @pytest.mark.parametrize("useC", [False, True])
@@ -966,7 +984,7 @@ def test_cumsum1():
     )
 
     model = helper.make_model(graph, producer_name="cumsum_graph")
-    check_correctness(model)
+    check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)})
 
 
 @pytest.mark.parametrize("axis", [[0, 2], None])
diff --git a/tests/python/relax/test_op_create.py 
b/tests/python/relax/test_op_create.py
index 1e895169f6..67f3470191 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -545,6 +545,64 @@ def 
test_ones_like_zeros_like_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.zeros_like(x1))
 
 
+def test_eye_infer_struct_info():
+    bb = relax.BlockBuilder()
+
+    _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3), 
"float32"))
+    _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4), 
"float32"))
+    _check_inference(bb, relax.op.eye(3, dtype="int64"), 
relax.TensorStructInfo((3, 3), "int64"))
+    _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3, 
5), "float32"))
+    _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3, 
5), "float32"))
+
+
+def test_eye_infer_struct_info_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    m = tir.Var("m", "int64")
+    k = tir.Var("k", "int64")
+
+    _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), 
"float32"))
+    _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), 
"float32"))
+    _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n), 
"float32"))
+
+
+def test_eye_like_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 5), "int64"))
+    x2 = relax.Var("x", R.Tensor((3, 3)))
+
+    _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4), 
"float32"))
+    _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5), 
"int64"))
+    _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3), 
dtype=""))
+    _check_inference(bb, relax.op.eye_like(x0, k=1), 
relax.TensorStructInfo((3, 4), "float32"))
+    _check_inference(
+        bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2, 
5), "float32")
+    )
+
+
+def test_eye_like_infer_struct_info_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    m = tir.Var("m", "int64")
+    x = relax.Var("x", R.Tensor((n, m), "float32"))
+    k = tir.Var("k", "int64")
+
+    _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), 
"float32"))
+    _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, 
m), "float32"))
+
+
+def test_eye_like_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.eye_like(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.eye_like(x1))
+
+
 def test_arange_infer_struct_info():
     bb = relax.BlockBuilder()
 
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index e958b03e4c..f6aefc8591 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3377,5 +3377,57 @@ def test_scatter_nd_infer_struct_info():
     )
 
 
+def test_one_hot_infer_struct_info():
+    bb = relax.BlockBuilder()
+
+    # Test case 1: Basic usage
+    i0 = relax.Var("indices", R.Tensor((3,), "int32"))
+    _check_inference(
+        bb,
+        relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5),
+        relax.TensorStructInfo((3, 5), "float32"),
+    )
+
+    # Test case 2: With specified axis
+    i1 = relax.Var("indices", R.Tensor((2, 2), "int32"))
+    _check_inference(
+        bb,
+        relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, 
axis=1),
+        relax.TensorStructInfo((2, 3, 2), "int64"),
+    )
+
+    # Test case 3: With symbolic shape
+    n = tir.Var("n", "int64")
+    i2 = relax.Var("indices", R.Tensor((n,), "int32"))
+    _check_inference(
+        bb,
+        relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4),
+        relax.TensorStructInfo((n, 4), "float32"),
+    )
+
+    # Test case 4: With unknown shape
+    i3 = relax.Var("indices", R.Tensor("int32"))
+    _check_inference(
+        bb,
+        relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+
+    # Test case 5: With different on_value and off_value dtypes
+    i3 = relax.Var("indices", R.Tensor((2, 3), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0), 
relax.PrimValue(0), 5))
+
+    # Test case 6: With invalid indices dtype
+    i4 = relax.Var("indices", R.Tensor((2, 3), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0), 
relax.PrimValue(0.0), 5))
+
+    # Test case 7: With invalid depth
+    i5 = relax.Var("indices", R.Tensor((2, 3), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0), 
relax.PrimValue(0.0), -1))
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to