This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3665ae2cd [Relax] support scatter ops (#17509)
e3665ae2cd is described below

commit e3665ae2cd3a4764d6ac58b2139a312af8d63216
Author: Archermmt <[email protected]>
AuthorDate: Mon Nov 11 21:00:43 2024 +0800

    [Relax] support scatter ops (#17509)
    
    * support scatter ops
    
    * format fix
    
    * format fix
---
 python/tvm/contrib/msc/core/codegen/codegen.py     |   5 +-
 python/tvm/contrib/msc/core/utils/dataset.py       |  48 +++++++-
 .../msc/framework/torch/frontend/translate.py      |  18 ++-
 python/tvm/contrib/msc/pipeline/pipeline.py        |   3 +-
 .../frontend/torch/base_fx_graph_translator.py     |  32 ++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |   2 +
 src/contrib/msc/framework/torch/torch_opcode.cc    |  56 +++++++++-
 src/contrib/msc/framework/tvm/relax_opcode.cc      |  46 ++++++++
 tests/python/contrib/test_msc/test_graph_build.py  | 124 +++++++++++++++++++--
 .../contrib/test_msc/test_translate_relax.py       |  74 ++++++++++--
 .../contrib/test_msc/test_translate_relay.py       |   4 +-
 .../contrib/test_msc/test_translate_tensorrt.py    |   3 +-
 .../contrib/test_msc/test_translate_torch.py       |  71 +++++++++++-
 tests/python/relax/test_frontend_from_fx.py        |  60 ++++++++++
 14 files changed, 509 insertions(+), 37 deletions(-)

diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py 
b/python/tvm/contrib/msc/core/codegen/codegen.py
index 888f1bad4e..7e3ddd5e07 100644
--- a/python/tvm/contrib/msc/core/codegen/codegen.py
+++ b/python/tvm/contrib/msc/core/codegen/codegen.py
@@ -224,6 +224,7 @@ def relay_to_relax(
     trans_config: Optional[Dict[str, str]] = None,
     build_config: Optional[Dict[str, str]] = None,
     opt_config: Optional[Dict[str, str]] = None,
+    build_folder: msc_utils.MSCDirectory = None,
 ) -> tvm.IRModule:
     """Change relay IRModule to relax MSCGraph.
 
@@ -239,6 +240,8 @@ def relay_to_relax(
         The config for build MSCGraph.
     opt_config: dict
         The config for optimize the relay before translate.
+    build_folder: MSCDirectory
+        The folder for saving scripts and datas.
 
     Returns
     -------
@@ -254,4 +257,4 @@ def relay_to_relax(
         opt_config=opt_config,
     )
 
-    return to_relax(graph, weights, codegen_config={"from_relay": True})
+    return to_relax(graph, weights, codegen_config={"from_relay": True}, 
build_folder=build_folder)
diff --git a/python/tvm/contrib/msc/core/utils/dataset.py 
b/python/tvm/contrib/msc/core/utils/dataset.py
index e6461d1079..9f706dbf74 100644
--- a/python/tvm/contrib/msc/core/utils/dataset.py
+++ b/python/tvm/contrib/msc/core/utils/dataset.py
@@ -20,12 +20,13 @@
 import os
 import shutil
 import json
-from typing import List, Union, Dict, Any
+from typing import List, Union, Dict, Any, Tuple
 import numpy as np
 
 import tvm
 from .arguments import load_dict
 from .info import cast_array, is_array
+from .namespace import MSCFramework
 
 
 def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], 
style="dict") -> Any:
@@ -64,6 +65,51 @@ def format_datas(datas: Union[List[Any], Dict[str, Any]], 
names: List[str], styl
     raise TypeError("Unexpected style " + str(style))
 
 
+def random_data(
+    info: Union[List, Tuple, dict],
+    framework: str = MSCFramework.MSC,
+    device: str = "cpu",
+    max_val: int = None,
+) -> Any:
+    """Create random data from info
+
+    Parameters
+    ----------
+    info: list| tuple| dict
+        The data info.
+    framework: str
+        The framework.
+    device: str
+        The device.
+    """
+
+    if isinstance(info, (tuple, list)):
+        if len(info) == 1:
+            info = {"name": "data", "shape": info[0], "dtype": "float32"}
+        elif len(info) == 2:
+            info = {"name": "data", "shape": info[0], "dtype": info[1]}
+        elif len(info) == 3:
+            info = {"name": info[0], "shape": info[1], "dtype": info[2]}
+        else:
+            raise Exception("Unexpected info " + str(info))
+    assert isinstance(info, dict) and all(
+        key in info for key in ["shape", "dtype"]
+    ), "shape and dtype should be given to create randome data"
+    if info["dtype"] in ("int32", "int64"):
+        if max_val is None:
+            data = np.zeros(info["shape"]).astype(info["dtype"])
+        else:
+            data = np.random.randint(0, high=max_val, 
size=info["shape"]).astype(info["dtype"])
+    elif info["dtype"] == "bool":
+        data = np.random.rand(*info["shape"]).astype("float32")
+        data = np.where(data >= 0.5, True, False)
+    else:
+        data = np.random.rand(*info["shape"]).astype(info["dtype"])
+        if max_val is not None:
+            data *= max_val
+    return cast_array(data, framework, device=device)
+
+
 class BaseDataLoader(object):
     """Basic dataset loader for MSC
 
diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py 
b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
index c8c2844c28..04597bd341 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
@@ -25,6 +25,7 @@ from tvm.relax.frontend.torch import from_fx
 from tvm.contrib.msc.core.ir.graph import MSCGraph
 from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
 from tvm.contrib.msc.core.codegen import relay_to_relax
+from tvm.contrib.msc.core import utils as msc_utils
 
 
 def set_weight_alias(graph: MSCGraph) -> MSCGraph:
@@ -70,6 +71,7 @@ def from_torch(
     opt_config: Optional[Dict[str, str]] = None,
     as_msc: bool = True,
     custom_convert_map: dict = None,
+    build_folder: msc_utils.MSCDirectory = None,
 ) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
     """Change torch nn.Module to MSCGraph.
 
@@ -93,6 +95,8 @@ def from_torch(
         Set to to return msc graph, otherwise relax mod
     custom_convert_map: dict
         The convert map for plugin
+    build_folder: MSCDirectory
+        The folder for saving scripts and datas.
 
     Returns
     -------
@@ -102,9 +106,15 @@ def from_torch(
         The weights from the IRModule.
     """
 
+    # try to symbolic_trace
     if via_relax:
-        input_info = normalize_inputs(input_info)
-        graph_model, params = torch.fx.symbolic_trace(model), None
+        try:
+            graph_model = torch.fx.symbolic_trace(model)
+        except:  # pylint: disable=bare-except
+            via_relax = False
+
+    if via_relax:
+        input_info, params = normalize_inputs(input_info), None
         with torch.no_grad():
             relax_mod = from_fx(graph_model, input_info, 
custom_convert_map=custom_convert_map)
     else:
@@ -122,7 +132,9 @@ def from_torch(
         relay_mod, params = tvm.relay.frontend.from_pytorch(
             scripted_model, shape_list, custom_convert_map=custom_convert_map
         )
-        relax_mod = relay_to_relax(relay_mod, params, trans_config, 
build_config, opt_config)
+        relax_mod = relay_to_relax(
+            relay_mod, params, trans_config, build_config, opt_config, 
build_folder=build_folder
+        )
     if not as_msc:
         return relax_mod, params
     graph, weights = from_relax(relax_mod, trans_config=trans_config, 
build_config=build_config)
diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py 
b/python/tvm/contrib/msc/pipeline/pipeline.py
index e003f69224..09fc7727a6 100644
--- a/python/tvm/contrib/msc/pipeline/pipeline.py
+++ b/python/tvm/contrib/msc/pipeline/pipeline.py
@@ -21,7 +21,6 @@ import os
 import json
 from typing import Any, Union, List, Tuple
 import traceback
-import numpy as np
 
 from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool
 from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey
@@ -678,7 +677,7 @@ class BasePipeline(object):
             def get_random():
                 def _to_data(inp):
                     shape = [1 if isinstance(d, str) else d for d in inp[1]]
-                    return np.random.rand(*shape).astype(inp[2])
+                    return msc_utils.random_data([shape, inp[2]])
 
                 for _ in range(max_batch):
                     yield {i[0]: _to_data(i) for i in self._config["inputs"]}
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 322ee04e0c..d84993c68d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -783,6 +783,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.reshape(x, dims))
 
+    def _scatter(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        if len(node.args) == 1:
+            dim = node.kwargs["dim"]
+            index = self.env[node.kwargs["index"]]
+            src = self.env[node.kwargs["src"]]
+        elif len(node.args) == 4:
+            dim = node.args[1]
+            index = self.env[node.args[2]]
+            src = self.env[node.args[3]]
+        else:
+            raise Exception("Unexpected args " + str(node.args))
+        return self.block_builder.emit(relax.op.scatter_elements(x, index, 
src, axis=dim))
+
     def _split(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         split_size = node.args[1]
@@ -801,6 +815,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
+    def _stack(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        in_args = args[0]
+        assert all(
+            a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] 
for a in in_args[1:]
+        ), "Expect all dim at {} to be the same, get {}".format(
+            axis, [a.struct_info.shape for a in args]
+        )
+        cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
+        s_shape = []
+        for idx, s in enumerate(cat.struct_info.shape):
+            if idx == axis:
+                s_shape.extend([len(in_args), 
in_args[0].struct_info.shape[axis]])
+            else:
+                s_shape.append(s)
+        return self.block_builder.emit(relax.op.reshape(cat, s_shape))
+
     def _tile(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9fbc95fa7c..746010a4dc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -676,9 +676,11 @@ class TorchFXImporter(BaseFXGraphImporter):
             "permute": self._permute,
             "repeat": self._repeat,
             "reshape": self._reshape,
+            "scatter": self._scatter,
             "size": self._size,
             "split": self._split,
             "squeeze": self._squeeze,
+            "stack": self._stack,
             "tile": self._tile,
             "transpose": self._transpose,
             "unsqueeze": lambda node: self.block_builder.emit(
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc 
b/src/contrib/msc/framework/torch/torch_opcode.cc
index 9ae825b804..abac3682fb 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -214,14 +214,28 @@ class TorchConstantCodeGen : public TorchOpCode {
 
  protected:
   void CodeGenInit() final {
+    const auto& dtype = node()->OutputAt(0)->DTypeName();
+    const auto& ref_name = StringUtils::Replace(node()->name, ".", "_");
     if (node()->HasAttr("scalar")) {
-      if (node()->OutputAt(0)->DTypeName() == "int32") {
+      if (dtype == "int32") {
         stack_.assign(module_ref(), node()->GetTypeAttr<int>("scalar"));
-      } else if (node()->OutputAt(0)->DTypeName() == "int64") {
+      } else if (dtype == "int64") {
         stack_.assign(module_ref(), node()->GetTypeAttr<int64_t>("scalar"));
-      } else if (node()->OutputAt(0)->DTypeName() == "float32") {
+      } else if (dtype == "float32") {
         stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
       }
+    } else if (dtype == "int32") {
+      stack_.func_call("register_buffer", "", "self")
+          .call_arg(DocUtils::ToStr(ref_name))
+          .inplace_start("torch.IntTensor")
+          .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+          .inplace_end();
+    } else if (dtype == "int64") {
+      stack_.func_call("register_buffer", "", "self")
+          .call_arg(DocUtils::ToStr(ref_name))
+          .inplace_start("torch.LongTensor")
+          .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+          .inplace_end();
     } else {
       stack_.func_call("torch.Tensor", "data")
           .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
@@ -565,6 +579,39 @@ class TorchSimpleCodeGen : public TorchOpCode {
   TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen);
 };
 
+class TorchScatterElementsCodeGen : public TorchOpCode {
+  TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen)
+
+ protected:
+  void CodeGenForward() final {
+    if (node()->InputAt(1)->DTypeName() == "int32") {
+      stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+    }
+    stack_.op_call()
+        .op_input_arg()
+        .op_arg<int>("axis", "dim")
+        .op_input_arg(1, "index")
+        .op_input_arg(2, "src");
+  }
+};
+
+class TorchScatterNDCodeGen : public TorchOpCode {
+  TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen)
+
+ protected:
+  void CodeGenForward() final {
+    if (node()->InputAt(1)->DTypeName() == "int32") {
+      stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+    }
+    // relax add extra dim for indices
+    if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) {
+      stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1);
+    }
+    stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2))
+        .assign(IdxNode(), IdxInput(0));
+  }
+};
+
 class TorchSplitCodeGen : public TorchOpCode {
   TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen)
 
@@ -719,6 +766,9 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TorchOpCode>>>
   map->emplace("permute_dims", std::make_shared<TorchPermuteDimsCodeGen>("", 
"torch.permute"));
   map->emplace("repeat", std::make_shared<TorchRepeatCodeGen>("", "repeat"));
   map->emplace("reshape", std::make_shared<TorchReshapeCodeGen>("", 
"torch.reshape"));
+  map->emplace("scatter_elements",
+               std::make_shared<TorchScatterElementsCodeGen>("", 
"torch.scatter"));
+  map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
   map->emplace("split", std::make_shared<TorchSplitCodeGen>("", 
"torch.split"));
   map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", 
""));
 
diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc 
b/src/contrib/msc/framework/tvm/relax_opcode.cc
index 73722f9877..a4be884858 100644
--- a/src/contrib/msc/framework/tvm/relax_opcode.cc
+++ b/src/contrib/msc/framework/tvm/relax_opcode.cc
@@ -568,6 +568,34 @@ class RelaxReshapeCodeGen : public RelaxOpCode {
   }
 };
 
