This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dc1e580e61 [Unity][Relax] Add masked_fill operator (#15077)
dc1e580e61 is described below
commit dc1e580e61f30f1ae8e2688658e2467928a6e516
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jun 13 02:25:31 2023 +0400
[Unity][Relax] Add masked_fill operator (#15077)
add masked_fill op to relax
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/{__init__.py => mask.py} | 54 +++++++++++-----------------
2 files changed, 21 insertions(+), 34 deletions(-)
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index d9af245d79..b8b5b5f22e 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -25,6 +25,7 @@ from .datatype import *
from .index import *
from .linear_algebra import *
from .manipulate import *
+from .mask import *
from .op_attrs import *
from .statistical import *
from .search import *
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/mask.py
similarity index 51%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/mask.py
index d9af245d79..4fc94b9cf4 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/mask.py
@@ -14,39 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Operators with mask."""
+from . import _ffi_api
+from ..expr import Expr
-# Operators
-from .base import *
-from .binary import *
-from .create import *
-from .datatype import *
-from .index import *
-from .linear_algebra import *
-from .manipulate import *
-from .op_attrs import *
-from .statistical import *
-from .search import *
-from .set import *
-from .ternary import *
-from .unary import *
-from . import builtin
-from . import grad
-from . import image
-from . import memory
-from . import nn
-# Operator gradient functions
-from . import _op_gradient
-
-
-def _register_op_make():
- # pylint: disable=import-outside-toplevel
- from . import _ffi_api
- from .. import expr
-
- expr._op_ffi_api = _ffi_api # type: ignore
-
-
-_register_op_make()
+def masked_fill(x: Expr, mask: Expr, value: Expr):
+ """Fill a tensor by a specified value in places defined by a mask.
+ Parameters
+ ----------
+ x : relax.Expr
+ The input data to the operator.
+ mask : relax.Expr
+ The mask.
+ value : relax.Expr
+ The value to set in the input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The filled tensor.
+ """
+ values = _ffi_api.full_like(x, value) # type: ignore
+ return _ffi_api.where(mask, values, x) # type: ignore