This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b9fa576ab3 [BugFix] Use shape dtype on ArgReduce to determine return
type (#12083)
b9fa576ab3 is described below
commit b9fa576ab35ef67e0a7fbaf669109f010e66e20c
Author: Everton Constantino <[email protected]>
AuthorDate: Thu Jul 14 14:54:56 2022 -0300
[BugFix] Use shape dtype on ArgReduce to determine return type (#12083)
Fix ArgReduce automatic return type inference by forcing it to use the
datatype of the shape of the Tensor instead of the fixed Int32.
Including additional tests.
---
src/relay/op/tensor/reduce.cc | 2 +-
tests/python/relay/test_type_infer.py | 38 +++++++++++++++++++++++++++++++++++
2 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index fba2a60cec..2b1afc6e55 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -338,7 +338,7 @@ bool GenericReduceRel(const Array<Type>& types, int
num_inputs, const Attrs& att
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
- reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
+ reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype()));
return true;
}
/*!
diff --git a/tests/python/relay/test_type_infer.py
b/tests/python/relay/test_type_infer.py
index af64ce714d..b0b7ef0481 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -23,6 +23,8 @@ from tvm import IRModule, parser, relay, te
from tvm.relay import analysis, op, transform
from tvm.relay.op import op as _op
+import numpy as np
+
def infer_mod(mod, annotate_spans=True):
if annotate_spans:
@@ -544,6 +546,42 @@ def test_repeat_register():
assert "Operator custom_log3 is registered before" in str(cm.execption)
+def test_argreduce_infer_return_type():
+ x_shape = (1, 1)
+ broadcast_shape = [1, 1]
+ shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x:
np.int64(x))]
+
+ # Testing with argmax
+ for (sdtype, conv) in shape_dtypes:
+ x = relay.var("data", relay.TensorType(x_shape, "float32"))
+ broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape,
dtype=sdtype))
+ argmax = relay.op.argmax(broadcast_to, axis=[1])
+
+ f = relay.Function([x], argmax)
+ assert_has_type(
+ f,
+ relay.FuncType(
+ [relay.TensorType(broadcast_shape, "float32")],
+ relay.TensorType([conv(1)], dtype=sdtype),
+ ),
+ )
+
+ # Testing with argmin
+ for (sdtype, conv) in shape_dtypes:
+ x = relay.var("data", relay.TensorType(x_shape, "float32"))
+ broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape,
dtype=sdtype))
+ argmin = relay.op.argmin(broadcast_to, axis=[1])
+
+ f = relay.Function([x], argmin)
+ assert_has_type(
+ f,
+ relay.FuncType(
+ [relay.TensorType(broadcast_shape, "float32")],
+ relay.TensorType([conv(1)], dtype=sdtype),
+ ),
+ )
+
+
if __name__ == "__main__":
import sys