This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5ad2f77 [Relay] Gather op dynamic input support (#9240)
5ad2f77 is described below
commit 5ad2f77403bed9a2bf356cc0d3d785ecc13e6c58
Author: masahi <[email protected]>
AuthorDate: Tue Oct 12 01:22:10 2021 +0900
[Relay] Gather op dynamic input support (#9240)
* support gather op dynamic input
* fix shape func and add test
* remove constness check
* fix shape func output rank
* restore check
Co-authored-by: masa <[email protected]>
---
include/tvm/topi/transform.h | 6 ++++--
python/tvm/relay/op/_transform.py | 20 ++++++++++++++++++++
src/relay/op/tensor/transform.cc | 6 ++++--
tests/python/relay/test_any.py | 22 ++++++++++++++++++++++
4 files changed, 50 insertions(+), 4 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 8d1a49a..3df9caf 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1233,8 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const
Tensor& indices,
}
ICHECK_GE(axis, 0);
ICHECK_LT(axis, ndim_d);
- size_t indices_dim_i =
static_cast<size_t>(GetConstInt(indices->shape[axis]));
- ICHECK_GE(indices_dim_i, 1);
+ if (indices->shape[axis].as<IntImmNode>()) {
+ size_t indices_dim_i =
static_cast<size_t>(GetConstInt(indices->shape[axis]));
+ ICHECK_GE(indices_dim_i, 1);
+ }
ICHECK(indices->dtype.is_int());
Array<PrimExpr> out_shape;
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index 0284d24..76c8069 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -1174,3 +1174,23 @@ def gather_nd_shape_func(attrs, inputs, _):
assert index_rank > 0, "index_rank needs to be specified for dynamic
gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims),
convert(index_rank))]
+
+
+@script
+def _gather_shape(data_shape, indices_shape, axis):
+ out_shape = output_tensor((data_shape.shape[0],), "int64")
+ for i in range(data_shape.shape[0]):
+ if i != axis:
+ assert (
+ data_shape[i] == indices_shape[i]
+ ), "data and indices size at non-gather axes must be the same"
+ out_shape[i] = indices_shape[i]
+ return out_shape
+
+
+@_reg.register_shape_func("gather", False)
+def gather_shape_func(attrs, inputs, _):
+ """
+ Shape func for gather operator.
+ """
+ return [_gather_shape(inputs[0], inputs[1], attrs.axis)]
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 3781107..fa5b31a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3260,8 +3260,10 @@ bool GatherRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
oshape.reserve(ndim_data);
for (size_t i = 0; i < ndim_data; ++i) {
if (i == static_cast<size_t>(axis)) {
- const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
- ICHECK_GE(*indice_shape_i, 1);
+ if (indices->shape[i].as<IntImmNode>()) {
+ const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
+ ICHECK_GE(*indice_shape_i, 1);
+ }
} else {
ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
}
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index decddc1..8788faf 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -2064,5 +2064,27 @@ def test_scatter_nd():
verify_scatter_nd(data, indices, updates, out)
[email protected]_gpu
+def test_gather():
+ def verify_gather(data_shape, indices_shape, data_shape_np,
indices_shape_np, axis):
+ x = relay.var("x", relay.TensorType(data_shape, "float32"))
+ y = relay.var("y", relay.TensorType(indices_shape, "int32"))
+ z = relay.gather(x, axis, y)
+
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x, y], z)
+
+ data_np = np.random.uniform(size=data_shape_np).astype("float32")
+ indices_np = np.random.randint(low=0, high=2, size=indices_shape_np,
dtype="int32")
+
+ ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np)
+ check_result([data_np, indices_np], mod, [ref_res])
+
+ verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0)
+ verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1)
+ verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1)
+ verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2,
3), (1, 3), 0)
+
+
if __name__ == "__main__":
pytest.main([__file__])