This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38e726aab1 [Relax][PyTorch] Cleanup unary op converters (#17356)
38e726aab1 is described below

commit 38e726aab191d5c16a7d98b2191a5f97f7fef410
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Thu Sep 12 04:18:07 2024 +0900

    [Relax][PyTorch] Cleanup unary op converters (#17356)
    
    * classify into 9 types of ops
    
    * introduce `_unary_op()`
    
    * cleanup `_clamp()`
    
    * cleanup `_gelu()`
    
    * cleanup `_hardsigmoid()` and `_hardswish()`
    
    * cleanup `_leakyrelu()`
    
    * cleanup `_log_softmax()`
    
    * cleanup `_round()`
    
    * cleanup `_softmax()`
    
    * cleanup `_tril_triu()`
    
    * replace `fx.node.Node` with `fx.Node`
---
 python/tvm/relax/frontend/torch/fx_translator.py | 566 ++++++++++++-----------
 1 file changed, 288 insertions(+), 278 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index aed38d7c49..8d66343254 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -35,7 +35,7 @@ class TorchFXImporter:
         import torch  # type: ignore
         from torch import fx
 
-        self.env: Dict[fx.node.Node, relax.Expr] = {}
+        self.env: Dict[fx.Node, relax.Expr] = {}
         self.params: Dict[torch.Tensor, relax.Expr] = {}
         self.named_modules: Dict[str, torch.Module] = None
         self.block_builder: relax.BlockBuilder = None
@@ -108,7 +108,7 @@ class TorchFXImporter:
     def _retrieve_args(self, node):
         from torch import fx
 
-        if isinstance(node, fx.node.Node):
+        if isinstance(node, fx.Node):
             return self.env[node]
         elif isinstance(node, tuple):
             return tuple(self._retrieve_args(x) for x in node)
@@ -136,33 +136,113 @@ class TorchFXImporter:
         lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs)
         return self.block_builder.emit(op(lhs, rhs))
 
-    ########## Arithmetic ##########
+    ########## Unary Ops ##########
 
-    def _exp(self, node: fx.node.Node) -> relax.Var:
-        return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
+    def _unary_op(self, op: Callable) -> Callable:
+        from torch import fx
 
-    def _sigmoid(self, node: fx.node.Node) -> relax.Var:
-        return 
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
+        def convert(node: fx.Node) -> relax.Var:
+            return self.block_builder.emit(op(self.env[node.args[0]]))
 
-    def _sqrt(self, node: fx.node.Node) -> relax.Expr:
-        arg = self.env[node.args[0]]
-        if isinstance(arg, (int, float)):
-            arg = relax.const(arg, "float32")
-        return self.block_builder.emit(relax.op.sqrt(arg))
+        return convert
 
-    def _rsqrt(self, node: fx.node.Node) -> relax.Expr:
-        arg = self.env[node.args[0]]
-        if isinstance(arg, (int, float)):
-            arg = relax.const(arg, "float32")
-        return self.block_builder.emit(relax.op.rsqrt(arg))
+    def _clamp(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        a_min = args[1] if len(args) > 1 else node.kwargs["min"]
+        a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+        if not isinstance(a_min, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant min value for torch.clamp/clip, "
+                f"but got {a_min} with type {type(a_min)}"
+            )
+        if not isinstance(a_max, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant max value for torch.clamp/clip, "
+                f"but got {a_max} with type {type(a_max)}"
+            )
+        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+    def _gelu(self, node: fx.Node) -> relax.Expr:
+        approximate = node.kwargs.get("approximate", "none")
+        if approximate == "none":
+            return 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
+        elif approximate == "tanh":
+            return 
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
+        else:
+            raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
+
+    def _hardsigmoid(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        x0 = relax.op.add(x, relax.const(3, dtype))
+        x1 = relax.op.clip(x0, 0, 6)
+        return self.block_builder.emit(relax.op.divide(x1, relax.const(6, 
dtype)))
+
+    def _hardswish(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        x0 = relax.op.add(x, relax.const(3, dtype))
+        x1 = relax.op.clip(x0, 0, 6)
+        x2 = relax.op.divide(x1, relax.const(6, dtype))
+        return self.block_builder.emit(relax.op.multiply(x, x2))
+
+    def _leakyrelu(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("negative_slope", 0.01)
+        return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+    def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        alpha = module.negative_slope
+        return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+    def _log_softmax(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
+
+    def _log_softmax_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        dim = module.dim
+        assert dim is not None
+        return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
 
-    def _round(self, node: fx.node.Node) -> relax.Expr:
-        if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
+    def _round(self, node: fx.Node) -> relax.Expr:
+        if node.kwargs.get("decimals", 0) != 0:
             raise ValueError("specifying decimals for round is not supported 
yet")
         arg = self.env[node.args[0]]
         return self.block_builder.emit(relax.op.round(arg))
 
-    def _add(self, node: fx.node.Node) -> relax.Expr:
+    def _softmax(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+    def _softmax_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        dim = module.dim
+        assert dim is not None
+        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+    def _tril_triu(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            x = self.env[node.args[0]]
+            k = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("diagonal", 0)
+            assert isinstance(k, int)
+            return self.block_builder.emit(op(x, k))
+
+        return convert
+
+    ########## Arithmetic ##########
+
+    def _add(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.add, lhs, rhs)
@@ -176,103 +256,54 @@ class TorchFXImporter:
             )
         return lhs + rhs
 
-    def _max(self, node: fx.node.Node) -> relax.Expr:
+    def _max(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.maximum, lhs, rhs)
 
-    def _floordiv(self, node: fx.node.Node) -> relax.Expr:
+    def _floordiv(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
         return lhs // rhs
 
-    def _mul(self, node: fx.node.Node) -> relax.Expr:
+    def _mul(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.multiply, lhs, rhs)
         return lhs * rhs
 
-    def _pow(self, node: fx.node.Node) -> relax.Expr:
+    def _pow(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.power, lhs, rhs)
         return lhs**rhs
 
-    def _neg(self, node: fx.node.Node) -> relax.Expr:
-        x = self.env[node.args[0]]
-        return self.block_builder.emit(relax.op.negative(x))
-
-    def _sub(self, node: fx.node.Node) -> relax.Expr:
+    def _sub(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.subtract, lhs, rhs)
         return lhs - rhs
 
-    def _truediv(self, node: fx.node.Node) -> relax.Expr:
+    def _truediv(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.divide, lhs, rhs)
         return lhs / rhs
 
-    def _clamp(self, node: fx.node.Node) -> relax.Expr:
-        args = self.retrieve_args(node)
-        a_min = node.kwargs["min"]
-        a_max = node.kwargs["max"]
-        if not isinstance(a_min, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant min value for torch.clamp/clip, "
-                f"but got {a_min} with type {type(a_min)}"
-            )
-        if not isinstance(a_max, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant max value for torch.clamp/clip, "
-                f"but got {a_max} with type {type(a_max)}"
-            )
-        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
-
-    def _gelu(self, node: fx.node.Node) -> relax.Expr:
-        if "approximate" not in node.kwargs:
-            approximate = "none"
-        else:
-            approximate = node.kwargs["approximate"]
-        if approximate == "none":
-            return 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
-        elif approximate == "tanh":
-            return 
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
-        else:
-            raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
-
-    def _hardsigmoid(self, node: fx.node.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dtype = x.struct_info.dtype
-        x0 = relax.op.add(x, relax.const(3, dtype))
-        x1 = relax.op.clip(x0, 0, 6)
-        return self.block_builder.emit(relax.op.divide(x1, relax.const(6, 
dtype)))
-
-    def _hardswish(self, node: fx.node.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dtype = x.struct_info.dtype
-        x0 = relax.op.add(x, relax.const(3, dtype))
-        x1 = relax.op.clip(x0, 0, 6)
-        x2 = relax.op.divide(x1, relax.const(6, dtype))
-        return self.block_builder.emit(relax.op.multiply(x, x2))
-
     ########## Compare ##########
 
-    def _lt(self, node: fx.node.Node) -> relax.Expr:
+    def _lt(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.less, lhs, rhs)
 
-    def _eq(self, node: fx.node.Node) -> relax.Expr:
+    def _eq(self, node: fx.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.equal, lhs, rhs)
 
     ########## Creation ##########
 
-    def _arange(self, node: fx.node.Node) -> relax.Var:
+    def _arange(self, node: fx.Node) -> relax.Var:
         import torch
 
         start_end_step = [None, None, None]
@@ -311,15 +342,15 @@ class TorchFXImporter:
         else:
             dtype = "int64"
         start_end_step = [
-            self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in 
start_end_step
+            self.env[x] if isinstance(x, torch.fx.Node) else x for x in 
start_end_step
         ]
         return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
 
-    def _empty(self, node: fx.node.Node) -> relax.Var:
+    def _empty(self, node: fx.Node) -> relax.Var:
         dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
         return self.block_builder.emit(relax.op.zeros(node.args, dtype))
 
-    def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
+    def _inplace_fill(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         dtype = x.struct_info.dtype
@@ -328,7 +359,7 @@ class TorchFXImporter:
         self.env[node.args[0]] = filled
         return filled
 
-    def _tensor(self, node: fx.node.Node) -> relax.Var:
+    def _tensor(self, node: fx.Node) -> relax.Var:
         dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
         if isinstance(node.args[0], float):
             return relax.const(node.args[0], dtype if dtype is not None else 
"float32")
@@ -336,21 +367,10 @@ class TorchFXImporter:
             return relax.const(node.args[0], dtype if dtype is not None else 
"int64")
         raise ValueError("torch.tensor with value not a float or int is not 
accepted")
 
-    def _tril_triu(self, op: Callable) -> Callable:
-        from torch import fx
-
-        def convert(node: fx.node.Node) -> relax.Var:
-            x = self.env[node.args[0]]
-            k = node.args[1] if len(node.args) > 1 else 0
-            assert isinstance(k, int)
-            return self.block_builder.emit(op(x, k))
-
-        return convert
-
     def _inplace_tril_triu(self, op: Callable) -> Callable:
         from torch import fx
 
-        def convert(node: fx.node.Node) -> relax.Var:
+        def convert(node: fx.Node) -> relax.Var:
             x = self.env[node.args[0]]
             k = node.args[1] if len(node.args) > 1 else 0
             assert isinstance(k, int)
@@ -361,7 +381,7 @@ class TorchFXImporter:
 
         return convert
 
-    def _new_ones(self, node: fx.node.Node) -> relax.Var:
+    def _new_ones(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         self_var = args[0]
         size = args[1:]
@@ -376,7 +396,7 @@ class TorchFXImporter:
             )
         )
 
-    def _ones(self, node: fx.node.Node) -> relax.Var:
+    def _ones(self, node: fx.Node) -> relax.Var:
         import torch
 
         args = self.retrieve_args(node)
@@ -397,7 +417,7 @@ class TorchFXImporter:
             )
         )
 
-    def _full(self, node: fx.node.Node) -> relax.Var:
+    def _full(self, node: fx.Node) -> relax.Var:
         import torch
 
         args = self.retrieve_args(node)
@@ -421,14 +441,14 @@ class TorchFXImporter:
 
     ########## Statistical ##########
 
-    def _sum(self, node: fx.node.Node) -> relax.Var:
+    def _sum(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
         if len(args) == 1:
             return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
         return self.block_builder.emit(relax.op.sum(args[0], args[1]))
 
-    def _mean(self, node: fx.node.Node) -> relax.Var:
+    def _mean(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
         if len(args) == 1:
@@ -437,18 +457,18 @@ class TorchFXImporter:
 
     ########## DataType ##########
 
-    def _float(self, node: fx.node.Node) -> relax.Var:
+    def _float(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float32"))
 
-    def _half(self, node: fx.node.Node) -> relax.Var:
+    def _half(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
 
-    def _type(self, node: fx.node.Node) -> relax.Var:
+    def _type(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
         return self.block_builder.emit(relax.op.astype(x, dtype))
 
-    def _to(self, node: fx.node.Node) -> relax.Var:
+    def _to(self, node: fx.Node) -> relax.Var:
         import torch
 
         x = self.env[node.args[0]]
@@ -466,7 +486,7 @@ class TorchFXImporter:
     def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
         return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, 
out_dtype="float32"))
 
-    def _matmul(self, node: fx.node.Node) -> relax.Var:
+    def _matmul(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         res = self._matmul_impl(
             args[0],
@@ -474,7 +494,7 @@ class TorchFXImporter:
         )
         return res
 
-    def _addmm(self, node: fx.node.Node) -> relax.Var:
+    def _addmm(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         y = self.env[node.args[1]]
         z = self.env[node.args[2]]
@@ -496,7 +516,7 @@ class TorchFXImporter:
             res = bias if res is None else 
self.block_builder.emit(relax.op.add(bias, res))
         return res
 
-    def _baddbmm(self, node: fx.node.Node) -> relax.Var:
+    def _baddbmm(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         a = self.env[node.args[1]]
         b = self.env[node.args[2]]
@@ -518,7 +538,7 @@ class TorchFXImporter:
             res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
         return res
 
-    def _einsum(self, node: fx.node.Node) -> relax.Var:
+    def _einsum(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         args = self.retrieve_args(node)
@@ -526,7 +546,7 @@ class TorchFXImporter:
             return self.block_builder.emit(relax.op.einsum(tuple(args[1]), 
args[0]))
         return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))
 
-    def _unbind(self, node: fx.node.Node) -> relax.Var:
+    def _unbind(self, node: fx.Node) -> relax.Var:
         if len(node.args) == 2:
             assert isinstance(node.args[1], int), "Expected 2nd argument of 
unbind as int"
             dim = node.args[1]
@@ -544,12 +564,12 @@ class TorchFXImporter:
 
     ########## Manipulation ##########
 
-    def _cat(self, node: fx.node.Node) -> relax.Var:
+    def _cat(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
         return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
 
-    def _expand(self, node: fx.node.Node) -> relax.Var:
+    def _expand(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         broadcast_shape, in_shape = [], self.shape_of(args[0])
         for idx, i in enumerate(args[1:]):
@@ -559,7 +579,7 @@ class TorchFXImporter:
                 broadcast_shape.append(i)
         return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
 
-    def _flatten(self, node: fx.node.Node) -> relax.Var:
+    def _flatten(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
             module = self.named_modules[node.target]
@@ -579,7 +599,7 @@ class TorchFXImporter:
         )
         return self.block_builder.emit(relax.op.reshape(x, new_shape))
 
-    def _permute(self, node: fx.node.Node) -> relax.Var:
+    def _permute(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         args = self.retrieve_args(node)
@@ -587,7 +607,7 @@ class TorchFXImporter:
             return self.block_builder.emit(relax.op.permute_dims(args[0], 
tuple(args[1])))
         return self.block_builder.emit(relax.op.permute_dims(args[0], 
args[1:]))
 
-    def _reshape(self, node: fx.node.Node) -> relax.Var:
+    def _reshape(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         args = self.retrieve_args(node)
@@ -595,7 +615,7 @@ class TorchFXImporter:
             return self.block_builder.emit(relax.op.reshape(args[0], 
tuple(args[1])))
         return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
 
-    def _split(self, node: fx.node.Node) -> relax.Var:
+    def _split(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         split_size = node.args[1]
         if "dim" in node.kwargs:
@@ -611,7 +631,7 @@ class TorchFXImporter:
             n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
         return self.block_builder.emit(relax.op.split(x, n_section, dim))
 
-    def _chunk(self, node: fx.node.Node) -> relax.Var:
+    def _chunk(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         chunks = node.args[1]
 
@@ -623,13 +643,13 @@ class TorchFXImporter:
             dim = 0
         return self.block_builder.emit(relax.op.split(x, chunks, dim))
 
-    def _transpose(self, node: fx.node.Node) -> relax.Var:
+    def _transpose(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         full_idx = list(range(len(self.shape_of(args[0]))))
         full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
         return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
 
-    def _squeeze(self, node: fx.node.Node) -> relax.Var:
+    def _squeeze(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
 
         if "dim" in node.kwargs:
@@ -640,7 +660,7 @@ class TorchFXImporter:
             dim = None
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
-    def _repeat(self, node: fx.node.Node) -> relax.Var:
+    def _repeat(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         args = self.retrieve_args(node)
@@ -648,7 +668,7 @@ class TorchFXImporter:
             return self.block_builder.emit(relax.op.tile(args[0], 
tuple(args[1])))
         return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
 
-    def _tile(self, node: fx.node.Node) -> relax.Var:
+    def _tile(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         args = self.retrieve_args(node)
@@ -656,7 +676,7 @@ class TorchFXImporter:
             return self.block_builder.emit(relax.op.tile(args[0], 
tuple(args[1])))
         return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
 
-    def _cumsum(self, node: fx.node.Node) -> relax.Var:
+    def _cumsum(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
 
         if "dim" in node.kwargs:
@@ -674,13 +694,13 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
 
-    def _index_select(self, node: fx.node.Node) -> relax.Var:
+    def _index_select(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1]
         index = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.take(x, index, dim))
 
-    def _masked_fill(self, node: fx.node.Node) -> relax.Var:
+    def _masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
         value = node.args[2]
@@ -688,7 +708,7 @@ class TorchFXImporter:
         values = self.block_builder.emit(relax.op.full_like(x, rx_value))
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
-    def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var:
+    def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
         value = node.args[2]
@@ -703,7 +723,7 @@ class TorchFXImporter:
     def _argmax_argmin(self, op: Callable) -> Callable:
         from torch import fx
 
-        def convert(node: fx.node.Node):
+        def convert(node: fx.Node):
             x = self.env[node.args[0]]
             dim = None
             keepdims = False
@@ -726,14 +746,14 @@ class TorchFXImporter:
 
     ########## Neural Network ##########
 
-    def _linear(self, node: fx.node.Node) -> relax.Var:
+    def _linear(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
         bias = None if module.bias is None else self.params[module.bias]
         return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
 
-    def _linear_functional(self, node: fx.node.Node) -> relax.Var:
+    def _linear_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -770,7 +790,7 @@ class TorchFXImporter:
         bias = relax.op.reshape(bias, (1, -1, 1))
         return self.block_builder.emit(relax.op.add(conv1d, bias))
 
-    def _conv1d(self, node: fx.node.Node) -> relax.Var:
+    def _conv1d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -788,7 +808,7 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
+    def _conv1d_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -838,7 +858,7 @@ class TorchFXImporter:
         bias = relax.op.reshape(bias, (1, -1, 1))
         return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
 
-    def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
+    def _conv1d_transpose(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -856,7 +876,7 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
+    def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -905,7 +925,7 @@ class TorchFXImporter:
         bias = relax.op.reshape(bias, (1, -1, 1, 1))
         return self.block_builder.emit(relax.op.add(conv2d, bias))
 
-    def _conv2d(self, node: fx.node.Node) -> relax.Var:
+    def _conv2d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -923,7 +943,7 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
+    def _conv2d_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -973,7 +993,7 @@ class TorchFXImporter:
         bias = relax.op.reshape(bias, (1, -1, 1, 1))
         return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
 
-    def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
+    def _conv2d_transpose(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -991,7 +1011,7 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
+    def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -1040,7 +1060,7 @@ class TorchFXImporter:
         bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
         return self.block_builder.emit(relax.op.add(conv3d, bias))
 
-    def _conv3d(self, node: fx.node.Node) -> relax.Var:
+    def _conv3d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -1058,7 +1078,7 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv3d_functional(self, node: fx.node.Node) -> relax.Var:
+    def _conv3d_functional(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
         weight = args[1]
@@ -1077,7 +1097,7 @@ class TorchFXImporter:
             groups=groups,
         )
 
-    def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
+    def _max_pool2d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
             module = self.named_modules[node.target]
@@ -1108,7 +1128,7 @@ class TorchFXImporter:
             )
         )
 
-    def _avg_pool2d(self, node: fx.node.Node) -> relax.Var:
+    def _avg_pool2d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
             module = self.named_modules[node.target]
@@ -1154,7 +1174,7 @@ class TorchFXImporter:
     def _adaptive_avg_pool2d(self, is_module: bool) -> Callable:
         from torch import fx
 
-        def _impl(node: fx.node.Node) -> relax.Var:
+        def _impl(node: fx.Node) -> relax.Var:
             if is_module:
                 module = self.named_modules[node.target]
                 x = self.env[node.args[0]]
@@ -1168,7 +1188,7 @@ class TorchFXImporter:
 
         return _impl
 
-    def _softmax(self, node: fx.node.Node) -> relax.Var:
+    def _softmax(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
             module = self.named_modules[node.target]
@@ -1179,29 +1199,7 @@ class TorchFXImporter:
         assert dim is not None
         return self.block_builder.emit(relax.op.nn.softmax(x, dim))
 
-    def _log_softmax(self, node: fx.node.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        if node.target in self.named_modules:
-            module = self.named_modules[node.target]
-            dim = module.dim
-        else:
-            nargs = len(node.args)
-            dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
-        assert dim is not None
-        return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
-
-    def _leakyrelu(self, node: fx.node.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        if node.target in self.named_modules:
-            module = self.named_modules[node.target]
-            alpha = module.negative_slope
-        else:
-            nargs = len(node.args)
-            alpha = node.args[1] if nargs > 1 else 
node.kwargs["negative_slope"]
-        assert alpha is not None
-        return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
-
-    def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
+    def _batch_norm_2d(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -1224,7 +1222,7 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
 
-    def _layer_norm(self, node: fx.node.Node) -> relax.Var:
+    def _layer_norm(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
         from torch.fx.immutable_collections import immutable_list
         import numpy as np  # type: ignore
@@ -1291,7 +1289,7 @@ class TorchFXImporter:
             )
         )
 
-    def _group_norm(self, node: fx.node.Node) -> relax.Var:
+    def _group_norm(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
         x = self.env[node.args[0]]
@@ -1317,7 +1315,7 @@ class TorchFXImporter:
             )
         )
 
-    def _embedding(self, node: fx.node.Node) -> relax.Var:
+    def _embedding(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
@@ -1333,7 +1331,7 @@ class TorchFXImporter:
             embedding = self.block_builder.emit(relax.op.take(weight, x, 
axis=0))
             return self.block_builder.emit(relax.op.reshape(embedding, 
[*x_shape, emb_size]))
 
-    def _interpolate(self, node: fx.node.Node) -> relax.Var:
+    def _interpolate(self, node: fx.Node) -> relax.Var:
         # torch.nn.functional.interpolate(
         #   input, size=None, scale_factor=None, mode='nearest', 
align_corners=None,
         #   recompute_scale_factor=None, antialias=False)
@@ -1407,7 +1405,7 @@ class TorchFXImporter:
             )
         )
 
-    def _cross_entropy(self, node: fx.node.Node) -> relax.Expr:
+    def _cross_entropy(self, node: fx.Node) -> relax.Expr:
         preds = self.env[node.args[0]]
         targets = self.env[node.args[1]]
 
@@ -1442,7 +1440,7 @@ class TorchFXImporter:
             )
         )
 
-    def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
+    def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
         assert (
             len(node.args) <= 4
         ), "Dropout is not supported, and is_causal should be called by 
kwargs."
@@ -1464,13 +1462,13 @@ class TorchFXImporter:
 
     ########## Others ##########
 
-    def _sym_size_int(self, node: fx.node.Node) -> relax.Expr:
+    def _sym_size_int(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
         idx = node.args[1]
         return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
 
-    def _size(self, node: fx.node.Node) -> relax.Expr:
+    def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
         if len(node.args) == 1:
@@ -1480,7 +1478,7 @@ class TorchFXImporter:
         idx = node.args[1]
         return self.shape_of(x)[idx].value
 
-    def _getattr(self, node: fx.node.Node) -> relax.Var:
+    def _getattr(self, node: fx.Node) -> relax.Var:
         if isinstance(self.env[node.args[0]], relax.Expr):
             if node.args[1] == "dtype":
                 return self.env[node.args[0]].struct_info.dtype
@@ -1488,7 +1486,7 @@ class TorchFXImporter:
                 return self.shape_of(self.env[node.args[0]])
         return getattr(self.env[node.args[0]], node.args[1])
 
-    def _getitem(self, node: fx.node.Node) -> relax.Var:
+    def _getitem(self, node: fx.Node) -> relax.Var:
         import torch
 
         x = self.env[node.args[0]]
@@ -1510,7 +1508,7 @@ class TorchFXImporter:
             shape = self.shape_of(x)
             non_ellipsis_cnt = 0
             for index in node.args[1]:
-                if isinstance(index, (int, slice, torch.fx.node.Node)):
+                if isinstance(index, (int, slice, torch.fx.Node)):
                     non_ellipsis_cnt += 1
             for index in node.args[1]:
                 if isinstance(index, int):
@@ -1534,7 +1532,7 @@ class TorchFXImporter:
                         stride.append(1)
                         stride_axes.append(i)
                         i += 1
-                elif isinstance(index, torch.fx.node.Node):
+                elif isinstance(index, torch.fx.Node):
                     node_index = self.env[index]
                     if not isinstance(node_index, relax.Expr):
                         raise ValueError(
@@ -1573,142 +1571,154 @@ class TorchFXImporter:
         from torch import nn
         from torch import fx
 
-        self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], 
relax.Var]] = {
-            # call_module
-            nn.Linear: self._linear,
+        self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], 
relax.Var]] = {
+            ## call_module
+            # unary
+            nn.Dropout: lambda node: self.env[node.args[0]],
+            nn.GELU: self._gelu,
+            nn.Hardsigmoid: self._hardsigmoid,
+            nn.Hardswish: self._hardswish,
+            nn.Identity: lambda node: self.env[node.args[0]],
+            nn.LeakyReLU: self._leakyrelu_module,
+            nn.LogSoftmax: self._log_softmax_module,
+            nn.ReLU: self._unary_op(relax.op.nn.relu),
+            nn.ReLU6: lambda node: self.block_builder.emit(
+                relax.op.clip(self.env[node.args[0]], 0, 6)
+            ),
+            nn.Sigmoid: self._unary_op(relax.op.sigmoid),
+            nn.SiLU: self._unary_op(relax.op.nn.silu),
+            nn.Softmax: self._softmax_module,
+            nn.Tanh: self._unary_op(relax.op.tanh),
+            # neural network
+            nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
+            nn.AvgPool2d: self._avg_pool2d,
+            nn.BatchNorm2d: self._batch_norm_2d,
             nn.Conv1d: self._conv1d,
             nn.Conv2d: self._conv2d,
             nn.Conv3d: self._conv3d,
             nn.ConvTranspose1d: self._conv1d_transpose,
             nn.ConvTranspose2d: self._conv2d_transpose,
-            nn.MaxPool2d: self._max_pool2d,
-            nn.AvgPool2d: self._avg_pool2d,
-            nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
-            nn.Softmax: self._softmax,
-            nn.LogSoftmax: self._log_softmax,
-            nn.ReLU: lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
-            nn.LeakyReLU: self._leakyrelu,
-            nn.ReLU6: lambda node: self.block_builder.emit(
-                relax.op.clip(self.env[node.args[0]], 0, 6)
-            ),
-            nn.GELU: self._gelu,
-            nn.Sigmoid: self._sigmoid,
-            nn.Tanh: lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
-            nn.SiLU: lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
-            nn.Hardsigmoid: self._hardsigmoid,
-            nn.Hardswish: self._hardswish,
-            nn.Flatten: self._flatten,
-            nn.BatchNorm2d: self._batch_norm_2d,
-            nn.LayerNorm: self._layer_norm,
+            nn.CrossEntropyLoss: self._cross_entropy,
             nn.GroupNorm: self._group_norm,
-            nn.Dropout: lambda node: self.env[node.args[0]],
-            nn.Identity: lambda node: self.env[node.args[0]],
+            nn.LayerNorm: self._layer_norm,
+            nn.Linear: self._linear,
+            nn.MaxPool2d: self._max_pool2d,
             nn.modules.sparse.Embedding: self._embedding,
-            nn.CrossEntropyLoss: self._cross_entropy,
-            # call_function and call_method
-            "sin": lambda node: 
self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
-            "cos": lambda node: 
self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
-            "tan": lambda node: 
self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
-            "asin": lambda node: 
self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
-            "acos": lambda node: 
self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
-            "atan": lambda node: 
self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
-            "sinh": lambda node: 
self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
-            "cosh": lambda node: 
self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
-            "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
-            "asinh": lambda node: 
self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
-            "acosh": lambda node: 
self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
-            "atanh": lambda node: 
self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
-            "exp": self._exp,
-            "iadd": self._add,
+            # tensor manipulation
+            nn.Flatten: self._flatten,
+            ## call_function and call_method
+            # unary
+            "acos": self._unary_op(relax.op.acos),
+            "acosh": self._unary_op(relax.op.acosh),
+            "asin": self._unary_op(relax.op.asin),
+            "asinh": self._unary_op(relax.op.asinh),
+            "atan": self._unary_op(relax.op.atan),
+            "atanh": self._unary_op(relax.op.atanh),
+            "clamp": self._clamp,
+            "cos": self._unary_op(relax.op.cos),
+            "cosh": self._unary_op(relax.op.cosh),
+            "dropout": lambda node: self.env[node.args[0]],
+            "exp": self._unary_op(relax.op.exp),
+            "gelu": self._gelu,
+            "hardsigmoid": self._hardsigmoid,
+            "hardswish": self._hardswish,
+            "leaky_relu": self._leakyrelu,
+            "log_softmax": self._log_softmax,
+            "neg": self._unary_op(relax.op.negative),
+            "relu": self._unary_op(relax.op.nn.relu),
+            "round": self._round,
+            "rsqrt": self._unary_op(relax.op.rsqrt),
+            "sigmoid": self._unary_op(relax.op.sigmoid),
+            "silu": self._unary_op(relax.op.nn.silu),
+            "sin": self._unary_op(relax.op.sin),
+            "sinh": self._unary_op(relax.op.sinh),
+            "softmax": self._softmax,
+            "sqrt": self._unary_op(relax.op.sqrt),
+            "tan": self._unary_op(relax.op.tan),
+            "tanh": self._unary_op(relax.op.tanh),
+            "tril_": self._inplace_tril_triu(relax.op.tril),
+            "tril": self._tril_triu(relax.op.tril),
+            "triu_": self._inplace_tril_triu(relax.op.triu),
+            "triu": self._tril_triu(relax.op.triu),
+            # binary
             "add": self._add,
+            "eq": self._eq,
             "floordiv": self._floordiv,
+            "iadd": self._add,
+            "lt": self._lt,
+            "matmul": self._matmul,
+            "max": self._max,
             "mul": self._mul,
-            "sub": self._sub,
             "pow": self._pow,
-            "sigmoid": self._sigmoid,
-            "sqrt": self._sqrt,
-            "round": self._round,
-            "lt": self._lt,
-            "eq": self._eq,
+            "sub": self._sub,
             "truediv": self._truediv,
-            "fill_": self._inplace_fill,
-            "new_ones": self._new_ones,
-            "arange": self._arange,
-            "empty": self._empty,
-            "tensor": self._tensor,
-            "tril": self._tril_triu(relax.op.tril),
-            "triu": self._tril_triu(relax.op.triu),
-            "tril_": self._inplace_tril_triu(relax.op.tril),
-            "triu_": self._inplace_tril_triu(relax.op.triu),
-            "sum": self._sum,
-            "float": self._float,
-            "half": self._half,
-            "type": self._type,
-            "astype": self._type,
-            "matmul": self._matmul,
-            "conv1d": self._conv1d_functional,
+            # neural network
+            "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+            "addmm": self._addmm,
+            "avg_pool2d": self._avg_pool2d,
+            "baddbmm": self._baddbmm,
+            "bmm": self._matmul,
             "conv_transpose1d": self._conv1d_transpose_functional,
-            "conv2d": self._conv2d_functional,
             "conv_transpose2d": self._conv2d_transpose_functional,
+            "conv1d": self._conv1d_functional,
+            "conv2d": self._conv2d_functional,
             "conv3d": self._conv3d_functional,
+            "cross_entropy": self._cross_entropy,
+            "einsum": self._einsum,
+            "interpolate": self._interpolate,
+            "layer_norm": self._layer_norm,
             "linear": self._linear_functional,
-            "addmm": self._addmm,
-            "baddbmm": self._baddbmm,
-            "bmm": self._matmul,
+            "max_pool2d": self._max_pool2d,
+            "scaled_dot_product_attention": self._scaled_dot_product_attention,
+            "stochastic_depth": lambda node: self.env[node.args[0]],
+            "unbind": self._unbind,
+            # statistical
+            "mean": self._mean,
+            "sum": self._sum,
+            # search
+            "argmax": self._argmax_argmin(relax.op.argmax),
+            "argmin": self._argmax_argmin(relax.op.argmin),
+            # tensor manipulation
             "cat": self._cat,
             "concat": self._cat,
+            "contiguous": lambda node: self.env[node.args[0]],
+            "cumsum": self._cumsum,
             "expand": self._expand,
             "flatten": self._flatten,
             "permute": self._permute,
             "repeat": self._repeat,
             "reshape": self._reshape,
+            "size": self._size,
             "split": self._split,
+            "squeeze": self._squeeze,
             "tile": self._tile,
-            "cumsum": self._cumsum,
-            "chunk": self._chunk,
             "transpose": self._transpose,
-            "squeeze": self._squeeze,
             "unsqueeze": lambda node: self.block_builder.emit(
                 relax.op.expand_dims(self.env[node.args[0]], node.args[1])
             ),
             "view": self._reshape,
-            "argmax": self._argmax_argmin(relax.op.argmax),
-            "argmin": self._argmax_argmin(relax.op.argmin),
-            "softmax": self._softmax,
-            "log_softmax": self._log_softmax,
-            "dropout": lambda node: self.env[node.args[0]],
-            "stochastic_depth": lambda node: self.env[node.args[0]],
-            "clamp": self._clamp,
-            "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
-            "leaky_relu": self._leakyrelu,
-            "gelu": self._gelu,
-            "silu": lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
-            "hardsigmoid": self._hardsigmoid,
-            "hardswish": self._hardswish,
-            "interpolate": self._interpolate,
-            "sym_size.int": self._sym_size_int,
-            "size": self._size,
-            "getattr": self._getattr,
-            "getitem": self._getitem,
-            "contiguous": lambda node: self.env[node.args[0]],
-            "to": self._to,
-            "max_pool2d": self._max_pool2d,
-            "avg_pool2d": self._avg_pool2d,
-            "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
-            "layer_norm": self._layer_norm,
+            # tensor creation
+            "arange": self._arange,
+            "chunk": self._chunk,
+            "empty": self._empty,
+            "fill_": self._inplace_fill,
+            "full": self._full,
             "index_select": self._index_select,
+            "masked_fill_": self._inplace_masked_fill,
             "masked_fill": self._masked_fill,
+            "new_ones": self._new_ones,
             "ones": self._ones,
-            "full": self._full,
-            "masked_fill_": self._inplace_masked_fill,
-            "mean": self._mean,
-            "rsqrt": self._rsqrt,
-            "neg": self._neg,
-            "max": self._max,
-            "cross_entropy": self._cross_entropy,
-            "scaled_dot_product_attention": self._scaled_dot_product_attention,
-            "einsum": self._einsum,
-            "unbind": self._unbind,
+            "tensor": self._tensor,
+            "to": self._to,
+            # datatype
+            "astype": self._type,
+            "float": self._float,
+            "half": self._half,
+            "type": self._type,
+            # other
+            "getattr": self._getattr,
+            "getitem": self._getitem,
+            "sym_size.int": self._sym_size_int,
         }
 
     def update_convert_map(self, custom_convert_map: dict):


Reply via email to