+class RelaxScatterElementsCodeGen : public RelaxOpCode {
+  RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen)
+
+ protected:
+  void CodeGenBuild() final { 
stack_.op_call().op_inputs_arg(false).op_arg<int>("axis"); }
+};
+
+class RelaxScatterNDCodeGen : public RelaxOpCode {
+  RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen)
+
+ protected:
+  void CodeGenBuild() final {
+    if (config()->from_relay) {
+      size_t ndim = node()->InputAt(1)->Ndim();
+      std::vector<size_t> axes;
+      axes.push_back(ndim - 1);
+      for (size_t i = 0; i < ndim - 1; i++) {
+        axes.push_back(i);
+      }
+      stack_.func_call("relax.op.permute_dims", IdxInput(1))
+          .call_arg(IdxInput(1))
+          .call_arg(DocUtils::ToList(axes));
+      BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index));
+    }
+    stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction");
+  }
+};
+
 class RelaxResize2dCodeGen : public RelaxOpCode {
   RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen)
 
@@ -626,6 +654,20 @@ class RelaxSplitCodeGen : public RelaxOpCode {
   }
 };
 
+class RelaxStackCodeGen : public RelaxOpCode {
+  RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen)
+
+ protected:
+  void CodeGenBuild() final {
+    stack_.op_call().op_inputs_arg().op_arg<int>("axis");
+    BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index));
+    const auto& out_shape = GetPrims(node()->OutputAt(0));
+    stack_.func_call("relax.op.reshape", IdxNode())
+        .call_arg(IdxNode())
+        .call_arg(DocUtils::ToList(out_shape), "shape");
+  }
+};
+
 class RelaxTakeCodeGen : public RelaxOpCode {
   RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen)
 
