This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new efe3a79 [Torch] Support bincount and scatter_add ops (#6740)
efe3a79 is described below
commit efe3a79aacd934ea5ffb13170230bf199a473e72
Author: masahi <[email protected]>
AuthorDate: Sat Oct 24 23:00:36 2020 +0900
[Torch] Support bincount and scatter_add ops (#6740)
---
python/tvm/relay/frontend/pytorch.py | 33 +++++++++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 32 ++++++++++++++++++--------
2 files changed, 56 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index c8fbd5a..c41d680 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2357,6 +2357,37 @@ def _empty():
return _impl
+def _bincount():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ weights = inputs[1]
+ maximum = _op.max(data)
+ dim = maximum + _expr.const(1, dtype="int64")
+ if weights:
+ weight_type = _infer_type(weights).checked_type
+ out_dtype = weight_type.dtype
+ updates = weights
+ else:
+ out_dtype = "int64"
+ updates = _op.ones_like(data)
+
+ counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
+ return _op.scatter_add(counts, data, updates, axis=0)
+
+ return _impl
+
+
+def _scatter_add():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ axis = inputs[1]
+ index = inputs[2]
+ src = inputs[3]
+ return _op.scatter_add(data, index, src, axis=axis)
+
+ return _impl
+
+
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
@@ -2699,6 +2730,8 @@ def _get_convert_map(prelude, default_dtype):
"aten::tensor": _identity(), # used for example in tensor(1.0)
"aten::numel": _numel(),
"aten::empty": _empty(),
+ "aten::bincount": _bincount(),
+ "aten::scatter_add": _scatter_add(),
}
return convert_map
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 54c3daf..e997ebe 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3139,26 +3139,27 @@ def test_forward_nonzero():
def test_forward_scatter():
- class Scatter(Module):
- def __init__(self, dim=0):
- super().__init__()
- self.dim = dim
+ # integer cannot be traced
+ def test_fn_scatter(dim):
+ return lambda data, index, src: torch.scatter(data, dim=dim,
index=index, src=src)
- def forward(self, data, index, src):
- return torch.scatter(data, dim=self.dim, index=index, src=src)
+ def test_fn_scatter_add(dim):
+ return lambda data, index, src: torch.scatter_add(data, dim=dim,
index=index, src=src)
in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
# TODO: add scatter gpu schedule to enable gpu test.
- verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"])
+ verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src],
["llvm"])
+ verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src],
["llvm"])
in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)
- # TODO: add scatter gpu schedule to enable gpu test.
- verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])
+ # # TODO: add scatter gpu schedule to enable gpu test.
+ verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src],
["llvm"])
+ verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src],
["llvm"])
def test_numel():
@@ -3350,6 +3351,18 @@ def test_convert_torch_script_with_input_types():
assert tvm.ir.structural_equal(expected_mod, mod["main"],
map_free_vars=True)
+def test_bincount():
+ def test_fn(x, weights=None):
+ return torch.bincount(x, weights=weights)
+
+ inp = torch.randint(0, 8, (5,), dtype=torch.int64)
+ weights = torch.linspace(0, 1, steps=5)
+
+ verify_trace_model(test_fn, [inp], ["llvm"])
+ verify_trace_model(test_fn, [inp, weights], ["llvm"])
+ verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
+
+
if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
@@ -3476,6 +3489,7 @@ if __name__ == "__main__":
test_forward_nonzero()
test_forward_scatter()
test_numel()
+ test_bincount()
# Model tests
test_resnet18()