This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2a2081e [TOPI] GPU scatter_add using atomic (#7044)
2a2081e is described below
commit 2a2081e536f26b49506ef38fec820dc196bc6a2f
Author: masahi <[email protected]>
AuthorDate: Tue Dec 8 07:04:34 2020 +0900
[TOPI] GPU scatter_add using atomic (#7044)
* use atomic add for faster 1d scatter add
* update tests
* run black
* more pylint fix
* remove fp64 bintcount test
Co-authored-by: masa <[email protected]>
---
python/tvm/relay/frontend/pytorch.py | 17 +++++-
python/tvm/topi/cuda/scatter.py | 80 ++++++++++++++++++++++++++-
tests/python/frontend/pytorch/test_forward.py | 10 ++--
tests/python/relay/test_op_level3.py | 4 ++
4 files changed, 102 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 4f75cf3..d2c52fb 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1921,18 +1921,29 @@ class PyTorchOpConverter:
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
+ input_type = _infer_type(data).checked_type.dtype
+ if input_type == "int64":
+ logging.warning(
+ "Casting an int64 input to int32, since we do not have int64
atomic add"
+ "needed for bincount yet."
+ )
+ data = _op.cast(data, "int32")
maximum = _op.max(data)
- dim = maximum + _expr.const(1, dtype="int64")
+ dim = maximum + _expr.const(1, dtype="int32")
if weights:
weight_type = _infer_type(weights).checked_type
out_dtype = weight_type.dtype
updates = weights
else:
- out_dtype = "int64"
+ out_dtype = "int32"
updates = _op.ones_like(data)
counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
- return _op.scatter_add(counts, data, updates, axis=0)
+ out = _op.scatter_add(counts, data, updates, axis=0)
+ if input_type == "int32":
+ # Torch always outputs int64 results for bincount
+ return _op.cast(out, "int64")
+ return out
def scatter_add(self, inputs, input_types):
data = inputs[0]
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 5e03faf..89c5cd2 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -19,6 +19,7 @@
import tvm
from tvm import te
from ..scatter import _verify_scatter_nd_inputs
+from .nms import atomic_add
def ceil_div(a, b):
@@ -470,6 +471,83 @@ def scatter(data, indices, updates, axis=0):
return out
+def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
+ """Generate scatter add ir for 1d inputs, using atomic_add instruction
+
+ Parameters
+ ----------
+ data : tir.Tensor
+ The input data to the operator.
+
+ indices : tir.Tensor
+ The index locations to update.
+
+ updates : tir.Tensor
+ The values to update.
+
+ axis : int
+ The axis to scatter on
+
+ out : tir.Tensor
+ The output tensor.
+
+ Returns
+ -------
+ ret : tir
+ The computational ir.
+ """
+ assert axis == 0
+ n = data.shape[0]
+
+ ib = tvm.tir.ir_builder.create()
+
+ out_ptr = ib.buffer_ptr(out)
+ data_ptr = ib.buffer_ptr(data)
+
+ max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+
+ with ib.new_scope():
+ nthread_bx = ceil_div(n, nthread_tx)
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * nthread_tx + tx
+ with ib.if_scope(tid < n):
+ out_ptr[tid] = data_ptr[tid]
+
+ indices_ptr = ib.buffer_ptr(indices)
+ updates_ptr = ib.buffer_ptr(updates)
+
+ ni = indices.shape[0]
+
+ atomic_add_return = ib.allocate(updates.dtype, (1,),
name="atomic_add_return", scope="local")
+
+ with ib.new_scope():
+ nthread_bx = ceil_div(ni, nthread_tx)
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * nthread_tx + tx
+
+ with ib.if_scope(tid < ni):
+ index = indices_ptr[tid]
+ with ib.if_scope(index < 0):
+ atomic_add_return[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of",
out_ptr[index + n]),
+ updates_ptr[tid],
+ )
+ with ib.else_scope():
+ atomic_add_return[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of",
out_ptr[index]),
+ updates_ptr[tid],
+ )
+
+ return ib.get()
+
+
def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices
@@ -501,7 +579,7 @@ def scatter_add(data, indices, updates, axis=0):
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"
ir_funcs = {
- 1: gen_ir_1d,
+ 1: gen_scatter_add_1d_atomic,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 6250dff..2dda675 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3355,12 +3355,12 @@ def test_bincount():
def test_fn(x, weights=None):
return torch.bincount(x, weights=weights)
- inp = torch.randint(0, 8, (5,), dtype=torch.int64)
- weights = torch.linspace(0, 1, steps=5)
+ inp = torch.randint(0, 100, (10000,), dtype=torch.int64)
+ weights = torch.linspace(0, 100, steps=10000)
- verify_trace_model(test_fn, [inp], ["llvm"])
- verify_trace_model(test_fn, [inp, weights], ["llvm"])
- verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
+ targets = ["llvm", "cuda"]
+ verify_trace_model(test_fn, [inp], targets)
+ verify_trace_model(test_fn, [inp, weights], targets)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_op_level3.py
b/tests/python/relay/test_op_level3.py
index 82d0563..fc1929e 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1017,11 +1017,15 @@ def test_scatter_add():
ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
+ if target == "nvptx":
+ # TODO(masahi): support atomic in LLVM codegen
+ continue
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
rtol=1e-5)
verify_scatter_add((10,), (10,), 0)
+ verify_scatter_add((1000,), (1000,), 0)
verify_scatter_add((10, 5), (10, 5), -2)
verify_scatter_add((10, 5), (10, 5), -1)
verify_scatter_add((10, 5), (3, 5), 0)