@@ -763,7 +805,11 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<RelaxOpCode>>>
   map->emplace("permute_dims", 
std::make_shared<RelaxPermuteDimsCodeGen>("relax.op.permute_dims"));
   map->emplace("repeat", 
std::make_shared<RelaxRepeatCodeGen>("relax.op.repeat"));
   map->emplace("reshape", 
std::make_shared<RelaxReshapeCodeGen>("relax.op.reshape"));
+  map->emplace("scatter_elements",
+               
std::make_shared<RelaxScatterElementsCodeGen>("relax.op.scatter_elements"));
+  map->emplace("scatter_nd", 
std::make_shared<RelaxScatterNDCodeGen>("relax.op.scatter_nd"));
   map->emplace("split", std::make_shared<RelaxSplitCodeGen>("relax.op.split"));
+  map->emplace("stack", 
std::make_shared<RelaxStackCodeGen>("relax.op.concat"));
   map->emplace("strided_slice",
                
std::make_shared<RelaxStridedSliceCodeGen>("relax.op.strided_slice"));
   map->emplace("take", std::make_shared<RelaxTakeCodeGen>("relax.op.take"));
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
index 76e3147a55..647879378e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -20,21 +20,16 @@
 
 import pytest
 import torch
-from torch import fx
 from torch.nn import Module
 
 import tvm.testing
