This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e3665ae2cd [Relax] support scatter ops (#17509)
e3665ae2cd is described below
commit e3665ae2cd3a4764d6ac58b2139a312af8d63216
Author: Archermmt <[email protected]>
AuthorDate: Mon Nov 11 21:00:43 2024 +0800
[Relax] support scatter ops (#17509)
* support scatter ops
* format fix
* format fix
---
python/tvm/contrib/msc/core/codegen/codegen.py | 5 +-
python/tvm/contrib/msc/core/utils/dataset.py | 48 +++++++-
.../msc/framework/torch/frontend/translate.py | 18 ++-
python/tvm/contrib/msc/pipeline/pipeline.py | 3 +-
.../frontend/torch/base_fx_graph_translator.py | 32 ++++++
python/tvm/relax/frontend/torch/fx_translator.py | 2 +
src/contrib/msc/framework/torch/torch_opcode.cc | 56 +++++++++-
src/contrib/msc/framework/tvm/relax_opcode.cc | 46 ++++++++
tests/python/contrib/test_msc/test_graph_build.py | 124 +++++++++++++++++++--
.../contrib/test_msc/test_translate_relax.py | 74 ++++++++++--
.../contrib/test_msc/test_translate_relay.py | 4 +-
.../contrib/test_msc/test_translate_tensorrt.py | 3 +-
.../contrib/test_msc/test_translate_torch.py | 71 +++++++++++-
tests/python/relax/test_frontend_from_fx.py | 60 ++++++++++
14 files changed, 509 insertions(+), 37 deletions(-)
diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py
b/python/tvm/contrib/msc/core/codegen/codegen.py
index 888f1bad4e..7e3ddd5e07 100644
--- a/python/tvm/contrib/msc/core/codegen/codegen.py
+++ b/python/tvm/contrib/msc/core/codegen/codegen.py
@@ -224,6 +224,7 @@ def relay_to_relax(
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
+ build_folder: msc_utils.MSCDirectory = None,
) -> tvm.IRModule:
"""Change relay IRModule to relax MSCGraph.
@@ -239,6 +240,8 @@ def relay_to_relax(
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
+ build_folder: MSCDirectory
+ The folder for saving scripts and datas.
Returns
-------
@@ -254,4 +257,4 @@ def relay_to_relax(
opt_config=opt_config,
)
- return to_relax(graph, weights, codegen_config={"from_relay": True})
+ return to_relax(graph, weights, codegen_config={"from_relay": True},
build_folder=build_folder)
diff --git a/python/tvm/contrib/msc/core/utils/dataset.py
b/python/tvm/contrib/msc/core/utils/dataset.py
index e6461d1079..9f706dbf74 100644
--- a/python/tvm/contrib/msc/core/utils/dataset.py
+++ b/python/tvm/contrib/msc/core/utils/dataset.py
@@ -20,12 +20,13 @@
import os
import shutil
import json
-from typing import List, Union, Dict, Any
+from typing import List, Union, Dict, Any, Tuple
import numpy as np
import tvm
from .arguments import load_dict
from .info import cast_array, is_array
+from .namespace import MSCFramework
def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str],
style="dict") -> Any:
@@ -64,6 +65,51 @@ def format_datas(datas: Union[List[Any], Dict[str, Any]],
names: List[str], styl
raise TypeError("Unexpected style " + str(style))
+def random_data(
+ info: Union[List, Tuple, dict],
+ framework: str = MSCFramework.MSC,
+ device: str = "cpu",
+ max_val: int = None,
+) -> Any:
+ """Create random data from info
+
+ Parameters
+ ----------
+ info: list| tuple| dict
+ The data info.
+ framework: str
+ The framework.
+ device: str
+ The device.
+ """
+
+ if isinstance(info, (tuple, list)):
+ if len(info) == 1:
+ info = {"name": "data", "shape": info[0], "dtype": "float32"}
+ elif len(info) == 2:
+ info = {"name": "data", "shape": info[0], "dtype": info[1]}
+ elif len(info) == 3:
+ info = {"name": info[0], "shape": info[1], "dtype": info[2]}
+ else:
+ raise Exception("Unexpected info " + str(info))
+ assert isinstance(info, dict) and all(
+ key in info for key in ["shape", "dtype"]
+ ), "shape and dtype should be given to create randome data"
+ if info["dtype"] in ("int32", "int64"):
+ if max_val is None:
+ data = np.zeros(info["shape"]).astype(info["dtype"])
+ else:
+ data = np.random.randint(0, high=max_val,
size=info["shape"]).astype(info["dtype"])
+ elif info["dtype"] == "bool":
+ data = np.random.rand(*info["shape"]).astype("float32")
+ data = np.where(data >= 0.5, True, False)
+ else:
+ data = np.random.rand(*info["shape"]).astype(info["dtype"])
+ if max_val is not None:
+ data *= max_val
+ return cast_array(data, framework, device=device)
+
+
class BaseDataLoader(object):
"""Basic dataset loader for MSC
diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
index c8c2844c28..04597bd341 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
@@ -25,6 +25,7 @@ from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
from tvm.contrib.msc.core.codegen import relay_to_relax
+from tvm.contrib.msc.core import utils as msc_utils
def set_weight_alias(graph: MSCGraph) -> MSCGraph:
@@ -70,6 +71,7 @@ def from_torch(
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
custom_convert_map: dict = None,
+ build_folder: msc_utils.MSCDirectory = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.
@@ -93,6 +95,8 @@ def from_torch(
Set to to return msc graph, otherwise relax mod
custom_convert_map: dict
The convert map for plugin
+ build_folder: MSCDirectory
+ The folder for saving scripts and datas.
Returns
-------
@@ -102,9 +106,15 @@ def from_torch(
The weights from the IRModule.
"""
+ # try to symbolic_trace
if via_relax:
- input_info = normalize_inputs(input_info)
- graph_model, params = torch.fx.symbolic_trace(model), None
+ try:
+ graph_model = torch.fx.symbolic_trace(model)
+ except: # pylint: disable=bare-except
+ via_relax = False
+
+ if via_relax:
+ input_info, params = normalize_inputs(input_info), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info,
custom_convert_map=custom_convert_map)
else:
@@ -122,7 +132,9 @@ def from_torch(
relay_mod, params = tvm.relay.frontend.from_pytorch(
scripted_model, shape_list, custom_convert_map=custom_convert_map
)
- relax_mod = relay_to_relax(relay_mod, params, trans_config,
build_config, opt_config)
+ relax_mod = relay_to_relax(
+ relay_mod, params, trans_config, build_config, opt_config,
build_folder=build_folder
+ )
if not as_msc:
return relax_mod, params
graph, weights = from_relax(relax_mod, trans_config=trans_config,
build_config=build_config)
diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py
b/python/tvm/contrib/msc/pipeline/pipeline.py
index e003f69224..09fc7727a6 100644
--- a/python/tvm/contrib/msc/pipeline/pipeline.py
+++ b/python/tvm/contrib/msc/pipeline/pipeline.py
@@ -21,7 +21,6 @@ import os
import json
from typing import Any, Union, List, Tuple
import traceback
-import numpy as np
from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool
from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey
@@ -678,7 +677,7 @@ class BasePipeline(object):
def get_random():
def _to_data(inp):
shape = [1 if isinstance(d, str) else d for d in inp[1]]
- return np.random.rand(*shape).astype(inp[2])
+ return msc_utils.random_data([shape, inp[2]])
for _ in range(max_batch):
yield {i[0]: _to_data(i) for i in self._config["inputs"]}
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 322ee04e0c..d84993c68d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -783,6 +783,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))
+ def _scatter(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ if len(node.args) == 1:
+ dim = node.kwargs["dim"]
+ index = self.env[node.kwargs["index"]]
+ src = self.env[node.kwargs["src"]]
+ elif len(node.args) == 4:
+ dim = node.args[1]
+ index = self.env[node.args[2]]
+ src = self.env[node.args[3]]
+ else:
+ raise Exception("Unexpected args " + str(node.args))
+ return self.block_builder.emit(relax.op.scatter_elements(x, index,
src, axis=dim))
+
def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
@@ -801,6 +815,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
return self.block_builder.emit(relax.op.squeeze(x, dim))
+ def _stack(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ in_args = args[0]
+ assert all(
+ a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis]
for a in in_args[1:]
+ ), "Expect all dim at {} to be the same, get {}".format(
+ axis, [a.struct_info.shape for a in args]
+ )
+ cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
+ s_shape = []
+ for idx, s in enumerate(cat.struct_info.shape):
+ if idx == axis:
+ s_shape.extend([len(in_args),
in_args[0].struct_info.shape[axis]])
+ else:
+ s_shape.append(s)
+ return self.block_builder.emit(relax.op.reshape(cat, s_shape))
+
def _tile(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9fbc95fa7c..746010a4dc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -676,9 +676,11 @@ class TorchFXImporter(BaseFXGraphImporter):
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
+ "scatter": self._scatter,
"size": self._size,
"split": self._split,
"squeeze": self._squeeze,
+ "stack": self._stack,
"tile": self._tile,
"transpose": self._transpose,
"unsqueeze": lambda node: self.block_builder.emit(
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc
b/src/contrib/msc/framework/torch/torch_opcode.cc
index 9ae825b804..abac3682fb 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -214,14 +214,28 @@ class TorchConstantCodeGen : public TorchOpCode {
protected:
void CodeGenInit() final {
+ const auto& dtype = node()->OutputAt(0)->DTypeName();
+ const auto& ref_name = StringUtils::Replace(node()->name, ".", "_");
if (node()->HasAttr("scalar")) {
- if (node()->OutputAt(0)->DTypeName() == "int32") {
+ if (dtype == "int32") {
stack_.assign(module_ref(), node()->GetTypeAttr<int>("scalar"));
- } else if (node()->OutputAt(0)->DTypeName() == "int64") {
+ } else if (dtype == "int64") {
stack_.assign(module_ref(), node()->GetTypeAttr<int64_t>("scalar"));
- } else if (node()->OutputAt(0)->DTypeName() == "float32") {
+ } else if (dtype == "float32") {
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
}
+ } else if (dtype == "int32") {
+ stack_.func_call("register_buffer", "", "self")
+ .call_arg(DocUtils::ToStr(ref_name))
+ .inplace_start("torch.IntTensor")
+ .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+ .inplace_end();
+ } else if (dtype == "int64") {
+ stack_.func_call("register_buffer", "", "self")
+ .call_arg(DocUtils::ToStr(ref_name))
+ .inplace_start("torch.LongTensor")
+ .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+ .inplace_end();
} else {
stack_.func_call("torch.Tensor", "data")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
@@ -565,6 +579,39 @@ class TorchSimpleCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen);
};
+class TorchScatterElementsCodeGen : public TorchOpCode {
+ TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen)
+
+ protected:
+ void CodeGenForward() final {
+ if (node()->InputAt(1)->DTypeName() == "int32") {
+ stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+ }
+ stack_.op_call()
+ .op_input_arg()
+ .op_arg<int>("axis", "dim")
+ .op_input_arg(1, "index")
+ .op_input_arg(2, "src");
+ }
+};
+
+class TorchScatterNDCodeGen : public TorchOpCode {
+ TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen)
+
+ protected:
+ void CodeGenForward() final {
+ if (node()->InputAt(1)->DTypeName() == "int32") {
+ stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+ }
+ // relax add extra dim for indices
+ if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) {
+ stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1);
+ }
+ stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2))
+ .assign(IdxNode(), IdxInput(0));
+ }
+};
+
class TorchSplitCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen)
@@ -719,6 +766,9 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TorchOpCode>>>
map->emplace("permute_dims", std::make_shared<TorchPermuteDimsCodeGen>("",
"torch.permute"));
map->emplace("repeat", std::make_shared<TorchRepeatCodeGen>("", "repeat"));
map->emplace("reshape", std::make_shared<TorchReshapeCodeGen>("",
"torch.reshape"));
+ map->emplace("scatter_elements",
+ std::make_shared<TorchScatterElementsCodeGen>("",
"torch.scatter"));
+ map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("",
"torch.split"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("",
""));
diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc
b/src/contrib/msc/framework/tvm/relax_opcode.cc
index 73722f9877..a4be884858 100644
--- a/src/contrib/msc/framework/tvm/relax_opcode.cc
+++ b/src/contrib/msc/framework/tvm/relax_opcode.cc
@@ -568,6 +568,34 @@ class RelaxReshapeCodeGen : public RelaxOpCode {
}
};
+class RelaxScatterElementsCodeGen : public RelaxOpCode {
+ RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen)
+
+ protected:
+ void CodeGenBuild() final {
stack_.op_call().op_inputs_arg(false).op_arg<int>("axis"); }
+};
+
+class RelaxScatterNDCodeGen : public RelaxOpCode {
+ RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen)
+
+ protected:
+ void CodeGenBuild() final {
+ if (config()->from_relay) {
+ size_t ndim = node()->InputAt(1)->Ndim();
+ std::vector<size_t> axes;
+ axes.push_back(ndim - 1);
+ for (size_t i = 0; i < ndim - 1; i++) {
+ axes.push_back(i);
+ }
+ stack_.func_call("relax.op.permute_dims", IdxInput(1))
+ .call_arg(IdxInput(1))
+ .call_arg(DocUtils::ToList(axes));
+ BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index));
+ }
+ stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction");
+ }
+};
+
class RelaxResize2dCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen)
@@ -626,6 +654,20 @@ class RelaxSplitCodeGen : public RelaxOpCode {
}
};
+class RelaxStackCodeGen : public RelaxOpCode {
+ RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen)
+
+ protected:
+ void CodeGenBuild() final {
+ stack_.op_call().op_inputs_arg().op_arg<int>("axis");
+ BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index));
+ const auto& out_shape = GetPrims(node()->OutputAt(0));
+ stack_.func_call("relax.op.reshape", IdxNode())
+ .call_arg(IdxNode())
+ .call_arg(DocUtils::ToList(out_shape), "shape");
+ }
+};
+
class RelaxTakeCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen)
@@ -763,7 +805,11 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<RelaxOpCode>>>
map->emplace("permute_dims",
std::make_shared<RelaxPermuteDimsCodeGen>("relax.op.permute_dims"));
map->emplace("repeat",
std::make_shared<RelaxRepeatCodeGen>("relax.op.repeat"));
map->emplace("reshape",
std::make_shared<RelaxReshapeCodeGen>("relax.op.reshape"));
+ map->emplace("scatter_elements",
+
std::make_shared<RelaxScatterElementsCodeGen>("relax.op.scatter_elements"));
+ map->emplace("scatter_nd",
std::make_shared<RelaxScatterNDCodeGen>("relax.op.scatter_nd"));
map->emplace("split", std::make_shared<RelaxSplitCodeGen>("relax.op.split"));
+ map->emplace("stack",
std::make_shared<RelaxStackCodeGen>("relax.op.concat"));
map->emplace("strided_slice",
std::make_shared<RelaxStridedSliceCodeGen>("relax.op.strided_slice"));
map->emplace("take", std::make_shared<RelaxTakeCodeGen>("relax.op.take"));
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
index 76e3147a55..647879378e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -20,21 +20,16 @@
import pytest
import torch
-from torch import fx
from torch.nn import Module
import tvm.testing
-from tvm.relax.frontend.torch import from_fx
-from tvm.contrib.msc.core.frontend import translate, normalize_inputs
+from tvm.contrib.msc.framework.torch.frontend import translate
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
def verify_model(torch_model, input_info, expected):
- input_info = normalize_inputs(input_info)
- graph_model = fx.symbolic_trace(torch_model)
- with torch.no_grad():
- mod = from_fx(graph_model, input_info)
- graph, _ = translate.from_relax(mod)
+ graph, _ = translate.from_torch(torch_model, input_info)
inspect = graph.inspect()
assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with
expected {}".format(
inspect, expected
@@ -2389,6 +2384,119 @@ def test_cat(dynamic):
verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2)
[email protected]("dynamic", [True, False])
+def test_stack(dynamic):
+ """test graph builder for stack"""
+
+ bz = "bz" if dynamic else 1
+
+ class Stack(Module):
+ def forward(self, data, data1, data2):
+ return torch.stack((data, data1, data2), dim=0)
+
+ input_info = [
+ ([bz, 3, 10, 10], "float32"),
+ ([bz, 3, 10, 10], "float32"),
+ ([bz, 3, 10, 10], "float32"),
+ ]
+
+ expected = {
+ "inputs": [
+ {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32",
"layout": ""},
+ {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32",
"layout": ""},
+ {"name": "inp_2", "shape": [bz, 3, 10, 10], "dtype": "float32",
"layout": ""},
+ ],
+ "outputs": [
+ {
+ "name": "reshape",
+ "shape": [3, bz, 3, 10, 10],
+ "dtype": "float32",
+ "layout": "" if dynamic else "EABCD",
+ }
+ ],
+ "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1},
+ }
+
+ if dynamic:
+ expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
+
+ verify_model(Stack(), input_info, expected)
+
+
[email protected]("dynamic", [True, False])
+def test_scatter(dynamic):
+ """test graph builder for scatter"""
+
+ bz = "bz" if dynamic else 20
+
+ class Scatter1(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(2, 5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ return data.scatter(dim=0, index=self.index, src=src)
+
+ class Scatter2(Module):
+ def forward(self, data, index, src):
+ return data.scatter(0, index, src)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout":
""},
+ ],
+ "outputs": [
+ {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": ""}
+ ],
+ "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements":
1},
+ }
+ expected2 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
""},
+ {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
+ {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout":
""},
+ ],
+ "outputs": [
+ {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": ""}
+ ],
+ "nodes": {"total": 4, "input": 3, "scatter_elements": 1},
+ }
+ if dynamic:
+ expected1["prims"] = {"total": 1, "shape": 1}
+ expected2["prims"] = {"total": 1, "shape": 1}
+
+ verify_model(Scatter1(), [([bz, 20], "float32"), ([2, 5], "float32")],
expected1)
+ verify_model(
+ Scatter2(), [([bz, 20], "float32"), ([2, 5], "int64"), ([2, 5],
"float32")], expected2
+ )
+
+
+def test_put():
+ """test graph builder for index_put"""
+
+ class IndexPut(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ data[self.index] = src
+ return data
+
+ expected = {
+ "inputs": [
+ {"name": "input0", "shape": [10, 20], "dtype": "float32",
"layout": ""},
+ {"name": "input1", "shape": [5, 20], "dtype": "float32", "layout":
""},
+ ],
+ "outputs": [{"name": "scatter_nd", "shape": [10, 20], "dtype":
"float32", "layout": ""}],
+ "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_nd": 1},
+ }
+
+ input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+ verify_model(IndexPut(), input_info, expected)
+
+
@pytest.mark.parametrize("dynamic", [True, False])
def test_attention(dynamic):
"""test graph builder for attention"""
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py
b/tests/python/contrib/test_msc/test_translate_relax.py
index 64d00bb092..27a02844e1 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -18,27 +18,26 @@
""" Test translate from relax. """
import torch
-from torch import fx
from torch.nn import Module
import numpy as np
import tvm.testing
-from tvm.relax.frontend.torch import from_fx
-from tvm.contrib.msc.core.frontend import translate
+from tvm.contrib.msc.framework.torch.frontend import translate as
torch_translate
+
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
+from tvm.contrib.msc.core.frontend import translate as core_translate
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
def verify_model(torch_model, input_info, opt_config=None):
"""Compare torch module IR"""
- graph_model = fx.symbolic_trace(torch_model)
- with torch.no_grad():
- orig_mod = from_fx(graph_model, input_info)
-
+ orig_mod, _ = torch_translate.from_torch(torch_model, input_info,
as_msc=False)
target = "llvm"
dev = tvm.cpu()
- args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for
shape, dtype in input_info]
+ args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info]
def _tvm_runtime_to_np(obj):
if isinstance(obj, tvm.runtime.NDArray):
@@ -60,7 +59,7 @@ def verify_model(torch_model, input_info, opt_config=None):
return _tvm_runtime_to_np(res)
rt_mod = tvm_codegen.to_relax(
- *translate.from_relax(orig_mod, opt_config=opt_config),
+ *core_translate.from_relax(orig_mod, opt_config=opt_config),
codegen_config={"explicit_name": False},
)
@@ -1153,6 +1152,63 @@ def test_cat():
verify_model(Cat2(), [([1, 3, 10, 10], "float32")])
+def test_stack():
+ """test relax translator for stack"""
+
+ class Stack1(Module):
+ def forward(self, data, data1, data2):
+ return torch.stack((data, data1, data2), dim=0)
+
+ class Stack2(Module):
+ def forward(self, data):
+ const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+ const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+ return torch.stack((data, const1, const2), dim=1)
+
+ input_info = [
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ]
+ verify_model(Stack1(), input_info)
+ verify_model(Stack2(), [([1, 3, 10, 10], "float32")])
+
+
+def test_scatter():
+ """test relax translator for scatter"""
+
+ class Scatter1(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(2, 5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ return data.scatter(dim=0, index=self.index, src=src)
+
+ class Scatter2(Module):
+ def forward(self, data, index, src):
+ return data.scatter(0, index, src)
+
+ verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")])
+ verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2,
5], "float32")])
+
+
+def test_put():
+ """test relax translator for index_put"""
+
+ class IndexPut(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ data[self.index] = src
+ return data
+
+ input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+ verify_model(IndexPut(), input_info)
+
+
def test_attention():
"""test relax translator for attention"""
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py
b/tests/python/contrib/test_msc/test_translate_relay.py
index ebba339a4a..801893e9de 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -18,8 +18,6 @@
""" Test translate from relay. """
-import numpy as np
-
import torch
from torch import fx
from torch.nn import Module
@@ -66,7 +64,7 @@ def verify_model(torch_model, input_info, opt_config=None,
codegen_config=None,
expected = tvm.relax.transform.CanonicalizeBindings()(expected)
# graph from relay
- datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
+ datas = [msc_utils.random_data(i) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
with torch.no_grad():
scripted_model = torch.jit.trace(torch_model,
tuple(torch_datas)).eval() # type: ignore
diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py
b/tests/python/contrib/test_msc/test_translate_tensorrt.py
index 6d87ca8753..e0fd39249a 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorrt.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py
@@ -18,7 +18,6 @@
""" Test translate for TensorrRT. """
import pytest
-import numpy as np
import torch
from torch import fx
@@ -91,7 +90,7 @@ def verify_model(torch_model, input_info, **trans_config):
"""Build model and verify results"""
graph_model = fx.symbolic_trace(torch_model)
- datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
+ datas = [msc_utils.random_data(i) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
with torch.no_grad():
golden = torch_model(*torch_datas)
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py
b/tests/python/contrib/test_msc/test_translate_torch.py
index 55bae682ef..6ed28c0ac0 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -17,24 +17,24 @@
""" Test translate from torch. """
-import numpy as np
-
import torch
from torch.nn import Module
import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
def verify_model(torch_model, input_info, via_relax=True):
"""Compare torch module results"""
- graph, weights = translate.from_torch(torch_model, input_info,
via_relax=via_relax)
- model = codegen.to_torch(graph, weights)
- torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i
in input_info]
+ torch_datas = [msc_utils.random_data(i, MSCFramework.TORCH) for i in
input_info]
with torch.no_grad():
golden = torch_model(*torch_datas)
+ graph, weights = translate.from_torch(torch_model, input_info,
via_relax=via_relax)
+ model = codegen.to_torch(graph, weights)
with torch.no_grad():
if not graph.get_inputs():
result = model()
@@ -1128,6 +1128,67 @@ def test_cat():
verify_model(Cat2(), [([1, 3, 10, 10], "float32")], via_relax)
+def test_stack():
+ """test torch translator for stack"""
+
+ class Stack1(Module):
+ def forward(self, data, data1, data2):
+ return torch.stack((data, data1, data2), dim=0)
+
+ class Stack2(Module):
+ def forward(self, data):
+ const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+ const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
+ return torch.stack((data, const1, const2), dim=1)
+
+ input_info = [
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ]
+ for via_relax in [True, False]:
+ verify_model(Stack1(), input_info, via_relax)
+ verify_model(Stack2(), [([1, 3, 10, 10], "float32")], via_relax)
+
+
+def test_scatter():
+ """test torch translator for scatter"""
+
+ class Scatter1(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(2, 5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ return data.scatter(dim=0, index=self.index, src=src)
+
+ class Scatter2(Module):
+ def forward(self, data, index, src):
+ return data.scatter(0, index, src)
+
+ for via_relax in [True, False]:
+ verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")],
via_relax)
+ verify_model(
+ Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5],
"float32")], via_relax
+ )
+
+
+def test_put():
+ """test torch translator for index_put"""
+
+ class IndexPut(Module):
+ def __init__(self):
+ super().__init__()
+ self.index = msc_utils.random_data([(5), "int64"],
MSCFramework.TORCH, max_val=5)
+
+ def forward(self, data, src):
+ data[self.index] = src
+ return data
+
+ input_info = [([10, 20], "float32"), ([5, 20], "float32")]
+ verify_model(IndexPut(), input_info, False)
+
+
def test_attention():
"""test torch translator for attention"""
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2cabcba325..08331f0861 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3963,5 +3963,65 @@ def test_sym_size_int():
verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1)
+def test_stack():
+
+ input_info = [
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ([1, 3, 10, 10], "float32"),
+ ]
+
+ class Stack(Module):
+ def forward(self, data, data1, data2):
+ return torch.stack((data, data1, data2), dim=0)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ inp_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat(
+ (inp_0, inp_1, inp_2), axis=0
+ )
+ lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape(
+ lv, R.shape([3, 1, 3, 10, 10])
+ )
+ gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(Stack(), input_info, {}, expected)
+
+
+def test_scatter():
+ input_info = [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5],
"float32")]
+
+ class Scatter(Module):
+ def forward(self, data, index, src):
+ return data.scatter(dim=0, index=index, src=src)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((20, 20), dtype="float32"),
+ inp_1: R.Tensor((2, 5), dtype="int64"),
+ inp_2: R.Tensor((2, 5), dtype="float32"),
+ ) -> R.Tensor((20, 20), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((20, 20), dtype="float32") = R.scatter_elements(
+ inp_0, inp_1, inp_2, axis=0, reduction="update"
+ )
+ gv: R.Tensor((20, 20), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Scatter(), input_info, {}, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()