This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new efe3a79  [Torch] Support bincount and scatter_add ops (#6740)
efe3a79 is described below

commit efe3a79aacd934ea5ffb13170230bf199a473e72
Author: masahi <[email protected]>
AuthorDate: Sat Oct 24 23:00:36 2020 +0900

    [Torch] Support bincount and scatter_add ops (#6740)
---
 python/tvm/relay/frontend/pytorch.py          | 33 +++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 32 ++++++++++++++++++--------
 2 files changed, 56 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index c8fbd5a..c41d680 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2357,6 +2357,37 @@ def _empty():
     return _impl
 
 
+def _bincount():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        weights = inputs[1]
+        maximum = _op.max(data)
+        dim = maximum + _expr.const(1, dtype="int64")
+        if weights:
+            weight_type = _infer_type(weights).checked_type
+            out_dtype = weight_type.dtype
+            updates = weights
+        else:
+            out_dtype = "int64"
+            updates = _op.ones_like(data)
+
+        counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
+        return _op.scatter_add(counts, data, updates, axis=0)
+
+    return _impl
+
+
+def _scatter_add():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        axis = inputs[1]
+        index = inputs[2]
+        src = inputs[3]
+        return _op.scatter_add(data, index, src, axis=axis)
+
+    return _impl
+
+
 def _pytorch_result_type(dtypes, non_tensor_inputs):
     """This promotes TVM dtypes like PyTorch would"""
     import torch
@@ -2699,6 +2730,8 @@ def _get_convert_map(prelude, default_dtype):
         "aten::tensor": _identity(),  # used for example in tensor(1.0)
         "aten::numel": _numel(),
         "aten::empty": _empty(),
+        "aten::bincount": _bincount(),
+        "aten::scatter_add": _scatter_add(),
     }
     return convert_map
 
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 54c3daf..e997ebe 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3139,26 +3139,27 @@ def test_forward_nonzero():
 
 
 def test_forward_scatter():
-    class Scatter(Module):
-        def __init__(self, dim=0):
-            super().__init__()
-            self.dim = dim
+    # integer cannot be traced
+    def test_fn_scatter(dim):
+        return lambda data, index, src: torch.scatter(data, dim=dim, 
index=index, src=src)
 
-        def forward(self, data, index, src):
-            return torch.scatter(data, dim=self.dim, index=index, src=src)
+    def test_fn_scatter_add(dim):
+        return lambda data, index, src: torch.scatter_add(data, dim=dim, 
index=index, src=src)
 
     in_data = torch.zeros(3, 5)
     in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
     in_src = torch.rand(2, 5)
     # TODO: add scatter gpu schedule to enable gpu test.
-    verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"])
+    verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], 
["llvm"])
+    verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], 
["llvm"])
 
     in_data = torch.zeros(2, 4)
     in_index = torch.tensor([[2], [3]])
     in_src = torch.rand(2, 1)
 
-    # TODO: add scatter gpu schedule to enable gpu test.
-    verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])
+    # # TODO: add scatter gpu schedule to enable gpu test.
+    verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], 
["llvm"])
+    verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], 
["llvm"])
 
 
 def test_numel():
@@ -3350,6 +3351,18 @@ def test_convert_torch_script_with_input_types():
     assert tvm.ir.structural_equal(expected_mod, mod["main"], 
map_free_vars=True)
 
 
+def test_bincount():
+    def test_fn(x, weights=None):
+        return torch.bincount(x, weights=weights)
+
+    inp = torch.randint(0, 8, (5,), dtype=torch.int64)
+    weights = torch.linspace(0, 1, steps=5)
+
+    verify_trace_model(test_fn, [inp], ["llvm"])
+    verify_trace_model(test_fn, [inp, weights], ["llvm"])
+    verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
+
+
 if __name__ == "__main__":
     # some structural tests
     test_forward_traced_function()
@@ -3476,6 +3489,7 @@ if __name__ == "__main__":
     test_forward_nonzero()
     test_forward_scatter()
     test_numel()
+    test_bincount()
 
     # Model tests
     test_resnet18()

Reply via email to