-from tvm.relax.frontend.torch import from_fx
-from tvm.contrib.msc.core.frontend import translate, normalize_inputs
+from tvm.contrib.msc.framework.torch.frontend import translate
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
 from tvm.contrib.msc.core import utils as msc_utils
 
 
 def verify_model(torch_model, input_info, expected):
-    input_info = normalize_inputs(input_info)
-    graph_model = fx.symbolic_trace(torch_model)
-    with torch.no_grad():
-        mod = from_fx(graph_model, input_info)
-    graph, _ = translate.from_relax(mod)
+    graph, _ = translate.from_torch(torch_model, input_info)
     inspect = graph.inspect()
     assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with 
expected {}".format(
         inspect, expected
@@ -2389,6 +2384,119 @@ def test_cat(dynamic):
     verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2)
 
 
[email protected]("dynamic", [True, False])
+def test_stack(dynamic):
+    """test graph builder for stack"""
+
+    bz = "bz" if dynamic else 1
+
+    class Stack(Module):
+        def forward(self, data, data1, data2):
+            return torch.stack((data, data1, data2), dim=0)
+
+    input_info = [
+        ([bz, 3, 10, 10], "float32"),
+        ([bz, 3, 10, 10], "float32"),
+        ([bz, 3, 10, 10], "float32"),
+    ]
+
+    expected = {
+        "inputs": [
+            {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", 
"layout": ""},
+            {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", 
"layout": ""},
+            {"name": "inp_2", "shape": [bz, 3, 10, 10], "dtype": "float32", 
"layout": ""},
+        ],
+        "outputs": [
+            {
+                "name": "reshape",
+                "shape": [3, bz, 3, 10, 10],
+                "dtype": "float32",
+                "layout": "" if dynamic else "EABCD",
+            }
+        ],
+        "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1},
+    }
+
+    if dynamic:
+        expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
+
+    verify_model(Stack(), input_info, expected)
+
+
[email protected]("dynamic", [True, False])
+def test_scatter(dynamic):
+    """test graph builder for scatter"""
+
+    bz = "bz" if dynamic else 20
+
+    class Scatter1(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(2, 5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            return data.scatter(dim=0, index=self.index, src=src)
+
+    class Scatter2(Module):
+        def forward(self, data, index, src):
+            return data.scatter(0, index, src)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": 
""},
+        ],
+        "outputs": [
+            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": ""}
+        ],
+        "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 
1},
+    }
+    expected2 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
""},
+            {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
+            {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": 
""},
+        ],
+        "outputs": [
+            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": ""}
+        ],
+        "nodes": {"total": 4, "input": 3, "scatter_elements": 1},
+    }
+    if dynamic:
+        expected1["prims"] = {"total": 1, "shape": 1}
+        expected2["prims"] = {"total": 1, "shape": 1}
+
+    verify_model(Scatter1(), [([bz, 20], "float32"), ([2, 5], "float32")], 
expected1)
+    verify_model(
+        Scatter2(), [([bz, 20], "float32"), ([2, 5], "int64"), ([2, 5], 
"float32")], expected2
+    )
+
+
+def test_put():
+    """test graph builder for index_put"""
+
+    class IndexPut(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            data[self.index] = src
+            return data
+
+    expected = {
+        "inputs": [
+            {"name": "input0", "shape": [10, 20], "dtype": "float32", 
"layout": ""},
+            {"name": "input1", "shape": [5, 20], "dtype": "float32", "layout": 
""},
+        ],
+        "outputs": [{"name": "scatter_nd", "shape": [10, 20], "dtype": 
"float32", "layout": ""}],
+        "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_nd": 1},
+    }
+
+    input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+    verify_model(IndexPut(), input_info, expected)
+
+
 @pytest.mark.parametrize("dynamic", [True, False])
 def test_attention(dynamic):
     """test graph builder for attention"""
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py 
b/tests/python/contrib/test_msc/test_translate_relax.py
index 64d00bb092..27a02844e1 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -18,27 +18,26 @@
 """ Test translate from relax. """
 
 import torch
