This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ab3f82669 [Relax][PyTorch] Cleanup Tensor Manipulation and Creation 
op converters (#17376)
4ab3f82669 is described below

commit 4ab3f82669fb20d77cae47704c857ab39a577417
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Sep 16 23:13:41 2024 +0900

    [Relax][PyTorch] Cleanup Tensor Manipulation and Creation op converters 
(#17376)
    
    * cleanup `_cat()`
    
    * cleanup `_cumsum()`
    
    * cleanup `_expand()`
    
    * cleanup `_flatten()`
    
    * cleanup `_permute()`
    
    * cleanup `_repeat()`
    
    * cleanup `_reshape()`
    
    * cleanup `_size()`
    
    * cleanup `_split()`
    
    * cleanup `_squeeze()`
    
    * cleanup `_tile()`
    
    * cleanup `_transpose()`
    
    * cleanup `chunk()`
    
    * cleanup `_arange()`
    
    * cleanup `_empty()`
    
    * cleanup `_inplace_fill()`
    
    * cleanup `_full()`
    
    * cleanup `_index_select()`
    
    * cleanup `_inplace_masked_fill()`
    
    * cleanup `_masked_fill()`
    
    * cleanup `_new_ones()`
    
    * cleanup `_ones()`
    
    * cleanup `_tensor()`
    
    * `_inplace_tril_triu()` is an unary op
    
    * `_batch_norm_2d()` is a nn ops
    
    * `_interpolate()` is a nn ops
    
    * `_cross_entropy()` is a nn ops
    
    * chore
    
    * fix tensor size
---
 python/tvm/relax/frontend/torch/fx_translator.py | 755 +++++++++++------------
 1 file changed, 358 insertions(+), 397 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4dc49d20ff..983bce0255 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -212,6 +212,20 @@ class TorchFXImporter:
         assert dim is not None
         return self.block_builder.emit(relax.op.nn.softmax(x, dim))
 
+    def _inplace_tril_triu(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            x = self.env[node.args[0]]
+            k = node.args[1] if len(node.args) > 1 else 0
+            assert isinstance(k, int)
+
+            mutated = self.block_builder.emit(op(x, k))
+            self.env[node.args[0]] = mutated
+            return mutated
+
+        return convert
+
     def _tril_triu(self, op: Callable) -> Callable:
         from torch import fx
 
@@ -356,6 +370,29 @@ class TorchFXImporter:
             res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
         return res
 
+    def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+        bias = self.params[module.bias]
+        running_mean = self._convert_torch_tensor_to_relax(module.running_mean)
+        running_var = self._convert_torch_tensor_to_relax(module.running_var)
+        eps = module.eps
+
+        res_tuple = self.block_builder.emit(
+            relax.op.nn.batch_norm(
+                x,
+                weight,
+                bias,
+                running_mean,
+                running_var,
+                axis=1,
+                epsilon=eps,
+            )
+        )
+
+        return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+
     def _conv1d_transpose_impl(
         self,
         x: relax.Expr,
@@ -683,6 +720,40 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
+    def _cross_entropy(self, node: fx.Node) -> relax.Expr:
+        preds = self.env[node.args[0]]
+        targets = self.env[node.args[1]]
+        weights = self.env.get(node.kwargs["weight"], None)
+        reduction = node.kwargs["reduction"]
+        ignore_index = node.kwargs["ignore_index"]
+
+        return self.block_builder.emit(
+            relax.op.nn.nll_loss(
+                relax.op.nn.log_softmax(preds), targets, weights, reduction, 
ignore_index
+            )
+        )
+
+    def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
+        preds = self.env[node.args[0]]
+        targets = self.env[node.args[1]]
+        module = self.named_modules[node.target]
+
+        weights = module.weight
+        if weights is not None:
+            if weights in self.params:
+                weights = self.params[weights]
+            else:
+                weights = relax.const(weights.numpy(), preds.struct_info.dtype)
+
+        reduction = module.reduction
+        ignore_index = module.ignore_index
+
+        return self.block_builder.emit(
+            relax.op.nn.nll_loss(
+                relax.op.nn.log_softmax(preds), targets, weights, reduction, 
ignore_index
+            )
+        )
+
     def _einsum(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
@@ -740,6 +811,80 @@ class TorchFXImporter:
             )
         )
 
+    def _interpolate(self, node: fx.Node) -> relax.Var:
+        # torch.nn.functional.interpolate(
+        #   input, size=None, scale_factor=None, mode='nearest', 
align_corners=None,
+        #   recompute_scale_factor=None, antialias=False)
+        # (TODO) this is a temporary implementation for interpolate that only 
considers NCHW layout
+        # it basically replicates the implementation in 
tvm.relay.frontend.pytorch
+        data = self.env[node.args[0]]
+        size = (
+            node.args[1]
+            if len(node.args) > 1
+            else (node.kwargs["size"] if "size" in node.kwargs else None)
+        )
+        scale_factor = (
+            node.args[2]
+            if len(node.args) > 2
+            else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs 
else None)
+        )
+        method = (
+            node.args[3]
+            if len(node.args) > 3
+            else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
+        )
+        align_corners = (
+            node.args[4]
+            if len(node.args) > 4
+            else (node.kwargs["align_corners"] if "align_corners" in 
node.kwargs else None)
+        )
+        recompute_scale_factor = (
+            node.args[5]
+            if len(node.args) > 5
+            else (
+                node.kwargs["recompute_scale_factor"]
+                if "recompute_scale_factor" in node.kwargs
+                else None
+            )
+        )
+        antialias = (
+            node.args[6]
+            if len(node.args) > 6
+            else (node.kwargs["antialias"] if "antialias" in node.kwargs else 
False)
+        )
+
+        assert recompute_scale_factor is None
+        assert antialias is False
+
+        if size is None:
+            shape = self.shape_of(data)
+            assert isinstance(shape, relax.ShapeExpr)
+            if isinstance(scale_factor, tuple):
+                assert len(scale_factor) == len(shape) - 2
+                size = tuple(
+                    int(shape[i].value * scale_factor[i - 2]) for i in 
range(2, len(shape))
+                )
+            else:
+                size = tuple(int(shape[i].value * scale_factor) for i in 
range(2, len(shape)))
+
+        if method.startswith("nearest"):
+            method = "nearest_neighbor"
+        elif method[0:2] == "bi":
+            method = method[2:]
+
+        if method == "nearest_neighbor":
+            coord_trans = "asymmetric"
+        elif align_corners:
+            coord_trans = "align_corners"
+        else:
+            coord_trans = "half_pixel"
+
+        return self.block_builder.emit(
+            relax.op.image.resize2d(
+                data, size, layout="NCHW", method=method, 
coordinate_transformation_mode=coord_trans
+            )
+        )
+
     def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> 
relax.Var:
         from torch.fx.immutable_collections import immutable_list
         import numpy as np  # type: ignore
@@ -913,230 +1058,106 @@ class TorchFXImporter:
 
         return convert
 
-    ########## DataType ##########
-
-    def _float(self, node: fx.Node) -> relax.Var:
-        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float32"))
-
-    def _half(self, node: fx.Node) -> relax.Var:
-        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
+    ########## Manipulation ##########
 
-    def _to(self, node: fx.Node) -> relax.Var:
-        import torch
+    def _cat(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
 
+    def _chunk(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        if len(node.args) == 2:
-            if isinstance(node.args[1], torch.dtype):
-                dtype = TorchFXImporter._convert_data_type(node.args[1], 
self.env)
-                return self.block_builder.emit(relax.op.astype(x, dtype))
-        elif "dtype" in node.kwargs:
-            dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], 
self.env)
-            return self.block_builder.emit(relax.op.astype(x, dtype))
-        return x
+        chunks = node.args[1]
+        dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+        return self.block_builder.emit(relax.op.split(x, chunks, dim))
 
-    def _type(self, node: fx.Node) -> relax.Var:
+    def _cumsum(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
-        return self.block_builder.emit(relax.op.astype(x, dtype))
 
-    ########## Creation ##########
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        if "dtype" in node.kwargs:
+            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+        else:
+            dtype = None
+        if "out" in node.kwargs:
+            raise ValueError("specifying out for cumsum is not supported yet")
 
-    def _arange(self, node: fx.Node) -> relax.Var:
-        import torch
+        return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
 
-        start_end_step = [None, None, None]
-        if "start" in node.kwargs:
-            start_end_step[0] = node.kwargs["start"]
-        if "end" in node.kwargs:
-            start_end_step[1] = node.kwargs["end"]
-        if "step" in node.kwargs:
-            start_end_step[2] = node.kwargs["step"]
+    def _expand(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        sizes = args[1:] if len(args) > 2 else args[1]
+        broadcast_shape, in_shape = [], self.shape_of(args[0])
+        for idx, i in enumerate(sizes):
+            if isinstance(i, int) and i == -1:
+                broadcast_shape.append(in_shape[idx])
+            else:
+                broadcast_shape.append(i)
+        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
 
-        if len(node.args) == 1:
-            assert start_end_step[1] is None
-            start_end_step[1] = node.args[0]
-        elif len(node.args) == 2:
-            assert start_end_step[0] is None
-            assert start_end_step[1] is None
-            start_end_step[0] = node.args[0]
-            start_end_step[1] = node.args[1]
-        elif len(node.args) == 3:
-            assert start_end_step[0] is None
-            assert start_end_step[1] is None
-            assert start_end_step[2] is None
-            start_end_step[0] = node.args[0]
-            start_end_step[1] = node.args[1]
-            start_end_step[2] = node.args[2]
+    def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
+        shape = self.shape_of(x)
+        start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+        end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+        flattened = reduce(lambda x, y: x * y, [shape[i] for i in 
range(start_dim, end_dim + 1)])
+        new_shape = (
+            [shape[i] for i in range(0, start_dim)]
+            + [flattened]
+            + [shape[i] for i in range(end_dim + 1, len(shape))]
+        )
+        return self.block_builder.emit(relax.op.reshape(x, new_shape))
 
-        if start_end_step[0] is None:
-            start_end_step[0] = 0
-        if start_end_step[2] is None:
-            start_end_step[2] = 1
+    def _flatten(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        start_dim = node.args[1] if len(node.args) >= 2 else 
node.kwargs.get("start_dim", 0)
+        end_dim = node.args[2] if len(node.args) == 3 else 
node.kwargs.get("end_dim", -1)
+        return self._flatten_impl(x, start_dim, end_dim)
 
-        if "dtype" in node.kwargs:
-            dtype = 
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
-        elif any([isinstance(x, float) for x in start_end_step]):
-            dtype = 
TorchFXImporter._convert_data_type(torch.get_default_dtype())
-        else:
-            dtype = "int64"
-        start_end_step = [
-            self.env[x] if isinstance(x, torch.fx.Node) else x for x in 
start_end_step
-        ]
-        return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
+    def _flatten_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        start_dim = module.start_dim
+        end_dim = module.end_dim
+        return self._flatten_impl(x, start_dim, end_dim)
 
-    def _empty(self, node: fx.Node) -> relax.Var:
-        dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
-        return self.block_builder.emit(relax.op.zeros(node.args, dtype))
+    def _permute(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
 
-    def _inplace_fill(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
-        dtype = x.struct_info.dtype
-        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
-        filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
-        self.env[node.args[0]] = filled
-        return filled
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.permute_dims(x, dims))
 
-    def _tensor(self, node: fx.Node) -> relax.Var:
-        dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
-        if isinstance(node.args[0], float):
-            return relax.const(node.args[0], dtype if dtype is not None else 
"float32")
-        elif isinstance(node.args[0], int):
-            return relax.const(node.args[0], dtype if dtype is not None else 
"int64")
-        raise ValueError("torch.tensor with value not a float or int is not 
accepted")
+    def _repeat(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
 
-    def _inplace_tril_triu(self, op: Callable) -> Callable:
-        from torch import fx
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.tile(x, dims))
 
-        def convert(node: fx.Node) -> relax.Var:
-            x = self.env[node.args[0]]
-            k = node.args[1] if len(node.args) > 1 else 0
-            assert isinstance(k, int)
-
-            mutated = self.block_builder.emit(op(x, k))
-            self.env[node.args[0]] = mutated
-            return mutated
-
-        return convert
-
-    def _new_ones(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        self_var = args[0]
-        size = args[1:]
-        if not isinstance(size, (list, tuple)):
-            size = (size,)
-        size = relax.ShapeExpr(size)
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                relax.const(1, self_var.struct_info.dtype),
-                self_var.struct_info.dtype,
-            )
-        )
-
-    def _ones(self, node: fx.Node) -> relax.Var:
-        import torch
-
-        args = self.retrieve_args(node)
-        size = args[0]
-        if not isinstance(size, (list, tuple)):
-            size = (size,)
-        size = relax.ShapeExpr(size)
-        dtype = (
-            TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
-            if "dtype" in node.kwargs
-            else TorchFXImporter._convert_data_type(torch.get_default_dtype(), 
self.env)
-        )
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                relax.const(1, dtype),
-                dtype,
-            )
-        )
-
-    def _full(self, node: fx.Node) -> relax.Var:
-        import torch
-
-        args = self.retrieve_args(node)
-        size = args[0]
-        if not isinstance(size, (list, tuple)):
-            size = (size,)
-        size = relax.ShapeExpr(size)
-        dtype = (
-            TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
-            if "dtype" in node.kwargs
-            else TorchFXImporter._convert_data_type(torch.get_default_dtype(), 
self.env)
-        )
-        value = args[1] if isinstance(args[1], relax.expr.Constant) else 
relax.const(args[1], dtype)
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                value,
-                dtype,
-            )
-        )
-
-    ########## Manipulation ##########
-
-    def _cat(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
-        return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+    def _reshape(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
 
-    def _expand(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
-        broadcast_shape, in_shape = [], self.shape_of(args[0])
-        for idx, i in enumerate(args[1:]):
-            if isinstance(i, int) and i == -1:
-                broadcast_shape.append(in_shape[idx])
-            else:
-                broadcast_shape.append(i)
-        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.reshape(x, dims))
 
-    def _flatten(self, node: fx.Node) -> relax.Var:
+    def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
-        if node.target in self.named_modules:
-            module = self.named_modules[node.target]
-            start_dim = module.start_dim
-            end_dim = module.end_dim
-        else:
-            start_dim = node.args[1] if len(node.args) >= 2 else 0
-            end_dim = node.args[2] if len(node.args) == 3 else -1
         shape = self.shape_of(x)
-        start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
-        end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
-        flattened = reduce(lambda x, y: x * y, [shape[i] for i in 
range(start_dim, end_dim + 1)])
-        new_shape = (
-            [shape[i] for i in range(0, start_dim)]
-            + [flattened]
-            + [shape[i] for i in range(end_dim + 1, len(shape))]
-        )
-        return self.block_builder.emit(relax.op.reshape(x, new_shape))
-
-    def _permute(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        if isinstance(args[1], (torch.Size, tuple, list)):
-            return self.block_builder.emit(relax.op.permute_dims(args[0], 
tuple(args[1])))
-        return self.block_builder.emit(relax.op.permute_dims(args[0], 
args[1:]))
-
-    def _reshape(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        if isinstance(args[1], (torch.Size, tuple, list)):
-            return self.block_builder.emit(relax.op.reshape(args[0], 
tuple(args[1])))
-        return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
+        if len(node.args) == 1:
+            assert isinstance(shape, relax.ShapeExpr)
+            return shape
+        assert len(node.args) == 2
+        idx = node.args[1]
+        return self.shape_of(x)[idx].value
 
     def _split(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         split_size = node.args[1]
-        if "dim" in node.kwargs:
-            dim = node.kwargs["dim"]
-        else:
-            dim = 0
+        dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
         if isinstance(split_size, (list, tuple)):
             n_section = []
             for s in split_size[:-1]:
@@ -1146,17 +1167,18 @@ class TorchFXImporter:
             n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
         return self.block_builder.emit(relax.op.split(x, n_section, dim))
 
-    def _chunk(self, node: fx.Node) -> relax.Var:
+    def _squeeze(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        chunks = node.args[1]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        return self.block_builder.emit(relax.op.squeeze(x, dim))
 
-        if "dim" in node.kwargs:
-            dim = node.kwargs["dim"]
-        elif len(node.args) > 2:
-            dim = node.args[2]
-        else:
-            dim = 0
-        return self.block_builder.emit(relax.op.split(x, chunks, dim))
+    def _tile(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.tile(x, dims))
 
     def _transpose(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
@@ -1164,50 +1186,80 @@ class TorchFXImporter:
         full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
         return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
 
-    def _squeeze(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-
-        if "dim" in node.kwargs:
-            dim = node.kwargs["dim"]
-        elif len(node.args) > 1:
-            dim = node.args[1]
-        else:
-            dim = None
-        return self.block_builder.emit(relax.op.squeeze(x, dim))
+    ########## Creation ##########
 
-    def _repeat(self, node: fx.Node) -> relax.Var:
+    def _arange(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
-        args = self.retrieve_args(node)
-        if isinstance(args[1], (torch.Size, tuple, list)):
-            return self.block_builder.emit(relax.op.tile(args[0], 
tuple(args[1])))
-        return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
-
-    def _tile(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
+        start_end_step = [None, None, None]
+        if "start" in node.kwargs:
+            start_end_step[0] = node.kwargs["start"]
+        if "end" in node.kwargs:
+            start_end_step[1] = node.kwargs["end"]
+        if "step" in node.kwargs:
+            start_end_step[2] = node.kwargs["step"]
 
-        args = self.retrieve_args(node)
-        if isinstance(args[1], (torch.Size, tuple, list)):
-            return self.block_builder.emit(relax.op.tile(args[0], 
tuple(args[1])))
-        return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
+        if len(node.args) == 1:
+            assert start_end_step[1] is None
+            start_end_step[1] = node.args[0]
+        elif len(node.args) == 2:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+        elif len(node.args) == 3:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            assert start_end_step[2] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+            start_end_step[2] = node.args[2]
 
-    def _cumsum(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
+        if start_end_step[0] is None:
+            start_end_step[0] = 0
+        if start_end_step[2] is None:
+            start_end_step[2] = 1
 
-        if "dim" in node.kwargs:
-            dim = node.kwargs["dim"]
-        elif len(node.args) > 1:
-            dim = node.args[1]
-        else:
-            dim = None
         if "dtype" in node.kwargs:
-            dtype = 
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
+            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+        elif any([isinstance(x, float) for x in start_end_step]):
+            dtype = self._convert_data_type(torch.get_default_dtype())
         else:
-            dtype = None
-        if "out" in node.kwargs:
-            raise ValueError("specifying out for cumsum is not supported yet")
+            dtype = "int64"
+        start_end_step = [
+            self.env[x] if isinstance(x, torch.fx.Node) else x for x in 
start_end_step
+        ]
+        return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
 
-        return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+    def _empty(self, node: fx.Node) -> relax.Var:
+        dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+        return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
+
+    def _inplace_fill(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
+        filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
+        self.env[node.args[0]] = filled
+        return filled
+
+    def _full(self, node: fx.Node) -> relax.Var:
+        import torch
+
+        args = self.retrieve_args(node)
+        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        value = args[1] if isinstance(args[1], relax.expr.Constant) else 
relax.const(args[1], dtype)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                value,
+                dtype,
+            )
+        )
 
     def _index_select(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
@@ -1215,14 +1267,6 @@ class TorchFXImporter:
         index = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.take(x, index, dim))
 
-    def _masked_fill(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        mask = self.env[node.args[1]]
-        value = node.args[2]
-        rx_value = relax.const(value)
-        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
-        return self.block_builder.emit(relax.op.where(mask, values, x))
-
     def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
@@ -1233,168 +1277,79 @@ class TorchFXImporter:
         self.env[node.args[0]] = output
         return output
 
-    ########## Neural Network ##########
-
-    def _softmax(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        if node.target in self.named_modules:
-            module = self.named_modules[node.target]
-            dim = module.dim
-        else:
-            nargs = len(node.args)
-            dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
-        assert dim is not None
-        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
-
-    def _batch_norm_2d(self, node: fx.Node) -> relax.Var:
+    def _masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        module = self.named_modules[node.target]
-        weight = self.params[module.weight]
-        bias = self.params[module.bias]
-        running_mean = self._convert_torch_tensor_to_relax(module.running_mean)
-        running_var = self._convert_torch_tensor_to_relax(module.running_var)
-        eps = module.eps
+        mask = self.env[node.args[1]]
+        rx_value = relax.const(node.args[2])
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        return self.block_builder.emit(relax.op.where(mask, values, x))
 
-        res_tuple = self.block_builder.emit(
-            relax.op.nn.batch_norm(
-                x,
-                weight,
-                bias,
-                running_mean,
-                running_var,
-                axis=1,
-                epsilon=eps,
+    def _new_ones(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        self_var = args[0]
+        size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
+        if not isinstance(size, (list, tuple)):
+            size = (size,)
+        size = relax.ShapeExpr(size)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                relax.const(1, self_var.struct_info.dtype),
+                self_var.struct_info.dtype,
             )
         )
 
-        return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+    def _ones(self, node: fx.Node) -> relax.Var:
+        import torch
 
-    def _interpolate(self, node: fx.Node) -> relax.Var:
-        # torch.nn.functional.interpolate(
-        #   input, size=None, scale_factor=None, mode='nearest', 
align_corners=None,
-        #   recompute_scale_factor=None, antialias=False)
-        # (TODO) this is a temporary implementation for interpolate that only 
considers NCHW layout
-        # it basically replicates the implementation in 
tvm.relay.frontend.pytorch
-        data = self.env[node.args[0]]
-        size = (
-            node.args[1]
-            if len(node.args) > 1
-            else (node.kwargs["size"] if "size" in node.kwargs else None)
-        )
-        scale_factor = (
-            node.args[2]
-            if len(node.args) > 2
-            else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs 
else None)
-        )
-        method = (
-            node.args[3]
-            if len(node.args) > 3
-            else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
-        )
-        align_corners = (
-            node.args[4]
-            if len(node.args) > 4
-            else (node.kwargs["align_corners"] if "align_corners" in 
node.kwargs else None)
-        )
-        recompute_scale_factor = (
-            node.args[5]
-            if len(node.args) > 5
-            else (
-                node.kwargs["recompute_scale_factor"]
-                if "recompute_scale_factor" in node.kwargs
-                else None
-            )
-        )
-        antialias = (
-            node.args[6]
-            if len(node.args) > 6
-            else (node.kwargs["antialias"] if "antialias" in node.kwargs else 
False)
+        args = self.retrieve_args(node)
+        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
         )
-
-        assert recompute_scale_factor is None
-        assert antialias is False
-
-        if size is None:
-            shape = self.shape_of(data)
-            assert isinstance(shape, relax.ShapeExpr)
-            if isinstance(scale_factor, tuple):
-                assert len(scale_factor) == len(shape) - 2
-                size = tuple(
-                    int(shape[i].value * scale_factor[i - 2]) for i in 
range(2, len(shape))
-                )
-            else:
-                size = tuple(int(shape[i].value * scale_factor) for i in 
range(2, len(shape)))
-
-        if method.startswith("nearest"):
-            method = "nearest_neighbor"
-        elif method[0:2] == "bi":
-            method = method[2:]
-
-        if method == "nearest_neighbor":
-            coord_trans = "asymmetric"
-        elif align_corners:
-            coord_trans = "align_corners"
-        else:
-            coord_trans = "half_pixel"
-
         return self.block_builder.emit(
-            relax.op.image.resize2d(
-                data, size, layout="NCHW", method=method, 
coordinate_transformation_mode=coord_trans
+            relax.op.full(
+                size,
+                relax.const(1, dtype),
+                dtype,
             )
         )
 
-    def _cross_entropy(self, node: fx.Node) -> relax.Expr:
-        preds = self.env[node.args[0]]
-        targets = self.env[node.args[1]]
-
-        # functional.cross_entropy
-        if node.target not in self.named_modules:
-            weights = node.kwargs["weight"]
-            if weights is not None:
-                weights = self.env[weights]
-            reduction = node.kwargs["reduction"]
-            ignore_index = node.kwargs["ignore_index"]
-
-            return self.block_builder.emit(
-                relax.op.nn.nll_loss(
-                    relax.op.nn.log_softmax(preds), targets, weights, 
reduction, ignore_index
-                )
-            )
+    def _tensor(self, node: fx.Node) -> relax.Var:
+        dtype = node.kwargs.get("dtype", None)
+        if isinstance(node.args[0], float):
+            return relax.const(node.args[0], dtype if dtype is not None else 
"float32")
+        elif isinstance(node.args[0], int):
+            return relax.const(node.args[0], dtype if dtype is not None else 
"int64")
+        raise ValueError("torch.tensor with value not a float or int is not 
accepted")
 
-        module = self.named_modules[node.target]
+    ########## DataType ##########
 
-        weights = module.weight
-        if weights is not None:
-            if weights in self.params:
-                weights = self.params[weights]
-            else:
-                weights = relax.const(weights.numpy(), preds.struct_info.dtype)
-        reduction = module.reduction
-        ignore_index = module.ignore_index
+    def _float(self, node: fx.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float32"))
 
-        return self.block_builder.emit(
-            relax.op.nn.nll_loss(
-                relax.op.nn.log_softmax(preds), targets, weights, reduction, 
ignore_index
-            )
-        )
+    def _half(self, node: fx.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
 
-    ########## Others ##########
+    def _to(self, node: fx.Node) -> relax.Var:
+        import torch
 
-    def _sym_size_int(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
-        shape = self.shape_of(x)
-        idx = node.args[1]
-        return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
+        if len(node.args) == 2:
+            if isinstance(node.args[1], torch.dtype):
+                dtype = TorchFXImporter._convert_data_type(node.args[1], 
self.env)
+                return self.block_builder.emit(relax.op.astype(x, dtype))
+        elif "dtype" in node.kwargs:
+            dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], 
self.env)
+            return self.block_builder.emit(relax.op.astype(x, dtype))
+        return x
 
-    def _size(self, node: fx.Node) -> relax.Expr:
+    def _type(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        shape = self.shape_of(x)
-        if len(node.args) == 1:
-            assert isinstance(shape, relax.ShapeExpr)
-            return shape
-        assert len(node.args) == 2
-        idx = node.args[1]
-        return self.shape_of(x)[idx].value
+        dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
+        return self.block_builder.emit(relax.op.astype(x, dtype))
+
+    ########## Others ##########
 
     def _getattr(self, node: fx.Node) -> relax.Var:
         if isinstance(self.env[node.args[0]], relax.Expr):
@@ -1485,6 +1440,12 @@ class TorchFXImporter:
         else:
             assert False
 
+    def _sym_size_int(self, node: fx.Node) -> relax.Expr:
+        x = self.env[node.args[0]]
+        shape = self.shape_of(x)
+        idx = node.args[1]
+        return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
+
     def create_convert_map(self):
         import operator
         from torch import nn
@@ -1511,20 +1472,20 @@ class TorchFXImporter:
             # neural network
             nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
             nn.AvgPool2d: self._avg_pool2d_module,
-            nn.BatchNorm2d: self._batch_norm_2d,
+            nn.BatchNorm2d: self._batch_norm_2d_module,
             nn.Conv1d: self._conv1d_module,
             nn.Conv2d: self._conv2d_module,
             nn.Conv3d: self._conv3d_module,
             nn.ConvTranspose1d: self._conv1d_transpose_module,
             nn.ConvTranspose2d: self._conv2d_transpose_module,
-            nn.CrossEntropyLoss: self._cross_entropy,
+            nn.CrossEntropyLoss: self._cross_entropy_module,
             nn.GroupNorm: self._group_norm_module,
             nn.LayerNorm: self._layer_norm_module,
             nn.Linear: self._linear_module,
             nn.MaxPool2d: self._max_pool2d_module,
             nn.modules.sparse.Embedding: self._embedding_module,
             # tensor manipulation
-            nn.Flatten: self._flatten,
+            nn.Flatten: self._flatten_module,
             ## call_function and call_method
             # unary
             "acos": self._unary_op(relax.op.acos),
@@ -1603,6 +1564,7 @@ class TorchFXImporter:
             "argmin": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
             "cat": self._cat,
+            "chunk": self._chunk,
             "concat": self._cat,
             "contiguous": lambda node: self.env[node.args[0]],
             "cumsum": self._cumsum,
@@ -1622,7 +1584,6 @@ class TorchFXImporter:
             "view": self._reshape,
             # tensor creation
             "arange": self._arange,
-            "chunk": self._chunk,
             "empty": self._empty,
             "fill_": self._inplace_fill,
             "full": self._full,
@@ -1632,11 +1593,11 @@ class TorchFXImporter:
             "new_ones": self._new_ones,
             "ones": self._ones,
             "tensor": self._tensor,
-            "to": self._to,
             # datatype
             "astype": self._type,
             "float": self._float,
             "half": self._half,
+            "to": self._to,
             "type": self._type,
             # other
             "getattr": self._getattr,


Reply via email to