This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a2081e  [TOPI] GPU scatter_add using atomic (#7044)
2a2081e is described below

commit 2a2081e536f26b49506ef38fec820dc196bc6a2f
Author: masahi <[email protected]>
AuthorDate: Tue Dec 8 07:04:34 2020 +0900

    [TOPI] GPU scatter_add using atomic (#7044)
    
    * use atomic add for faster 1d scatter add
    
    * update tests
    
    * run black
    
    * more pylint fix
    
    * remove fp64 bintcount test
    
    Co-authored-by: masa <[email protected]>
---
 python/tvm/relay/frontend/pytorch.py          | 17 +++++-
 python/tvm/topi/cuda/scatter.py               | 80 ++++++++++++++++++++++++++-
 tests/python/frontend/pytorch/test_forward.py | 10 ++--
 tests/python/relay/test_op_level3.py          |  4 ++
 4 files changed, 102 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 4f75cf3..d2c52fb 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1921,18 +1921,29 @@ class PyTorchOpConverter:
     def bincount(self, inputs, input_types):
         data = inputs[0]
         weights = inputs[1]
+        input_type = _infer_type(data).checked_type.dtype
+        if input_type == "int64":
+            logging.warning(
+                "Casting an int64 input to int32, since we do not have int64 
atomic add"
+                "needed for bincount yet."
+            )
+            data = _op.cast(data, "int32")
         maximum = _op.max(data)
-        dim = maximum + _expr.const(1, dtype="int64")
+        dim = maximum + _expr.const(1, dtype="int32")
         if weights:
             weight_type = _infer_type(weights).checked_type
             out_dtype = weight_type.dtype
             updates = weights
         else:
-            out_dtype = "int64"
+            out_dtype = "int32"
             updates = _op.ones_like(data)
 
         counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
-        return _op.scatter_add(counts, data, updates, axis=0)
+        out = _op.scatter_add(counts, data, updates, axis=0)
+        if input_type == "int32":
+            # Torch always outputs int64 results for bincount
+            return _op.cast(out, "int64")
+        return out
 
     def scatter_add(self, inputs, input_types):
         data = inputs[0]
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 5e03faf..89c5cd2 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -19,6 +19,7 @@
 import tvm
 from tvm import te
 from ..scatter import _verify_scatter_nd_inputs
+from .nms import atomic_add
 
 
 def ceil_div(a, b):
@@ -470,6 +471,83 @@ def scatter(data, indices, updates, axis=0):
     return out
 
 
+def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
+    """Generate scatter add ir for 1d inputs, using atomic_add instruction
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    with ib.new_scope():
+        nthread_bx = ceil_div(n, nthread_tx)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+        with ib.if_scope(tid < n):
+            out_ptr[tid] = data_ptr[tid]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    atomic_add_return = ib.allocate(updates.dtype, (1,), 
name="atomic_add_return", scope="local")
+
+    with ib.new_scope():
+        nthread_bx = ceil_div(ni, nthread_tx)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        with ib.if_scope(tid < ni):
+            index = indices_ptr[tid]
+            with ib.if_scope(index < 0):
+                atomic_add_return[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", 
out_ptr[index + n]),
+                    updates_ptr[tid],
+                )
+            with ib.else_scope():
+                atomic_add_return[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", 
out_ptr[index]),
+                    updates_ptr[tid],
+                )
+
+    return ib.get()
+
+
 def scatter_add(data, indices, updates, axis=0):
     """Update data by adding values in updates at positions defined by indices
 
@@ -501,7 +579,7 @@ def scatter_add(data, indices, updates, axis=0):
     assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"
 
     ir_funcs = {
-        1: gen_ir_1d,
+        1: gen_scatter_add_1d_atomic,
         2: gen_ir_2d,
         3: gen_ir_3d,
         4: gen_ir_4d,
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 6250dff..2dda675 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3355,12 +3355,12 @@ def test_bincount():
     def test_fn(x, weights=None):
         return torch.bincount(x, weights=weights)
 
-    inp = torch.randint(0, 8, (5,), dtype=torch.int64)
-    weights = torch.linspace(0, 1, steps=5)
+    inp = torch.randint(0, 100, (10000,), dtype=torch.int64)
+    weights = torch.linspace(0, 100, steps=10000)
 
-    verify_trace_model(test_fn, [inp], ["llvm"])
-    verify_trace_model(test_fn, [inp, weights], ["llvm"])
-    verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
+    targets = ["llvm", "cuda"]
+    verify_trace_model(test_fn, [inp], targets)
+    verify_trace_model(test_fn, [inp, weights], targets)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_op_level3.py 
b/tests/python/relay/test_op_level3.py
index 82d0563..fc1929e 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1017,11 +1017,15 @@ def test_scatter_add():
         ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
         for target, ctx in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
+                if target == "nvptx":
+                    # TODO(masahi): support atomic in LLVM codegen
+                    continue
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, 
rtol=1e-5)
 
     verify_scatter_add((10,), (10,), 0)
+    verify_scatter_add((1000,), (1000,), 0)
     verify_scatter_add((10, 5), (10, 5), -2)
     verify_scatter_add((10, 5), (10, 5), -1)
     verify_scatter_add((10, 5), (3, 5), 0)

Reply via email to