-from torch import fx
 from torch.nn import Module
 
 import numpy as np
 
 import tvm.testing
-from tvm.relax.frontend.torch import from_fx
-from tvm.contrib.msc.core.frontend import translate
+from tvm.contrib.msc.framework.torch.frontend import translate as 
torch_translate
+
 from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
+from tvm.contrib.msc.core.frontend import translate as core_translate
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
 
 
 def verify_model(torch_model, input_info, opt_config=None):
     """Compare torch module IR"""
 
-    graph_model = fx.symbolic_trace(torch_model)
-    with torch.no_grad():
-        orig_mod = from_fx(graph_model, input_info)
-
+    orig_mod, _ = torch_translate.from_torch(torch_model, input_info, 
as_msc=False)
     target = "llvm"
     dev = tvm.cpu()
-    args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for 
shape, dtype in input_info]
+    args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info]
 
     def _tvm_runtime_to_np(obj):
         if isinstance(obj, tvm.runtime.NDArray):
@@ -60,7 +59,7 @@ def verify_model(torch_model, input_info, opt_config=None):
         return _tvm_runtime_to_np(res)
 
     rt_mod = tvm_codegen.to_relax(
-        *translate.from_relax(orig_mod, opt_config=opt_config),
+        *core_translate.from_relax(orig_mod, opt_config=opt_config),
         codegen_config={"explicit_name": False},
     )
 
