This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dc1e580e61 [Unity][Relax] Add masked_fill operator (#15077)
dc1e580e61 is described below

commit dc1e580e61f30f1ae8e2688658e2467928a6e516
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jun 13 02:25:31 2023 +0400

    [Unity][Relax] Add masked_fill operator (#15077)
    
    add masked_fill op to relax
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relax/op/__init__.py              |  1 +
 python/tvm/relax/op/{__init__.py => mask.py} | 54 +++++++++++-----------------
 2 files changed, 21 insertions(+), 34 deletions(-)

diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index d9af245d79..b8b5b5f22e 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -25,6 +25,7 @@ from .datatype import *
 from .index import *
 from .linear_algebra import *
 from .manipulate import *
+from .mask import *
 from .op_attrs import *
 from .statistical import *
 from .search import *
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/mask.py
similarity index 51%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/mask.py
index d9af245d79..4fc94b9cf4 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/mask.py
@@ -14,39 +14,25 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Operators with mask."""
+from . import _ffi_api
+from ..expr import Expr
 
-# Operators
-from .base import *
-from .binary import *
-from .create import *
-from .datatype import *
-from .index import *
-from .linear_algebra import *
-from .manipulate import *
-from .op_attrs import *
-from .statistical import *
-from .search import *
-from .set import *
-from .ternary import *
-from .unary import *
-from . import builtin
-from . import grad
-from . import image
-from . import memory
-from . import nn
 
-# Operator gradient functions
-from . import _op_gradient
-
-
-def _register_op_make():
-    # pylint: disable=import-outside-toplevel
-    from . import _ffi_api
-    from .. import expr
-
-    expr._op_ffi_api = _ffi_api  # type: ignore
-
-
-_register_op_make()
+def masked_fill(x: Expr, mask: Expr, value: Expr):
+    """Fill a tensor by a specified value in places defined by a mask.
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data to the operator.
+    mask : relax.Expr
+        The mask.
+    value : relax.Expr
+        The value to set in the input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The filled tensor.
+    """
+    values = _ffi_api.full_like(x, value)  # type: ignore
+    return _ffi_api.where(mask, values, x)  # type: ignore

Reply via email to