@@ -1153,6 +1152,63 @@ def test_cat():
     verify_model(Cat2(), [([1, 3, 10, 10], "float32")])
 
 
+def test_stack():
+    """test relax translator for stack"""
+
+    class Stack1(Module):
+        def forward(self, data, data1, data2):
+            return torch.stack((data, data1, data2), dim=0)
+
+    class Stack2(Module):
+        def forward(self, data):
+            const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+            const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+            return torch.stack((data, const1, const2), dim=1)
+
+    input_info = [
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+    ]
+    verify_model(Stack1(), input_info)
+    verify_model(Stack2(), [([1, 3, 10, 10], "float32")])
+
+
+def test_scatter():
+    """test relax translator for scatter"""
+
+    class Scatter1(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(2, 5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            return data.scatter(dim=0, index=self.index, src=src)
+
+    class Scatter2(Module):
+        def forward(self, data, index, src):
+            return data.scatter(0, index, src)
+
+    verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")])
+    verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 
5], "float32")])
+
+
+def test_put():
+    """test relax translator for index_put"""
+
+    class IndexPut(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            data[self.index] = src
+            return data
+
+    input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+    verify_model(IndexPut(), input_info)
+
+
 def test_attention():
     """test relax translator for attention"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py 
b/tests/python/contrib/test_msc/test_translate_relay.py
index ebba339a4a..801893e9de 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -18,8 +18,6 @@
 
 """ Test translate from relay. """
 
-import numpy as np
-
 import torch
 from torch import fx
 from torch.nn import Module
@@ -66,7 +64,7 @@ def verify_model(torch_model, input_info, opt_config=None, 
codegen_config=None,
     expected = tvm.relax.transform.CanonicalizeBindings()(expected)
 
     # graph from relay
-    datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
+    datas = [msc_utils.random_data(i) for i in input_info]
     torch_datas = [torch.from_numpy(i) for i in datas]
     with torch.no_grad():
         scripted_model = torch.jit.trace(torch_model, 
tuple(torch_datas)).eval()  # type: ignore
diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py 
b/tests/python/contrib/test_msc/test_translate_tensorrt.py
index 6d87ca8753..e0fd39249a 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorrt.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py
@@ -18,7 +18,6 @@
 """ Test translate for TensorrRT. """
 
 import pytest
-import numpy as np
 
 import torch
 from torch import fx
@@ -91,7 +90,7 @@ def verify_model(torch_model, input_info, **trans_config):
     """Build model and verify results"""
 
     graph_model = fx.symbolic_trace(torch_model)
-    datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
+    datas = [msc_utils.random_data(i) for i in input_info]
     torch_datas = [torch.from_numpy(i) for i in datas]
     with torch.no_grad():
         golden = torch_model(*torch_datas)
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py 
b/tests/python/contrib/test_msc/test_translate_torch.py
index 55bae682ef..6ed28c0ac0 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -17,24 +17,24 @@
 
 """ Test translate from torch. """
 
-import numpy as np
-
 import torch
 from torch.nn import Module
 
 import tvm.testing
 from tvm.contrib.msc.framework.torch.frontend import translate
 from tvm.contrib.msc.framework.torch import codegen
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
 
 
 def verify_model(torch_model, input_info, via_relax=True):
     """Compare torch module results"""
 
-    graph, weights = translate.from_torch(torch_model, input_info, 
via_relax=via_relax)
-    model = codegen.to_torch(graph, weights)
-    torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i 
in input_info]
+    torch_datas = [msc_utils.random_data(i, MSCFramework.TORCH) for i in 
input_info]
     with torch.no_grad():
         golden = torch_model(*torch_datas)
+    graph, weights = translate.from_torch(torch_model, input_info, 
via_relax=via_relax)
+    model = codegen.to_torch(graph, weights)
     with torch.no_grad():
         if not graph.get_inputs():
             result = model()
@@ -1128,6 +1128,67 @@ def test_cat():
         verify_model(Cat2(), [([1, 3, 10, 10], "float32")], via_relax)
 
 
+def test_stack():
+    """test torch translator for stack"""
+
+    class Stack1(Module):
+        def forward(self, data, data1, data2):
+            return torch.stack((data, data1, data2), dim=0)
+
+    class Stack2(Module):
+        def forward(self, data):
+            const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+            const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+            return torch.stack((data, const1, const2), dim=1)
+
+    input_info = [
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+    ]
+    for via_relax in [True, False]:
+        verify_model(Stack1(), input_info, via_relax)
+        verify_model(Stack2(), [([1, 3, 10, 10], "float32")], via_relax)
+
+
+def test_scatter():
+    """test torch translator for scatter"""
+
+    class Scatter1(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(2, 5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            return data.scatter(dim=0, index=self.index, src=src)
+
+    class Scatter2(Module):
+        def forward(self, data, index, src):
+            return data.scatter(0, index, src)
+
+    for via_relax in [True, False]:
+        verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")], 
via_relax)
+        verify_model(
+            Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], 
"float32")], via_relax
+        )
+
+
+def test_put():
+    """test torch translator for index_put"""
+
+    class IndexPut(Module):
+        def __init__(self):
+            super().__init__()
+            self.index = msc_utils.random_data([(5), "int64"], 
MSCFramework.TORCH, max_val=5)
+
+        def forward(self, data, src):
+            data[self.index] = src
+            return data
+
+    input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+    verify_model(IndexPut(), input_info, False)
+
+
 def test_attention():
     """test torch translator for attention"""
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2cabcba325..08331f0861 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3963,5 +3963,65 @@ def test_sym_size_int():
     verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1)
 
 
+def test_stack():
+
+    input_info = [
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+        ([1, 3, 10, 10], "float32"),
+    ]
+
+    class Stack(Module):
+        def forward(self, data, data1, data2):
+            return torch.stack((data, data1, data2), dim=0)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            inp_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat(
+                    (inp_0, inp_1, inp_2), axis=0
+                )
+                lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape(
+                    lv, R.shape([3, 1, 3, 10, 10])
+                )
+                gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(Stack(), input_info, {}, expected)
+
+
+def test_scatter():
+    input_info = [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], 
"float32")]
+
+    class Scatter(Module):
+        def forward(self, data, index, src):
+            return data.scatter(dim=0, index=index, src=src)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((20, 20), dtype="float32"),
+            inp_1: R.Tensor((2, 5), dtype="int64"),
+            inp_2: R.Tensor((2, 5), dtype="float32"),
+        ) -> R.Tensor((20, 20), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((20, 20), dtype="float32") = R.scatter_elements(
+                    inp_0, inp_1, inp_2, axis=0, reduction="update"
+                )
+                gv: R.Tensor((20, 20), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Scatter(), input_info, {}, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to